aixtools 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/_version.py +2 -2
- aixtools/agents/agent.py +26 -7
- aixtools/agents/print_nodes.py +54 -0
- aixtools/agents/prompt.py +2 -2
- aixtools/compliance/private_data.py +1 -1
- aixtools/evals/discovery.py +174 -0
- aixtools/evals/evals.py +74 -0
- aixtools/evals/run_evals.py +110 -0
- aixtools/logging/log_objects.py +24 -23
- aixtools/mcp/client.py +148 -2
- aixtools/server/__init__.py +0 -6
- aixtools/server/path.py +88 -31
- aixtools/testing/aix_test_model.py +9 -1
- aixtools/tools/doctor/mcp_tool_doctor.py +79 -0
- aixtools/tools/doctor/tool_doctor.py +4 -0
- aixtools/tools/doctor/tool_recommendation.py +5 -0
- aixtools/utils/config.py +0 -1
- {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/METADATA +186 -30
- {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/RECORD +23 -55
- aixtools-0.2.0.dist-info/entry_points.txt +4 -0
- aixtools-0.2.0.dist-info/top_level.txt +1 -0
- aixtools/server/workspace_privacy.py +0 -65
- aixtools-0.1.10.dist-info/entry_points.txt +0 -2
- aixtools-0.1.10.dist-info/top_level.txt +0 -5
- docker/mcp-base/Dockerfile +0 -33
- docker/mcp-base/zscaler.crt +0 -28
- notebooks/example_faulty_mcp_server.ipynb +0 -74
- notebooks/example_mcp_server_stdio.ipynb +0 -76
- notebooks/example_raw_mcp_client.ipynb +0 -84
- notebooks/example_tool_doctor.ipynb +0 -65
- scripts/config.sh +0 -28
- scripts/lint.sh +0 -32
- scripts/log_view.sh +0 -18
- scripts/run_example_mcp_server.sh +0 -14
- scripts/run_faulty_mcp_server.sh +0 -13
- scripts/run_server.sh +0 -29
- scripts/test.sh +0 -30
- tests/unit/__init__.py +0 -0
- tests/unit/a2a/__init__.py +0 -0
- tests/unit/a2a/google_sdk/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/__init__.py +0 -0
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_agent_executor.py +0 -188
- tests/unit/a2a/google_sdk/pydantic_ai_adapter/test_storage.py +0 -156
- tests/unit/a2a/google_sdk/test_card.py +0 -114
- tests/unit/a2a/google_sdk/test_remote_agent_connection.py +0 -413
- tests/unit/a2a/google_sdk/test_utils.py +0 -208
- tests/unit/agents/__init__.py +0 -0
- tests/unit/agents/test_prompt.py +0 -363
- tests/unit/compliance/test_private_data.py +0 -329
- tests/unit/google/__init__.py +0 -1
- tests/unit/google/test_client.py +0 -233
- tests/unit/mcp/__init__.py +0 -0
- tests/unit/mcp/test_client.py +0 -242
- tests/unit/server/__init__.py +0 -0
- tests/unit/server/test_path.py +0 -225
- tests/unit/server/test_utils.py +0 -362
- tests/unit/utils/__init__.py +0 -0
- tests/unit/utils/test_files.py +0 -146
- tests/unit/vault/__init__.py +0 -0
- tests/unit/vault/test_vault.py +0 -246
- {tests → aixtools/evals}/__init__.py +0 -0
- {aixtools-0.1.10.dist-info → aixtools-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -1,156 +0,0 @@
|
|
|
1
|
-
"""Tests for the Pydantic AI adapter storage module."""
|
|
2
|
-
|
|
3
|
-
import unittest
|
|
4
|
-
from unittest.mock import MagicMock
|
|
5
|
-
|
|
6
|
-
from aixtools.a2a.google_sdk.pydantic_ai_adapter.storage import (
|
|
7
|
-
PydanticAiAgentHistoryStorage,
|
|
8
|
-
InMemoryHistoryStorage,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class TestInMemoryHistoryStorage(unittest.TestCase):
|
|
13
|
-
"""Tests for the InMemoryHistoryStorage class."""
|
|
14
|
-
|
|
15
|
-
def setUp(self):
|
|
16
|
-
self.storage = InMemoryHistoryStorage()
|
|
17
|
-
|
|
18
|
-
def test_init(self):
|
|
19
|
-
"""Test InMemoryHistoryStorage initialization."""
|
|
20
|
-
self.assertEqual(self.storage.storage, {})
|
|
21
|
-
|
|
22
|
-
def test_get_nonexistent_task(self):
|
|
23
|
-
"""Test getting history for a task that doesn't exist."""
|
|
24
|
-
result = self.storage.get("nonexistent_task")
|
|
25
|
-
self.assertIsNone(result)
|
|
26
|
-
|
|
27
|
-
def test_store_and_get_messages(self):
|
|
28
|
-
"""Test storing and retrieving messages."""
|
|
29
|
-
task_id = "test_task_1"
|
|
30
|
-
# Use simple mock objects that can be stored
|
|
31
|
-
messages = [MagicMock(), MagicMock()]
|
|
32
|
-
|
|
33
|
-
# Store the messages
|
|
34
|
-
self.storage.store(task_id, messages)
|
|
35
|
-
|
|
36
|
-
# Retrieve the messages
|
|
37
|
-
result = self.storage.get(task_id)
|
|
38
|
-
|
|
39
|
-
self.assertIsNotNone(result)
|
|
40
|
-
self.assertEqual(result, messages)
|
|
41
|
-
self.assertEqual(len(result), 2)
|
|
42
|
-
|
|
43
|
-
def test_store_overwrites_existing(self):
|
|
44
|
-
"""Test that storing overwrites existing messages for the same task."""
|
|
45
|
-
task_id = "test_task_3"
|
|
46
|
-
|
|
47
|
-
# Store initial messages
|
|
48
|
-
initial_messages = [MagicMock()]
|
|
49
|
-
self.storage.store(task_id, initial_messages)
|
|
50
|
-
|
|
51
|
-
# Store new messages (should overwrite)
|
|
52
|
-
new_messages = [MagicMock(), MagicMock()]
|
|
53
|
-
self.storage.store(task_id, new_messages)
|
|
54
|
-
|
|
55
|
-
# Retrieve and verify new messages are stored
|
|
56
|
-
result = self.storage.get(task_id)
|
|
57
|
-
|
|
58
|
-
self.assertIsNotNone(result)
|
|
59
|
-
self.assertEqual(result, new_messages)
|
|
60
|
-
self.assertEqual(len(result), 2)
|
|
61
|
-
self.assertNotEqual(result, initial_messages)
|
|
62
|
-
|
|
63
|
-
def test_multiple_tasks(self):
|
|
64
|
-
"""Test storing and retrieving messages for multiple tasks."""
|
|
65
|
-
task1_id = "task_1"
|
|
66
|
-
task2_id = "task_2"
|
|
67
|
-
|
|
68
|
-
task1_messages = [MagicMock()]
|
|
69
|
-
task2_messages = [MagicMock(), MagicMock()]
|
|
70
|
-
|
|
71
|
-
# Store messages for both tasks
|
|
72
|
-
self.storage.store(task1_id, task1_messages)
|
|
73
|
-
self.storage.store(task2_id, task2_messages)
|
|
74
|
-
|
|
75
|
-
# Retrieve and verify both tasks' messages
|
|
76
|
-
result1 = self.storage.get(task1_id)
|
|
77
|
-
result2 = self.storage.get(task2_id)
|
|
78
|
-
|
|
79
|
-
self.assertIsNotNone(result1)
|
|
80
|
-
self.assertIsNotNone(result2)
|
|
81
|
-
self.assertEqual(result1, task1_messages)
|
|
82
|
-
self.assertEqual(result2, task2_messages)
|
|
83
|
-
self.assertNotEqual(result1, result2)
|
|
84
|
-
|
|
85
|
-
def test_store_empty_list(self):
|
|
86
|
-
"""Test storing an empty list of messages."""
|
|
87
|
-
task_id = "empty_task"
|
|
88
|
-
empty_messages = []
|
|
89
|
-
|
|
90
|
-
self.storage.store(task_id, empty_messages)
|
|
91
|
-
result = self.storage.get(task_id)
|
|
92
|
-
|
|
93
|
-
self.assertIsNotNone(result)
|
|
94
|
-
self.assertEqual(result, empty_messages)
|
|
95
|
-
self.assertEqual(len(result), 0)
|
|
96
|
-
|
|
97
|
-
def test_get_after_multiple_stores(self):
|
|
98
|
-
"""Test that get returns the most recent store for a task."""
|
|
99
|
-
task_id = "update_task"
|
|
100
|
-
|
|
101
|
-
# Store multiple times
|
|
102
|
-
messages1 = [MagicMock()]
|
|
103
|
-
messages2 = [MagicMock()]
|
|
104
|
-
messages3 = [MagicMock(), MagicMock()]
|
|
105
|
-
|
|
106
|
-
self.storage.store(task_id, messages1)
|
|
107
|
-
self.storage.store(task_id, messages2)
|
|
108
|
-
self.storage.store(task_id, messages3)
|
|
109
|
-
|
|
110
|
-
result = self.storage.get(task_id)
|
|
111
|
-
|
|
112
|
-
self.assertIsNotNone(result)
|
|
113
|
-
self.assertEqual(result, messages3)
|
|
114
|
-
|
|
115
|
-
def test_storage_isolation(self):
|
|
116
|
-
"""Test that different storage instances are isolated."""
|
|
117
|
-
storage1 = InMemoryHistoryStorage()
|
|
118
|
-
storage2 = InMemoryHistoryStorage()
|
|
119
|
-
|
|
120
|
-
task_id = "isolation_test"
|
|
121
|
-
messages1 = [MagicMock()]
|
|
122
|
-
messages2 = [MagicMock()]
|
|
123
|
-
|
|
124
|
-
storage1.store(task_id, messages1)
|
|
125
|
-
storage2.store(task_id, messages2)
|
|
126
|
-
|
|
127
|
-
result1 = storage1.get(task_id)
|
|
128
|
-
result2 = storage2.get(task_id)
|
|
129
|
-
|
|
130
|
-
self.assertIsNotNone(result1)
|
|
131
|
-
self.assertIsNotNone(result2)
|
|
132
|
-
self.assertEqual(result1, messages1)
|
|
133
|
-
self.assertEqual(result2, messages2)
|
|
134
|
-
self.assertNotEqual(result1, result2)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class TestPydanticAiAgentHistoryStorageInterface(unittest.TestCase):
|
|
138
|
-
"""Tests for the PydanticAiAgentHistoryStorage abstract interface."""
|
|
139
|
-
|
|
140
|
-
def test_cannot_instantiate_abstract_class(self):
|
|
141
|
-
"""Test that the abstract base class cannot be instantiated."""
|
|
142
|
-
with self.assertRaises(TypeError):
|
|
143
|
-
PydanticAiAgentHistoryStorage()
|
|
144
|
-
|
|
145
|
-
def test_inmemory_implements_interface(self):
|
|
146
|
-
"""Test that InMemoryHistoryStorage properly implements the interface."""
|
|
147
|
-
storage = InMemoryHistoryStorage()
|
|
148
|
-
|
|
149
|
-
# Verify it's an instance of the abstract base class
|
|
150
|
-
self.assertIsInstance(storage, PydanticAiAgentHistoryStorage)
|
|
151
|
-
|
|
152
|
-
# Verify it has the required methods
|
|
153
|
-
self.assertTrue(hasattr(storage, 'get'))
|
|
154
|
-
self.assertTrue(hasattr(storage, 'store'))
|
|
155
|
-
self.assertTrue(callable(getattr(storage, 'get')))
|
|
156
|
-
self.assertTrue(callable(getattr(storage, 'store')))
|
|
@@ -1,114 +0,0 @@
|
|
|
1
|
-
"""Tests for the A2A card module."""
|
|
2
|
-
|
|
3
|
-
import unittest
|
|
4
|
-
from unittest.mock import AsyncMock, MagicMock, patch
|
|
5
|
-
|
|
6
|
-
import httpx
|
|
7
|
-
from a2a.client import A2ACardResolver
|
|
8
|
-
from a2a.types import AgentCard
|
|
9
|
-
|
|
10
|
-
from aixtools.a2a.google_sdk.card import get_agent_card
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class TestCard(unittest.IsolatedAsyncioTestCase):
|
|
14
|
-
"""Tests for the A2A card module."""
|
|
15
|
-
|
|
16
|
-
def setUp(self):
|
|
17
|
-
self.test_agent_host = "http://localhost:9999"
|
|
18
|
-
|
|
19
|
-
@patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
|
|
20
|
-
async def test_get_agent_card_success(self, mock_resolver_class):
|
|
21
|
-
"""Test successful retrieval of agent card."""
|
|
22
|
-
# Setup
|
|
23
|
-
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
|
24
|
-
|
|
25
|
-
mock_resolver = AsyncMock(spec=A2ACardResolver)
|
|
26
|
-
mock_resolver_class.return_value = mock_resolver
|
|
27
|
-
|
|
28
|
-
mock_card = MagicMock(spec=AgentCard)
|
|
29
|
-
mock_card.model_dump_json.return_value = '{"test": "data"}'
|
|
30
|
-
mock_resolver.get_agent_card.return_value = mock_card
|
|
31
|
-
|
|
32
|
-
# Call the function
|
|
33
|
-
result = await get_agent_card(mock_client, self.test_agent_host)
|
|
34
|
-
|
|
35
|
-
# Verify the result
|
|
36
|
-
self.assertEqual(result, mock_card)
|
|
37
|
-
|
|
38
|
-
# Verify the resolver was created correctly
|
|
39
|
-
mock_resolver_class.assert_called_once_with(
|
|
40
|
-
httpx_client=mock_client,
|
|
41
|
-
base_url=self.test_agent_host
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
# Verify the card was fetched
|
|
45
|
-
mock_resolver.get_agent_card.assert_called_once()
|
|
46
|
-
|
|
47
|
-
# Verify the URL was set
|
|
48
|
-
self.assertEqual(result.url, self.test_agent_host)
|
|
49
|
-
|
|
50
|
-
@patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
|
|
51
|
-
async def test_get_agent_card_failure(self, mock_resolver_class):
|
|
52
|
-
"""Test handling of errors when retrieving agent card."""
|
|
53
|
-
# Setup
|
|
54
|
-
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
|
55
|
-
|
|
56
|
-
mock_resolver = AsyncMock(spec=A2ACardResolver)
|
|
57
|
-
mock_resolver_class.return_value = mock_resolver
|
|
58
|
-
|
|
59
|
-
# Make the resolver raise an exception
|
|
60
|
-
mock_resolver.get_agent_card.side_effect = Exception("Failed to fetch card")
|
|
61
|
-
|
|
62
|
-
# Call the function and expect an exception
|
|
63
|
-
with self.assertRaises(RuntimeError) as context:
|
|
64
|
-
await get_agent_card(mock_client, self.test_agent_host)
|
|
65
|
-
|
|
66
|
-
self.assertIn("Failed to fetch the public agent card", str(context.exception))
|
|
67
|
-
|
|
68
|
-
@patch("aixtools.a2a.google_sdk.card.logger")
|
|
69
|
-
@patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
|
|
70
|
-
async def test_get_agent_card_logging(self, mock_resolver_class, mock_logger):
|
|
71
|
-
"""Test that proper logging occurs during card retrieval."""
|
|
72
|
-
# Setup
|
|
73
|
-
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
|
74
|
-
|
|
75
|
-
mock_resolver = AsyncMock(spec=A2ACardResolver)
|
|
76
|
-
mock_resolver_class.return_value = mock_resolver
|
|
77
|
-
|
|
78
|
-
mock_card = MagicMock(spec=AgentCard)
|
|
79
|
-
mock_card.model_dump_json.return_value = '{"test": "data"}'
|
|
80
|
-
mock_resolver.get_agent_card.return_value = mock_card
|
|
81
|
-
|
|
82
|
-
# Call the function
|
|
83
|
-
await get_agent_card(mock_client, self.test_agent_host)
|
|
84
|
-
|
|
85
|
-
# Verify logging calls
|
|
86
|
-
mock_logger.info.assert_called()
|
|
87
|
-
self.assertEqual(mock_logger.info.call_count, 2) # Two info calls in the function
|
|
88
|
-
|
|
89
|
-
@patch("aixtools.a2a.google_sdk.card.logger")
|
|
90
|
-
@patch("aixtools.a2a.google_sdk.card.A2ACardResolver")
|
|
91
|
-
async def test_get_agent_card_error_logging(self, mock_resolver_class, mock_logger):
|
|
92
|
-
"""Test that errors are properly logged."""
|
|
93
|
-
# Setup
|
|
94
|
-
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
|
95
|
-
|
|
96
|
-
mock_resolver = AsyncMock(spec=A2ACardResolver)
|
|
97
|
-
mock_resolver_class.return_value = mock_resolver
|
|
98
|
-
|
|
99
|
-
test_error = Exception("Test error")
|
|
100
|
-
mock_resolver.get_agent_card.side_effect = test_error
|
|
101
|
-
|
|
102
|
-
# Call the function and expect an exception
|
|
103
|
-
with self.assertRaises(RuntimeError):
|
|
104
|
-
await get_agent_card(mock_client, self.test_agent_host)
|
|
105
|
-
|
|
106
|
-
# Verify error logging
|
|
107
|
-
mock_logger.error.assert_called_once()
|
|
108
|
-
args, kwargs = mock_logger.error.call_args
|
|
109
|
-
self.assertIn("Critical error fetching public agent card", args[0])
|
|
110
|
-
self.assertTrue(kwargs.get('exc_info'))
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if __name__ == '__main__':
|
|
114
|
-
unittest.main()
|
|
@@ -1,413 +0,0 @@
|
|
|
1
|
-
"""Tests for the remote agent connection module."""
|
|
2
|
-
|
|
3
|
-
import unittest
|
|
4
|
-
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
5
|
-
|
|
6
|
-
from a2a.client import Client
|
|
7
|
-
from a2a.types import AgentCard, Message, Task, TaskState, TaskStatus, TaskQueryParams
|
|
8
|
-
|
|
9
|
-
from aixtools.a2a.google_sdk.remote_agent_connection import (
|
|
10
|
-
RemoteAgentConnection,
|
|
11
|
-
is_in_terminal_state,
|
|
12
|
-
is_in_terminal_or_interrupted_state,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class TestRemoteAgentConnection(unittest.IsolatedAsyncioTestCase):
|
|
17
|
-
"""Tests for the RemoteAgentConnection class."""
|
|
18
|
-
|
|
19
|
-
def setUp(self):
|
|
20
|
-
self.mock_card = MagicMock(spec=AgentCard)
|
|
21
|
-
self.mock_client = AsyncMock(spec=Client)
|
|
22
|
-
self.connection = RemoteAgentConnection(self.mock_card, self.mock_client)
|
|
23
|
-
|
|
24
|
-
def test_get_agent_card(self):
|
|
25
|
-
"""Test that get_agent_card returns the stored card."""
|
|
26
|
-
result = self.connection.get_agent_card()
|
|
27
|
-
self.assertEqual(result, self.mock_card)
|
|
28
|
-
|
|
29
|
-
async def test_send_message_returns_message(self):
|
|
30
|
-
"""Test send_message when it receives a Message response."""
|
|
31
|
-
mock_message = MagicMock(spec=Message)
|
|
32
|
-
mock_task = MagicMock(spec=Task)
|
|
33
|
-
|
|
34
|
-
# Mock the async generator to yield a message
|
|
35
|
-
async def mock_generator():
|
|
36
|
-
yield mock_message
|
|
37
|
-
yield (mock_task,)
|
|
38
|
-
|
|
39
|
-
self.mock_client.send_message.return_value = mock_generator()
|
|
40
|
-
|
|
41
|
-
test_message = MagicMock(spec=Message)
|
|
42
|
-
result = await self.connection.send_message(test_message)
|
|
43
|
-
|
|
44
|
-
self.assertEqual(result, mock_message)
|
|
45
|
-
self.mock_client.send_message.assert_called_once_with(test_message)
|
|
46
|
-
|
|
47
|
-
async def test_send_message_returns_terminal_task(self):
|
|
48
|
-
"""Test send_message when it receives a task in terminal state."""
|
|
49
|
-
mock_task = MagicMock(spec=Task)
|
|
50
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
51
|
-
mock_task.status.state = TaskState.completed
|
|
52
|
-
|
|
53
|
-
# Mock the async generator to yield a terminal task
|
|
54
|
-
async def mock_generator():
|
|
55
|
-
yield (mock_task,)
|
|
56
|
-
|
|
57
|
-
self.mock_client.send_message.return_value = mock_generator()
|
|
58
|
-
|
|
59
|
-
test_message = MagicMock(spec=Message)
|
|
60
|
-
result = await self.connection.send_message(test_message)
|
|
61
|
-
|
|
62
|
-
self.assertEqual(result, mock_task)
|
|
63
|
-
|
|
64
|
-
async def test_send_message_returns_interrupted_task(self):
|
|
65
|
-
"""Test send_message when it receives a task in interrupted state."""
|
|
66
|
-
mock_task = MagicMock(spec=Task)
|
|
67
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
68
|
-
mock_task.status.state = TaskState.input_required
|
|
69
|
-
|
|
70
|
-
# Mock the async generator to yield an interrupted task
|
|
71
|
-
async def mock_generator():
|
|
72
|
-
yield (mock_task,)
|
|
73
|
-
|
|
74
|
-
self.mock_client.send_message.return_value = mock_generator()
|
|
75
|
-
|
|
76
|
-
test_message = MagicMock(spec=Message)
|
|
77
|
-
result = await self.connection.send_message(test_message)
|
|
78
|
-
|
|
79
|
-
self.assertEqual(result, mock_task)
|
|
80
|
-
|
|
81
|
-
async def test_send_message_returns_last_task(self):
|
|
82
|
-
"""Test send_message returns the last task when no terminal/interrupted state is reached."""
|
|
83
|
-
mock_task1 = MagicMock(spec=Task)
|
|
84
|
-
mock_task1.status = MagicMock(spec=TaskStatus)
|
|
85
|
-
mock_task1.status.state = TaskState.working
|
|
86
|
-
|
|
87
|
-
mock_task2 = MagicMock(spec=Task)
|
|
88
|
-
mock_task2.status = MagicMock(spec=TaskStatus)
|
|
89
|
-
mock_task2.status.state = TaskState.working
|
|
90
|
-
|
|
91
|
-
# Mock the async generator to yield multiple running tasks
|
|
92
|
-
async def mock_generator():
|
|
93
|
-
yield (mock_task1,)
|
|
94
|
-
yield (mock_task2,)
|
|
95
|
-
|
|
96
|
-
self.mock_client.send_message.return_value = mock_generator()
|
|
97
|
-
|
|
98
|
-
test_message = MagicMock(spec=Message)
|
|
99
|
-
result = await self.connection.send_message(test_message)
|
|
100
|
-
|
|
101
|
-
self.assertEqual(result, mock_task2)
|
|
102
|
-
|
|
103
|
-
async def test_send_message_handles_exception(self):
|
|
104
|
-
"""Test send_message properly handles and re-raises exceptions."""
|
|
105
|
-
test_error = Exception("Test error")
|
|
106
|
-
self.mock_client.send_message.side_effect = test_error
|
|
107
|
-
|
|
108
|
-
test_message = MagicMock(spec=Message)
|
|
109
|
-
|
|
110
|
-
with self.assertRaises(Exception) as context:
|
|
111
|
-
await self.connection.send_message(test_message)
|
|
112
|
-
|
|
113
|
-
self.assertEqual(context.exception, test_error)
|
|
114
|
-
|
|
115
|
-
async def test_send_message_returns_none_when_no_events(self):
|
|
116
|
-
"""Test send_message returns None when no events are yielded."""
|
|
117
|
-
# Mock empty async generator
|
|
118
|
-
async def mock_generator():
|
|
119
|
-
return
|
|
120
|
-
yield # This line will never be reached
|
|
121
|
-
|
|
122
|
-
self.mock_client.send_message.return_value = mock_generator()
|
|
123
|
-
|
|
124
|
-
test_message = MagicMock(spec=Message)
|
|
125
|
-
result = await self.connection.send_message(test_message)
|
|
126
|
-
|
|
127
|
-
self.assertIsNone(result)
|
|
128
|
-
|
|
129
|
-
async def test_send_message_with_polling_returns_message(self):
|
|
130
|
-
"""Test send_message_with_polling when send_message returns a Message."""
|
|
131
|
-
mock_message = MagicMock(spec=Message)
|
|
132
|
-
|
|
133
|
-
# Mock send_message to return a Message
|
|
134
|
-
self.connection.send_message = AsyncMock(return_value=mock_message)
|
|
135
|
-
|
|
136
|
-
test_message = MagicMock(spec=Message)
|
|
137
|
-
result = await self.connection.send_message_with_polling(test_message)
|
|
138
|
-
|
|
139
|
-
self.assertEqual(result, mock_message)
|
|
140
|
-
self.connection.send_message.assert_called_once_with(test_message)
|
|
141
|
-
|
|
142
|
-
async def test_send_message_with_polling_returns_terminal_task(self):
|
|
143
|
-
"""Test send_message_with_polling when send_message returns a task in terminal state."""
|
|
144
|
-
mock_task = MagicMock(spec=Task)
|
|
145
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
146
|
-
mock_task.status.state = TaskState.completed
|
|
147
|
-
|
|
148
|
-
# Mock send_message to return a terminal task
|
|
149
|
-
self.connection.send_message = AsyncMock(return_value=mock_task)
|
|
150
|
-
|
|
151
|
-
test_message = MagicMock(spec=Message)
|
|
152
|
-
result = await self.connection.send_message_with_polling(test_message)
|
|
153
|
-
|
|
154
|
-
self.assertEqual(result, mock_task)
|
|
155
|
-
self.connection.send_message.assert_called_once_with(test_message)
|
|
156
|
-
|
|
157
|
-
async def test_send_message_with_polling_polls_until_terminal(self):
|
|
158
|
-
"""Test send_message_with_polling polls until task reaches terminal state."""
|
|
159
|
-
# Create initial task in working state
|
|
160
|
-
initial_task = MagicMock(spec=Task)
|
|
161
|
-
initial_task.id = "task123"
|
|
162
|
-
initial_task.status = MagicMock(spec=TaskStatus)
|
|
163
|
-
initial_task.status.state = TaskState.working
|
|
164
|
-
|
|
165
|
-
# Create intermediate task still in working state
|
|
166
|
-
intermediate_task = MagicMock(spec=Task)
|
|
167
|
-
intermediate_task.status = MagicMock(spec=TaskStatus)
|
|
168
|
-
intermediate_task.status.state = TaskState.working
|
|
169
|
-
|
|
170
|
-
# Create final task in completed state
|
|
171
|
-
final_task = MagicMock(spec=Task)
|
|
172
|
-
final_task.status = MagicMock(spec=TaskStatus)
|
|
173
|
-
final_task.status.state = TaskState.completed
|
|
174
|
-
|
|
175
|
-
# Mock send_message to return initial working task
|
|
176
|
-
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
177
|
-
|
|
178
|
-
# Mock get_task to return intermediate task first, then final task
|
|
179
|
-
self.mock_client.get_task.side_effect = [intermediate_task, final_task]
|
|
180
|
-
|
|
181
|
-
test_message = MagicMock(spec=Message)
|
|
182
|
-
|
|
183
|
-
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
184
|
-
result = await self.connection.send_message_with_polling(
|
|
185
|
-
test_message, sleep_time=0.1, max_iter=10
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
self.assertEqual(result, final_task)
|
|
189
|
-
self.connection.send_message.assert_called_once_with(test_message)
|
|
190
|
-
|
|
191
|
-
# Verify get_task was called with correct parameters
|
|
192
|
-
expected_calls = [
|
|
193
|
-
call(TaskQueryParams(id="task123")),
|
|
194
|
-
call(TaskQueryParams(id="task123"))
|
|
195
|
-
]
|
|
196
|
-
self.mock_client.get_task.assert_has_calls(expected_calls)
|
|
197
|
-
|
|
198
|
-
# Verify sleep was called twice (once for each polling iteration)
|
|
199
|
-
self.assertEqual(mock_sleep.call_count, 2)
|
|
200
|
-
mock_sleep.assert_called_with(0.1)
|
|
201
|
-
|
|
202
|
-
async def test_send_message_with_polling_polls_until_interrupted(self):
|
|
203
|
-
"""Test send_message_with_polling polls until task reaches interrupted state."""
|
|
204
|
-
# Create initial task in working state
|
|
205
|
-
initial_task = MagicMock(spec=Task)
|
|
206
|
-
initial_task.id = "task456"
|
|
207
|
-
initial_task.status = MagicMock(spec=TaskStatus)
|
|
208
|
-
initial_task.status.state = TaskState.working
|
|
209
|
-
|
|
210
|
-
# Create final task in input_required state (interrupted)
|
|
211
|
-
final_task = MagicMock(spec=Task)
|
|
212
|
-
final_task.status = MagicMock(spec=TaskStatus)
|
|
213
|
-
final_task.status.state = TaskState.input_required
|
|
214
|
-
|
|
215
|
-
# Mock send_message to return initial working task
|
|
216
|
-
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
217
|
-
|
|
218
|
-
# Mock get_task to return final task
|
|
219
|
-
self.mock_client.get_task.return_value = final_task
|
|
220
|
-
|
|
221
|
-
test_message = MagicMock(spec=Message)
|
|
222
|
-
|
|
223
|
-
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
224
|
-
result = await self.connection.send_message_with_polling(
|
|
225
|
-
test_message, sleep_time=0.05, max_iter=5
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
self.assertEqual(result, final_task)
|
|
229
|
-
self.mock_client.get_task.assert_called_once_with(TaskQueryParams(id="task456"))
|
|
230
|
-
mock_sleep.assert_called_once_with(0.05)
|
|
231
|
-
|
|
232
|
-
async def test_send_message_with_polling_timeout_exception(self):
|
|
233
|
-
"""Test send_message_with_polling raises exception when max_iter is reached."""
|
|
234
|
-
# Create initial task in working state
|
|
235
|
-
initial_task = MagicMock(spec=Task)
|
|
236
|
-
initial_task.id = "task789"
|
|
237
|
-
initial_task.status = MagicMock(spec=TaskStatus)
|
|
238
|
-
initial_task.status.state = TaskState.working
|
|
239
|
-
|
|
240
|
-
# Create task that stays in working state
|
|
241
|
-
working_task = MagicMock(spec=Task)
|
|
242
|
-
working_task.status = MagicMock(spec=TaskStatus)
|
|
243
|
-
working_task.status.state = TaskState.working
|
|
244
|
-
|
|
245
|
-
# Mock send_message to return initial working task
|
|
246
|
-
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
247
|
-
|
|
248
|
-
# Mock get_task to always return working task
|
|
249
|
-
self.mock_client.get_task.return_value = working_task
|
|
250
|
-
|
|
251
|
-
test_message = MagicMock(spec=Message)
|
|
252
|
-
|
|
253
|
-
with patch('asyncio.sleep', new_callable=AsyncMock):
|
|
254
|
-
with self.assertRaises(Exception) as context:
|
|
255
|
-
await self.connection.send_message_with_polling(
|
|
256
|
-
test_message, sleep_time=0.1, max_iter=3
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
expected_timeout = 3 * 0.1 # max_iter * sleep_time
|
|
260
|
-
self.assertIn(f"Task did not complete in {expected_timeout} seconds", str(context.exception))
|
|
261
|
-
|
|
262
|
-
# Verify get_task was called max_iter times
|
|
263
|
-
self.assertEqual(self.mock_client.get_task.call_count, 3)
|
|
264
|
-
|
|
265
|
-
async def test_send_message_with_polling_no_task_returned(self):
|
|
266
|
-
"""Test send_message_with_polling raises ValueError when send_message returns None."""
|
|
267
|
-
# Mock send_message to return None
|
|
268
|
-
self.connection.send_message = AsyncMock(return_value=None)
|
|
269
|
-
|
|
270
|
-
test_message = MagicMock(spec=Message)
|
|
271
|
-
|
|
272
|
-
with self.assertRaises(ValueError) as context:
|
|
273
|
-
await self.connection.send_message_with_polling(test_message)
|
|
274
|
-
|
|
275
|
-
self.assertEqual(str(context.exception), "No task or message returned from send_message")
|
|
276
|
-
self.connection.send_message.assert_called_once_with(test_message)
|
|
277
|
-
|
|
278
|
-
async def test_send_message_with_polling_custom_parameters(self):
|
|
279
|
-
"""Test send_message_with_polling with custom sleep_time and max_iter."""
|
|
280
|
-
# Create initial task in working state
|
|
281
|
-
initial_task = MagicMock(spec=Task)
|
|
282
|
-
initial_task.id = "custom_task"
|
|
283
|
-
initial_task.status = MagicMock(spec=TaskStatus)
|
|
284
|
-
initial_task.status.state = TaskState.working
|
|
285
|
-
|
|
286
|
-
# Create final task in completed state
|
|
287
|
-
final_task = MagicMock(spec=Task)
|
|
288
|
-
final_task.status = MagicMock(spec=TaskStatus)
|
|
289
|
-
final_task.status.state = TaskState.completed
|
|
290
|
-
|
|
291
|
-
# Mock send_message to return initial working task
|
|
292
|
-
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
293
|
-
|
|
294
|
-
# Mock get_task to return final task
|
|
295
|
-
self.mock_client.get_task.return_value = final_task
|
|
296
|
-
|
|
297
|
-
test_message = MagicMock(spec=Message)
|
|
298
|
-
|
|
299
|
-
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
300
|
-
result = await self.connection.send_message_with_polling(
|
|
301
|
-
test_message, sleep_time=0.5, max_iter=100
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
self.assertEqual(result, final_task)
|
|
305
|
-
mock_sleep.assert_called_once_with(0.5)
|
|
306
|
-
|
|
307
|
-
async def test_send_message_with_polling_default_parameters(self):
|
|
308
|
-
"""Test send_message_with_polling uses default parameters correctly."""
|
|
309
|
-
# Create initial task in working state
|
|
310
|
-
initial_task = MagicMock(spec=Task)
|
|
311
|
-
initial_task.id = "default_task"
|
|
312
|
-
initial_task.status = MagicMock(spec=TaskStatus)
|
|
313
|
-
initial_task.status.state = TaskState.working
|
|
314
|
-
|
|
315
|
-
# Create final task in failed state (terminal)
|
|
316
|
-
final_task = MagicMock(spec=Task)
|
|
317
|
-
final_task.status = MagicMock(spec=TaskStatus)
|
|
318
|
-
final_task.status.state = TaskState.failed
|
|
319
|
-
|
|
320
|
-
# Mock send_message to return initial working task
|
|
321
|
-
self.connection.send_message = AsyncMock(return_value=initial_task)
|
|
322
|
-
|
|
323
|
-
# Mock get_task to return final task
|
|
324
|
-
self.mock_client.get_task.return_value = final_task
|
|
325
|
-
|
|
326
|
-
test_message = MagicMock(spec=Message)
|
|
327
|
-
|
|
328
|
-
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
|
329
|
-
result = await self.connection.send_message_with_polling(test_message)
|
|
330
|
-
|
|
331
|
-
self.assertEqual(result, final_task)
|
|
332
|
-
# Default sleep_time is 0.2
|
|
333
|
-
mock_sleep.assert_called_once_with(0.2)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
class TestHelperFunctions(unittest.TestCase):
|
|
337
|
-
"""Tests for helper functions in the remote agent connection module."""
|
|
338
|
-
|
|
339
|
-
def test_is_in_terminal_state_completed(self):
|
|
340
|
-
"""Test is_in_terminal_state returns True for completed state."""
|
|
341
|
-
mock_task = MagicMock(spec=Task)
|
|
342
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
343
|
-
mock_task.status.state = TaskState.completed
|
|
344
|
-
|
|
345
|
-
result = is_in_terminal_state(mock_task)
|
|
346
|
-
self.assertTrue(result)
|
|
347
|
-
|
|
348
|
-
def test_is_in_terminal_state_canceled(self):
|
|
349
|
-
"""Test is_in_terminal_state returns True for canceled state."""
|
|
350
|
-
mock_task = MagicMock(spec=Task)
|
|
351
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
352
|
-
mock_task.status.state = TaskState.canceled
|
|
353
|
-
|
|
354
|
-
result = is_in_terminal_state(mock_task)
|
|
355
|
-
self.assertTrue(result)
|
|
356
|
-
|
|
357
|
-
def test_is_in_terminal_state_failed(self):
|
|
358
|
-
"""Test is_in_terminal_state returns True for failed state."""
|
|
359
|
-
mock_task = MagicMock(spec=Task)
|
|
360
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
361
|
-
mock_task.status.state = TaskState.failed
|
|
362
|
-
|
|
363
|
-
result = is_in_terminal_state(mock_task)
|
|
364
|
-
self.assertTrue(result)
|
|
365
|
-
|
|
366
|
-
def test_is_in_terminal_state_running(self):
|
|
367
|
-
"""Test is_in_terminal_state returns False for running state."""
|
|
368
|
-
mock_task = MagicMock(spec=Task)
|
|
369
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
370
|
-
mock_task.status.state = TaskState.working
|
|
371
|
-
|
|
372
|
-
result = is_in_terminal_state(mock_task)
|
|
373
|
-
self.assertFalse(result)
|
|
374
|
-
|
|
375
|
-
def test_is_in_terminal_or_interrupted_state_input_required(self):
|
|
376
|
-
"""Test is_in_terminal_or_interrupted_state returns True for input_required state."""
|
|
377
|
-
mock_task = MagicMock(spec=Task)
|
|
378
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
379
|
-
mock_task.status.state = TaskState.input_required
|
|
380
|
-
|
|
381
|
-
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
382
|
-
self.assertTrue(result)
|
|
383
|
-
|
|
384
|
-
def test_is_in_terminal_or_interrupted_state_unknown(self):
|
|
385
|
-
"""Test is_in_terminal_or_interrupted_state returns True for unknown state."""
|
|
386
|
-
mock_task = MagicMock(spec=Task)
|
|
387
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
388
|
-
mock_task.status.state = TaskState.unknown
|
|
389
|
-
|
|
390
|
-
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
391
|
-
self.assertTrue(result)
|
|
392
|
-
|
|
393
|
-
def test_is_in_terminal_or_interrupted_state_completed(self):
|
|
394
|
-
"""Test is_in_terminal_or_interrupted_state returns True for terminal states."""
|
|
395
|
-
mock_task = MagicMock(spec=Task)
|
|
396
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
397
|
-
mock_task.status.state = TaskState.completed
|
|
398
|
-
|
|
399
|
-
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
400
|
-
self.assertTrue(result)
|
|
401
|
-
|
|
402
|
-
def test_is_in_terminal_or_interrupted_state_running(self):
|
|
403
|
-
"""Test is_in_terminal_or_interrupted_state returns False for running state."""
|
|
404
|
-
mock_task = MagicMock(spec=Task)
|
|
405
|
-
mock_task.status = MagicMock(spec=TaskStatus)
|
|
406
|
-
mock_task.status.state = TaskState.working
|
|
407
|
-
|
|
408
|
-
result = is_in_terminal_or_interrupted_state(mock_task)
|
|
409
|
-
self.assertFalse(result)
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
if __name__ == '__main__':
|
|
413
|
-
unittest.main()
|