dprint doesn't need to be a ctxman

This commit is contained in:
Kay Faraday 2021-09-08 09:51:40 +00:00
parent 99cf301fd6
commit 551e21def6
2 changed files with 12 additions and 17 deletions

View File

@ -14,7 +14,7 @@ import contextvars
from yarl import URL from yarl import URL
from aiohttp import web from aiohttp import web
import pytomlpp as toml import pytomlpp as toml
from utils import build_dprint, ContextVar from utils import build_dprint_factory, ContextVar
def load_config(): def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
@ -32,12 +32,13 @@ def load_config():
return config return config
config = load_config() config = load_config()
dprint = build_dprint(config.get('debug')) dprint = build_dprint_factory(config.get('debug'))
http = ContextVar('http') http = ContextVar('http')
async def handler(request): async def handler(request):
hconfig = config['hosts'][request.host] hconfig = config['hosts'][request.host]
async with dprint() as print, http().request( print = dprint()
async with http().request(
request.method, request.method,
# TODO support upstream unix sockets # TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path), URL(hconfig['upstream']).with_path(request.path),

View File

@ -2,7 +2,6 @@
import random import random
import asyncio import asyncio
import contextlib
import contextvars import contextvars
from aiohttp import web from aiohttp import web
@ -17,23 +16,18 @@ class ContextVar:
def set(self, *args): return self._var.set(*args) def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args) def reset(self, *args): return self._var.reset(*args)
class asyncnullcontext(contextlib.AbstractAsyncContextManager): def build_dprint_factory(debug: bool):
async def __aenter__(self): return None
async def __aexit__(self, *excinfo): return None
def build_dprint(debug: bool):
if debug: if debug:
class dprint(asyncnullcontext): def dprint_factory():
# 8 byte pseudo-random request tag # 8 byte pseudo-random request tag
def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2)) req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2))
def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs) def dprint(*args, **kwargs): print(f'[{req_id}]', *args, **kwargs)
async def __aenter__(self): return self.print return dprint
else: else:
class dprint(asyncnullcontext): def dprint(*args, **kwargs): return None
async def __aenter__(self): def dprint_factory(): return dprint
return lambda *args, **kwargs: None
return dprint return dprint_factory
def make_unlimited_request(*args): def make_unlimited_request(*args):
return web.BaseRequest(*args, loop=asyncio.get_running_loop(), client_max_size=None) return web.BaseRequest(*args, loop=asyncio.get_running_loop(), client_max_size=None)