unique_toolkit 0.8.12__py3-none-any.whl → 0.8.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,174 @@
1
+ import json
2
+ import logging
3
+ from copy import deepcopy
4
+
5
+ from unique_toolkit.content.schemas import ContentChunk, ContentMetadata
6
+ from unique_toolkit.language_model.schemas import (
7
+ LanguageModelAssistantMessage,
8
+ LanguageModelMessage,
9
+ LanguageModelToolMessage,
10
+ )
11
+ from unique_toolkit.tools.schemas import Source
12
+ from unique_toolkit.tools.utils.source_handling.schema import (
13
+ SourceFormatConfig,
14
+ )
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def convert_tool_interactions_to_content_messages(
21
+ loop_history: list[LanguageModelMessage],
22
+ ) -> list[LanguageModelMessage]:
23
+ new_loop_history = []
24
+ copy_loop_history = deepcopy(loop_history)
25
+
26
+ for message in copy_loop_history:
27
+ if isinstance(message, LanguageModelAssistantMessage) and message.tool_calls:
28
+ new_loop_history.append(_convert_tool_call_to_content(message))
29
+
30
+ elif isinstance(message, LanguageModelToolMessage):
31
+ new_loop_history.append(_convert_tool_call_response_to_content(message))
32
+ else:
33
+ new_loop_history.append(message)
34
+
35
+ return new_loop_history
36
+
37
+
38
+ def _convert_tool_call_to_content(
39
+ assistant_message: LanguageModelAssistantMessage,
40
+ ) -> LanguageModelAssistantMessage:
41
+ assert assistant_message.tool_calls is not None
42
+ new_content = "The assistant requested the following tool_call:"
43
+ for tool_call in assistant_message.tool_calls:
44
+ new_content += (
45
+ f"\n\n- {tool_call.function.name}: {tool_call.function.arguments}"
46
+ )
47
+ assistant_message.tool_calls = None
48
+ assistant_message.content = new_content
49
+
50
+ return assistant_message
51
+
52
+
53
+ def _convert_tool_call_response_to_content(
54
+ tool_message: LanguageModelToolMessage,
55
+ ) -> LanguageModelAssistantMessage:
56
+ new_content = f"The assistant received the following tool_call_response: {tool_message.name}, {tool_message.content}"
57
+ assistant_message = LanguageModelAssistantMessage(
58
+ content=new_content,
59
+ )
60
+ return assistant_message
61
+
62
+
63
+ def transform_chunks_to_string(
64
+ content_chunks: list[ContentChunk],
65
+ max_source_number: int,
66
+ cfg: SourceFormatConfig | None,
67
+ full_sources_serialize_dump: bool = False,
68
+ ) -> str:
69
+ """Transform content chunks into a string of sources.
70
+
71
+ Args:
72
+ content_chunks (list[ContentChunk]): The content chunks to transform
73
+ max_source_number (int): The maximum source number to use
74
+ feature_full_sources (bool, optional): Whether to include the full source object. Defaults to False which is the old format.
75
+
76
+ Returns:
77
+ str: String for the tool call response
78
+ """
79
+ if not content_chunks:
80
+ return "No relevant sources found."
81
+ if full_sources_serialize_dump:
82
+ sources = [
83
+ Source(
84
+ source_number=max_source_number + i,
85
+ key=chunk.key,
86
+ id=chunk.id,
87
+ order=chunk.order,
88
+ content=chunk.text,
89
+ chunk_id=chunk.chunk_id,
90
+ metadata=(
91
+ _format_metadata(chunk.metadata, cfg) or None
92
+ if chunk.metadata
93
+ else None
94
+ ),
95
+ url=chunk.url,
96
+ ).model_dump(
97
+ exclude_none=True,
98
+ exclude_defaults=True,
99
+ by_alias=True,
100
+ )
101
+ for i, chunk in enumerate(content_chunks)
102
+ ]
103
+ else:
104
+ sources = [
105
+ {
106
+ "source_number": max_source_number + i,
107
+ "content": chunk.text,
108
+ **(
109
+ {"metadata": meta}
110
+ if (
111
+ meta := _format_metadata(chunk.metadata, cfg)
112
+ ) # only add when not empty
113
+ else {}
114
+ ),
115
+ }
116
+ for i, chunk in enumerate(content_chunks)
117
+ ]
118
+ return json.dumps(sources)
119
+
120
+
121
+ def load_sources_from_string(
122
+ source_string: str,
123
+ ) -> list[Source] | None:
124
+ """Transform JSON string from language model tool message in the tool call response into Source objects"""
125
+
126
+ try:
127
+ # First, try parsing as JSON (new format)
128
+ sources_data = json.loads(source_string)
129
+ return [Source.model_validate(source) for source in sources_data]
130
+ except (json.JSONDecodeError, ValueError):
131
+ logger.warning("Failed to parse source string")
132
+ return None
133
+
134
+
135
+ def _format_metadata(
136
+ metadata: ContentMetadata | None,
137
+ cfg: SourceFormatConfig | None,
138
+ ) -> str:
139
+ """
140
+ Build the concatenated tag string from the chunk's metadata
141
+ and the templates found in cfg.sections.
142
+ Example result:
143
+ "<|topic|>GenAI<|/topic|>\n<|date|>This document is from: 2025-04-29<|/date|>\n"
144
+ """
145
+ if metadata is None:
146
+ return ""
147
+
148
+ if cfg is None or not cfg.sections:
149
+ # If no configuration is provided, return empty string
150
+ return ""
151
+
152
+ meta_dict = metadata.model_dump(exclude_none=True, by_alias=True)
153
+ sections = cfg.sections
154
+
155
+ parts: list[str] = []
156
+ for key, template in sections.items():
157
+ if key in meta_dict:
158
+ parts.append(template.format(meta_dict[key]))
159
+
160
+ return "".join(parts)
161
+
162
+
163
+ ### In case we do not want any formatting of metadata we could use this function
164
+ # def _filtered_metadata(
165
+ # meta: ContentMetadata | None,
166
+ # cfg: SourceFormatConfig,
167
+ # ) -> dict[str, str] | None:
168
+ # if meta is None:
169
+ # return None
170
+
171
+ # allowed = set(cfg.sections)
172
+ # raw = meta.model_dump(exclude_none=True, by_alias=True)
173
+ # pruned = {k: v for k, v in raw.items() if k in allowed}
174
+ # return pruned or None
@@ -0,0 +1,122 @@
1
+ from abc import ABC
2
+ import asyncio
3
+ from logging import Logger
4
+
5
+ from unique_toolkit.chat.service import ChatService
6
+ from unique_toolkit.language_model.schemas import (
7
+ LanguageModelMessage,
8
+ LanguageModelStreamResponse,
9
+ )
10
+ from unique_toolkit.tools.utils.execution.execution import Result, SafeTaskExecutor
11
+
12
+
13
+ class Postprocessor(ABC):
14
+ def __init__(self, name: str):
15
+ self.name = name
16
+
17
+ def get_name(self) -> str:
18
+ return self.name
19
+
20
+ async def run(self, loop_response: LanguageModelStreamResponse) -> str:
21
+ raise NotImplementedError("Subclasses must implement this method.")
22
+
23
+ async def apply_postprocessing_to_response(
24
+ self, loop_response: LanguageModelStreamResponse
25
+ ) -> bool:
26
+ raise NotImplementedError(
27
+ "Subclasses must implement this method to apply post-processing to the response."
28
+ )
29
+
30
+ async def remove_from_text(self, text) -> str:
31
+ raise NotImplementedError(
32
+ "Subclasses must implement this method to remove post-processing from the message."
33
+ )
34
+
35
+
36
+ class PostprocessorManager:
37
+ """
38
+ Manages and executes postprocessors for modifying and refining responses.
39
+
40
+ This class is responsible for:
41
+ - Storing and managing a collection of postprocessor instances.
42
+ - Executing postprocessors asynchronously to refine loop responses.
43
+ - Applying modifications to assistant messages based on postprocessor results.
44
+ - Providing utility methods for text manipulation using postprocessors.
45
+
46
+ Key Features:
47
+ - Postprocessor Management: Allows adding and retrieving postprocessor instances.
48
+ - Asynchronous Execution: Runs all postprocessors concurrently for efficiency.
49
+ - Response Modification: Applies postprocessing changes to assistant messages when necessary.
50
+ - Text Cleanup: Supports removing specific patterns or content from text using postprocessors.
51
+ - Error Handling: Logs warnings for any postprocessors that fail during execution.
52
+
53
+ The PostprocessorManager serves as a centralized system for managing and applying postprocessing logic to enhance response quality and consistency.
54
+ """
55
+
56
+ _postprocessors: list[Postprocessor] = []
57
+
58
+ def __init__(
59
+ self,
60
+ logger: Logger,
61
+ chat_service: ChatService,
62
+ ):
63
+ self._logger = logger
64
+ self._chat_service = chat_service
65
+
66
+ def add_postprocessor(self, postprocessor: Postprocessor):
67
+ self._postprocessors.append(postprocessor)
68
+
69
+ def get_postprocessors(self, name: str) -> list[Postprocessor]:
70
+ return self._postprocessors
71
+
72
+ async def run_postprocessors(
73
+ self,
74
+ loop_response: LanguageModelStreamResponse,
75
+ ) -> None:
76
+ task_executor = SafeTaskExecutor(
77
+ logger=self._logger,
78
+ )
79
+
80
+ tasks = [
81
+ task_executor.execute_async(
82
+ self.execute_postprocessors,
83
+ loop_response=loop_response,
84
+ postprocessor_instance=postprocessor,
85
+ )
86
+ for postprocessor in self._postprocessors
87
+ ]
88
+ postprocessor_results = await asyncio.gather(*tasks)
89
+
90
+ for i, result in enumerate(postprocessor_results):
91
+ if not result.success:
92
+ self._logger.warning(
93
+ f"Postprocessor {self._postprocessors[i].get_name()} failed to run."
94
+ )
95
+
96
+ modification_results = [
97
+ postprocessor.apply_postprocessing_to_response(loop_response)
98
+ for postprocessor in self._postprocessors
99
+ ]
100
+
101
+ has_been_modified = any(modification_results)
102
+
103
+ if has_been_modified:
104
+ self._chat_service.modify_assistant_message(
105
+ content=loop_response.message.text,
106
+ message_id=loop_response.message.id,
107
+ )
108
+
109
+ async def execute_postprocessors(
110
+ self,
111
+ loop_response: LanguageModelStreamResponse,
112
+ postprocessor_instance: Postprocessor,
113
+ ) -> None:
114
+ await postprocessor_instance.run(loop_response)
115
+
116
+ async def remove_from_text(
117
+ self,
118
+ text: str,
119
+ ) -> str:
120
+ for postprocessor in self._postprocessors:
121
+ text = await postprocessor.remove_from_text(text)
122
+ return text
@@ -9,6 +9,25 @@ class tool_chunks:
9
9
 
10
10
 
11
11
  class ReferenceManager:
12
+ """
13
+ Manages content chunks and references extracted from tool responses.
14
+
15
+ This class is responsible for:
16
+ - Extracting and storing referenceable content chunks from tool responses.
17
+ - Managing a collection of content chunks and their associated references.
18
+ - Providing methods to retrieve, replace, and manipulate chunks and references.
19
+ - Supporting the retrieval of the latest references and their corresponding chunks.
20
+
21
+ Key Features:
22
+ - Chunk Extraction: Extracts content chunks from tool responses and organizes them for reference.
23
+ - Reference Management: Tracks references to content chunks and allows for easy retrieval.
24
+ - Latest Reference Access: Provides methods to fetch the most recent references and their associated chunks.
25
+ - Flexible Chunk Replacement: Allows for replacing the current set of chunks with a new list.
26
+ - Reference-to-Chunk Mapping: Matches references to their corresponding chunks based on source IDs.
27
+
28
+ The ReferenceManager serves as a utility for managing and linking content chunks with references, enabling efficient content tracking and retrieval.
29
+ """
30
+
12
31
  def __init__(self):
13
32
  self._tool_chunks: dict[str, tool_chunks] = {}
14
33
  self._chunks: list[ContentChunk] = []
@@ -0,0 +1,140 @@
1
+ import base64
2
+ import zlib
3
+ from logging import getLogger
4
+ from typing import Generic, Type, TypeVar
5
+
6
+ from pydantic import BaseModel
7
+ from unique_toolkit.short_term_memory.schemas import ShortTermMemory
8
+ from unique_toolkit.short_term_memory.service import ShortTermMemoryService
9
+ from unique_toolkit.tools.utils.execution.execution import SafeTaskExecutor
10
+
11
+
12
+ TSchema = TypeVar("TSchema", bound=BaseModel)
13
+
14
+
15
+ logger = getLogger(__name__)
16
+
17
+
18
+ def _default_short_term_memory_name(schema: type[BaseModel]) -> str:
19
+ return f"{schema.__name__}Key"
20
+
21
+
22
+ def _compress_data_zlib_base64(data: str) -> str:
23
+ """Compress data using ZLIB and encode as base64 string."""
24
+ compressed = zlib.compress(data.encode("utf-8"))
25
+ return base64.b64encode(compressed).decode("utf-8")
26
+
27
+
28
+ def _decompress_data_zlib_base64(compressed_data: str) -> str:
29
+ """Decompress base64 encoded ZLIB data."""
30
+ decoded = base64.b64decode(compressed_data.encode("utf-8"))
31
+ return zlib.decompress(decoded).decode("utf-8")
32
+
33
+
34
+ class PersistentShortMemoryManager(Generic[TSchema]):
35
+ """
36
+ Manages the storage, retrieval, and processing of short-term memory in a persistent manner.
37
+
38
+ This class is responsible for:
39
+ - Saving and loading short-term memory data, both synchronously and asynchronously.
40
+ - Compressing and decompressing memory data for efficient storage.
41
+ - Validating and processing memory data using a predefined schema.
42
+ - Logging the status of memory operations, such as whether memory was found or saved.
43
+
44
+ Key Features:
45
+ - Persistent Storage: Integrates with a short-term memory service to store and retrieve memory data.
46
+ - Compression Support: Compresses memory data before saving and decompresses it upon retrieval.
47
+ - Schema Validation: Ensures memory data adheres to a specified schema for consistency.
48
+ - Synchronous and Asynchronous Operations: Supports both sync and async methods for flexibility.
49
+ - Logging and Debugging: Provides detailed logs for memory operations, including success and failure cases.
50
+
51
+ The PersistentShortMemoryManager is designed to handle short-term memory efficiently, ensuring data integrity and optimized storage.
52
+ """
53
+ def __init__(
54
+ self,
55
+ short_term_memory_service: ShortTermMemoryService,
56
+ short_term_memory_schema: Type[TSchema],
57
+ short_term_memory_name: str | None = None,
58
+ ) -> None:
59
+ self._short_term_memory_name = (
60
+ short_term_memory_name
61
+ if short_term_memory_name
62
+ else _default_short_term_memory_name(short_term_memory_schema)
63
+ )
64
+ self._short_term_memory_schema = short_term_memory_schema
65
+ self._short_term_memory_service = short_term_memory_service
66
+
67
+ self._executor = SafeTaskExecutor(
68
+ log_exceptions=False,
69
+ )
70
+
71
+ def _log_not_found(self) -> None:
72
+ logger.warning(
73
+ f"No short term memory found for chat {self._short_term_memory_service.chat_id} and key {self._short_term_memory_name}"
74
+ )
75
+
76
+ def _log_found(self) -> None:
77
+ logger.debug(
78
+ f"Short term memory found for chat {self._short_term_memory_service.chat_id} and key {self._short_term_memory_name}"
79
+ )
80
+
81
+ def _find_latest_memory_sync(self) -> ShortTermMemory | None:
82
+ result = self._executor.execute(
83
+ self._short_term_memory_service.find_latest_memory,
84
+ self._short_term_memory_name,
85
+ )
86
+
87
+ self._log_not_found() if not result.success else self._log_found()
88
+
89
+ return result.unpack(default=None)
90
+
91
+ async def _find_latest_memory_async(self) -> ShortTermMemory | None:
92
+ result = await self._executor.execute_async(
93
+ self._short_term_memory_service.find_latest_memory_async,
94
+ self._short_term_memory_name,
95
+ )
96
+
97
+ self._log_not_found() if not result.success else self._log_found()
98
+
99
+ return result.unpack(default=None)
100
+
101
+ def save_sync(self, short_term_memory: TSchema) -> None:
102
+ json_data = short_term_memory.model_dump_json()
103
+ compressed_data = _compress_data_zlib_base64(json_data)
104
+ logger.info(
105
+ f"Saving memory with {len(compressed_data)} characters compressed from {len(json_data)} characters for memory {self._short_term_memory_name}"
106
+ )
107
+ self._short_term_memory_service.create_memory(
108
+ key=self._short_term_memory_name,
109
+ value=compressed_data,
110
+ )
111
+
112
+ async def save_async(self, short_term_memory: TSchema) -> None:
113
+ json_data = short_term_memory.model_dump_json()
114
+ compressed_data = _compress_data_zlib_base64(json_data)
115
+ logger.info(
116
+ f"Saving memory with {len(compressed_data)} characters compressed from {len(json_data)} characters for memory {self._short_term_memory_name}"
117
+ )
118
+ await self._short_term_memory_service.create_memory_async(
119
+ key=self._short_term_memory_name,
120
+ value=compressed_data,
121
+ )
122
+
123
+ def _process_compressed_memory(
124
+ self, memory: ShortTermMemory | None
125
+ ) -> TSchema | None:
126
+ if memory is not None and memory.data is not None:
127
+ if isinstance(memory.data, str):
128
+ data = _decompress_data_zlib_base64(memory.data)
129
+ return self._short_term_memory_schema.model_validate_json(data)
130
+ elif isinstance(memory.data, dict):
131
+ return self._short_term_memory_schema.model_validate(memory.data)
132
+ return None
133
+
134
+ def load_sync(self) -> TSchema | None:
135
+ memory: ShortTermMemory | None = self._find_latest_memory_sync()
136
+ return self._process_compressed_memory(memory)
137
+
138
+ async def load_async(self) -> TSchema | None:
139
+ memory: ShortTermMemory | None = await self._find_latest_memory_async()
140
+ return self._process_compressed_memory(memory)
@@ -0,0 +1,102 @@
1
+ from logging import Logger
2
+ from pydantic import BaseModel, Field
3
+
4
+ from unique_toolkit.chat.service import ChatService
5
+ from unique_toolkit.language_model.schemas import (
6
+ LanguageModelAssistantMessage,
7
+ LanguageModelStreamResponse,
8
+ )
9
+ from unique_toolkit.tools.tool_progress_reporter import (
10
+ ToolProgressReporter,
11
+ )
12
+
13
+
14
+ class ThinkingManagerConfig(BaseModel):
15
+ thinking_steps_display: bool = Field(
16
+ default=True, description="Whether to display thinking steps in the chat."
17
+ )
18
+
19
+
20
+ class ThinkingManager:
21
+ """
22
+ Manages the display and tracking of thinking steps during response generation.
23
+
24
+ This class is responsible for:
25
+ - Tracking and formatting thinking steps as part of the response process.
26
+ - Updating the tool progress reporter with the latest thinking step information.
27
+ - Managing the display of thinking steps in the assistant's response.
28
+ - Closing and finalizing the thinking steps section when the process is complete.
29
+
30
+ Key Features:
31
+ - Thinking Step Tracking: Maintains a sequential log of thinking steps with step numbers.
32
+ - Configurable Display: Supports enabling or disabling the display of thinking steps based on configuration.
33
+ - Integration with Tool Progress: Updates the tool progress reporter to reflect the current thinking state.
34
+ - Dynamic Response Updates: Modifies the assistant's response to include or finalize thinking steps.
35
+ - Flexible Formatting: Formats thinking steps in a structured and user-friendly HTML-like format.
36
+
37
+ The ThinkingManager enhances transparency and user understanding by providing a clear view of the assistant's reasoning process.
38
+ """
39
+ def __init__(
40
+ self,
41
+ logger: Logger,
42
+ config: ThinkingManagerConfig,
43
+ tool_progress_reporter: ToolProgressReporter,
44
+ chat_service: ChatService,
45
+ ):
46
+ self._chat_service = chat_service
47
+ self._config = config
48
+ self._thinking_steps = ""
49
+ self._thinking_step_number = 1
50
+ self._tool_progress_reporter = tool_progress_reporter
51
+
52
+ def thinking_is_displayed(self) -> bool:
53
+ return self._config.thinking_steps_display
54
+
55
+ def update_tool_progress_reporter(self, loop_response: LanguageModelStreamResponse):
56
+ if self._config.thinking_steps_display and (
57
+ not loop_response.message.text
58
+ == self._tool_progress_reporter._progress_start_text
59
+ ):
60
+ self._tool_progress_reporter.tool_statuses = {}
61
+ self._tool_progress_reporter._progress_start_text = (
62
+ loop_response.message.text
63
+ )
64
+
65
+ def update_start_text(
66
+ self, start_text: str, loop_response: LanguageModelStreamResponse
67
+ ) -> str:
68
+ if not self._config.thinking_steps_display:
69
+ return start_text
70
+ if not loop_response.message.original_text:
71
+ return start_text
72
+ if loop_response.message.original_text == "":
73
+ return start_text
74
+
75
+ update_message = loop_response.message.original_text
76
+
77
+ if start_text == "":
78
+ self._thinking_steps = f"\n<i><b>Step 1:</b>\n{update_message}</i>\n"
79
+ start_text = f"""<details open>\n<summary><b>Thinking steps</b></summary>\n{self._thinking_steps}\n</details>\n\n---\n\n"""
80
+ else:
81
+ self._thinking_steps += f"\n\n<i><b>Step {self._thinking_step_number}:</b>\n{update_message}</i>\n\n"
82
+ start_text = f"""<details open>\n<summary><b>Thinking steps</b></summary>\n<i>{self._thinking_steps}\n\n</i>\n</details>\n\n---\n\n"""
83
+
84
+ self._thinking_step_number += 1
85
+ return start_text
86
+
87
+ def close_thinking_steps(self, loop_response: LanguageModelStreamResponse):
88
+ if not self._config.thinking_steps_display:
89
+ return
90
+ if not self._thinking_steps:
91
+ return
92
+ if not loop_response.message.text:
93
+ return
94
+ if not loop_response.message.text.startswith("<details open>"):
95
+ return
96
+
97
+ loop_response.message.text = loop_response.message.text.replace(
98
+ "<details open>", "<details>"
99
+ )
100
+
101
+ self._chat_service.modify_assistant_message(content=loop_response.message.text)
102
+ return
@@ -134,5 +134,4 @@ class ToolPrompts(BaseModel):
134
134
  display_name: str
135
135
  tool_description: str
136
136
  tool_format_information_for_system_prompt: str
137
- tool_format_information_for_system_prompt: str
138
137
  input_model: dict[str, Any]
@@ -18,7 +18,7 @@ from unique_toolkit.language_model.schemas import (
18
18
  from unique_toolkit.language_model.service import LanguageModelService
19
19
 
20
20
 
21
- from unique_toolkit.evaluators.schemas import EvaluationMetricName
21
+ from unique_toolkit.evals.schemas import EvaluationMetricName
22
22
  from unique_toolkit.tools.agent_chunks_handler import AgentChunksHandler
23
23
  from unique_toolkit.tools.config import ToolBuildConfig, ToolSelectionPolicy
24
24
  from unique_toolkit.tools.schemas import BaseToolConfig, ToolCallResponse, ToolPrompts
@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
4
4
  from unique_toolkit.app.schemas import ChatEvent
5
5
  from unique_toolkit.language_model.schemas import (
6
6
  LanguageModelFunction,
7
+ LanguageModelTool,
7
8
  LanguageModelToolDescription,
8
9
  )
9
10
  from unique_toolkit.tools.config import ToolBuildConfig
@@ -12,6 +13,7 @@ from unique_toolkit.tools.schemas import ToolCallResponse, ToolPrompts
12
13
  from unique_toolkit.tools.tool import Tool
13
14
  from unique_toolkit.tools.tool_progress_reporter import ToolProgressReporter
14
15
  from unique_toolkit.tools.utils.execution.execution import Result, SafeTaskExecutor
16
+ from unique_toolkit.evals.schemas import EvaluationMetricName
15
17
 
16
18
 
17
19
  class ForcedToolOption:
@@ -67,16 +69,18 @@ class ToolManager:
67
69
  tool_progress_reporter: ToolProgressReporter,
68
70
  ):
69
71
  self._logger = logger
70
- self._config = config
72
+ self._config = config
71
73
  self._tool_progress_reporter = tool_progress_reporter
72
74
  self._tools = []
73
75
  self._tool_choices = event.payload.tool_choices
74
76
  self._disabled_tools = event.payload.disabled_tools
77
+ # this needs to be a set of strings to avoid duplicates
78
+ self._tool_evaluation_check_list: set[EvaluationMetricName] = set()
75
79
  self._init__tools(event)
76
80
 
77
- def _init__tools(self, event: ChatEvent) -> None:
81
+ def _init__tools(self, event: ChatEvent) -> None:
78
82
  tool_choices = self._tool_choices
79
- tool_configs = self._config .tools
83
+ tool_configs = self._config.tools
80
84
  self._logger.info("Initializing tool definitions...")
81
85
  self._logger.info(f"Tool choices: {tool_choices}")
82
86
  self._logger.info(f"Tool configs: {tool_configs}")
@@ -105,6 +109,9 @@ class ToolManager:
105
109
 
106
110
  self._tools.append(t)
107
111
 
112
+ def get_evaluation_check_list(self) -> list[EvaluationMetricName]:
113
+ return list(self._tool_evaluation_check_list)
114
+
108
115
  def log_loaded_tools(self):
109
116
  self._logger.info(f"Loaded tools: {[tool.name for tool in self._tools]}")
110
117
 
@@ -118,9 +125,15 @@ class ToolManager:
118
125
  return None
119
126
 
120
127
  def get_forced_tools(self) -> list[ForcedToolOption]:
121
- return [ForcedToolOption(t.name) for t in self._tools if t.name in self._tool_choices]
128
+ return [
129
+ ForcedToolOption(t.name)
130
+ for t in self._tools
131
+ if t.name in self._tool_choices
132
+ ]
122
133
 
123
- def get_tool_definitions(self) -> list[LanguageModelToolDescription]:
134
+ def get_tool_definitions(
135
+ self,
136
+ ) -> list[LanguageModelTool | LanguageModelToolDescription]:
124
137
  return [tool.tool_description() for tool in self._tools]
125
138
 
126
139
  def get_tool_prompts(self) -> list[ToolPrompts]:
@@ -137,19 +150,19 @@ class ToolManager:
137
150
  )
138
151
  num_tool_calls = len(tool_calls)
139
152
 
140
- if num_tool_calls > self._config .max_tool_calls:
153
+ if num_tool_calls > self._config.max_tool_calls:
141
154
  self._logger.warning(
142
155
  (
143
156
  "Number of tool calls %s exceeds the allowed maximum of %s."
144
157
  "The tool calls will be reduced to the first %s."
145
158
  ),
146
159
  num_tool_calls,
147
- self._config .max_tool_calls,
148
- self._config .max_tool_calls,
160
+ self._config.max_tool_calls,
161
+ self._config.max_tool_calls,
149
162
  )
150
- tool_calls = tool_calls[: self._config .max_tool_calls]
163
+ tool_calls = tool_calls[: self._config.max_tool_calls]
151
164
 
152
- tool_call_responses = await self._execute_parallelized(tool_calls=tool_calls)
165
+ tool_call_responses = await self._execute_parallelized(tool_calls)
153
166
  return tool_call_responses
154
167
 
155
168
  async def _execute_parallelized(
@@ -194,6 +207,9 @@ class ToolManager:
194
207
  tool_response: ToolCallResponse = await tool_instance.run(
195
208
  tool_call=tool_call
196
209
  )
210
+ evaluation_checks = tool_instance.evaluation_check_list()
211
+ self._tool_evaluation_check_list.update(evaluation_checks)
212
+
197
213
  return tool_response
198
214
 
199
215
  return ToolCallResponse(