support multiple patterns (closes #1)

This commit is contained in:
Kay Faraday 2022-06-12 01:53:08 +00:00
parent 97065e65b6
commit af1622442b
3 changed files with 32 additions and 13 deletions

View File

@ -5,14 +5,15 @@ port = '8000'
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
socket_mod = 0o777
# whether to enable request logging. disabled by default
debug = True
debug = true
# replace foo with bar for site1.example
# replace foo with bar and baz with quux for site1.example
[hosts."site1.example"]
upstream = 'http://localhost:3001'
# these can be regexes
pattern = 'foo'
repl = 'bar'
replacements = [
{pattern = 'foo', repl = 'bar'},
{pattern = 'baz', repl = 'quux'},
]
# which mime types to replace for. defaults to all mime types.
# `x/y` also matches `x/a+y`, `x/b+y`, `x/c+y` etc.
mime_types = ['text/html', 'application/json']

View File

@ -12,9 +12,9 @@ import aiohttp
import contextlib
import contextvars
from yarl import URL
from aiohttp import web
import pytomlpp as toml
from utils import build_dprint_factory, ContextVar
from aiohttp import web
from utils import build_dprint_factory, ContextVar, MultiReplacer
def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
@ -22,11 +22,14 @@ def load_config():
for host in config['hosts'].values():
if not host.setdefault('replace', True): continue
host['pattern_decoded'] = host['pattern']
host['pattern'] = re.compile(host['pattern'].encode())
host['repl_decoded'] = host['repl']
host['repl'] = host['repl'].encode()
host['mime_types'] = frozenset(host.get('mime_types', ()))
host['replacer'] = MultiReplacer({
replacement['pattern'].encode(): replacement['repl'].encode()
for replacement
in host['replacements']
})
try:
host['excluded_routes'].sort(key=len, reverse=True)
except KeyError:
@ -99,13 +102,13 @@ async def _handler(request):
)
return await proxy_passthrough(upstream_resp, resp)
dprint()('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded']))
dprint()('replacing', hconfig['replacer'].replacements)
return await proxy_replace(hconfig, upstream_resp, resp)
async def proxy_replace(hconfig, upstream_resp, resp):
# iter_lines when
while (line := await upstream_resp.content.readline()):
await resp.write(hconfig['pattern'].sub(hconfig['repl'], line))
await resp.write(hconfig['replacer'].replace(line))
return await finalize_resp(resp)

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-only
import re
import random
import asyncio
import contextvars
from aiohttp import web
from typing import Dict
class ContextVar:
"""
@ -16,6 +18,19 @@ class ContextVar:
def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args)
class MultiReplacer:
__slots__ = frozenset({'replacements', '_pattern'})
def __init__(self, replacements: Dict[bytes, bytes]): # TODO support AnyStr
self.replacements = replacements
self._pattern = re.compile(b'(?:%b)' % b'|'.join(map(re.escape, self.replacements)))
def replace(self, s):
return self._pattern.sub(self._replacement, s)
def _replacement(self, match):
return self.replacements[match[0]]
def build_dprint_factory(debug: bool):
if debug:
def dprint_factory():