pleroma-ebooks/generators/textsynth.py

124 lines
2.9 KiB
Python

# SPDX-License-Identifier: MPL-2.0
import sqlite3
from random import expovariate
import typing
import aiohttp
from utils import http_session_factory
async def make_sentence(cfg):
# set default
if not cfg["textsynth_engine_id"]:
cfg["textsynth_engine_id"] = "gptneox_20B"
if not cfg["textsynth_token"]:
raise ValueError("textsynth_token is not set, create an account at textsynth.com")
db = sqlite3.connect(cfg["db_path"])
db.text_factory = str
c = db.cursor()
if cfg['learn_from_cw']:
ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")"
toots = c.execute(
f"""
SELECT content
FROM posts
WHERE
summary IS NULL
OR summary NOT IN {ignored_cws_query_params}
ORDER BY POST_ID DESC
""",
cfg["ignored_cws"],
).fetchall()
else:
toots = c.execute(
"""
SELECT content
FROM posts
WHERE summary IS NULL
ORDER POST_ID DESC
LIMIT 10000
""",
).fetchall()
if not toots:
raise ValueError("Database is empty! Try running main.py.")
# get first element of tuple
toots: list[str] = [toot[0] for toot in toots]
# replace stuff
toots = [toot.translate(str.maketrans({
ord('\n'): "\\n",
})) for toot in toots]
new_toots = []
def sum_of_chars(list_of_strings: list[str]) -> int:
return sum(len(string) + 1 for string in list_of_strings) # +1 for \n
while sum_of_chars(new_toots) < 8192:
index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked
# round index to nearest int
index = int(index)
# remove toot from list
toot = toots.pop(index)
# add toot to new list
new_toots.append(toot)
toots = new_toots
# concatenate toots
toots = "\n".join(toots)
# truncate to last 8192 characters
toots = toots[-8192:]
# raise ValueError("toots: " + toots)
# POST https://api.textsynth.com/v1/engines/{engine_id}/completions
# using aiohttp
post = None
async with http_session_factory() as session:
while post is None:
async with session.post(
"https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]),
headers={
"Authorization": "Bearer {}".format(cfg["textsynth_token"]),
},
json={
"prompt": toots + "\n",
"stop": "\n",
"max_tokens": 200
},
) as resp:
if resp.status != 200:
raise ValueError("TextSynth API returned status code {}".format(resp.status))
data = await resp.json()
if not data["text"]:
# raise ValueError("TextSynth API returned empty text")
# just generate a new sentence
continue
post: str = data["text"]
# check wether the post only consists of mentions
# split by words
words = post.split()
# check if all words are mentions
if all(word.startswith("@") for word in words):
# generate a new sentence
post = None
continue
db.close()
# replace stuff in post
post = post.replace("\\n", "\n")
# and post it
return post