From 551e21def6554ba77a264fadcc1c0ea1c98742a2 Mon Sep 17 00:00:00 2001 From: Kay Faraday Date: Wed, 8 Sep 2021 09:51:40 +0000 Subject: [PATCH] dprint doesn't need to be a ctxman --- replacer.py | 7 ++++--- utils.py | 22 ++++++++-------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/replacer.py b/replacer.py index a9b34d9..5770905 100755 --- a/replacer.py +++ b/replacer.py @@ -14,7 +14,7 @@ import contextvars from yarl import URL from aiohttp import web import pytomlpp as toml -from utils import build_dprint, ContextVar +from utils import build_dprint_factory, ContextVar def load_config(): with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: @@ -32,12 +32,13 @@ def load_config(): return config config = load_config() -dprint = build_dprint(config.get('debug')) +dprint = build_dprint_factory(config.get('debug')) http = ContextVar('http') async def handler(request): hconfig = config['hosts'][request.host] - async with dprint() as print, http().request( + print = dprint() + async with http().request( request.method, # TODO support upstream unix sockets URL(hconfig['upstream']).with_path(request.path), diff --git a/utils.py b/utils.py index 51d9f07..53b2eb8 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,6 @@ import random import asyncio -import contextlib import contextvars from aiohttp import web @@ -17,23 +16,18 @@ class ContextVar: def set(self, *args): return self._var.set(*args) def reset(self, *args): return self._var.reset(*args) -class asyncnullcontext(contextlib.AbstractAsyncContextManager): - async def __aenter__(self): return None - async def __aexit__(self, *excinfo): return None - -def build_dprint(debug: bool): +def build_dprint_factory(debug: bool): if debug: - class dprint(asyncnullcontext): + def dprint_factory(): # 8 byte pseudo-random request tag - def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2)) - def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs) - async def __aenter__(self): return self.print + req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2)) + def dprint(*args, **kwargs): print(f'[{req_id}]', *args, **kwargs) + return dprint else: - class dprint(asyncnullcontext): - async def __aenter__(self): - return lambda *args, **kwargs: None + def dprint(*args, **kwargs): return None + def dprint_factory(): return dprint - return dprint + return dprint_factory def make_unlimited_request(*args): return web.BaseRequest(*args, loop=asyncio.get_running_loop(), client_max_size=None)