replacer/replacer.py

114 lines
3.7 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-License-Identifier: AGPL-3.0-only
import io
import os
import re
import sys
import utils
import random
import asyncio
import aiohttp
import contextlib
import contextvars
from yarl import URL
from aiohttp import web
import pytomlpp as toml
from utils import build_dprint_factory, ContextVar
def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
config = toml.load(f)
for host in config['hosts'].values():
if not host.setdefault('replace', True): continue
host['pattern_decoded'] = host['pattern']
host['pattern'] = re.compile(host['pattern'].encode())
host['repl_decoded'] = host['repl']
host['repl'] = host['repl'].encode()
host['mime_types'] = frozenset(host.get('mime_types', ()))
host['excluded_routes'].sort(key=len, reverse=True)
return config
config = load_config()
dprint = build_dprint_factory(config.get('debug'))
http = ContextVar('http')
async def handler(request):
hconfig = config['hosts'][request.host]
print = dprint()
async with http().request(
request.method,
# TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path),
headers=request.headers.copy(),
params=request.query.copy(),
data=await request.read() if request.can_read_body else None,
# proxy redirects as-is
allow_redirects=False,
) as upstream_resp:
print(request.method, request.host, request.path, '', hconfig['upstream'])
headers = upstream_resp.headers.copy()
# we're not using gzip here so don't confuse our client
with contextlib.suppress(KeyError): del headers['Content-Encoding']
resp = web.StreamResponse(status=upstream_resp.status, headers=headers)
await resp.prepare(request)
if not hconfig['replace']:
print('Not replacing for this host')
return await proxy_passthrough(upstream_resp, resp)
for excluded_route in hconfig['excluded_routes']:
if request.path.startswith(excluded_route):
print('Not replacing for', excluded_route)
return await proxy_passthrough(upstream_resp, resp)
if hconfig['mime_types'] and upstream_resp.content_type not in hconfig['mime_types']:
print('Not configured to replace for MIME type', upstream_resp.content_type)
return await proxy_passthrough(upstream_resp, resp)
print('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded']))
return await proxy_replace(hconfig, upstream_resp, resp)
async def proxy_replace(hconfig, upstream_resp, resp):
# iter_lines when
while (line := await upstream_resp.content.readline()):
await resp.write(hconfig['pattern'].sub(hconfig['repl'], line))
return await finalize_resp(resp)
async def proxy_passthrough(upstream_resp, resp):
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
await resp.write(chunk)
return await finalize_resp(resp)
async def finalize_resp(resp):
await resp.write_eof()
return resp
async def amain():
# It's not our job to hand out 413s, that's the job of upstream and downstream
srv = web.Server(handler, request_factory=utils.make_unlimited_request)
runner = web.ServerRunner(srv)
await runner.setup()
bind = os.path.expandvars(config['bind'])
if (unix := '/' in bind):
site = web.UnixSite(runner, bind)
else:
site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080)))
async with aiohttp.ClientSession() as http_:
http.set(http_)
await site.start()
print('Listening on', bind if unix else f'{bind}:{port}', end='...\n')
if '/' in bind:
with contextlib.suppress(KeyError):
os.chmod(bind, config.get('socket_mod', 0o775))
await asyncio.sleep(float('inf'))
def main():
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain())
if __name__ == '__main__':
main()