forked from KayFaraday/pleroma-ebooks
add textsynth
This commit is contained in:
parent
b73526a895
commit
879eb32b2d
|
@ -13,5 +13,7 @@
|
|||
"overlap_ratio": 0.7,
|
||||
"generation_mode": "markov",
|
||||
"access_token": "",
|
||||
"db_path": ""
|
||||
}
|
||||
"db_path": "",
|
||||
"textsynth_token": null,
|
||||
"textsynth_engine_id": "gptneox_20B"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
import sqlite3
|
||||
from random import expovariate
|
||||
import typing
|
||||
import aiohttp
|
||||
|
||||
async def make_sentence(cfg):
|
||||
# set default
|
||||
if not cfg["textsynth_engine_id"]:
|
||||
cfg["textsynth_engine_id"] = "gptneox_20B"
|
||||
|
||||
if not cfg["textsynth_token"]:
|
||||
raise ValueError("textsynth_token is not set, create an account at textsynth.com")
|
||||
|
||||
db = sqlite3.connect(cfg["db_path"])
|
||||
db.text_factory = str
|
||||
c = db.cursor()
|
||||
if cfg['learn_from_cw']:
|
||||
ignored_cws_query_params = "(" + ",".join("?" * len(cfg["ignored_cws"])) + ")"
|
||||
toots = c.execute(
|
||||
f"""
|
||||
SELECT content
|
||||
FROM posts
|
||||
WHERE
|
||||
summary IS NULL
|
||||
OR summary NOT IN {ignored_cws_query_params}
|
||||
ORDER BY POST_ID DESC
|
||||
""",
|
||||
cfg["ignored_cws"],
|
||||
).fetchall()
|
||||
else:
|
||||
toots = c.execute(
|
||||
"""
|
||||
SELECT content
|
||||
FROM posts
|
||||
WHERE summary IS NULL
|
||||
ORDER POST_ID DESC
|
||||
LIMIT 10000
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
if not toots:
|
||||
raise ValueError("Database is empty! Try running main.py.")
|
||||
|
||||
|
||||
# get first element of tuple
|
||||
toots: list[str] = [toot[0] for toot in toots]
|
||||
|
||||
# replace stuff
|
||||
toots = [toot.translate(str.maketrans({
|
||||
ord('\n'): "\\n",
|
||||
})) for toot in toots]
|
||||
|
||||
new_toots = []
|
||||
|
||||
def sum_of_chars(list_of_strings: list[str]) -> int:
|
||||
return sum(len(string) + 1 for string in list_of_strings) # +1 for \n
|
||||
|
||||
while sum_of_chars(new_toots) < 8192:
|
||||
index = expovariate(1 / 10) % len(toots) # more likely to pick newer posts but older ones are also sometimes picked
|
||||
# round index to nearest int
|
||||
index = int(index)
|
||||
# remove toot from list
|
||||
toot = toots.pop(index)
|
||||
# add toot to new list
|
||||
new_toots.append(toot)
|
||||
|
||||
toots = new_toots
|
||||
|
||||
# concatenate toots
|
||||
toots = "\n".join(toots)
|
||||
|
||||
# truncate to last 8192 characters
|
||||
toots = toots[-8192:]
|
||||
|
||||
# raise ValueError("toots: " + toots)
|
||||
|
||||
# POST https://api.textsynth.com/v1/engines/{engine_id}/completions
|
||||
|
||||
# using aiohttp
|
||||
|
||||
post = None
|
||||
while post is None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.textsynth.com/v1/engines/{}/completions".format(cfg["textsynth_engine_id"]),
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(cfg["textsynth_token"]),
|
||||
},
|
||||
json={
|
||||
"prompt": toots + "\n",
|
||||
"stop": "\n",
|
||||
"max_tokens": 200
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("TextSynth API returned status code {}".format(resp.status))
|
||||
data = await resp.json()
|
||||
if not data["text"]:
|
||||
# raise ValueError("TextSynth API returned empty text")
|
||||
# just generate a new sentence
|
||||
continue
|
||||
post: str = data["text"]
|
||||
|
||||
# check wether the post only consists of mentions
|
||||
# split by words
|
||||
words = post.split()
|
||||
# check if all words are mentions
|
||||
if all(word.startswith("@") for word in words):
|
||||
# generate a new sentence
|
||||
post = None
|
||||
continue
|
||||
|
||||
db.close()
|
||||
|
||||
# replace stuff in post
|
||||
post = post.replace("\\n", "\n")
|
||||
|
||||
# and post it
|
||||
return post
|
|
@ -7,6 +7,7 @@ import html
|
|||
import enum
|
||||
import json
|
||||
import shutil
|
||||
import inspect
|
||||
import sqlite3
|
||||
import argparse
|
||||
import itertools
|
||||
|
@ -18,6 +19,7 @@ from bs4 import BeautifulSoup
|
|||
TextGenerationMode = enum.Enum('TextGenerationMode', """
|
||||
markov
|
||||
gpt_2
|
||||
textsynth
|
||||
""".split())
|
||||
|
||||
def arg_parser_factory(*, description):
|
||||
|
@ -67,8 +69,15 @@ async def make_post(cfg, *, mode=TextGenerationMode.markov):
|
|||
from generators.markov import make_sentence
|
||||
elif mode is TextGenerationMode.gpt_2:
|
||||
from generators.gpt_2 import make_sentence
|
||||
elif mode is TextGenerationMode.textsynth:
|
||||
from generators.textsynth import make_sentence
|
||||
|
||||
return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||
# return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||
# check if inspect.iscoroutinefunction(object)
|
||||
if inspect.iscoroutinefunction(make_sentence):
|
||||
return await make_sentence(cfg)
|
||||
else:
|
||||
return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||
|
||||
def extract_post_content(text):
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
|
Loading…
Reference in New Issue