aixtools 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/_version.py +2 -2
- aixtools/a2a/app.py +1 -1
- aixtools/a2a/google_sdk/__init__.py +0 -0
- aixtools/a2a/google_sdk/card.py +27 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
- aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
- aixtools/a2a/google_sdk/utils.py +59 -0
- aixtools/agents/prompt.py +97 -0
- aixtools/context.py +5 -0
- aixtools/google/client.py +25 -0
- aixtools/logging/logging_config.py +45 -0
- aixtools/mcp/client.py +274 -0
- aixtools/mcp/faulty_mcp.py +7 -7
- aixtools/server/utils.py +3 -3
- aixtools/utils/config.py +6 -0
- aixtools/utils/files.py +17 -0
- aixtools/utils/utils.py +7 -0
- {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/METADATA +3 -1
- {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/RECORD +45 -14
- {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/top_level.txt +1 -0
- scripts/test.sh +23 -0
- tests/__init__.py +0 -0
- tests/unit/__init__.py +0 -0
- tests/unit/a2a/__init__.py +0 -0
- tests/unit/a2a/google_sdk/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +188 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +156 -0
- tests/unit/a2a/google_sdk/test_card.py +114 -0
- tests/unit/a2a/google_sdk/test_remote_agent_connection.py +413 -0
- tests/unit/a2a/google_sdk/test_utils.py +208 -0
- tests/unit/agents/__init__.py +0 -0
- tests/unit/agents/test_prompt.py +363 -0
- tests/unit/google/__init__.py +1 -0
- tests/unit/google/test_client.py +233 -0
- tests/unit/mcp/__init__.py +0 -0
- tests/unit/mcp/test_client.py +242 -0
- tests/unit/server/__init__.py +0 -0
- tests/unit/server/test_path.py +225 -0
- tests/unit/server/test_utils.py +362 -0
- tests/unit/utils/__init__.py +0 -0
- tests/unit/utils/test_files.py +146 -0
- aixtools/a2a/__init__.py +0 -5
- {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/WHEEL +0 -0
- {aixtools-0.1.3.dist-info → aixtools-0.1.5.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""Tests for the remote agent connection module."""
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
5
|
+
|
|
6
|
+
from a2a.client import Client
|
|
7
|
+
from a2a.types import AgentCard, Message, Task, TaskState, TaskStatus, TaskQueryParams
|
|
8
|
+
|
|
9
|
+
from aixtools.a2a.google_sdk.remote_agent_connection import (
|
|
10
|
+
RemoteAgentConnection,
|
|
11
|
+
is_in_terminal_state,
|
|
12
|
+
is_in_terminal_or_interrupted_state,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestRemoteAgentConnection(unittest.IsolatedAsyncioTestCase):
|
|
17
|
+
"""Tests for the RemoteAgentConnection class."""
|
|
18
|
+
|
|
19
|
+
def setUp(self):
|
|
20
|
+
self.mock_card = MagicMock(spec=AgentCard)
|
|
21
|
+
self.mock_client = AsyncMock(spec=Client)
|
|
22
|
+
self.connection = RemoteAgentConnection(self.mock_card, self.mock_client)
|
|
23
|
+
|
|
24
|
+
def test_get_agent_card(self):
|
|
25
|
+
"""Test that get_agent_card returns the stored card."""
|
|
26
|
+
result = self.connection.get_agent_card()
|
|
27
|
+
self.assertEqual(result, self.mock_card)
|
|
28
|
+
|
|
29
|
+
async def test_send_message_returns_message(self):
|
|
30
|
+
"""Test send_message when it receives a Message response."""
|
|
31
|
+
mock_message = MagicMock(spec=Message)
|
|
32
|
+
mock_task = MagicMock(spec=Task)
|
|
33
|
+
|
|
34
|
+
# Mock the async generator to yield a message
|
|
35
|
+
async def mock_generator():
|
|
36
|
+
yield mock_message
|
|
37
|
+
yield (mock_task,)
|
|
38
|
+
|
|
39
|
+
self.mock_client.send_message.return_value = mock_generator()
|
|
40
|
+
|
|
41
|
+
test_message = MagicMock(spec=Message)
|
|
42
|
+
result = await self.connection.send_message(test_message)
|
|
43
|
+
|
|
44
|
+
self.assertEqual(result, mock_message)
|
|
45
|
+
self.mock_client.send_message.assert_called_once_with(test_message)
|
|
46
|
+
|
|
47
|
+
async def test_send_message_returns_terminal_task(self):
|
|
48
|
+
"""Test send_message when it receives a task in terminal state."""
|
|
49
|
+
mock_task = MagicMock(spec=Task)
|
|
50
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
51
|
+
mock_task.status.state = TaskState.completed
|
|
52
|
+
|
|
53
|
+
# Mock the async generator to yield a terminal task
|
|
54
|
+
async def mock_generator():
|
|
55
|
+
yield (mock_task,)
|
|
56
|
+
|
|
57
|
+
self.mock_client.send_message.return_value = mock_generator()
|
|
58
|
+
|
|
59
|
+
test_message = MagicMock(spec=Message)
|
|
60
|
+
result = await self.connection.send_message(test_message)
|
|
61
|
+
|
|
62
|
+
self.assertEqual(result, mock_task)
|
|
63
|
+
|
|
64
|
+
async def test_send_message_returns_interrupted_task(self):
|
|
65
|
+
"""Test send_message when it receives a task in interrupted state."""
|
|
66
|
+
mock_task = MagicMock(spec=Task)
|
|
67
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
68
|
+
mock_task.status.state = TaskState.input_required
|
|
69
|
+
|
|
70
|
+
# Mock the async generator to yield an interrupted task
|
|
71
|
+
async def mock_generator():
|
|
72
|
+
yield (mock_task,)
|
|
73
|
+
|
|
74
|
+
self.mock_client.send_message.return_value = mock_generator()
|
|
75
|
+
|
|
76
|
+
test_message = MagicMock(spec=Message)
|
|
77
|
+
result = await self.connection.send_message(test_message)
|
|
78
|
+
|
|
79
|
+
self.assertEqual(result, mock_task)
|
|
80
|
+
|
|
81
|
+
async def test_send_message_returns_last_task(self):
|
|
82
|
+
"""Test send_message returns the last task when no terminal/interrupted state is reached."""
|
|
83
|
+
mock_task1 = MagicMock(spec=Task)
|
|
84
|
+
mock_task1.status = MagicMock(spec=TaskStatus)
|
|
85
|
+
mock_task1.status.state = TaskState.working
|
|
86
|
+
|
|
87
|
+
mock_task2 = MagicMock(spec=Task)
|
|
88
|
+
mock_task2.status = MagicMock(spec=TaskStatus)
|
|
89
|
+
mock_task2.status.state = TaskState.working
|
|
90
|
+
|
|
91
|
+
# Mock the async generator to yield multiple running tasks
|
|
92
|
+
async def mock_generator():
|
|
93
|
+
yield (mock_task1,)
|
|
94
|
+
yield (mock_task2,)
|
|
95
|
+
|
|
96
|
+
self.mock_client.send_message.return_value = mock_generator()
|
|
97
|
+
|
|
98
|
+
test_message = MagicMock(spec=Message)
|
|
99
|
+
result = await self.connection.send_message(test_message)
|
|
100
|
+
|
|
101
|
+
self.assertEqual(result, mock_task2)
|
|
102
|
+
|
|
103
|
+
async def test_send_message_handles_exception(self):
|
|
104
|
+
"""Test send_message properly handles and re-raises exceptions."""
|
|
105
|
+
test_error = Exception("Test error")
|
|
106
|
+
self.mock_client.send_message.side_effect = test_error
|
|
107
|
+
|
|
108
|
+
test_message = MagicMock(spec=Message)
|
|
109
|
+
|
|
110
|
+
with self.assertRaises(Exception) as context:
|
|
111
|
+
await self.connection.send_message(test_message)
|
|
112
|
+
|
|
113
|
+
self.assertEqual(context.exception, test_error)
|
|
114
|
+
|
|
115
|
+
async def test_send_message_returns_none_when_no_events(self):
|
|
116
|
+
"""Test send_message returns None when no events are yielded."""
|
|
117
|
+
# Mock empty async generator
|
|
118
|
+
async def mock_generator():
|
|
119
|
+
return
|
|
120
|
+
yield # This line will never be reached
|
|
121
|
+
|
|
122
|
+
self.mock_client.send_message.return_value = mock_generator()
|
|
123
|
+
|
|
124
|
+
test_message = MagicMock(spec=Message)
|
|
125
|
+
result = await self.connection.send_message(test_message)
|
|
126
|
+
|
|
127
|
+
self.assertIsNone(result)
|
|
128
|
+
|
|
129
|
+
async def test_send_message_with_polling_returns_message(self):
|
|
130
|
+
"""Test send_message_with_polling when send_message returns a Message."""
|
|
131
|
+
mock_message = MagicMock(spec=Message)
|
|
132
|
+
|
|
133
|
+
# Mock send_message to return a Message
|
|
134
|
+
self.connection.send_message = AsyncMock(return_value=mock_message)
|
|
135
|
+
|
|
136
|
+
test_message = MagicMock(spec=Message)
|
|
137
|
+
result = await self.connection.send_message_with_polling(test_message)
|
|
138
|
+
|
|
139
|
+
self.assertEqual(result, mock_message)
|
|
140
|
+
self.connection.send_message.assert_called_once_with(test_message)
|
|
141
|
+
|
|
142
|
+
async def test_send_message_with_polling_returns_terminal_task(self):
|
|
143
|
+
"""Test send_message_with_polling when send_message returns a task in terminal state."""
|
|
144
|
+
mock_task = MagicMock(spec=Task)
|
|
145
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
146
|
+
mock_task.status.state = TaskState.completed
|
|
147
|
+
|
|
148
|
+
# Mock send_message to return a terminal task
|
|
149
|
+
self.connection.send_message = AsyncMock(return_value=mock_task)
|
|
150
|
+
|
|
151
|
+
test_message = MagicMock(spec=Message)
|
|
152
|
+
result = await self.connection.send_message_with_polling(test_message)
|
|
153
|
+
|
|
154
|
+
self.assertEqual(result, mock_task)
|
|
155
|
+
self.connection.send_message.assert_called_once_with(test_message)
|
|
156
|
+
|
|
157
|
+
async def test_send_message_with_polling_polls_until_terminal(self):
|
|
158
|
+
"""Test send_message_with_polling polls until task reaches terminal state."""
|
|
159
|
+
# Create initial task in working state
|
|
160
|
+
initial_task = MagicMock(spec=Task)
|
|
161
|
+
initial_task.id = "task123"
|
|
162
|
+
initial_task.status = MagicMock(spec=TaskStatus)
|
|
163
|
+
initial_task.status.state = TaskState.working
|
|
164
|
+
|
|
165
|
+
# Create intermediate task still in working state
|
|
166
|
+
intermediate_task = MagicMock(spec=Task)
|
|
167
|
+
intermediate_task.status = MagicMock(spec=TaskStatus)
|
|
168
|
+
intermediate_task.status.state = TaskState.working
|
|
169
|
+
|
|
170
|
+
# Create final task in completed state
|
|
171
|
+
final_task = MagicMock(spec=Task)
|
|
172
|
+
final_task.status = MagicMock(spec=TaskStatus)
|
|
173
|
+
final_task.status.state = TaskState.completed
|
|
174
|
+
|
|
175
|
+
# Mock send_message to return initial working task
|
|
176
|
+
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
177
|
+
|
|
178
|
+
# Mock get_task to return intermediate task first, then final task
|
|
179
|
+
self.mock_client.get_task.side_effect = [intermediate_task, final_task]
|
|
180
|
+
|
|
181
|
+
test_message = MagicMock(spec=Message)
|
|
182
|
+
|
|
183
|
+
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
184
|
+
result = await self.connection.send_message_with_polling(
|
|
185
|
+
test_message, sleep_time=0.1, max_iter=10
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self.assertEqual(result, final_task)
|
|
189
|
+
self.connection.send_message.assert_called_once_with(test_message)
|
|
190
|
+
|
|
191
|
+
# Verify get_task was called with correct parameters
|
|
192
|
+
expected_calls = [
|
|
193
|
+
call(TaskQueryParams(id="task123")),
|
|
194
|
+
call(TaskQueryParams(id="task123"))
|
|
195
|
+
]
|
|
196
|
+
self.mock_client.get_task.assert_has_calls(expected_calls)
|
|
197
|
+
|
|
198
|
+
# Verify sleep was called twice (once for each polling iteration)
|
|
199
|
+
self.assertEqual(mock_sleep.call_count, 2)
|
|
200
|
+
mock_sleep.assert_called_with(0.1)
|
|
201
|
+
|
|
202
|
+
async def test_send_message_with_polling_polls_until_interrupted(self):
|
|
203
|
+
"""Test send_message_with_polling polls until task reaches interrupted state."""
|
|
204
|
+
# Create initial task in working state
|
|
205
|
+
initial_task = MagicMock(spec=Task)
|
|
206
|
+
initial_task.id = "task456"
|
|
207
|
+
initial_task.status = MagicMock(spec=TaskStatus)
|
|
208
|
+
initial_task.status.state = TaskState.working
|
|
209
|
+
|
|
210
|
+
# Create final task in input_required state (interrupted)
|
|
211
|
+
final_task = MagicMock(spec=Task)
|
|
212
|
+
final_task.status = MagicMock(spec=TaskStatus)
|
|
213
|
+
final_task.status.state = TaskState.input_required
|
|
214
|
+
|
|
215
|
+
# Mock send_message to return initial working task
|
|
216
|
+
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
217
|
+
|
|
218
|
+
# Mock get_task to return final task
|
|
219
|
+
self.mock_client.get_task.return_value = final_task
|
|
220
|
+
|
|
221
|
+
test_message = MagicMock(spec=Message)
|
|
222
|
+
|
|
223
|
+
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
224
|
+
result = await self.connection.send_message_with_polling(
|
|
225
|
+
test_message, sleep_time=0.05, max_iter=5
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.assertEqual(result, final_task)
|
|
229
|
+
self.mock_client.get_task.assert_called_once_with(TaskQueryParams(id="task456"))
|
|
230
|
+
mock_sleep.assert_called_once_with(0.05)
|
|
231
|
+
|
|
232
|
+
async def test_send_message_with_polling_timeout_exception(self):
|
|
233
|
+
"""Test send_message_with_polling raises exception when max_iter is reached."""
|
|
234
|
+
# Create initial task in working state
|
|
235
|
+
initial_task = MagicMock(spec=Task)
|
|
236
|
+
initial_task.id = "task789"
|
|
237
|
+
initial_task.status = MagicMock(spec=TaskStatus)
|
|
238
|
+
initial_task.status.state = TaskState.working
|
|
239
|
+
|
|
240
|
+
# Create task that stays in working state
|
|
241
|
+
working_task = MagicMock(spec=Task)
|
|
242
|
+
working_task.status = MagicMock(spec=TaskStatus)
|
|
243
|
+
working_task.status.state = TaskState.working
|
|
244
|
+
|
|
245
|
+
# Mock send_message to return initial working task
|
|
246
|
+
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
247
|
+
|
|
248
|
+
# Mock get_task to always return working task
|
|
249
|
+
self.mock_client.get_task.return_value = working_task
|
|
250
|
+
|
|
251
|
+
test_message = MagicMock(spec=Message)
|
|
252
|
+
|
|
253
|
+
with patch('asyncio.sleep', new_callable=AsyncMock):
|
|
254
|
+
with self.assertRaises(Exception) as context:
|
|
255
|
+
await self.connection.send_message_with_polling(
|
|
256
|
+
test_message, sleep_time=0.1, max_iter=3
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
expected_timeout = 3 * 0.1 # max_iter * sleep_time
|
|
260
|
+
self.assertIn(f"Task did not complete in {expected_timeout} seconds", str(context.exception))
|
|
261
|
+
|
|
262
|
+
# Verify get_task was called max_iter times
|
|
263
|
+
self.assertEqual(self.mock_client.get_task.call_count, 3)
|
|
264
|
+
|
|
265
|
+
async def test_send_message_with_polling_no_task_returned(self):
|
|
266
|
+
"""Test send_message_with_polling raises ValueError when send_message returns None."""
|
|
267
|
+
# Mock send_message to return None
|
|
268
|
+
self.connection.send_message = AsyncMock(return_value=None)
|
|
269
|
+
|
|
270
|
+
test_message = MagicMock(spec=Message)
|
|
271
|
+
|
|
272
|
+
with self.assertRaises(ValueError) as context:
|
|
273
|
+
await self.connection.send_message_with_polling(test_message)
|
|
274
|
+
|
|
275
|
+
self.assertEqual(str(context.exception), "No task or message returned from send_message")
|
|
276
|
+
self.connection.send_message.assert_called_once_with(test_message)
|
|
277
|
+
|
|
278
|
+
async def test_send_message_with_polling_custom_parameters(self):
|
|
279
|
+
"""Test send_message_with_polling with custom sleep_time and max_iter."""
|
|
280
|
+
# Create initial task in working state
|
|
281
|
+
initial_task = MagicMock(spec=Task)
|
|
282
|
+
initial_task.id = "custom_task"
|
|
283
|
+
initial_task.status = MagicMock(spec=TaskStatus)
|
|
284
|
+
initial_task.status.state = TaskState.working
|
|
285
|
+
|
|
286
|
+
# Create final task in completed state
|
|
287
|
+
final_task = MagicMock(spec=Task)
|
|
288
|
+
final_task.status = MagicMock(spec=TaskStatus)
|
|
289
|
+
final_task.status.state = TaskState.completed
|
|
290
|
+
|
|
291
|
+
# Mock send_message to return initial working task
|
|
292
|
+
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
293
|
+
|
|
294
|
+
# Mock get_task to return final task
|
|
295
|
+
self.mock_client.get_task.return_value = final_task
|
|
296
|
+
|
|
297
|
+
test_message = MagicMock(spec=Message)
|
|
298
|
+
|
|
299
|
+
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
300
|
+
result = await self.connection.send_message_with_polling(
|
|
301
|
+
test_message, sleep_time=0.5, max_iter=100
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
self.assertEqual(result, final_task)
|
|
305
|
+
mock_sleep.assert_called_once_with(0.5)
|
|
306
|
+
|
|
307
|
+
async def test_send_message_with_polling_default_parameters(self):
|
|
308
|
+
"""Test send_message_with_polling uses default parameters correctly."""
|
|
309
|
+
# Create initial task in working state
|
|
310
|
+
initial_task = MagicMock(spec=Task)
|
|
311
|
+
initial_task.id = "default_task"
|
|
312
|
+
initial_task.status = MagicMock(spec=TaskStatus)
|
|
313
|
+
initial_task.status.state = TaskState.working
|
|
314
|
+
|
|
315
|
+
# Create final task in failed state (terminal)
|
|
316
|
+
final_task = MagicMock(spec=Task)
|
|
317
|
+
final_task.status = MagicMock(spec=TaskStatus)
|
|
318
|
+
final_task.status.state = TaskState.failed
|
|
319
|
+
|
|
320
|
+
# Mock send_message to return initial working task
|
|
321
|
+
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
322
|
+
|
|
323
|
+
# Mock get_task to return final task
|
|
324
|
+
self.mock_client.get_task.return_value = final_task
|
|
325
|
+
|
|
326
|
+
test_message = MagicMock(spec=Message)
|
|
327
|
+
|
|
328
|
+
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
329
|
+
result = await self.connection.send_message_with_polling(test_message)
|
|
330
|
+
|
|
331
|
+
self.assertEqual(result, final_task)
|
|
332
|
+
# Default sleep_time is 0.2
|
|
333
|
+
mock_sleep.assert_called_once_with(0.2)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class TestHelperFunctions(unittest.TestCase):
|
|
337
|
+
"""Tests for helper functions in the remote agent connection module."""
|
|
338
|
+
|
|
339
|
+
def test_is_in_terminal_state_completed(self):
|
|
340
|
+
"""Test is_in_terminal_state returns True for completed state."""
|
|
341
|
+
mock_task = MagicMock(spec=Task)
|
|
342
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
343
|
+
mock_task.status.state = TaskState.completed
|
|
344
|
+
|
|
345
|
+
result = is_in_terminal_state(mock_task)
|
|
346
|
+
self.assertTrue(result)
|
|
347
|
+
|
|
348
|
+
def test_is_in_terminal_state_canceled(self):
|
|
349
|
+
"""Test is_in_terminal_state returns True for canceled state."""
|
|
350
|
+
mock_task = MagicMock(spec=Task)
|
|
351
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
352
|
+
mock_task.status.state = TaskState.canceled
|
|
353
|
+
|
|
354
|
+
result = is_in_terminal_state(mock_task)
|
|
355
|
+
self.assertTrue(result)
|
|
356
|
+
|
|
357
|
+
def test_is_in_terminal_state_failed(self):
|
|
358
|
+
"""Test is_in_terminal_state returns True for failed state."""
|
|
359
|
+
mock_task = MagicMock(spec=Task)
|
|
360
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
361
|
+
mock_task.status.state = TaskState.failed
|
|
362
|
+
|
|
363
|
+
result = is_in_terminal_state(mock_task)
|
|
364
|
+
self.assertTrue(result)
|
|
365
|
+
|
|
366
|
+
def test_is_in_terminal_state_running(self):
|
|
367
|
+
"""Test is_in_terminal_state returns False for running state."""
|
|
368
|
+
mock_task = MagicMock(spec=Task)
|
|
369
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
370
|
+
mock_task.status.state = TaskState.working
|
|
371
|
+
|
|
372
|
+
result = is_in_terminal_state(mock_task)
|
|
373
|
+
self.assertFalse(result)
|
|
374
|
+
|
|
375
|
+
def test_is_in_terminal_or_interrupted_state_input_required(self):
|
|
376
|
+
"""Test is_in_terminal_or_interrupted_state returns True for input_required state."""
|
|
377
|
+
mock_task = MagicMock(spec=Task)
|
|
378
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
379
|
+
mock_task.status.state = TaskState.input_required
|
|
380
|
+
|
|
381
|
+
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
382
|
+
self.assertTrue(result)
|
|
383
|
+
|
|
384
|
+
def test_is_in_terminal_or_interrupted_state_unknown(self):
|
|
385
|
+
"""Test is_in_terminal_or_interrupted_state returns True for unknown state."""
|
|
386
|
+
mock_task = MagicMock(spec=Task)
|
|
387
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
388
|
+
mock_task.status.state = TaskState.unknown
|
|
389
|
+
|
|
390
|
+
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
391
|
+
self.assertTrue(result)
|
|
392
|
+
|
|
393
|
+
def test_is_in_terminal_or_interrupted_state_completed(self):
|
|
394
|
+
"""Test is_in_terminal_or_interrupted_state returns True for terminal states."""
|
|
395
|
+
mock_task = MagicMock(spec=Task)
|
|
396
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
397
|
+
mock_task.status.state = TaskState.completed
|
|
398
|
+
|
|
399
|
+
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
400
|
+
self.assertTrue(result)
|
|
401
|
+
|
|
402
|
+
def test_is_in_terminal_or_interrupted_state_running(self):
|
|
403
|
+
"""Test is_in_terminal_or_interrupted_state returns False for running state."""
|
|
404
|
+
mock_task = MagicMock(spec=Task)
|
|
405
|
+
mock_task.status = MagicMock(spec=TaskStatus)
|
|
406
|
+
mock_task.status.state = TaskState.working
|
|
407
|
+
|
|
408
|
+
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
409
|
+
self.assertFalse(result)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
if __name__ == '__main__':
|
|
413
|
+
unittest.main()
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Tests for the A2A utils module."""
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
from a2a.client import ClientConfig, ClientFactory, A2ACardResolver
|
|
8
|
+
from a2a.server.agent_execution import RequestContext
|
|
9
|
+
from a2a.types import AgentCard
|
|
10
|
+
|
|
11
|
+
from aixtools.a2a.google_sdk.utils import (
|
|
12
|
+
_AgentCardResolver,
|
|
13
|
+
get_a2a_clients,
|
|
14
|
+
get_session_id_tuple,
|
|
15
|
+
)
|
|
16
|
+
from aixtools.a2a.google_sdk.remote_agent_connection import RemoteAgentConnection
|
|
17
|
+
from aixtools.context import DEFAULT_USER_ID, DEFAULT_SESSION_ID
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestAgentCardResolver(unittest.IsolatedAsyncioTestCase):
|
|
21
|
+
"""Tests for the _AgentCardResolver class."""
|
|
22
|
+
|
|
23
|
+
def setUp(self):
|
|
24
|
+
self.mock_client = AsyncMock(spec=httpx.AsyncClient)
|
|
25
|
+
self.resolver = _AgentCardResolver(self.mock_client)
|
|
26
|
+
|
|
27
|
+
@patch("aixtools.a2a.google_sdk.utils.ClientFactory")
|
|
28
|
+
def test_init(self, mock_client_factory_class):
|
|
29
|
+
"""Test _AgentCardResolver initialization."""
|
|
30
|
+
mock_factory = MagicMock(spec=ClientFactory)
|
|
31
|
+
mock_client_factory_class.return_value = mock_factory
|
|
32
|
+
|
|
33
|
+
resolver = _AgentCardResolver(self.mock_client)
|
|
34
|
+
|
|
35
|
+
# Verify ClientFactory was created with correct config
|
|
36
|
+
mock_client_factory_class.assert_called_once()
|
|
37
|
+
call_args = mock_client_factory_class.call_args[0][0]
|
|
38
|
+
self.assertIsInstance(call_args, ClientConfig)
|
|
39
|
+
self.assertEqual(call_args.httpx_client, self.mock_client)
|
|
40
|
+
|
|
41
|
+
# Verify attributes are set
|
|
42
|
+
self.assertEqual(resolver._httpx_client, self.mock_client)
|
|
43
|
+
self.assertEqual(resolver._a2a_client_factory, mock_factory)
|
|
44
|
+
self.assertEqual(resolver.clients, {})
|
|
45
|
+
|
|
46
|
+
@patch("aixtools.a2a.google_sdk.utils.RemoteAgentConnection")
|
|
47
|
+
def test_register_agent_card(self, mock_connection_class):
|
|
48
|
+
"""Test registering an agent card."""
|
|
49
|
+
mock_card = MagicMock(spec=AgentCard)
|
|
50
|
+
mock_card.name = "test_agent"
|
|
51
|
+
|
|
52
|
+
mock_client = MagicMock()
|
|
53
|
+
mock_factory = MagicMock(spec=ClientFactory)
|
|
54
|
+
mock_factory.create.return_value = mock_client
|
|
55
|
+
self.resolver._a2a_client_factory = mock_factory
|
|
56
|
+
|
|
57
|
+
mock_connection = MagicMock(spec=RemoteAgentConnection)
|
|
58
|
+
mock_connection_class.return_value = mock_connection
|
|
59
|
+
|
|
60
|
+
self.resolver.register_agent_card(mock_card)
|
|
61
|
+
|
|
62
|
+
# Verify client was created
|
|
63
|
+
mock_factory.create.assert_called_once_with(mock_card)
|
|
64
|
+
|
|
65
|
+
# Verify RemoteAgentConnection was created
|
|
66
|
+
mock_connection_class.assert_called_once_with(mock_card, mock_client)
|
|
67
|
+
|
|
68
|
+
# Verify connection was stored
|
|
69
|
+
self.assertEqual(self.resolver.clients["test_agent"], mock_connection)
|
|
70
|
+
|
|
71
|
+
@patch("aixtools.a2a.google_sdk.utils.A2ACardResolver")
|
|
72
|
+
async def test_retrieve_card(self, mock_resolver_class):
|
|
73
|
+
"""Test retrieving a card from an address."""
|
|
74
|
+
mock_resolver = AsyncMock(spec=A2ACardResolver)
|
|
75
|
+
mock_resolver_class.return_value = mock_resolver
|
|
76
|
+
|
|
77
|
+
mock_card = MagicMock(spec=AgentCard)
|
|
78
|
+
mock_card.name = "test_agent"
|
|
79
|
+
mock_resolver.get_agent_card.return_value = mock_card
|
|
80
|
+
|
|
81
|
+
with patch.object(self.resolver, 'register_agent_card') as mock_register:
|
|
82
|
+
await self.resolver.retrieve_card("http://test.com")
|
|
83
|
+
|
|
84
|
+
# Verify resolver was created correctly (it tries the first card path)
|
|
85
|
+
mock_resolver_class.assert_called_with(self.mock_client, "http://test.com", "/.well-known/agent-card.json")
|
|
86
|
+
|
|
87
|
+
# Verify card was retrieved
|
|
88
|
+
mock_resolver.get_agent_card.assert_called_once()
|
|
89
|
+
|
|
90
|
+
# Verify card was registered
|
|
91
|
+
mock_register.assert_called_once_with(mock_card)
|
|
92
|
+
|
|
93
|
+
async def test_get_a2a_clients(self):
|
|
94
|
+
"""Test getting A2A clients for multiple hosts."""
|
|
95
|
+
agent_hosts = ["http://agent1.com", "http://agent2.com"]
|
|
96
|
+
|
|
97
|
+
with patch.object(self.resolver, 'retrieve_card') as mock_retrieve:
|
|
98
|
+
mock_retrieve.return_value = None # Mock async function
|
|
99
|
+
|
|
100
|
+
# Set up some mock clients
|
|
101
|
+
mock_connection1 = MagicMock(spec=RemoteAgentConnection)
|
|
102
|
+
mock_connection2 = MagicMock(spec=RemoteAgentConnection)
|
|
103
|
+
self.resolver.clients = {
|
|
104
|
+
"agent1": mock_connection1,
|
|
105
|
+
"agent2": mock_connection2
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
result = await self.resolver.get_a2a_clients(agent_hosts)
|
|
109
|
+
|
|
110
|
+
# Verify retrieve_card was called for each host
|
|
111
|
+
self.assertEqual(mock_retrieve.call_count, 2)
|
|
112
|
+
mock_retrieve.assert_any_call("http://agent1.com")
|
|
113
|
+
mock_retrieve.assert_any_call("http://agent2.com")
|
|
114
|
+
|
|
115
|
+
# Verify result contains the clients
|
|
116
|
+
self.assertEqual(result, self.resolver.clients)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class TestGetA2AClients(unittest.IsolatedAsyncioTestCase):
|
|
120
|
+
"""Tests for the get_a2a_clients function."""
|
|
121
|
+
|
|
122
|
+
@patch("aixtools.a2a.google_sdk.utils._AgentCardResolver")
|
|
123
|
+
@patch("aixtools.a2a.google_sdk.utils.httpx.AsyncClient")
|
|
124
|
+
async def test_get_a2a_clients(self, mock_client_class, mock_resolver_class):
|
|
125
|
+
"""Test the get_a2a_clients function."""
|
|
126
|
+
mock_client = AsyncMock()
|
|
127
|
+
mock_client_class.return_value = mock_client
|
|
128
|
+
|
|
129
|
+
mock_resolver = AsyncMock(spec=_AgentCardResolver)
|
|
130
|
+
mock_resolver_class.return_value = mock_resolver
|
|
131
|
+
|
|
132
|
+
mock_clients = {"agent1": MagicMock(), "agent2": MagicMock()}
|
|
133
|
+
mock_resolver.get_a2a_clients.return_value = mock_clients
|
|
134
|
+
|
|
135
|
+
ctx = ("user123", "session456")
|
|
136
|
+
agent_hosts = ["http://agent1.com", "http://agent2.com"]
|
|
137
|
+
|
|
138
|
+
result = await get_a2a_clients(ctx, agent_hosts)
|
|
139
|
+
|
|
140
|
+
# Verify httpx client was created with correct headers
|
|
141
|
+
expected_headers = {
|
|
142
|
+
"user-id": "user123",
|
|
143
|
+
"session-id": "session456",
|
|
144
|
+
}
|
|
145
|
+
mock_client_class.assert_called_once_with(headers=expected_headers, timeout=60.0)
|
|
146
|
+
|
|
147
|
+
# Verify resolver was created with the client
|
|
148
|
+
mock_resolver_class.assert_called_once_with(mock_client)
|
|
149
|
+
|
|
150
|
+
# Verify get_a2a_clients was called with the hosts
|
|
151
|
+
mock_resolver.get_a2a_clients.assert_called_once_with(agent_hosts)
|
|
152
|
+
|
|
153
|
+
# Verify result
|
|
154
|
+
self.assertEqual(result, mock_clients)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class TestGetSessionIdTuple(unittest.TestCase):
|
|
158
|
+
"""Tests for the get_session_id_tuple function."""
|
|
159
|
+
|
|
160
|
+
def test_get_session_id_tuple_with_headers(self):
|
|
161
|
+
"""Test getting session ID tuple when headers are present."""
|
|
162
|
+
mock_context = MagicMock(spec=RequestContext)
|
|
163
|
+
mock_context.call_context.state = {
|
|
164
|
+
"headers": {
|
|
165
|
+
"user-id": "test_user",
|
|
166
|
+
"session-id": "test_session"
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
result = get_session_id_tuple(mock_context)
|
|
171
|
+
|
|
172
|
+
self.assertEqual(result, ("test_user", "test_session"))
|
|
173
|
+
|
|
174
|
+
def test_get_session_id_tuple_with_partial_headers(self):
|
|
175
|
+
"""Test getting session ID tuple when only some headers are present."""
|
|
176
|
+
mock_context = MagicMock(spec=RequestContext)
|
|
177
|
+
mock_context.call_context.state = {
|
|
178
|
+
"headers": {
|
|
179
|
+
"user-id": "test_user"
|
|
180
|
+
# session-id is missing
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
result = get_session_id_tuple(mock_context)
|
|
185
|
+
|
|
186
|
+
self.assertEqual(result, ("test_user", DEFAULT_SESSION_ID))
|
|
187
|
+
|
|
188
|
+
def test_get_session_id_tuple_no_headers(self):
|
|
189
|
+
"""Test getting session ID tuple when no headers are present."""
|
|
190
|
+
mock_context = MagicMock(spec=RequestContext)
|
|
191
|
+
mock_context.call_context.state = {}
|
|
192
|
+
|
|
193
|
+
result = get_session_id_tuple(mock_context)
|
|
194
|
+
|
|
195
|
+
self.assertEqual(result, (DEFAULT_USER_ID, DEFAULT_SESSION_ID))
|
|
196
|
+
|
|
197
|
+
def test_get_session_id_tuple_empty_headers(self):
|
|
198
|
+
"""Test getting session ID tuple when headers dict is empty."""
|
|
199
|
+
mock_context = MagicMock(spec=RequestContext)
|
|
200
|
+
mock_context.call_context.state = {"headers": {}}
|
|
201
|
+
|
|
202
|
+
result = get_session_id_tuple(mock_context)
|
|
203
|
+
|
|
204
|
+
self.assertEqual(result, (DEFAULT_USER_ID, DEFAULT_SESSION_ID))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
if __name__ == '__main__':
|
|
208
|
+
unittest.main()
|
|
File without changes
|