replacer/replacer.py

154 lines
4.5 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-License-Identifier: AGPL-3.0-only
import io
import os
import re
import sys
import utils
import random
import asyncio
import aiohttp
import contextlib
import contextvars
import qtoml as toml
from yarl import URL
from aiohttp import web
from utils import build_dprint_factory, ContextVar, MultiReplacer
def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
config = toml.load(f)
for host in config['hosts'].values():
if not host.setdefault('replace', True): continue
host['mime_types'] = frozenset(host.get('mime_types', ()))
host['replacer'] = MultiReplacer({
replacement['pattern'].encode(): replacement['repl'].encode()
for replacement
in host['replacements']
})
try:
host['excluded_routes'].sort(key=len, reverse=True)
except KeyError:
host['excluded_routes'] = ()
return config
config = load_config()
dprint_factory = build_dprint_factory(config.get('debug'))
dprint = ContextVar('dprint')
http = ContextVar('http')
async def handler(request):
try:
await _handler(request)
except aiohttp.ClientConnectorError:
raise web.HTTPServiceUnavailable(
text=f'{request.host} is temporarily unavailable. Try again later.',
)
async def _handler(request):
try:
hconfig = config['hosts'][request.host]
except KeyError:
# this should only happen if our reverse proxy is set to proxy everything to us
raise web.HTTPNotFound()
dprint.set(dprint_factory())
async with http().request(
request.method,
# TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path),
headers=request.headers.copy(),
params=request.query.copy(),
data=await request.read() if request.can_read_body else None,
# proxy redirects as-is
allow_redirects=False,
) as upstream_resp:
dprint()(request.method, request.host, request.path, '', hconfig['upstream'])
headers = upstream_resp.headers.copy()
# we're not using gzip here so don't confuse our client
with contextlib.suppress(KeyError): del headers['Content-Encoding']
resp = web.StreamResponse(status=upstream_resp.status, headers=headers)
await resp.prepare(request)
if not hconfig['replace']:
dprint()('Not replacing for this host')
return await proxy_passthrough(upstream_resp, resp)
for excluded_route in hconfig['excluded_routes']:
if request.path.startswith(excluded_route):
dprint()('Not replacing for', excluded_route)
return await proxy_passthrough(upstream_resp, resp)
type, slash, subtype = upstream_resp.content_type.partition('/')
if slash != '/':
base_mime_type = upstream_resp.content_type
else:
# application/activity+json → application/json
extra, plus, base = subtype.partition('+')
if plus != '+':
base_mime_type = upstream_resp.content_type
else:
base_mime_type = type + slash + base
if hconfig['mime_types'] and base_mime_type not in hconfig['mime_types']:
dprint()(
f'Not configured to replace for MIME type: {upstream_resp.content_type}',
f'(base MIME type: {base_mime_type})',
)
return await proxy_passthrough(upstream_resp, resp)
dprint()('replacing', hconfig['replacer'].replacements)
return await proxy_replace(hconfig, upstream_resp, resp)
async def proxy_replace(hconfig, upstream_resp, resp):
# iter_lines when
while (line := await upstream_resp.content.readline()):
await resp.write(hconfig['replacer'].replace(line))
return await finalize_resp(resp)
async def proxy_passthrough(upstream_resp, resp):
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
await resp.write(chunk)
return await finalize_resp(resp)
async def finalize_resp(resp):
await resp.write_eof()
return resp
async def amain():
srv = web.Server(
handler,
# It's not our job to hand out 413s, that's the job of upstream and downstream
request_factory=utils.make_unlimited_request,
)
runner = web.ServerRunner(srv)
await runner.setup()
bind = os.path.expandvars(config['bind'])
if (bind_unix := '/' in bind):
site = web.UnixSite(runner, bind)
else:
site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080)))
async with aiohttp.ClientSession() as http_:
http.set(http_)
await site.start()
print('Listening on', bind if bind_unix else f'{bind}:{port}', end='...\n')
if bind_unix:
with contextlib.suppress(KeyError):
os.chmod(bind, config.get('socket_mod', 0o775))
await asyncio.sleep(float('inf'))
def main():
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain())
if __name__ == '__main__':
main()