unique_toolkit 0.8.14__py3-none-any.whl → 0.8.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/default_language_model.py +6 -0
- unique_toolkit/_common/token/image_token_counting.py +67 -0
- unique_toolkit/_common/token/token_counting.py +196 -0
- unique_toolkit/evals/config.py +36 -0
- unique_toolkit/evals/context_relevancy/prompts.py +56 -0
- unique_toolkit/evals/context_relevancy/schema.py +88 -0
- unique_toolkit/evals/context_relevancy/service.py +241 -0
- unique_toolkit/evals/hallucination/constants.py +61 -0
- unique_toolkit/evals/hallucination/hallucination_evaluation.py +92 -0
- unique_toolkit/evals/hallucination/prompts.py +79 -0
- unique_toolkit/evals/hallucination/service.py +57 -0
- unique_toolkit/evals/hallucination/utils.py +213 -0
- unique_toolkit/evals/output_parser.py +48 -0
- unique_toolkit/evals/tests/test_context_relevancy_service.py +252 -0
- unique_toolkit/evals/tests/test_output_parser.py +80 -0
- unique_toolkit/history_manager/history_construction_with_contents.py +307 -0
- unique_toolkit/history_manager/history_manager.py +80 -111
- unique_toolkit/history_manager/loop_token_reducer.py +457 -0
- unique_toolkit/language_model/schemas.py +8 -0
- unique_toolkit/reference_manager/reference_manager.py +15 -2
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/METADATA +7 -1
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/RECORD +24 -7
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/WHEEL +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from logging import Logger
|
|
3
|
-
from typing import Awaitable, Callable
|
|
3
|
+
from typing import Annotated, Awaitable, Callable
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
|
+
import tiktoken
|
|
7
8
|
from unique_toolkit.app.schemas import ChatEvent
|
|
8
9
|
|
|
9
10
|
|
|
@@ -17,17 +18,55 @@ from unique_toolkit.language_model.schemas import (
|
|
|
17
18
|
LanguageModelFunction,
|
|
18
19
|
LanguageModelMessage,
|
|
19
20
|
LanguageModelMessageRole,
|
|
21
|
+
LanguageModelMessages,
|
|
22
|
+
LanguageModelSystemMessage,
|
|
20
23
|
LanguageModelToolMessage,
|
|
21
24
|
LanguageModelUserMessage,
|
|
22
25
|
)
|
|
23
26
|
|
|
24
27
|
from unique_toolkit.tools.schemas import ToolCallResponse
|
|
25
|
-
from unique_toolkit.content.utils import count_tokens
|
|
26
28
|
from unique_toolkit.history_manager.utils import transform_chunks_to_string
|
|
27
29
|
|
|
30
|
+
from _common.validators import LMI
|
|
31
|
+
from history_manager.loop_token_reducer import LoopTokenReducer
|
|
32
|
+
from reference_manager.reference_manager import ReferenceManager
|
|
33
|
+
from tools.config import get_configuration_dict
|
|
34
|
+
|
|
35
|
+
DeactivatedNone = Annotated[
|
|
36
|
+
None,
|
|
37
|
+
Field(title="Deactivated", description="None"),
|
|
38
|
+
]
|
|
28
39
|
|
|
29
40
|
class HistoryManagerConfig(BaseModel):
|
|
30
41
|
|
|
42
|
+
class InputTokenDistributionConfig(BaseModel):
|
|
43
|
+
model_config = get_configuration_dict(frozen=True)
|
|
44
|
+
|
|
45
|
+
percent_for_history: float = Field(
|
|
46
|
+
default=0.6,
|
|
47
|
+
ge=0.0,
|
|
48
|
+
lt=1.0,
|
|
49
|
+
description="The fraction of the max input tokens that will be reserved for the history.",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def max_history_tokens(self, max_input_token: int) -> int:
|
|
53
|
+
return int(self.percent_for_history * max_input_token)
|
|
54
|
+
|
|
55
|
+
class UploadedContentConfig(BaseModel):
|
|
56
|
+
model_config = get_configuration_dict()
|
|
57
|
+
|
|
58
|
+
user_context_window_limit_warning: str = Field(
|
|
59
|
+
default="The uploaded content is too large to fit into the ai model. "
|
|
60
|
+
"Unique AI will search for relevant sections in the material and if needed combine the data with knowledge base content",
|
|
61
|
+
description="Message to show when using the Internal Search instead of upload and chat tool due to context window limit. Jinja template.",
|
|
62
|
+
)
|
|
63
|
+
percent_for_uploaded_content: float = Field(
|
|
64
|
+
default=0.6,
|
|
65
|
+
ge=0.0,
|
|
66
|
+
lt=1.0,
|
|
67
|
+
description="The fraction of the max input tokens that will be reserved for the uploaded content.",
|
|
68
|
+
)
|
|
69
|
+
|
|
31
70
|
class ExperimentalFeatures(BaseModel):
|
|
32
71
|
def __init__(self, full_sources_serialize_dump: bool = False):
|
|
33
72
|
self.full_sources_serialize_dump = full_sources_serialize_dump
|
|
@@ -48,6 +87,20 @@ class HistoryManagerConfig(BaseModel):
|
|
|
48
87
|
description="The maximum number of tokens to keep in the history.",
|
|
49
88
|
)
|
|
50
89
|
|
|
90
|
+
uploaded_content_config: (
|
|
91
|
+
Annotated[
|
|
92
|
+
UploadedContentConfig,
|
|
93
|
+
Field(title="Active"),
|
|
94
|
+
]
|
|
95
|
+
| DeactivatedNone
|
|
96
|
+
) = UploadedContentConfig()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
input_token_distribution: InputTokenDistributionConfig = Field(
|
|
100
|
+
default=InputTokenDistributionConfig(),
|
|
101
|
+
description="Configuration for the input token distribution.",
|
|
102
|
+
)
|
|
103
|
+
|
|
51
104
|
|
|
52
105
|
class HistoryManager:
|
|
53
106
|
"""
|
|
@@ -78,11 +131,20 @@ class HistoryManager:
|
|
|
78
131
|
logger: Logger,
|
|
79
132
|
event: ChatEvent,
|
|
80
133
|
config: HistoryManagerConfig,
|
|
134
|
+
language_model: LMI,
|
|
135
|
+
reference_manager: ReferenceManager,
|
|
81
136
|
):
|
|
82
137
|
self._config = config
|
|
83
138
|
self._logger = logger
|
|
84
|
-
self.
|
|
85
|
-
self.
|
|
139
|
+
self._language_model = language_model
|
|
140
|
+
self._token_reducer = LoopTokenReducer(
|
|
141
|
+
logger=self._logger,
|
|
142
|
+
event=event,
|
|
143
|
+
config=self._config,
|
|
144
|
+
language_model=self._language_model,
|
|
145
|
+
reference_manager=reference_manager,
|
|
146
|
+
)
|
|
147
|
+
|
|
86
148
|
|
|
87
149
|
def has_no_loop_messages(self) -> bool:
|
|
88
150
|
return len(self._loop_history) == 0
|
|
@@ -150,112 +212,19 @@ class HistoryManager:
|
|
|
150
212
|
def add_assistant_message(self, message: LanguageModelAssistantMessage) -> None:
|
|
151
213
|
self._loop_history.append(message)
|
|
152
214
|
|
|
153
|
-
|
|
215
|
+
|
|
216
|
+
async def get_history_for_model_call(
|
|
154
217
|
self,
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
) ->
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
# Get uploaded files
|
|
167
|
-
uploaded_files = self._content_service.search_content_on_chat(
|
|
168
|
-
chat_id=self._chat_service.chat_id
|
|
218
|
+
original_user_message: str,
|
|
219
|
+
rendered_user_message_string: str,
|
|
220
|
+
rendered_system_message_string: str,
|
|
221
|
+
remove_from_text: Callable[[str], Awaitable[str]]
|
|
222
|
+
) -> LanguageModelMessages:
|
|
223
|
+
messages = await self._token_reducer.get_history_for_model_call(
|
|
224
|
+
original_user_message=original_user_message,
|
|
225
|
+
rendered_user_message_string=rendered_user_message_string,
|
|
226
|
+
rendered_system_message_string=rendered_system_message_string,
|
|
227
|
+
loop_history=self._loop_history,
|
|
228
|
+
remove_from_text=remove_from_text,
|
|
169
229
|
)
|
|
170
|
-
|
|
171
|
-
full_history = await self._chat_service.get_full_history_async()
|
|
172
|
-
|
|
173
|
-
merged_history = self._merge_history_and_uploads(full_history, uploaded_files)
|
|
174
|
-
|
|
175
|
-
if postprocessing_step is not None:
|
|
176
|
-
merged_history = postprocessing_step(merged_history)
|
|
177
|
-
|
|
178
|
-
limited_history = self._limit_to_token_window(
|
|
179
|
-
merged_history, self._config.max_history_tokens
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
# Add current user message if not already in history
|
|
183
|
-
# we grab it fresh from the db so it must contain all the messages this code is not needed anymore below currently it's left in for explainability
|
|
184
|
-
# current_user_msg = LanguageModelUserMessage(
|
|
185
|
-
# content=self.event.payload.user_message.text
|
|
186
|
-
# )
|
|
187
|
-
# if not any(
|
|
188
|
-
# msg.role == LanguageModelMessageRole.USER
|
|
189
|
-
# and msg.content == current_user_msg.content
|
|
190
|
-
# for msg in complete_history
|
|
191
|
-
# ):
|
|
192
|
-
# complete_history.append(current_user_msg)
|
|
193
|
-
|
|
194
|
-
# # Add final assistant response - this should be available when this method is called
|
|
195
|
-
# if (
|
|
196
|
-
# hasattr(self, "loop_response")
|
|
197
|
-
# and self.loop_response
|
|
198
|
-
# and self.loop_response.message.text
|
|
199
|
-
# ):
|
|
200
|
-
# complete_history.append(
|
|
201
|
-
# LanguageModelAssistantMessage(
|
|
202
|
-
# content=self.loop_response.message.text
|
|
203
|
-
# )
|
|
204
|
-
# )
|
|
205
|
-
# else:
|
|
206
|
-
# self.logger.warning(
|
|
207
|
-
# "Called get_complete_conversation_history_after_streaming_no_tool_calls but no loop_response.message.text is available"
|
|
208
|
-
# )
|
|
209
|
-
|
|
210
|
-
return limited_history
|
|
211
|
-
|
|
212
|
-
def _merge_history_and_uploads(
|
|
213
|
-
self, history: list[ChatMessage], uploads: list[Content]
|
|
214
|
-
) -> list[LanguageModelMessage]:
|
|
215
|
-
# Assert that all content have a created_at
|
|
216
|
-
content_with_created_at = [content for content in uploads if content.created_at]
|
|
217
|
-
sorted_history = sorted(
|
|
218
|
-
history + content_with_created_at,
|
|
219
|
-
key=lambda x: x.created_at or datetime.min,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
msg_builder = MessagesBuilder()
|
|
223
|
-
for msg in sorted_history:
|
|
224
|
-
if isinstance(msg, Content):
|
|
225
|
-
msg_builder.user_message_append(
|
|
226
|
-
f"Uploaded file: {msg.key}, ContentId: {msg.id}"
|
|
227
|
-
)
|
|
228
|
-
else:
|
|
229
|
-
msg_builder.messages.append(
|
|
230
|
-
LanguageModelMessage(
|
|
231
|
-
role=LanguageModelMessageRole(msg.role),
|
|
232
|
-
content=msg.content,
|
|
233
|
-
)
|
|
234
|
-
)
|
|
235
|
-
return msg_builder.messages
|
|
236
|
-
|
|
237
|
-
def _limit_to_token_window(
|
|
238
|
-
self, messages: list[LanguageModelMessage], token_limit: int
|
|
239
|
-
) -> list[LanguageModelMessage]:
|
|
240
|
-
selected_messages = []
|
|
241
|
-
token_count = 0
|
|
242
|
-
for msg in messages[::-1]:
|
|
243
|
-
msg_token_count = count_tokens(str(msg.content))
|
|
244
|
-
if token_count + msg_token_count > token_limit:
|
|
245
|
-
break
|
|
246
|
-
selected_messages.append(msg)
|
|
247
|
-
token_count += msg_token_count
|
|
248
|
-
return selected_messages[::-1]
|
|
249
|
-
|
|
250
|
-
async def remove_post_processing_manipulations(
|
|
251
|
-
self, remove_from_text: Callable[[str], Awaitable[str]]
|
|
252
|
-
) -> list[LanguageModelMessage]:
|
|
253
|
-
messages = await self.get_history()
|
|
254
|
-
for message in messages:
|
|
255
|
-
if isinstance(message.content, str):
|
|
256
|
-
message.content = await remove_from_text(message.content)
|
|
257
|
-
else:
|
|
258
|
-
self._logger.warning(
|
|
259
|
-
f"Skipping message with unsupported content type: {type(message.content)}"
|
|
260
|
-
)
|
|
261
|
-
return messages
|
|
230
|
+
return messages
|
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
|
|
2
|
+
import json
|
|
3
|
+
from logging import Logger
|
|
4
|
+
from typing import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
import tiktoken
|
|
8
|
+
from _common.token.token_counting import num_token_for_language_model_messages
|
|
9
|
+
from _common.validators import LMI
|
|
10
|
+
from app.schemas import ChatEvent
|
|
11
|
+
from chat.service import ChatService
|
|
12
|
+
from content.schemas import ContentChunk
|
|
13
|
+
from content.service import ContentService
|
|
14
|
+
from history_manager.history_construction_with_contents import FileContentSerialization, get_full_history_with_contents
|
|
15
|
+
from history_manager.history_manager import HistoryManagerConfig
|
|
16
|
+
from language_model.schemas import LanguageModelAssistantMessage, LanguageModelMessage, LanguageModelMessageRole, LanguageModelMessages, LanguageModelSystemMessage, LanguageModelToolMessage, LanguageModelUserMessage
|
|
17
|
+
from reference_manager.reference_manager import ReferenceManager
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SourceReductionResult(BaseModel):
|
|
21
|
+
message: LanguageModelToolMessage
|
|
22
|
+
reduced_chunks: list[ContentChunk]
|
|
23
|
+
chunk_offset: int
|
|
24
|
+
source_offset: int
|
|
25
|
+
|
|
26
|
+
class Config:
|
|
27
|
+
arbitrary_types_allowed = True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LoopTokenReducer():
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
logger: Logger,
|
|
35
|
+
event: ChatEvent,
|
|
36
|
+
config: HistoryManagerConfig,
|
|
37
|
+
reference_manager: ReferenceManager,
|
|
38
|
+
language_model: LMI
|
|
39
|
+
):
|
|
40
|
+
self._config = config
|
|
41
|
+
self._logger = logger
|
|
42
|
+
self._reference_manager = reference_manager
|
|
43
|
+
self._language_model = language_model
|
|
44
|
+
self._encoder = self._get_encoder(language_model)
|
|
45
|
+
self._chat_service = ChatService(event)
|
|
46
|
+
self._content_service = ContentService.from_event(event)
|
|
47
|
+
self._user_message = event.payload.user_message
|
|
48
|
+
self._chat_id = event.payload.chat_id
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_encoder(self, language_model: LMI) -> tiktoken.Encoding:
|
|
52
|
+
name = language_model.name or "cl100k_base"
|
|
53
|
+
return tiktoken.get_encoding(name)
|
|
54
|
+
|
|
55
|
+
async def get_history_for_model_call( self,
|
|
56
|
+
original_user_message: str,
|
|
57
|
+
rendered_user_message_string: str,
|
|
58
|
+
rendered_system_message_string: str,
|
|
59
|
+
loop_history: list[LanguageModelMessage],
|
|
60
|
+
remove_from_text: Callable[[str], Awaitable[str]]
|
|
61
|
+
) -> LanguageModelMessages:
|
|
62
|
+
"""Compose the system and user messages for the plan execution step, which is evaluating if any further tool calls are required."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
messages = await self._construct_history(
|
|
66
|
+
original_user_message,
|
|
67
|
+
rendered_user_message_string,
|
|
68
|
+
rendered_system_message_string,
|
|
69
|
+
loop_history,
|
|
70
|
+
remove_from_text
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
token_count = self._count_message_tokens(messages)
|
|
74
|
+
self._log_token_usage(token_count)
|
|
75
|
+
|
|
76
|
+
while self._exceeds_token_limit(token_count):
|
|
77
|
+
token_count_before_reduction = token_count
|
|
78
|
+
loop_history = self._handle_token_limit_exceeded(loop_history)
|
|
79
|
+
messages = await self._construct_history(
|
|
80
|
+
original_user_message,
|
|
81
|
+
rendered_user_message_string,
|
|
82
|
+
rendered_system_message_string,
|
|
83
|
+
loop_history,
|
|
84
|
+
remove_from_text
|
|
85
|
+
)
|
|
86
|
+
token_count = self._count_message_tokens(messages)
|
|
87
|
+
self._log_token_usage(token_count)
|
|
88
|
+
token_count_after_reduction = token_count
|
|
89
|
+
if token_count_after_reduction >= token_count_before_reduction:
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
return messages
|
|
93
|
+
|
|
94
|
+
def _exceeds_token_limit(self, token_count: int) -> bool:
|
|
95
|
+
"""Check if token count exceeds the maximum allowed limit and if at least one tool call has more than one source."""
|
|
96
|
+
# At least one tool call should have more than one chunk as answer
|
|
97
|
+
has_multiple_chunks_for_a_tool_call = any(
|
|
98
|
+
len(chunks) > 1
|
|
99
|
+
for chunks in self._reference_manager.get_chunks_of_all_tools()
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# TODO: This is not fully correct at the moment as the token_count
|
|
103
|
+
# include system_prompt and user question already
|
|
104
|
+
# TODO: There is a problem if we exceed but only have one chunk per tool call
|
|
105
|
+
exceeds_limit = (
|
|
106
|
+
token_count
|
|
107
|
+
> self._language_model.token_limits.token_limit_input
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return has_multiple_chunks_for_a_tool_call and exceeds_limit
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _count_message_tokens(self, messages: LanguageModelMessages) -> int:
|
|
114
|
+
"""Count tokens in messages using the configured encoding model."""
|
|
115
|
+
return num_token_for_language_model_messages(
|
|
116
|
+
messages=messages,
|
|
117
|
+
encode=self._encoder.encode,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _log_token_usage(self, token_count: int) -> None:
|
|
121
|
+
"""Log token usage and update debug info."""
|
|
122
|
+
self._logger.info(f"Token messages: {token_count}")
|
|
123
|
+
# self.agent_debug_info.add("token_messages", token_count)
|
|
124
|
+
|
|
125
|
+
async def _construct_history(
|
|
126
|
+
self,
|
|
127
|
+
original_user_message: str,
|
|
128
|
+
rendered_user_message_string: str,
|
|
129
|
+
rendered_system_message_string: str,
|
|
130
|
+
loop_history: list[LanguageModelMessage],
|
|
131
|
+
remove_from_text: Callable[[str], Awaitable[str]]
|
|
132
|
+
) -> LanguageModelMessages:
|
|
133
|
+
history_from_db = await self._get_history_from_db(remove_from_text)
|
|
134
|
+
history_from_db = self._replace_user_message(history_from_db, original_user_message, rendered_user_message_string)
|
|
135
|
+
system_message = LanguageModelSystemMessage(content=rendered_system_message_string)
|
|
136
|
+
|
|
137
|
+
constructed_history = LanguageModelMessages(
|
|
138
|
+
[system_message] + history_from_db + loop_history,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
return constructed_history
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _handle_token_limit_exceeded(self,loop_history: list[LanguageModelMessage]) -> list[LanguageModelMessage]:
|
|
146
|
+
"""Handle case where token limit is exceeded by reducing sources in tool responses."""
|
|
147
|
+
self._logger.warning(
|
|
148
|
+
f"Length of messages is exceeds limit of {self._language_model.token_limits.token_limit_input} tokens. "
|
|
149
|
+
"Reducing number of sources per tool call.",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return self._reduce_message_length_by_reducing_sources_in_tool_response(loop_history)
|
|
153
|
+
|
|
154
|
+
def _replace_user_message(
|
|
155
|
+
self,
|
|
156
|
+
history: list[LanguageModelMessage],
|
|
157
|
+
original_user_message: str,
|
|
158
|
+
rendered_user_message_string: str,
|
|
159
|
+
) -> list[LanguageModelMessage]:
|
|
160
|
+
"""
|
|
161
|
+
Replaces the original user message in the history with the rendered user message string.
|
|
162
|
+
"""
|
|
163
|
+
if history[-1].role == LanguageModelMessageRole.USER:
|
|
164
|
+
m = history[-1]
|
|
165
|
+
|
|
166
|
+
if isinstance(m.content, list):
|
|
167
|
+
# Replace the last text element but be careful not to delete data added when merging with contents
|
|
168
|
+
for t in reversed(m.content):
|
|
169
|
+
field = t.get("type", "")
|
|
170
|
+
if field == "text" and isinstance(field, dict):
|
|
171
|
+
inner_field = field.get("text", "")
|
|
172
|
+
if isinstance(inner_field, str):
|
|
173
|
+
added_to_message_by_history = inner_field.replace(
|
|
174
|
+
original_user_message,
|
|
175
|
+
"",
|
|
176
|
+
)
|
|
177
|
+
t["text"] = rendered_user_message_string + added_to_message_by_history
|
|
178
|
+
break
|
|
179
|
+
elif m.content:
|
|
180
|
+
added_to_message_by_history = m.content.replace(original_user_message, "")
|
|
181
|
+
m.content = rendered_user_message_string + added_to_message_by_history
|
|
182
|
+
else:
|
|
183
|
+
history = history + [
|
|
184
|
+
LanguageModelUserMessage(content=rendered_user_message_string),
|
|
185
|
+
]
|
|
186
|
+
return history
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def _get_history_from_db(
|
|
190
|
+
self,
|
|
191
|
+
remove_from_text: Callable[[str], Awaitable[str]]
|
|
192
|
+
) -> list[LanguageModelMessage]:
|
|
193
|
+
"""
|
|
194
|
+
Get the history of the conversation. The function will retrieve a subset of the full history based on the configuration.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
list[LanguageModelMessage]: The history
|
|
198
|
+
"""
|
|
199
|
+
full_history = get_full_history_with_contents(
|
|
200
|
+
user_message=self._user_message,
|
|
201
|
+
chat_id=self._chat_id,
|
|
202
|
+
chat_service=self._chat_service,
|
|
203
|
+
content_service=self._content_service,
|
|
204
|
+
file_content_serialization_type=(
|
|
205
|
+
FileContentSerialization.NONE
|
|
206
|
+
if self._config.uploaded_content_config
|
|
207
|
+
else FileContentSerialization.FILE_NAME
|
|
208
|
+
),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
full_history.root = await self._clean_messages(full_history.root, remove_from_text)
|
|
212
|
+
|
|
213
|
+
limited_history_messages = self._limit_to_token_window(
|
|
214
|
+
full_history.root,
|
|
215
|
+
self._config.input_token_distribution.max_history_tokens(
|
|
216
|
+
self._language_model.token_limits.token_limit_input,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
if len(limited_history_messages) == 0:
|
|
222
|
+
limited_history_messages = full_history.root[-1:]
|
|
223
|
+
|
|
224
|
+
self._logger.info(
|
|
225
|
+
f"Reduced history to {len(limited_history_messages)} messages from {len(full_history.root)}",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return self.ensure_last_message_is_user_message(limited_history_messages)
|
|
229
|
+
|
|
230
|
+
def _limit_to_token_window(
|
|
231
|
+
self, messages: list[LanguageModelMessage], token_limit: int
|
|
232
|
+
) -> list[LanguageModelMessage]:
|
|
233
|
+
selected_messages = []
|
|
234
|
+
token_count = 0
|
|
235
|
+
for msg in messages[::-1]:
|
|
236
|
+
msg_token_count = self._count_tokens(str(msg.content))
|
|
237
|
+
if token_count + msg_token_count > token_limit:
|
|
238
|
+
break
|
|
239
|
+
selected_messages.append(msg)
|
|
240
|
+
token_count += msg_token_count
|
|
241
|
+
return selected_messages[::-1]
|
|
242
|
+
|
|
243
|
+
async def _clean_messages(
|
|
244
|
+
self,
|
|
245
|
+
messages: list[LanguageModelMessage | LanguageModelToolMessage | LanguageModelAssistantMessage | LanguageModelSystemMessage | LanguageModelUserMessage],
|
|
246
|
+
remove_from_text: Callable[[str], Awaitable[str]]
|
|
247
|
+
) -> list[LanguageModelMessage]:
|
|
248
|
+
for message in messages:
|
|
249
|
+
if isinstance(message.content, str):
|
|
250
|
+
message.content = await remove_from_text(message.content)
|
|
251
|
+
else:
|
|
252
|
+
self._logger.warning(
|
|
253
|
+
f"Skipping message with unsupported content type: {type(message.content)}"
|
|
254
|
+
)
|
|
255
|
+
return messages
|
|
256
|
+
|
|
257
|
+
def _count_tokens(self,text:str) -> int:
|
|
258
|
+
|
|
259
|
+
return len(self._encoder.encode(text))
|
|
260
|
+
|
|
261
|
+
def ensure_last_message_is_user_message(self, limited_history_messages):
|
|
262
|
+
"""
|
|
263
|
+
As the token limit can be reached in the middle of a gpt_request,
|
|
264
|
+
we move forward to the next user message,to avoid confusing messages for the LLM
|
|
265
|
+
"""
|
|
266
|
+
idx = 0
|
|
267
|
+
for idx, message in enumerate(limited_history_messages):
|
|
268
|
+
if message.role == LanguageModelMessageRole.USER:
|
|
269
|
+
break
|
|
270
|
+
|
|
271
|
+
# FIXME: This might reduce the history by a lot if we have a lot of tool calls / references in the history. Could make sense to summarize the messages and include
|
|
272
|
+
# FIXME: We should remove chunks no longer in history from handler
|
|
273
|
+
return limited_history_messages[idx:]
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _reduce_message_length_by_reducing_sources_in_tool_response(
|
|
277
|
+
self,
|
|
278
|
+
history: list[LanguageModelMessage],
|
|
279
|
+
) -> list[LanguageModelMessage]:
|
|
280
|
+
"""
|
|
281
|
+
Reduce the message length by removing the last source result of each tool call.
|
|
282
|
+
If there is only one source for a tool call, the tool call message is returned unchanged.
|
|
283
|
+
"""
|
|
284
|
+
history_reduced: list[LanguageModelMessage] = []
|
|
285
|
+
content_chunks_reduced: list[ContentChunk] = []
|
|
286
|
+
chunk_offset = 0
|
|
287
|
+
source_offset = 0
|
|
288
|
+
|
|
289
|
+
for message in history:
|
|
290
|
+
if self._should_reduce_message(message):
|
|
291
|
+
result = self._reduce_sources_in_tool_message(
|
|
292
|
+
message, # type: ignore
|
|
293
|
+
chunk_offset,
|
|
294
|
+
source_offset,
|
|
295
|
+
)
|
|
296
|
+
content_chunks_reduced.extend(result.reduced_chunks)
|
|
297
|
+
history_reduced.append(result.message)
|
|
298
|
+
chunk_offset = result.chunk_offset
|
|
299
|
+
source_offset = result.source_offset
|
|
300
|
+
else:
|
|
301
|
+
history_reduced.append(message)
|
|
302
|
+
|
|
303
|
+
self._reference_manager.replace(chunks=content_chunks_reduced)
|
|
304
|
+
return history_reduced
|
|
305
|
+
|
|
306
|
+
def _should_reduce_message(self, message: LanguageModelMessage) -> bool:
|
|
307
|
+
"""Determine if a message should have its sources reduced."""
|
|
308
|
+
return (
|
|
309
|
+
message.role == LanguageModelMessageRole.TOOL
|
|
310
|
+
and isinstance(message, LanguageModelToolMessage)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _reduce_sources_in_tool_message(
|
|
315
|
+
self,
|
|
316
|
+
message: LanguageModelToolMessage,
|
|
317
|
+
chunk_offset: int,
|
|
318
|
+
source_offset: int,
|
|
319
|
+
) -> SourceReductionResult:
|
|
320
|
+
"""
|
|
321
|
+
Reduce the sources in the tool message by removing the last source.
|
|
322
|
+
If there is only one source, the message is returned unchanged.
|
|
323
|
+
"""
|
|
324
|
+
tool_chunks = self._reference_manager.get_chunks_of_tool(message.tool_call_id)
|
|
325
|
+
num_sources = len(tool_chunks)
|
|
326
|
+
|
|
327
|
+
if num_sources == 0:
|
|
328
|
+
return SourceReductionResult(
|
|
329
|
+
message=message,
|
|
330
|
+
reduced_chunks=[],
|
|
331
|
+
chunk_offset=chunk_offset,
|
|
332
|
+
source_offset=source_offset,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Reduce chunks, keeping all but the last one if multiple exist
|
|
336
|
+
if num_sources == 1:
|
|
337
|
+
reduced_chunks = tool_chunks
|
|
338
|
+
content_chunks_reduced = self._reference_manager.get_chunks()[
|
|
339
|
+
chunk_offset : chunk_offset + num_sources
|
|
340
|
+
]
|
|
341
|
+
else:
|
|
342
|
+
reduced_chunks = tool_chunks[:-1]
|
|
343
|
+
content_chunks_reduced = self._reference_manager.get_chunks()[
|
|
344
|
+
chunk_offset : chunk_offset + num_sources - 1
|
|
345
|
+
]
|
|
346
|
+
self._reference_manager.replace_chunks_of_tool(
|
|
347
|
+
message.tool_call_id,
|
|
348
|
+
reduced_chunks
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Create new message with reduced sources
|
|
352
|
+
new_message = self._create_tool_call_message_with_reduced_sources(
|
|
353
|
+
message=message,
|
|
354
|
+
content_chunks=reduced_chunks,
|
|
355
|
+
source_offset=source_offset,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return SourceReductionResult(
|
|
359
|
+
message=new_message,
|
|
360
|
+
reduced_chunks=content_chunks_reduced,
|
|
361
|
+
chunk_offset=chunk_offset + num_sources,
|
|
362
|
+
source_offset=source_offset
|
|
363
|
+
+ num_sources
|
|
364
|
+
- (1 if num_sources != 1 else 0),
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
def _create_tool_call_message_with_reduced_sources(
|
|
368
|
+
self,
|
|
369
|
+
message: LanguageModelToolMessage,
|
|
370
|
+
content_chunks: list[ContentChunk] | None = None,
|
|
371
|
+
source_offset: int = 0,
|
|
372
|
+
) -> LanguageModelToolMessage:
|
|
373
|
+
# Handle special case for TableSearch tool
|
|
374
|
+
if message.name == "TableSearch":
|
|
375
|
+
return self._create_reduced_table_search_message(
|
|
376
|
+
message, content_chunks, source_offset
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Handle empty content case
|
|
380
|
+
if not content_chunks:
|
|
381
|
+
return self._create_reduced_empty_sources_message(message)
|
|
382
|
+
|
|
383
|
+
# Handle standard content chunks
|
|
384
|
+
return self._create_reduced_standard_sources_message(
|
|
385
|
+
message, content_chunks, source_offset
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def _create_reduced_table_search_message(
|
|
389
|
+
self,
|
|
390
|
+
message: LanguageModelToolMessage,
|
|
391
|
+
content_chunks: list[ContentChunk] | None,
|
|
392
|
+
source_offset: int,
|
|
393
|
+
) -> LanguageModelToolMessage:
|
|
394
|
+
"""
|
|
395
|
+
Create a message for TableSearch tool.
|
|
396
|
+
|
|
397
|
+
Note: TableSearch content consists of a single result with SQL results,
|
|
398
|
+
not content chunks.
|
|
399
|
+
"""
|
|
400
|
+
if not content_chunks:
|
|
401
|
+
content = message.content
|
|
402
|
+
else:
|
|
403
|
+
if isinstance(message.content, str):
|
|
404
|
+
content_dict = json.loads(message.content)
|
|
405
|
+
elif isinstance(message.content, dict):
|
|
406
|
+
content_dict = message.content
|
|
407
|
+
else:
|
|
408
|
+
raise ValueError(
|
|
409
|
+
f"Unexpected content type: {type(message.content)}"
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
content = json.dumps(
|
|
413
|
+
{
|
|
414
|
+
"source_number": source_offset,
|
|
415
|
+
"content": content_dict.get("content"),
|
|
416
|
+
}
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return LanguageModelToolMessage(
|
|
420
|
+
content=content,
|
|
421
|
+
tool_call_id=message.tool_call_id,
|
|
422
|
+
name=message.name,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _create_reduced_empty_sources_message(
|
|
427
|
+
self,
|
|
428
|
+
message: LanguageModelToolMessage,
|
|
429
|
+
) -> LanguageModelToolMessage:
|
|
430
|
+
"""Create a message when no content chunks are available."""
|
|
431
|
+
return LanguageModelToolMessage(
|
|
432
|
+
content="No relevant sources found.",
|
|
433
|
+
tool_call_id=message.tool_call_id,
|
|
434
|
+
name=message.name,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _create_reduced_standard_sources_message(
|
|
439
|
+
self,
|
|
440
|
+
message: LanguageModelToolMessage,
|
|
441
|
+
content_chunks: list[ContentChunk],
|
|
442
|
+
source_offset: int,
|
|
443
|
+
) -> LanguageModelToolMessage:
|
|
444
|
+
"""Create a message with standard content chunks."""
|
|
445
|
+
sources = [
|
|
446
|
+
{
|
|
447
|
+
"source_number": source_offset + i,
|
|
448
|
+
"content": chunk.text,
|
|
449
|
+
}
|
|
450
|
+
for i, chunk in enumerate(content_chunks)
|
|
451
|
+
]
|
|
452
|
+
|
|
453
|
+
return LanguageModelToolMessage(
|
|
454
|
+
content=str(sources),
|
|
455
|
+
tool_call_id=message.tool_call_id,
|
|
456
|
+
name=message.name,
|
|
457
|
+
)
|