make third_party.utils.make_toot async
This commit is contained in:
parent
4e4619fbe0
commit
d0965d437b
2
gen.py
2
gen.py
|
@ -22,7 +22,7 @@ async def main():
|
|||
args = parse_args()
|
||||
cfg = utils.load_config(args.cfg)
|
||||
|
||||
toot = utils.make_toot(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
|
||||
toot = await utils.make_post(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
|
||||
if cfg['strip_paired_punctuation']:
|
||||
toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot)
|
||||
if not args.simulate:
|
||||
|
|
|
@ -12,6 +12,7 @@ import argparse
|
|||
import itertools
|
||||
import json5 as json
|
||||
import multiprocessing
|
||||
import anyio.to_process
|
||||
from random import randint
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
@ -61,37 +62,13 @@ def remove_mention(cfg, sentence):
|
|||
|
||||
return sentence
|
||||
|
||||
def _wrap_pipe(f):
|
||||
def g(pout, *args, **kwargs):
|
||||
try:
|
||||
pout.send(f(*args, **kwargs))
|
||||
except ValueError as exc:
|
||||
pout.send(exc.args[0])
|
||||
return g
|
||||
|
||||
def make_toot(cfg, *, mode=TextGenerationMode.markov):
|
||||
toot = None
|
||||
pin, pout = multiprocessing.Pipe(False)
|
||||
|
||||
async def make_post(cfg, *, mode=TextGenerationMode.markov):
|
||||
if mode is TextGenerationMode.markov:
|
||||
from generators.markov import make_sentence
|
||||
elif mode is TextGenerationMode.gpt_2:
|
||||
from generators.gpt_2 import make_sentence
|
||||
else:
|
||||
raise ValueError('Invalid text generation mode')
|
||||
|
||||
p = multiprocessing.Process(target=_wrap_pipe(make_sentence), args=[pout, cfg])
|
||||
p.start()
|
||||
p.join(5) # wait 5 seconds to get something
|
||||
if p.is_alive(): # if it's still trying to make a toot after 5 seconds
|
||||
p.terminate()
|
||||
p.join()
|
||||
else:
|
||||
toot = pin.recv()
|
||||
|
||||
if toot is None:
|
||||
toot = 'Toot generation failed! Contact io@csdisaster.club for assistance.'
|
||||
return toot
|
||||
return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||
|
||||
def extract_post_content(text):
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
|
7
utils.py
7
utils.py
|
@ -1,12 +1,11 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
import anyio
|
||||
import functools
|
||||
from bs4 import BeautifulSoup
|
||||
from functools import wraps
|
||||
|
||||
def shield(f):
|
||||
@functools.wraps(f)
|
||||
@wraps(f)
|
||||
async def shielded(*args, **kwargs):
|
||||
with anyio.CancelScope(shield=True) as cs:
|
||||
with anyio.CancelScope(shield=True):
|
||||
return await f(*args, **kwargs)
|
||||
return shielded
|
||||
|
|
Loading…
Reference in New Issue