Basic protection against invalid domain names (#680)

This commit is contained in:
Henri Dickson 2023-12-13 04:04:41 -05:00 committed by GitHub
parent d07482f5a8
commit 5d508a17ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 1 deletions

View File

@ -3,6 +3,36 @@ import pytest
from users.models import Domain from users.models import Domain
def test_valid_domain():
"""
Tests that a valid domain is valid
"""
assert Domain.is_valid_domain("example.com")
assert Domain.is_valid_domain("xn----gtbspbbmkef.xn--p1ai")
assert Domain.is_valid_domain("underscore_subdomain.example.com")
assert Domain.is_valid_domain("something.versicherung")
assert Domain.is_valid_domain("11.com")
assert Domain.is_valid_domain("a.cn")
assert Domain.is_valid_domain("sub1.sub2.sample.co.uk")
assert Domain.is_valid_domain("somerandomexample.xn--fiqs8s")
assert not Domain.is_valid_domain("über.com")
assert not Domain.is_valid_domain("example.com:4444")
assert not Domain.is_valid_domain("example.-com")
assert not Domain.is_valid_domain("foo@bar.com")
assert not Domain.is_valid_domain("example.")
assert not Domain.is_valid_domain("example.com.")
assert not Domain.is_valid_domain("-example.com")
assert not Domain.is_valid_domain("_example.com")
assert not Domain.is_valid_domain("_example._com")
assert not Domain.is_valid_domain("example_.com")
assert not Domain.is_valid_domain("example")
assert not Domain.is_valid_domain("a......b.com")
assert not Domain.is_valid_domain("a.123")
assert not Domain.is_valid_domain("123.123")
assert not Domain.is_valid_domain("123.123.123.123")
@pytest.mark.django_db @pytest.mark.django_db
def test_recursive_block(): def test_recursive_block():
""" """

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import re
import ssl import ssl
from functools import cached_property from functools import cached_property
from typing import Optional from typing import Optional
@ -8,6 +9,7 @@ import httpx
import pydantic import pydantic
import urlman import urlman
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from core.models import Config from core.models import Config
@ -53,6 +55,14 @@ class DomainStates(StateGraph):
return cls.outdated return cls.outdated
def _domain_validator(value: str):
if not Domain.is_valid_domain(value):
raise ValidationError(
"%(value)s is not a valid domain",
params={"value": value},
)
class Domain(StatorModel): class Domain(StatorModel):
""" """
Represents a domain that a user can have an account on. Represents a domain that a user can have an account on.
@ -71,7 +81,9 @@ class Domain(StatorModel):
display domains for now, until we start doing better probing. display domains for now, until we start doing better probing.
""" """
domain = models.CharField(max_length=250, primary_key=True) domain = models.CharField(
max_length=250, primary_key=True, validators=[_domain_validator]
)
service_domain = models.CharField( service_domain = models.CharField(
max_length=250, max_length=250,
null=True, null=True,
@ -119,6 +131,19 @@ class Domain(StatorModel):
class Meta: class Meta:
indexes: list = [] indexes: list = []
@classmethod
def is_valid_domain(cls, domain: str) -> bool:
"""
Check if a domain is valid, domain must be lowercase
"""
return (
re.match(
r"^(?:[a-z0-9](?:[a-z0-9-_]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-_]{0,61}[a-z]$",
domain,
)
is not None
)
@classmethod @classmethod
def get_remote_domain(cls, domain: str) -> "Domain": def get_remote_domain(cls, domain: str) -> "Domain":
return cls.objects.get_or_create(domain=domain.lower(), local=False)[0] return cls.objects.get_or_create(domain=domain.lower(), local=False)[0]

View File

@ -18,6 +18,8 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
domain = domain_instance.domain domain = domain_instance.domain
else: else:
username, domain = handle.split("@", 1) username, domain = handle.split("@", 1)
if not Domain.is_valid_domain(domain):
raise Http404("Invalid domain")
# Resolve the domain to the display domain # Resolve the domain to the display domain
domain_instance = Domain.get_domain(domain) domain_instance = Domain.get_domain(domain)
if domain_instance is None: if domain_instance is None: