diff --git a/.gitignore b/.gitignore index c18dd8d..3497d6b 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ __pycache__/ +*.toml +!*.example.toml diff --git a/app.py b/app.py index 28c7871..2a8c327 100755 --- a/app.py +++ b/app.py @@ -4,61 +4,55 @@ import io import os import sys +import random import asyncio import aiohttp import contextlib import contextvars from yarl import URL from aiohttp import web +import pytomlpp as toml +from utils import build_dprint, ContextVar -class ContextVar: - """ - contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get() - """ - __slots__ = frozenset({'_var'}) - def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw) - def __call__(self, *args): return self._var.get(*args) # convenience - def get(self, *args): return self._var.get(*args) - def set(self, *args): return self._var.set(*args) - def reset(self, *args): return self._var.reset(*args) +with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: + config = toml.load(f) -pattern = os.environ['PATTERN'].encode() -repl = os.environ['REPL'].encode() -# TODO support upstream unix sockets -upstream = os.environ['UPSTREAM'] -bind = sys.argv[1] -socket_mod = int(os.environ.get('SOCKET_MODE', '775'), 8) -replace_for = os.environ.get('REPLACE_FOR_HOST') -try: - port = int(sys.argv[2]) -except IndexError: - port = 8080 +for host in config['hosts'].values(): + if not host.get('replace', True): continue + host['pattern_decoded'] = host['pattern'] + host['pattern'] = host['pattern'].encode() + host['repl_decoded'] = host['repl'] + host['repl'] = host['repl'].encode() + +dprint = build_dprint(config.get('debug')) http = ContextVar('http') async def handler(request): - async with http().request( + hconfig = config['hosts'][request.host] + async with dprint() as print, http().request( request.method, - URL(upstream).with_path(request.path), + # TODO support upstream unix sockets + URL(hconfig['upstream']).with_path(request.path), headers=request.headers.copy(), params=request.query.copy(), data=await request.read() if request.can_read_body else None, # proxy redirects as-is allow_redirects=False, ) as upstream_resp: + print('proceeding for', request.host, request.path, '→', hconfig['upstream']) headers = upstream_resp.headers.copy() # we're not using gzip here so don't confuse our client with contextlib.suppress(KeyError): del headers['Content-Encoding'] resp = web.StreamResponse(status=upstream_resp.status, headers=headers) await resp.prepare(request) - if ( - upstream_resp.content_type == 'text/html' - and (replace_for is None or request.host == replace_for) - ): + if upstream_resp.content_type == 'text/html' and hconfig.get('replace', True): # iter_lines when + print('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded'])) while (line := await upstream_resp.content.readline()): - await resp.write(line.replace(pattern, repl)) + await resp.write(line.replace(hconfig['pattern'], hconfig['repl'])) else: + print('not replacing') async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE): await resp.write(chunk) @@ -69,16 +63,18 @@ async def amain(): srv = web.Server(handler) runner = web.ServerRunner(srv) await runner.setup() - if '/' in bind: + bind = os.path.expandvars(config['bind']) + if (unix := '/' in bind): site = web.UnixSite(runner, bind) else: - site = web.TCPSite(runner, host=bind, port=port) + site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080))) async with aiohttp.ClientSession() as http_: http.set(http_) await site.start() + print('Listening on', bind if unix else f'{bind}:{port}', end='...\n') if '/' in bind: with contextlib.suppress(KeyError): - os.chmod(bind, socket_mod) + os.chmod(bind, config.get('socket_mod', 0o775)) await asyncio.sleep(float('inf')) def main(): diff --git a/config.example.toml b/config.example.toml new file mode 100644 index 0000000..34790a3 --- /dev/null +++ b/config.example.toml @@ -0,0 +1,19 @@ +bind = '127.0.0.1' +# you can also bind to a unix socket, and environment vars are supported: +# bind = '$RUNTIME_DIRECTORY/replacer.sock' +port = '8000' +# if using a unix socket, you can specify the file access mode here. it defaults to 0o775. +socket_mod = 0o777 +# whether to enable request logging. disabled by default +debug = True + +# replace foo with bar for site1.example +[hosts."site1.example"] +upstream = 'http://localhost:3001' +pattern = 'foo' +repl = 'bar' + +# just pass through site2.example unmodified +[hosts."site2.example"] +upstream = 'http://localhost:3001' +replace = false diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e67b70b --- /dev/null +++ b/utils.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import random +import contextlib +import contextvars + +class ContextVar: + """ + contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get() + """ + __slots__ = frozenset({'_var'}) + def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw) + def __call__(self, *args): return self._var.get(*args) # convenience + def get(self, *args): return self._var.get(*args) + def set(self, *args): return self._var.set(*args) + def reset(self, *args): return self._var.reset(*args) + +class asyncnullcontext(contextlib.AbstractAsyncContextManager): + async def __aenter__(self): return None + async def __aexit__(self, *excinfo): return None + +def build_dprint(debug: bool): + if debug: + class dprint(asyncnullcontext): + # 8 byte pseudo-random request tag + def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2)) + def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs) + async def __aenter__(self): return self.print + else: + class dprint(asyncnullcontext): + async def __aenter__(self): + return lambda *args, **kwargs: None + + return dprint