122 lines
2.8 KiB
Python
122 lines
2.8 KiB
Python
# SPDX-License-Identifier: MPL-2.0
|
|
|
|
import sqlite3
|
|
from random import expovariate
|
|
import typing
|
|
import aiohttp
|
|
|
|
async def make_sentence(cfg):
|
|
# set default
|
|
if not cfg["textsynth_engine_id"]:
|
|
cfg["textsynth_engine_id"] = "gptneox_20B"
|
|
|
|
if not cfg["textsynth_token"]:
|
|
raise ValueError("textsynth_token is not set, create an account at textsynth.com")
|
|
|
|
db = sqlite3.connect(cfg["db_path"])
|
|
db.text_factory = str
|
|
c = db.cursor()
|
|
if cfg['learn_from_cw']:
|
|
ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")"
|
|
toots = c.execute(
|
|
f"""
|
|
SELECT content
|
|
FROM posts
|
|
WHERE
|
|
summary IS NULL
|
|
OR summary NOT IN {ignored_cws_query_params}
|
|
ORDER BY POST_ID DESC
|
|
""",
|
|
cfg["ignored_cws"],
|
|
).fetchall()
|
|
else:
|
|
toots = c.execute(
|
|
"""
|
|
SELECT content
|
|
FROM posts
|
|
WHERE summary IS NULL
|
|
ORDER POST_ID DESC
|
|
LIMIT 10000
|
|
""",
|
|
).fetchall()
|
|
|
|
if not toots:
|
|
raise ValueError("Database is empty! Try running main.py.")
|
|
|
|
|
|
# get first element of tuple
|
|
toots: list[str] = [toot[0] for toot in toots]
|
|
|
|
# replace stuff
|
|
toots = [toot.translate(str.maketrans({
|
|
ord('\n'): "\\n",
|
|
})) for toot in toots]
|
|
|
|
new_toots = []
|
|
|
|
def sum_of_chars(list_of_strings: list[str]) -> int:
|
|
return sum(len(string) + 1 for string in list_of_strings) # +1 for \n
|
|
|
|
while sum_of_chars(new_toots) < 8192:
|
|
index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked
|
|
# round index to nearest int
|
|
index = int(index)
|
|
# remove toot from list
|
|
toot = toots.pop(index)
|
|
# add toot to new list
|
|
new_toots.append(toot)
|
|
|
|
toots = new_toots
|
|
|
|
# concatenate toots
|
|
toots = "\n".join(toots)
|
|
|
|
# truncate to last 8192 characters
|
|
toots = toots[-8192:]
|
|
|
|
# raise ValueError("toots: " + toots)
|
|
|
|
# POST https://api.textsynth.com/v1/engines/{engine_id}/completions
|
|
|
|
# using aiohttp
|
|
|
|
post = None
|
|
while post is None:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
"https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]),
|
|
headers={
|
|
"Authorization": "Bearer {}".format(cfg["textsynth_token"]),
|
|
},
|
|
json={
|
|
"prompt": toots + "\n",
|
|
"stop": "\n",
|
|
"max_tokens": 200
|
|
},
|
|
) as resp:
|
|
if resp.status != 200:
|
|
raise ValueError("TextSynth API returned status code {}".format(resp.status))
|
|
data = await resp.json()
|
|
if not data["text"]:
|
|
# raise ValueError("TextSynth API returned empty text")
|
|
# just generate a new sentence
|
|
continue
|
|
post: str = data["text"]
|
|
|
|
# check wether the post only consists of mentions
|
|
# split by words
|
|
words = post.split()
|
|
# check if all words are mentions
|
|
if all(word.startswith("@") for word in words):
|
|
# generate a new sentence
|
|
post = None
|
|
continue
|
|
|
|
db.close()
|
|
|
|
# replace stuff in post
|
|
post = post.replace("\\n", "\n")
|
|
|
|
# and post it
|
|
return post
|