diff --git a/README.md b/README.md index 1880c02be..7ef5915e0 100644 --- a/README.md +++ b/README.md @@ -270,6 +270,82 @@ a2a.blocking.consumption.timeout.seconds=5 **Note:** The reference server implementations (Quarkus-based) automatically include the MicroProfile Config integration, so properties work out of the box in `application.properties`. +### 5. Task Authorization (Optional) + +The SDK includes an opt-in SPI for per-user task authorization. When enabled, every `RequestHandler` operation checks whether the authenticated user is allowed to access the target task before proceeding. When no provider is present, all operations are permitted (the default). + +> **⚠ Security note:** For multi-user deployments, a `TaskAuthorizationProvider` **must** be configured. Without one, all operations are permitted regardless of authentication — any authenticated user can read, modify, or cancel any task. Production deployments should use a fail-closed ownership policy (deny access when ownership is unknown). + +#### Providing an implementation + +Create an `@ApplicationScoped` CDI bean that implements `TaskAuthorizationProvider`: + +```java +import jakarta.enterprise.context.ApplicationScoped; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; +import org.a2aproject.sdk.server.ServerCallContext; + +@ApplicationScoped +public class MyTaskAuthorizationProvider implements TaskAuthorizationProvider { + + @Override + public boolean checkRead(ServerCallContext context, String taskId, TaskOperation op) { + // Return true to allow, false to deny. + // Denied reads throw TaskNotFoundError — the caller cannot distinguish + // "not found" from "not authorized", preventing information leakage. + String owner = ownershipStore.get(taskId); + if (owner == null) { + return false; // fail-closed: unknown ownership → deny + } + return owner.equals(context.getUser().getUsername()); + } + + @Override + public boolean checkWrite(ServerCallContext context, String taskId, TaskOperation op) { + return checkRead(context, taskId, op); + } + + @Override + public boolean checkCreate(ServerCallContext context, TaskOperation op) { + return context.getUser().isAuthenticated(); + } + + @Override + public boolean isTaskRecorded(String taskId) { + return ownershipStore.contains(taskId); + } + + @Override + public void recordOwnership(ServerCallContext context, String taskId, TaskOperation op) { + ownershipStore.put(taskId, context.getUser().getUsername()); + } +} +``` + +No additional configuration is required — the SDK automatically discovers the bean via CDI and wires it into the request pipeline. See the [`TaskAuthorizationProvider`](server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskAuthorizationProvider.java) javadoc for the full contract, including thread-safety requirements and ownership-recording semantics. + +#### User identity in ServerCallContext + +Authorization decisions rely on `context.getUser()` returning the authenticated user. How the user is populated depends on the transport: + +- **JSON-RPC and REST**: The Quarkus route handler extracts the user from the Vert.x routing context (`rc.userContext()`) and sets it on `ServerCallContext` directly. +- **gRPC**: The reference server includes a `QuarkusCallContextFactory` CDI bean that injects the Quarkus `SecurityIdentity` and maps it to the `ServerCallContext` `User`. This happens automatically when using the reference gRPC module. If you provide your own `CallContextFactory`, you are responsible for populating the user. + +> **Note:** When task authorization is required, always obtain `RequestHandler` through CDI injection. Manual instantiation via `DefaultRequestHandler.create()` bypasses the `AuthorizationRequestHandlerDecorator` and all authorization checks. + +#### How it works + +| Operation | Check | +|-----------|-------| +| `getTask`, `subscribeToTask`, `getTaskPushNotificationConfig`, `listTaskPushNotificationConfigs` | `checkRead` | +| `cancelTask`, `createTaskPushNotificationConfig`, `deleteTaskPushNotificationConfig` | `checkWrite` | +| `messageSend` / `messageSendStream` (existing task) | `checkWrite` | +| `messageSend` / `messageSendStream` (new task) | `checkCreate`, then `recordOwnership` after creation | +| `listTasks` | Filtering pushed to `TaskStore.list()` — calls `checkRead` per task | + +> Task authorization addresses task isolation for deployments that enable `TaskAuthorizationProvider` with a fail-closed ownership policy. Multi-user deployments must configure this as a required security setting, and should avoid policies that allow unknown ownership by default. + ### Serving Older Protocol Versions (Backward Compatibility) The A2A Java SDK includes compatibility layers that allow your server to accept requests from clients using older protocol versions. Each compatibility layer is a separate set of modules that you add to your project as needed. **No changes to your `AgentExecutor` are needed** — the compatibility layer converts older protocol requests to v1.0 internally before delegating to your agent. diff --git a/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusCallContextFactory_v0_3.java b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusCallContextFactory_v0_3.java new file mode 100644 index 000000000..47dc07028 --- /dev/null +++ b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusCallContextFactory_v0_3.java @@ -0,0 +1,78 @@ +package org.a2aproject.sdk.compat03.server.grpc.quarkus; + +import static org.a2aproject.sdk.server.ServerCallContext.TRANSPORT_KEY; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; + +import io.grpc.Context; +import io.grpc.Metadata; +import io.grpc.stub.StreamObserver; +import io.quarkus.security.identity.SecurityIdentity; +import org.a2aproject.sdk.compat03.conversion.A2AProtocol_v0_3; +import org.a2aproject.sdk.compat03.transport.grpc.context.GrpcContextKeys_v0_3; +import org.a2aproject.sdk.compat03.transport.grpc.handler.CallContextFactory_v0_3; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; +import org.a2aproject.sdk.server.auth.UnauthenticatedUser; +import org.a2aproject.sdk.server.auth.User; +import org.a2aproject.sdk.spec.TransportProtocol; + +@ApplicationScoped +public class QuarkusCallContextFactory_v0_3 implements CallContextFactory_v0_3 { + + @Inject + Instance securityIdentityInstance; + + @Override + public ServerCallContext create(StreamObserver responseObserver) { + User user; + if (securityIdentityInstance.isResolvable()) { + SecurityIdentity securityIdentity = securityIdentityInstance.get(); + if (!securityIdentity.isAnonymous()) { + user = new AuthenticatedUser(securityIdentity.getPrincipal().getName()); + } else { + user = UnauthenticatedUser.INSTANCE; + } + } else { + user = UnauthenticatedUser.INSTANCE; + } + + Map state = new HashMap<>(); + state.put(TRANSPORT_KEY, TransportProtocol.GRPC); + state.put("grpc_response_observer", responseObserver); + + Context currentContext = Context.current(); + if (currentContext != null) { + state.put("grpc_context", currentContext); + io.grpc.Metadata grpcMetadata = GrpcContextKeys_v0_3.METADATA_KEY.get(currentContext); + if (grpcMetadata != null) { + state.put("grpc_metadata", grpcMetadata); + Map headers = new HashMap<>(); + for (String key : grpcMetadata.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + continue; + } + headers.put(key, grpcMetadata.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER))); + } + state.put("headers", headers); + } + String methodName = GrpcContextKeys_v0_3.METHOD_NAME_KEY.get(currentContext); + if (methodName != null) { + state.put("grpc_method_name", methodName); + } + String peerInfo = GrpcContextKeys_v0_3.PEER_INFO_KEY.get(currentContext); + if (peerInfo != null) { + state.put("grpc_peer_info", peerInfo); + } + } + + return new ServerCallContext(user, state, new HashSet<>(), A2AProtocol_v0_3.PROTOCOL_VERSION); + } +} diff --git a/compat-0.3/reference/grpc/src/test/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusA2AGrpc_v0_3_WithTaskAuthorizationTest.java b/compat-0.3/reference/grpc/src/test/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusA2AGrpc_v0_3_WithTaskAuthorizationTest.java new file mode 100644 index 000000000..8bb82b83a --- /dev/null +++ b/compat-0.3/reference/grpc/src/test/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusA2AGrpc_v0_3_WithTaskAuthorizationTest.java @@ -0,0 +1,65 @@ +package org.a2aproject.sdk.compat03.server.grpc.quarkus; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import org.a2aproject.sdk.compat03.client.ClientBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.grpc.GrpcTransport_v0_3; +import org.a2aproject.sdk.compat03.client.transport.grpc.GrpcTransportConfigBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.spi.interceptors.auth.AuthInterceptor_v0_3; +import org.a2aproject.sdk.compat03.conversion.AbstractA2AServerWithTaskAuthorizationTest_v0_3; +import org.a2aproject.sdk.compat03.conversion.TaskAuthorizationTestProfile_v0_3; +import org.a2aproject.sdk.compat03.spec.TransportProtocol_v0_3; +import org.junit.jupiter.api.AfterAll; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile_v0_3.class) +public class QuarkusA2AGrpc_v0_3_WithTaskAuthorizationTest extends AbstractA2AServerWithTaskAuthorizationTest_v0_3 { + + private static final Map channels = new ConcurrentHashMap<>(); + + public QuarkusA2AGrpc_v0_3_WithTaskAuthorizationTest() { + super(8081); + } + + @Override + protected String getTransportProtocol() { + return TransportProtocol_v0_3.GRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder_v0_3 builder, String username, String password) { + AuthInterceptor_v0_3 authInterceptor = new AuthInterceptor_v0_3( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(GrpcTransport_v0_3.class, new GrpcTransportConfigBuilder_v0_3() + .channelFactory(target -> { + ManagedChannel channel = ManagedChannelBuilder.forTarget(target).usePlaintext().build(); + channels.put(username, channel); + return channel; + }) + .addInterceptor(authInterceptor)); + } + + @AfterAll + static void closeChannels() { + channels.values().forEach(ch -> { + ch.shutdownNow(); + try { + ch.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } +} diff --git a/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java b/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java index e4bf2434c..bbd89645c 100644 --- a/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java +++ b/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java @@ -58,6 +58,7 @@ import org.a2aproject.sdk.compat03.util.Utils_v0_3; import org.a2aproject.sdk.server.PublicAgentCard; import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; import org.a2aproject.sdk.server.auth.UnauthenticatedUser; import org.a2aproject.sdk.server.auth.User; import org.a2aproject.sdk.server.common.quarkus.SseResponseWriter; @@ -298,17 +299,8 @@ private ServerCallContext createCallContext(RoutingContext rc) { if (rc.user() == null) { user = UnauthenticatedUser.INSTANCE; } else { - user = new User() { - @Override - public boolean isAuthenticated() { - return rc.userContext().authenticated(); - } - - @Override - public String getUsername() { - return rc.user().subject(); - } - }; + String subject = rc.user().subject(); + user = new AuthenticatedUser(subject != null ? subject : ""); } Map state = new HashMap<>(); diff --git a/compat-0.3/reference/jsonrpc/src/test/java/org/a2aproject/sdk/compat03/server/apps/quarkus/QuarkusA2AJSONRPC_v0_3_WithTaskAuthorizationVertxTest.java b/compat-0.3/reference/jsonrpc/src/test/java/org/a2aproject/sdk/compat03/server/apps/quarkus/QuarkusA2AJSONRPC_v0_3_WithTaskAuthorizationVertxTest.java new file mode 100644 index 000000000..c932a4dd7 --- /dev/null +++ b/compat-0.3/reference/jsonrpc/src/test/java/org/a2aproject/sdk/compat03/server/apps/quarkus/QuarkusA2AJSONRPC_v0_3_WithTaskAuthorizationVertxTest.java @@ -0,0 +1,47 @@ +package org.a2aproject.sdk.compat03.server.apps.quarkus; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.compat03.client.ClientBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.jsonrpc.JSONRPCTransport_v0_3; +import org.a2aproject.sdk.compat03.client.transport.jsonrpc.JSONRPCTransportConfigBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.spi.interceptors.auth.AuthInterceptor_v0_3; +import org.a2aproject.sdk.compat03.conversion.AbstractA2AServerWithTaskAuthorizationTest_v0_3; +import org.a2aproject.sdk.compat03.conversion.TaskAuthorizationTestProfile_v0_3; +import org.a2aproject.sdk.compat03.spec.TransportProtocol_v0_3; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile_v0_3.class) +public class QuarkusA2AJSONRPC_v0_3_WithTaskAuthorizationVertxTest extends AbstractA2AServerWithTaskAuthorizationTest_v0_3 { + + @Inject + Vertx vertx; + + public QuarkusA2AJSONRPC_v0_3_WithTaskAuthorizationVertxTest() { + super(8081); + } + + @Override + protected String getTransportProtocol() { + return TransportProtocol_v0_3.JSONRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder_v0_3 builder, String username, String password) { + AuthInterceptor_v0_3 authInterceptor = new AuthInterceptor_v0_3( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(JSONRPCTransport_v0_3.class, + new JSONRPCTransportConfigBuilder_v0_3() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java b/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java index da53ed42d..5c80c0230 100644 --- a/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java +++ b/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java @@ -45,6 +45,7 @@ import org.a2aproject.sdk.compat03.transport.rest.handler.RestHandler_v0_3.HTTPRestStreamingResponse; import org.a2aproject.sdk.server.PublicAgentCard; import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; import org.a2aproject.sdk.server.auth.UnauthenticatedUser; import org.a2aproject.sdk.server.auth.User; import org.a2aproject.sdk.server.common.quarkus.SseResponseWriter; @@ -406,24 +407,8 @@ private ServerCallContext createCallContext(RoutingContext rc, String jsonRpcMet if (rc.user() == null) { user = UnauthenticatedUser.INSTANCE; } else { - user = new User() { - @Override - public boolean isAuthenticated() { - if (rc.userContext() != null) { - return rc.userContext().authenticated(); - } - return false; - } - - @Override - public String getUsername() { - if (rc.user() != null) { - String subject = rc.user().subject(); - return subject != null ? subject : ""; - } - return ""; - } - }; + String subject = rc.user().subject(); + user = new AuthenticatedUser(subject != null ? subject : ""); } Map state = new HashMap<>(); diff --git a/compat-0.3/reference/rest/src/test/java/org/a2aproject/sdk/compat03/server/rest/quarkus/QuarkusA2ARest_v0_3_WithTaskAuthorizationVertxTest.java b/compat-0.3/reference/rest/src/test/java/org/a2aproject/sdk/compat03/server/rest/quarkus/QuarkusA2ARest_v0_3_WithTaskAuthorizationVertxTest.java new file mode 100644 index 000000000..5dd10ac7b --- /dev/null +++ b/compat-0.3/reference/rest/src/test/java/org/a2aproject/sdk/compat03/server/rest/quarkus/QuarkusA2ARest_v0_3_WithTaskAuthorizationVertxTest.java @@ -0,0 +1,47 @@ +package org.a2aproject.sdk.compat03.server.rest.quarkus; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.compat03.client.ClientBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.rest.RestTransport_v0_3; +import org.a2aproject.sdk.compat03.client.transport.rest.RestTransportConfigBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.transport.spi.interceptors.auth.AuthInterceptor_v0_3; +import org.a2aproject.sdk.compat03.conversion.AbstractA2AServerWithTaskAuthorizationTest_v0_3; +import org.a2aproject.sdk.compat03.conversion.TaskAuthorizationTestProfile_v0_3; +import org.a2aproject.sdk.compat03.spec.TransportProtocol_v0_3; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile_v0_3.class) +public class QuarkusA2ARest_v0_3_WithTaskAuthorizationVertxTest extends AbstractA2AServerWithTaskAuthorizationTest_v0_3 { + + @Inject + Vertx vertx; + + public QuarkusA2ARest_v0_3_WithTaskAuthorizationVertxTest() { + super(8081); + } + + @Override + protected String getTransportProtocol() { + return TransportProtocol_v0_3.HTTP_JSON.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder_v0_3 builder, String username, String password) { + AuthInterceptor_v0_3 authInterceptor = new AuthInterceptor_v0_3( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(RestTransport_v0_3.class, + new RestTransportConfigBuilder_v0_3() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/AbstractA2AServerWithTaskAuthorizationTest_v0_3.java b/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/AbstractA2AServerWithTaskAuthorizationTest_v0_3.java new file mode 100644 index 000000000..b0898fe1b --- /dev/null +++ b/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/AbstractA2AServerWithTaskAuthorizationTest_v0_3.java @@ -0,0 +1,216 @@ +package org.a2aproject.sdk.compat03.conversion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.a2aproject.sdk.compat03.client.Client_v0_3; +import org.a2aproject.sdk.compat03.client.ClientBuilder_v0_3; +import org.a2aproject.sdk.compat03.client.TaskEvent_v0_3; +import org.a2aproject.sdk.compat03.client.TaskUpdateEvent_v0_3; +import org.a2aproject.sdk.compat03.client.config.ClientConfig_v0_3; +import org.a2aproject.sdk.compat03.spec.A2AClientException_v0_3; +import org.a2aproject.sdk.compat03.spec.AgentCapabilities_v0_3; +import org.a2aproject.sdk.compat03.spec.AgentCard_v0_3; +import org.a2aproject.sdk.compat03.spec.AgentInterface_v0_3; +import org.a2aproject.sdk.compat03.spec.HTTPAuthSecurityScheme_v0_3; +import org.a2aproject.sdk.compat03.spec.Message_v0_3; +import org.a2aproject.sdk.compat03.spec.Task_v0_3; +import org.a2aproject.sdk.compat03.spec.TaskIdParams_v0_3; +import org.a2aproject.sdk.compat03.spec.TaskNotFoundError_v0_3; +import org.a2aproject.sdk.compat03.spec.TaskQueryParams_v0_3; +import org.a2aproject.sdk.compat03.spec.TaskState_v0_3; +import org.a2aproject.sdk.compat03.spec.TextPart_v0_3; +import org.junit.jupiter.api.Test; + +/** + * Abstract base class for v0.3 task authorization integration tests. + *

+ * Mirrors {@link org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest} + * but uses v0.3 client types. Tests verify that + * {@link org.a2aproject.sdk.server.auth.TaskAuthorizationProvider} is enforced + * through the v0.3 compatibility layer. + *

+ * Note: v0.3 has no {@code listTasks()} API, so {@code testListTasksShowsOnlyOwnTasks} + * is not included. + */ +public abstract class AbstractA2AServerWithTaskAuthorizationTest_v0_3 { + + protected static final String USER_A = "testuser"; + protected static final String USER_A_PASSWORD = "testpass"; + protected static final String USER_B = "userB"; + protected static final String USER_B_PASSWORD = "passB"; + protected static final String BASIC_AUTH_SCHEME_NAME = "basicAuth"; + + protected final int serverPort; + + protected AbstractA2AServerWithTaskAuthorizationTest_v0_3(int serverPort) { + this.serverPort = serverPort; + } + + protected abstract String getTransportProtocol(); + + protected abstract String getTransportUrl(); + + protected abstract void configureTransportWithCredentials(ClientBuilder_v0_3 builder, String username, String password); + + protected Client_v0_3 createClient(String username, String password) throws A2AClientException_v0_3 { + AgentCard_v0_3 agentCard = createTestAgentCard(); + ClientConfig_v0_3 clientConfig = new ClientConfig_v0_3.Builder().setStreaming(false).build(); + ClientBuilder_v0_3 clientBuilder = Client_v0_3.builder(agentCard).clientConfig(clientConfig); + configureTransportWithCredentials(clientBuilder, username, password); + return clientBuilder.build(); + } + + private AgentCard_v0_3 createTestAgentCard() { + return new AgentCard_v0_3.Builder() + .name("test-card") + .description("A test agent card") + .url(getTransportUrl()) + .version("1.0") + .preferredTransport(getTransportProtocol()) + .capabilities(new AgentCapabilities_v0_3.Builder() + .streaming(false) + .pushNotifications(false) + .stateTransitionHistory(false) + .build()) + .defaultInputModes(List.of("text")) + .defaultOutputModes(List.of("text")) + .skills(List.of()) + .additionalInterfaces(List.of(new AgentInterface_v0_3(getTransportProtocol(), getTransportUrl()))) + .securitySchemes(Map.of( + BASIC_AUTH_SCHEME_NAME, + new HTTPAuthSecurityScheme_v0_3.Builder() + .scheme("basic") + .description("HTTP Basic authentication") + .build())) + .security(List.of(Map.of(BASIC_AUTH_SCHEME_NAME, List.of()))) + .build(); + } + + protected static String getEncodedCredentials(String username, String password) { + return Base64.getEncoder().encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); + } + + protected Task_v0_3 sendMessageAndGetTask(Client_v0_3 client, String messageText) throws Exception { + Message_v0_3 message = new Message_v0_3.Builder() + .messageId(UUID.randomUUID().toString()) + .role(Message_v0_3.Role.USER) + .parts(new TextPart_v0_3("a2a-local:" + messageText)) + .build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference receivedTask = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + + client.sendMessage(message, List.of((event, agentCard) -> { + if (event instanceof TaskEvent_v0_3 te) { + receivedTask.set(te.getTask()); + if (te.getTask().status().state() == TaskState_v0_3.COMPLETED) { + latch.countDown(); + } + } else if (event instanceof TaskUpdateEvent_v0_3 tue) { + receivedTask.set(tue.getTask()); + if (tue.getTask().status().state() == TaskState_v0_3.COMPLETED) { + latch.countDown(); + } + } + }), error -> { + errorRef.set(error); + latch.countDown(); + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS), "Task should complete within timeout"); + assertNull(errorRef.get(), "Should not have received an error: " + errorRef.get()); + + Task_v0_3 task = receivedTask.get(); + assertNotNull(task, "Should have received a task"); + assertEquals(TaskState_v0_3.COMPLETED, task.status().state()); + return task; + } + + @Test + public void testOwnerCanGetOwnTask() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 task = sendMessageAndGetTask(clientA, "owner-get-test"); + + Task_v0_3 retrieved = clientA.getTask(new TaskQueryParams_v0_3(task.id())); + assertNotNull(retrieved); + assertEquals(task.id(), retrieved.id()); + } + + @Test + public void testNonOwnerCannotGetTask() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 task = sendMessageAndGetTask(clientA, "non-owner-get-test"); + + Client_v0_3 clientB = createClient(USER_B, USER_B_PASSWORD); + A2AClientException_v0_3 error = assertThrows(A2AClientException_v0_3.class, () -> + clientB.getTask(new TaskQueryParams_v0_3(task.id()))); + assertInstanceOf(TaskNotFoundError_v0_3.class, error.getCause()); + } + + @Test + public void testOwnerCanCancelOwnTask() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 task = sendMessageAndGetTask(clientA, "owner-cancel-test"); + + try { + clientA.cancelTask(new TaskIdParams_v0_3(task.id())); + } catch (A2AClientException_v0_3 e) { + assertFalse(e.getCause() instanceof TaskNotFoundError_v0_3, + "Owner should not get TaskNotFoundError when canceling own task"); + } + } + + @Test + public void testNonOwnerCannotCancelTask() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 task = sendMessageAndGetTask(clientA, "non-owner-cancel-test"); + + Client_v0_3 clientB = createClient(USER_B, USER_B_PASSWORD); + A2AClientException_v0_3 error = assertThrows(A2AClientException_v0_3.class, () -> + clientB.cancelTask(new TaskIdParams_v0_3(task.id()))); + assertInstanceOf(TaskNotFoundError_v0_3.class, error.getCause()); + } + + @Test + public void testUnauthorizedLooksLikeNotFound() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 task = sendMessageAndGetTask(clientA, "info-hiding-test"); + + Client_v0_3 clientB = createClient(USER_B, USER_B_PASSWORD); + + A2AClientException_v0_3 unauthorizedError = assertThrows(A2AClientException_v0_3.class, () -> + clientB.getTask(new TaskQueryParams_v0_3(task.id()))); + A2AClientException_v0_3 notFoundError = assertThrows(A2AClientException_v0_3.class, () -> + clientB.getTask(new TaskQueryParams_v0_3(UUID.randomUUID().toString()))); + + assertInstanceOf(TaskNotFoundError_v0_3.class, unauthorizedError.getCause()); + assertInstanceOf(TaskNotFoundError_v0_3.class, notFoundError.getCause()); + } + + @Test + public void testBothUsersCanCreateTasks() throws Exception { + Client_v0_3 clientA = createClient(USER_A, USER_A_PASSWORD); + Task_v0_3 taskA = sendMessageAndGetTask(clientA, "create-test-a"); + assertNotNull(taskA.id()); + + Client_v0_3 clientB = createClient(USER_B, USER_B_PASSWORD); + Task_v0_3 taskB = sendMessageAndGetTask(clientB, "create-test-b"); + assertNotNull(taskB.id()); + } +} diff --git a/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/TaskAuthorizationTestProfile_v0_3.java b/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/TaskAuthorizationTestProfile_v0_3.java new file mode 100644 index 000000000..79f8784fc --- /dev/null +++ b/compat-0.3/server-conversion/src/test/java/org/a2aproject/sdk/compat03/conversion/TaskAuthorizationTestProfile_v0_3.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.compat03.conversion; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkus.test.junit.QuarkusTestProfile; + +public class TaskAuthorizationTestProfile_v0_3 extends AuthTestProfile_v0_3 { + + @Override + public Map getConfigOverrides() { + Map config = new HashMap<>(super.getConfigOverrides()); + config.put("quarkus.security.users.embedded.users.userB", "passB"); + config.put("quarkus.security.users.embedded.roles.userB", "user"); + config.put("test.task-authorization.enabled", "true"); + return config; + } +} diff --git a/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/AgentExecutorProducer_v0_3.java b/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/AgentExecutorProducer_v0_3.java index 1e06f9b47..01a04aabd 100644 --- a/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/AgentExecutorProducer_v0_3.java +++ b/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/AgentExecutorProducer_v0_3.java @@ -38,6 +38,15 @@ public void execute(RequestContext context, AgentEmitter agentEmitter) throws A2 return; } + // Local handling: creates a task with startWork() + addArtifact() + complete() + if (input.startsWith("a2a-local:")) { + String payload = input.substring("a2a-local:".length()); + agentEmitter.startWork(); + agentEmitter.addArtifact(List.of(new TextPart("Handled locally: " + payload))); + agentEmitter.complete(); + return; + } + // Special handling for multi-event test (routed by message content) if (input.startsWith("multi-event:first")) { agentEmitter.startWork(); diff --git a/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/TestTaskAuthorizationProvider_v0_3.java b/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/TestTaskAuthorizationProvider_v0_3.java new file mode 100644 index 000000000..a03c92e40 --- /dev/null +++ b/compat-0.3/tests/server-common/src/main/java/org/a2aproject/sdk/compat03/conversion/TestTaskAuthorizationProvider_v0_3.java @@ -0,0 +1,48 @@ +package org.a2aproject.sdk.compat03.conversion; + +import java.util.concurrent.ConcurrentHashMap; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.arc.Unremovable; +import io.quarkus.arc.properties.IfBuildProperty; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; + +@ApplicationScoped +@Unremovable +@IfBuildProperty(name = "test.task-authorization.enabled", stringValue = "true", enableIfMissing = false) +public class TestTaskAuthorizationProvider_v0_3 implements TaskAuthorizationProvider { + + private final ConcurrentHashMap taskOwners = new ConcurrentHashMap<>(); + + @Override + public boolean checkRead(ServerCallContext context, String taskId, TaskOperation operation) { + String owner = taskOwners.get(taskId); + // Intentionally fail-open for testing; production implementations should fail-closed (deny unknown tasks) + return owner == null || owner.equals(context.getUser().getUsername()); + } + + @Override + public boolean checkWrite(ServerCallContext context, String taskId, TaskOperation operation) { + String owner = taskOwners.get(taskId); + // Intentionally fail-open for testing; production implementations should fail-closed (deny unknown tasks) + return owner == null || owner.equals(context.getUser().getUsername()); + } + + @Override + public boolean checkCreate(ServerCallContext context, TaskOperation operation) { + return context.getUser().isAuthenticated(); + } + + @Override + public boolean isTaskRecorded(String taskId) { + return taskOwners.containsKey(taskId); + } + + @Override + public void recordOwnership(ServerCallContext context, String taskId, TaskOperation operation) { + taskOwners.putIfAbsent(taskId, context.getUser().getUsername()); + } +} diff --git a/extras/task-store-database-jpa/src/main/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java b/extras/task-store-database-jpa/src/main/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java index b22d3fca5..32dda8953 100644 --- a/extras/task-store-database-jpa/src/main/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java +++ b/extras/task-store-database-jpa/src/main/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java @@ -10,6 +10,8 @@ import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.event.Event; import jakarta.enterprise.inject.Alternative; +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import jakarta.persistence.EntityManager; import jakarta.persistence.PersistenceContext; @@ -20,6 +22,9 @@ import org.a2aproject.sdk.extras.common.events.TaskFinalizedEvent; import org.a2aproject.sdk.jsonrpc.common.json.JsonProcessingException; import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; import org.a2aproject.sdk.server.config.A2AConfigProvider; import org.a2aproject.sdk.server.tasks.TaskStateProvider; import org.a2aproject.sdk.server.tasks.TaskStore; @@ -30,6 +35,7 @@ import org.a2aproject.sdk.spec.Message; import org.a2aproject.sdk.util.PageToken; import org.a2aproject.sdk.spec.Task; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +56,19 @@ public class JpaDatabaseTaskStore implements TaskStore, TaskStateProvider { @Inject A2AConfigProvider configProvider; + private final @org.jspecify.annotations.Nullable TaskAuthorizationProvider authorizationProvider; + + public JpaDatabaseTaskStore() { + this.authorizationProvider = null; + } + + @Inject + public JpaDatabaseTaskStore(@Any Instance authorizationProviderInstance) { + this.authorizationProvider = authorizationProviderInstance.isResolvable() + ? authorizationProviderInstance.get() + : null; + } + /** * Grace period for task finalization in replicated scenarios (seconds). * After a task reaches a final state, this is the minimum time to wait before cleanup @@ -250,109 +269,81 @@ public boolean isTaskFinalized(String taskId) { @Transactional @Override - public ListTasksResult list(ListTasksParams params) { + public ListTasksResult list(ListTasksParams params, @Nullable ServerCallContext context) { LOGGER.debug("Listing tasks with params: contextId={}, status={}, pageSize={}, pageToken={}", params.contextId(), params.status(), params.pageSize(), params.pageToken()); try { - // Parse pageToken once at the beginning - PageToken pageToken = PageToken.fromString(params.pageToken()); - Instant tokenTimestamp = pageToken != null ? pageToken.timestamp() : null; - String tokenId = pageToken != null ? pageToken.id() : null; - - // Build dynamic JPQL query with WHERE clauses for filtering - StringBuilder queryBuilder = new StringBuilder("SELECT t FROM JpaTask t WHERE 1=1"); - StringBuilder countQueryBuilder = new StringBuilder("SELECT COUNT(t) FROM JpaTask t WHERE 1=1"); - - // Apply contextId filter using denormalized column - if (params.contextId() != null) { - queryBuilder.append(" AND t.contextId = :contextId"); - countQueryBuilder.append(" AND t.contextId = :contextId"); - } - - // Apply status filter using denormalized column - if (params.status() != null) { - queryBuilder.append(" AND t.state = :state"); - countQueryBuilder.append(" AND t.state = :state"); - } - - // Apply statusTimestampAfter filter using denormalized timestamp column - if (params.statusTimestampAfter() != null) { - queryBuilder.append(" AND t.statusTimestamp > :statusTimestampAfter"); - countQueryBuilder.append(" AND t.statusTimestamp > :statusTimestampAfter"); - } - - // Apply pagination cursor using keyset pagination for composite sort (timestamp DESC, id ASC) - if (tokenTimestamp != null) { - // Keyset pagination: get tasks where timestamp < tokenTimestamp OR (timestamp = tokenTimestamp AND id > tokenId) - queryBuilder.append(" AND (t.statusTimestamp < :tokenTimestamp OR (t.statusTimestamp = :tokenTimestamp AND t.id > :tokenId))"); - } - - // Sort by status timestamp descending (most recent first), then by ID for stable ordering - queryBuilder.append(" ORDER BY t.statusTimestamp DESC, t.id ASC"); - - // Create and configure the main query - TypedQuery query = em.createQuery(queryBuilder.toString(), JpaTask.class); - - // Set filter parameters - if (params.contextId() != null) { - query.setParameter("contextId", params.contextId()); - } - if (params.status() != null) { - query.setParameter("state", params.status().name()); - } - if (params.statusTimestampAfter() != null) { - query.setParameter("statusTimestampAfter", params.statusTimestampAfter()); - } - if (tokenTimestamp != null) { - query.setParameter("tokenTimestamp", tokenTimestamp); - query.setParameter("tokenId", tokenId); - } - - // Apply page size limit (+1 to check for next page) int pageSize = params.getEffectivePageSize(); - query.setMaxResults(pageSize + 1); - - // Execute query and deserialize tasks - List jpaTasksPage = query.getResultList(); - // Determine if there are more results - boolean hasMore = jpaTasksPage.size() > pageSize; - if (hasMore) { - jpaTasksPage = jpaTasksPage.subList(0, pageSize); - } + // Build base WHERE clause (without cursor — shared across iterations) + String baseWhereClause = buildBaseWhereClause(params); + + List tasks; + boolean hasMore; + int totalSize; + + if (authorizationProvider != null && context != null) { + // Iterative fetch: accumulate pageSize authorized results across DB pages + tasks = new ArrayList<>(pageSize); + PageToken cursor = PageToken.fromString(params.pageToken()); + boolean dbExhausted = false; + int maxIterations = 10; + int iterations = 0; + + while (tasks.size() < pageSize && !dbExhausted && iterations < maxIterations) { + iterations++; + int remaining = pageSize - tasks.size(); + int limit = remaining + 1; + TypedQuery query = createPageQuery( + baseWhereClause, params, cursor, limit); + List batch = query.getResultList(); + + dbExhausted = batch.size() < limit; + + int processedCount = 0; + for (JpaTask jpaTask : batch) { + processedCount++; + Task task = deserializeTask(jpaTask); + if (authorizationProvider.checkRead(context, task.id(), TaskOperation.LIST_TASKS)) { + tasks.add(task); + if (tasks.size() == pageSize) { + break; + } + } + } + + // Advance cursor to last fetched DB row for next iteration + if (processedCount > 0) { + JpaTask last = batch.get(processedCount - 1); + cursor = new PageToken(last.getStatusTimestamp(), last.getId()); + } + } - // Get total count of matching tasks - TypedQuery countQuery = em.createQuery(countQueryBuilder.toString(), Long.class); - if (params.contextId() != null) { - countQuery.setParameter("contextId", params.contextId()); - } - if (params.status() != null) { - countQuery.setParameter("state", params.status().name()); - } - if (params.statusTimestampAfter() != null) { - countQuery.setParameter("statusTimestampAfter", params.statusTimestampAfter()); - } - int totalSize = countQuery.getSingleResult().intValue(); - - // Deserialize tasks from JSON - List tasks = new ArrayList<>(); - for (JpaTask jpaTask : jpaTasksPage) { - try { - tasks.add(jpaTask.getTask()); - } catch (JsonProcessingException e) { - LOGGER.error("Failed to deserialize task with ID: {}", jpaTask.getId(), e); - throw new TaskSerializationException(jpaTask.getId(), - "Failed to deserialize task during list operation", e); + hasMore = !dbExhausted; + // Use the authorized count to avoid leaking the existence of unauthorized tasks + totalSize = tasks.size(); + } else { + // Single fetch — no authorization filtering needed + totalSize = executeCountQuery(baseWhereClause, params); + PageToken cursor = PageToken.fromString(params.pageToken()); + TypedQuery query = createPageQuery( + baseWhereClause, params, cursor, pageSize + 1); + List jpaTasksPage = query.getResultList(); + + hasMore = jpaTasksPage.size() > pageSize; + int batchEnd = Math.min(jpaTasksPage.size(), pageSize); + + tasks = new ArrayList<>(batchEnd); + for (int i = 0; i < batchEnd; i++) { + tasks.add(deserializeTask(jpaTasksPage.get(i))); } } - // Determine next page token (timestamp:ID of last task if there are more results) - // Format: "timestamp_millis:taskId" for keyset pagination + // Determine next page token from the last returned task String nextPageToken = null; if (hasMore && !tasks.isEmpty()) { Task lastTask = tasks.get(tasks.size() - 1); - // All tasks have timestamps (TaskStatus canonical constructor ensures this) Instant timestamp = lastTask.status().timestamp().toInstant(); nextPageToken = new PageToken(timestamp, lastTask.id()).toString(); } @@ -369,13 +360,74 @@ public ListTasksResult list(ListTasksParams params) { return new ListTasksResult(transformedTasks, totalSize, transformedTasks.size(), nextPageToken); } catch (PersistenceException e) { - // Database errors from query creation, execution, or count LOGGER.error("Database query failed during list operation", e); - throw new TaskPersistenceException(null, // No single taskId for list operation + throw new TaskPersistenceException(null, "Database query failed during list operation", e); } } + private String buildBaseWhereClause(ListTasksParams params) { + StringBuilder sb = new StringBuilder(" WHERE 1=1"); + if (params.contextId() != null) { + sb.append(" AND t.contextId = :contextId"); + } + if (params.status() != null) { + sb.append(" AND t.state = :state"); + } + if (params.statusTimestampAfter() != null) { + sb.append(" AND t.statusTimestamp > :statusTimestampAfter"); + } + return sb.toString(); + } + + private TypedQuery createPageQuery(String baseWhereClause, + ListTasksParams params, @org.jspecify.annotations.Nullable PageToken cursor, int maxResults) { + StringBuilder jpql = new StringBuilder("SELECT t FROM JpaTask t").append(baseWhereClause); + if (cursor != null) { + jpql.append(" AND (t.statusTimestamp < :tokenTimestamp" + + " OR (t.statusTimestamp = :tokenTimestamp AND t.id > :tokenId))"); + } + jpql.append(" ORDER BY t.statusTimestamp DESC, t.id ASC"); + + TypedQuery query = em.createQuery(jpql.toString(), JpaTask.class); + setFilterParameters(query, params); + if (cursor != null) { + query.setParameter("tokenTimestamp", cursor.timestamp()); + query.setParameter("tokenId", cursor.id()); + } + query.setMaxResults(maxResults); + return query; + } + + private int executeCountQuery(String baseWhereClause, ListTasksParams params) { + String jpql = "SELECT COUNT(t) FROM JpaTask t" + baseWhereClause; + TypedQuery countQuery = em.createQuery(jpql, Long.class); + setFilterParameters(countQuery, params); + return countQuery.getSingleResult().intValue(); + } + + private void setFilterParameters(TypedQuery query, ListTasksParams params) { + if (params.contextId() != null) { + query.setParameter("contextId", params.contextId()); + } + if (params.status() != null) { + query.setParameter("state", params.status().name()); + } + if (params.statusTimestampAfter() != null) { + query.setParameter("statusTimestampAfter", params.statusTimestampAfter()); + } + } + + private Task deserializeTask(JpaTask jpaTask) { + try { + return jpaTask.getTask(); + } catch (JsonProcessingException e) { + LOGGER.error("Failed to deserialize task with ID: {}", jpaTask.getId(), e); + throw new TaskSerializationException(jpaTask.getId(), + "Failed to deserialize task during list operation", e); + } + } + private Task transformTask(Task task, int historyLength, boolean includeArtifacts) { // Limit history if needed (keep most recent N messages) List history = task.history(); diff --git a/extras/task-store-database-jpa/src/test/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java b/extras/task-store-database-jpa/src/test/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java index ababb7977..f95a11fa1 100644 --- a/extras/task-store-database-jpa/src/test/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java +++ b/extras/task-store-database-jpa/src/test/java/org/a2aproject/sdk/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java @@ -291,7 +291,7 @@ public void testListTasksEmpty() { .contextId("non-existent-context-12345") .tenant("tenant") .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertNotNull(result); assertEquals(0, result.totalSize()); @@ -331,7 +331,7 @@ public void testListTasksFilterByContextId() { .contextId("context-A") .tenant("tenant") .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(2, result.totalSize()); assertEquals(2, result.pageSize()); @@ -371,7 +371,7 @@ public void testListTasksFilterByStatus() { .tenant("tenant") .status(TaskState.TASK_STATE_WORKING) .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(1, result.totalSize()); assertEquals(1, result.pageSize()); @@ -411,7 +411,7 @@ public void testListTasksCombinedFilters() { .tenant("tenant") .status(TaskState.TASK_STATE_WORKING) .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(1, result.totalSize()); assertEquals(1, result.pageSize()); @@ -441,7 +441,7 @@ public void testListTasksPagination() { .tenant("tenant") .pageSize(2) .build(); - ListTasksResult result1 = taskStore.list(params1); + ListTasksResult result1 = taskStore.list(params1, null); assertEquals(5, result1.totalSize()); assertEquals(2, result1.pageSize()); @@ -455,7 +455,7 @@ public void testListTasksPagination() { .pageSize(2) .pageToken(result1.nextPageToken()) .build(); - ListTasksResult result2 = taskStore.list(params2); + ListTasksResult result2 = taskStore.list(params2, null); assertEquals(5, result2.totalSize()); assertEquals(2, result2.pageSize()); @@ -468,7 +468,7 @@ public void testListTasksPagination() { .pageSize(2) .pageToken(result2.nextPageToken()) .build(); - ListTasksResult result3 = taskStore.list(params3); + ListTasksResult result3 = taskStore.list(params3, null); assertEquals(5, result3.totalSize()); assertEquals(1, result3.pageSize()); @@ -535,7 +535,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .tenant("tenant") .pageSize(2) .build(); - ListTasksResult result1 = taskStore.list(params1); + ListTasksResult result1 = taskStore.list(params1, null); assertEquals(5, result1.totalSize()); assertEquals(2, result1.pageSize()); @@ -558,7 +558,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .pageSize(2) .pageToken(result1.nextPageToken()) .build(); - ListTasksResult result2 = taskStore.list(params2); + ListTasksResult result2 = taskStore.list(params2, null); assertEquals(5, result2.totalSize()); assertEquals(2, result2.pageSize()); @@ -575,7 +575,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .pageSize(2) .pageToken(result2.nextPageToken()) .build(); - ListTasksResult result3 = taskStore.list(params3); + ListTasksResult result3 = taskStore.list(params3, null); assertEquals(5, result3.totalSize()); assertEquals(1, result3.pageSize()); @@ -624,7 +624,7 @@ public void testListTasksHistoryLimiting() { .tenant("tenant") .historyLength(3) .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(1, result.tasks().size()); Task retrieved = result.tasks().get(0); @@ -661,7 +661,7 @@ public void testListTasksArtifactInclusion() { .contextId("context-artifact-unique") .tenant("tenant") .build(); - ListTasksResult resultWithout = taskStore.list(paramsWithoutArtifacts); + ListTasksResult resultWithout = taskStore.list(paramsWithoutArtifacts, null); assertEquals(1, resultWithout.tasks().size()); assertTrue(resultWithout.tasks().get(0).artifacts().isEmpty(), @@ -673,7 +673,7 @@ public void testListTasksArtifactInclusion() { .tenant("tenant") .includeArtifacts(true) .build(); - ListTasksResult resultWith = taskStore.list(paramsWithArtifacts); + ListTasksResult resultWith = taskStore.list(paramsWithArtifacts, null); assertEquals(1, resultWith.tasks().size()); assertEquals(1, resultWith.tasks().get(0).artifacts().size(), @@ -699,7 +699,7 @@ public void testListTasksDefaultPageSize() { .contextId("context-default-pagesize") .tenant("tenant") .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(100, result.totalSize()); assertEquals(50, result.pageSize(), "Default page size should be 50"); @@ -725,7 +725,7 @@ public void testListTasksInvalidPageTokenFormat() { .build(); try { - taskStore.list(params1); + taskStore.list(params1, null); throw new AssertionError("Expected InvalidParamsError for legacy ID-only pageToken"); } catch (org.a2aproject.sdk.spec.InvalidParamsError e) { // Expected - legacy format not supported @@ -741,7 +741,7 @@ public void testListTasksInvalidPageTokenFormat() { .build(); try { - taskStore.list(params2); + taskStore.list(params2, null); throw new AssertionError("Expected InvalidParamsError for malformed timestamp"); } catch (org.a2aproject.sdk.spec.InvalidParamsError e) { // Expected - malformed timestamp @@ -786,7 +786,7 @@ public void testListTasksOrderingById() { .contextId("context-order") .tenant("tenant") .build(); - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, null); assertEquals(3, result.tasks().size()); assertEquals("task-order-a", result.tasks().get(0).id()); diff --git a/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusCallContextFactory.java b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusCallContextFactory.java new file mode 100644 index 000000000..b2fc8252a --- /dev/null +++ b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusCallContextFactory.java @@ -0,0 +1,96 @@ +package org.a2aproject.sdk.server.grpc.quarkus; + +import static org.a2aproject.sdk.server.ServerCallContext.TRANSPORT_KEY; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; + +import io.grpc.Context; +import io.grpc.Metadata; +import io.grpc.stub.StreamObserver; +import io.quarkus.security.identity.SecurityIdentity; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; +import org.a2aproject.sdk.server.auth.UnauthenticatedUser; +import org.a2aproject.sdk.server.auth.User; +import org.a2aproject.sdk.server.extensions.A2AExtensions; +import org.a2aproject.sdk.spec.TransportProtocol; +import org.a2aproject.sdk.transport.grpc.context.GrpcContextKeys; +import org.a2aproject.sdk.transport.grpc.handler.CallContextFactory; + +@ApplicationScoped +public class QuarkusCallContextFactory implements CallContextFactory { + + @Inject + Instance securityIdentityInstance; + + @Override + public ServerCallContext create(StreamObserver responseObserver) { + User user; + if (securityIdentityInstance.isResolvable()) { + SecurityIdentity securityIdentity = securityIdentityInstance.get(); + if (!securityIdentity.isAnonymous()) { + user = new AuthenticatedUser(securityIdentity.getPrincipal().getName()); + } else { + user = UnauthenticatedUser.INSTANCE; + } + } else { + user = UnauthenticatedUser.INSTANCE; + } + + Map state = new HashMap<>(); + state.put(TRANSPORT_KEY, TransportProtocol.GRPC); + state.put("grpc_response_observer", responseObserver); + + Context currentContext = Context.current(); + if (currentContext != null) { + state.put("grpc_context", currentContext); + io.grpc.Metadata grpcMetadata = GrpcContextKeys.METADATA_KEY.get(currentContext); + if (grpcMetadata != null) { + state.put("grpc_metadata", grpcMetadata); + Map headers = new HashMap<>(); + for (String key : grpcMetadata.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + continue; + } + headers.put(key, grpcMetadata.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER))); + } + state.put("headers", headers); + } + String methodName = GrpcContextKeys.GRPC_METHOD_NAME_KEY.get(currentContext); + if (methodName != null) { + state.put("grpc_method_name", methodName); + } + String peerInfo = GrpcContextKeys.PEER_INFO_KEY.get(currentContext); + if (peerInfo != null) { + state.put("grpc_peer_info", peerInfo); + } + } + + String requestedVersion = null; + try { + requestedVersion = GrpcContextKeys.VERSION_HEADER_KEY.get(); + } catch (Exception e) { + // Context not available + } + + Set requestedExtensions = new HashSet<>(); + try { + String extensionsHeader = GrpcContextKeys.EXTENSIONS_HEADER_KEY.get(); + if (extensionsHeader != null) { + requestedExtensions = A2AExtensions.getRequestedExtensions(List.of(extensionsHeader)); + } + } catch (Exception e) { + // Context not available + } + + return new ServerCallContext(user, state, requestedExtensions, requestedVersion); + } +} diff --git a/reference/grpc/src/test/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusA2AGrpcWithTaskAuthorizationTest.java b/reference/grpc/src/test/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusA2AGrpcWithTaskAuthorizationTest.java new file mode 100644 index 000000000..41515683a --- /dev/null +++ b/reference/grpc/src/test/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusA2AGrpcWithTaskAuthorizationTest.java @@ -0,0 +1,73 @@ +package org.a2aproject.sdk.server.grpc.quarkus; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.transport.grpc.GrpcTransport; +import org.a2aproject.sdk.client.transport.grpc.GrpcTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.PublicAgentCard; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.server.apps.common.TaskAuthorizationTestProfile; +import org.a2aproject.sdk.spec.AgentCard; +import org.a2aproject.sdk.spec.TransportProtocol; +import org.junit.jupiter.api.AfterAll; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class QuarkusA2AGrpcWithTaskAuthorizationTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + @PublicAgentCard + AgentCard agentCard; + + private static final Map channels = new ConcurrentHashMap<>(); + + @Override + protected String getTransportProtocol() { + return TransportProtocol.GRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "localhost:8081"; + } + + @Override + protected AgentCard fetchAgentCardFromServer() { + return agentCard; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(GrpcTransport.class, new GrpcTransportConfigBuilder() + .channelFactory(target -> { + ManagedChannel channel = ManagedChannelBuilder.forTarget(target).usePlaintext().build(); + channels.put(username, channel); + return channel; + }) + .addInterceptor(authInterceptor)); + } + + @AfterAll + static void closeChannels() { + channels.values().forEach(ch -> { + ch.shutdownNow(); + try { + ch.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } +} diff --git a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/A2AServerRoutes.java index 1f6bc0e23..41818baa5 100644 --- a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/A2AServerRoutes.java @@ -63,6 +63,7 @@ import org.a2aproject.sdk.jsonrpc.common.wrappers.SubscribeToTaskRequest; import org.a2aproject.sdk.server.AgentCardCacheMetadata; import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; import org.a2aproject.sdk.server.auth.UnauthenticatedUser; import org.a2aproject.sdk.server.auth.User; import org.a2aproject.sdk.server.common.quarkus.SseResponseWriter; @@ -557,17 +558,8 @@ private ServerCallContext createCallContext(RoutingContext rc) { if (rc.user() == null) { user = UnauthenticatedUser.INSTANCE; } else { - user = new User() { - @Override - public boolean isAuthenticated() { - return rc.userContext().authenticated(); - } - - @Override - public String getUsername() { - return rc.user().subject(); - } - }; + String subject = rc.user().subject(); + user = new AuthenticatedUser(subject != null ? subject : ""); } Map state = new HashMap<>(); // TODO Python's impl has diff --git a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/CallContextFactory.java b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/CallContextFactory.java index e6c48b617..e9e82bb10 100644 --- a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/CallContextFactory.java +++ b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/CallContextFactory.java @@ -30,10 +30,7 @@ * public ServerCallContext build(RoutingContext rc) { * // Extract user from Quarkus security context * User user = (rc.user() == null) ? UnauthenticatedUser.INSTANCE : - * new User() { - * public boolean isAuthenticated() { return rc.userContext().authenticated(); } - * public String getUsername() { return rc.user().subject(); } - * }; + * new AuthenticatedUser(rc.user().subject()); * * // Extract custom data from routing context * String orgId = rc.request().getHeader("X-Organization-ID"); diff --git a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/package-info.java b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/package-info.java index de3d9e89c..b4b4501c5 100644 --- a/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/package-info.java +++ b/reference/jsonrpc/src/main/java/org/a2aproject/sdk/server/apps/quarkus/package-info.java @@ -126,10 +126,7 @@ * public ServerCallContext build(RoutingContext rc) { * // Extract user from Quarkus security context * User user = (rc.user() == null) ? UnauthenticatedUser.INSTANCE : - * new User() { - * public boolean isAuthenticated() { return rc.userContext().authenticated(); } - * public String getUsername() { return rc.user().subject(); } - * }; + * new AuthenticatedUser(rc.user().subject()); * * // Extract custom data from routing context * Map state = new HashMap<>(); diff --git a/reference/jsonrpc/src/test/java/org/a2aproject/sdk/server/apps/quarkus/QuarkusA2AJSONRPCWithTaskAuthorizationVertxTest.java b/reference/jsonrpc/src/test/java/org/a2aproject/sdk/server/apps/quarkus/QuarkusA2AJSONRPCWithTaskAuthorizationVertxTest.java new file mode 100644 index 000000000..e928224e4 --- /dev/null +++ b/reference/jsonrpc/src/test/java/org/a2aproject/sdk/server/apps/quarkus/QuarkusA2AJSONRPCWithTaskAuthorizationVertxTest.java @@ -0,0 +1,43 @@ +package org.a2aproject.sdk.server.apps.quarkus; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.client.transport.jsonrpc.JSONRPCTransport; +import org.a2aproject.sdk.client.transport.jsonrpc.JSONRPCTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.server.apps.common.TaskAuthorizationTestProfile; +import org.a2aproject.sdk.spec.TransportProtocol; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class QuarkusA2AJSONRPCWithTaskAuthorizationVertxTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + Vertx vertx; + + @Override + protected String getTransportProtocol() { + return TransportProtocol.JSONRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(JSONRPCTransport.class, + new JSONRPCTransportConfigBuilder() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/reference/rest/src/main/java/org/a2aproject/sdk/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/org/a2aproject/sdk/server/rest/quarkus/A2AServerRoutes.java index 1fc25b08d..46eff8c37 100644 --- a/reference/rest/src/main/java/org/a2aproject/sdk/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/org/a2aproject/sdk/server/rest/quarkus/A2AServerRoutes.java @@ -27,6 +27,7 @@ import org.a2aproject.sdk.common.A2AHeaders; import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.AuthenticatedUser; import org.a2aproject.sdk.server.auth.UnauthenticatedUser; import org.a2aproject.sdk.server.auth.User; import org.a2aproject.sdk.server.extensions.A2AExtensions; @@ -888,23 +889,8 @@ private ServerCallContext createCallContext(RoutingContext rc, String jsonRpcMet if (rc.user() == null) { user = UnauthenticatedUser.INSTANCE; } else { - user = new User() { - @Override - public boolean isAuthenticated() { - if (rc.userContext() != null) { - return rc.userContext().authenticated(); - } - return false; - } - - @Override - public String getUsername() { - if (rc.user() != null && rc.user().subject() != null) { - return rc.user().subject(); - } - return ""; - } - }; + String subject = rc.user().subject(); + user = new AuthenticatedUser(subject != null ? subject : ""); } Map state = new HashMap<>(); diff --git a/reference/rest/src/test/java/org/a2aproject/sdk/server/rest/quarkus/QuarkusA2ARestWithTaskAuthorizationVertxTest.java b/reference/rest/src/test/java/org/a2aproject/sdk/server/rest/quarkus/QuarkusA2ARestWithTaskAuthorizationVertxTest.java new file mode 100644 index 000000000..e2809a6cf --- /dev/null +++ b/reference/rest/src/test/java/org/a2aproject/sdk/server/rest/quarkus/QuarkusA2ARestWithTaskAuthorizationVertxTest.java @@ -0,0 +1,43 @@ +package org.a2aproject.sdk.server.rest.quarkus; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.client.transport.rest.RestTransport; +import org.a2aproject.sdk.client.transport.rest.RestTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.server.apps.common.TaskAuthorizationTestProfile; +import org.a2aproject.sdk.spec.TransportProtocol; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class QuarkusA2ARestWithTaskAuthorizationVertxTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + Vertx vertx; + + @Override + protected String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(RestTransport.class, + new RestTransportConfigBuilder() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/server-common/pom.xml b/server-common/pom.xml index a117dcc23..69513abc9 100644 --- a/server-common/pom.xml +++ b/server-common/pom.xml @@ -92,6 +92,11 @@ mockito-core test + + org.mockito + mockito-junit-jupiter + test + ch.qos.logback logback-classic diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/auth/AuthenticatedUser.java b/server-common/src/main/java/org/a2aproject/sdk/server/auth/AuthenticatedUser.java new file mode 100644 index 000000000..255b2020a --- /dev/null +++ b/server-common/src/main/java/org/a2aproject/sdk/server/auth/AuthenticatedUser.java @@ -0,0 +1,19 @@ +package org.a2aproject.sdk.server.auth; + +import org.a2aproject.sdk.util.Assert; + +public record AuthenticatedUser(String username) implements User { + public AuthenticatedUser { + Assert.checkNotNullParam("username", username); + } + + @Override + public boolean isAuthenticated() { + return true; + } + + @Override + public String getUsername() { + return username; + } +} diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskAuthorizationProvider.java b/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskAuthorizationProvider.java new file mode 100644 index 000000000..8387001c3 --- /dev/null +++ b/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskAuthorizationProvider.java @@ -0,0 +1,159 @@ +package org.a2aproject.sdk.server.auth; + +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.spec.A2AError; + +/** + * SPI for per-user task authorization. + *

+ * Implementers provide a CDI bean ({@code @ApplicationScoped}) implementing this interface + * to control which users can read, write, or create tasks. When no implementation is provided, + * all operations are permitted. + * + *

Providing an implementation

+ *

+ * Create an {@code @ApplicationScoped} CDI bean that implements this interface. The SDK + * automatically discovers it and wires it into the request pipeline — no additional + * configuration is required. + * + *

{@code
+ * @ApplicationScoped
+ * public class MyTaskAuthorizationProvider implements TaskAuthorizationProvider {
+ *
+ *     @Override
+ *     public boolean checkRead(ServerCallContext context, String taskId, TaskOperation op) {
+ *         User user = context.getUser();
+ *         // look up ownership in your backing store
+ *         return isOwner(user, taskId);
+ *     }
+ *
+ *     @Override
+ *     public boolean checkWrite(ServerCallContext context, String taskId, TaskOperation op) {
+ *         return checkRead(context, taskId, op); // same rule
+ *     }
+ *
+ *     @Override
+ *     public boolean checkCreate(ServerCallContext context, TaskOperation op) {
+ *         return context.getUser().isAuthenticated();
+ *     }
+ *
+ *     @Override
+ *     public boolean isTaskRecorded(String taskId) {
+ *         return ownershipStore.contains(taskId);
+ *     }
+ *
+ *     @Override
+ *     public void recordOwnership(ServerCallContext context, String taskId, TaskOperation op) {
+ *         ownershipStore.put(taskId, context.getUser().getUsername());
+ *     }
+ * }
+ * }
+ * + *

Behavior

+ *

+ * When a provider is present, the SDK enforces authorization as follows: + *

    + *
  • {@code onGetTask}, {@code onSubscribeToTask}, {@code onGetTaskPushNotificationConfig}, + * {@code onListTaskPushNotificationConfigs} — call {@link #checkRead}
  • + *
  • {@code onCancelTask}, {@code onCreateTaskPushNotificationConfig}, + * {@code onDeleteTaskPushNotificationConfig} — call {@link #checkWrite}
  • + *
  • {@code onMessageSend}, {@code onMessageSendStream} — call {@link #checkWrite} if an + * existing task ID is provided, otherwise call {@link #checkCreate}; after the delegate + * returns, call {@link #recordOwnership} if a new task was created
  • + *
  • {@code onListTasks} — filtering is pushed down to the {@code TaskStore}, which calls + * {@link #checkRead} per task to exclude unauthorized entries
  • + *
+ * Denied operations throw {@code TaskNotFoundError} — the caller cannot distinguish + * "does not exist" from "not authorized", preventing information leakage. + * + *

Thread safety

+ *

+ * Implementations must be thread-safe. Methods will be called concurrently from multiple + * requests. + * + *

Ownership recording

+ *

+ * {@link #recordOwnership} is only triggered by {@code onMessageSend} and + * {@code onMessageSendStream} — the methods that can create tasks. + * Other methods ({@code onGetTask}, {@code onCancelTask}, etc.) do not trigger recording. + * {@link #checkRead}/{@link #checkWrite} may be called for tasks the provider has no ownership + * data for (e.g., legacy tasks created before the provider was enabled). For production + * deployments, a fail-closed policy is recommended: deny access when no ownership + * data exists. An {@code owner == null → allow} policy is only appropriate for testing or + * single-user deployments. If enabling the provider on an existing deployment, consider a + * migration step to backfill ownership for pre-existing tasks. + * + *

Common pitfalls

+ *
    + *
  • TOCTOU race on ownership recording: The {@link #isTaskRecorded} → + * {@link #recordOwnership} sequence is not atomic. Two concurrent {@code onMessageSend} + * calls for the same new task can both see {@code isTaskRecorded()} return {@code false} + * and both call {@code recordOwnership}. Implementations must use atomic-insert patterns + * (e.g., {@code ConcurrentMap.putIfAbsent}, {@code INSERT ... ON CONFLICT DO NOTHING}) + * so the first writer wins and the second is a harmless no-op.
  • + *
  • CDI injection requirement: When task authorization is required, always obtain + * {@code RequestHandler} through CDI injection. Manual instantiation via + * {@code DefaultRequestHandler.create()} bypasses the + * {@code AuthorizationRequestHandlerDecorator}.
  • + *
+ * + * @see TaskOperation + */ +public interface TaskAuthorizationProvider { + + /** + * Check whether the current user is allowed to read the given task. + * + * @param context the server call context containing the authenticated user + * @param taskId the task being accessed + * @param operation which RequestHandler method triggered the check + * @return {@code true} to allow, {@code false} to deny + * @throws A2AError if the authorization check itself fails + */ + boolean checkRead(ServerCallContext context, String taskId, TaskOperation operation) throws A2AError; + + /** + * Check whether the current user is allowed to write to the given task. + * + * @param context the server call context containing the authenticated user + * @param taskId the task being accessed + * @param operation which RequestHandler method triggered the check + * @return {@code true} to allow, {@code false} to deny + * @throws A2AError if the authorization check itself fails + */ + boolean checkWrite(ServerCallContext context, String taskId, TaskOperation operation) throws A2AError; + + /** + * Check whether the current user is allowed to create a new task. + * + * @param context the server call context containing the authenticated user + * @param operation which RequestHandler method triggered the check + * @return {@code true} to allow, {@code false} to deny + * @throws A2AError if the authorization check itself fails + */ + boolean checkCreate(ServerCallContext context, TaskOperation operation) throws A2AError; + + /** + * Check whether the given task is already known to this provider. + * Used to avoid redundant {@link #recordOwnership} calls. + * + * @param taskId the task to check + * @return {@code true} if ownership has already been recorded for this task + * @throws A2AError if the check itself fails + */ + boolean isTaskRecorded(String taskId) throws A2AError; + + /** + * Record that the current user owns the given task. Called after task creation + * via {@code onMessageSend} or {@code onMessageSendStream}. + *

+ * Must be idempotent. Concurrent requests for the same unrecorded task may both + * call this method before either completes. + * + * @param context the server call context containing the authenticated user + * @param taskId the newly created task + * @param operation which RequestHandler method triggered the recording + * @throws A2AError if recording fails + */ + void recordOwnership(ServerCallContext context, String taskId, TaskOperation operation) throws A2AError; +} diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskOperation.java b/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskOperation.java new file mode 100644 index 000000000..ca0e76fe6 --- /dev/null +++ b/server-common/src/main/java/org/a2aproject/sdk/server/auth/TaskOperation.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.server.auth; + +/** + * Identifies which {@link org.a2aproject.sdk.server.requesthandlers.RequestHandler} operation + * triggered an authorization check. + */ +public enum TaskOperation { + GET_TASK, + LIST_TASKS, + CANCEL_TASK, + MESSAGE_SEND, + MESSAGE_SEND_STREAM, + SUBSCRIBE_TO_TASK, + CREATE_TASK_PUSH_NOTIFICATION_CONFIG, + GET_TASK_PUSH_NOTIFICATION_CONFIG, + LIST_TASK_PUSH_NOTIFICATION_CONFIGS, + DELETE_TASK_PUSH_NOTIFICATION_CONFIG +} diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/auth/User.java b/server-common/src/main/java/org/a2aproject/sdk/server/auth/User.java index 4d2e56185..fe1c439a2 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/auth/User.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/auth/User.java @@ -1,6 +1,12 @@ package org.a2aproject.sdk.server.auth; +import org.jspecify.annotations.Nullable; + public interface User { boolean isAuthenticated(); String getUsername(); + + default @Nullable Object getAttribute(String key) { + return null; + } } diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecorator.java b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecorator.java new file mode 100644 index 000000000..0e75c474c --- /dev/null +++ b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecorator.java @@ -0,0 +1,254 @@ +package org.a2aproject.sdk.server.requesthandlers; + +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.Priority; +import jakarta.decorator.Decorator; +import jakarta.decorator.Delegate; +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; + +import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.spec.A2AError; +import org.a2aproject.sdk.spec.CancelTaskParams; +import org.a2aproject.sdk.spec.DeleteTaskPushNotificationConfigParams; +import org.a2aproject.sdk.spec.EventKind; +import org.a2aproject.sdk.spec.GetTaskPushNotificationConfigParams; +import org.a2aproject.sdk.spec.ListTaskPushNotificationConfigsParams; +import org.a2aproject.sdk.spec.ListTaskPushNotificationConfigsResult; +import org.a2aproject.sdk.spec.ListTasksParams; +import org.a2aproject.sdk.spec.MessageSendParams; +import org.a2aproject.sdk.spec.StreamingEventKind; +import org.a2aproject.sdk.spec.Task; +import org.a2aproject.sdk.spec.TaskArtifactUpdateEvent; +import org.a2aproject.sdk.spec.TaskIdParams; +import org.a2aproject.sdk.spec.TaskNotFoundError; +import org.a2aproject.sdk.spec.TaskPushNotificationConfig; +import org.a2aproject.sdk.spec.TaskQueryParams; +import org.a2aproject.sdk.spec.TaskStatusUpdateEvent; +import org.a2aproject.sdk.server.auth.TaskOperation; +import org.jspecify.annotations.Nullable; + +@Decorator +@Priority(50) +public class AuthorizationRequestHandlerDecorator implements RequestHandler { + + @Inject + @Delegate + @Any + private RequestHandler delegate; + + @Inject + @Any + Instance authorizationProviderInstance; + + private @Nullable TaskAuthorizationProvider authorizationProvider; + + public AuthorizationRequestHandlerDecorator() { + } + + AuthorizationRequestHandlerDecorator(RequestHandler delegate, + @Nullable TaskAuthorizationProvider authorizationProvider) { + this.delegate = delegate; + this.authorizationProvider = authorizationProvider; + } + + @PostConstruct + void init() { + if (authorizationProviderInstance != null) { + authorizationProvider = authorizationProviderInstance.isResolvable() + ? authorizationProviderInstance.get() + : null; + } + } + + private Flow.Publisher wrapPublisherForOwnership( + Flow.Publisher publisher, + ServerCallContext context, + TaskOperation operation, + TaskAuthorizationProvider provider) { + return subscriber -> publisher.subscribe(new Flow.Subscriber<>() { + private final AtomicBoolean ownershipChecked = new AtomicBoolean(false); + private final AtomicBoolean done = new AtomicBoolean(false); + @SuppressWarnings("NullAway") + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription s) { + this.subscription = s; + subscriber.onSubscribe(s); + } + + @Override + public void onNext(StreamingEventKind event) { + if (done.get()) { + return; + } + if (!ownershipChecked.get()) { + String taskId = extractTaskId(event); + if (taskId != null) { + ownershipChecked.set(true); + try { + if (!provider.isTaskRecorded(taskId)) { + provider.recordOwnership(context, taskId, operation); + } + } catch (A2AError e) { + done.set(true); + subscription.cancel(); + subscriber.onError(e); + return; + } + } + } + subscriber.onNext(event); + } + + @Override + public void onError(Throwable t) { + if (done.compareAndSet(false, true)) { + subscriber.onError(t); + } + } + + @Override + public void onComplete() { + if (done.compareAndSet(false, true)) { + subscriber.onComplete(); + } + } + }); + } + + private void enforceCreate(ServerCallContext context, TaskOperation operation) throws A2AError { + if (authorizationProvider != null && !authorizationProvider.checkCreate(context, operation)) { + throw new TaskNotFoundError(); + } + } + + private @Nullable String extractTaskId(Object event) { + if (event instanceof Task task) { + return task.id(); + } else if (event instanceof TaskStatusUpdateEvent e) { + return e.taskId(); + } else if (event instanceof TaskArtifactUpdateEvent e) { + return e.taskId(); + } + return null; + } + + private void recordOwnershipIfNeeded(ServerCallContext context, @Nullable String taskId, + TaskOperation operation) throws A2AError { + if (authorizationProvider != null && taskId != null && !authorizationProvider.isTaskRecorded(taskId)) { + authorizationProvider.recordOwnership(context, taskId, operation); + } + } + + private void enforceWrite(ServerCallContext context, String taskId, TaskOperation operation) throws A2AError { + if (authorizationProvider != null && !authorizationProvider.checkWrite(context, taskId, operation)) { + throw new TaskNotFoundError(); + } + } + + private void enforceRead(ServerCallContext context, String taskId, TaskOperation operation) throws A2AError { + if (authorizationProvider != null && !authorizationProvider.checkRead(context, taskId, operation)) { + throw new TaskNotFoundError(); + } + } + + @Override + public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws A2AError { + enforceRead(context, params.id(), TaskOperation.GET_TASK); + return delegate.onGetTask(params, context); + } + + @Override + public ListTasksResult onListTasks(ListTasksParams params, ServerCallContext context) throws A2AError { + return delegate.onListTasks(params, context); + } + + @Override + public Task onCancelTask(CancelTaskParams params, ServerCallContext context) throws A2AError { + enforceWrite(context, params.id(), TaskOperation.CANCEL_TASK); + return delegate.onCancelTask(params, context); + } + + @Override + public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws A2AError { + String taskId = params.message().taskId(); + if (taskId != null) { + enforceWrite(context, taskId, TaskOperation.MESSAGE_SEND); + } else { + enforceCreate(context, TaskOperation.MESSAGE_SEND); + } + EventKind result = delegate.onMessageSend(params, context); + String resultTaskId = extractTaskId(result); + recordOwnershipIfNeeded(context, resultTaskId, TaskOperation.MESSAGE_SEND); + return result; + } + + @Override + public Flow.Publisher onMessageSendStream(MessageSendParams params, + ServerCallContext context) throws A2AError { + String taskId = params.message().taskId(); + if (taskId != null) { + enforceWrite(context, taskId, TaskOperation.MESSAGE_SEND_STREAM); + } else { + enforceCreate(context, TaskOperation.MESSAGE_SEND_STREAM); + } + Flow.Publisher publisher = delegate.onMessageSendStream(params, context); + if (authorizationProvider != null) { + publisher = wrapPublisherForOwnership(publisher, context, TaskOperation.MESSAGE_SEND_STREAM, + authorizationProvider); + } + return publisher; + } + + @Override + public TaskPushNotificationConfig onCreateTaskPushNotificationConfig(TaskPushNotificationConfig params, + ServerCallContext context) throws A2AError { + String taskId = params.taskId(); + if (taskId != null) { + enforceWrite(context, taskId, TaskOperation.CREATE_TASK_PUSH_NOTIFICATION_CONFIG); + } + // taskId is required by the spec; if null, the delegate will reject with InvalidParamsError + return delegate.onCreateTaskPushNotificationConfig(params, context); + } + + @Override + public TaskPushNotificationConfig onGetTaskPushNotificationConfig(GetTaskPushNotificationConfigParams params, + ServerCallContext context) throws A2AError { + enforceRead(context, params.taskId(), TaskOperation.GET_TASK_PUSH_NOTIFICATION_CONFIG); + return delegate.onGetTaskPushNotificationConfig(params, context); + } + + @Override + public Flow.Publisher onSubscribeToTask(TaskIdParams params, + ServerCallContext context) throws A2AError { + enforceRead(context, params.id(), TaskOperation.SUBSCRIBE_TO_TASK); + return delegate.onSubscribeToTask(params, context); + } + + @Override + public ListTaskPushNotificationConfigsResult onListTaskPushNotificationConfigs( + ListTaskPushNotificationConfigsParams params, ServerCallContext context) throws A2AError { + enforceRead(context, params.id(), TaskOperation.LIST_TASK_PUSH_NOTIFICATION_CONFIGS); + return delegate.onListTaskPushNotificationConfigs(params, context); + } + + @Override + public void onDeleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams params, + ServerCallContext context) throws A2AError { + enforceWrite(context, params.taskId(), TaskOperation.DELETE_TASK_PUSH_NOTIFICATION_CONFIG); + delegate.onDeleteTaskPushNotificationConfig(params, context); + } + + @Override + public void validateRequestedTask(@Nullable String requestedTaskId) throws A2AError { + delegate.validateRequestedTask(requestedTaskId); + } +} diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java index 8287a07a8..58f31f581 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java @@ -353,7 +353,7 @@ public ListTasksResult onListTasks(ListTasksParams params, ServerCallContext con } } - ListTasksResult result = taskStore.list(params); + ListTasksResult result = taskStore.list(params, context); LOGGER.debug("Found {} tasks (total: {})", result.pageSize(), result.totalSize()); return result; } diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStore.java b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStore.java index 7bce70841..ddfbde45b 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStore.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStore.java @@ -6,8 +6,14 @@ import java.util.concurrent.ConcurrentMap; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; import org.a2aproject.sdk.spec.Artifact; import org.a2aproject.sdk.spec.ListTasksParams; import org.a2aproject.sdk.spec.Message; @@ -67,7 +73,7 @@ * .statusTimestampBefore(Instant.now().minus(Duration.ofHours(48))) * .build(); * - * List oldTasks = taskStore.list(params).tasks(); + * List oldTasks = taskStore.list(params, context).tasks(); * oldTasks.stream() * .filter(task -> task.status().state().isFinal()) * .forEach(task -> taskStore.delete(task.id())); @@ -86,6 +92,22 @@ public class InMemoryTaskStore implements TaskStore, TaskStateProvider { private final ConcurrentMap tasks = new ConcurrentHashMap<>(); + private final @Nullable TaskAuthorizationProvider authorizationProvider; + + public InMemoryTaskStore() { + this.authorizationProvider = null; + } + + @Inject + public InMemoryTaskStore(@Any Instance authorizationProviderInstance) { + this.authorizationProvider = authorizationProviderInstance.isResolvable() + ? authorizationProviderInstance.get() + : null; + } + + InMemoryTaskStore(@Nullable TaskAuthorizationProvider authorizationProvider) { + this.authorizationProvider = authorizationProvider; + } @Override public void save(Task task, boolean isReplicated) { @@ -104,7 +126,7 @@ public void delete(String taskId) { } @Override - public ListTasksResult list(ListTasksParams params) { + public ListTasksResult list(ListTasksParams params, @Nullable ServerCallContext context) { // Filter and sort tasks in a single stream pipeline List allFilteredTasks = tasks.values().stream() .filter(task -> params.contextId() == null || params.contextId().equals(task.contextId())) @@ -114,6 +136,8 @@ public ListTasksResult list(ListTasksParams params) { (task.status() != null && task.status().timestamp() != null && task.status().timestamp().toInstant().isAfter(params.statusTimestampAfter()))) + .filter(task -> authorizationProvider == null || context == null || + authorizationProvider.checkRead(context, task.id(), TaskOperation.LIST_TASKS)) .sorted(Comparator.comparing( (Task t) -> (t.status() != null && t.status().timestamp() != null) // Truncate to milliseconds for consistency with pageToken precision diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStore.java b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStore.java index 86a6b76cb..6f798cc7d 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStore.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStore.java @@ -1,6 +1,7 @@ package org.a2aproject.sdk.server.tasks; import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; import org.a2aproject.sdk.spec.ListTasksParams; import org.a2aproject.sdk.spec.Task; import org.jspecify.annotations.Nullable; @@ -82,7 +83,7 @@ * conflicts appropriately (last-write-wins, optimistic locking, etc.). * *

List Operation Performance

- * The {@link #list(org.a2aproject.sdk.spec.ListTasksParams)} method may need to scan and filter + * The {@link #list(ListTasksParams, ServerCallContext)} method may need to scan and filter * many tasks. Database implementations should: *
    *
  • Use indexes on contextId, status, lastUpdatedAt
  • @@ -189,8 +190,23 @@ public interface TaskStore { /** * List tasks with optional filtering and pagination. + *

    + * Authorization filtering: When a + * {@link org.a2aproject.sdk.server.auth.TaskAuthorizationProvider TaskAuthorizationProvider} + * bean is present, implementations must call + * {@link org.a2aproject.sdk.server.auth.TaskAuthorizationProvider#checkRead checkRead} for + * each candidate task and exclude tasks for which the check returns {@code false}. + * The filtering should be applied before pagination so that page sizes are correct + * from the caller's perspective. If no provider is present, all tasks are returned. + *

    + * ⚠ Custom implementation warning: Returning unfiltered results bypasses the + * authorization model and can leak tasks belonging to other users. Custom implementations + * must apply per-task {@code checkRead} filtering before pagination. The + * {@code TaskAuthorizationProvider} should be declared as a CDI dependency and injected + * via the constructor or {@code @Inject}. * * @param params the filtering and pagination parameters + * @param context the server call context (used for authorization filtering) * @return the list of tasks matching the criteria with pagination info * @throws TaskSerializationException if any persisted task data cannot be deserialized during listing * (corrupted JSON in database) @@ -198,5 +214,5 @@ public interface TaskStore { * connection error) * @throws TaskStoreException for other listing failures not covered by specific subclasses */ - ListTasksResult list(ListTasksParams params); + ListTasksResult list(ListTasksParams params, @Nullable ServerCallContext context); } diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStoreException.java b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStoreException.java index b4146b91e..d3f127a9a 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStoreException.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/tasks/TaskStoreException.java @@ -19,7 +19,7 @@ *

  • {@code save(Task, boolean)} - Task persistence failures
  • *
  • {@code get(String)} - Task retrieval failures
  • *
  • {@code delete(String)} - Task deletion failures
  • - *
  • {@code list(ListTasksParams)} - Task listing failures
  • + *
  • {@code list(ListTasksParams, ServerCallContext)} - Task listing failures
  • *
* *

Error Handling Pattern

diff --git a/server-common/src/test/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecoratorTest.java b/server-common/src/test/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecoratorTest.java new file mode 100644 index 000000000..901a92b1b --- /dev/null +++ b/server-common/src/test/java/org/a2aproject/sdk/server/requesthandlers/AuthorizationRequestHandlerDecoratorTest.java @@ -0,0 +1,475 @@ +package org.a2aproject.sdk.server.requesthandlers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; + +import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; +import org.a2aproject.sdk.spec.A2AError; +import org.a2aproject.sdk.spec.CancelTaskParams; +import org.a2aproject.sdk.spec.DeleteTaskPushNotificationConfigParams; +import org.a2aproject.sdk.spec.EventKind; +import org.a2aproject.sdk.spec.StreamingEventKind; +import org.a2aproject.sdk.spec.GetTaskPushNotificationConfigParams; +import org.a2aproject.sdk.spec.ListTaskPushNotificationConfigsParams; +import org.a2aproject.sdk.spec.ListTasksParams; +import org.a2aproject.sdk.spec.Message; +import org.a2aproject.sdk.spec.MessageSendParams; +import org.a2aproject.sdk.spec.Task; +import org.a2aproject.sdk.spec.TaskIdParams; +import org.a2aproject.sdk.spec.TaskNotFoundError; +import org.a2aproject.sdk.spec.TaskPushNotificationConfig; +import org.a2aproject.sdk.spec.TaskQueryParams; +import org.a2aproject.sdk.spec.TaskState; +import org.a2aproject.sdk.spec.TaskStatus; +import org.a2aproject.sdk.spec.TaskStatusUpdateEvent; +import org.a2aproject.sdk.spec.TextPart; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class AuthorizationRequestHandlerDecoratorTest { + + @Mock + private RequestHandler delegate; + + @Mock + private ServerCallContext context; + + @Mock + private TaskAuthorizationProvider authorizationProvider; + + private static Task testTask(String id) { + return Task.builder() + .id(id) + .contextId("ctx-1") + .status(new TaskStatus(TaskState.TASK_STATE_COMPLETED)) + .history(Collections.emptyList()) + .artifacts(Collections.emptyList()) + .build(); + } + + @Nested + class NoProviderTests { + + private AuthorizationRequestHandlerDecorator decorator; + + @BeforeEach + void setUp() { + decorator = new AuthorizationRequestHandlerDecorator(delegate, null); + } + + @Test + void onGetTask_delegatesWithoutChecks() throws A2AError { + TaskQueryParams params = new TaskQueryParams("task-1", null, null); + Task expected = testTask("task-1"); + when(delegate.onGetTask(params, context)).thenReturn(expected); + + Task result = decorator.onGetTask(params, context); + + assertEquals(expected, result); + verify(delegate).onGetTask(params, context); + } + + @Test + void onListTasks_delegatesWithoutChecks() throws A2AError { + ListTasksParams params = new ListTasksParams(); + ListTasksResult expected = new ListTasksResult(Collections.emptyList(), 0, 0, null); + when(delegate.onListTasks(params, context)).thenReturn(expected); + + ListTasksResult result = decorator.onListTasks(params, context); + + assertEquals(expected, result); + verify(delegate).onListTasks(params, context); + } + + @Test + void onCancelTask_delegatesWithoutChecks() throws A2AError { + CancelTaskParams params = new CancelTaskParams("task-1"); + Task expected = testTask("task-1"); + when(delegate.onCancelTask(params, context)).thenReturn(expected); + + Task result = decorator.onCancelTask(params, context); + + assertEquals(expected, result); + verify(delegate).onCancelTask(params, context); + } + } + + @Nested + class CheckReadTests { + + private AuthorizationRequestHandlerDecorator decorator; + + @BeforeEach + void setUp() { + decorator = new AuthorizationRequestHandlerDecorator(delegate, authorizationProvider); + } + + @Test + void onGetTask_allowed() throws A2AError { + TaskQueryParams params = new TaskQueryParams("task-1", null, null); + Task expected = testTask("task-1"); + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.GET_TASK)).thenReturn(true); + when(delegate.onGetTask(params, context)).thenReturn(expected); + + Task result = decorator.onGetTask(params, context); + + assertEquals(expected, result); + } + + @Test + void onGetTask_denied() throws A2AError { + TaskQueryParams params = new TaskQueryParams("task-1", null, null); + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.GET_TASK)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onGetTask(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onSubscribeToTask_denied() throws A2AError { + TaskIdParams params = new TaskIdParams("task-1"); + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.SUBSCRIBE_TO_TASK)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onSubscribeToTask(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onGetTaskPushNotificationConfig_denied() throws A2AError { + GetTaskPushNotificationConfigParams params = new GetTaskPushNotificationConfigParams("task-1", "config-1"); + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.GET_TASK_PUSH_NOTIFICATION_CONFIG)) + .thenReturn(false); + + assertThrows(TaskNotFoundError.class, + () -> decorator.onGetTaskPushNotificationConfig(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onListTaskPushNotificationConfigs_denied() throws A2AError { + ListTaskPushNotificationConfigsParams params = new ListTaskPushNotificationConfigsParams("task-1"); + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.LIST_TASK_PUSH_NOTIFICATION_CONFIGS)) + .thenReturn(false); + + assertThrows(TaskNotFoundError.class, + () -> decorator.onListTaskPushNotificationConfigs(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void spiException_propagates() throws A2AError { + TaskQueryParams params = new TaskQueryParams("task-1", null, null); + A2AError spiError = new A2AError(-32000, "Authorization service unavailable", null) {}; + when(authorizationProvider.checkRead(context, "task-1", TaskOperation.GET_TASK)).thenThrow(spiError); + + A2AError thrown = assertThrows(A2AError.class, () -> decorator.onGetTask(params, context)); + assertEquals(spiError, thrown); + verifyNoInteractions(delegate); + } + } + + @Nested + class CheckWriteTests { + + private AuthorizationRequestHandlerDecorator decorator; + + @BeforeEach + void setUp() { + decorator = new AuthorizationRequestHandlerDecorator(delegate, authorizationProvider); + } + + @Test + void onCancelTask_allowed() throws A2AError { + CancelTaskParams params = new CancelTaskParams("task-1"); + Task expected = testTask("task-1"); + when(authorizationProvider.checkWrite(context, "task-1", TaskOperation.CANCEL_TASK)).thenReturn(true); + when(delegate.onCancelTask(params, context)).thenReturn(expected); + + Task result = decorator.onCancelTask(params, context); + + assertEquals(expected, result); + } + + @Test + void onCancelTask_denied() throws A2AError { + CancelTaskParams params = new CancelTaskParams("task-1"); + when(authorizationProvider.checkWrite(context, "task-1", TaskOperation.CANCEL_TASK)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onCancelTask(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onCreateTaskPushNotificationConfig_denied() throws A2AError { + TaskPushNotificationConfig params = TaskPushNotificationConfig.builder() + .id("config-1").taskId("task-1").url("https://example.com/webhook").build(); + when(authorizationProvider.checkWrite(context, "task-1", + TaskOperation.CREATE_TASK_PUSH_NOTIFICATION_CONFIG)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, + () -> decorator.onCreateTaskPushNotificationConfig(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onDeleteTaskPushNotificationConfig_denied() throws A2AError { + DeleteTaskPushNotificationConfigParams params = + new DeleteTaskPushNotificationConfigParams("task-1", "config-1"); + when(authorizationProvider.checkWrite(context, "task-1", + TaskOperation.DELETE_TASK_PUSH_NOTIFICATION_CONFIG)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, + () -> decorator.onDeleteTaskPushNotificationConfig(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onMessageSend_existingTask_denied() throws A2AError { + Message message = Message.builder().messageId("m-1").role(Message.Role.ROLE_USER) + .taskId("task-1").parts(new TextPart("hello")).build(); + MessageSendParams params = new MessageSendParams(message, null, null, null); + when(authorizationProvider.checkWrite(context, "task-1", TaskOperation.MESSAGE_SEND)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onMessageSend(params, context)); + verifyNoInteractions(delegate); + } + + @Test + void onMessageSendStream_existingTask_denied() throws A2AError { + Message message = Message.builder().messageId("m-1").role(Message.Role.ROLE_USER) + .taskId("task-1").parts(new TextPart("hello")).build(); + MessageSendParams params = new MessageSendParams(message, null, null, null); + when(authorizationProvider.checkWrite(context, "task-1", TaskOperation.MESSAGE_SEND_STREAM)) + .thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onMessageSendStream(params, context)); + verifyNoInteractions(delegate); + } + } + + @Nested + class CheckCreateAndOwnershipTests { + + private AuthorizationRequestHandlerDecorator decorator; + + @BeforeEach + void setUp() { + decorator = new AuthorizationRequestHandlerDecorator(delegate, authorizationProvider); + } + + private MessageSendParams newTaskParams() { + Message message = Message.builder().messageId("m-1").role(Message.Role.ROLE_USER) + .parts(new TextPart("hello")).build(); + return new MessageSendParams(message, null, null, null); + } + + @Test + void onMessageSend_newTask_checkCreateDenied() throws A2AError { + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND)).thenReturn(false); + + assertThrows(TaskNotFoundError.class, () -> decorator.onMessageSend(newTaskParams(), context)); + verifyNoInteractions(delegate); + } + + @Test + void onMessageSend_newTask_createsTask_recordsOwnership() throws A2AError { + Task createdTask = testTask("new-task-1"); + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND)).thenReturn(true); + when(delegate.onMessageSend(any(), eq(context))).thenReturn(createdTask); + when(authorizationProvider.isTaskRecorded("new-task-1")).thenReturn(false); + + EventKind result = decorator.onMessageSend(newTaskParams(), context); + + assertEquals(createdTask, result); + verify(authorizationProvider).recordOwnership(context, "new-task-1", TaskOperation.MESSAGE_SEND); + } + + @Test + void onMessageSend_newTask_alreadyRecorded_skipsOwnership() throws A2AError { + Task createdTask = testTask("new-task-1"); + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND)).thenReturn(true); + when(delegate.onMessageSend(any(), eq(context))).thenReturn(createdTask); + when(authorizationProvider.isTaskRecorded("new-task-1")).thenReturn(true); + + decorator.onMessageSend(newTaskParams(), context); + + verify(authorizationProvider, never()).recordOwnership(any(), any(), any()); + } + + @Test + void onMessageSend_returnsMessage_noOwnershipRecording() throws A2AError { + Message response = Message.builder().messageId("resp-1").role(Message.Role.ROLE_AGENT) + .parts(new TextPart("response")).build(); + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND)).thenReturn(true); + when(delegate.onMessageSend(any(), eq(context))).thenReturn(response); + + EventKind result = decorator.onMessageSend(newTaskParams(), context); + + assertEquals(response, result); + verify(authorizationProvider, never()).isTaskRecorded(any()); + verify(authorizationProvider, never()).recordOwnership(any(), any(), any()); + } + + @Test + void onMessageSend_existingTask_allowed_recordsOwnershipIfNeeded() throws A2AError { + Message message = Message.builder().messageId("m-1").role(Message.Role.ROLE_USER) + .taskId("existing-task").parts(new TextPart("hello")).build(); + MessageSendParams params = new MessageSendParams(message, null, null, null); + Task resultTask = testTask("existing-task"); + when(authorizationProvider.checkWrite(context, "existing-task", TaskOperation.MESSAGE_SEND)) + .thenReturn(true); + when(delegate.onMessageSend(params, context)).thenReturn(resultTask); + when(authorizationProvider.isTaskRecorded("existing-task")).thenReturn(false); + + decorator.onMessageSend(params, context); + + verify(authorizationProvider).recordOwnership(context, "existing-task", TaskOperation.MESSAGE_SEND); + } + } + + @Nested + class StreamingOwnershipTests { + + private AuthorizationRequestHandlerDecorator decorator; + + @BeforeEach + void setUp() { + decorator = new AuthorizationRequestHandlerDecorator(delegate, authorizationProvider); + } + + private MessageSendParams newTaskStreamParams() { + Message message = Message.builder().messageId("m-1").role(Message.Role.ROLE_USER) + .parts(new TextPart("hello")).build(); + return new MessageSendParams(message, null, null, null); + } + + @Test + void onMessageSendStream_newTask_recordsOwnershipOnFirstTaskEvent() throws Exception { + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND_STREAM)).thenReturn(true); + when(authorizationProvider.isTaskRecorded("stream-task-1")).thenReturn(false); + + TaskStatusUpdateEvent statusEvent = new TaskStatusUpdateEvent( + "stream-task-1", new TaskStatus(TaskState.TASK_STATE_WORKING), "ctx-1", null); + + Flow.Publisher sourcePublisher = subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + subscriber.onNext(statusEvent); + subscriber.onComplete(); + } + + @Override + public void cancel() { + } + }); + }; + when(delegate.onMessageSendStream(any(), eq(context))).thenReturn(sourcePublisher); + + Flow.Publisher result = decorator.onMessageSendStream(newTaskStreamParams(), context); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + result.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(10); + } + + @Override + public void onNext(StreamingEventKind item) { + received.add(item); + } + + @Override + public void onError(Throwable t) { + latch.countDown(); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertEquals(1, received.size()); + assertEquals(statusEvent, received.get(0)); + verify(authorizationProvider).recordOwnership(context, "stream-task-1", TaskOperation.MESSAGE_SEND_STREAM); + } + + @Test + void onMessageSendStream_messageOnly_noOwnershipRecording() throws Exception { + when(authorizationProvider.checkCreate(context, TaskOperation.MESSAGE_SEND_STREAM)).thenReturn(true); + + Message response = Message.builder().messageId("resp-1").role(Message.Role.ROLE_AGENT) + .parts(new TextPart("response")).build(); + + Flow.Publisher sourcePublisher = subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + subscriber.onNext(response); + subscriber.onComplete(); + } + + @Override + public void cancel() { + } + }); + }; + when(delegate.onMessageSendStream(any(), eq(context))).thenReturn(sourcePublisher); + + Flow.Publisher result = decorator.onMessageSendStream(newTaskStreamParams(), context); + + CountDownLatch latch = new CountDownLatch(1); + result.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(10); + } + + @Override + public void onNext(StreamingEventKind item) { + } + + @Override + public void onError(Throwable t) { + latch.countDown(); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + verify(authorizationProvider, never()).isTaskRecorded(any()); + verify(authorizationProvider, never()).recordOwnership(any(), any(), any()); + } + } +} diff --git a/server-common/src/test/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStoreAuthorizationTest.java b/server-common/src/test/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStoreAuthorizationTest.java new file mode 100644 index 000000000..37bddf46b --- /dev/null +++ b/server-common/src/test/java/org/a2aproject/sdk/server/tasks/InMemoryTaskStoreAuthorizationTest.java @@ -0,0 +1,108 @@ +package org.a2aproject.sdk.server.tasks; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import java.util.Collections; + +import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; +import org.a2aproject.sdk.spec.ListTasksParams; +import org.a2aproject.sdk.spec.Task; +import org.a2aproject.sdk.spec.TaskState; +import org.a2aproject.sdk.spec.TaskStatus; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class InMemoryTaskStoreAuthorizationTest { + + @Mock + private TaskAuthorizationProvider authorizationProvider; + + @Mock + private ServerCallContext context; + + private InMemoryTaskStore store; + + private static Task testTask(String id) { + return Task.builder() + .id(id) + .contextId("ctx-1") + .status(new TaskStatus(TaskState.TASK_STATE_COMPLETED)) + .history(Collections.emptyList()) + .artifacts(Collections.emptyList()) + .build(); + } + + @BeforeEach + void setUp() { + store = new InMemoryTaskStore(authorizationProvider); + } + + @Test + void list_filtersUnauthorizedTasks() throws Exception { + store.save(testTask("task-1"), false); + store.save(testTask("task-2"), false); + store.save(testTask("task-3"), false); + + when(authorizationProvider.checkRead(eq(context), eq("task-1"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(true); + when(authorizationProvider.checkRead(eq(context), eq("task-2"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(false); + when(authorizationProvider.checkRead(eq(context), eq("task-3"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(true); + + ListTasksParams params = new ListTasksParams(); + ListTasksResult result = store.list(params, context); + + assertEquals(2, result.tasks().size()); + assertEquals(2, result.totalSize()); + assertTrue(result.tasks().stream().anyMatch(t -> t.id().equals("task-1"))); + assertTrue(result.tasks().stream().anyMatch(t -> t.id().equals("task-3"))); + assertTrue(result.tasks().stream().noneMatch(t -> t.id().equals("task-2"))); + } + + @Test + void list_noProvider_returnsAllTasks() throws Exception { + InMemoryTaskStore storeNoAuth = new InMemoryTaskStore((TaskAuthorizationProvider) null); + storeNoAuth.save(testTask("task-1"), false); + storeNoAuth.save(testTask("task-2"), false); + + ListTasksParams params = new ListTasksParams(); + ListTasksResult result = storeNoAuth.list(params, context); + + assertEquals(2, result.tasks().size()); + } + + @Test + void list_paginationCorrectWithFiltering() throws Exception { + for (int i = 1; i <= 5; i++) { + store.save(testTask("task-" + i), false); + } + + when(authorizationProvider.checkRead(eq(context), eq("task-1"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(true); + when(authorizationProvider.checkRead(eq(context), eq("task-2"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(false); + when(authorizationProvider.checkRead(eq(context), eq("task-3"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(true); + when(authorizationProvider.checkRead(eq(context), eq("task-4"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(false); + when(authorizationProvider.checkRead(eq(context), eq("task-5"), eq(TaskOperation.LIST_TASKS))) + .thenReturn(true); + + ListTasksParams params = new ListTasksParams(); + ListTasksResult result = store.list(params, context); + + assertEquals(3, result.totalSize()); + assertEquals(3, result.pageSize()); + } +} diff --git a/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/MultiVersionGrpcWithTaskAuthorizationTest.java b/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/MultiVersionGrpcWithTaskAuthorizationTest.java new file mode 100644 index 000000000..3f7c9e1d3 --- /dev/null +++ b/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/MultiVersionGrpcWithTaskAuthorizationTest.java @@ -0,0 +1,72 @@ +package org.a2aproject.sdk.tests.multiversion.grpc; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.transport.grpc.GrpcTransport; +import org.a2aproject.sdk.client.transport.grpc.GrpcTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.PublicAgentCard; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.spec.AgentCard; +import org.a2aproject.sdk.spec.TransportProtocol; +import org.junit.jupiter.api.AfterAll; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class MultiVersionGrpcWithTaskAuthorizationTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + @PublicAgentCard + AgentCard agentCard; + + private static final Map channels = new ConcurrentHashMap<>(); + + @Override + protected String getTransportProtocol() { + return TransportProtocol.GRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "localhost:8081"; + } + + @Override + protected AgentCard fetchAgentCardFromServer() { + return agentCard; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(GrpcTransport.class, new GrpcTransportConfigBuilder() + .channelFactory(target -> { + ManagedChannel channel = ManagedChannelBuilder.forTarget(target).usePlaintext().build(); + channels.put(username, channel); + return channel; + }) + .addInterceptor(authInterceptor)); + } + + @AfterAll + static void closeChannels() { + channels.values().forEach(ch -> { + ch.shutdownNow(); + try { + ch.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } +} diff --git a/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/TaskAuthorizationTestProfile.java b/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/TaskAuthorizationTestProfile.java new file mode 100644 index 000000000..b4cdb2522 --- /dev/null +++ b/tests/multiversion/grpc/src/test/java/org/a2aproject/sdk/tests/multiversion/grpc/TaskAuthorizationTestProfile.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.tests.multiversion.grpc; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkus.test.junit.QuarkusTestProfile; + +public class TaskAuthorizationTestProfile extends AuthTestProfile { + + @Override + public Map getConfigOverrides() { + Map config = new HashMap<>(super.getConfigOverrides()); + config.put("quarkus.security.users.embedded.users.userB", "passB"); + config.put("quarkus.security.users.embedded.roles.userB", "user"); + config.put("test.task-authorization.enabled", "true"); + return config; + } +} diff --git a/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/MultiVersionJSONRPCWithTaskAuthorizationTest.java b/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/MultiVersionJSONRPCWithTaskAuthorizationTest.java new file mode 100644 index 000000000..ff4a2c400 --- /dev/null +++ b/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/MultiVersionJSONRPCWithTaskAuthorizationTest.java @@ -0,0 +1,42 @@ +package org.a2aproject.sdk.tests.multiversion.jsonrpc; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.client.transport.jsonrpc.JSONRPCTransport; +import org.a2aproject.sdk.client.transport.jsonrpc.JSONRPCTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.spec.TransportProtocol; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class MultiVersionJSONRPCWithTaskAuthorizationTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + Vertx vertx; + + @Override + protected String getTransportProtocol() { + return TransportProtocol.JSONRPC.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(JSONRPCTransport.class, + new JSONRPCTransportConfigBuilder() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/TaskAuthorizationTestProfile.java b/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/TaskAuthorizationTestProfile.java new file mode 100644 index 000000000..6440a1a6a --- /dev/null +++ b/tests/multiversion/jsonrpc/src/test/java/org/a2aproject/sdk/tests/multiversion/jsonrpc/TaskAuthorizationTestProfile.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.tests.multiversion.jsonrpc; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkus.test.junit.QuarkusTestProfile; + +public class TaskAuthorizationTestProfile extends AuthTestProfile { + + @Override + public Map getConfigOverrides() { + Map config = new HashMap<>(super.getConfigOverrides()); + config.put("quarkus.security.users.embedded.users.userB", "passB"); + config.put("quarkus.security.users.embedded.roles.userB", "user"); + config.put("test.task-authorization.enabled", "true"); + return config; + } +} diff --git a/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/MultiVersionRestWithTaskAuthorizationTest.java b/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/MultiVersionRestWithTaskAuthorizationTest.java new file mode 100644 index 000000000..5469ea18e --- /dev/null +++ b/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/MultiVersionRestWithTaskAuthorizationTest.java @@ -0,0 +1,42 @@ +package org.a2aproject.sdk.tests.multiversion.rest; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.vertx.core.Vertx; +import jakarta.inject.Inject; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.http.VertxA2AHttpClient; +import org.a2aproject.sdk.client.transport.rest.RestTransport; +import org.a2aproject.sdk.client.transport.rest.RestTransportConfigBuilder; +import org.a2aproject.sdk.client.transport.spi.interceptors.auth.AuthInterceptor; +import org.a2aproject.sdk.server.apps.common.AbstractA2AServerWithTaskAuthorizationTest; +import org.a2aproject.sdk.spec.TransportProtocol; + +@QuarkusTest +@TestProfile(TaskAuthorizationTestProfile.class) +public class MultiVersionRestWithTaskAuthorizationTest extends AbstractA2AServerWithTaskAuthorizationTest { + + @Inject + Vertx vertx; + + @Override + protected String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransportWithCredentials(ClientBuilder builder, String username, String password) { + AuthInterceptor authInterceptor = new AuthInterceptor( + (schemeName, context) -> BASIC_AUTH_SCHEME_NAME.equals(schemeName) + ? getEncodedCredentials(username, password) : null); + builder.withTransport(RestTransport.class, + new RestTransportConfigBuilder() + .httpClient(new VertxA2AHttpClient(vertx)) + .addInterceptor(authInterceptor)); + } +} diff --git a/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/TaskAuthorizationTestProfile.java b/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/TaskAuthorizationTestProfile.java new file mode 100644 index 000000000..1d2d62f7f --- /dev/null +++ b/tests/multiversion/rest/src/test/java/org/a2aproject/sdk/tests/multiversion/rest/TaskAuthorizationTestProfile.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.tests.multiversion.rest; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkus.test.junit.QuarkusTestProfile; + +public class TaskAuthorizationTestProfile extends AuthTestProfile { + + @Override + public Map getConfigOverrides() { + Map config = new HashMap<>(super.getConfigOverrides()); + config.put("quarkus.security.users.embedded.users.userB", "passB"); + config.put("quarkus.security.users.embedded.roles.userB", "user"); + config.put("test.task-authorization.enabled", "true"); + return config; + } +} diff --git a/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/AbstractA2AServerWithTaskAuthorizationTest.java b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/AbstractA2AServerWithTaskAuthorizationTest.java new file mode 100644 index 000000000..df321fc02 --- /dev/null +++ b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/AbstractA2AServerWithTaskAuthorizationTest.java @@ -0,0 +1,228 @@ +package org.a2aproject.sdk.server.apps.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.a2aproject.sdk.client.Client; +import org.a2aproject.sdk.client.ClientBuilder; +import org.a2aproject.sdk.client.TaskEvent; +import org.a2aproject.sdk.client.TaskUpdateEvent; +import org.a2aproject.sdk.client.config.ClientConfig; +import org.a2aproject.sdk.jsonrpc.common.json.JsonUtil; +import org.a2aproject.sdk.jsonrpc.common.wrappers.ListTasksResult; +import org.a2aproject.sdk.spec.A2AClientException; +import org.a2aproject.sdk.spec.AgentCard; +import org.a2aproject.sdk.spec.CancelTaskParams; +import org.a2aproject.sdk.spec.ListTasksParams; +import org.a2aproject.sdk.spec.Message; +import org.a2aproject.sdk.spec.Task; +import org.a2aproject.sdk.spec.TaskNotFoundError; +import org.a2aproject.sdk.spec.TaskQueryParams; +import org.a2aproject.sdk.spec.TaskState; +import org.a2aproject.sdk.spec.TextPart; +import org.junit.jupiter.api.Test; + +/** + * Abstract base class for task authorization integration tests. + *

+ * Verifies that {@link org.a2aproject.sdk.server.auth.TaskAuthorizationProvider} + * is enforced end-to-end through the transport layer with two distinct users. + */ +public abstract class AbstractA2AServerWithTaskAuthorizationTest { + + protected static final String USER_A = "testuser"; + protected static final String USER_A_PASSWORD = "testpass"; + protected static final String USER_B = "userB"; + protected static final String USER_B_PASSWORD = "passB"; + protected static final String BASIC_AUTH_SCHEME_NAME = "basicAuth"; + + protected abstract String getTransportProtocol(); + + protected abstract String getTransportUrl(); + + protected abstract void configureTransportWithCredentials(ClientBuilder builder, String username, String password); + + protected Client createClient(String username, String password) throws A2AClientException { + AgentCard agentCard = fetchAgentCardFromServer(); + ClientConfig clientConfig = new ClientConfig.Builder().setStreaming(false).build(); + ClientBuilder clientBuilder = Client.builder(agentCard).clientConfig(clientConfig); + configureTransportWithCredentials(clientBuilder, username, password); + return clientBuilder.build(); + } + + protected AgentCard fetchAgentCardFromServer() { + try { + HttpClient httpClient = HttpClient.newHttpClient(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(getTransportUrl() + "/.well-known/agent-card.json")) + .GET() + .build(); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() != 200) { + throw new RuntimeException("Failed to fetch agent card: " + response.statusCode()); + } + return JsonUtil.fromJson(response.body(), AgentCard.class); + } catch (Exception e) { + throw new RuntimeException("Failed to fetch AgentCard from server", e); + } + } + + protected static String getEncodedCredentials(String username, String password) { + return Base64.getEncoder().encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); + } + + protected Task sendMessageAndGetTask(Client client, String messageText) throws Exception { + Message message = Message.builder() + .messageId(UUID.randomUUID().toString()) + .role(Message.Role.ROLE_USER) + .parts(new TextPart("a2a-local:" + messageText)) + .build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference receivedTask = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + + client.sendMessage(message, List.of((event, agentCard) -> { + if (event instanceof TaskEvent te) { + receivedTask.set(te.getTask()); + if (te.getTask().status().state() == TaskState.TASK_STATE_COMPLETED) { + latch.countDown(); + } + } else if (event instanceof TaskUpdateEvent tue) { + receivedTask.set(tue.getTask()); + if (tue.getTask().status().state() == TaskState.TASK_STATE_COMPLETED) { + latch.countDown(); + } + } + }), error -> { + errorRef.set(error); + latch.countDown(); + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS), "Task should complete within timeout"); + assertNull(errorRef.get(), "Should not have received an error: " + errorRef.get()); + + Task task = receivedTask.get(); + assertNotNull(task, "Should have received a task"); + assertEquals(TaskState.TASK_STATE_COMPLETED, task.status().state()); + return task; + } + + @Test + public void testOwnerCanGetOwnTask() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task task = sendMessageAndGetTask(clientA, "owner-get-test"); + + Task retrieved = clientA.getTask(new TaskQueryParams(task.id())); + assertNotNull(retrieved); + assertEquals(task.id(), retrieved.id()); + } + + @Test + public void testNonOwnerCannotGetTask() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task task = sendMessageAndGetTask(clientA, "non-owner-get-test"); + + Client clientB = createClient(USER_B, USER_B_PASSWORD); + A2AClientException error = assertThrows(A2AClientException.class, () -> + clientB.getTask(new TaskQueryParams(task.id()))); + assertInstanceOf(TaskNotFoundError.class, error.getCause()); + } + + @Test + public void testOwnerCanCancelOwnTask() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task task = sendMessageAndGetTask(clientA, "owner-cancel-test"); + + try { + clientA.cancelTask(new CancelTaskParams(task.id())); + } catch (A2AClientException e) { + // UnsupportedOperationError is acceptable (task already completed) + // but TaskNotFoundError means auth failed + assertFalse(e.getCause() instanceof TaskNotFoundError, + "Owner should not get TaskNotFoundError when canceling own task"); + } + } + + @Test + public void testNonOwnerCannotCancelTask() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task task = sendMessageAndGetTask(clientA, "non-owner-cancel-test"); + + Client clientB = createClient(USER_B, USER_B_PASSWORD); + A2AClientException error = assertThrows(A2AClientException.class, () -> + clientB.cancelTask(new CancelTaskParams(task.id()))); + assertInstanceOf(TaskNotFoundError.class, error.getCause()); + } + + @Test + public void testListTasksShowsOnlyOwnTasks() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task taskA1 = sendMessageAndGetTask(clientA, "list-test-a1"); + Task taskA2 = sendMessageAndGetTask(clientA, "list-test-a2"); + + Client clientB = createClient(USER_B, USER_B_PASSWORD); + Task taskB1 = sendMessageAndGetTask(clientB, "list-test-b1"); + + ListTasksParams listParams = ListTasksParams.builder().tenant("").build(); + + ListTasksResult resultA = clientA.listTasks(listParams); + Set taskIdsA = resultA.tasks().stream().map(Task::id).collect(Collectors.toSet()); + assertTrue(taskIdsA.contains(taskA1.id()), "UserA should see taskA1"); + assertTrue(taskIdsA.contains(taskA2.id()), "UserA should see taskA2"); + assertFalse(taskIdsA.contains(taskB1.id()), "UserA should NOT see taskB1"); + + ListTasksResult resultB = clientB.listTasks(listParams); + Set taskIdsB = resultB.tasks().stream().map(Task::id).collect(Collectors.toSet()); + assertTrue(taskIdsB.contains(taskB1.id()), "UserB should see taskB1"); + assertFalse(taskIdsB.contains(taskA1.id()), "UserB should NOT see taskA1"); + assertFalse(taskIdsB.contains(taskA2.id()), "UserB should NOT see taskA2"); + } + + @Test + public void testUnauthorizedLooksLikeNotFound() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task task = sendMessageAndGetTask(clientA, "info-hiding-test"); + + Client clientB = createClient(USER_B, USER_B_PASSWORD); + + A2AClientException unauthorizedError = assertThrows(A2AClientException.class, () -> + clientB.getTask(new TaskQueryParams(task.id()))); + A2AClientException notFoundError = assertThrows(A2AClientException.class, () -> + clientB.getTask(new TaskQueryParams(UUID.randomUUID().toString()))); + + assertInstanceOf(TaskNotFoundError.class, unauthorizedError.getCause()); + assertInstanceOf(TaskNotFoundError.class, notFoundError.getCause()); + } + + @Test + public void testBothUsersCanCreateTasks() throws Exception { + Client clientA = createClient(USER_A, USER_A_PASSWORD); + Task taskA = sendMessageAndGetTask(clientA, "create-test-a"); + assertNotNull(taskA.id()); + + Client clientB = createClient(USER_B, USER_B_PASSWORD); + Task taskB = sendMessageAndGetTask(clientB, "create-test-b"); + assertNotNull(taskB.id()); + } +} diff --git a/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TaskAuthorizationTestProfile.java b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TaskAuthorizationTestProfile.java new file mode 100644 index 000000000..503142ab3 --- /dev/null +++ b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TaskAuthorizationTestProfile.java @@ -0,0 +1,18 @@ +package org.a2aproject.sdk.server.apps.common; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkus.test.junit.QuarkusTestProfile; + +public class TaskAuthorizationTestProfile extends AuthTestProfile { + + @Override + public Map getConfigOverrides() { + Map config = new HashMap<>(super.getConfigOverrides()); + config.put("quarkus.security.users.embedded.users.userB", "passB"); + config.put("quarkus.security.users.embedded.roles.userB", "user"); + config.put("test.task-authorization.enabled", "true"); + return config; + } +} diff --git a/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TestTaskAuthorizationProvider.java b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TestTaskAuthorizationProvider.java new file mode 100644 index 000000000..3febb92ed --- /dev/null +++ b/tests/server-common/src/test/java/org/a2aproject/sdk/server/apps/common/TestTaskAuthorizationProvider.java @@ -0,0 +1,48 @@ +package org.a2aproject.sdk.server.apps.common; + +import java.util.concurrent.ConcurrentHashMap; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.arc.Unremovable; +import io.quarkus.arc.properties.IfBuildProperty; +import org.a2aproject.sdk.server.ServerCallContext; +import org.a2aproject.sdk.server.auth.TaskAuthorizationProvider; +import org.a2aproject.sdk.server.auth.TaskOperation; + +@ApplicationScoped +@Unremovable +@IfBuildProperty(name = "test.task-authorization.enabled", stringValue = "true", enableIfMissing = false) +public class TestTaskAuthorizationProvider implements TaskAuthorizationProvider { + + private final ConcurrentHashMap taskOwners = new ConcurrentHashMap<>(); + + @Override + public boolean checkRead(ServerCallContext context, String taskId, TaskOperation operation) { + String owner = taskOwners.get(taskId); + // Intentionally fail-open for testing; production implementations should fail-closed (deny unknown tasks) + return owner == null || owner.equals(context.getUser().getUsername()); + } + + @Override + public boolean checkWrite(ServerCallContext context, String taskId, TaskOperation operation) { + String owner = taskOwners.get(taskId); + // Intentionally fail-open for testing; production implementations should fail-closed (deny unknown tasks) + return owner == null || owner.equals(context.getUser().getUsername()); + } + + @Override + public boolean checkCreate(ServerCallContext context, TaskOperation operation) { + return context.getUser().isAuthenticated(); + } + + @Override + public boolean isTaskRecorded(String taskId) { + return taskOwners.containsKey(taskId); + } + + @Override + public void recordOwnership(ServerCallContext context, String taskId, TaskOperation operation) { + taskOwners.putIfAbsent(taskId, context.getUser().getUsername()); + } +}