support multiple patterns (closes #1)
This commit is contained in:
parent
97065e65b6
commit
af1622442b
|
@ -5,14 +5,15 @@ port = '8000'
|
||||||
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
|
# if using a unix socket, you can specify the file access mode here. it defaults to 0o775.
|
||||||
socket_mod = 0o777
|
socket_mod = 0o777
|
||||||
# whether to enable request logging. disabled by default
|
# whether to enable request logging. disabled by default
|
||||||
debug = True
|
debug = true
|
||||||
|
|
||||||
# replace foo with bar for site1.example
|
# replace foo with bar and baz with quux for site1.example
|
||||||
[hosts."site1.example"]
|
[hosts."site1.example"]
|
||||||
upstream = 'http://localhost:3001'
|
upstream = 'http://localhost:3001'
|
||||||
# these can be regexes
|
replacements = [
|
||||||
pattern = 'foo'
|
{pattern = 'foo', repl = 'bar'},
|
||||||
repl = 'bar'
|
{pattern = 'baz', repl = 'quux'},
|
||||||
|
]
|
||||||
# which mime types to replace for. defaults to all mime types.
|
# which mime types to replace for. defaults to all mime types.
|
||||||
# `x/y` also matches `x/a+y`, `x/b+y`, `x/c+y` etc.
|
# `x/y` also matches `x/a+y`, `x/b+y`, `x/c+y` etc.
|
||||||
mime_types = ['text/html', 'application/json']
|
mime_types = ['text/html', 'application/json']
|
||||||
|
|
19
replacer.py
19
replacer.py
|
@ -12,9 +12,9 @@ import aiohttp
|
||||||
import contextlib
|
import contextlib
|
||||||
import contextvars
|
import contextvars
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
from aiohttp import web
|
|
||||||
import pytomlpp as toml
|
import pytomlpp as toml
|
||||||
from utils import build_dprint_factory, ContextVar
|
from aiohttp import web
|
||||||
|
from utils import build_dprint_factory, ContextVar, MultiReplacer
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
|
with open(sys.argv[1] if len(sys.argv) > 1 else 'config.toml') as f:
|
||||||
|
@ -22,11 +22,14 @@ def load_config():
|
||||||
|
|
||||||
for host in config['hosts'].values():
|
for host in config['hosts'].values():
|
||||||
if not host.setdefault('replace', True): continue
|
if not host.setdefault('replace', True): continue
|
||||||
host['pattern_decoded'] = host['pattern']
|
|
||||||
host['pattern'] = re.compile(host['pattern'].encode())
|
|
||||||
host['repl_decoded'] = host['repl']
|
|
||||||
host['repl'] = host['repl'].encode()
|
|
||||||
host['mime_types'] = frozenset(host.get('mime_types', ()))
|
host['mime_types'] = frozenset(host.get('mime_types', ()))
|
||||||
|
|
||||||
|
host['replacer'] = MultiReplacer({
|
||||||
|
replacement['pattern'].encode(): replacement['repl'].encode()
|
||||||
|
for replacement
|
||||||
|
in host['replacements']
|
||||||
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
host['excluded_routes'].sort(key=len, reverse=True)
|
host['excluded_routes'].sort(key=len, reverse=True)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -99,13 +102,13 @@ async def _handler(request):
|
||||||
)
|
)
|
||||||
return await proxy_passthrough(upstream_resp, resp)
|
return await proxy_passthrough(upstream_resp, resp)
|
||||||
|
|
||||||
dprint()('replacing', repr(hconfig['pattern_decoded']), 'with', repr(hconfig['repl_decoded']))
|
dprint()('replacing', hconfig['replacer'].replacements)
|
||||||
return await proxy_replace(hconfig, upstream_resp, resp)
|
return await proxy_replace(hconfig, upstream_resp, resp)
|
||||||
|
|
||||||
async def proxy_replace(hconfig, upstream_resp, resp):
|
async def proxy_replace(hconfig, upstream_resp, resp):
|
||||||
# iter_lines when
|
# iter_lines when
|
||||||
while (line := await upstream_resp.content.readline()):
|
while (line := await upstream_resp.content.readline()):
|
||||||
await resp.write(hconfig['pattern'].sub(hconfig['repl'], line))
|
await resp.write(hconfig['replacer'].replace(line))
|
||||||
|
|
||||||
return await finalize_resp(resp)
|
return await finalize_resp(resp)
|
||||||
|
|
||||||
|
|
15
utils.py
15
utils.py
|
@ -1,9 +1,11 @@
|
||||||
# SPDX-License-Identifier: AGPL-3.0-only
|
# SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
|
||||||
|
import re
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
class ContextVar:
|
class ContextVar:
|
||||||
"""
|
"""
|
||||||
|
@ -16,6 +18,19 @@ class ContextVar:
|
||||||
def set(self, *args): return self._var.set(*args)
|
def set(self, *args): return self._var.set(*args)
|
||||||
def reset(self, *args): return self._var.reset(*args)
|
def reset(self, *args): return self._var.reset(*args)
|
||||||
|
|
||||||
|
class MultiReplacer:
|
||||||
|
__slots__ = frozenset({'replacements', '_pattern'})
|
||||||
|
|
||||||
|
def __init__(self, replacements: Dict[bytes, bytes]): # TODO support AnyStr
|
||||||
|
self.replacements = replacements
|
||||||
|
self._pattern = re.compile(b'(?:%b)' % b'|'.join(map(re.escape, self.replacements)))
|
||||||
|
|
||||||
|
def replace(self, s):
|
||||||
|
return self._pattern.sub(self._replacement, s)
|
||||||
|
|
||||||
|
def _replacement(self, match):
|
||||||
|
return self.replacements[match[0]]
|
||||||
|
|
||||||
def build_dprint_factory(debug: bool):
|
def build_dprint_factory(debug: bool):
|
||||||
if debug:
|
if debug:
|
||||||
def dprint_factory():
|
def dprint_factory():
|
||||||
|
|
Loading…
Reference in New Issue