move config to toml and add optional debug logging

This commit is contained in:
Kay Faraday 2021-09-07 07:09:18 +00:00
parent 980e421c4d
commit 695c7145af
4 changed files with 82 additions and 31 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
__pycache__/
*.toml
!*.example.toml

58
app.py
View File

@ -4,61 +4,55 @@
import io
import os
import sys
import random
import asyncio
import aiohttp
import contextlib
import contextvars
from yarl import URL
from aiohttp import web
import pytomlpp as toml
from utils import build_dprint, ContextVar
class ContextVar:
"""
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
"""
__slots__ = frozenset({'_var'})
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
def __call__(self, *args): return self._var.get(*args) # convenience
def get(self, *args): return self._var.get(*args)
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
config = toml.load(f)
pattern = os.environ['PATTERN'].encode()
repl = os.environ['REPL'].encode()
# TODO support upstream unix sockets
upstream = os.environ['UPSTREAM']
bind = sys.argv[1]
socket_mod = int(os.environ.get('SOCKET_MODE', '775'), 8)
replace_for = os.environ.get('REPLACE_FOR_HOST')
try:
port = int(sys.argv[2])
except IndexError:
port = 8080
for host in config['hosts'].values():
if not host.get('replace', True): continue
host['pattern_decoded'] = host['pattern']
host['pattern'] = host['pattern'].encode()
host['repl_decoded'] = host['repl']
host['repl'] = host['repl'].encode()
dprint = build_dprint(config.get('debug'))
http = ContextVar('http')
async def handler(request):
async with http().request(
hconfig = config['hosts'][request.host]
async with dprint() as print, http().request(
request.method,
URL(upstream).with_path(request.path),
# TODO support upstream unix sockets
URL(hconfig['upstream']).with_path(request.path),
headers=request.headers.copy(),
params=request.query.copy(),
data=await request.read() if request.can_read_body else None,
# proxy redirects as-is
allow_redirects=False,
) as upstream_resp:
print('proceeding for', request.host, request.path, '', hconfig['upstream'])
headers = upstream_resp.headers.copy()
# we're not using gzip here so don't confuse our client
with contextlib.suppress(KeyError): del headers['Content-Encoding']
resp = web.StreamResponse(status=upstream_resp.status, headers=headers)
await resp.prepare(request)
if (
upstream_resp.content_type == 'text/html'
and (replace_for is None or request.host == replace_for)
):
if upstream_resp.content_type == 'text/html' and hconfig.get('replace', True):
# iter_lines when
print('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded']))
while (line := await upstream_resp.content.readline()):
await resp.write(line.replace(pattern, repl))
await resp.write(line.replace(hconfig['pattern'], hconfig['repl']))
else:
print('not replacing')
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
await resp.write(chunk)
@ -69,16 +63,18 @@ async def amain():
srv = web.Server(handler)
runner = web.ServerRunner(srv)
await runner.setup()
if '/' in bind:
bind = os.path.expandvars(config['bind'])
if (unix := '/' in bind):
site = web.UnixSite(runner, bind)
else:
site = web.TCPSite(runner, host=bind, port=port)
site = web.TCPSite(runner, host=bind, port=(port := config.get('port', 8080)))
async with aiohttp.ClientSession() as http_:
http.set(http_)
await site.start()
print('Listening on', bind if unix else f'{bind}:{port}', end='...\n')
if '/' in bind:
with contextlib.suppress(KeyError):
os.chmod(bind, socket_mod)
os.chmod(bind, config.get('socket_mod', 0o775))
await asyncio.sleep(float('inf'))
def main():

19
config.example.toml Normal file
View File

@ -0,0 +1,19 @@
bind = '127.0.0.1'
# you can also bind to a unix socket, and environment vars are supported:
# bind = '$RUNTIME_DIRECTORY/replacer.sock'
port = '8000'
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
socket_mod = 0o777
# whether to enable request logging. disabled by default
debug = True
# replace foo with bar for site1.example
[hosts."site1.example"]
upstream = 'http://localhost:3001'
pattern = 'foo'
repl = 'bar'
# just pass through site2.example unmodified
[hosts."site2.example"]
upstream = 'http://localhost:3001'
replace = false

34
utils.py Normal file
View File

@ -0,0 +1,34 @@
# SPDX-License-Identifier: AGPL-3.0-only
import random
import contextlib
import contextvars
class ContextVar:
"""
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
"""
__slots__ = frozenset({'_var'})
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
def __call__(self, *args): return self._var.get(*args) # convenience
def get(self, *args): return self._var.get(*args)
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
class asyncnullcontext(contextlib.AbstractAsyncContextManager):
async def __aenter__(self): return None
async def __aexit__(self, *excinfo): return None
def build_dprint(debug: bool):
if debug:
class dprint(asyncnullcontext):
# 8 byte pseudo-random request tag
def __init__(self): self.req_id = ''.join(random.choices('0123456789abcdef', k=8 * 2))
def print(self, *args, **kwargs): print(f'[{self.req_id}]', *args, **kwargs)
async def __aenter__(self): return self.print
else:
class dprint(asyncnullcontext):
async def __aenter__(self):
return lambda *args, **kwargs: None
return dprint