Compare commits

...

2 Commits

3 changed files with 135 additions and 3 deletions

View File

@ -13,5 +13,7 @@
"overlap_ratio": 0.7,
"generation_mode": "markov",
"access_token": "",
"db_path": ""
"db_path": "",
"textsynth_token": null,
"textsynth_engine_id": "gptneox_20B"
}

121
generators/textsynth.py Normal file
View File

@ -0,0 +1,121 @@
# SPDX-License-Identifier: MPL-2.0
import sqlite3
from random import expovariate
import typing
import aiohttp
async def make_sentence(cfg):
# set default
if not cfg["textsynth_engine_id"]:
cfg["textsynth_engine_id"] = "gptneox_20B"
if not cfg["textsynth_token"]:
raise ValueError("textsynth_token is not set, create an account at textsynth.com")
db = sqlite3.connect(cfg["db_path"])
db.text_factory = str
c = db.cursor()
if cfg['learn_from_cw']:
ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")"
toots = c.execute(
f"""
SELECT content
FROM posts
WHERE
summary IS NULL
OR summary NOT IN {ignored_cws_query_params}
ORDER BY POST_ID DESC
""",
cfg["ignored_cws"],
).fetchall()
else:
toots = c.execute(
"""
SELECT content
FROM posts
WHERE summary IS NULL
ORDER POST_ID DESC
LIMIT 10000
""",
).fetchall()
if not toots:
raise ValueError("Database is empty! Try running main.py.")
# get first element of tuple
toots: list[str] = [toot[0] for toot in toots]
# replace stuff
toots = [toot.translate(str.maketrans({
ord('\n'): "\\n",
})) for toot in toots]
new_toots = []
def sum_of_chars(list_of_strings: list[str]) -> int:
return sum(len(string) + 1 for string in list_of_strings) # +1 for \n
while sum_of_chars(new_toots) < 8192:
index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked
# round index to nearest int
index = int(index)
# remove toot from list
toot = toots.pop(index)
# add toot to new list
new_toots.append(toot)
toots = new_toots
# concatenate toots
toots = "\n".join(toots)
# truncate to last 8192 characters
toots = toots[-8192:]
# raise ValueError("toots: " + toots)
# POST https://api.textsynth.com/v1/engines/{engine_id}/completions
# using aiohttp
post = None
while post is None:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]),
headers={
"Authorization": "Bearer {}".format(cfg["textsynth_token"]),
},
json={
"prompt": toots + "\n",
"stop": "\n",
"max_tokens": 200
},
) as resp:
if resp.status != 200:
raise ValueError("TextSynth API returned status code {}".format(resp.status))
data = await resp.json()
if not data["text"]:
# raise ValueError("TextSynth API returned empty text")
# just generate a new sentence
continue
post: str = data["text"]
# check wether the post only consists of mentions
# split by words
words = post.split()
# check if all words are mentions
if all(word.startswith("@") for word in words):
# generate a new sentence
post = None
continue
db.close()
# replace stuff in post
post = post.replace("\\n", "\n")
# and post it
return post

View File

@ -7,6 +7,7 @@ import html
import enum
import json
import shutil
import inspect
import sqlite3
import argparse
import itertools
@ -18,6 +19,7 @@ from bs4 import BeautifulSoup
TextGenerationMode = enum.Enum('TextGenerationMode', """
markov
gpt_2
textsynth
""".split())
def arg_parser_factory(*, description):
@ -67,7 +69,14 @@ async def make_post(cfg, *, mode=TextGenerationMode.markov):
from generators.markov import make_sentence
elif mode is TextGenerationMode.gpt_2:
from generators.gpt_2 import make_sentence
elif mode is TextGenerationMode.textsynth:
from generators.textsynth import make_sentence
# return await anyio.to_process.run_sync(make_sentence, cfg)
# check if inspect.iscoroutinefunction(object)
if inspect.iscoroutinefunction(make_sentence):
return await make_sentence(cfg)
else:
return await anyio.to_process.run_sync(make_sentence, cfg)
def extract_post_content(text):