diff --git a/requirements.txt b/requirements.txt
index fd89b83..2047d56 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,6 +21,7 @@ pydantic~=1.10.2
pyld~=2.0.3
pylibmc~=1.6.3
pymemcache~=4.0.0
+pytest-asyncio~=0.20.3
python-dateutil~=2.8.2
python-dotenv~=0.21.0
redis~=4.4.0
diff --git a/tests/users/models/test_identity.py b/tests/users/models/test_identity.py
index 182f2a9..27cd8c4 100644
--- a/tests/users/models/test_identity.py
+++ b/tests/users/models/test_identity.py
@@ -1,5 +1,6 @@
import pytest
from asgiref.sync import async_to_sync
+from pytest_httpx import HTTPXMock
from core.models import Config
from users.models import Domain, Identity, User
@@ -176,3 +177,57 @@ def test_fetch_actor(httpx_mock, config_system):
assert identity.image_uri == "https://example.com/image.jpg"
assert identity.summary == "
A test user
"
assert "ts-a-faaaake" in identity.public_key
+
+
+@pytest.mark.django_db
+@pytest.mark.asyncio
+async def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
+ """
+ Ensures that we can deal with various kinds of webfinger URLs
+ """
+
+ # With no host-meta, it should be the default
+ assert (
+ await Identity.fetch_webfinger_url("example.com")
+ == "https://example.com/.well-known/webfinger?resource={uri}"
+ )
+
+ # Inject a host-meta directing it to a subdomain
+ httpx_mock.add_response(
+ url="https://example.com/.well-known/host-meta",
+ text="""
+
+
+ """,
+ )
+ assert (
+ await Identity.fetch_webfinger_url("example.com")
+ == "https://fedi.example.com/.well-known/webfinger?resource={uri}"
+ )
+
+ # Inject a host-meta directing it to a different URL format
+ httpx_mock.add_response(
+ url="https://example.com/.well-known/host-meta",
+ text="""
+
+
+ """,
+ )
+ assert (
+ await Identity.fetch_webfinger_url("example.com")
+ == "https://example.com/amazing-webfinger?query={uri}"
+ )
+
+ # Inject a host-meta directing it to a different url THAT SUPPORTS XML ONLY
+ # (we want to ignore that one)
+ httpx_mock.add_response(
+ url="https://example.com/.well-known/host-meta",
+ text="""
+
+
+ """,
+ )
+ assert (
+ await Identity.fetch_webfinger_url("example.com")
+ == "https://example.com/.well-known/webfinger?resource={uri}"
+ )
diff --git a/users/models/identity.py b/users/models/identity.py
index 39c6a3c..a0cdf28 100644
--- a/users/models/identity.py
+++ b/users/models/identity.py
@@ -601,14 +601,11 @@ class Identity(StatorModel):
### Actor/Webfinger fetching ###
@classmethod
- async def fetch_webfinger(cls, handle: str) -> tuple[str | None, str | None]:
+ async def fetch_webfinger_url(cls, domain: str):
"""
- Given a username@domain handle, returns a tuple of
- (actor uri, canonical handle) or None, None if it does not resolve.
+ Given a domain (hostname), returns the correct webfinger URL to use
+ based on probing host-meta.
"""
- domain = handle.split("@")[1].lower()
- webfinger_url = f"https://{domain}/.well-known/webfinger?resource={{uri}}"
-
async with httpx.AsyncClient(
timeout=settings.SETUP.REMOTE_TIMEOUT,
headers={"User-Agent": settings.TAKAHE_USER_AGENT},
@@ -626,13 +623,29 @@ class Identity(StatorModel):
if response.status_code == 200 and response.content.strip():
tree = etree.fromstring(response.content)
template = tree.xpath(
- "string(.//*[local-name() = 'Link' and @rel='lrdd']/@template)"
+ "string(.//*[local-name() = 'Link' and @rel='lrdd' and (not(@type) or @type='application/jrd+json')]/@template)"
)
if template:
- webfinger_url = template
+ return template
except (httpx.RequestError, etree.ParseError):
pass
+ return f"https://{domain}/.well-known/webfinger?resource={{uri}}"
+
+ @classmethod
+ async def fetch_webfinger(cls, handle: str) -> tuple[str | None, str | None]:
+ """
+ Given a username@domain handle, returns a tuple of
+ (actor uri, canonical handle) or None, None if it does not resolve.
+ """
+ domain = handle.split("@")[1].lower()
+ webfinger_url = await cls.fetch_webfinger_url(domain)
+
+ # Go make a Webfinger request
+ async with httpx.AsyncClient(
+ timeout=settings.SETUP.REMOTE_TIMEOUT,
+ headers={"User-Agent": settings.TAKAHE_USER_AGENT},
+ ) as client:
try:
response = await client.get(
webfinger_url.format(uri=f"acct:{handle}"),
@@ -640,12 +653,12 @@ class Identity(StatorModel):
headers={"Accept": "application/json"},
)
response.raise_for_status()
- except httpx.HTTPError as ex:
+ except httpx.RequestError as ex:
response = getattr(ex, "response", None)
if (
response
and response.status_code < 500
- and response.status_code not in [404, 410]
+ and response.status_code not in [401, 403, 404, 410]
):
raise ValueError(
f"Client error fetching webfinger: {response.status_code}",