add basic migration support

io 2 years ago
parent 191214dbd6
commit b906abe2b1

@ -9,10 +9,10 @@ import pendulum
import operator
import aiosqlite
import contextlib
from utils import shield
from pleroma import Pleroma
from bs4 import BeautifulSoup
from functools import partial
from utils import shield, suppress
from typing import Iterable, NewType
from third_party.utils import extract_post_content
@ -26,6 +26,8 @@ UTC = pendulum.timezone('UTC')
JSON_CONTENT_TYPE = 'application/json'
ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json'
class PostFetcher:
def __init__(self, *, config):
self.config = config
@ -47,10 +49,26 @@ class PostFetcher:
self._db = await stack.enter_async_context(aiosqlite.connect(self.config['db_path']))
await self._maybe_run_migrations()
self._db.row_factory = aiosqlite.Row
self._ctx_stack = stack
return self
async def _maybe_run_migrations(self):
async with self._db.cursor() as cur, suppress(aiosqlite.OperationalError):
if await (await cur.execute('SELECT migration_version FROM migrations')).fetchone(): return
await self._run_migrations()
async def _run_migrations(self):
# TODO proper migrations, not just "has the schema ever been run" migrations
async with await (anyio.Path(__file__).parent/'schema.sql').open() as f:
schema = await
async with self._db.cursor() as cur:
await cur.executescript(schema)
await cur.execute('INSERT INTO migrations (migration_version) VALUES (?)', (MIGRATION_VERSION,))
async def __aexit__(self, *excinfo):
return await self._ctx_stack.__aexit__(*excinfo)

@ -6,3 +6,7 @@ CREATE TABLE posts (
-- UTC Unix timestamp in seconds
published_at REAL NOT NULL
CREATE TABLE migrations (
migration_version INTEGER NOT NULL

@ -1,7 +1,25 @@
# SPDX-License-Identifier: AGPL-3.0-only
import anyio
import contextlib
from functools import wraps
from datetime import datetime, timezone
def as_corofunc(f):
async def wrapped(*args, **kwargs):
# can't decide if i want an `anyio.sleep(0)` here.
return f(*args, **kwargs)
return wrapped
def as_async_cm(cls):
@wraps(cls, updated=()) # cls.__dict__ doesn't support .update()
class wrapped(cls, contextlib.AbstractAsyncContextManager):
__aenter__ = as_corofunc(cls.__enter__)
__aexit__ = as_corofunc(cls.__exit__)
return wrapped
suppress = as_async_cm(contextlib.suppress)
def shield(f):