From 879eb32b2d0e902b9a760a0ff03965d9b5b59821 Mon Sep 17 00:00:00 2001 From: PeachyDelight Date: Sun, 19 Jun 2022 02:59:40 +0200 Subject: [PATCH] add textsynth --- config.defaults.json | 6 +- generators/textsynth.py | 121 ++++++++++++++++++++++++++++++++++++++++ third_party/utils.py | 11 +++- 3 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 generators/textsynth.py diff --git a/config.defaults.json b/config.defaults.json index 91fafad..e45c552 100644 --- a/config.defaults.json +++ b/config.defaults.json @@ -13,5 +13,7 @@ "overlap_ratio": 0.7, "generation_mode": "markov", "access_token": "", - "db_path": "" -} \ No newline at end of file + "db_path": "", + "textsynth_token": null, + "textsynth_engine_id": "gptneox_20B" +} diff --git a/generators/textsynth.py b/generators/textsynth.py new file mode 100644 index 0000000..1db8114 --- /dev/null +++ b/generators/textsynth.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: MPL-2.0 + +import sqlite3 +from random import expovariate +import typing +import aiohttp + +async def make_sentence(cfg): + # set default + if not cfg["textsynth_engine_id"]: + cfg["textsynth_engine_id"] = "gptneox_20B" + + if not cfg["textsynth_token"]: + raise ValueError("textsynth_token is not set, create an account at textsynth.com") + + db = sqlite3.connect(cfg["db_path"]) + db.text_factory = str + c = db.cursor() + if cfg['learn_from_cw']: + ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")" + toots = c.execute( + f""" + SELECT content + FROM posts + WHERE + summary IS NULL + OR summary NOT IN {ignored_cws_query_params} + ORDER BY POST_ID DESC + """, + cfg["ignored_cws"], + ).fetchall() + else: + toots = c.execute( + """ + SELECT content + FROM posts + WHERE summary IS NULL + ORDER POST_ID DESC + LIMIT 10000 + """, + ).fetchall() + + if not toots: + raise ValueError("Database is empty! Try running main.py.") + + + # get first element of tuple + toots: list[str] = [toot[0] for toot in toots] + + # replace stuff + toots = [toot.translate(str.maketrans({ + ord('\n'): "\\n", + })) for toot in toots] + + new_toots = [] + + def sum_of_chars(list_of_strings: list[str]) -> int: + return sum(len(string) + 1 for string in list_of_strings) # +1 for \n + + while sum_of_chars(new_toots) < 8192: + index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked + # round index to nearest int + index = int(index) + # remove toot from list + toot = toots.pop(index) + # add toot to new list + new_toots.append(toot) + + toots = new_toots + + # concatenate toots + toots = "\n".join(toots) + + # truncate to last 8192 characters + toots = toots[-8192:] + + # raise ValueError("toots: " + toots) + + # POST https://api.textsynth.com/v1/engines/{engine_id}/completions + + # using aiohttp + + post = None + while post is None: + async with aiohttp.ClientSession() as session: + async with session.post( + "https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]), + headers={ + "Authorization": "Bearer {}".format(cfg["textsynth_token"]), + }, + json={ + "prompt": toots + "\n", + "stop": "\n", + "max_tokens": 200 + }, + ) as resp: + if resp.status != 200: + raise ValueError("TextSynth API returned status code {}".format(resp.status)) + data = await resp.json() + if not data["text"]: + # raise ValueError("TextSynth API returned empty text") + # just generate a new sentence + continue + post: str = data["text"] + + # check wether the post only consists of mentions + # split by words + words = post.split() + # check if all words are mentions + if all(word.startswith("@") for word in words): + # generate a new sentence + post = None + continue + + db.close() + + # replace stuff in post + post = post.replace("\\n", "\n") + + # and post it + return post diff --git a/third_party/utils.py b/third_party/utils.py index b6de7ae..2a3014e 100644 --- a/third_party/utils.py +++ b/third_party/utils.py @@ -7,6 +7,7 @@ import html import enum import json import shutil +import inspect import sqlite3 import argparse import itertools @@ -18,6 +19,7 @@ from bs4 import BeautifulSoup TextGenerationMode = enum.Enum('TextGenerationMode', """ markov gpt_2 + textsynth """.split()) def arg_parser_factory(*, description): @@ -67,8 +69,15 @@ async def make_post(cfg, *, mode=TextGenerationMode.markov): from generators.markov import make_sentence elif mode is TextGenerationMode.gpt_2: from generators.gpt_2 import make_sentence + elif mode is TextGenerationMode.textsynth: + from generators.textsynth import make_sentence - return await anyio.to_process.run_sync(make_sentence, cfg) + # return await anyio.to_process.run_sync(make_sentence, cfg) + # check if inspect.iscoroutinefunction(object) + if inspect.iscoroutinefunction(make_sentence): + return await make_sentence(cfg) + else: + return await anyio.to_process.run_sync(make_sentence, cfg) def extract_post_content(text): soup = BeautifulSoup(text, "html.parser")