aixtools 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (62) hide show
  1. aixtools/_version.py +2 -2
  2. aixtools/agents/agent.py +26 -7
  3. aixtools/agents/print_nodes.py +54 -0
  4. aixtools/agents/prompt.py +2 -2
  5. aixtools/compliance/private_data.py +1 -1
  6. aixtools/evals/discovery.py +174 -0
  7. aixtools/evals/evals.py +74 -0
  8. aixtools/evals/run_evals.py +110 -0
  9. aixtools/logging/log_objects.py +24 -23
  10. aixtools/mcp/client.py +148 -2
  11. aixtools/server/__init__.py +0 -6
  12. aixtools/server/path.py +88 -31
  13. aixtools/testing/aix_test_model.py +9 -1
  14. aixtools/tools/doctor/mcp_tool_doctor.py +79 -0
  15. aixtools/tools/doctor/tool_doctor.py +4 -0
  16. aixtools/tools/doctor/tool_recommendation.py +5 -0
  17. aixtools/utils/config.py +0 -1
  18. {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/METADATA +186 -30
  19. {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/RECORD +23 -55
  20. aixtools-0.2.0.dist-info/entry_points.txt +4 -0
  21. aixtools-0.2.0.dist-info/top_level.txt +1 -0
  22. aixtools/server/workspace_privacy.py +0 -65
  23. aixtools-0.1.10.dist-info/entry_points.txt +0 -2
  24. aixtools-0.1.10.dist-info/top_level.txt +0 -5
  25. docker/mcp-base/Dockerfile +0 -33
  26. docker/mcp-base/zscaler.crt +0 -28
  27. notebooks/example_faulty_mcp_server.ipynb +0 -74
  28. notebooks/example_mcp_server_stdio.ipynb +0 -76
  29. notebooks/example_raw_mcp_client.ipynb +0 -84
  30. notebooks/example_tool_doctor.ipynb +0 -65
  31. scripts/config.sh +0 -28
  32. scripts/lint.sh +0 -32
  33. scripts/log_view.sh +0 -18
  34. scripts/run_example_mcp_server.sh +0 -14
  35. scripts/run_faulty_mcp_server.sh +0 -13
  36. scripts/run_server.sh +0 -29
  37. scripts/test.sh +0 -30
  38. tests/unit/__init__.py +0 -0
  39. tests/unit/a2a/__init__.py +0 -0
  40. tests/unit/a2a/google_sdk/__init__.py +0 -0
  41. tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
  42. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +0 -188
  43. tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +0 -156
  44. tests/unit/a2a/google_sdk/test_card.py +0 -114
  45. tests/unit/a2a/google_sdk/test_remote_agent_connection.py +0 -413
  46. tests/unit/a2a/google_sdk/test_utils.py +0 -208
  47. tests/unit/agents/__init__.py +0 -0
  48. tests/unit/agents/test_prompt.py +0 -363
  49. tests/unit/compliance/test_private_data.py +0 -329
  50. tests/unit/google/__init__.py +0 -1
  51. tests/unit/google/test_client.py +0 -233
  52. tests/unit/mcp/__init__.py +0 -0
  53. tests/unit/mcp/test_client.py +0 -242
  54. tests/unit/server/__init__.py +0 -0
  55. tests/unit/server/test_path.py +0 -225
  56. tests/unit/server/test_utils.py +0 -362
  57. tests/unit/utils/__init__.py +0 -0
  58. tests/unit/utils/test_files.py +0 -146
  59. tests/unit/vault/__init__.py +0 -0
  60. tests/unit/vault/test_vault.py +0 -246
  61. {tests → aixtools/evals}/__init__.py +0 -0
  62. {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/WHEEL +0 -0
@@ -1,156 +0,0 @@
1
- """Tests for the Pydantic AI adapter storage module."""
2
-
3
- import unittest
4
- from unittest.mock import MagicMock
5
-
6
- from aixtools.a2a.google_sdk.pydantic_ai_adapter.storage import (
7
- PydanticAiAgentHistoryStorage,
8
- InMemoryHistoryStorage,
9
- )
10
-
11
-
12
- class TestInMemoryHistoryStorage(unittest.TestCase):
13
- """Tests for the InMemoryHistoryStorage class."""
14
-
15
- def setUp(self):
16
- self.storage = InMemoryHistoryStorage()
17
-
18
- def test_init(self):
19
- """Test InMemoryHistoryStorage initialization."""
20
- self.assertEqual(self.storage.storage, {})
21
-
22
- def test_get_nonexistent_task(self):
23
- """Test getting history for a task that doesn't exist."""
24
- result = self.storage.get("nonexistent_task")
25
- self.assertIsNone(result)
26
-
27
- def test_store_and_get_messages(self):
28
- """Test storing and retrieving messages."""
29
- task_id = "test_task_1"
30
- # Use simple mock objects that can be stored
31
- messages = [MagicMock(), MagicMock()]
32
-
33
- # Store the messages
34
- self.storage.store(task_id, messages)
35
-
36
- # Retrieve the messages
37
- result = self.storage.get(task_id)
38
-
39
- self.assertIsNotNone(result)
40
- self.assertEqual(result, messages)
41
- self.assertEqual(len(result), 2)
42
-
43
- def test_store_overwrites_existing(self):
44
- """Test that storing overwrites existing messages for the same task."""
45
- task_id = "test_task_3"
46
-
47
- # Store initial messages
48
- initial_messages = [MagicMock()]
49
- self.storage.store(task_id, initial_messages)
50
-
51
- # Store new messages (should overwrite)
52
- new_messages = [MagicMock(), MagicMock()]
53
- self.storage.store(task_id, new_messages)
54
-
55
- # Retrieve and verify new messages are stored
56
- result = self.storage.get(task_id)
57
-
58
- self.assertIsNotNone(result)
59
- self.assertEqual(result, new_messages)
60
- self.assertEqual(len(result), 2)
61
- self.assertNotEqual(result, initial_messages)
62
-
63
- def test_multiple_tasks(self):
64
- """Test storing and retrieving messages for multiple tasks."""
65
- task1_id = "task_1"
66
- task2_id = "task_2"
67
-
68
- task1_messages = [MagicMock()]
69
- task2_messages = [MagicMock(), MagicMock()]
70
-
71
- # Store messages for both tasks
72
- self.storage.store(task1_id, task1_messages)
73
- self.storage.store(task2_id, task2_messages)
74
-
75
- # Retrieve and verify both tasks' messages
76
- result1 = self.storage.get(task1_id)
77
- result2 = self.storage.get(task2_id)
78
-
79
- self.assertIsNotNone(result1)
80
- self.assertIsNotNone(result2)
81
- self.assertEqual(result1, task1_messages)
82
- self.assertEqual(result2, task2_messages)
83
- self.assertNotEqual(result1, result2)
84
-
85
- def test_store_empty_list(self):
86
- """Test storing an empty list of messages."""
87
- task_id = "empty_task"
88
- empty_messages = []
89
-
90
- self.storage.store(task_id, empty_messages)
91
- result = self.storage.get(task_id)
92
-
93
- self.assertIsNotNone(result)
94
- self.assertEqual(result, empty_messages)
95
- self.assertEqual(len(result), 0)
96
-
97
- def test_get_after_multiple_stores(self):
98
- """Test that get returns the most recent store for a task."""
99
- task_id = "update_task"
100
-
101
- # Store multiple times
102
- messages1 = [MagicMock()]
103
- messages2 = [MagicMock()]
104
- messages3 = [MagicMock(), MagicMock()]
105
-
106
- self.storage.store(task_id, messages1)
107
- self.storage.store(task_id, messages2)
108
- self.storage.store(task_id, messages3)
109
-
110
- result = self.storage.get(task_id)
111
-
112
- self.assertIsNotNone(result)
113
- self.assertEqual(result, messages3)
114
-
115
- def test_storage_isolation(self):
116
- """Test that different storage instances are isolated."""
117
- storage1 = InMemoryHistoryStorage()
118
- storage2 = InMemoryHistoryStorage()
119
-
120
- task_id = "isolation_test"
121
- messages1 = [MagicMock()]
122
- messages2 = [MagicMock()]
123
-
124
- storage1.store(task_id, messages1)
125
- storage2.store(task_id, messages2)
126
-
127
- result1 = storage1.get(task_id)
128
- result2 = storage2.get(task_id)
129
-
130
- self.assertIsNotNone(result1)
131
- self.assertIsNotNone(result2)
132
- self.assertEqual(result1, messages1)
133
- self.assertEqual(result2, messages2)
134
- self.assertNotEqual(result1, result2)
135
-
136
-
137
- class TestPydanticAiAgentHistoryStorageInterface(unittest.TestCase):
138
- """Tests for the PydanticAiAgentHistoryStorage abstract interface."""
139
-
140
- def test_cannot_instantiate_abstract_class(self):
141
- """Test that the abstract base class cannot be instantiated."""
142
- with self.assertRaises(TypeError):
143
- PydanticAiAgentHistoryStorage()
144
-
145
- def test_inmemory_implements_interface(self):
146
- """Test that InMemoryHistoryStorage properly implements the interface."""
147
- storage = InMemoryHistoryStorage()
148
-
149
- # Verify it's an instance of the abstract base class
150
- self.assertIsInstance(storage, PydanticAiAgentHistoryStorage)
151
-
152
- # Verify it has the required methods
153
- self.assertTrue(hasattr(storage, 'get'))
154
- self.assertTrue(hasattr(storage, 'store'))
155
- self.assertTrue(callable(getattr(storage, 'get')))
156
- self.assertTrue(callable(getattr(storage, 'store')))
@@ -1,114 +0,0 @@
1
- """Tests for the A2A card module."""
2
-
3
- import unittest
4
- from unittest.mock import AsyncMock, MagicMock, patch
5
-
6
- import httpx
7
- from a2a.client import A2ACardResolver
8
- from a2a.types import AgentCard
9
-
10
- from aixtools.a2a.google_sdk.card import get_agent_card
11
-
12
-
13
- class TestCard(unittest.IsolatedAsyncioTestCase):
14
- """Tests for the A2A card module."""
15
-
16
- def setUp(self):
17
- self.test_agent_host = "http://localhost:9999"
18
-
19
- @patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
20
- async def test_get_agent_card_success(self, mock_resolver_class):
21
- """Test successful retrieval of agent card."""
22
- # Setup
23
- mock_client = AsyncMock(spec=httpx.AsyncClient)
24
-
25
- mock_resolver = AsyncMock(spec=A2ACardResolver)
26
- mock_resolver_class.return_value = mock_resolver
27
-
28
- mock_card = MagicMock(spec=AgentCard)
29
- mock_card.model_dump_json.return_value = '{"test": "data"}'
30
- mock_resolver.get_agent_card.return_value = mock_card
31
-
32
- # Call the function
33
- result = await get_agent_card(mock_client, self.test_agent_host)
34
-
35
- # Verify the result
36
- self.assertEqual(result, mock_card)
37
-
38
- # Verify the resolver was created correctly
39
- mock_resolver_class.assert_called_once_with(
40
- httpx_client=mock_client,
41
- base_url=self.test_agent_host
42
- )
43
-
44
- # Verify the card was fetched
45
- mock_resolver.get_agent_card.assert_called_once()
46
-
47
- # Verify the URL was set
48
- self.assertEqual(result.url, self.test_agent_host)
49
-
50
- @patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
51
- async def test_get_agent_card_failure(self, mock_resolver_class):
52
- """Test handling of errors when retrieving agent card."""
53
- # Setup
54
- mock_client = AsyncMock(spec=httpx.AsyncClient)
55
-
56
- mock_resolver = AsyncMock(spec=A2ACardResolver)
57
- mock_resolver_class.return_value = mock_resolver
58
-
59
- # Make the resolver raise an exception
60
- mock_resolver.get_agent_card.side_effect = Exception("Failed to fetch card")
61
-
62
- # Call the function and expect an exception
63
- with self.assertRaises(RuntimeError) as context:
64
- await get_agent_card(mock_client, self.test_agent_host)
65
-
66
- self.assertIn("Failed to fetch the public agent card", str(context.exception))
67
-
68
- @patch("aixtools.a2a.google_sdk.card.logger")
69
- @patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
70
- async def test_get_agent_card_logging(self, mock_resolver_class, mock_logger):
71
- """Test that proper logging occurs during card retrieval."""
72
- # Setup
73
- mock_client = AsyncMock(spec=httpx.AsyncClient)
74
-
75
- mock_resolver = AsyncMock(spec=A2ACardResolver)
76
- mock_resolver_class.return_value = mock_resolver
77
-
78
- mock_card = MagicMock(spec=AgentCard)
79
- mock_card.model_dump_json.return_value = '{"test": "data"}'
80
- mock_resolver.get_agent_card.return_value = mock_card
81
-
82
- # Call the function
83
- await get_agent_card(mock_client, self.test_agent_host)
84
-
85
- # Verify logging calls
86
- mock_logger.info.assert_called()
87
- self.assertEqual(mock_logger.info.call_count, 2) # Two info calls in the function
88
-
89
- @patch("aixtools.a2a.google_sdk.card.logger")
90
- @patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
91
- async def test_get_agent_card_error_logging(self, mock_resolver_class, mock_logger):
92
- """Test that errors are properly logged."""
93
- # Setup
94
- mock_client = AsyncMock(spec=httpx.AsyncClient)
95
-
96
- mock_resolver = AsyncMock(spec=A2ACardResolver)
97
- mock_resolver_class.return_value = mock_resolver
98
-
99
- test_error = Exception("Test error")
100
- mock_resolver.get_agent_card.side_effect = test_error
101
-
102
- # Call the function and expect an exception
103
- with self.assertRaises(RuntimeError):
104
- await get_agent_card(mock_client, self.test_agent_host)
105
-
106
- # Verify error logging
107
- mock_logger.error.assert_called_once()
108
- args, kwargs = mock_logger.error.call_args
109
- self.assertIn("Critical error fetching public agent card", args[0])
110
- self.assertTrue(kwargs.get('exc_info'))
111
-
112
-
113
- if __name__ == '__main__':
114
- unittest.main()
@@ -1,413 +0,0 @@
1
- """Tests for the remote agent connection module."""
2
-
3
- import unittest
4
- from unittest.mock import AsyncMock, MagicMock, patch, call
5
-
6
- from a2a.client import Client
7
- from a2a.types import AgentCard, Message, Task, TaskState, TaskStatus, TaskQueryParams
8
-
9
- from aixtools.a2a.google_sdk.remote_agent_connection import (
10
- RemoteAgentConnection,
11
- is_in_terminal_state,
12
- is_in_terminal_or_interrupted_state,
13
- )
14
-
15
-
16
- class TestRemoteAgentConnection(unittest.IsolatedAsyncioTestCase):
17
- """Tests for the RemoteAgentConnection class."""
18
-
19
- def setUp(self):
20
- self.mock_card = MagicMock(spec=AgentCard)
21
- self.mock_client = AsyncMock(spec=Client)
22
- self.connection = RemoteAgentConnection(self.mock_card, self.mock_client)
23
-
24
- def test_get_agent_card(self):
25
- """Test that get_agent_card returns the stored card."""
26
- result = self.connection.get_agent_card()
27
- self.assertEqual(result, self.mock_card)
28
-
29
- async def test_send_message_returns_message(self):
30
- """Test send_message when it receives a Message response."""
31
- mock_message = MagicMock(spec=Message)
32
- mock_task = MagicMock(spec=Task)
33
-
34
- # Mock the async generator to yield a message
35
- async def mock_generator():
36
- yield mock_message
37
- yield (mock_task,)
38
-
39
- self.mock_client.send_message.return_value = mock_generator()
40
-
41
- test_message = MagicMock(spec=Message)
42
- result = await self.connection.send_message(test_message)
43
-
44
- self.assertEqual(result, mock_message)
45
- self.mock_client.send_message.assert_called_once_with(test_message)
46
-
47
- async def test_send_message_returns_terminal_task(self):
48
- """Test send_message when it receives a task in terminal state."""
49
- mock_task = MagicMock(spec=Task)
50
- mock_task.status = MagicMock(spec=TaskStatus)
51
- mock_task.status.state = TaskState.completed
52
-
53
- # Mock the async generator to yield a terminal task
54
- async def mock_generator():
55
- yield (mock_task,)
56
-
57
- self.mock_client.send_message.return_value = mock_generator()
58
-
59
- test_message = MagicMock(spec=Message)
60
- result = await self.connection.send_message(test_message)
61
-
62
- self.assertEqual(result, mock_task)
63
-
64
- async def test_send_message_returns_interrupted_task(self):
65
- """Test send_message when it receives a task in interrupted state."""
66
- mock_task = MagicMock(spec=Task)
67
- mock_task.status = MagicMock(spec=TaskStatus)
68
- mock_task.status.state = TaskState.input_required
69
-
70
- # Mock the async generator to yield an interrupted task
71
- async def mock_generator():
72
- yield (mock_task,)
73
-
74
- self.mock_client.send_message.return_value = mock_generator()
75
-
76
- test_message = MagicMock(spec=Message)
77
- result = await self.connection.send_message(test_message)
78
-
79
- self.assertEqual(result, mock_task)
80
-
81
- async def test_send_message_returns_last_task(self):
82
- """Test send_message returns the last task when no terminal/interrupted state is reached."""
83
- mock_task1 = MagicMock(spec=Task)
84
- mock_task1.status = MagicMock(spec=TaskStatus)
85
- mock_task1.status.state = TaskState.working
86
-
87
- mock_task2 = MagicMock(spec=Task)
88
- mock_task2.status = MagicMock(spec=TaskStatus)
89
- mock_task2.status.state = TaskState.working
90
-
91
- # Mock the async generator to yield multiple running tasks
92
- async def mock_generator():
93
- yield (mock_task1,)
94
- yield (mock_task2,)
95
-
96
- self.mock_client.send_message.return_value = mock_generator()
97
-
98
- test_message = MagicMock(spec=Message)
99
- result = await self.connection.send_message(test_message)
100
-
101
- self.assertEqual(result, mock_task2)
102
-
103
- async def test_send_message_handles_exception(self):
104
- """Test send_message properly handles and re-raises exceptions."""
105
- test_error = Exception("Test error")
106
- self.mock_client.send_message.side_effect = test_error
107
-
108
- test_message = MagicMock(spec=Message)
109
-
110
- with self.assertRaises(Exception) as context:
111
- await self.connection.send_message(test_message)
112
-
113
- self.assertEqual(context.exception, test_error)
114
-
115
- async def test_send_message_returns_none_when_no_events(self):
116
- """Test send_message returns None when no events are yielded."""
117
- # Mock empty async generator
118
- async def mock_generator():
119
- return
120
- yield # This line will never be reached
121
-
122
- self.mock_client.send_message.return_value = mock_generator()
123
-
124
- test_message = MagicMock(spec=Message)
125
- result = await self.connection.send_message(test_message)
126
-
127
- self.assertIsNone(result)
128
-
129
- async def test_send_message_with_polling_returns_message(self):
130
- """Test send_message_with_polling when send_message returns a Message."""
131
- mock_message = MagicMock(spec=Message)
132
-
133
- # Mock send_message to return a Message
134
- self.connection.send_message = AsyncMock(return_value=mock_message)
135
-
136
- test_message = MagicMock(spec=Message)
137
- result = await self.connection.send_message_with_polling(test_message)
138
-
139
- self.assertEqual(result, mock_message)
140
- self.connection.send_message.assert_called_once_with(test_message)
141
-
142
- async def test_send_message_with_polling_returns_terminal_task(self):
143
- """Test send_message_with_polling when send_message returns a task in terminal state."""
144
- mock_task = MagicMock(spec=Task)
145
- mock_task.status = MagicMock(spec=TaskStatus)
146
- mock_task.status.state = TaskState.completed
147
-
148
- # Mock send_message to return a terminal task
149
- self.connection.send_message = AsyncMock(return_value=mock_task)
150
-
151
- test_message = MagicMock(spec=Message)
152
- result = await self.connection.send_message_with_polling(test_message)
153
-
154
- self.assertEqual(result, mock_task)
155
- self.connection.send_message.assert_called_once_with(test_message)
156
-
157
- async def test_send_message_with_polling_polls_until_terminal(self):
158
- """Test send_message_with_polling polls until task reaches terminal state."""
159
- # Create initial task in working state
160
- initial_task = MagicMock(spec=Task)
161
- initial_task.id = "task123"
162
- initial_task.status = MagicMock(spec=TaskStatus)
163
- initial_task.status.state = TaskState.working
164
-
165
- # Create intermediate task still in working state
166
- intermediate_task = MagicMock(spec=Task)
167
- intermediate_task.status = MagicMock(spec=TaskStatus)
168
- intermediate_task.status.state = TaskState.working
169
-
170
- # Create final task in completed state
171
- final_task = MagicMock(spec=Task)
172
- final_task.status = MagicMock(spec=TaskStatus)
173
- final_task.status.state = TaskState.completed
174
-
175
- # Mock send_message to return initial working task
176
- self.connection.send_message = AsyncMock(return_value=initial_task)
177
-
178
- # Mock get_task to return intermediate task first, then final task
179
- self.mock_client.get_task.side_effect = [intermediate_task, final_task]
180
-
181
- test_message = MagicMock(spec=Message)
182
-
183
- with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
184
- result = await self.connection.send_message_with_polling(
185
- test_message, sleep_time=0.1, max_iter=10
186
- )
187
-
188
- self.assertEqual(result, final_task)
189
- self.connection.send_message.assert_called_once_with(test_message)
190
-
191
- # Verify get_task was called with correct parameters
192
- expected_calls = [
193
- call(TaskQueryParams(id="task123")),
194
- call(TaskQueryParams(id="task123"))
195
- ]
196
- self.mock_client.get_task.assert_has_calls(expected_calls)
197
-
198
- # Verify sleep was called twice (once for each polling iteration)
199
- self.assertEqual(mock_sleep.call_count, 2)
200
- mock_sleep.assert_called_with(0.1)
201
-
202
- async def test_send_message_with_polling_polls_until_interrupted(self):
203
- """Test send_message_with_polling polls until task reaches interrupted state."""
204
- # Create initial task in working state
205
- initial_task = MagicMock(spec=Task)
206
- initial_task.id = "task456"
207
- initial_task.status = MagicMock(spec=TaskStatus)
208
- initial_task.status.state = TaskState.working
209
-
210
- # Create final task in input_required state (interrupted)
211
- final_task = MagicMock(spec=Task)
212
- final_task.status = MagicMock(spec=TaskStatus)
213
- final_task.status.state = TaskState.input_required
214
-
215
- # Mock send_message to return initial working task
216
- self.connection.send_message = AsyncMock(return_value=initial_task)
217
-
218
- # Mock get_task to return final task
219
- self.mock_client.get_task.return_value = final_task
220
-
221
- test_message = MagicMock(spec=Message)
222
-
223
- with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
224
- result = await self.connection.send_message_with_polling(
225
- test_message, sleep_time=0.05, max_iter=5
226
- )
227
-
228
- self.assertEqual(result, final_task)
229
- self.mock_client.get_task.assert_called_once_with(TaskQueryParams(id="task456"))
230
- mock_sleep.assert_called_once_with(0.05)
231
-
232
- async def test_send_message_with_polling_timeout_exception(self):
233
- """Test send_message_with_polling raises exception when max_iter is reached."""
234
- # Create initial task in working state
235
- initial_task = MagicMock(spec=Task)
236
- initial_task.id = "task789"
237
- initial_task.status = MagicMock(spec=TaskStatus)
238
- initial_task.status.state = TaskState.working
239
-
240
- # Create task that stays in working state
241
- working_task = MagicMock(spec=Task)
242
- working_task.status = MagicMock(spec=TaskStatus)
243
- working_task.status.state = TaskState.working
244
-
245
- # Mock send_message to return initial working task
246
- self.connection.send_message = AsyncMock(return_value=initial_task)
247
-
248
- # Mock get_task to always return working task
249
- self.mock_client.get_task.return_value = working_task
250
-
251
- test_message = MagicMock(spec=Message)
252
-
253
- with patch('asyncio.sleep', new_callable=AsyncMock):
254
- with self.assertRaises(Exception) as context:
255
- await self.connection.send_message_with_polling(
256
- test_message, sleep_time=0.1, max_iter=3
257
- )
258
-
259
- expected_timeout = 3 * 0.1 # max_iter * sleep_time
260
- self.assertIn(f"Task did not complete in {expected_timeout} seconds", str(context.exception))
261
-
262
- # Verify get_task was called max_iter times
263
- self.assertEqual(self.mock_client.get_task.call_count, 3)
264
-
265
- async def test_send_message_with_polling_no_task_returned(self):
266
- """Test send_message_with_polling raises ValueError when send_message returns None."""
267
- # Mock send_message to return None
268
- self.connection.send_message = AsyncMock(return_value=None)
269
-
270
- test_message = MagicMock(spec=Message)
271
-
272
- with self.assertRaises(ValueError) as context:
273
- await self.connection.send_message_with_polling(test_message)
274
-
275
- self.assertEqual(str(context.exception), "No task or message returned from send_message")
276
- self.connection.send_message.assert_called_once_with(test_message)
277
-
278
- async def test_send_message_with_polling_custom_parameters(self):
279
- """Test send_message_with_polling with custom sleep_time and max_iter."""
280
- # Create initial task in working state
281
- initial_task = MagicMock(spec=Task)
282
- initial_task.id = "custom_task"
283
- initial_task.status = MagicMock(spec=TaskStatus)
284
- initial_task.status.state = TaskState.working
285
-
286
- # Create final task in completed state
287
- final_task = MagicMock(spec=Task)
288
- final_task.status = MagicMock(spec=TaskStatus)
289
- final_task.status.state = TaskState.completed
290
-
291
- # Mock send_message to return initial working task
292
- self.connection.send_message = AsyncMock(return_value=initial_task)
293
-
294
- # Mock get_task to return final task
295
- self.mock_client.get_task.return_value = final_task
296
-
297
- test_message = MagicMock(spec=Message)
298
-
299
- with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
300
- result = await self.connection.send_message_with_polling(
301
- test_message, sleep_time=0.5, max_iter=100
302
- )
303
-
304
- self.assertEqual(result, final_task)
305
- mock_sleep.assert_called_once_with(0.5)
306
-
307
- async def test_send_message_with_polling_default_parameters(self):
308
- """Test send_message_with_polling uses default parameters correctly."""
309
- # Create initial task in working state
310
- initial_task = MagicMock(spec=Task)
311
- initial_task.id = "default_task"
312
- initial_task.status = MagicMock(spec=TaskStatus)
313
- initial_task.status.state = TaskState.working
314
-
315
- # Create final task in failed state (terminal)
316
- final_task = MagicMock(spec=Task)
317
- final_task.status = MagicMock(spec=TaskStatus)
318
- final_task.status.state = TaskState.failed
319
-
320
- # Mock send_message to return initial working task
321
- self.connection.send_message = AsyncMock(return_value=initial_task)
322
-
323
- # Mock get_task to return final task
324
- self.mock_client.get_task.return_value = final_task
325
-
326
- test_message = MagicMock(spec=Message)
327
-
328
- with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
329
- result = await self.connection.send_message_with_polling(test_message)
330
-
331
- self.assertEqual(result, final_task)
332
- # Default sleep_time is 0.2
333
- mock_sleep.assert_called_once_with(0.2)
334
-
335
-
336
- class TestHelperFunctions(unittest.TestCase):
337
- """Tests for helper functions in the remote agent connection module."""
338
-
339
- def test_is_in_terminal_state_completed(self):
340
- """Test is_in_terminal_state returns True for completed state."""
341
- mock_task = MagicMock(spec=Task)
342
- mock_task.status = MagicMock(spec=TaskStatus)
343
- mock_task.status.state = TaskState.completed
344
-
345
- result = is_in_terminal_state(mock_task)
346
- self.assertTrue(result)
347
-
348
- def test_is_in_terminal_state_canceled(self):
349
- """Test is_in_terminal_state returns True for canceled state."""
350
- mock_task = MagicMock(spec=Task)
351
- mock_task.status = MagicMock(spec=TaskStatus)
352
- mock_task.status.state = TaskState.canceled
353
-
354
- result = is_in_terminal_state(mock_task)
355
- self.assertTrue(result)
356
-
357
- def test_is_in_terminal_state_failed(self):
358
- """Test is_in_terminal_state returns True for failed state."""
359
- mock_task = MagicMock(spec=Task)
360
- mock_task.status = MagicMock(spec=TaskStatus)
361
- mock_task.status.state = TaskState.failed
362
-
363
- result = is_in_terminal_state(mock_task)
364
- self.assertTrue(result)
365
-
366
- def test_is_in_terminal_state_running(self):
367
- """Test is_in_terminal_state returns False for running state."""
368
- mock_task = MagicMock(spec=Task)
369
- mock_task.status = MagicMock(spec=TaskStatus)
370
- mock_task.status.state = TaskState.working
371
-
372
- result = is_in_terminal_state(mock_task)
373
- self.assertFalse(result)
374
-
375
- def test_is_in_terminal_or_interrupted_state_input_required(self):
376
- """Test is_in_terminal_or_interrupted_state returns True for input_required state."""
377
- mock_task = MagicMock(spec=Task)
378
- mock_task.status = MagicMock(spec=TaskStatus)
379
- mock_task.status.state = TaskState.input_required
380
-
381
- result = is_in_terminal_or_interrupted_state(mock_task)
382
- self.assertTrue(result)
383
-
384
- def test_is_in_terminal_or_interrupted_state_unknown(self):
385
- """Test is_in_terminal_or_interrupted_state returns True for unknown state."""
386
- mock_task = MagicMock(spec=Task)
387
- mock_task.status = MagicMock(spec=TaskStatus)
388
- mock_task.status.state = TaskState.unknown
389
-
390
- result = is_in_terminal_or_interrupted_state(mock_task)
391
- self.assertTrue(result)
392
-
393
- def test_is_in_terminal_or_interrupted_state_completed(self):
394
- """Test is_in_terminal_or_interrupted_state returns True for terminal states."""
395
- mock_task = MagicMock(spec=Task)
396
- mock_task.status = MagicMock(spec=TaskStatus)
397
- mock_task.status.state = TaskState.completed
398
-
399
- result = is_in_terminal_or_interrupted_state(mock_task)
400
- self.assertTrue(result)
401
-
402
- def test_is_in_terminal_or_interrupted_state_running(self):
403
- """Test is_in_terminal_or_interrupted_state returns False for running state."""
404
- mock_task = MagicMock(spec=Task)
405
- mock_task.status = MagicMock(spec=TaskStatus)
406
- mock_task.status.state = TaskState.working
407
-
408
- result = is_in_terminal_or_interrupted_state(mock_task)
409
- self.assertFalse(result)
410
-
411
-
412
- if __name__ == '__main__':
413
- unittest.main()