unique_toolkit 1.7.0__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. unique_toolkit/agentic/tools/a2a/__init__.py +19 -3
  2. unique_toolkit/agentic/tools/a2a/config.py +12 -52
  3. unique_toolkit/agentic/tools/a2a/evaluation/__init__.py +10 -3
  4. unique_toolkit/agentic/tools/a2a/evaluation/_utils.py +66 -0
  5. unique_toolkit/agentic/tools/a2a/evaluation/config.py +19 -3
  6. unique_toolkit/agentic/tools/a2a/evaluation/evaluator.py +82 -89
  7. unique_toolkit/agentic/tools/a2a/manager.py +2 -2
  8. unique_toolkit/agentic/tools/a2a/postprocessing/__init__.py +9 -1
  9. unique_toolkit/agentic/tools/a2a/postprocessing/{display.py → _display.py} +16 -7
  10. unique_toolkit/agentic/tools/a2a/postprocessing/_utils.py +19 -0
  11. unique_toolkit/agentic/tools/a2a/postprocessing/config.py +24 -0
  12. unique_toolkit/agentic/tools/a2a/postprocessing/postprocessor.py +109 -110
  13. unique_toolkit/agentic/tools/a2a/postprocessing/test/test_consolidate_references.py +665 -0
  14. unique_toolkit/agentic/tools/a2a/postprocessing/test/test_display.py +54 -75
  15. unique_toolkit/agentic/tools/a2a/postprocessing/test/test_postprocessor_reference_functions.py +53 -45
  16. unique_toolkit/agentic/tools/a2a/tool/__init__.py +4 -0
  17. unique_toolkit/agentic/tools/a2a/{memory.py → tool/_memory.py} +1 -1
  18. unique_toolkit/agentic/tools/a2a/{schema.py → tool/_schema.py} +0 -6
  19. unique_toolkit/agentic/tools/a2a/tool/config.py +63 -0
  20. unique_toolkit/agentic/tools/a2a/{service.py → tool/service.py} +108 -65
  21. unique_toolkit/agentic/tools/config.py +2 -2
  22. unique_toolkit/agentic/tools/tool_manager.py +1 -2
  23. {unique_toolkit-1.7.0.dist-info → unique_toolkit-1.8.1.dist-info}/METADATA +8 -1
  24. {unique_toolkit-1.7.0.dist-info → unique_toolkit-1.8.1.dist-info}/RECORD +26 -20
  25. {unique_toolkit-1.7.0.dist-info → unique_toolkit-1.8.1.dist-info}/LICENSE +0 -0
  26. {unique_toolkit-1.7.0.dist-info → unique_toolkit-1.8.1.dist-info}/WHEEL +0 -0
@@ -2,11 +2,13 @@ import re
2
2
 
3
3
  import pytest
4
4
 
5
- from unique_toolkit.agentic.tools.a2a.config import ResponseDisplayMode
6
- from unique_toolkit.agentic.tools.a2a.postprocessing.display import (
5
+ from unique_toolkit.agentic.tools.a2a.postprocessing._display import (
6
+ _build_sub_agent_answer_display,
7
7
  _DetailsResponseDisplayHandler,
8
- build_sub_agent_answer_display,
9
- remove_sub_agent_answer_from_text,
8
+ _remove_sub_agent_answer_from_text,
9
+ )
10
+ from unique_toolkit.agentic.tools.a2a.postprocessing.config import (
11
+ SubAgentResponseDisplayMode,
10
12
  )
11
13
 
12
14
 
@@ -14,12 +16,12 @@ class TestDetailsResponseDisplayHandler:
14
16
  """Test suite for DetailsResponseDisplayHandler class."""
15
17
 
16
18
  @pytest.fixture
17
- def open_handler(self):
19
+ def open_handler(self) -> _DetailsResponseDisplayHandler:
18
20
  """Create a handler with open mode."""
19
21
  return _DetailsResponseDisplayHandler(mode="open")
20
22
 
21
23
  @pytest.fixture
22
- def closed_handler(self):
24
+ def closed_handler(self) -> _DetailsResponseDisplayHandler:
23
25
  """Create a handler with closed mode."""
24
26
  return _DetailsResponseDisplayHandler(mode="closed")
25
27
 
@@ -201,82 +203,59 @@ class TestDisplayFunctions:
201
203
  "answer": "This is a test answer.",
202
204
  }
203
205
 
204
- def test_build_sub_agent_answer_display_details_open(self, sample_data):
205
- """Test building sub-agent answer display with DETAILS_OPEN mode."""
206
- result = build_sub_agent_answer_display(
207
- display_name=sample_data["display_name"],
208
- display_mode=ResponseDisplayMode.DETAILS_OPEN,
209
- answer=sample_data["answer"],
210
- assistant_id=sample_data["assistant_id"],
211
- )
212
-
213
- assert "<details open>" in result
214
- assert sample_data["display_name"] in result
215
- assert sample_data["answer"] in result
216
- assert sample_data["assistant_id"] in result
217
-
218
- def test_build_sub_agent_answer_display_details_closed(self, sample_data):
219
- """Test building sub-agent answer display with DETAILS_CLOSED mode."""
220
- result = build_sub_agent_answer_display(
221
- display_name=sample_data["display_name"],
222
- display_mode=ResponseDisplayMode.DETAILS_CLOSED,
223
- answer=sample_data["answer"],
224
- assistant_id=sample_data["assistant_id"],
225
- )
226
-
227
- assert "<details>" in result
228
- assert "<details open>" not in result
229
- assert sample_data["display_name"] in result
230
- assert sample_data["answer"] in result
231
- assert sample_data["assistant_id"] in result
232
-
233
- def test_build_sub_agent_answer_display_hidden_mode(self, sample_data):
234
- """Test building sub-agent answer display with HIDDEN mode."""
235
- result = build_sub_agent_answer_display(
206
+ @pytest.mark.parametrize(
207
+ "display_mode,expected_content,not_expected_content",
208
+ [
209
+ (SubAgentResponseDisplayMode.DETAILS_OPEN, "<details open>", None),
210
+ (SubAgentResponseDisplayMode.DETAILS_CLOSED, "<details>", "<details open>"),
211
+ (SubAgentResponseDisplayMode.HIDDEN, "", None),
212
+ ],
213
+ )
214
+ def test_build_sub_agent_answer_display(
215
+ self, sample_data, display_mode, expected_content, not_expected_content
216
+ ):
217
+ """Test building sub-agent answer display with different modes."""
218
+ result = _build_sub_agent_answer_display(
236
219
  display_name=sample_data["display_name"],
237
- display_mode=ResponseDisplayMode.HIDDEN,
220
+ display_mode=display_mode,
238
221
  answer=sample_data["answer"],
239
222
  assistant_id=sample_data["assistant_id"],
240
223
  )
241
224
 
242
- assert result == ""
243
-
244
- def test_remove_sub_agent_answer_from_text_details_open(self, sample_data):
245
- """Test removing sub-agent answer from text with DETAILS_OPEN mode."""
246
- # First build the display
247
- display_html = build_sub_agent_answer_display(
248
- display_name=sample_data["display_name"],
249
- display_mode=ResponseDisplayMode.DETAILS_OPEN,
250
- answer=sample_data["answer"],
251
- assistant_id=sample_data["assistant_id"],
252
- )
253
-
254
- text_with_display = f"Before\n{display_html}\nAfter"
255
-
256
- result = remove_sub_agent_answer_from_text(
257
- display_mode=ResponseDisplayMode.DETAILS_OPEN,
258
- text=text_with_display,
259
- assistant_id=sample_data["assistant_id"],
260
- )
261
-
262
- assert "Before" in result
263
- assert "After" in result
264
- assert sample_data["answer"] not in result
265
-
266
- def test_remove_sub_agent_answer_from_text_details_closed(self, sample_data):
267
- """Test removing sub-agent answer from text with DETAILS_CLOSED mode."""
225
+ if display_mode == SubAgentResponseDisplayMode.HIDDEN:
226
+ assert result == ""
227
+ else:
228
+ assert expected_content in result
229
+ assert sample_data["display_name"] in result
230
+ assert sample_data["answer"] in result
231
+ assert sample_data["assistant_id"] in result
232
+
233
+ if not_expected_content:
234
+ assert not_expected_content not in result
235
+
236
+ @pytest.mark.parametrize(
237
+ "display_mode",
238
+ [
239
+ SubAgentResponseDisplayMode.DETAILS_OPEN,
240
+ SubAgentResponseDisplayMode.DETAILS_CLOSED,
241
+ ],
242
+ )
243
+ def test_remove_sub_agent_answer_from_text_details_modes(
244
+ self, sample_data, display_mode
245
+ ):
246
+ """Test removing sub-agent answer from text with DETAILS_OPEN and DETAILS_CLOSED modes."""
268
247
  # First build the display
269
- display_html = build_sub_agent_answer_display(
248
+ display_html = _build_sub_agent_answer_display(
270
249
  display_name=sample_data["display_name"],
271
- display_mode=ResponseDisplayMode.DETAILS_CLOSED,
250
+ display_mode=display_mode,
272
251
  answer=sample_data["answer"],
273
252
  assistant_id=sample_data["assistant_id"],
274
253
  )
275
254
 
276
255
  text_with_display = f"Before\n{display_html}\nAfter"
277
256
 
278
- result = remove_sub_agent_answer_from_text(
279
- display_mode=ResponseDisplayMode.DETAILS_CLOSED,
257
+ result = _remove_sub_agent_answer_from_text(
258
+ display_mode=display_mode,
280
259
  text=text_with_display,
281
260
  assistant_id=sample_data["assistant_id"],
282
261
  )
@@ -288,8 +267,8 @@ class TestDisplayFunctions:
288
267
  def test_remove_sub_agent_answer_from_text_hidden_mode(self, sample_data):
289
268
  """Test removing sub-agent answer from text with HIDDEN mode."""
290
269
  text = "Some text here"
291
- result = remove_sub_agent_answer_from_text(
292
- display_mode=ResponseDisplayMode.HIDDEN,
270
+ result = _remove_sub_agent_answer_from_text(
271
+ display_mode=SubAgentResponseDisplayMode.HIDDEN,
293
272
  text=text,
294
273
  assistant_id=sample_data["assistant_id"],
295
274
  )
@@ -301,9 +280,9 @@ class TestDisplayFunctions:
301
280
  original_text = "This is the original text."
302
281
 
303
282
  # Build display
304
- display_html = build_sub_agent_answer_display(
283
+ display_html = _build_sub_agent_answer_display(
305
284
  display_name=sample_data["display_name"],
306
- display_mode=ResponseDisplayMode.DETAILS_OPEN,
285
+ display_mode=SubAgentResponseDisplayMode.DETAILS_OPEN,
307
286
  answer=sample_data["answer"],
308
287
  assistant_id=sample_data["assistant_id"],
309
288
  )
@@ -312,8 +291,8 @@ class TestDisplayFunctions:
312
291
  text_with_display = f"{original_text}\n{display_html}"
313
292
 
314
293
  # Remove display
315
- result = remove_sub_agent_answer_from_text(
316
- display_mode=ResponseDisplayMode.DETAILS_OPEN,
294
+ result = _remove_sub_agent_answer_from_text(
295
+ display_mode=SubAgentResponseDisplayMode.DETAILS_OPEN,
317
296
  text=text_with_display,
318
297
  assistant_id=sample_data["assistant_id"],
319
298
  )
@@ -1,4 +1,6 @@
1
- from unique_toolkit.agentic.tools.a2a.postprocessing.postprocessor import (
1
+ import pytest
2
+
3
+ from unique_toolkit.agentic.tools.a2a.postprocessing._utils import (
2
4
  _replace_references_in_text,
3
5
  _replace_references_in_text_non_overlapping,
4
6
  )
@@ -30,19 +32,21 @@ class TestReplaceReferencesInTextNonOverlapping:
30
32
  result = _replace_references_in_text_non_overlapping(text, ref_map)
31
33
  assert result == text
32
34
 
33
- def test_empty_ref_map(self):
34
- """Test with empty reference map."""
35
- text = "This text has<sup>1</sup> references but empty map."
36
- ref_map = {}
37
- result = _replace_references_in_text_non_overlapping(text, ref_map)
38
- assert result == text
39
-
40
- def test_empty_text(self):
41
- """Test with empty text."""
42
- text = ""
43
- ref_map = {1: 5}
35
+ @pytest.mark.parametrize(
36
+ "text,ref_map,expected",
37
+ [
38
+ (
39
+ "This text has<sup>1</sup> references but empty map.",
40
+ {},
41
+ "This text has<sup>1</sup> references but empty map.",
42
+ ),
43
+ ("", {1: 5}, ""),
44
+ ],
45
+ )
46
+ def test_empty_inputs(self, text, ref_map, expected):
47
+ """Test with empty reference map or empty text."""
44
48
  result = _replace_references_in_text_non_overlapping(text, ref_map)
45
- assert result == ""
49
+ assert result == expected
46
50
 
47
51
  def test_reference_not_in_map(self):
48
52
  """Test with references in text that are not in the map."""
@@ -92,20 +96,24 @@ class TestReplaceReferencesInTextNonOverlapping:
92
96
  expected = "Good<sup>10</sup> and bad<sup>abc</sup> and<sup></sup>."
93
97
  assert result == expected
94
98
 
95
- def test_zero_reference_number(self):
96
- """Test with zero as reference number."""
97
- text = "Zero reference<sup>0</sup> here."
98
- ref_map = {0: 100}
99
- result = _replace_references_in_text_non_overlapping(text, ref_map)
100
- expected = "Zero reference<sup>100</sup> here."
101
- assert result == expected
102
-
103
- def test_negative_reference_numbers(self):
104
- """Test with negative reference numbers (edge case)."""
105
- text = "Negative<sup>-1</sup> reference."
106
- ref_map = {-1: 5}
99
+ @pytest.mark.parametrize(
100
+ "text,ref_map,expected",
101
+ [
102
+ (
103
+ "Zero reference<sup>0</sup> here.",
104
+ {0: 100},
105
+ "Zero reference<sup>100</sup> here.",
106
+ ),
107
+ (
108
+ "Negative<sup>-1</sup> reference.",
109
+ {-1: 5},
110
+ "Negative<sup>5</sup> reference.",
111
+ ),
112
+ ],
113
+ )
114
+ def test_special_reference_numbers(self, text, ref_map, expected):
115
+ """Test with zero and negative reference numbers."""
107
116
  result = _replace_references_in_text_non_overlapping(text, ref_map)
108
- expected = "Negative<sup>5</sup> reference."
109
117
  assert result == expected
110
118
 
111
119
 
@@ -152,26 +160,26 @@ class TestReplaceReferencesInText:
152
160
  expected = "A<sup>4</sup>B<sup>1</sup>C<sup>2</sup>D<sup>3</sup>."
153
161
  assert result == expected
154
162
 
155
- def test_empty_ref_map(self):
156
- """Test with empty reference map."""
157
- text = "Text with<sup>1</sup> references."
158
- ref_map = {}
159
- result = _replace_references_in_text(text, ref_map)
160
- assert result == text
161
-
162
- def test_empty_text(self):
163
- """Test with empty text."""
164
- text = ""
165
- ref_map = {1: 2}
166
- result = _replace_references_in_text(text, ref_map)
167
- assert result == ""
168
-
169
- def test_no_references_in_text(self):
170
- """Test with text containing no references."""
171
- text = "This text has no references."
172
- ref_map = {1: 10, 2: 20}
163
+ @pytest.mark.parametrize(
164
+ "text,ref_map,expected",
165
+ [
166
+ (
167
+ "Text with<sup>1</sup> references.",
168
+ {},
169
+ "Text with<sup>1</sup> references.",
170
+ ),
171
+ ("", {1: 2}, ""),
172
+ (
173
+ "This text has no references.",
174
+ {1: 10, 2: 20},
175
+ "This text has no references.",
176
+ ),
177
+ ],
178
+ )
179
+ def test_edge_cases(self, text, ref_map, expected):
180
+ """Test edge cases: empty reference map, empty text, and text with no references."""
173
181
  result = _replace_references_in_text(text, ref_map)
174
- assert result == text
182
+ assert result == expected
175
183
 
176
184
  def test_single_reference_no_overlap(self):
177
185
  """Test single reference with no overlap potential."""
@@ -0,0 +1,4 @@
1
+ from unique_toolkit.agentic.tools.a2a.tool.config import SubAgentToolConfig
2
+ from unique_toolkit.agentic.tools.a2a.tool.service import SubAgentTool
3
+
4
+ __all__ = ["SubAgentTool", "SubAgentToolConfig"]
@@ -2,7 +2,7 @@ from unique_toolkit import ShortTermMemoryService
2
2
  from unique_toolkit.agentic.short_term_memory_manager.persistent_short_term_memory_manager import (
3
3
  PersistentShortMemoryManager,
4
4
  )
5
- from unique_toolkit.agentic.tools.a2a.schema import SubAgentShortTermMemorySchema
5
+ from unique_toolkit.agentic.tools.a2a.tool._schema import SubAgentShortTermMemorySchema
6
6
 
7
7
 
8
8
  def _get_short_term_memory_name(assistant_id: str) -> str:
@@ -1,15 +1,9 @@
1
1
  from pydantic import BaseModel
2
2
 
3
- from unique_toolkit.agentic.tools.schemas import ToolCallResponse
4
-
5
3
 
6
4
  class SubAgentToolInput(BaseModel):
7
5
  user_message: str
8
6
 
9
7
 
10
- class SubAgentToolCallResponse(ToolCallResponse):
11
- assistant_message: str
12
-
13
-
14
8
  class SubAgentShortTermMemorySchema(BaseModel):
15
9
  chat_id: str
@@ -0,0 +1,63 @@
1
+ from pydantic import Field
2
+
3
+ from unique_toolkit._common.pydantic_helpers import get_configuration_dict
4
+ from unique_toolkit.agentic.tools.schemas import BaseToolConfig
5
+
6
+ DEFAULT_PARAM_DESCRIPTION_SUB_AGENT_USER_MESSAGE = """
7
+ This is the message that will be sent to the sub-agent.
8
+ """.strip()
9
+
10
+ DEFAULT_FORMAT_INFORMATION_SUB_AGENT_SYSTEM_MESSAGE = """
11
+ NEVER mention any references from sub-agent answers in your response.
12
+ """.strip()
13
+
14
+
15
+ class SubAgentToolConfig(BaseToolConfig):
16
+ model_config = get_configuration_dict()
17
+
18
+ assistant_id: str = Field(
19
+ default="",
20
+ description="The unique identifier of the assistant to use for the sub-agent.",
21
+ )
22
+ chat_id: str | None = Field(
23
+ default=None,
24
+ description="The chat ID to use for the sub-agent conversation. If None, a new chat will be created.",
25
+ )
26
+ reuse_chat: bool = Field(
27
+ default=True,
28
+ description="Whether to reuse the existing chat or create a new one for each sub-agent call.",
29
+ )
30
+
31
+ tool_description_for_system_prompt: str = Field(
32
+ default="",
33
+ description="Description of the tool that will be included in the system prompt.",
34
+ )
35
+ tool_description: str = Field(
36
+ default="",
37
+ description="Description of the tool that will be included in the tools sent to the model.",
38
+ )
39
+ param_description_sub_agent_user_message: str = Field(
40
+ default=DEFAULT_PARAM_DESCRIPTION_SUB_AGENT_USER_MESSAGE,
41
+ description="Description of the user message parameter that will be sent to the model.",
42
+ )
43
+ tool_format_information_for_system_prompt: str = Field(
44
+ default=DEFAULT_FORMAT_INFORMATION_SUB_AGENT_SYSTEM_MESSAGE,
45
+ description="Format information that will be included in the system prompt to guide response formatting.",
46
+ )
47
+ tool_description_for_user_prompt: str = Field(
48
+ default="",
49
+ description="Description of the tool that will be included in the user prompt.",
50
+ )
51
+ tool_format_information_for_user_prompt: str = Field(
52
+ default="",
53
+ description="Format information that will be included in the user prompt to guide response formatting.",
54
+ )
55
+
56
+ poll_interval: float = Field(
57
+ default=1.0,
58
+ description="Time interval in seconds between polling attempts when waiting for sub-agent response.",
59
+ )
60
+ max_wait: float = Field(
61
+ default=120.0,
62
+ description="Maximum time in seconds to wait for the sub-agent response before timing out.",
63
+ )
@@ -1,3 +1,5 @@
1
+ import asyncio
2
+ import contextlib
1
3
  from typing import Protocol, override
2
4
 
3
5
  import unique_sdk
@@ -5,16 +7,16 @@ from pydantic import Field, create_model
5
7
  from unique_sdk.utils.chat_in_space import send_message_and_wait_for_completion
6
8
 
7
9
  from unique_toolkit.agentic.evaluation.schemas import EvaluationMetricName
8
- from unique_toolkit.agentic.tools.a2a.config import (
9
- SubAgentToolConfig,
10
- )
11
- from unique_toolkit.agentic.tools.a2a.memory import (
10
+ from unique_toolkit.agentic.tools.a2a.tool._memory import (
12
11
  get_sub_agent_short_term_memory_manager,
13
12
  )
14
- from unique_toolkit.agentic.tools.a2a.schema import (
13
+ from unique_toolkit.agentic.tools.a2a.tool._schema import (
15
14
  SubAgentShortTermMemorySchema,
16
15
  SubAgentToolInput,
17
16
  )
17
+ from unique_toolkit.agentic.tools.a2a.tool.config import (
18
+ SubAgentToolConfig,
19
+ )
18
20
  from unique_toolkit.agentic.tools.agent_chunks_hanlder import AgentChunksHandler
19
21
  from unique_toolkit.agentic.tools.factory import ToolFactory
20
22
  from unique_toolkit.agentic.tools.schemas import ToolCallResponse
@@ -34,10 +36,17 @@ from unique_toolkit.language_model.schemas import LanguageModelMessage
34
36
  class SubAgentResponseSubscriber(Protocol):
35
37
  def notify_sub_agent_response(
36
38
  self,
37
- sub_agent_assistant_id: str,
38
39
  response: unique_sdk.Space.Message,
40
+ sub_agent_assistant_id: str,
41
+ sequence_number: int,
39
42
  ) -> None: ...
40
43
 
44
+ """
45
+ Notify the subscriber that a sub agent response has been received.
46
+ Important: The subscriber should NOT modify the response in place.
47
+ The sequence number is a 1-indexed counter that is incremented for each concurrent run of the same sub agent.
48
+ """
49
+
41
50
 
42
51
  class SubAgentTool(Tool[SubAgentToolConfig]):
43
52
  name: str = "SubAgentTool"
@@ -66,12 +75,18 @@ class SubAgentTool(Tool[SubAgentToolConfig]):
66
75
  self._subscribers: list[SubAgentResponseSubscriber] = []
67
76
  self._should_run_evaluation = False
68
77
 
69
- def display_name(self) -> str:
70
- return self._display_name
78
+ # Synchronization state
79
+ self._sequence_number = 1
80
+ self._lock = asyncio.Lock()
71
81
 
72
82
  def subscribe(self, subscriber: SubAgentResponseSubscriber) -> None:
73
83
  self._subscribers.append(subscriber)
74
84
 
85
+ @override
86
+ def display_name(self) -> str:
87
+ return self._display_name
88
+
89
+ @override
75
90
  def tool_description(self) -> LanguageModelToolDescription:
76
91
  tool_input_model_with_description = create_model(
77
92
  "SubAgentToolInput",
@@ -87,27 +102,98 @@ class SubAgentTool(Tool[SubAgentToolConfig]):
87
102
  parameters=tool_input_model_with_description,
88
103
  )
89
104
 
105
+ @override
90
106
  def tool_description_for_system_prompt(self) -> str:
91
107
  return self.config.tool_description_for_system_prompt
92
108
 
109
+ @override
93
110
  def tool_format_information_for_system_prompt(self) -> str:
94
111
  return self.config.tool_format_information_for_system_prompt
95
112
 
113
+ @override
96
114
  def tool_description_for_user_prompt(self) -> str:
97
115
  return self.config.tool_description_for_user_prompt
98
116
 
117
+ @override
99
118
  def tool_format_information_for_user_prompt(self) -> str:
100
119
  return self.config.tool_format_information_for_user_prompt
101
120
 
121
+ @override
102
122
  def evaluation_check_list(self) -> list[EvaluationMetricName]:
103
123
  return [EvaluationMetricName.SUB_AGENT] if self._should_run_evaluation else []
104
124
 
125
+ @override
105
126
  def get_evaluation_checks_based_on_tool_response(
106
127
  self,
107
128
  tool_response: ToolCallResponse,
108
129
  ) -> list[EvaluationMetricName]:
109
130
  return []
110
131
 
132
+ @override
133
+ async def run(self, tool_call: LanguageModelFunction) -> ToolCallResponse:
134
+ tool_input = SubAgentToolInput.model_validate(tool_call.arguments)
135
+
136
+ if self._lock.locked():
137
+ await self._notify_progress(
138
+ tool_call=tool_call,
139
+ message=f"Waiting for another run of `{self.display_name()}` to finish",
140
+ state=ProgressState.STARTED,
141
+ )
142
+
143
+ # When reusing the chat id, executing the sub agent in parrallel leads to race conditions and undefined behavior.
144
+ # To avoid this, we use a lock to serialize the execution of the same sub agent.
145
+ context = self._lock if self.config.reuse_chat else contextlib.nullcontext()
146
+
147
+ async with context:
148
+ sequence_number = self._sequence_number
149
+ self._sequence_number += 1
150
+
151
+ await self._notify_progress(
152
+ tool_call=tool_call,
153
+ message=tool_input.user_message,
154
+ state=ProgressState.RUNNING,
155
+ )
156
+
157
+ # Check if there is a saved chat id in short term memory
158
+ chat_id = await self._get_chat_id()
159
+
160
+ response = await self._execute_and_handle_timeout(
161
+ tool_user_message=tool_input.user_message,
162
+ chat_id=chat_id,
163
+ tool_call=tool_call,
164
+ )
165
+
166
+ self._should_run_evaluation |= (
167
+ response["assessment"] is not None and len(response["assessment"]) > 0
168
+ ) # Run evaluation if any sub agent returned an assessment
169
+
170
+ self._notify_subscribers(response, sequence_number)
171
+
172
+ if chat_id is None:
173
+ await self._save_chat_id(response["chatId"])
174
+
175
+ if response["text"] is None:
176
+ raise ValueError("No response returned from sub agent")
177
+
178
+ await self._notify_progress(
179
+ tool_call=tool_call,
180
+ message=tool_input.user_message,
181
+ state=ProgressState.FINISHED,
182
+ )
183
+
184
+ return ToolCallResponse(
185
+ id=tool_call.id, # type: ignore
186
+ name=tool_call.name,
187
+ content=response["text"],
188
+ )
189
+
190
+ @override
191
+ def get_tool_call_result_for_loop_history(
192
+ self,
193
+ tool_response: ToolCallResponse,
194
+ agent_chunks_handler: AgentChunksHandler,
195
+ ) -> LanguageModelMessage: ... # Empty as method is deprecated
196
+
111
197
  async def _get_chat_id(self) -> str | None:
112
198
  if not self.config.reuse_chat:
113
199
  return None
@@ -145,8 +231,21 @@ class SubAgentTool(Tool[SubAgentToolConfig]):
145
231
  state=state,
146
232
  )
147
233
 
234
+ def _notify_subscribers(
235
+ self, response: unique_sdk.Space.Message, sequence_number: int
236
+ ) -> None:
237
+ for subsciber in self._subscribers:
238
+ subsciber.notify_sub_agent_response(
239
+ sub_agent_assistant_id=self.config.assistant_id,
240
+ response=response,
241
+ sequence_number=sequence_number,
242
+ )
243
+
148
244
  async def _execute_and_handle_timeout(
149
- self, tool_user_message: str, chat_id: str, tool_call: LanguageModelFunction
245
+ self,
246
+ tool_user_message: str,
247
+ chat_id: str | None,
248
+ tool_call: LanguageModelFunction,
150
249
  ) -> unique_sdk.Space.Message:
151
250
  try:
152
251
  return await send_message_and_wait_for_completion(
@@ -170,61 +269,5 @@ class SubAgentTool(Tool[SubAgentToolConfig]):
170
269
  "Timeout while waiting for response from sub agent. The user should consider increasing the max wait time.",
171
270
  ) from e
172
271
 
173
- def _notify_subscribers(self, response: unique_sdk.Space.Message) -> None:
174
- for subsciber in self._subscribers:
175
- subsciber.notify_sub_agent_response(
176
- sub_agent_assistant_id=self.config.assistant_id,
177
- response=response,
178
- )
179
-
180
- async def run(self, tool_call: LanguageModelFunction) -> ToolCallResponse:
181
- tool_input = SubAgentToolInput.model_validate(tool_call.arguments)
182
-
183
- await self._notify_progress(
184
- tool_call=tool_call,
185
- message=tool_input.user_message,
186
- state=ProgressState.RUNNING,
187
- )
188
-
189
- # Check if there is a saved chat id in short term memory
190
- chat_id = await self._get_chat_id()
191
-
192
- response = await self._execute_and_handle_timeout(
193
- tool_user_message=tool_input.user_message, # type: ignore
194
- chat_id=chat_id, # type: ignore
195
- tool_call=tool_call,
196
- )
197
-
198
- self._should_run_evaluation = (
199
- response["assessment"] is not None and len(response["assessment"]) > 0
200
- )
201
-
202
- self._notify_subscribers(response)
203
-
204
- if chat_id is None and self.config.reuse_chat:
205
- await self._save_chat_id(response["chatId"])
206
-
207
- if response["text"] is None:
208
- raise ValueError("No response returned from sub agent")
209
-
210
- await self._notify_progress(
211
- tool_call=tool_call,
212
- message=tool_input.user_message,
213
- state=ProgressState.FINISHED,
214
- )
215
-
216
- return ToolCallResponse(
217
- id=tool_call.id, # type: ignore
218
- name=tool_call.name,
219
- content=response["text"],
220
- )
221
-
222
- @override
223
- def get_tool_call_result_for_loop_history(
224
- self,
225
- tool_response: ToolCallResponse,
226
- agent_chunks_handler: AgentChunksHandler,
227
- ) -> LanguageModelMessage: ... # Empty as method is deprecated
228
-
229
272
 
230
273
  ToolFactory.register_tool(SubAgentTool, SubAgentToolConfig)