unique_toolkit 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,204 @@
1
+ from unittest.mock import AsyncMock
2
+
3
+ import pytest
4
+ from unique_toolkit.chat.service import ChatService
5
+ from unique_toolkit.content.schemas import ContentReference
6
+ from unique_toolkit.language_model.schemas import LanguageModelFunction
7
+ from unique_toolkit.tools.tool_progress_reporter import (
8
+ DUMMY_REFERENCE_PLACEHOLDER,
9
+ ProgressState,
10
+ ToolExecutionStatus,
11
+ ToolProgressReporter,
12
+ ToolWithToolProgressReporter,
13
+ track_tool_progress,
14
+ )
15
+
16
+
17
+ @pytest.fixture
18
+ def chat_service():
19
+ return AsyncMock(spec=ChatService)
20
+
21
+
22
+ @pytest.fixture
23
+ def tool_progress_reporter(chat_service):
24
+ return ToolProgressReporter(chat_service)
25
+
26
+
27
+ @pytest.fixture
28
+ def tool_call():
29
+ return LanguageModelFunction(id="test_id", name="test_tool")
30
+
31
+
32
+ class TestToolProgressReporter:
33
+ @pytest.mark.asyncio
34
+ async def test_notify_from_tool_call(self, tool_progress_reporter, tool_call):
35
+ # Arrange
36
+ name = "Test Tool"
37
+ message = "Processing..."
38
+ state = ProgressState.RUNNING
39
+ references = [
40
+ ContentReference(
41
+ sequence_number=1,
42
+ id="1",
43
+ message_id="1",
44
+ name="1",
45
+ source="1",
46
+ source_id="1",
47
+ url="1",
48
+ )
49
+ ]
50
+
51
+ # Act
52
+ await tool_progress_reporter.notify_from_tool_call(
53
+ tool_call=tool_call,
54
+ name=name,
55
+ message=message,
56
+ state=state,
57
+ references=references,
58
+ requires_new_assistant_message=True,
59
+ )
60
+
61
+ # Assert
62
+ assert tool_call.id in tool_progress_reporter.tool_statuses
63
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
64
+ assert status.name == name
65
+ assert status.message == message
66
+ assert status.state == state
67
+ assert status.references == references
68
+ assert tool_progress_reporter.requires_new_assistant_message is True
69
+
70
+ def test_replace_placeholders(self, tool_progress_reporter):
71
+ # Arrange
72
+ message = (
73
+ f"Test{DUMMY_REFERENCE_PLACEHOLDER}message{DUMMY_REFERENCE_PLACEHOLDER}"
74
+ )
75
+
76
+ # Act
77
+ result = tool_progress_reporter._replace_placeholders(message, start_number=1)
78
+
79
+ # Assert
80
+ assert result == "Test<sup>1</sup>message<sup>2</sup>"
81
+
82
+ def test_correct_reference_sequence(self, tool_progress_reporter):
83
+ # Arrange
84
+ references = [
85
+ ContentReference(
86
+ sequence_number=0,
87
+ id="1",
88
+ message_id="1",
89
+ name="1",
90
+ source="1",
91
+ source_id="1",
92
+ url="1",
93
+ ),
94
+ ContentReference(
95
+ sequence_number=0,
96
+ id="2",
97
+ message_id="2",
98
+ name="2",
99
+ source="2",
100
+ source_id="2",
101
+ url="2",
102
+ ),
103
+ ]
104
+
105
+ # Act
106
+ result = tool_progress_reporter._correct_reference_sequence(
107
+ references, start_number=1
108
+ )
109
+
110
+ # Assert
111
+ assert len(result) == 2
112
+ assert result[0].sequence_number == 1
113
+ assert result[1].sequence_number == 2
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_publish_updates_chat_service(
117
+ self, tool_progress_reporter, tool_call
118
+ ):
119
+ # Arrange
120
+ status = ToolExecutionStatus(
121
+ name="Test Tool",
122
+ message="Test message",
123
+ state=ProgressState.FINISHED,
124
+ references=[
125
+ ContentReference(
126
+ sequence_number=1,
127
+ id="1",
128
+ message_id="1",
129
+ name="1",
130
+ source="1",
131
+ source_id="1",
132
+ url="1",
133
+ )
134
+ ],
135
+ )
136
+ tool_progress_reporter.tool_statuses[tool_call.id] = status
137
+
138
+ # Act
139
+ await tool_progress_reporter.publish()
140
+
141
+ # Assert
142
+ tool_progress_reporter.chat_service.modify_assistant_message_async.assert_called_once()
143
+
144
+
145
+ class TestToolProgressDecorator:
146
+ class DummyTool(ToolWithToolProgressReporter):
147
+ def __init__(self, tool_progress_reporter):
148
+ self.tool_progress_reporter = tool_progress_reporter
149
+
150
+ @track_tool_progress(
151
+ message="Processing",
152
+ on_start_state=ProgressState.STARTED,
153
+ on_success_state=ProgressState.FINISHED,
154
+ on_success_message="Completed",
155
+ requires_new_assistant_message=True,
156
+ )
157
+ async def execute(self, tool_call, notification_tool_name):
158
+ return {
159
+ "references": [
160
+ ContentReference(
161
+ sequence_number=1,
162
+ id="1",
163
+ message_id="1",
164
+ name="1",
165
+ source="1",
166
+ source_id="1",
167
+ url="1",
168
+ )
169
+ ]
170
+ }
171
+
172
+ @pytest.mark.asyncio
173
+ async def test_decorator_success_flow(self, tool_progress_reporter, tool_call):
174
+ # Arrange
175
+ tool = self.DummyTool(tool_progress_reporter)
176
+
177
+ # Act
178
+ await tool.execute(tool_call, "Test Tool")
179
+
180
+ # Assert
181
+ assert len(tool_progress_reporter.tool_statuses) == 1
182
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
183
+ assert status.state == ProgressState.FINISHED
184
+ assert status.message == "Completed"
185
+
186
+ @pytest.mark.asyncio
187
+ async def test_decorator_error_flow(self, tool_progress_reporter, tool_call):
188
+ # Arrange
189
+ class ErrorTool(ToolWithToolProgressReporter):
190
+ def __init__(self, tool_progress_reporter):
191
+ self.tool_progress_reporter = tool_progress_reporter
192
+
193
+ @track_tool_progress(message="Processing")
194
+ async def execute(self, tool_call, notification_tool_name):
195
+ raise ValueError("Test error")
196
+
197
+ tool = ErrorTool(tool_progress_reporter)
198
+
199
+ # Act & Assert
200
+ with pytest.raises(ValueError):
201
+ await tool.execute(tool_call, "Test Tool")
202
+
203
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
204
+ assert status.state == ProgressState.FAILED
@@ -0,0 +1,168 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import StrEnum
3
+ from logging import getLogger
4
+ from typing import Generic, TypeVar
5
+ from typing import Any, cast
6
+
7
+ from pydantic import Field
8
+ from typing_extensions import deprecated
9
+ from unique_toolkit.app.schemas import ChatEvent
10
+ from unique_toolkit.chat.service import (
11
+ ChatService,
12
+ )
13
+ from unique_toolkit.language_model import LanguageModelToolDescription
14
+ from unique_toolkit.language_model.schemas import (
15
+ LanguageModelFunction,
16
+ LanguageModelMessage,
17
+ )
18
+ from unique_toolkit.language_model.service import LanguageModelService
19
+
20
+
21
+ from unique_toolkit.evaluators.schemas import EvaluationMetricName
22
+ from unique_toolkit.tools.agent_chunks_handler import AgentChunksHandler
23
+ from unique_toolkit.tools.config import ToolBuildConfig, ToolSelectionPolicy
24
+ from unique_toolkit.tools.schemas import BaseToolConfig, ToolCallResponse, ToolPrompts
25
+ from unique_toolkit.tools.tool_progress_reporter import ToolProgressReporter
26
+
27
+ ConfigType = TypeVar("ConfigType", bound=BaseToolConfig)
28
+
29
+
30
+ class Tool(ABC, Generic[ConfigType]):
31
+ name: str
32
+ settings: ToolBuildConfig
33
+
34
+ def display_name(self) -> str:
35
+ """The display name of the tool."""
36
+ return self.settings.display_name
37
+
38
+ def icon(self) -> str:
39
+ """The icon of the tool."""
40
+ return self.settings.icon
41
+
42
+ def selection_policy(self) -> ToolSelectionPolicy:
43
+ """The selection policy of the tool."""
44
+ return self.settings.selection_policy
45
+
46
+ def is_exclusive(self) -> bool:
47
+ """Whether the tool is exclusive or not."""
48
+ return self.settings.is_exclusive
49
+
50
+ def is_enabled(self) -> bool:
51
+ """Whether the tool is enabled or not."""
52
+ return self.settings.is_enabled
53
+
54
+ @abstractmethod
55
+ def tool_description(self) -> LanguageModelToolDescription:
56
+ raise NotImplementedError
57
+
58
+ def tool_description_as_json(self) -> dict[str, Any]:
59
+ parameters = self.tool_description().parameters
60
+ if not isinstance(parameters, dict):
61
+ return parameters.model_json_schema()
62
+ else:
63
+ return cast("dict[str, Any]", parameters)
64
+
65
+ @abstractmethod
66
+ def tool_description_for_system_prompt(self) -> str:
67
+ raise NotImplementedError
68
+
69
+ @abstractmethod
70
+ def tool_format_information_for_system_prompt(self) -> str:
71
+ raise NotImplementedError
72
+
73
+ def tool_format_reminder_for_user_prompt(self) -> str:
74
+ """A short reminder for the user prompt for formatting rules for the tool.
75
+ You can use this if the LLM fails to follow the formatting rules.
76
+ """
77
+ return ""
78
+
79
+ @deprecated("Do not use as is bound to loop agent only")
80
+ @abstractmethod
81
+ def get_tool_call_result_for_loop_history(
82
+ self,
83
+ tool_response: ToolCallResponse,
84
+ agent_chunks_handler: AgentChunksHandler,
85
+ ) -> LanguageModelMessage:
86
+ raise NotImplementedError
87
+
88
+ @deprecated(
89
+ "Do not use. The tool should not determine how"
90
+ "it is checked. This should be defined by the user"
91
+ "of the tool."
92
+ )
93
+ @abstractmethod
94
+ def evaluation_check_list(self) -> list[EvaluationMetricName]:
95
+ raise NotImplementedError
96
+
97
+ @abstractmethod
98
+ async def run(self, tool_call: LanguageModelFunction) -> ToolCallResponse:
99
+ raise NotImplementedError
100
+
101
+ @deprecated(
102
+ "Do not use as the evaluation checks should not be determined by\n"
103
+ "the tool. The decision on what check should be done is up to the\n"
104
+ "user of the tool or the dev.",
105
+ )
106
+ @abstractmethod
107
+ def get_evaluation_checks_based_on_tool_response(
108
+ self,
109
+ tool_response: ToolCallResponse,
110
+ ) -> list[EvaluationMetricName]:
111
+ raise NotImplementedError
112
+
113
+ def get_tool_prompts(self) -> ToolPrompts:
114
+ return ToolPrompts(
115
+ name=self.name,
116
+ display_name=self.display_name(),
117
+ tool_description=self.tool_description().description,
118
+ tool_format_information_for_system_prompt=self.tool_format_information_for_system_prompt(),
119
+ input_model=self.tool_description_as_json(),
120
+ )
121
+
122
+ # Properties that we should soon deprecate
123
+
124
+ @property
125
+ @deprecated("Never reuse event. Dangerous")
126
+ def event(self) -> ChatEvent:
127
+ return self._event
128
+
129
+ @property
130
+ @deprecated("Do not use this property as directly tied to chat frontend")
131
+ def chat_service(self) -> ChatService:
132
+ return self._chat_service
133
+
134
+ @property
135
+ @deprecated("Do not use this property as directly tied to chat frontend")
136
+ def language_model_service(self) -> LanguageModelService:
137
+ return self._language_model_service
138
+
139
+ @property
140
+ @deprecated("Do not use this as directly tied to chat frontend")
141
+ def tool_progress_reporter(self) -> ToolProgressReporter | None:
142
+ return self._tool_progress_reporter
143
+
144
+ def __init__(
145
+ self,
146
+ config: ConfigType,
147
+ event: ChatEvent,
148
+ tool_progress_reporter: ToolProgressReporter | None = None,
149
+ ):
150
+ self.settings = ToolBuildConfig(
151
+ name=self.name,
152
+ configuration=config, # type: ignore
153
+ # the ToolBuildConfig has a wrong type in it to be fixed later.
154
+ )
155
+
156
+ self.config = config
157
+ module_name = "default overwrite for module name"
158
+ self.logger = getLogger(f"{module_name}.{__name__}")
159
+ self.debug_info: dict = {}
160
+
161
+ # TODO: Remove these properties as soon as possible
162
+ self._event: ChatEvent = event
163
+ self._tool_progress_reporter: ToolProgressReporter | None = (
164
+ tool_progress_reporter
165
+ )
166
+
167
+ self._chat_service = ChatService(event)
168
+ self._language_model_service = LanguageModelService(event)
@@ -0,0 +1,242 @@
1
+ import asyncio
2
+ from logging import Logger, getLogger
3
+ from pydantic import BaseModel, Field
4
+ from unique_toolkit.app.schemas import ChatEvent
5
+ from unique_toolkit.language_model.schemas import (
6
+ LanguageModelFunction,
7
+ LanguageModelToolDescription,
8
+ )
9
+ from unique_toolkit.tools.config import ToolBuildConfig
10
+ from unique_toolkit.tools.factory import ToolFactory
11
+ from unique_toolkit.tools.schemas import ToolCallResponse, ToolPrompts
12
+ from unique_toolkit.tools.tool import Tool
13
+ from unique_toolkit.tools.tool_progress_reporter import ToolProgressReporter
14
+ from unique_toolkit.tools.utils.execution.execution import Result, SafeTaskExecutor
15
+
16
+
17
+ class ForcedToolOption:
18
+ type: str = "function"
19
+
20
+ def __init__(self, name: str):
21
+ self.name = name
22
+
23
+
24
+ class ToolManagerConfig(BaseModel):
25
+ tools: list[ToolBuildConfig] = Field(
26
+ default=[],
27
+ description="List of tools that the agent can use.",
28
+ )
29
+
30
+ max_tool_calls: int = Field(
31
+ default=10,
32
+ ge=1,
33
+ description="Maximum number of tool calls that can be executed in one iteration.",
34
+ )
35
+
36
+ def __init__(self, tools: list[ToolBuildConfig], max_tool_calls: int = 10):
37
+ self.tools = tools
38
+ self.max_tool_calls = max_tool_calls
39
+
40
+
41
+ class ToolManager:
42
+ """
43
+ Manages the tools available to the agent and executes tool calls.
44
+
45
+ This class is responsible for:
46
+ - Initializing tools based on the provided configuration and runtime events.
47
+ - Filtering tools based on availability, exclusivity, and user-defined constraints.
48
+ - Managing the lifecycle of tools, including retrieval, execution, and logging.
49
+ - Executing tool calls in parallel when possible to optimize performance.
50
+ - Enforcing limits on the number of tool calls and handling duplicate requests.
51
+
52
+ Key Features:
53
+ - Dynamic Tool Initialization: Tools are dynamically selected and initialized
54
+ based on runtime events and user preferences.
55
+ - Parallel Execution: Supports asynchronous execution of tools for efficiency.
56
+ - Error Handling: Provides detailed error messages and logs for failed tool calls.
57
+ - Scalability: Designed to handle a large number of tools and tool calls efficiently.
58
+
59
+ Only the ToolManager is allowed to interact with the tools directly.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ logger: Logger,
65
+ config: ToolManagerConfig,
66
+ event: ChatEvent,
67
+ tool_progress_reporter: ToolProgressReporter,
68
+ ):
69
+ self._logger = logger
70
+ self._config = config
71
+ self._tool_progress_reporter = tool_progress_reporter
72
+ self._tools = []
73
+ self._tool_choices = event.payload.tool_choices
74
+ self._disabled_tools = event.payload.disabled_tools
75
+ self._init__tools(event)
76
+
77
+ def _init__tools(self, event: ChatEvent) -> None:
78
+ tool_choices = self._tool_choices
79
+ tool_configs = self._config .tools
80
+ self._logger.info("Initializing tool definitions...")
81
+ self._logger.info(f"Tool choices: {tool_choices}")
82
+ self._logger.info(f"Tool configs: {tool_configs}")
83
+
84
+ self.available_tools = [
85
+ ToolFactory.build_tool_with_settings(
86
+ t.name,
87
+ t,
88
+ t.configuration,
89
+ event,
90
+ tool_progress_reporter=self._tool_progress_reporter,
91
+ )
92
+ for t in tool_configs
93
+ ]
94
+
95
+ for t in self.available_tools:
96
+ if t.is_exclusive():
97
+ self._tools = [t]
98
+ return
99
+ if not t.is_enabled():
100
+ continue
101
+ if t.name in self._disabled_tools:
102
+ continue
103
+ if len(tool_choices) > 0 and t.name not in tool_choices:
104
+ continue
105
+
106
+ self._tools.append(t)
107
+
108
+ def log_loaded_tools(self):
109
+ self._logger.info(f"Loaded tools: {[tool.name for tool in self._tools]}")
110
+
111
+ def get_tools(self) -> list[Tool]:
112
+ return self._tools
113
+
114
+ def get_tool_by_name(self, name: str) -> Tool | None:
115
+ for tool in self._tools:
116
+ if tool.name == name:
117
+ return tool
118
+ return None
119
+
120
+ def get_forced_tools(self) -> list[ForcedToolOption]:
121
+ return [ForcedToolOption(t.name) for t in self._tools if t.name in self._tool_choices]
122
+
123
+ def get_tool_definitions(self) -> list[LanguageModelToolDescription]:
124
+ return [tool.tool_description() for tool in self._tools]
125
+
126
+ def get_tool_prompts(self) -> list[ToolPrompts]:
127
+ return [tool.get_tool_prompts() for tool in self._tools]
128
+
129
+ async def execute_selected_tools(
130
+ self,
131
+ tool_calls: list[LanguageModelFunction],
132
+ ) -> list[ToolCallResponse]:
133
+ tool_calls = tool_calls
134
+
135
+ tool_calls = self.filter_duplicate_tool_calls(
136
+ tool_calls=tool_calls,
137
+ )
138
+ num_tool_calls = len(tool_calls)
139
+
140
+ if num_tool_calls > self._config .max_tool_calls:
141
+ self._logger.warning(
142
+ (
143
+ "Number of tool calls %s exceeds the allowed maximum of %s."
144
+ "The tool calls will be reduced to the first %s."
145
+ ),
146
+ num_tool_calls,
147
+ self._config .max_tool_calls,
148
+ self._config .max_tool_calls,
149
+ )
150
+ tool_calls = tool_calls[: self._config .max_tool_calls]
151
+
152
+ tool_call_responses = await self._execute_parallelized(tool_calls=tool_calls)
153
+ return tool_call_responses
154
+
155
+ async def _execute_parallelized(
156
+ self,
157
+ tool_calls: list[LanguageModelFunction],
158
+ ) -> list[ToolCallResponse]:
159
+ self._logger.info("Execute tool calls")
160
+
161
+ task_executor = SafeTaskExecutor(
162
+ logger=self._logger,
163
+ )
164
+
165
+ # Create tasks for each tool call
166
+ tasks = [
167
+ task_executor.execute_async(
168
+ self.execute_tool_call,
169
+ tool_call=tool_call,
170
+ )
171
+ for tool_call in tool_calls
172
+ ]
173
+
174
+ # Wait until all tasks are finished
175
+ tool_call_results = await asyncio.gather(*tasks)
176
+ tool_call_results_unpacked: list[ToolCallResponse] = []
177
+ for i, result in enumerate(tool_call_results):
178
+ unpacked_tool_call_result = self._create_tool_call_response(
179
+ result, tool_calls[i]
180
+ )
181
+ tool_call_results_unpacked.append(unpacked_tool_call_result)
182
+
183
+ return tool_call_results_unpacked
184
+
185
+ async def execute_tool_call(
186
+ self, tool_call: LanguageModelFunction
187
+ ) -> ToolCallResponse:
188
+ self._logger.info(f"Processing tool call: {tool_call.name}")
189
+
190
+ tool_instance = self.get_tool_by_name(tool_call.name)
191
+
192
+ if tool_instance:
193
+ # Execute the tool
194
+ tool_response: ToolCallResponse = await tool_instance.run(
195
+ tool_call=tool_call
196
+ )
197
+ return tool_response
198
+
199
+ return ToolCallResponse(
200
+ id=tool_call.id, # type: ignore
201
+ name=tool_call.name,
202
+ error_message=f"Tool of name {tool_call.name} not found",
203
+ )
204
+
205
+ def _create_tool_call_response(
206
+ self, result: Result[ToolCallResponse], tool_call: LanguageModelFunction
207
+ ) -> ToolCallResponse:
208
+ if not result.success:
209
+ return ToolCallResponse(
210
+ id=tool_call.id or "unknown_id",
211
+ name=tool_call.name,
212
+ error_message=str(result.exception),
213
+ )
214
+ unpacked = result.unpack()
215
+ if not isinstance(unpacked, ToolCallResponse):
216
+ return ToolCallResponse(
217
+ id=tool_call.id or "unknown_id",
218
+ name=tool_call.name,
219
+ error_message="Tool call response is not of type ToolCallResponse",
220
+ )
221
+ return unpacked
222
+
223
+ def filter_duplicate_tool_calls(
224
+ self,
225
+ tool_calls: list[LanguageModelFunction],
226
+ ) -> list[LanguageModelFunction]:
227
+ """
228
+ Filter out duplicate tool calls based on name and arguments.
229
+ """
230
+
231
+ unique_tool_calls = []
232
+
233
+ for call in tool_calls:
234
+ if all(not call == other_call for other_call in unique_tool_calls):
235
+ unique_tool_calls.append(call)
236
+
237
+ if len(tool_calls) != len(unique_tool_calls):
238
+ self._logger = getLogger(__name__)
239
+ self._logger.warning(
240
+ f"Filtered out {len(tool_calls) - len(unique_tool_calls)} duplicate tool calls."
241
+ )
242
+ return unique_tool_calls
@@ -88,28 +88,21 @@ class ToolProgressReporter:
88
88
  references=references,
89
89
  )
90
90
  self.requires_new_assistant_message = (
91
- self.requires_new_assistant_message
92
- or requires_new_assistant_message
91
+ self.requires_new_assistant_message or requires_new_assistant_message
93
92
  )
94
93
  await self.publish()
95
94
 
96
95
  async def publish(self):
97
96
  messages = []
98
97
  all_references = []
99
- for item in sorted(
100
- self.tool_statuses.values(), key=lambda x: x.timestamp
101
- ):
98
+ for item in sorted(self.tool_statuses.values(), key=lambda x: x.timestamp):
102
99
  references = item.references
103
100
  start_number = len(all_references) + 1
104
101
  message = self._replace_placeholders(item.message, start_number)
105
- references = self._correct_reference_sequence(
106
- references, start_number
107
- )
102
+ references = self._correct_reference_sequence(references, start_number)
108
103
  all_references.extend(references)
109
104
 
110
- messages.append(
111
- f"{ARROW}**{item.name} {item.state.value}**: {message}"
112
- )
105
+ messages.append(f"{ARROW}**{item.name} {item.state.value}**: {message}")
113
106
 
114
107
  await self.chat_service.modify_assistant_message_async(
115
108
  content=self._progress_start_text + "\n\n" + "\n\n".join(messages),