replacer/app.py

75 lines
2.1 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-License-Identifier: AGPL-3.0-only
import io
import os
import asyncio
import aiohttp
import contextlib
import contextvars
from yarl import URL
from aiohttp import web
class ContextVar:
"""
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
"""
__slots__ = frozenset({'_var'})
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
def __call__(self, *args): return self._var.get(*args) # convenience
def get(self, *args): return self._var.get(*args)
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
pattern = os.environ['PATTERN'].encode()
repl = os.environ['REPL'].encode()
upstream = 'http://' + os.environ['UPSTREAM']
bind = os.environ['BIND']
port = int(os.environ.get('PORT', '8080'))
http = ContextVar('http')
async def handler(request):
async with http().request(
request.method,
URL(upstream).with_path(request.path),
headers=request.headers.copy(),
params=request.query.copy(),
data=await request.read() if request.can_read_body else None,
# proxy redirects as-is
allow_redirects=False,
) as upstream_resp:
resp = web.StreamResponse(status=upstream_resp.status, headers=upstream_resp.headers)
await resp.prepare(request)
if upstream_resp.content_type == 'text/html':
# iter_lines when
while (line := await upstream_resp.content.readline()):
await resp.write(line.replace(pattern, repl))
else:
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
await resp.write(chunk)
await resp.write_eof()
return resp
async def amain():
srv = web.Server(handler)
runner = web.ServerRunner(srv)
await runner.setup()
if '/' in bind:
site = web.UnixSite(runner, bind)
else:
site = web.TCPSite(runner, host=bind, port=port)
async with aiohttp.ClientSession() as http_:
http.set(http_)
await site.start()
await asyncio.sleep(float('inf'))
def main():
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain())
if __name__ == '__main__':
main()