dprint doesn't need to be a ctxman

This commit is contained in:
Kay Faraday 2021-09-08 09:51:40 +00:00
parent 99cf301fd6
commit 551e21def6
2 changed files with 12 additions and 17 deletions

View File

@ -14,7 +14,7 @@ import contextvars
from yarl import URL
from aiohttp import web
import pytomlpp as toml
from utils import build_dprint, ContextVar
from utils import build_dprint_factory, ContextVar
def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
@ -32,12 +32,13 @@ def load_config():
return config
config = load_config()
dprint = build_dprint(config.get('debug'))
dprint = build_dprint_factory(config.get('debug'))
http = ContextVar('http')
async def handler(request):
hconfig = config['hosts'][request.host]
async with dprint() as print, http().request(
print = dprint()
async with http().request(
request.method,
# TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path),

View File

@ -2,7 +2,6 @@
import random
import asyncio
import contextlib
import contextvars
from aiohttp import web
@ -17,23 +16,18 @@ class ContextVar:
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
class asyncnullcontext(contextlib.AbstractAsyncContextManager):
async def __aenter__(self): return None
async def __aexit__(self, *excinfo): return None
def build_dprint(debug: bool):
def build_dprint_factory(debug: bool):
if debug:
class dprint(asyncnullcontext):
def dprint_factory():
# 8 byte pseudo-random request tag
def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2))
def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs)
async def __aenter__(self): return self.print
req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2))
def dprint(*args, **kwargs): print(f'[{req_id}]', *args, **kwargs)
return dprint
else:
class dprint(asyncnullcontext):
async def __aenter__(self):
return lambda *args, **kwargs: None
def dprint(*args, **kwargs): return None
def dprint_factory(): return dprint
return dprint
return dprint_factory
def make_unlimited_request(*args):
return web.BaseRequest(*args, loop=asyncio.get_running_loop(), client_max_size=None)