support multiple patterns (closes #1)

This commit is contained in:
Kay Faraday 2022-06-12 01:53:08 +00:00
parent 97065e65b6
commit af1622442b
3 changed files with 32 additions and 13 deletions

View File

@ -5,14 +5,15 @@ port = '8000'
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775. # if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
socket_mod = 0o777 socket_mod = 0o777
# whether to enable request logging. disabled by default # whether to enable request logging. disabled by default
debug = True debug = true
# replace foo with bar for site1.example # replace foo with bar and baz with quux for site1.example
[hosts."site1.example"] [hosts."site1.example"]
upstream = 'http://localhost:3001' upstream = 'http://localhost:3001'
# these can be regexes replacements = [
pattern = 'foo' {pattern = 'foo', repl = 'bar'},
repl = 'bar' {pattern = 'baz', repl = 'quux'},
]
# which mime types to replace for. defaults to all mime types. # which mime types to replace for. defaults to all mime types.
# `x/y` also matches `x/a+y`, `x/b+y`, `x/c+y` etc. # `x/y` also matches `x/a+y`, `x/b+y`, `x/c+y` etc.
mime_types = ['text/html', 'application/json'] mime_types = ['text/html', 'application/json']

View File

@ -12,9 +12,9 @@ import aiohttp
import contextlib import contextlib
import contextvars import contextvars
from yarl import URL from yarl import URL
from aiohttp import web
import pytomlpp as toml import pytomlpp as toml
from utils import build_dprint_factory, ContextVar from aiohttp import web
from utils import build_dprint_factory, ContextVar, MultiReplacer
def load_config(): def load_config():
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f: with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
@ -22,11 +22,14 @@ def load_config():
for host in config['hosts'].values(): for host in config['hosts'].values():
if not host.setdefault('replace', True): continue if not host.setdefault('replace', True): continue
host['pattern_decoded'] = host['pattern']
host['pattern'] = re.compile(host['pattern'].encode())
host['repl_decoded'] = host['repl']
host['repl'] = host['repl'].encode()
host['mime_types'] = frozenset(host.get('mime_types', ())) host['mime_types'] = frozenset(host.get('mime_types', ()))
host['replacer'] = MultiReplacer({
replacement['pattern'].encode(): replacement['repl'].encode()
for replacement
in host['replacements']
})
try: try:
host['excluded_routes'].sort(key=len, reverse=True) host['excluded_routes'].sort(key=len, reverse=True)
except KeyError: except KeyError:
@ -99,13 +102,13 @@ async def _handler(request):
) )
return await proxy_passthrough(upstream_resp, resp) return await proxy_passthrough(upstream_resp, resp)
dprint()('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded'])) dprint()('replacing', hconfig['replacer'].replacements)
return await proxy_replace(hconfig, upstream_resp, resp) return await proxy_replace(hconfig, upstream_resp, resp)
async def proxy_replace(hconfig, upstream_resp, resp): async def proxy_replace(hconfig, upstream_resp, resp):
# iter_lines when # iter_lines when
while (line := await upstream_resp.content.readline()): while (line := await upstream_resp.content.readline()):
await resp.write(hconfig['pattern'].sub(hconfig['repl'], line)) await resp.write(hconfig['replacer'].replace(line))
return await finalize_resp(resp) return await finalize_resp(resp)

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-only # SPDX-License-Identifier: AGPL-3.0-only
import re
import random import random
import asyncio import asyncio
import contextvars import contextvars
from aiohttp import web from aiohttp import web
from typing import Dict
class ContextVar: class ContextVar:
""" """
@ -16,6 +18,19 @@ class ContextVar:
def set(self, *args): return self._var.set(*args) def set(self, *args): return self._var.set(*args)
def reset(self, *args): return self._var.reset(*args) def reset(self, *args): return self._var.reset(*args)
class MultiReplacer:
__slots__ = frozenset({'replacements', '_pattern'})
def __init__(self, replacements: Dict[bytes, bytes]): # TODO support AnyStr
self.replacements = replacements
self._pattern = re.compile(b'(?:%b)' % b'|'.join(map(re.escape, self.replacements)))
def replace(self, s):
return self._pattern.sub(self._replacement, s)
def _replacement(self, match):
return self.replacements[match[0]]
def build_dprint_factory(debug: bool): def build_dprint_factory(debug: bool):
if debug: if debug:
def dprint_factory(): def dprint_factory():