#!/usr/bin/env python3 # SPDX-License-Identifier: AGPL-3.0-only import io import os import re import sys import utils import random import asyncio import aiohttp import contextlib import contextvars from yarl import URL from aiohttp import web import pytomlpp as toml from utils import build_dprint_factory, ContextVar def load_config(): with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: config = toml.load(f) for host in config['hosts'].values(): if not host.setdefault('replace', True): continue host['pattern_decoded'] = host['pattern'] host['pattern'] = re.compile(host['pattern'].encode()) host['repl_decoded'] = host['repl'] host['repl'] = host['repl'].encode() host['mime_types'] = frozenset(host.get('mime_types', ())) host['excluded_routes'].sort(key=len, reverse=True) return config config = load_config() dprint = build_dprint_factory(config.get('debug')) http = ContextVar('http') async def handler(request): hconfig = config['hosts'][request.host] print = dprint() async with http().request( request.method, # TODO support upstream unix sockets URL(hconfig['upstream']).with_path(request.path), headers=request.headers.copy(), params=request.query.copy(), data=await request.read() if request.can_read_body else None, # proxy redirects as-is allow_redirects=False, ) as upstream_resp: print(request.method, request.host, request.path, '→', hconfig['upstream']) headers = upstream_resp.headers.copy() # we're not using gzip here so don't confuse our client with contextlib.suppress(KeyError): del headers['Content-Encoding'] resp = web.StreamResponse(status=upstream_resp.status, headers=headers) await resp.prepare(request) if not hconfig['replace']: print('Not replacing for this host') return await proxy_passthrough(upstream_resp, resp) for excluded_route in hconfig['excluded_routes']: if request.path.startswith(excluded_route): print('Not replacing for', excluded_route) return await proxy_passthrough(upstream_resp, resp) if hconfig['mime_types'] and upstream_resp.content_type not in hconfig['mime_types']: print('Not configured to replace for MIME type', upstream_resp.content_type) return await proxy_passthrough(upstream_resp, resp) print('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded'])) return await proxy_replace(hconfig, upstream_resp, resp) async def proxy_replace(hconfig, upstream_resp, resp): # iter_lines when while (line := await upstream_resp.content.readline()): await resp.write(hconfig['pattern'].sub(hconfig['repl'], line)) return await finalize_resp(resp) async def proxy_passthrough(upstream_resp, resp): async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE): await resp.write(chunk) return await finalize_resp(resp) async def finalize_resp(resp): await resp.write_eof() return resp async def amain(): # It's not our job to hand out 413s, that's the job of upstream and downstream srv = web.Server(handler, request_factory=utils.make_unlimited_request) runner = web.ServerRunner(srv) await runner.setup() bind = os.path.expandvars(config['bind']) if (unix := '/' in bind): site = web.UnixSite(runner, bind) else: site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080))) async with aiohttp.ClientSession() as http_: http.set(http_) await site.start() print('Listening on', bind if unix else f'{bind}:{port}', end='...\n') if '/' in bind: with contextlib.suppress(KeyError): os.chmod(bind, config.get('socket_mod', 0o775)) await asyncio.sleep(float('inf')) def main(): with contextlib.suppress(KeyboardInterrupt): asyncio.run(amain()) if __name__ == '__main__': main()