Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void testDirectNotificationTrigger() {
mockPushNotificationSender.sendNotification(testTask);

// Verify it was captured
Queue<Task> captured = mockPushNotificationSender.getCapturedTasks();
Queue<Task> captured = mockPushNotificationSender.getCapturedEvents();
assertEquals(1, captured.size());
assertEquals("direct-test-task", captured.peek().getId());
Comment on lines +91 to 93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type of the captured variable is incorrect. mockPushNotificationSender.getCapturedEvents() returns a Queue<StreamingEventKind>, not a Queue<Task>. This will cause a compilation error. You need to update the variable type and add a cast to access Task-specific methods. It's also safer to check the type before casting.

        Queue<StreamingEventKind> captured = mockPushNotificationSender.getCapturedEvents();
        assertEquals(1, captured.size());
        assertTrue(captured.peek() instanceof Task, "Captured event should be a Task");
        assertEquals("direct-test-task", ((Task) captured.peek()).getId());

}
Expand Down Expand Up @@ -151,7 +151,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
boolean notificationReceived = false;

while (System.currentTimeMillis() < end) {
if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) {
if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) {
notificationReceived = true;
break;
}
Expand All @@ -161,7 +161,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
assertTrue(notificationReceived, "Timeout waiting for push notification.");

// Step 6: Verify the captured notification
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedTasks();
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedEvents();

// Verify the notification contains the correct task with artifacts
Task notifiedTaskWithArtifact = capturedTasks.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

import io.a2a.spec.StreamingEventKind;
import jakarta.annotation.Priority;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Alternative;

import io.a2a.server.tasks.PushNotificationSender;
import io.a2a.spec.Task;

/**
* Mock implementation of PushNotificationSender for integration testing.
Expand All @@ -19,18 +19,18 @@
@Priority(100)
public class MockPushNotificationSender implements PushNotificationSender {

private final Queue<Task> capturedTasks = new ConcurrentLinkedQueue<>();
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();

@Override
public void sendNotification(Task task) {
capturedTasks.add(task);
public void sendNotification(StreamingEventKind kind) {
capturedEvents.add(kind);
}

public Queue<Task> getCapturedTasks() {
return capturedTasks;
public Queue<StreamingEventKind> getCapturedEvents() {
return capturedEvents;
}

public void clear() {
capturedTasks.clear();
capturedEvents.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
import static io.a2a.client.http.A2AHttpClient.APPLICATION_JSON;
import static io.a2a.client.http.A2AHttpClient.CONTENT_TYPE;
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;
import static io.a2a.spec.Message.MESSAGE;
import static io.a2a.spec.Task.TASK;
import static io.a2a.spec.TaskArtifactUpdateEvent.ARTIFACT_UPDATE;
import static io.a2a.spec.TaskStatusUpdateEvent.STATUS_UPDATE;

import io.a2a.spec.Message;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskStatusUpdateEvent;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

Expand Down Expand Up @@ -42,34 +51,45 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt
}

@Override
public void sendNotification(Task task) {
List<PushNotificationConfig> pushConfigs = configStore.getInfo(task.getId());
public void sendNotification(StreamingEventKind kind) {
String taskId = switch (kind.getKind()) {
case TASK -> ((Task) kind).getId();
case MESSAGE -> ((Message)kind).getTaskId();
case STATUS_UPDATE -> ((TaskStatusUpdateEvent)kind).getTaskId();
case ARTIFACT_UPDATE -> ((TaskArtifactUpdateEvent)kind).getTaskId();
default -> null;
};
Comment on lines +55 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since this project uses a modern Java version (as indicated by the use of sealed interfaces), you can leverage pattern matching for switch. This will make the code more concise, readable, and type-safe by avoiding explicit string comparisons and casts.

        String taskId = switch (kind) {
            case Task t -> t.getId();
            case Message m -> m.getTaskId();
            case TaskStatusUpdateEvent e -> e.getTaskId();
            case TaskArtifactUpdateEvent e -> e.getTaskId();
        };

if (taskId == null) {
return;
}

List<PushNotificationConfig> pushConfigs = configStore.getInfo(taskId);
if (pushConfigs == null || pushConfigs.isEmpty()) {
return;
}

List<CompletableFuture<Boolean>> dispatchResults = pushConfigs
.stream()
.map(pushConfig -> dispatch(task, pushConfig))
.map(pushConfig -> dispatch(kind, pushConfig))
.toList();
CompletableFuture<Void> allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0]));
CompletableFuture<Boolean> dispatchResult = allFutures.thenApply(v -> dispatchResults.stream()
.allMatch(CompletableFuture::join));
try {
boolean allSent = dispatchResult.get();
if (! allSent) {
LOGGER.warn("Some push notifications failed to send for taskId: " + task.getId());
if (!allSent) {
LOGGER.warn("Some push notifications failed to send for taskId: " + taskId);
}
} catch (InterruptedException | ExecutionException e) {
LOGGER.warn("Some push notifications failed to send for taskId " + task.getId() + ": {}", e.getMessage(), e);
LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e);
}
}

private CompletableFuture<Boolean> dispatch(Task task, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo));
private CompletableFuture<Boolean> dispatch(StreamingEventKind kind, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(kind, pushInfo));
}

private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) {
private boolean dispatchNotification(StreamingEventKind kind, PushNotificationConfig pushInfo) {
String url = pushInfo.url();
String token = pushInfo.token();

Expand All @@ -80,7 +100,7 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo)

String body;
try {
body = Utils.OBJECT_MAPPER.writeValueAsString(task);
body = Utils.OBJECT_MAPPER.writeValueAsString(kind);
} catch (JsonProcessingException e) {
LOGGER.debug("Error writing value as string: {}", e.getMessage(), e);
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package io.a2a.server.tasks;

import io.a2a.spec.Task;
import io.a2a.spec.StreamingEventKind;

/**
* Interface for sending push notifications for tasks.
*/
public interface PushNotificationSender {

/**
* Sends a push notification containing the latest task state.
* @param task the task
* Sends a push notification with a payload related to the task.
* @param kind the payload to push
*/
void sendNotification(Task task);
void sendNotification(StreamingEventKind kind);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import io.a2a.spec.Message;
import io.a2a.spec.Part;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.TextPart;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -67,6 +71,7 @@ class TestPostBuilder implements A2AHttpClient.PostBuilder {
@Override
public PostBuilder body(String body) {
this.body = body;
System.out.println("body = " + body);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This System.out.println appears to be leftover debugging code and should be removed before merging.

return this;
}

Expand All @@ -80,6 +85,7 @@ public A2AHttpResponse post() throws IOException, InterruptedException {
Task task = Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE);
tasks.add(task);
Comment on lines 85 to 86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The mock HTTP client is hardcoded to deserialize the request body into a Task object. With the changes in this PR, the payload can be any StreamingEventKind, such as a Message. This will cause a JsonProcessingException when testSendNotificationWithMessage is run. The mock client needs to be updated to handle different event types. I suggest deserializing to StreamingEventKind and then conditionally adding to the tasks list if it's a Task. You will also need to update the new test testSendNotificationWithMessage to assert against the correct event type (a Message), which may require adding a new list to TestHttpClient to store all captured events.

Suggested change
Task task = Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE);
tasks.add(task);
StreamingEventKind event = Utils.OBJECT_MAPPER.readValue(body, StreamingEventKind.class);
if (event instanceof Task task) {
tasks.add(task);
}

urls.add(url);
System.out.println(requestHeaders);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This System.out.println appears to be leftover debugging code and should be removed before merging.

headers.add(new java.util.HashMap<>(requestHeaders));

return new A2AHttpResponse() {
Expand All @@ -95,7 +101,7 @@ public boolean success() {

@Override
public String body() {
return "";
return body;
}
};
} finally {
Expand Down Expand Up @@ -316,4 +322,45 @@ public void testSendNotificationHttpError() {
// Verify no tasks were successfully processed due to the error
assertEquals(0, testHttpClient.tasks.size());
}

@Test
public void testSendNotificationWithMessage() throws InterruptedException {
String taskId = "task_send_notification_with_message";
Task taskData = createSampleTask(taskId, TaskState.COMPLETED);
PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token");

// Set up the configuration in the store
configStore.setInfo(taskId, config);

// Set up latch to wait for async completion
testHttpClient.latch = new CountDownLatch(1);

Message message = new Message.Builder()
.taskId(taskId)
.messageId("task_push_notification_message")
.parts(Collections.singletonList(new TextPart("Message for task " + taskId)))
.role(Message.Role.USER)
.build();
sender.sendNotification(message);

// Wait for the async operation to complete
assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds");

// Verify the task was sent via HTTP
assertEquals(1, testHttpClient.tasks.size());
Task sentTask = testHttpClient.tasks.get(0);
assertEquals(taskData.getId(), sentTask.getId());
Comment on lines +350 to +352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These assertions are incorrect. You are sending a Message object, but asserting that a Task was captured. The PushNotificationSender sends the Message object itself. This test will fail because TestHttpClient is not equipped to handle Message payloads and will throw an exception during deserialization, as noted in another comment. The TestHttpClient needs to be updated to capture StreamingEventKind objects, and this test should assert that a Message was received.


// Verify that the X-A2A-Notification-Token header is sent with the correct token
assertEquals(1, testHttpClient.headers.size());
Map<String, String> sentHeaders = testHttpClient.headers.get(0);
assertEquals(2, sentHeaders.size());
assertTrue(sentHeaders.containsKey(A2AHeaders.X_A2A_NOTIFICATION_TOKEN));
assertEquals(config.token(), sentHeaders.get(A2AHeaders.X_A2A_NOTIFICATION_TOKEN));
// Content-Type header should always be present
assertTrue(sentHeaders.containsKey(CONTENT_TYPE));
assertEquals(APPLICATION_JSON, sentHeaders.get(CONTENT_TYPE));

}

}
Loading