Basic protection against invalid domain names (#680)
This commit is contained in:
parent
d07482f5a8
commit
5d508a17ec
|
@ -3,6 +3,36 @@ import pytest
|
||||||
from users.models import Domain
|
from users.models import Domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_domain():
|
||||||
|
"""
|
||||||
|
Tests that a valid domain is valid
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert Domain.is_valid_domain("example.com")
|
||||||
|
assert Domain.is_valid_domain("xn----gtbspbbmkef.xn--p1ai")
|
||||||
|
assert Domain.is_valid_domain("underscore_subdomain.example.com")
|
||||||
|
assert Domain.is_valid_domain("something.versicherung")
|
||||||
|
assert Domain.is_valid_domain("11.com")
|
||||||
|
assert Domain.is_valid_domain("a.cn")
|
||||||
|
assert Domain.is_valid_domain("sub1.sub2.sample.co.uk")
|
||||||
|
assert Domain.is_valid_domain("somerandomexample.xn--fiqs8s")
|
||||||
|
assert not Domain.is_valid_domain("über.com")
|
||||||
|
assert not Domain.is_valid_domain("example.com:4444")
|
||||||
|
assert not Domain.is_valid_domain("example.-com")
|
||||||
|
assert not Domain.is_valid_domain("foo@bar.com")
|
||||||
|
assert not Domain.is_valid_domain("example.")
|
||||||
|
assert not Domain.is_valid_domain("example.com.")
|
||||||
|
assert not Domain.is_valid_domain("-example.com")
|
||||||
|
assert not Domain.is_valid_domain("_example.com")
|
||||||
|
assert not Domain.is_valid_domain("_example._com")
|
||||||
|
assert not Domain.is_valid_domain("example_.com")
|
||||||
|
assert not Domain.is_valid_domain("example")
|
||||||
|
assert not Domain.is_valid_domain("a......b.com")
|
||||||
|
assert not Domain.is_valid_domain("a.123")
|
||||||
|
assert not Domain.is_valid_domain("123.123")
|
||||||
|
assert not Domain.is_valid_domain("123.123.123.123")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_recursive_block():
|
def test_recursive_block():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import ssl
|
import ssl
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -8,6 +9,7 @@ import httpx
|
||||||
import pydantic
|
import pydantic
|
||||||
import urlman
|
import urlman
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
from core.models import Config
|
from core.models import Config
|
||||||
|
@ -53,6 +55,14 @@ class DomainStates(StateGraph):
|
||||||
return cls.outdated
|
return cls.outdated
|
||||||
|
|
||||||
|
|
||||||
|
def _domain_validator(value: str):
|
||||||
|
if not Domain.is_valid_domain(value):
|
||||||
|
raise ValidationError(
|
||||||
|
"%(value)s is not a valid domain",
|
||||||
|
params={"value": value},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Domain(StatorModel):
|
class Domain(StatorModel):
|
||||||
"""
|
"""
|
||||||
Represents a domain that a user can have an account on.
|
Represents a domain that a user can have an account on.
|
||||||
|
@ -71,7 +81,9 @@ class Domain(StatorModel):
|
||||||
display domains for now, until we start doing better probing.
|
display domains for now, until we start doing better probing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
domain = models.CharField(max_length=250, primary_key=True)
|
domain = models.CharField(
|
||||||
|
max_length=250, primary_key=True, validators=[_domain_validator]
|
||||||
|
)
|
||||||
service_domain = models.CharField(
|
service_domain = models.CharField(
|
||||||
max_length=250,
|
max_length=250,
|
||||||
null=True,
|
null=True,
|
||||||
|
@ -119,6 +131,19 @@ class Domain(StatorModel):
|
||||||
class Meta:
|
class Meta:
|
||||||
indexes: list = []
|
indexes: list = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_domain(cls, domain: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a domain is valid, domain must be lowercase
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
re.match(
|
||||||
|
r"^(?:[a-z0-9](?:[a-z0-9-_]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-_]{0,61}[a-z]$",
|
||||||
|
domain,
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_remote_domain(cls, domain: str) -> "Domain":
|
def get_remote_domain(cls, domain: str) -> "Domain":
|
||||||
return cls.objects.get_or_create(domain=domain.lower(), local=False)[0]
|
return cls.objects.get_or_create(domain=domain.lower(), local=False)[0]
|
||||||
|
|
|
@ -18,6 +18,8 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
|
||||||
domain = domain_instance.domain
|
domain = domain_instance.domain
|
||||||
else:
|
else:
|
||||||
username, domain = handle.split("@", 1)
|
username, domain = handle.split("@", 1)
|
||||||
|
if not Domain.is_valid_domain(domain):
|
||||||
|
raise Http404("Invalid domain")
|
||||||
# Resolve the domain to the display domain
|
# Resolve the domain to the display domain
|
||||||
domain_instance = Domain.get_domain(domain)
|
domain_instance = Domain.get_domain(domain)
|
||||||
if domain_instance is None:
|
if domain_instance is None:
|
||||||
|
|
Loading…
Reference in New Issue