add textsynth #2
|
@ -13,5 +13,7 @@
|
||||||
"overlap_ratio": 0.7,
|
"overlap_ratio": 0.7,
|
||||||
"generation_mode": "markov",
|
"generation_mode": "markov",
|
||||||
"access_token": "",
|
"access_token": "",
|
||||||
"db_path": ""
|
"db_path": "",
|
||||||
}
|
"textsynth_token": null,
|
||||||
|
"textsynth_engine_id": "gptneox_20B"
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
# SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from random import expovariate
|
||||||
|
import typing
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
async def make_sentence(cfg):
|
||||||
|
# set default
|
||||||
|
if not cfg["textsynth_engine_id"]:
|
||||||
|
cfg["textsynth_engine_id"] = "gptneox_20B"
|
||||||
|
|
||||||
|
if not cfg["textsynth_token"]:
|
||||||
|
raise ValueError("textsynth_token is not set, create an account at textsynth.com")
|
||||||
|
|
||||||
|
db = sqlite3.connect(cfg["db_path"])
|
||||||
|
db.text_factory = str
|
||||||
|
c = db.cursor()
|
||||||
|
if cfg['learn_from_cw']:
|
||||||
|
ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")"
|
||||||
|
toots = c.execute(
|
||||||
|
f"""
|
||||||
|
SELECT content
|
||||||
|
FROM posts
|
||||||
|
WHERE
|
||||||
|
summary IS NULL
|
||||||
|
OR summary NOT IN {ignored_cws_query_params}
|
||||||
|
ORDER BY POST_ID DESC
|
||||||
|
""",
|
||||||
|
cfg["ignored_cws"],
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
toots = c.execute(
|
||||||
|
"""
|
||||||
|
SELECT content
|
||||||
|
FROM posts
|
||||||
|
WHERE summary IS NULL
|
||||||
|
ORDER POST_ID DESC
|
||||||
|
LIMIT 10000
|
||||||
|
""",
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not toots:
|
||||||
|
raise ValueError("Database is empty! Try running main.py.")
|
||||||
|
|
||||||
|
|
||||||
|
# get first element of tuple
|
||||||
|
toots: list[str] = [toot[0] for toot in toots]
|
||||||
|
|
||||||
|
# replace stuff
|
||||||
|
toots = [toot.translate(str.maketrans({
|
||||||
|
ord('\n'): "\\n",
|
||||||
|
})) for toot in toots]
|
||||||
|
|
||||||
|
new_toots = []
|
||||||
|
|
||||||
|
def sum_of_chars(list_of_strings: list[str]) -> int:
|
||||||
|
return sum(len(string) + 1 for string in list_of_strings) # +1 for \n
|
||||||
|
|
||||||
|
while sum_of_chars(new_toots) < 8192:
|
||||||
|
index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked
|
||||||
|
# round index to nearest int
|
||||||
|
index = int(index)
|
||||||
|
# remove toot from list
|
||||||
|
toot = toots.pop(index)
|
||||||
|
# add toot to new list
|
||||||
|
new_toots.append(toot)
|
||||||
|
|
||||||
|
toots = new_toots
|
||||||
|
|
||||||
|
# concatenate toots
|
||||||
|
toots = "\n".join(toots)
|
||||||
|
|
||||||
|
# truncate to last 8192 characters
|
||||||
|
toots = toots[-8192:]
|
||||||
|
|
||||||
|
# raise ValueError("toots: " + toots)
|
||||||
|
|
||||||
|
# POST https://api.textsynth.com/v1/engines/{engine_id}/completions
|
||||||
|
|
||||||
|
# using aiohttp
|
||||||
|
|
||||||
|
post = None
|
||||||
|
while post is None:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]),
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer {}".format(cfg["textsynth_token"]),
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"prompt": toots + "\n",
|
||||||
|
"stop": "\n",
|
||||||
|
"max_tokens": 200
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise ValueError("TextSynth API returned status code {}".format(resp.status))
|
||||||
|
data = await resp.json()
|
||||||
|
if not data["text"]:
|
||||||
|
# raise ValueError("TextSynth API returned empty text")
|
||||||
|
# just generate a new sentence
|
||||||
|
continue
|
||||||
|
post: str = data["text"]
|
||||||
|
|
||||||
|
# check wether the post only consists of mentions
|
||||||
|
# split by words
|
||||||
|
words = post.split()
|
||||||
|
# check if all words are mentions
|
||||||
|
if all(word.startswith("@") for word in words):
|
||||||
|
# generate a new sentence
|
||||||
|
post = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# replace stuff in post
|
||||||
|
post = post.replace("\\n", "\n")
|
||||||
|
|
||||||
|
# and post it
|
||||||
|
return post
|
|
@ -7,6 +7,7 @@ import html
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
|
import inspect
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -18,6 +19,7 @@ from bs4 import BeautifulSoup
|
||||||
TextGenerationMode = enum.Enum('TextGenerationMode', """
|
TextGenerationMode = enum.Enum('TextGenerationMode', """
|
||||||
markov
|
markov
|
||||||
gpt_2
|
gpt_2
|
||||||
|
textsynth
|
||||||
""".split())
|
""".split())
|
||||||
|
|
||||||
def arg_parser_factory(*, description):
|
def arg_parser_factory(*, description):
|
||||||
|
@ -67,8 +69,15 @@ async def make_post(cfg, *, mode=TextGenerationMode.markov):
|
||||||
from generators.markov import make_sentence
|
from generators.markov import make_sentence
|
||||||
elif mode is TextGenerationMode.gpt_2:
|
elif mode is TextGenerationMode.gpt_2:
|
||||||
from generators.gpt_2 import make_sentence
|
from generators.gpt_2 import make_sentence
|
||||||
|
elif mode is TextGenerationMode.textsynth:
|
||||||
|
from generators.textsynth import make_sentence
|
||||||
|
|
||||||
return await anyio.to_process.run_sync(make_sentence, cfg)
|
# return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||||
|
# check if inspect.iscoroutinefunction(object)
|
||||||
|
if inspect.iscoroutinefunction(make_sentence):
|
||||||
|
return await make_sentence(cfg)
|
||||||
|
else:
|
||||||
|
return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||||
|
|
||||||
def extract_post_content(text):
|
def extract_post_content(text):
|
||||||
soup = BeautifulSoup(text, "html.parser")
|
soup = BeautifulSoup(text, "html.parser")
|
||||||
|
|
Loading…
Reference in New Issue