Basic protection against invalid domain names (#680)
This commit is contained in:
parent
d07482f5a8
commit
5d508a17ec
|
@ -3,6 +3,36 @@ import pytest
|
|||
from users.models import Domain
|
||||
|
||||
|
||||
def test_valid_domain():
|
||||
"""
|
||||
Tests that a valid domain is valid
|
||||
"""
|
||||
|
||||
assert Domain.is_valid_domain("example.com")
|
||||
assert Domain.is_valid_domain("xn----gtbspbbmkef.xn--p1ai")
|
||||
assert Domain.is_valid_domain("underscore_subdomain.example.com")
|
||||
assert Domain.is_valid_domain("something.versicherung")
|
||||
assert Domain.is_valid_domain("11.com")
|
||||
assert Domain.is_valid_domain("a.cn")
|
||||
assert Domain.is_valid_domain("sub1.sub2.sample.co.uk")
|
||||
assert Domain.is_valid_domain("somerandomexample.xn--fiqs8s")
|
||||
assert not Domain.is_valid_domain("über.com")
|
||||
assert not Domain.is_valid_domain("example.com:4444")
|
||||
assert not Domain.is_valid_domain("example.-com")
|
||||
assert not Domain.is_valid_domain("foo@bar.com")
|
||||
assert not Domain.is_valid_domain("example.")
|
||||
assert not Domain.is_valid_domain("example.com.")
|
||||
assert not Domain.is_valid_domain("-example.com")
|
||||
assert not Domain.is_valid_domain("_example.com")
|
||||
assert not Domain.is_valid_domain("_example._com")
|
||||
assert not Domain.is_valid_domain("example_.com")
|
||||
assert not Domain.is_valid_domain("example")
|
||||
assert not Domain.is_valid_domain("a......b.com")
|
||||
assert not Domain.is_valid_domain("a.123")
|
||||
assert not Domain.is_valid_domain("123.123")
|
||||
assert not Domain.is_valid_domain("123.123.123.123")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_recursive_block():
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import ssl
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
@ -8,6 +9,7 @@ import httpx
|
|||
import pydantic
|
||||
import urlman
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
|
||||
from core.models import Config
|
||||
|
@ -53,6 +55,14 @@ class DomainStates(StateGraph):
|
|||
return cls.outdated
|
||||
|
||||
|
||||
def _domain_validator(value: str):
|
||||
if not Domain.is_valid_domain(value):
|
||||
raise ValidationError(
|
||||
"%(value)s is not a valid domain",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
|
||||
class Domain(StatorModel):
|
||||
"""
|
||||
Represents a domain that a user can have an account on.
|
||||
|
@ -71,7 +81,9 @@ class Domain(StatorModel):
|
|||
display domains for now, until we start doing better probing.
|
||||
"""
|
||||
|
||||
domain = models.CharField(max_length=250, primary_key=True)
|
||||
domain = models.CharField(
|
||||
max_length=250, primary_key=True, validators=[_domain_validator]
|
||||
)
|
||||
service_domain = models.CharField(
|
||||
max_length=250,
|
||||
null=True,
|
||||
|
@ -119,6 +131,19 @@ class Domain(StatorModel):
|
|||
class Meta:
|
||||
indexes: list = []
|
||||
|
||||
@classmethod
|
||||
def is_valid_domain(cls, domain: str) -> bool:
|
||||
"""
|
||||
Check if a domain is valid, domain must be lowercase
|
||||
"""
|
||||
return (
|
||||
re.match(
|
||||
r"^(?:[a-z0-9](?:[a-z0-9-_]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-_]{0,61}[a-z]$",
|
||||
domain,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_remote_domain(cls, domain: str) -> "Domain":
|
||||
return cls.objects.get_or_create(domain=domain.lower(), local=False)[0]
|
||||
|
|
|
@ -18,6 +18,8 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
|
|||
domain = domain_instance.domain
|
||||
else:
|
||||
username, domain = handle.split("@", 1)
|
||||
if not Domain.is_valid_domain(domain):
|
||||
raise Http404("Invalid domain")
|
||||
# Resolve the domain to the display domain
|
||||
domain_instance = Domain.get_domain(domain)
|
||||
if domain_instance is None:
|
||||
|
|
Loading…
Reference in New Issue