154 lines
4.5 KiB
Python
Executable File
154 lines
4.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# SPDX-License-Identifier: AGPL-3.0-only
|
|
|
|
import io
|
|
import os
|
|
import re
|
|
import sys
|
|
import utils
|
|
import random
|
|
import asyncio
|
|
import aiohttp
|
|
import contextlib
|
|
import contextvars
|
|
import qtoml as toml
|
|
from yarl import URL
|
|
from aiohttp import web
|
|
from utils import build_dprint_factory, ContextVar, MultiReplacer
|
|
|
|
def load_config():
|
|
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
|
|
config = toml.load(f)
|
|
|
|
for host in config['hosts'].values():
|
|
if not host.setdefault('replace', True): continue
|
|
host['mime_types'] = frozenset(host.get('mime_types', ()))
|
|
|
|
host['replacer'] = MultiReplacer({
|
|
replacement['pattern'].encode(): replacement['repl'].encode()
|
|
for replacement
|
|
in host['replacements']
|
|
})
|
|
|
|
try:
|
|
host['excluded_routes'].sort(key=len, reverse=True)
|
|
except KeyError:
|
|
host['excluded_routes'] = ()
|
|
|
|
return config
|
|
|
|
config = load_config()
|
|
dprint_factory = build_dprint_factory(config.get('debug'))
|
|
dprint = ContextVar('dprint')
|
|
http = ContextVar('http')
|
|
|
|
async def handler(request):
|
|
try:
|
|
await _handler(request)
|
|
except aiohttp.ClientConnectorError:
|
|
raise web.HTTPServiceUnavailable(
|
|
text=f'{request.host} is temporarily unavailable. Try again later.',
|
|
)
|
|
|
|
async def _handler(request):
|
|
try:
|
|
hconfig = config['hosts'][request.host]
|
|
except KeyError:
|
|
# this should only happen if our reverse proxy is set to proxy everything to us
|
|
raise web.HTTPNotFound()
|
|
|
|
dprint.set(dprint_factory())
|
|
async with http().request(
|
|
request.method,
|
|
# TODO support upstream unix sockets
|
|
URL(hconfig['upstream']).with_path(request.path),
|
|
headers=request.headers.copy(),
|
|
params=request.query.copy(),
|
|
data=await request.read() if request.can_read_body else None,
|
|
# proxy redirects as-is
|
|
allow_redirects=False,
|
|
) as upstream_resp:
|
|
dprint()(request.method, request.host, request.path, '→', hconfig['upstream'])
|
|
headers = upstream_resp.headers.copy()
|
|
# we're not using gzip here so don't confuse our client
|
|
with contextlib.suppress(KeyError): del headers['Content-Encoding']
|
|
resp = web.StreamResponse(status=upstream_resp.status, headers=headers)
|
|
await resp.prepare(request)
|
|
|
|
if not hconfig['replace']:
|
|
dprint()('Not replacing for this host')
|
|
return await proxy_passthrough(upstream_resp, resp)
|
|
|
|
for excluded_route in hconfig['excluded_routes']:
|
|
if request.path.startswith(excluded_route):
|
|
dprint()('Not replacing for', excluded_route)
|
|
return await proxy_passthrough(upstream_resp, resp)
|
|
|
|
type, slash, subtype = upstream_resp.content_type.partition('/')
|
|
if slash != '/':
|
|
base_mime_type = upstream_resp.content_type
|
|
else:
|
|
# application/activity+json → application/json
|
|
extra, plus, base = subtype.partition('+')
|
|
if plus != '+':
|
|
base_mime_type = upstream_resp.content_type
|
|
else:
|
|
base_mime_type = type + slash + base
|
|
|
|
if hconfig['mime_types'] and base_mime_type not in hconfig['mime_types']:
|
|
dprint()(
|
|
f'Not configured to replace for MIME type: {upstream_resp.content_type}',
|
|
f'(base MIME type: {base_mime_type})',
|
|
)
|
|
return await proxy_passthrough(upstream_resp, resp)
|
|
|
|
dprint()('replacing', hconfig['replacer'].replacements)
|
|
return await proxy_replace(hconfig, upstream_resp, resp)
|
|
|
|
async def proxy_replace(hconfig, upstream_resp, resp):
|
|
# iter_lines when
|
|
while (line := await upstream_resp.content.readline()):
|
|
await resp.write(hconfig['replacer'].replace(line))
|
|
|
|
return await finalize_resp(resp)
|
|
|
|
async def proxy_passthrough(upstream_resp, resp):
|
|
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
|
|
await resp.write(chunk)
|
|
|
|
return await finalize_resp(resp)
|
|
|
|
async def finalize_resp(resp):
|
|
await resp.write_eof()
|
|
return resp
|
|
|
|
async def amain():
|
|
srv = web.Server(
|
|
handler,
|
|
# It's not our job to hand out 413s, that's the job of upstream and downstream
|
|
request_factory=utils.make_unlimited_request,
|
|
)
|
|
runner = web.ServerRunner(srv)
|
|
await runner.setup()
|
|
bind = os.path.expandvars(config['bind'])
|
|
if (bind_unix := '/' in bind):
|
|
site = web.UnixSite(runner, bind)
|
|
else:
|
|
site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080)))
|
|
|
|
async with aiohttp.ClientSession() as http_:
|
|
http.set(http_)
|
|
await site.start()
|
|
print('Listening on', bind if bind_unix else f'{bind}:{port}', end='...\n')
|
|
if bind_unix:
|
|
with contextlib.suppress(KeyError):
|
|
os.chmod(bind, config.get('socket_mod', 0o775))
|
|
await asyncio.sleep(float('inf'))
|
|
|
|
def main():
|
|
with contextlib.suppress(KeyboardInterrupt):
|
|
asyncio.run(amain())
|
|
|
|
if __name__ == '__main__':
|
|
main()
|