move config to toml and add optional debug logging

This commit is contained in:
Kay Faraday 2021-09-07 07:09:18 +00:00
parent 980e421c4d
commit 695c7145af
4 changed files with 82 additions and 31 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
__pycache__/ __pycache__/
*.toml
!*.example.toml

58
app.py
View File

@ -4,61 +4,55 @@
import io import io
import os import os
import sys import sys
import random
import asyncio import asyncio
import aiohttp import aiohttp
import contextlib import contextlib
import contextvars import contextvars
from yarl import URL from yarl import URL
from aiohttp import web from aiohttp import web
import pytomlpp as toml
from utils import build_dprint, ContextVar
class ContextVar: with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
""" config = toml.load(f)
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
"""
__slots__ = frozenset({'_var'})
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
def __call__(self, *args): return self._var.get(*args) # convenience
def get(self, *args): return self._var.get(*args)
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
pattern = os.environ['PATTERN'].encode() for host in config['hosts'].values():
repl = os.environ['REPL'].encode() if not host.get('replace', True): continue
# TODO support upstream unix sockets host['pattern_decoded'] = host['pattern']
upstream = os.environ['UPSTREAM'] host['pattern'] = host['pattern'].encode()
bind = sys.argv[1] host['repl_decoded'] = host['repl']
socket_mod = int(os.environ.get('SOCKET_MODE', '775'), 8) host['repl'] = host['repl'].encode()
replace_for = os.environ.get('REPLACE_FOR_HOST')
try: dprint = build_dprint(config.get('debug'))
port = int(sys.argv[2])
except IndexError:
port = 8080
http = ContextVar('http') http = ContextVar('http')
async def handler(request): async def handler(request):
async with http().request( hconfig = config['hosts'][request.host]
async with dprint() as print, http().request(
request.method, request.method,
URL(upstream).with_path(request.path), # TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path),
headers=request.headers.copy(), headers=request.headers.copy(),
params=request.query.copy(), params=request.query.copy(),
data=await request.read() if request.can_read_body else None, data=await request.read() if request.can_read_body else None,
# proxy redirects as-is # proxy redirects as-is
allow_redirects=False, allow_redirects=False,
) as upstream_resp: ) as upstream_resp:
print('proceeding for', request.host, request.path, '', hconfig['upstream'])
headers = upstream_resp.headers.copy() headers = upstream_resp.headers.copy()
# we're not using gzip here so don't confuse our client # we're not using gzip here so don't confuse our client
with contextlib.suppress(KeyError): del headers['Content-Encoding'] with contextlib.suppress(KeyError): del headers['Content-Encoding']
resp = web.StreamResponse(status=upstream_resp.status, headers=headers) resp = web.StreamResponse(status=upstream_resp.status, headers=headers)
await resp.prepare(request) await resp.prepare(request)
if ( if upstream_resp.content_type == 'text/html' and hconfig.get('replace', True):
upstream_resp.content_type == 'text/html'
and (replace_for is None or request.host == replace_for)
):
# iter_lines when # iter_lines when
print('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded']))
while (line := await upstream_resp.content.readline()): while (line := await upstream_resp.content.readline()):
await resp.write(line.replace(pattern, repl)) await resp.write(line.replace(hconfig['pattern'], hconfig['repl']))
else: else:
print('not replacing')
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE): async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
await resp.write(chunk) await resp.write(chunk)
@ -69,16 +63,18 @@ async def amain():
srv = web.Server(handler) srv = web.Server(handler)
runner = web.ServerRunner(srv) runner = web.ServerRunner(srv)
await runner.setup() await runner.setup()
if '/' in bind: bind = os.path.expandvars(config['bind'])
if (unix := '/' in bind):
site = web.UnixSite(runner, bind) site = web.UnixSite(runner, bind)
else: else:
site = web.TCPSite(runner, host=bind, port=port) site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080)))
async with aiohttp.ClientSession() as http_: async with aiohttp.ClientSession() as http_:
http.set(http_) http.set(http_)
await site.start() await site.start()
print('Listening on', bind if unix else f'{bind}:{port}', end='...\n')
if '/' in bind: if '/' in bind:
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
os.chmod(bind, socket_mod) os.chmod(bind, config.get('socket_mod', 0o775))
await asyncio.sleep(float('inf')) await asyncio.sleep(float('inf'))
def main(): def main():

19
config.example.toml Normal file
View File

@ -0,0 +1,19 @@
bind = '127.0.0.1'
# you can also bind to a unix socket, and environment vars are supported:
# bind = '$RUNTIME_DIRECTORY/replacer.sock'
port = '8000'
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
socket_mod = 0o777
# whether to enable request logging. disabled by default
debug = True
# replace foo with bar for site1.example
[hosts."site1.example"]
upstream = 'http://localhost:3001'
pattern = 'foo'
repl = 'bar'
# just pass through site2.example unmodified
[hosts."site2.example"]
upstream = 'http://localhost:3001'
replace = false

34
utils.py Normal file
View File

@ -0,0 +1,34 @@
# SPDX-License-Identifier: AGPL-3.0-only
import random
import contextlib
import contextvars
class ContextVar:
"""
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
"""
__slots__ = frozenset({'_var'})
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
def __call__(self, *args): return self._var.get(*args) # convenience
def get(self, *args): return self._var.get(*args)
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
class asyncnullcontext(contextlib.AbstractAsyncContextManager):
async def __aenter__(self): return None
async def __aexit__(self, *excinfo): return None
def build_dprint(debug: bool):
if debug:
class dprint(asyncnullcontext):
# 8 byte pseudo-random request tag
def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2))
def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs)
async def __aenter__(self): return self.print
else:
class dprint(asyncnullcontext):
async def __aenter__(self):
return lambda *args, **kwargs: None
return dprint