# SPDX-License-Identifier: MPL-2.0 import sqlite3 from random import expovariate import typing import aiohttp from utils import http_session_factory async def make_sentence(cfg): # set default if not cfg["textsynth_engine_id"]: cfg["textsynth_engine_id"] = "gptneox_20B" if not cfg["textsynth_token"]: raise ValueError("textsynth_token is not set, create an account at textsynth.com") db = sqlite3.connect(cfg["db_path"]) db.text_factory = str c = db.cursor() if cfg['learn_from_cw']: ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")" toots = c.execute( f""" SELECT content FROM posts WHERE summary IS NULL OR summary NOT IN {ignored_cws_query_params} ORDER BY POST_ID DESC """, cfg["ignored_cws"], ).fetchall() else: toots = c.execute( """ SELECT content FROM posts WHERE summary IS NULL ORDER POST_ID DESC LIMIT 10000 """, ).fetchall() if not toots: raise ValueError("Database is empty! Try running main.py.") # get first element of tuple toots: list[str] = [toot[0] for toot in toots] # replace stuff toots = [toot.translate(str.maketrans({ ord('\n'): "\\n", })) for toot in toots] new_toots = [] def sum_of_chars(list_of_strings: list[str]) -> int: return sum(len(string) + 1 for string in list_of_strings) # +1 for \n while sum_of_chars(new_toots) < 8192: index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked # round index to nearest int index = int(index) # remove toot from list toot = toots.pop(index) # add toot to new list new_toots.append(toot) toots = new_toots # concatenate toots toots = "\n".join(toots) # truncate to last 8192 characters toots = toots[-8192:] # raise ValueError("toots: " + toots) # POST https://api.textsynth.com/v1/engines/{engine_id}/completions # using aiohttp post = None async with http_session_factory() as session: while post is None: async with session.post( "https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]), headers={ "Authorization": "Bearer {}".format(cfg["textsynth_token"]), }, json={ "prompt": toots + "\n", "stop": "\n", "max_tokens": 200 }, ) as resp: if resp.status != 200: raise ValueError("TextSynth API returned status code {}".format(resp.status)) data = await resp.json() if not data["text"]: # raise ValueError("TextSynth API returned empty text") # just generate a new sentence continue post: str = data["text"] # check wether the post only consists of mentions # split by words words = post.split() # check if all words are mentions if all(word.startswith("@") for word in words): # generate a new sentence post = None continue db.close() # replace stuff in post post = post.replace("\\n", "\n") # and post it return post