diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 3b48152d5..e3bac8f7e 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -97,8 +97,10 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: - # Validate redirect_uri against client's registered redirect URIs - if self.redirect_uris is None or redirect_uri not in self.redirect_uris: + # Validate redirect_uri against client's registered redirect URIs. + # Pydantic URL equality is type-strict across AnyUrl subclasses, so + # compare canonical serialized values instead of object identity. + if self.redirect_uris is None or str(redirect_uri) not in {str(uri) for uri in self.redirect_uris}: raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri elif self.redirect_uris is not None and len(self.redirect_uris) == 1: diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index 7463bc5a8..c277e21e0 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -1,9 +1,14 @@ """Tests for OAuth 2.0 shared code.""" import pytest -from pydantic import ValidationError +from pydantic import AnyHttpUrl, AnyUrl, ValidationError -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata +from mcp.shared.auth import ( + InvalidRedirectUriError, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, +) def test_oauth(): @@ -138,3 +143,25 @@ def test_invalid_non_empty_url_still_rejected(): } with pytest.raises(ValidationError): OAuthClientMetadata.model_validate(data) + + +def test_redirect_uri_validation_accepts_equivalent_pydantic_url_types(): + """Pydantic URL subclasses with the same serialized URI should match.""" + client_info = OAuthClientInformationFull( + client_id="client-1", + redirect_uris=[AnyHttpUrl("https://example.com/callback")], + ) + + redirect_uri = client_info.validate_redirect_uri(AnyUrl("https://example.com/callback")) + + assert str(redirect_uri) == "https://example.com/callback" + + +def test_redirect_uri_validation_rejects_unregistered_equivalent_type(): + client_info = OAuthClientInformationFull( + client_id="client-1", + redirect_uris=[AnyHttpUrl("https://example.com/callback")], + ) + + with pytest.raises(InvalidRedirectUriError): + client_info.validate_redirect_uri(AnyUrl("https://evil.example/callback"))