From 5ea236006f91f5e470dba420f2da7d28548798f6 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:16:18 +0200 Subject: [PATCH 1/4] fix(auth): make get_access_token per-request in stateful sessions --- .../server/auth/middleware/auth_context.py | 23 ++++ src/mcp/server/runner.py | 7 +- .../test_get_access_token_streamable_http.py | 100 ++++++++++++++++++ 3 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 tests/server/auth/test_get_access_token_streamable_http.py diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546b..f34b98cefd 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,5 +1,8 @@ import contextvars +from contextvars import Token + +from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser @@ -20,6 +23,26 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None +def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: + """Set auth context for the current task from an incoming request. + + This is primarily used by server transports where request handlers may run + in background tasks that are not part of the original ASGI request task. + """ + if request is None: + return None + user = getattr(request, "user", None) + if isinstance(user, AuthenticatedUser): + return auth_context_var.set(user) + return None + + +def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: + if token is None: + return + auth_context_var.reset(token) + + class AuthContextMiddleware: """Middleware that extracts the authenticated user from the request and sets it in a contextvar for easy access throughout the request lifecycle. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 4f8d23b8dd..96a9d3284a 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -26,6 +26,7 @@ from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar +from mcp.server.auth.middleware.auth_context import _pop_auth_context, _push_auth_context_from_request from mcp.server.connection import Connection from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions @@ -259,7 +260,11 @@ async def _inner() -> HandlerResult: return result call = self._compose_server_middleware(ctx, method, params, _inner) - result = _dump_result(await call()) + auth_token = _push_auth_context_from_request(ctx.request) + try: + result = _dump_result(await call()) + finally: + _pop_auth_context(auth_token) if method == "initialize": # Commit only on chain success, so a middleware veto leaves no state. # Race-free: the read loop is parked until this call returns. diff --git a/tests/server/auth/test_get_access_token_streamable_http.py b/tests/server/auth/test_get_access_token_streamable_http.py new file mode 100644 index 0000000000..3e0bd24272 --- /dev/null +++ b/tests/server/auth/test_get_access_token_streamable_http.py @@ -0,0 +1,100 @@ +import time + +import httpx +import pytest +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +class _EchoTokenVerifier: + """Accepts any bearer token and echoes it back as the verified AccessToken.""" + + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + +async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) + + +class _MutableBearerAuth(httpx.Auth): + def __init__(self, token: str) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session() -> None: + host = "testserver" + + server = Server( + "auth-test-server", + on_call_tool=_handle_whoami, + on_list_tools=_handle_list_tools, + ) + + security = TransportSecuritySettings( + allowed_hosts=[host, f"{host}:*"], + allowed_origins=[f"http://{host}:*"], + ) + session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False) + + asgi_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + auth = _MutableBearerAuth("token-A") + async with asgi_app.router.lifespan_context(asgi_app): + async with ( + httpx.ASGITransport(asgi_app) as transport, + httpx.AsyncClient( + transport=transport, + base_url=f"http://{host}", + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client, + ): + transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) + async with Client(transport_ctx) as client: + r1 = await client.call_tool("whoami", {}) + assert isinstance(r1.content[0], TextContent) + assert r1.content[0].text == "token-A" + + auth.token = "token-B" + r2 = await client.call_tool("whoami", {}) + assert isinstance(r2.content[0], TextContent) + assert r2.content[0].text == "token-B" From caf980b649df30b46cb04050f629bfc1519d6e95 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:19:07 +0200 Subject: [PATCH 2/4] fix(auth): avoid Request.user assertion without auth middleware --- src/mcp/server/auth/middleware/auth_context.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index f34b98cefd..1edbc57de6 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -31,7 +31,16 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica """ if request is None: return None - user = getattr(request, "user", None) + # Avoid Request.user, which asserts AuthenticationMiddleware is installed. + user = None + scope = getattr(request, "scope", None) + if isinstance(scope, dict): + user = scope.get("user") + if user is None: + try: + user = getattr(request, "user", None) + except AssertionError: + user = None if isinstance(user, AuthenticatedUser): return auth_context_var.set(user) return None From 190f1014a87b8fe81b33ab014d86af6d8ff1497c Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:22:53 +0200 Subject: [PATCH 3/4] chore(auth): type-safe auth context push --- src/mcp/server/auth/middleware/auth_context.py | 11 ++++------- src/mcp/server/runner.py | 6 +++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1edbc57de6..0d7b3d6cbf 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,6 +1,6 @@ import contextvars - from contextvars import Token +from typing import Any from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send @@ -23,7 +23,7 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None -def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: +def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: """Set auth context for the current task from an incoming request. This is primarily used by server transports where request handlers may run @@ -32,10 +32,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica if request is None: return None # Avoid Request.user, which asserts AuthenticationMiddleware is installed. - user = None - scope = getattr(request, "scope", None) - if isinstance(scope, dict): - user = scope.get("user") + user: Any | None = request.scope.get("user") if user is None: try: user = getattr(request, "user", None) @@ -46,7 +43,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica return None -def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: +def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: if token is None: return auth_context_var.reset(token) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 96a9d3284a..aacb9043d6 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -26,7 +26,7 @@ from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar -from mcp.server.auth.middleware.auth_context import _pop_auth_context, _push_auth_context_from_request +from mcp.server.auth.middleware.auth_context import pop_auth_context, push_auth_context_from_request from mcp.server.connection import Connection from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions @@ -260,11 +260,11 @@ async def _inner() -> HandlerResult: return result call = self._compose_server_middleware(ctx, method, params, _inner) - auth_token = _push_auth_context_from_request(ctx.request) + auth_token = push_auth_context_from_request(ctx.request) try: result = _dump_result(await call()) finally: - _pop_auth_context(auth_token) + pop_auth_context(auth_token) if method == "initialize": # Commit only on chain success, so a middleware veto leaves no state. # Race-free: the read loop is parked until this call returns. From 00174fcf2a302ff3ffce5282ade17c714b4a3981 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 01:09:41 +0200 Subject: [PATCH 4/4] fix auth context reset for streamable HTTP --- .../server/auth/middleware/auth_context.py | 4 +- .../test_get_access_token_streamable_http.py | 55 +++++++++---------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 0d7b3d6cbf..31eb58b5b9 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -38,9 +38,7 @@ def push_auth_context_from_request(request: Request | None) -> Token[Authenticat user = getattr(request, "user", None) except AssertionError: user = None - if isinstance(user, AuthenticatedUser): - return auth_context_var.set(user) - return None + return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None) def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: diff --git a/tests/server/auth/test_get_access_token_streamable_http.py b/tests/server/auth/test_get_access_token_streamable_http.py index 3e0bd24272..9edf068477 100644 --- a/tests/server/auth/test_get_access_token_streamable_http.py +++ b/tests/server/auth/test_get_access_token_streamable_http.py @@ -14,7 +14,6 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( CallToolRequestParams, CallToolResult, @@ -43,14 +42,34 @@ async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequest class _MutableBearerAuth(httpx.Auth): - def __init__(self, token: str) -> None: + def __init__(self, token: str | None) -> None: self.token = token def auth_flow(self, request: httpx.Request): - request.headers["Authorization"] = f"Bearer {self.token}" + if self.token is not None: + request.headers["Authorization"] = f"Bearer {self.token}" yield request +async def _call_whoami(asgi_app: Starlette, host: str, token: str | None) -> str: + auth = _MutableBearerAuth(token) + async with ( + httpx.ASGITransport(asgi_app) as transport, + httpx.AsyncClient( + transport=transport, + base_url=f"http://{host}", + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client, + ): + transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) + async with Client(transport_ctx) as client: # pragma: no branch + result = await client.call_tool("whoami", {}) + assert isinstance(result.content[0], TextContent) + return result.content[0].text + + @pytest.mark.anyio async def test_get_access_token_reflects_current_request_in_stateful_session() -> None: host = "testserver" @@ -61,11 +80,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() - on_list_tools=_handle_list_tools, ) - security = TransportSecuritySettings( - allowed_hosts=[host, f"{host}:*"], - allowed_origins=[f"http://{host}:*"], - ) - session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False) + session_manager = StreamableHTTPSessionManager(app=server, stateless=False) asgi_app = Starlette( routes=[Mount("/mcp", app=session_manager.handle_request)], @@ -76,25 +91,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() - lifespan=lambda app: session_manager.run(), ) - auth = _MutableBearerAuth("token-A") async with asgi_app.router.lifespan_context(asgi_app): - async with ( - httpx.ASGITransport(asgi_app) as transport, - httpx.AsyncClient( - transport=transport, - base_url=f"http://{host}", - auth=auth, - timeout=httpx.Timeout(30, read=30), - follow_redirects=True, - ) as http_client, - ): - transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) - async with Client(transport_ctx) as client: - r1 = await client.call_tool("whoami", {}) - assert isinstance(r1.content[0], TextContent) - assert r1.content[0].text == "token-A" - - auth.token = "token-B" - r2 = await client.call_tool("whoami", {}) - assert isinstance(r2.content[0], TextContent) - assert r2.content[0].text == "token-B" + assert await _call_whoami(asgi_app, host, "token-A") == "token-A" + assert await _call_whoami(asgi_app, host, "token-B") == "token-B" + assert await _call_whoami(asgi_app, host, None) == ""