75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# SPDX-License-Identifier: AGPL-3.0-only
|
||
|
|
||
|
import io
|
||
|
import os
|
||
|
import asyncio
|
||
|
import aiohttp
|
||
|
import contextlib
|
||
|
import contextvars
|
||
|
from yarl import URL
|
||
|
from aiohttp import web
|
||
|
|
||
|
class ContextVar:
|
||
|
"""
|
||
|
contextvars.ContextVar, but you can call the var to get its value, rather than calling var.get()
|
||
|
"""
|
||
|
__slots__ = frozenset({'_var'})
|
||
|
def __init__(self, *a, **kw): self._var = contextvars.ContextVar(*a, **kw)
|
||
|
def __call__(self, *args): return self._var.get(*args) # convenience
|
||
|
def get(self, *args): return self._var.get(*args)
|
||
|
def set(self, *args): return self._var.set(*args)
|
||
|
def reset(self, *args): return self._var.reset(*args)
|
||
|
|
||
|
pattern = os.environ['PATTERN'].encode()
|
||
|
repl = os.environ['REPL'].encode()
|
||
|
upstream = 'http://' + os.environ['UPSTREAM']
|
||
|
bind = os.environ['BIND']
|
||
|
port = int(os.environ.get('PORT', '8080'))
|
||
|
http = ContextVar('http')
|
||
|
|
||
|
async def handler(request):
|
||
|
async with http().request(
|
||
|
request.method,
|
||
|
URL(upstream).with_path(request.path),
|
||
|
headers=request.headers.copy(),
|
||
|
params=request.query.copy(),
|
||
|
data=await request.read() if request.can_read_body else None,
|
||
|
# proxy redirects as-is
|
||
|
allow_redirects=False,
|
||
|
) as upstream_resp:
|
||
|
resp = web.StreamResponse(status=upstream_resp.status, headers=upstream_resp.headers)
|
||
|
await resp.prepare(request)
|
||
|
|
||
|
if upstream_resp.content_type == 'text/html':
|
||
|
# iter_lines when
|
||
|
while (line := await upstream_resp.content.readline()):
|
||
|
await resp.write(line.replace(pattern, repl))
|
||
|
else:
|
||
|
async for chunk in upstream_resp.content.iter_chunked(io.DEFAULT_BUFFER_SIZE):
|
||
|
await resp.write(chunk)
|
||
|
|
||
|
await resp.write_eof()
|
||
|
return resp
|
||
|
|
||
|
async def amain():
|
||
|
srv = web.Server(handler)
|
||
|
runner = web.ServerRunner(srv)
|
||
|
await runner.setup()
|
||
|
if '/' in bind:
|
||
|
site = web.UnixSite(runner, bind)
|
||
|
else:
|
||
|
site = web.TCPSite(runner, host=bind, port=port)
|
||
|
async with aiohttp.ClientSession() as http_:
|
||
|
http.set(http_)
|
||
|
await site.start()
|
||
|
await asyncio.sleep(float('inf'))
|
||
|
|
||
|
def main():
|
||
|
with contextlib.suppress(KeyboardInterrupt):
|
||
|
asyncio.run(amain())
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|
||
|
|