aixtools 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (46) hide show
  1. aixtools/_version.py +2 -2
  2. aixtools/a2a/app.py +1 -1
  3. aixtools/a2a/google_sdk/__init__.py +0 -0
  4. aixtools/a2a/google_sdk/card.py +27 -0
  5. aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
  6. aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
  7. aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
  8. aixtools/a2a/google_sdk/utils.py +59 -0
  9. aixtools/agents/prompt.py +97 -0
  10. aixtools/context.py +5 -0
  11. aixtools/google/client.py +25 -0
  12. aixtools/logging/logging_config.py +45 -0
  13. aixtools/mcp/client.py +274 -0
  14. aixtools/mcp/faulty_mcp.py +7 -7
  15. aixtools/server/utils.py +3 -3
  16. aixtools/utils/config.py +6 -0
  17. aixtools/utils/files.py +17 -0
  18. aixtools/utils/utils.py +7 -0
  19. {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/METADATA +3 -1
  20. {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/RECORD +45 -14
  21. {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/top_level.txt +1 -0
  22. scripts/test.sh +23 -0
  23. tests/__init__.py +0 -0
  24. tests/unit/__init__.py +0 -0
  25. tests/unit/a2a/__init__.py +0 -0
  26. tests/unit/a2a/google_sdk/__init__.py +0 -0
  27. tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
  28. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +188 -0
  29. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +156 -0
  30. tests/unit/a2a/google_sdk/test_card.py +114 -0
  31. tests/unit/a2a/google_sdk/test_remote_agent_connection.py +413 -0
  32. tests/unit/a2a/google_sdk/test_utils.py +208 -0
  33. tests/unit/agents/__init__.py +0 -0
  34. tests/unit/agents/test_prompt.py +363 -0
  35. tests/unit/google/__init__.py +1 -0
  36. tests/unit/google/test_client.py +233 -0
  37. tests/unit/mcp/__init__.py +0 -0
  38. tests/unit/mcp/test_client.py +242 -0
  39. tests/unit/server/__init__.py +0 -0
  40. tests/unit/server/test_path.py +225 -0
  41. tests/unit/server/test_utils.py +362 -0
  42. tests/unit/utils/__init__.py +0 -0
  43. tests/unit/utils/test_files.py +146 -0
  44. aixtools/a2a/__init__.py +0 -5
  45. {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/WHEEL +0 -0
  46. {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,413 @@
1
+ """Tests for the remote agent connection module."""
2
+
3
+ import unittest
4
+ from unittest.mock import AsyncMock, MagicMock, patch, call
5
+
6
+ from a2a.client import Client
7
+ from a2a.types import AgentCard, Message, Task, TaskState, TaskStatus, TaskQueryParams
8
+
9
+ from aixtools.a2a.google_sdk.remote_agent_connection import (
10
+ RemoteAgentConnection,
11
+ is_in_terminal_state,
12
+ is_in_terminal_or_interrupted_state,
13
+ )
14
+
15
+
16
+ class TestRemoteAgentConnection(unittest.IsolatedAsyncioTestCase):
17
+ """Tests for the RemoteAgentConnection class."""
18
+
19
+ def setUp(self):
20
+ self.mock_card = MagicMock(spec=AgentCard)
21
+ self.mock_client = AsyncMock(spec=Client)
22
+ self.connection = RemoteAgentConnection(self.mock_card, self.mock_client)
23
+
24
+ def test_get_agent_card(self):
25
+ """Test that get_agent_card returns the stored card."""
26
+ result = self.connection.get_agent_card()
27
+ self.assertEqual(result, self.mock_card)
28
+
29
+ async def test_send_message_returns_message(self):
30
+ """Test send_message when it receives a Message response."""
31
+ mock_message = MagicMock(spec=Message)
32
+ mock_task = MagicMock(spec=Task)
33
+
34
+ # Mock the async generator to yield a message
35
+ async def mock_generator():
36
+ yield mock_message
37
+ yield (mock_task,)
38
+
39
+ self.mock_client.send_message.return_value = mock_generator()
40
+
41
+ test_message = MagicMock(spec=Message)
42
+ result = await self.connection.send_message(test_message)
43
+
44
+ self.assertEqual(result, mock_message)
45
+ self.mock_client.send_message.assert_called_once_with(test_message)
46
+
47
+ async def test_send_message_returns_terminal_task(self):
48
+ """Test send_message when it receives a task in terminal state."""
49
+ mock_task = MagicMock(spec=Task)
50
+ mock_task.status = MagicMock(spec=TaskStatus)
51
+ mock_task.status.state = TaskState.completed
52
+
53
+ # Mock the async generator to yield a terminal task
54
+ async def mock_generator():
55
+ yield (mock_task,)
56
+
57
+ self.mock_client.send_message.return_value = mock_generator()
58
+
59
+ test_message = MagicMock(spec=Message)
60
+ result = await self.connection.send_message(test_message)
61
+
62
+ self.assertEqual(result, mock_task)
63
+
64
+ async def test_send_message_returns_interrupted_task(self):
65
+ """Test send_message when it receives a task in interrupted state."""
66
+ mock_task = MagicMock(spec=Task)
67
+ mock_task.status = MagicMock(spec=TaskStatus)
68
+ mock_task.status.state = TaskState.input_required
69
+
70
+ # Mock the async generator to yield an interrupted task
71
+ async def mock_generator():
72
+ yield (mock_task,)
73
+
74
+ self.mock_client.send_message.return_value = mock_generator()
75
+
76
+ test_message = MagicMock(spec=Message)
77
+ result = await self.connection.send_message(test_message)
78
+
79
+ self.assertEqual(result, mock_task)
80
+
81
+ async def test_send_message_returns_last_task(self):
82
+ """Test send_message returns the last task when no terminal/interrupted state is reached."""
83
+ mock_task1 = MagicMock(spec=Task)
84
+ mock_task1.status = MagicMock(spec=TaskStatus)
85
+ mock_task1.status.state = TaskState.working
86
+
87
+ mock_task2 = MagicMock(spec=Task)
88
+ mock_task2.status = MagicMock(spec=TaskStatus)
89
+ mock_task2.status.state = TaskState.working
90
+
91
+ # Mock the async generator to yield multiple running tasks
92
+ async def mock_generator():
93
+ yield (mock_task1,)
94
+ yield (mock_task2,)
95
+
96
+ self.mock_client.send_message.return_value = mock_generator()
97
+
98
+ test_message = MagicMock(spec=Message)
99
+ result = await self.connection.send_message(test_message)
100
+
101
+ self.assertEqual(result, mock_task2)
102
+
103
+ async def test_send_message_handles_exception(self):
104
+ """Test send_message properly handles and re-raises exceptions."""
105
+ test_error = Exception("Test error")
106
+ self.mock_client.send_message.side_effect = test_error
107
+
108
+ test_message = MagicMock(spec=Message)
109
+
110
+ with self.assertRaises(Exception) as context:
111
+ await self.connection.send_message(test_message)
112
+
113
+ self.assertEqual(context.exception, test_error)
114
+
115
+ async def test_send_message_returns_none_when_no_events(self):
116
+ """Test send_message returns None when no events are yielded."""
117
+ # Mock empty async generator
118
+ async def mock_generator():
119
+ return
120
+ yield # This line will never be reached
121
+
122
+ self.mock_client.send_message.return_value = mock_generator()
123
+
124
+ test_message = MagicMock(spec=Message)
125
+ result = await self.connection.send_message(test_message)
126
+
127
+ self.assertIsNone(result)
128
+
129
+ async def test_send_message_with_polling_returns_message(self):
130
+ """Test send_message_with_polling when send_message returns a Message."""
131
+ mock_message = MagicMock(spec=Message)
132
+
133
+ # Mock send_message to return a Message
134
+ self.connection.send_message = AsyncMock(return_value=mock_message)
135
+
136
+ test_message = MagicMock(spec=Message)
137
+ result = await self.connection.send_message_with_polling(test_message)
138
+
139
+ self.assertEqual(result, mock_message)
140
+ self.connection.send_message.assert_called_once_with(test_message)
141
+
142
+ async def test_send_message_with_polling_returns_terminal_task(self):
143
+ """Test send_message_with_polling when send_message returns a task in terminal state."""
144
+ mock_task = MagicMock(spec=Task)
145
+ mock_task.status = MagicMock(spec=TaskStatus)
146
+ mock_task.status.state = TaskState.completed
147
+
148
+ # Mock send_message to return a terminal task
149
+ self.connection.send_message = AsyncMock(return_value=mock_task)
150
+
151
+ test_message = MagicMock(spec=Message)
152
+ result = await self.connection.send_message_with_polling(test_message)
153
+
154
+ self.assertEqual(result, mock_task)
155
+ self.connection.send_message.assert_called_once_with(test_message)
156
+
157
+ async def test_send_message_with_polling_polls_until_terminal(self):
158
+ """Test send_message_with_polling polls until task reaches terminal state."""
159
+ # Create initial task in working state
160
+ initial_task = MagicMock(spec=Task)
161
+ initial_task.id = "task123"
162
+ initial_task.status = MagicMock(spec=TaskStatus)
163
+ initial_task.status.state = TaskState.working
164
+
165
+ # Create intermediate task still in working state
166
+ intermediate_task = MagicMock(spec=Task)
167
+ intermediate_task.status = MagicMock(spec=TaskStatus)
168
+ intermediate_task.status.state = TaskState.working
169
+
170
+ # Create final task in completed state
171
+ final_task = MagicMock(spec=Task)
172
+ final_task.status = MagicMock(spec=TaskStatus)
173
+ final_task.status.state = TaskState.completed
174
+
175
+ # Mock send_message to return initial working task
176
+ self.connection.send_message = AsyncMock(return_value=initial_task)
177
+
178
+ # Mock get_task to return intermediate task first, then final task
179
+ self.mock_client.get_task.side_effect = [intermediate_task, final_task]
180
+
181
+ test_message = MagicMock(spec=Message)
182
+
183
+ with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
184
+ result = await self.connection.send_message_with_polling(
185
+ test_message, sleep_time=0.1, max_iter=10
186
+ )
187
+
188
+ self.assertEqual(result, final_task)
189
+ self.connection.send_message.assert_called_once_with(test_message)
190
+
191
+ # Verify get_task was called with correct parameters
192
+ expected_calls = [
193
+ call(TaskQueryParams(id="task123")),
194
+ call(TaskQueryParams(id="task123"))
195
+ ]
196
+ self.mock_client.get_task.assert_has_calls(expected_calls)
197
+
198
+ # Verify sleep was called twice (once for each polling iteration)
199
+ self.assertEqual(mock_sleep.call_count, 2)
200
+ mock_sleep.assert_called_with(0.1)
201
+
202
+ async def test_send_message_with_polling_polls_until_interrupted(self):
203
+ """Test send_message_with_polling polls until task reaches interrupted state."""
204
+ # Create initial task in working state
205
+ initial_task = MagicMock(spec=Task)
206
+ initial_task.id = "task456"
207
+ initial_task.status = MagicMock(spec=TaskStatus)
208
+ initial_task.status.state = TaskState.working
209
+
210
+ # Create final task in input_required state (interrupted)
211
+ final_task = MagicMock(spec=Task)
212
+ final_task.status = MagicMock(spec=TaskStatus)
213
+ final_task.status.state = TaskState.input_required
214
+
215
+ # Mock send_message to return initial working task
216
+ self.connection.send_message = AsyncMock(return_value=initial_task)
217
+
218
+ # Mock get_task to return final task
219
+ self.mock_client.get_task.return_value = final_task
220
+
221
+ test_message = MagicMock(spec=Message)
222
+
223
+ with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
224
+ result = await self.connection.send_message_with_polling(
225
+ test_message, sleep_time=0.05, max_iter=5
226
+ )
227
+
228
+ self.assertEqual(result, final_task)
229
+ self.mock_client.get_task.assert_called_once_with(TaskQueryParams(id="task456"))
230
+ mock_sleep.assert_called_once_with(0.05)
231
+
232
+ async def test_send_message_with_polling_timeout_exception(self):
233
+ """Test send_message_with_polling raises exception when max_iter is reached."""
234
+ # Create initial task in working state
235
+ initial_task = MagicMock(spec=Task)
236
+ initial_task.id = "task789"
237
+ initial_task.status = MagicMock(spec=TaskStatus)
238
+ initial_task.status.state = TaskState.working
239
+
240
+ # Create task that stays in working state
241
+ working_task = MagicMock(spec=Task)
242
+ working_task.status = MagicMock(spec=TaskStatus)
243
+ working_task.status.state = TaskState.working
244
+
245
+ # Mock send_message to return initial working task
246
+ self.connection.send_message = AsyncMock(return_value=initial_task)
247
+
248
+ # Mock get_task to always return working task
249
+ self.mock_client.get_task.return_value = working_task
250
+
251
+ test_message = MagicMock(spec=Message)
252
+
253
+ with patch('asyncio.sleep', new_callable=AsyncMock):
254
+ with self.assertRaises(Exception) as context:
255
+ await self.connection.send_message_with_polling(
256
+ test_message, sleep_time=0.1, max_iter=3
257
+ )
258
+
259
+ expected_timeout = 3 * 0.1 # max_iter * sleep_time
260
+ self.assertIn(f"Task did not complete in {expected_timeout} seconds", str(context.exception))
261
+
262
+ # Verify get_task was called max_iter times
263
+ self.assertEqual(self.mock_client.get_task.call_count, 3)
264
+
265
+ async def test_send_message_with_polling_no_task_returned(self):
266
+ """Test send_message_with_polling raises ValueError when send_message returns None."""
267
+ # Mock send_message to return None
268
+ self.connection.send_message = AsyncMock(return_value=None)
269
+
270
+ test_message = MagicMock(spec=Message)
271
+
272
+ with self.assertRaises(ValueError) as context:
273
+ await self.connection.send_message_with_polling(test_message)
274
+
275
+ self.assertEqual(str(context.exception), "No task or message returned from send_message")
276
+ self.connection.send_message.assert_called_once_with(test_message)
277
+
278
+ async def test_send_message_with_polling_custom_parameters(self):
279
+ """Test send_message_with_polling with custom sleep_time and max_iter."""
280
+ # Create initial task in working state
281
+ initial_task = MagicMock(spec=Task)
282
+ initial_task.id = "custom_task"
283
+ initial_task.status = MagicMock(spec=TaskStatus)
284
+ initial_task.status.state = TaskState.working
285
+
286
+ # Create final task in completed state
287
+ final_task = MagicMock(spec=Task)
288
+ final_task.status = MagicMock(spec=TaskStatus)
289
+ final_task.status.state = TaskState.completed
290
+
291
+ # Mock send_message to return initial working task
292
+ self.connection.send_message = AsyncMock(return_value=initial_task)
293
+
294
+ # Mock get_task to return final task
295
+ self.mock_client.get_task.return_value = final_task
296
+
297
+ test_message = MagicMock(spec=Message)
298
+
299
+ with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
300
+ result = await self.connection.send_message_with_polling(
301
+ test_message, sleep_time=0.5, max_iter=100
302
+ )
303
+
304
+ self.assertEqual(result, final_task)
305
+ mock_sleep.assert_called_once_with(0.5)
306
+
307
+ async def test_send_message_with_polling_default_parameters(self):
308
+ """Test send_message_with_polling uses default parameters correctly."""
309
+ # Create initial task in working state
310
+ initial_task = MagicMock(spec=Task)
311
+ initial_task.id = "default_task"
312
+ initial_task.status = MagicMock(spec=TaskStatus)
313
+ initial_task.status.state = TaskState.working
314
+
315
+ # Create final task in failed state (terminal)
316
+ final_task = MagicMock(spec=Task)
317
+ final_task.status = MagicMock(spec=TaskStatus)
318
+ final_task.status.state = TaskState.failed
319
+
320
+ # Mock send_message to return initial working task
321
+ self.connection.send_message = AsyncMock(return_value=initial_task)
322
+
323
+ # Mock get_task to return final task
324
+ self.mock_client.get_task.return_value = final_task
325
+
326
+ test_message = MagicMock(spec=Message)
327
+
328
+ with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
329
+ result = await self.connection.send_message_with_polling(test_message)
330
+
331
+ self.assertEqual(result, final_task)
332
+ # Default sleep_time is 0.2
333
+ mock_sleep.assert_called_once_with(0.2)
334
+
335
+
336
+ class TestHelperFunctions(unittest.TestCase):
337
+ """Tests for helper functions in the remote agent connection module."""
338
+
339
+ def test_is_in_terminal_state_completed(self):
340
+ """Test is_in_terminal_state returns True for completed state."""
341
+ mock_task = MagicMock(spec=Task)
342
+ mock_task.status = MagicMock(spec=TaskStatus)
343
+ mock_task.status.state = TaskState.completed
344
+
345
+ result = is_in_terminal_state(mock_task)
346
+ self.assertTrue(result)
347
+
348
+ def test_is_in_terminal_state_canceled(self):
349
+ """Test is_in_terminal_state returns True for canceled state."""
350
+ mock_task = MagicMock(spec=Task)
351
+ mock_task.status = MagicMock(spec=TaskStatus)
352
+ mock_task.status.state = TaskState.canceled
353
+
354
+ result = is_in_terminal_state(mock_task)
355
+ self.assertTrue(result)
356
+
357
+ def test_is_in_terminal_state_failed(self):
358
+ """Test is_in_terminal_state returns True for failed state."""
359
+ mock_task = MagicMock(spec=Task)
360
+ mock_task.status = MagicMock(spec=TaskStatus)
361
+ mock_task.status.state = TaskState.failed
362
+
363
+ result = is_in_terminal_state(mock_task)
364
+ self.assertTrue(result)
365
+
366
+ def test_is_in_terminal_state_running(self):
367
+ """Test is_in_terminal_state returns False for running state."""
368
+ mock_task = MagicMock(spec=Task)
369
+ mock_task.status = MagicMock(spec=TaskStatus)
370
+ mock_task.status.state = TaskState.working
371
+
372
+ result = is_in_terminal_state(mock_task)
373
+ self.assertFalse(result)
374
+
375
+ def test_is_in_terminal_or_interrupted_state_input_required(self):
376
+ """Test is_in_terminal_or_interrupted_state returns True for input_required state."""
377
+ mock_task = MagicMock(spec=Task)
378
+ mock_task.status = MagicMock(spec=TaskStatus)
379
+ mock_task.status.state = TaskState.input_required
380
+
381
+ result = is_in_terminal_or_interrupted_state(mock_task)
382
+ self.assertTrue(result)
383
+
384
+ def test_is_in_terminal_or_interrupted_state_unknown(self):
385
+ """Test is_in_terminal_or_interrupted_state returns True for unknown state."""
386
+ mock_task = MagicMock(spec=Task)
387
+ mock_task.status = MagicMock(spec=TaskStatus)
388
+ mock_task.status.state = TaskState.unknown
389
+
390
+ result = is_in_terminal_or_interrupted_state(mock_task)
391
+ self.assertTrue(result)
392
+
393
+ def test_is_in_terminal_or_interrupted_state_completed(self):
394
+ """Test is_in_terminal_or_interrupted_state returns True for terminal states."""
395
+ mock_task = MagicMock(spec=Task)
396
+ mock_task.status = MagicMock(spec=TaskStatus)
397
+ mock_task.status.state = TaskState.completed
398
+
399
+ result = is_in_terminal_or_interrupted_state(mock_task)
400
+ self.assertTrue(result)
401
+
402
+ def test_is_in_terminal_or_interrupted_state_running(self):
403
+ """Test is_in_terminal_or_interrupted_state returns False for running state."""
404
+ mock_task = MagicMock(spec=Task)
405
+ mock_task.status = MagicMock(spec=TaskStatus)
406
+ mock_task.status.state = TaskState.working
407
+
408
+ result = is_in_terminal_or_interrupted_state(mock_task)
409
+ self.assertFalse(result)
410
+
411
+
412
+ if __name__ == '__main__':
413
+ unittest.main()
@@ -0,0 +1,208 @@
1
+ """Tests for the A2A utils module."""
2
+
3
+ import unittest
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ import httpx
7
+ from a2a.client import ClientConfig, ClientFactory, A2ACardResolver
8
+ from a2a.server.agent_execution import RequestContext
9
+ from a2a.types import AgentCard
10
+
11
+ from aixtools.a2a.google_sdk.utils import (
12
+ _AgentCardResolver,
13
+ get_a2a_clients,
14
+ get_session_id_tuple,
15
+ )
16
+ from aixtools.a2a.google_sdk.remote_agent_connection import RemoteAgentConnection
17
+ from aixtools.context import DEFAULT_USER_ID, DEFAULT_SESSION_ID
18
+
19
+
20
+ class TestAgentCardResolver(unittest.IsolatedAsyncioTestCase):
21
+ """Tests for the _AgentCardResolver class."""
22
+
23
+ def setUp(self):
24
+ self.mock_client = AsyncMock(spec=httpx.AsyncClient)
25
+ self.resolver = _AgentCardResolver(self.mock_client)
26
+
27
+ @patch("aixtools.a2a.google_sdk.utils.ClientFactory")
28
+ def test_init(self, mock_client_factory_class):
29
+ """Test _AgentCardResolver initialization."""
30
+ mock_factory = MagicMock(spec=ClientFactory)
31
+ mock_client_factory_class.return_value = mock_factory
32
+
33
+ resolver = _AgentCardResolver(self.mock_client)
34
+
35
+ # Verify ClientFactory was created with correct config
36
+ mock_client_factory_class.assert_called_once()
37
+ call_args = mock_client_factory_class.call_args[0][0]
38
+ self.assertIsInstance(call_args, ClientConfig)
39
+ self.assertEqual(call_args.httpx_client, self.mock_client)
40
+
41
+ # Verify attributes are set
42
+ self.assertEqual(resolver._httpx_client, self.mock_client)
43
+ self.assertEqual(resolver._a2a_client_factory, mock_factory)
44
+ self.assertEqual(resolver.clients, {})
45
+
46
+ @patch("aixtools.a2a.google_sdk.utils.RemoteAgentConnection")
47
+ def test_register_agent_card(self, mock_connection_class):
48
+ """Test registering an agent card."""
49
+ mock_card = MagicMock(spec=AgentCard)
50
+ mock_card.name = "test_agent"
51
+
52
+ mock_client = MagicMock()
53
+ mock_factory = MagicMock(spec=ClientFactory)
54
+ mock_factory.create.return_value = mock_client
55
+ self.resolver._a2a_client_factory = mock_factory
56
+
57
+ mock_connection = MagicMock(spec=RemoteAgentConnection)
58
+ mock_connection_class.return_value = mock_connection
59
+
60
+ self.resolver.register_agent_card(mock_card)
61
+
62
+ # Verify client was created
63
+ mock_factory.create.assert_called_once_with(mock_card)
64
+
65
+ # Verify RemoteAgentConnection was created
66
+ mock_connection_class.assert_called_once_with(mock_card, mock_client)
67
+
68
+ # Verify connection was stored
69
+ self.assertEqual(self.resolver.clients["test_agent"], mock_connection)
70
+
71
+ @patch("aixtools.a2a.google_sdk.utils.A2ACardResolver")
72
+ async def test_retrieve_card(self, mock_resolver_class):
73
+ """Test retrieving a card from an address."""
74
+ mock_resolver = AsyncMock(spec=A2ACardResolver)
75
+ mock_resolver_class.return_value = mock_resolver
76
+
77
+ mock_card = MagicMock(spec=AgentCard)
78
+ mock_card.name = "test_agent"
79
+ mock_resolver.get_agent_card.return_value = mock_card
80
+
81
+ with patch.object(self.resolver, 'register_agent_card') as mock_register:
82
+ await self.resolver.retrieve_card("http://test.com")
83
+
84
+ # Verify resolver was created correctly (it tries the first card path)
85
+ mock_resolver_class.assert_called_with(self.mock_client, "http://test.com", "/.well-known/agent-card.json")
86
+
87
+ # Verify card was retrieved
88
+ mock_resolver.get_agent_card.assert_called_once()
89
+
90
+ # Verify card was registered
91
+ mock_register.assert_called_once_with(mock_card)
92
+
93
+ async def test_get_a2a_clients(self):
94
+ """Test getting A2A clients for multiple hosts."""
95
+ agent_hosts = ["http://agent1.com", "http://agent2.com"]
96
+
97
+ with patch.object(self.resolver, 'retrieve_card') as mock_retrieve:
98
+ mock_retrieve.return_value = None # Mock async function
99
+
100
+ # Set up some mock clients
101
+ mock_connection1 = MagicMock(spec=RemoteAgentConnection)
102
+ mock_connection2 = MagicMock(spec=RemoteAgentConnection)
103
+ self.resolver.clients = {
104
+ "agent1": mock_connection1,
105
+ "agent2": mock_connection2
106
+ }
107
+
108
+ result = await self.resolver.get_a2a_clients(agent_hosts)
109
+
110
+ # Verify retrieve_card was called for each host
111
+ self.assertEqual(mock_retrieve.call_count, 2)
112
+ mock_retrieve.assert_any_call("http://agent1.com")
113
+ mock_retrieve.assert_any_call("http://agent2.com")
114
+
115
+ # Verify result contains the clients
116
+ self.assertEqual(result, self.resolver.clients)
117
+
118
+
119
+ class TestGetA2AClients(unittest.IsolatedAsyncioTestCase):
120
+ """Tests for the get_a2a_clients function."""
121
+
122
+ @patch("aixtools.a2a.google_sdk.utils._AgentCardResolver")
123
+ @patch("aixtools.a2a.google_sdk.utils.httpx.AsyncClient")
124
+ async def test_get_a2a_clients(self, mock_client_class, mock_resolver_class):
125
+ """Test the get_a2a_clients function."""
126
+ mock_client = AsyncMock()
127
+ mock_client_class.return_value = mock_client
128
+
129
+ mock_resolver = AsyncMock(spec=_AgentCardResolver)
130
+ mock_resolver_class.return_value = mock_resolver
131
+
132
+ mock_clients = {"agent1": MagicMock(), "agent2": MagicMock()}
133
+ mock_resolver.get_a2a_clients.return_value = mock_clients
134
+
135
+ ctx = ("user123", "session456")
136
+ agent_hosts = ["http://agent1.com", "http://agent2.com"]
137
+
138
+ result = await get_a2a_clients(ctx, agent_hosts)
139
+
140
+ # Verify httpx client was created with correct headers
141
+ expected_headers = {
142
+ "user-id": "user123",
143
+ "session-id": "session456",
144
+ }
145
+ mock_client_class.assert_called_once_with(headers=expected_headers, timeout=60.0)
146
+
147
+ # Verify resolver was created with the client
148
+ mock_resolver_class.assert_called_once_with(mock_client)
149
+
150
+ # Verify get_a2a_clients was called with the hosts
151
+ mock_resolver.get_a2a_clients.assert_called_once_with(agent_hosts)
152
+
153
+ # Verify result
154
+ self.assertEqual(result, mock_clients)
155
+
156
+
157
+ class TestGetSessionIdTuple(unittest.TestCase):
158
+ """Tests for the get_session_id_tuple function."""
159
+
160
+ def test_get_session_id_tuple_with_headers(self):
161
+ """Test getting session ID tuple when headers are present."""
162
+ mock_context = MagicMock(spec=RequestContext)
163
+ mock_context.call_context.state = {
164
+ "headers": {
165
+ "user-id": "test_user",
166
+ "session-id": "test_session"
167
+ }
168
+ }
169
+
170
+ result = get_session_id_tuple(mock_context)
171
+
172
+ self.assertEqual(result, ("test_user", "test_session"))
173
+
174
+ def test_get_session_id_tuple_with_partial_headers(self):
175
+ """Test getting session ID tuple when only some headers are present."""
176
+ mock_context = MagicMock(spec=RequestContext)
177
+ mock_context.call_context.state = {
178
+ "headers": {
179
+ "user-id": "test_user"
180
+ # session-id is missing
181
+ }
182
+ }
183
+
184
+ result = get_session_id_tuple(mock_context)
185
+
186
+ self.assertEqual(result, ("test_user", DEFAULT_SESSION_ID))
187
+
188
+ def test_get_session_id_tuple_no_headers(self):
189
+ """Test getting session ID tuple when no headers are present."""
190
+ mock_context = MagicMock(spec=RequestContext)
191
+ mock_context.call_context.state = {}
192
+
193
+ result = get_session_id_tuple(mock_context)
194
+
195
+ self.assertEqual(result, (DEFAULT_USER_ID, DEFAULT_SESSION_ID))
196
+
197
+ def test_get_session_id_tuple_empty_headers(self):
198
+ """Test getting session ID tuple when headers dict is empty."""
199
+ mock_context = MagicMock(spec=RequestContext)
200
+ mock_context.call_context.state = {"headers": {}}
201
+
202
+ result = get_session_id_tuple(mock_context)
203
+
204
+ self.assertEqual(result, (DEFAULT_USER_ID, DEFAULT_SESSION_ID))
205
+
206
+
207
+ if __name__ == '__main__':
208
+ unittest.main()
File without changes