unique_toolkit 0.8.11__py3-none-any.whl → 0.8.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/validators.py +37 -2
- unique_toolkit/content/service.py +15 -1
- unique_toolkit/embedding/service.py +11 -0
- unique_toolkit/evals/evaluation_manager.py +206 -0
- unique_toolkit/evals/exception.py +5 -0
- unique_toolkit/evals/schemas.py +100 -0
- unique_toolkit/history_manager/history_manager.py +261 -0
- unique_toolkit/history_manager/utils.py +174 -0
- unique_toolkit/language_model/service.py +11 -0
- unique_toolkit/postprocessor/postprocessor_manager.py +122 -0
- unique_toolkit/reference_manager/reference_manager.py +19 -0
- unique_toolkit/short_term_memory/persistent_short_term_memory_manager.py +140 -0
- unique_toolkit/thinking_manager/thinking_manager.py +102 -0
- unique_toolkit/tools/schemas.py +0 -1
- unique_toolkit/tools/tool.py +1 -1
- unique_toolkit/tools/tool_manager.py +26 -10
- {unique_toolkit-0.8.11.dist-info → unique_toolkit-0.8.13.dist-info}/METADATA +13 -1
- {unique_toolkit-0.8.11.dist-info → unique_toolkit-0.8.13.dist-info}/RECORD +20 -12
- {unique_toolkit-0.8.11.dist-info → unique_toolkit-0.8.13.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.11.dist-info → unique_toolkit-0.8.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
from unique_toolkit.content.schemas import ContentChunk, ContentMetadata
|
|
6
|
+
from unique_toolkit.language_model.schemas import (
|
|
7
|
+
LanguageModelAssistantMessage,
|
|
8
|
+
LanguageModelMessage,
|
|
9
|
+
LanguageModelToolMessage,
|
|
10
|
+
)
|
|
11
|
+
from unique_toolkit.tools.schemas import Source
|
|
12
|
+
from unique_toolkit.tools.utils.source_handling.schema import (
|
|
13
|
+
SourceFormatConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def convert_tool_interactions_to_content_messages(
|
|
21
|
+
loop_history: list[LanguageModelMessage],
|
|
22
|
+
) -> list[LanguageModelMessage]:
|
|
23
|
+
new_loop_history = []
|
|
24
|
+
copy_loop_history = deepcopy(loop_history)
|
|
25
|
+
|
|
26
|
+
for message in copy_loop_history:
|
|
27
|
+
if isinstance(message, LanguageModelAssistantMessage) and message.tool_calls:
|
|
28
|
+
new_loop_history.append(_convert_tool_call_to_content(message))
|
|
29
|
+
|
|
30
|
+
elif isinstance(message, LanguageModelToolMessage):
|
|
31
|
+
new_loop_history.append(_convert_tool_call_response_to_content(message))
|
|
32
|
+
else:
|
|
33
|
+
new_loop_history.append(message)
|
|
34
|
+
|
|
35
|
+
return new_loop_history
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _convert_tool_call_to_content(
|
|
39
|
+
assistant_message: LanguageModelAssistantMessage,
|
|
40
|
+
) -> LanguageModelAssistantMessage:
|
|
41
|
+
assert assistant_message.tool_calls is not None
|
|
42
|
+
new_content = "The assistant requested the following tool_call:"
|
|
43
|
+
for tool_call in assistant_message.tool_calls:
|
|
44
|
+
new_content += (
|
|
45
|
+
f"\n\n- {tool_call.function.name}: {tool_call.function.arguments}"
|
|
46
|
+
)
|
|
47
|
+
assistant_message.tool_calls = None
|
|
48
|
+
assistant_message.content = new_content
|
|
49
|
+
|
|
50
|
+
return assistant_message
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _convert_tool_call_response_to_content(
|
|
54
|
+
tool_message: LanguageModelToolMessage,
|
|
55
|
+
) -> LanguageModelAssistantMessage:
|
|
56
|
+
new_content = f"The assistant received the following tool_call_response: {tool_message.name}, {tool_message.content}"
|
|
57
|
+
assistant_message = LanguageModelAssistantMessage(
|
|
58
|
+
content=new_content,
|
|
59
|
+
)
|
|
60
|
+
return assistant_message
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def transform_chunks_to_string(
|
|
64
|
+
content_chunks: list[ContentChunk],
|
|
65
|
+
max_source_number: int,
|
|
66
|
+
cfg: SourceFormatConfig | None,
|
|
67
|
+
full_sources_serialize_dump: bool = False,
|
|
68
|
+
) -> str:
|
|
69
|
+
"""Transform content chunks into a string of sources.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
content_chunks (list[ContentChunk]): The content chunks to transform
|
|
73
|
+
max_source_number (int): The maximum source number to use
|
|
74
|
+
feature_full_sources (bool, optional): Whether to include the full source object. Defaults to False which is the old format.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
str: String for the tool call response
|
|
78
|
+
"""
|
|
79
|
+
if not content_chunks:
|
|
80
|
+
return "No relevant sources found."
|
|
81
|
+
if full_sources_serialize_dump:
|
|
82
|
+
sources = [
|
|
83
|
+
Source(
|
|
84
|
+
source_number=max_source_number + i,
|
|
85
|
+
key=chunk.key,
|
|
86
|
+
id=chunk.id,
|
|
87
|
+
order=chunk.order,
|
|
88
|
+
content=chunk.text,
|
|
89
|
+
chunk_id=chunk.chunk_id,
|
|
90
|
+
metadata=(
|
|
91
|
+
_format_metadata(chunk.metadata, cfg) or None
|
|
92
|
+
if chunk.metadata
|
|
93
|
+
else None
|
|
94
|
+
),
|
|
95
|
+
url=chunk.url,
|
|
96
|
+
).model_dump(
|
|
97
|
+
exclude_none=True,
|
|
98
|
+
exclude_defaults=True,
|
|
99
|
+
by_alias=True,
|
|
100
|
+
)
|
|
101
|
+
for i, chunk in enumerate(content_chunks)
|
|
102
|
+
]
|
|
103
|
+
else:
|
|
104
|
+
sources = [
|
|
105
|
+
{
|
|
106
|
+
"source_number": max_source_number + i,
|
|
107
|
+
"content": chunk.text,
|
|
108
|
+
**(
|
|
109
|
+
{"metadata": meta}
|
|
110
|
+
if (
|
|
111
|
+
meta := _format_metadata(chunk.metadata, cfg)
|
|
112
|
+
) # only add when not empty
|
|
113
|
+
else {}
|
|
114
|
+
),
|
|
115
|
+
}
|
|
116
|
+
for i, chunk in enumerate(content_chunks)
|
|
117
|
+
]
|
|
118
|
+
return json.dumps(sources)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def load_sources_from_string(
|
|
122
|
+
source_string: str,
|
|
123
|
+
) -> list[Source] | None:
|
|
124
|
+
"""Transform JSON string from language model tool message in the tool call response into Source objects"""
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
# First, try parsing as JSON (new format)
|
|
128
|
+
sources_data = json.loads(source_string)
|
|
129
|
+
return [Source.model_validate(source) for source in sources_data]
|
|
130
|
+
except (json.JSONDecodeError, ValueError):
|
|
131
|
+
logger.warning("Failed to parse source string")
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _format_metadata(
|
|
136
|
+
metadata: ContentMetadata | None,
|
|
137
|
+
cfg: SourceFormatConfig | None,
|
|
138
|
+
) -> str:
|
|
139
|
+
"""
|
|
140
|
+
Build the concatenated tag string from the chunk's metadata
|
|
141
|
+
and the templates found in cfg.sections.
|
|
142
|
+
Example result:
|
|
143
|
+
"<|topic|>GenAI<|/topic|>\n<|date|>This document is from: 2025-04-29<|/date|>\n"
|
|
144
|
+
"""
|
|
145
|
+
if metadata is None:
|
|
146
|
+
return ""
|
|
147
|
+
|
|
148
|
+
if cfg is None or not cfg.sections:
|
|
149
|
+
# If no configuration is provided, return empty string
|
|
150
|
+
return ""
|
|
151
|
+
|
|
152
|
+
meta_dict = metadata.model_dump(exclude_none=True, by_alias=True)
|
|
153
|
+
sections = cfg.sections
|
|
154
|
+
|
|
155
|
+
parts: list[str] = []
|
|
156
|
+
for key, template in sections.items():
|
|
157
|
+
if key in meta_dict:
|
|
158
|
+
parts.append(template.format(meta_dict[key]))
|
|
159
|
+
|
|
160
|
+
return "".join(parts)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
### In case we do not want any formatting of metadata we could use this function
|
|
164
|
+
# def _filtered_metadata(
|
|
165
|
+
# meta: ContentMetadata | None,
|
|
166
|
+
# cfg: SourceFormatConfig,
|
|
167
|
+
# ) -> dict[str, str] | None:
|
|
168
|
+
# if meta is None:
|
|
169
|
+
# return None
|
|
170
|
+
|
|
171
|
+
# allowed = set(cfg.sections)
|
|
172
|
+
# raw = meta.model_dump(exclude_none=True, by_alias=True)
|
|
173
|
+
# pruned = {k: v for k, v in raw.items() if k in allowed}
|
|
174
|
+
# return pruned or None
|
|
@@ -6,6 +6,7 @@ from typing_extensions import deprecated
|
|
|
6
6
|
|
|
7
7
|
from unique_toolkit._common.validate_required_values import validate_required_values
|
|
8
8
|
from unique_toolkit.app.schemas import BaseEvent, ChatEvent, Event
|
|
9
|
+
from unique_toolkit.app.unique_settings import UniqueSettings
|
|
9
10
|
from unique_toolkit.content.schemas import ContentChunk
|
|
10
11
|
from unique_toolkit.language_model.constants import (
|
|
11
12
|
DEFAULT_COMPLETE_TEMPERATURE,
|
|
@@ -89,6 +90,16 @@ class LanguageModelService:
|
|
|
89
90
|
"""
|
|
90
91
|
return cls(company_id=event.company_id, user_id=event.user_id)
|
|
91
92
|
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_settings(cls, settings: UniqueSettings):
|
|
95
|
+
"""
|
|
96
|
+
Initialize the LanguageModelService with a settings object.
|
|
97
|
+
"""
|
|
98
|
+
return cls(
|
|
99
|
+
company_id=settings.auth.company_id.get_secret_value(),
|
|
100
|
+
user_id=settings.auth.user_id.get_secret_value(),
|
|
101
|
+
)
|
|
102
|
+
|
|
92
103
|
@property
|
|
93
104
|
@deprecated(
|
|
94
105
|
"The event property is deprecated and will be removed in a future version."
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
import asyncio
|
|
3
|
+
from logging import Logger
|
|
4
|
+
|
|
5
|
+
from unique_toolkit.chat.service import ChatService
|
|
6
|
+
from unique_toolkit.language_model.schemas import (
|
|
7
|
+
LanguageModelMessage,
|
|
8
|
+
LanguageModelStreamResponse,
|
|
9
|
+
)
|
|
10
|
+
from unique_toolkit.tools.utils.execution.execution import Result, SafeTaskExecutor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Postprocessor(ABC):
|
|
14
|
+
def __init__(self, name: str):
|
|
15
|
+
self.name = name
|
|
16
|
+
|
|
17
|
+
def get_name(self) -> str:
|
|
18
|
+
return self.name
|
|
19
|
+
|
|
20
|
+
async def run(self, loop_response: LanguageModelStreamResponse) -> str:
|
|
21
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
|
22
|
+
|
|
23
|
+
async def apply_postprocessing_to_response(
|
|
24
|
+
self, loop_response: LanguageModelStreamResponse
|
|
25
|
+
) -> bool:
|
|
26
|
+
raise NotImplementedError(
|
|
27
|
+
"Subclasses must implement this method to apply post-processing to the response."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
async def remove_from_text(self, text) -> str:
|
|
31
|
+
raise NotImplementedError(
|
|
32
|
+
"Subclasses must implement this method to remove post-processing from the message."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PostprocessorManager:
|
|
37
|
+
"""
|
|
38
|
+
Manages and executes postprocessors for modifying and refining responses.
|
|
39
|
+
|
|
40
|
+
This class is responsible for:
|
|
41
|
+
- Storing and managing a collection of postprocessor instances.
|
|
42
|
+
- Executing postprocessors asynchronously to refine loop responses.
|
|
43
|
+
- Applying modifications to assistant messages based on postprocessor results.
|
|
44
|
+
- Providing utility methods for text manipulation using postprocessors.
|
|
45
|
+
|
|
46
|
+
Key Features:
|
|
47
|
+
- Postprocessor Management: Allows adding and retrieving postprocessor instances.
|
|
48
|
+
- Asynchronous Execution: Runs all postprocessors concurrently for efficiency.
|
|
49
|
+
- Response Modification: Applies postprocessing changes to assistant messages when necessary.
|
|
50
|
+
- Text Cleanup: Supports removing specific patterns or content from text using postprocessors.
|
|
51
|
+
- Error Handling: Logs warnings for any postprocessors that fail during execution.
|
|
52
|
+
|
|
53
|
+
The PostprocessorManager serves as a centralized system for managing and applying postprocessing logic to enhance response quality and consistency.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
_postprocessors: list[Postprocessor] = []
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
logger: Logger,
|
|
61
|
+
chat_service: ChatService,
|
|
62
|
+
):
|
|
63
|
+
self._logger = logger
|
|
64
|
+
self._chat_service = chat_service
|
|
65
|
+
|
|
66
|
+
def add_postprocessor(self, postprocessor: Postprocessor):
|
|
67
|
+
self._postprocessors.append(postprocessor)
|
|
68
|
+
|
|
69
|
+
def get_postprocessors(self, name: str) -> list[Postprocessor]:
|
|
70
|
+
return self._postprocessors
|
|
71
|
+
|
|
72
|
+
async def run_postprocessors(
|
|
73
|
+
self,
|
|
74
|
+
loop_response: LanguageModelStreamResponse,
|
|
75
|
+
) -> None:
|
|
76
|
+
task_executor = SafeTaskExecutor(
|
|
77
|
+
logger=self._logger,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
tasks = [
|
|
81
|
+
task_executor.execute_async(
|
|
82
|
+
self.execute_postprocessors,
|
|
83
|
+
loop_response=loop_response,
|
|
84
|
+
postprocessor_instance=postprocessor,
|
|
85
|
+
)
|
|
86
|
+
for postprocessor in self._postprocessors
|
|
87
|
+
]
|
|
88
|
+
postprocessor_results = await asyncio.gather(*tasks)
|
|
89
|
+
|
|
90
|
+
for i, result in enumerate(postprocessor_results):
|
|
91
|
+
if not result.success:
|
|
92
|
+
self._logger.warning(
|
|
93
|
+
f"Postprocessor {self._postprocessors[i].get_name()} failed to run."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
modification_results = [
|
|
97
|
+
postprocessor.apply_postprocessing_to_response(loop_response)
|
|
98
|
+
for postprocessor in self._postprocessors
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
has_been_modified = any(modification_results)
|
|
102
|
+
|
|
103
|
+
if has_been_modified:
|
|
104
|
+
self._chat_service.modify_assistant_message(
|
|
105
|
+
content=loop_response.message.text,
|
|
106
|
+
message_id=loop_response.message.id,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def execute_postprocessors(
|
|
110
|
+
self,
|
|
111
|
+
loop_response: LanguageModelStreamResponse,
|
|
112
|
+
postprocessor_instance: Postprocessor,
|
|
113
|
+
) -> None:
|
|
114
|
+
await postprocessor_instance.run(loop_response)
|
|
115
|
+
|
|
116
|
+
async def remove_from_text(
|
|
117
|
+
self,
|
|
118
|
+
text: str,
|
|
119
|
+
) -> str:
|
|
120
|
+
for postprocessor in self._postprocessors:
|
|
121
|
+
text = await postprocessor.remove_from_text(text)
|
|
122
|
+
return text
|
|
@@ -9,6 +9,25 @@ class tool_chunks:
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ReferenceManager:
|
|
12
|
+
"""
|
|
13
|
+
Manages content chunks and references extracted from tool responses.
|
|
14
|
+
|
|
15
|
+
This class is responsible for:
|
|
16
|
+
- Extracting and storing referenceable content chunks from tool responses.
|
|
17
|
+
- Managing a collection of content chunks and their associated references.
|
|
18
|
+
- Providing methods to retrieve, replace, and manipulate chunks and references.
|
|
19
|
+
- Supporting the retrieval of the latest references and their corresponding chunks.
|
|
20
|
+
|
|
21
|
+
Key Features:
|
|
22
|
+
- Chunk Extraction: Extracts content chunks from tool responses and organizes them for reference.
|
|
23
|
+
- Reference Management: Tracks references to content chunks and allows for easy retrieval.
|
|
24
|
+
- Latest Reference Access: Provides methods to fetch the most recent references and their associated chunks.
|
|
25
|
+
- Flexible Chunk Replacement: Allows for replacing the current set of chunks with a new list.
|
|
26
|
+
- Reference-to-Chunk Mapping: Matches references to their corresponding chunks based on source IDs.
|
|
27
|
+
|
|
28
|
+
The ReferenceManager serves as a utility for managing and linking content chunks with references, enabling efficient content tracking and retrieval.
|
|
29
|
+
"""
|
|
30
|
+
|
|
12
31
|
def __init__(self):
|
|
13
32
|
self._tool_chunks: dict[str, tool_chunks] = {}
|
|
14
33
|
self._chunks: list[ContentChunk] = []
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import zlib
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from typing import Generic, Type, TypeVar
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from unique_toolkit.short_term_memory.schemas import ShortTermMemory
|
|
8
|
+
from unique_toolkit.short_term_memory.service import ShortTermMemoryService
|
|
9
|
+
from unique_toolkit.tools.utils.execution.execution import SafeTaskExecutor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
TSchema = TypeVar("TSchema", bound=BaseModel)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _default_short_term_memory_name(schema: type[BaseModel]) -> str:
|
|
19
|
+
return f"{schema.__name__}Key"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _compress_data_zlib_base64(data: str) -> str:
|
|
23
|
+
"""Compress data using ZLIB and encode as base64 string."""
|
|
24
|
+
compressed = zlib.compress(data.encode("utf-8"))
|
|
25
|
+
return base64.b64encode(compressed).decode("utf-8")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _decompress_data_zlib_base64(compressed_data: str) -> str:
|
|
29
|
+
"""Decompress base64 encoded ZLIB data."""
|
|
30
|
+
decoded = base64.b64decode(compressed_data.encode("utf-8"))
|
|
31
|
+
return zlib.decompress(decoded).decode("utf-8")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PersistentShortMemoryManager(Generic[TSchema]):
|
|
35
|
+
"""
|
|
36
|
+
Manages the storage, retrieval, and processing of short-term memory in a persistent manner.
|
|
37
|
+
|
|
38
|
+
This class is responsible for:
|
|
39
|
+
- Saving and loading short-term memory data, both synchronously and asynchronously.
|
|
40
|
+
- Compressing and decompressing memory data for efficient storage.
|
|
41
|
+
- Validating and processing memory data using a predefined schema.
|
|
42
|
+
- Logging the status of memory operations, such as whether memory was found or saved.
|
|
43
|
+
|
|
44
|
+
Key Features:
|
|
45
|
+
- Persistent Storage: Integrates with a short-term memory service to store and retrieve memory data.
|
|
46
|
+
- Compression Support: Compresses memory data before saving and decompresses it upon retrieval.
|
|
47
|
+
- Schema Validation: Ensures memory data adheres to a specified schema for consistency.
|
|
48
|
+
- Synchronous and Asynchronous Operations: Supports both sync and async methods for flexibility.
|
|
49
|
+
- Logging and Debugging: Provides detailed logs for memory operations, including success and failure cases.
|
|
50
|
+
|
|
51
|
+
The PersistentShortMemoryManager is designed to handle short-term memory efficiently, ensuring data integrity and optimized storage.
|
|
52
|
+
"""
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
short_term_memory_service: ShortTermMemoryService,
|
|
56
|
+
short_term_memory_schema: Type[TSchema],
|
|
57
|
+
short_term_memory_name: str | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
self._short_term_memory_name = (
|
|
60
|
+
short_term_memory_name
|
|
61
|
+
if short_term_memory_name
|
|
62
|
+
else _default_short_term_memory_name(short_term_memory_schema)
|
|
63
|
+
)
|
|
64
|
+
self._short_term_memory_schema = short_term_memory_schema
|
|
65
|
+
self._short_term_memory_service = short_term_memory_service
|
|
66
|
+
|
|
67
|
+
self._executor = SafeTaskExecutor(
|
|
68
|
+
log_exceptions=False,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def _log_not_found(self) -> None:
|
|
72
|
+
logger.warning(
|
|
73
|
+
f"No short term memory found for chat {self._short_term_memory_service.chat_id} and key {self._short_term_memory_name}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _log_found(self) -> None:
|
|
77
|
+
logger.debug(
|
|
78
|
+
f"Short term memory found for chat {self._short_term_memory_service.chat_id} and key {self._short_term_memory_name}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def _find_latest_memory_sync(self) -> ShortTermMemory | None:
|
|
82
|
+
result = self._executor.execute(
|
|
83
|
+
self._short_term_memory_service.find_latest_memory,
|
|
84
|
+
self._short_term_memory_name,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self._log_not_found() if not result.success else self._log_found()
|
|
88
|
+
|
|
89
|
+
return result.unpack(default=None)
|
|
90
|
+
|
|
91
|
+
async def _find_latest_memory_async(self) -> ShortTermMemory | None:
|
|
92
|
+
result = await self._executor.execute_async(
|
|
93
|
+
self._short_term_memory_service.find_latest_memory_async,
|
|
94
|
+
self._short_term_memory_name,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._log_not_found() if not result.success else self._log_found()
|
|
98
|
+
|
|
99
|
+
return result.unpack(default=None)
|
|
100
|
+
|
|
101
|
+
def save_sync(self, short_term_memory: TSchema) -> None:
|
|
102
|
+
json_data = short_term_memory.model_dump_json()
|
|
103
|
+
compressed_data = _compress_data_zlib_base64(json_data)
|
|
104
|
+
logger.info(
|
|
105
|
+
f"Saving memory with {len(compressed_data)} characters compressed from {len(json_data)} characters for memory {self._short_term_memory_name}"
|
|
106
|
+
)
|
|
107
|
+
self._short_term_memory_service.create_memory(
|
|
108
|
+
key=self._short_term_memory_name,
|
|
109
|
+
value=compressed_data,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async def save_async(self, short_term_memory: TSchema) -> None:
|
|
113
|
+
json_data = short_term_memory.model_dump_json()
|
|
114
|
+
compressed_data = _compress_data_zlib_base64(json_data)
|
|
115
|
+
logger.info(
|
|
116
|
+
f"Saving memory with {len(compressed_data)} characters compressed from {len(json_data)} characters for memory {self._short_term_memory_name}"
|
|
117
|
+
)
|
|
118
|
+
await self._short_term_memory_service.create_memory_async(
|
|
119
|
+
key=self._short_term_memory_name,
|
|
120
|
+
value=compressed_data,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _process_compressed_memory(
|
|
124
|
+
self, memory: ShortTermMemory | None
|
|
125
|
+
) -> TSchema | None:
|
|
126
|
+
if memory is not None and memory.data is not None:
|
|
127
|
+
if isinstance(memory.data, str):
|
|
128
|
+
data = _decompress_data_zlib_base64(memory.data)
|
|
129
|
+
return self._short_term_memory_schema.model_validate_json(data)
|
|
130
|
+
elif isinstance(memory.data, dict):
|
|
131
|
+
return self._short_term_memory_schema.model_validate(memory.data)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
def load_sync(self) -> TSchema | None:
|
|
135
|
+
memory: ShortTermMemory | None = self._find_latest_memory_sync()
|
|
136
|
+
return self._process_compressed_memory(memory)
|
|
137
|
+
|
|
138
|
+
async def load_async(self) -> TSchema | None:
|
|
139
|
+
memory: ShortTermMemory | None = await self._find_latest_memory_async()
|
|
140
|
+
return self._process_compressed_memory(memory)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
|
|
4
|
+
from unique_toolkit.chat.service import ChatService
|
|
5
|
+
from unique_toolkit.language_model.schemas import (
|
|
6
|
+
LanguageModelAssistantMessage,
|
|
7
|
+
LanguageModelStreamResponse,
|
|
8
|
+
)
|
|
9
|
+
from unique_toolkit.tools.tool_progress_reporter import (
|
|
10
|
+
ToolProgressReporter,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ThinkingManagerConfig(BaseModel):
|
|
15
|
+
thinking_steps_display: bool = Field(
|
|
16
|
+
default=True, description="Whether to display thinking steps in the chat."
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ThinkingManager:
|
|
21
|
+
"""
|
|
22
|
+
Manages the display and tracking of thinking steps during response generation.
|
|
23
|
+
|
|
24
|
+
This class is responsible for:
|
|
25
|
+
- Tracking and formatting thinking steps as part of the response process.
|
|
26
|
+
- Updating the tool progress reporter with the latest thinking step information.
|
|
27
|
+
- Managing the display of thinking steps in the assistant's response.
|
|
28
|
+
- Closing and finalizing the thinking steps section when the process is complete.
|
|
29
|
+
|
|
30
|
+
Key Features:
|
|
31
|
+
- Thinking Step Tracking: Maintains a sequential log of thinking steps with step numbers.
|
|
32
|
+
- Configurable Display: Supports enabling or disabling the display of thinking steps based on configuration.
|
|
33
|
+
- Integration with Tool Progress: Updates the tool progress reporter to reflect the current thinking state.
|
|
34
|
+
- Dynamic Response Updates: Modifies the assistant's response to include or finalize thinking steps.
|
|
35
|
+
- Flexible Formatting: Formats thinking steps in a structured and user-friendly HTML-like format.
|
|
36
|
+
|
|
37
|
+
The ThinkingManager enhances transparency and user understanding by providing a clear view of the assistant's reasoning process.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
logger: Logger,
|
|
42
|
+
config: ThinkingManagerConfig,
|
|
43
|
+
tool_progress_reporter: ToolProgressReporter,
|
|
44
|
+
chat_service: ChatService,
|
|
45
|
+
):
|
|
46
|
+
self._chat_service = chat_service
|
|
47
|
+
self._config = config
|
|
48
|
+
self._thinking_steps = ""
|
|
49
|
+
self._thinking_step_number = 1
|
|
50
|
+
self._tool_progress_reporter = tool_progress_reporter
|
|
51
|
+
|
|
52
|
+
def thinking_is_displayed(self) -> bool:
|
|
53
|
+
return self._config.thinking_steps_display
|
|
54
|
+
|
|
55
|
+
def update_tool_progress_reporter(self, loop_response: LanguageModelStreamResponse):
|
|
56
|
+
if self._config.thinking_steps_display and (
|
|
57
|
+
not loop_response.message.text
|
|
58
|
+
== self._tool_progress_reporter._progress_start_text
|
|
59
|
+
):
|
|
60
|
+
self._tool_progress_reporter.tool_statuses = {}
|
|
61
|
+
self._tool_progress_reporter._progress_start_text = (
|
|
62
|
+
loop_response.message.text
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def update_start_text(
|
|
66
|
+
self, start_text: str, loop_response: LanguageModelStreamResponse
|
|
67
|
+
) -> str:
|
|
68
|
+
if not self._config.thinking_steps_display:
|
|
69
|
+
return start_text
|
|
70
|
+
if not loop_response.message.original_text:
|
|
71
|
+
return start_text
|
|
72
|
+
if loop_response.message.original_text == "":
|
|
73
|
+
return start_text
|
|
74
|
+
|
|
75
|
+
update_message = loop_response.message.original_text
|
|
76
|
+
|
|
77
|
+
if start_text == "":
|
|
78
|
+
self._thinking_steps = f"\n<i><b>Step 1:</b>\n{update_message}</i>\n"
|
|
79
|
+
start_text = f"""<details open>\n<summary><b>Thinking steps</b></summary>\n{self._thinking_steps}\n</details>\n\n---\n\n"""
|
|
80
|
+
else:
|
|
81
|
+
self._thinking_steps += f"\n\n<i><b>Step {self._thinking_step_number}:</b>\n{update_message}</i>\n\n"
|
|
82
|
+
start_text = f"""<details open>\n<summary><b>Thinking steps</b></summary>\n<i>{self._thinking_steps}\n\n</i>\n</details>\n\n---\n\n"""
|
|
83
|
+
|
|
84
|
+
self._thinking_step_number += 1
|
|
85
|
+
return start_text
|
|
86
|
+
|
|
87
|
+
def close_thinking_steps(self, loop_response: LanguageModelStreamResponse):
|
|
88
|
+
if not self._config.thinking_steps_display:
|
|
89
|
+
return
|
|
90
|
+
if not self._thinking_steps:
|
|
91
|
+
return
|
|
92
|
+
if not loop_response.message.text:
|
|
93
|
+
return
|
|
94
|
+
if not loop_response.message.text.startswith("<details open>"):
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
loop_response.message.text = loop_response.message.text.replace(
|
|
98
|
+
"<details open>", "<details>"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
self._chat_service.modify_assistant_message(content=loop_response.message.text)
|
|
102
|
+
return
|
unique_toolkit/tools/schemas.py
CHANGED
unique_toolkit/tools/tool.py
CHANGED
|
@@ -18,7 +18,7 @@ from unique_toolkit.language_model.schemas import (
|
|
|
18
18
|
from unique_toolkit.language_model.service import LanguageModelService
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from unique_toolkit.
|
|
21
|
+
from unique_toolkit.evals.schemas import EvaluationMetricName
|
|
22
22
|
from unique_toolkit.tools.agent_chunks_handler import AgentChunksHandler
|
|
23
23
|
from unique_toolkit.tools.config import ToolBuildConfig, ToolSelectionPolicy
|
|
24
24
|
from unique_toolkit.tools.schemas import BaseToolConfig, ToolCallResponse, ToolPrompts
|