Skip to content

Commit 847f18e

Browse files
authored
fix(jsonrpc, rest): extensions support in get_card methods in json-rpc and rest transports (#564)
`Headers` are now updated with `extensions` before the `get_agent_card` call which has headers as input parameters. - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. Fixes #504 🦕 Release-As: 0.3.19
1 parent 213d9f8 commit 847f18e

File tree

4 files changed

+224
-47
lines changed

4 files changed

+224
-47
lines changed

src/a2a/client/transports/jsonrpc.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,14 @@ async def get_card(
378378
extensions: list[str] | None = None,
379379
) -> AgentCard:
380380
"""Retrieves the agent's card."""
381+
modified_kwargs = update_extension_header(
382+
self._get_http_args(context),
383+
extensions if extensions is not None else self.extensions,
384+
)
381385
card = self.agent_card
382386
if not card:
383387
resolver = A2ACardResolver(self.httpx_client, self.url)
384-
card = await resolver.get_agent_card(
385-
http_kwargs=self._get_http_args(context)
386-
)
388+
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
387389
self._needs_extended_card = (
388390
card.supports_authenticated_extended_card
389391
)
@@ -393,10 +395,6 @@ async def get_card(
393395
return card
394396

395397
request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))
396-
modified_kwargs = update_extension_header(
397-
self._get_http_args(context),
398-
extensions if extensions is not None else self.extensions,
399-
)
400398
payload, modified_kwargs = await self._apply_interceptors(
401399
request.method,
402400
request.model_dump(mode='json', exclude_none=True),

src/a2a/client/transports/rest.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,14 @@ async def get_card(
370370
extensions: list[str] | None = None,
371371
) -> AgentCard:
372372
"""Retrieves the agent's card."""
373+
modified_kwargs = update_extension_header(
374+
self._get_http_args(context),
375+
extensions if extensions is not None else self.extensions,
376+
)
373377
card = self.agent_card
374378
if not card:
375379
resolver = A2ACardResolver(self.httpx_client, self.url)
376-
card = await resolver.get_agent_card(
377-
http_kwargs=self._get_http_args(context)
378-
)
380+
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
379381
self._needs_extended_card = (
380382
card.supports_authenticated_extended_card
381383
)
@@ -384,10 +386,6 @@ async def get_card(
384386
if not self._needs_extended_card:
385387
return card
386388

387-
modified_kwargs = update_extension_header(
388-
self._get_http_args(context),
389-
extensions if extensions is not None else self.extensions,
390-
)
391389
_, modified_kwargs = await self._apply_interceptors(
392390
{},
393391
modified_kwargs,

tests/client/transports/test_jsonrpc_client.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ async def async_iterable_from_list(
114114
yield item
115115

116116

117+
def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
118+
headers = mock_kwargs.get('headers', {})
119+
assert HTTP_EXTENSION_HEADER in headers
120+
header_value = headers[HTTP_EXTENSION_HEADER]
121+
actual_extensions = {e.strip() for e in header_value.split(',')}
122+
assert actual_extensions == expected_extensions
123+
124+
117125
class TestA2ACardResolver:
118126
BASE_URL = 'http://example.com'
119127
AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH
@@ -823,18 +831,13 @@ async def test_send_message_with_default_extensions(
823831
mock_httpx_client.post.assert_called_once()
824832
_, mock_kwargs = mock_httpx_client.post.call_args
825833

826-
headers = mock_kwargs.get('headers', {})
827-
assert HTTP_EXTENSION_HEADER in headers
828-
header_value = headers[HTTP_EXTENSION_HEADER]
829-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
830-
actual_extensions = set(actual_extensions_list)
831-
832-
expected_extensions = {
833-
'https://example.com/test-ext/v1',
834-
'https://example.com/test-ext/v2',
835-
}
836-
assert len(actual_extensions_list) == 2
837-
assert actual_extensions == expected_extensions
834+
_assert_extensions_header(
835+
mock_kwargs,
836+
{
837+
'https://example.com/test-ext/v1',
838+
'https://example.com/test-ext/v2',
839+
},
840+
)
838841

839842
@pytest.mark.asyncio
840843
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
@@ -870,8 +873,83 @@ async def test_send_message_streaming_with_new_extensions(
870873
mock_aconnect_sse.assert_called_once()
871874
_, kwargs = mock_aconnect_sse.call_args
872875

873-
headers = kwargs.get('headers', {})
874-
assert HTTP_EXTENSION_HEADER in headers
875-
assert (
876-
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
876+
_assert_extensions_header(
877+
kwargs,
878+
{
879+
'https://example.com/test-ext/v2',
880+
},
881+
)
882+
883+
@pytest.mark.asyncio
884+
async def test_get_card_no_card_provided_with_extensions(
885+
self, mock_httpx_client: AsyncMock
886+
):
887+
"""Test get_card with extensions set in Client when no card is initially provided.
888+
Tests that the extensions are added to the HTTP GET request."""
889+
extensions = [
890+
'https://example.com/test-ext/v1',
891+
'https://example.com/test-ext/v2',
892+
]
893+
client = JsonRpcTransport(
894+
httpx_client=mock_httpx_client,
895+
url=TestJsonRpcTransport.AGENT_URL,
896+
extensions=extensions,
897+
)
898+
mock_response = AsyncMock(spec=httpx.Response)
899+
mock_response.status_code = 200
900+
mock_response.json.return_value = AGENT_CARD.model_dump(mode='json')
901+
mock_httpx_client.get.return_value = mock_response
902+
903+
await client.get_card()
904+
905+
mock_httpx_client.get.assert_called_once()
906+
_, mock_kwargs = mock_httpx_client.get.call_args
907+
908+
_assert_extensions_header(
909+
mock_kwargs,
910+
{
911+
'https://example.com/test-ext/v1',
912+
'https://example.com/test-ext/v2',
913+
},
914+
)
915+
916+
@pytest.mark.asyncio
917+
async def test_get_card_with_extended_card_support_with_extensions(
918+
self, mock_httpx_client: AsyncMock
919+
):
920+
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
921+
Tests that the extensions are added to the RPC request."""
922+
extensions = [
923+
'https://example.com/test-ext/v1',
924+
'https://example.com/test-ext/v2',
925+
]
926+
agent_card = AGENT_CARD.model_copy(
927+
update={'supports_authenticated_extended_card': True}
928+
)
929+
client = JsonRpcTransport(
930+
httpx_client=mock_httpx_client,
931+
agent_card=agent_card,
932+
extensions=extensions,
933+
)
934+
935+
rpc_response = {
936+
'id': '123',
937+
'jsonrpc': '2.0',
938+
'result': AGENT_CARD_EXTENDED.model_dump(mode='json'),
939+
}
940+
with patch.object(
941+
client, '_send_request', new_callable=AsyncMock
942+
) as mock_send_request:
943+
mock_send_request.return_value = rpc_response
944+
await client.get_card(extensions=extensions)
945+
946+
mock_send_request.assert_called_once()
947+
_, mock_kwargs = mock_send_request.call_args[0]
948+
949+
_assert_extensions_header(
950+
mock_kwargs,
951+
{
952+
'https://example.com/test-ext/v1',
953+
'https://example.com/test-ext/v2',
954+
},
877955
)

tests/client/transports/test_rest_client.py

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from a2a.client import create_text_message_object
1010
from a2a.client.transports.rest import RestTransport
1111
from a2a.extensions.common import HTTP_EXTENSION_HEADER
12-
from a2a.types import AgentCard, MessageSendParams, Role
12+
from a2a.types import (
13+
AgentCapabilities,
14+
AgentCard,
15+
AgentSkill,
16+
MessageSendParams,
17+
Role,
18+
)
1319

1420

1521
@pytest.fixture
@@ -32,6 +38,14 @@ async def async_iterable_from_list(
3238
yield item
3339

3440

41+
def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
42+
headers = mock_kwargs.get('headers', {})
43+
assert HTTP_EXTENSION_HEADER in headers
44+
header_value = headers[HTTP_EXTENSION_HEADER]
45+
actual_extensions = {e.strip() for e in header_value.split(',')}
46+
assert actual_extensions == expected_extensions
47+
48+
3549
class TestRestTransportExtensions:
3650
@pytest.mark.asyncio
3751
async def test_send_message_with_default_extensions(
@@ -67,18 +81,13 @@ async def test_send_message_with_default_extensions(
6781
mock_build_request.assert_called_once()
6882
_, kwargs = mock_build_request.call_args
6983

70-
headers = kwargs.get('headers', {})
71-
assert HTTP_EXTENSION_HEADER in headers
72-
header_value = kwargs['headers'][HTTP_EXTENSION_HEADER]
73-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
74-
actual_extensions = set(actual_extensions_list)
75-
76-
expected_extensions = {
77-
'https://example.com/test-ext/v1',
78-
'https://example.com/test-ext/v2',
79-
}
80-
assert len(actual_extensions_list) == 2
81-
assert actual_extensions == expected_extensions
84+
_assert_extensions_header(
85+
kwargs,
86+
{
87+
'https://example.com/test-ext/v1',
88+
'https://example.com/test-ext/v2',
89+
},
90+
)
8291

8392
@pytest.mark.asyncio
8493
@patch('a2a.client.transports.rest.aconnect_sse')
@@ -114,8 +123,102 @@ async def test_send_message_streaming_with_new_extensions(
114123
mock_aconnect_sse.assert_called_once()
115124
_, kwargs = mock_aconnect_sse.call_args
116125

117-
headers = kwargs.get('headers', {})
118-
assert HTTP_EXTENSION_HEADER in headers
119-
assert (
120-
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
126+
_assert_extensions_header(
127+
kwargs,
128+
{
129+
'https://example.com/test-ext/v2',
130+
},
131+
)
132+
133+
@pytest.mark.asyncio
134+
async def test_get_card_no_card_provided_with_extensions(
135+
self, mock_httpx_client: AsyncMock
136+
):
137+
"""Test get_card with extensions set in Client when no card is initially provided.
138+
Tests that the extensions are added to the HTTP GET request."""
139+
extensions = [
140+
'https://example.com/test-ext/v1',
141+
'https://example.com/test-ext/v2',
142+
]
143+
client = RestTransport(
144+
httpx_client=mock_httpx_client,
145+
url='http://agent.example.com/api',
146+
extensions=extensions,
147+
)
148+
149+
mock_response = AsyncMock(spec=httpx.Response)
150+
mock_response.status_code = 200
151+
mock_response.json.return_value = {
152+
'name': 'Test Agent',
153+
'description': 'Test Agent Description',
154+
'url': 'http://agent.example.com/api',
155+
'version': '1.0.0',
156+
'default_input_modes': ['text'],
157+
'default_output_modes': ['text'],
158+
'capabilities': AgentCapabilities().model_dump(),
159+
'skills': [],
160+
}
161+
mock_httpx_client.get.return_value = mock_response
162+
163+
await client.get_card()
164+
165+
mock_httpx_client.get.assert_called_once()
166+
_, mock_kwargs = mock_httpx_client.get.call_args
167+
168+
_assert_extensions_header(
169+
mock_kwargs,
170+
{
171+
'https://example.com/test-ext/v1',
172+
'https://example.com/test-ext/v2',
173+
},
174+
)
175+
176+
@pytest.mark.asyncio
177+
async def test_get_card_with_extended_card_support_with_extensions(
178+
self, mock_httpx_client: AsyncMock
179+
):
180+
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
181+
Tests that the extensions are added to the GET request."""
182+
extensions = [
183+
'https://example.com/test-ext/v1',
184+
'https://example.com/test-ext/v2',
185+
]
186+
agent_card = AgentCard(
187+
name='Test Agent',
188+
description='Test Agent Description',
189+
url='http://agent.example.com/api',
190+
version='1.0.0',
191+
default_input_modes=['text'],
192+
default_output_modes=['text'],
193+
capabilities=AgentCapabilities(),
194+
skills=[],
195+
supports_authenticated_extended_card=True,
196+
)
197+
client = RestTransport(
198+
httpx_client=mock_httpx_client,
199+
agent_card=agent_card,
200+
)
201+
202+
mock_response = AsyncMock(spec=httpx.Response)
203+
mock_response.status_code = 200
204+
mock_response.json.return_value = agent_card.model_dump(mode='json')
205+
mock_httpx_client.send.return_value = mock_response
206+
207+
with patch.object(
208+
client, '_send_get_request', new_callable=AsyncMock
209+
) as mock_send_get_request:
210+
mock_send_get_request.return_value = agent_card.model_dump(
211+
mode='json'
212+
)
213+
await client.get_card(extensions=extensions)
214+
215+
mock_send_get_request.assert_called_once()
216+
_, _, mock_kwargs = mock_send_get_request.call_args[0]
217+
218+
_assert_extensions_header(
219+
mock_kwargs,
220+
{
221+
'https://example.com/test-ext/v1',
222+
'https://example.com/test-ext/v2',
223+
},
121224
)

0 commit comments

Comments
 (0)