#!/usr/bin/env python3 # SPDX-License-Identifier: AGPL-3.0-only import io import os import asyncio import aiohttp import contextlib import contextvars from yarl import URL from aiohttp import web class ContextVar: """ contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get() """ __slots__ = frozenset({'_var'}) def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw) def __call__(self, *args): return self._var.get(*args) # convenience def get(self, *args): return self._var.get(*args) def set(self, *args): return self._var.set(*args) def reset(self, *args): return self._var.reset(*args) pattern = os.environ['PATTERN'].encode() repl = os.environ['REPL'].encode() upstream = 'http://' + os.environ['UPSTREAM'] bind = os.environ['BIND'] port = int(os.environ.get('PORT', '8080')) http = ContextVar('http') async def handler(request): async with http().request( request.method, URL(upstream).with_path(request.path), headers=request.headers.copy(), params=request.query.copy(), data=await request.read() if request.can_read_body else None, # proxy redirects as-is allow_redirects=False, ) as upstream_resp: resp = web.StreamResponse(status=upstream_resp.status, headers=upstream_resp.headers) await resp.prepare(request) if upstream_resp.content_type == 'text/html': # iter_lines when while (line := await upstream_resp.content.readline()): await resp.write(line.replace(pattern, repl)) else: async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE): await resp.write(chunk) await resp.write_eof() return resp async def amain(): srv = web.Server(handler) runner = web.ServerRunner(srv) await runner.setup() if '/' in bind: site = web.UnixSite(runner, bind) else: site = web.TCPSite(runner, host=bind, port=port) async with aiohttp.ClientSession() as http_: http.set(http_) await site.start() await asyncio.sleep(float('inf')) def main(): with contextlib.suppress(KeyboardInterrupt): asyncio.run(amain()) if __name__ == '__main__': main()