#!/usr/bin/env python3 # SPDX-License-Identifier: AGPL-3.0-only import io import os import re import sys import utils import random import asyncio import aiohttp import contextlib import contextvars import qtoml as toml from yarl import URL from aiohttp import web from utils import build_dprint_factory, ContextVar, MultiReplacer def load_config(): with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: config = toml.load(f) for host in config['hosts'].values(): if not host.setdefault('replace', True): continue host['mime_types'] = frozenset(host.get('mime_types', ())) host['replacer'] = MultiReplacer({ replacement['pattern'].encode(): replacement['repl'].encode() for replacement in host['replacements'] }) try: host['excluded_routes'].sort(key=len, reverse=True) except KeyError: host['excluded_routes'] = () return config config = load_config() dprint_factory = build_dprint_factory(config.get('debug')) dprint = ContextVar('dprint') http = ContextVar('http') async def handler(request): try: await _handler(request) except aiohttp.ClientConnectorError: raise web.HTTPServiceUnavailable( text=f'{request.host} is temporarily unavailable. Try again later.', ) async def _handler(request): try: hconfig = config['hosts'][request.host] except KeyError: # this should only happen if our reverse proxy is set to proxy everything to us raise web.HTTPNotFound() dprint.set(dprint_factory()) async with http().request( request.method, # TODO support upstream unix sockets URL(hconfig['upstream']).with_path(request.path), headers=request.headers.copy(), params=request.query.copy(), data=await request.read() if request.can_read_body else None, # proxy redirects as-is allow_redirects=False, ) as upstream_resp: dprint()(request.method, request.host, request.path, '→', hconfig['upstream']) headers = upstream_resp.headers.copy() # we're not using gzip here so don't confuse our client with contextlib.suppress(KeyError): del headers['Content-Encoding'] resp = web.StreamResponse(status=upstream_resp.status, headers=headers) await resp.prepare(request) if not hconfig['replace']: dprint()('Not replacing for this host') return await proxy_passthrough(upstream_resp, resp) for excluded_route in hconfig['excluded_routes']: if request.path.startswith(excluded_route): dprint()('Not replacing for', excluded_route) return await proxy_passthrough(upstream_resp, resp) type, slash, subtype = upstream_resp.content_type.partition('/') if slash != '/': base_mime_type = upstream_resp.content_type else: # application/activity+json → application/json extra, plus, base = subtype.partition('+') if plus != '+': base_mime_type = upstream_resp.content_type else: base_mime_type = type + slash + base if hconfig['mime_types'] and base_mime_type not in hconfig['mime_types']: dprint()( f'Not configured to replace for MIME type: {upstream_resp.content_type}', f'(base MIME type: {base_mime_type})', ) return await proxy_passthrough(upstream_resp, resp) dprint()('replacing', hconfig['replacer'].replacements) return await proxy_replace(hconfig, upstream_resp, resp) async def proxy_replace(hconfig, upstream_resp, resp): # iter_lines when while (line := await upstream_resp.content.readline()): await resp.write(hconfig['replacer'].replace(line)) return await finalize_resp(resp) async def proxy_passthrough(upstream_resp, resp): async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE): await resp.write(chunk) return await finalize_resp(resp) async def finalize_resp(resp): await resp.write_eof() return resp async def amain(): srv = web.Server( handler, # It's not our job to hand out 413s, that's the job of upstream and downstream request_factory=utils.make_unlimited_request, ) runner = web.ServerRunner(srv) await runner.setup() bind = os.path.expandvars(config['bind']) if (bind_unix := '/' in bind): site = web.UnixSite(runner, bind) else: site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080))) async with aiohttp.ClientSession() as http_: http.set(http_) await site.start() print('Listening on', bind if bind_unix else f'{bind}:{port}', end='...\n') if bind_unix: with contextlib.suppress(KeyError): os.chmod(bind, config.get('socket_mod', 0o775)) await asyncio.sleep(float('inf')) def main(): with contextlib.suppress(KeyboardInterrupt): asyncio.run(amain()) if __name__ == '__main__': main()