unique_toolkit 0.8.14__py3-none-any.whl → 0.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. unique_toolkit/_common/default_language_model.py +6 -0
  2. unique_toolkit/_common/token/image_token_counting.py +67 -0
  3. unique_toolkit/_common/token/token_counting.py +196 -0
  4. unique_toolkit/evals/config.py +36 -0
  5. unique_toolkit/evals/context_relevancy/prompts.py +56 -0
  6. unique_toolkit/evals/context_relevancy/schema.py +88 -0
  7. unique_toolkit/evals/context_relevancy/service.py +241 -0
  8. unique_toolkit/evals/hallucination/constants.py +61 -0
  9. unique_toolkit/evals/hallucination/hallucination_evaluation.py +92 -0
  10. unique_toolkit/evals/hallucination/prompts.py +79 -0
  11. unique_toolkit/evals/hallucination/service.py +57 -0
  12. unique_toolkit/evals/hallucination/utils.py +213 -0
  13. unique_toolkit/evals/output_parser.py +48 -0
  14. unique_toolkit/evals/tests/test_context_relevancy_service.py +252 -0
  15. unique_toolkit/evals/tests/test_output_parser.py +80 -0
  16. unique_toolkit/history_manager/history_construction_with_contents.py +307 -0
  17. unique_toolkit/history_manager/history_manager.py +80 -111
  18. unique_toolkit/history_manager/loop_token_reducer.py +457 -0
  19. unique_toolkit/language_model/schemas.py +8 -0
  20. unique_toolkit/reference_manager/reference_manager.py +15 -2
  21. {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/METADATA +7 -1
  22. {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/RECORD +24 -7
  23. {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/LICENSE +0 -0
  24. {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/WHEEL +0 -0
@@ -0,0 +1,252 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+ from unique_toolkit.app.schemas import ChatEvent
5
+ from unique_toolkit.chat.service import LanguageModelName
6
+ from unique_toolkit.language_model.infos import (
7
+ LanguageModelInfo,
8
+ )
9
+ from unique_toolkit.language_model.schemas import (
10
+ LanguageModelAssistantMessage,
11
+ LanguageModelCompletionChoice,
12
+ LanguageModelMessages,
13
+ )
14
+ from unique_toolkit.language_model.service import LanguageModelResponse
15
+ from unique_toolkit.evals.config import EvaluationMetricConfig
16
+ from unique_toolkit.evals.context_relevancy.prompts import (
17
+ CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
18
+ )
19
+ from unique_toolkit.evals.context_relevancy.schema import (
20
+ EvaluationSchemaStructuredOutput,
21
+ )
22
+ from unique_toolkit.evals.context_relevancy.service import (
23
+ ContextRelevancyEvaluator,
24
+ )
25
+ from unique_toolkit.evals.exception import EvaluatorException
26
+ from unique_toolkit.evals.schemas import (
27
+ EvaluationMetricInput,
28
+ EvaluationMetricName,
29
+ EvaluationMetricResult,
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def event():
35
+ event = MagicMock(spec=ChatEvent)
36
+ event.payload = MagicMock()
37
+ event.payload.user_message = MagicMock()
38
+ event.payload.user_message.text = "Test query"
39
+ event.user_id = "user_0"
40
+ event.company_id = "company_0"
41
+ return event
42
+
43
+
44
+ @pytest.fixture
45
+ def evaluator(event):
46
+ return ContextRelevancyEvaluator(event)
47
+
48
+
49
+ @pytest.fixture
50
+ def basic_config():
51
+ return EvaluationMetricConfig(
52
+ enabled=True,
53
+ name=EvaluationMetricName.CONTEXT_RELEVANCY,
54
+ language_model=LanguageModelInfo.from_name(
55
+ LanguageModelName.AZURE_GPT_4o_2024_0806
56
+ ),
57
+ )
58
+
59
+
60
+ @pytest.fixture
61
+ def structured_config(basic_config):
62
+ model_info = LanguageModelInfo.from_name(LanguageModelName.AZURE_GPT_4o_2024_0806)
63
+ return EvaluationMetricConfig(
64
+ enabled=True,
65
+ name=EvaluationMetricName.CONTEXT_RELEVANCY,
66
+ language_model=model_info,
67
+ )
68
+
69
+
70
+ @pytest.fixture
71
+ def sample_input():
72
+ return EvaluationMetricInput(
73
+ input_text="test query",
74
+ context_texts=["test context 1", "test context 2"],
75
+ )
76
+
77
+
78
+ @pytest.mark.asyncio
79
+ async def test_analyze_disabled(evaluator, sample_input, basic_config):
80
+ basic_config.enabled = False
81
+ result = await evaluator.analyze(sample_input, basic_config)
82
+ assert result is None
83
+
84
+
85
+ @pytest.mark.asyncio
86
+ async def test_analyze_empty_context(evaluator, basic_config):
87
+ input_with_empty_context = EvaluationMetricInput(
88
+ input_text="test query", context_texts=[]
89
+ )
90
+
91
+ with pytest.raises(EvaluatorException) as exc_info:
92
+ await evaluator.analyze(input_with_empty_context, basic_config)
93
+
94
+ assert "No context texts provided." in str(exc_info.value)
95
+
96
+
97
+ @pytest.mark.asyncio
98
+ async def test_analyze_regular_output(evaluator, sample_input, basic_config):
99
+ mock_result = LanguageModelResponse(
100
+ choices=[
101
+ LanguageModelCompletionChoice(
102
+ index=0,
103
+ message=LanguageModelAssistantMessage(
104
+ content="""{
105
+ "value": "high",
106
+ "reason": "Test reason"
107
+ }"""
108
+ ),
109
+ finish_reason="stop",
110
+ )
111
+ ]
112
+ )
113
+
114
+ with patch.object(
115
+ evaluator.language_model_service,
116
+ "complete_async",
117
+ return_value=mock_result,
118
+ ) as mock_complete:
119
+ result = await evaluator.analyze(sample_input, basic_config)
120
+
121
+ assert isinstance(result, EvaluationMetricResult)
122
+ assert result.value.lower() == "high"
123
+ mock_complete.assert_called_once()
124
+
125
+
126
+ @pytest.mark.asyncio
127
+ async def test_analyze_structured_output(evaluator, sample_input, structured_config):
128
+ mock_result = LanguageModelResponse(
129
+ choices=[
130
+ LanguageModelCompletionChoice(
131
+ index=0,
132
+ message=LanguageModelAssistantMessage(
133
+ content="HIGH",
134
+ parsed={"value": "high", "reason": "Test reason"},
135
+ ),
136
+ finish_reason="stop",
137
+ )
138
+ ]
139
+ )
140
+
141
+ structured_output_schema = EvaluationSchemaStructuredOutput
142
+
143
+ with patch.object(
144
+ evaluator.language_model_service,
145
+ "complete_async",
146
+ return_value=mock_result,
147
+ ) as mock_complete:
148
+ result = await evaluator.analyze(
149
+ sample_input, structured_config, structured_output_schema
150
+ )
151
+ assert isinstance(result, EvaluationMetricResult)
152
+ assert result.value.lower() == "high"
153
+ mock_complete.assert_called_once()
154
+
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_analyze_structured_output_validation_error(
158
+ evaluator, sample_input, structured_config
159
+ ):
160
+ mock_result = LanguageModelResponse(
161
+ choices=[
162
+ LanguageModelCompletionChoice(
163
+ index=0,
164
+ message=LanguageModelAssistantMessage(
165
+ content="HIGH", parsed={"invalid": "data"}
166
+ ),
167
+ finish_reason="stop",
168
+ )
169
+ ]
170
+ )
171
+
172
+ structured_output_schema = EvaluationSchemaStructuredOutput
173
+
174
+ with patch.object(
175
+ evaluator.language_model_service,
176
+ "complete_async",
177
+ return_value=mock_result,
178
+ ):
179
+ with pytest.raises(EvaluatorException) as exc_info:
180
+ await evaluator.analyze(
181
+ sample_input, structured_config, structured_output_schema
182
+ )
183
+ assert "Error occurred during structured output validation" in str(
184
+ exc_info.value
185
+ )
186
+
187
+
188
+ @pytest.mark.asyncio
189
+ async def test_analyze_regular_output_empty_response(
190
+ evaluator, sample_input, basic_config
191
+ ):
192
+ mock_result = LanguageModelResponse(
193
+ choices=[
194
+ LanguageModelCompletionChoice(
195
+ index=0,
196
+ message=LanguageModelAssistantMessage(content=""),
197
+ finish_reason="stop",
198
+ )
199
+ ]
200
+ )
201
+
202
+ with patch.object(
203
+ evaluator.language_model_service,
204
+ "complete_async",
205
+ return_value=mock_result,
206
+ ):
207
+ with pytest.raises(EvaluatorException) as exc_info:
208
+ await evaluator.analyze(sample_input, basic_config)
209
+ assert "did not return a result" in str(exc_info.value)
210
+
211
+
212
+ def test_compose_msgs_regular(evaluator, sample_input, basic_config):
213
+ messages = evaluator._compose_msgs(
214
+ sample_input, basic_config, enable_structured_output=False
215
+ )
216
+
217
+ assert isinstance(messages, LanguageModelMessages)
218
+ assert messages.root[0].content == CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
219
+ assert isinstance(messages.root[1].content, str)
220
+ assert "test query" in messages.root[1].content
221
+ assert "test context 1" in messages.root[1].content
222
+ assert "test context 2" in messages.root[1].content
223
+
224
+
225
+ def test_compose_msgs_structured(evaluator, sample_input, structured_config):
226
+ messages = evaluator._compose_msgs(
227
+ sample_input, structured_config, enable_structured_output=True
228
+ )
229
+
230
+ assert isinstance(messages, LanguageModelMessages)
231
+ assert len(messages.root) == 2
232
+ assert (
233
+ messages.root[0].content != CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
234
+ ) # Should use structured output prompt
235
+ assert isinstance(messages.root[1].content, str)
236
+ assert "test query" in messages.root[1].content
237
+ assert "test context 1" in messages.root[1].content
238
+ assert "test context 2" in messages.root[1].content
239
+
240
+
241
+ @pytest.mark.asyncio
242
+ async def test_analyze_unknown_error(evaluator, sample_input, basic_config):
243
+ with patch.object(
244
+ evaluator.language_model_service,
245
+ "complete_async",
246
+ side_effect=Exception("Unknown error"),
247
+ ):
248
+ with pytest.raises(EvaluatorException) as exc_info:
249
+ await evaluator.analyze(sample_input, basic_config)
250
+ assert "Unknown error occurred during context relevancy metric analysis" in str(
251
+ exc_info.value
252
+ )
@@ -0,0 +1,80 @@
1
+ import pytest
2
+
3
+ from unique_toolkit.evals.context_relevancy.schema import EvaluationSchemaStructuredOutput, Fact
4
+ from unique_toolkit.evals.exception import EvaluatorException
5
+ from unique_toolkit.evals.output_parser import parse_eval_metric_result, parse_eval_metric_result_structured_output
6
+ from unique_toolkit.evals.schemas import EvaluationMetricName, EvaluationMetricResult
7
+
8
+
9
+
10
+
11
+ def test_parse_eval_metric_result_success():
12
+ # Test successful parsing with all fields
13
+ result = '{"value": "high", "reason": "Test reason"}'
14
+ parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
15
+
16
+ assert isinstance(parsed, EvaluationMetricResult)
17
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
18
+ assert parsed.value == "high"
19
+ assert parsed.reason == "Test reason"
20
+ assert parsed.fact_list == []
21
+
22
+
23
+ def test_parse_eval_metric_result_missing_fields():
24
+ # Test parsing with missing fields (should use default "None")
25
+ result = '{"value": "high"}'
26
+ parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
27
+
28
+ assert isinstance(parsed, EvaluationMetricResult)
29
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
30
+ assert parsed.value == "high"
31
+ assert parsed.reason == "None"
32
+ assert parsed.fact_list == []
33
+
34
+
35
+ def test_parse_eval_metric_result_invalid_json():
36
+ # Test parsing with invalid JSON
37
+ result = "invalid json"
38
+ with pytest.raises(EvaluatorException) as exc_info:
39
+ parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
40
+
41
+ assert "Error occurred during parsing the evaluation metric result" in str(
42
+ exc_info.value
43
+ )
44
+
45
+
46
+ def test_parse_eval_metric_result_structured_output_basic():
47
+ # Test basic structured output without fact list
48
+ result = EvaluationSchemaStructuredOutput(value="high", reason="Test reason")
49
+ parsed = parse_eval_metric_result_structured_output(
50
+ result, EvaluationMetricName.CONTEXT_RELEVANCY
51
+ )
52
+
53
+ assert isinstance(parsed, EvaluationMetricResult)
54
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
55
+ assert parsed.value == "high"
56
+ assert parsed.reason == "Test reason"
57
+ assert parsed.fact_list == []
58
+
59
+
60
+ def test_parse_eval_metric_result_structured_output_with_facts():
61
+ # Test structured output with fact list
62
+ result = EvaluationSchemaStructuredOutput(
63
+ value="high",
64
+ reason="Test reason",
65
+ fact_list=[
66
+ Fact(fact="Fact 1"),
67
+ Fact(fact="Fact 2"),
68
+ ],
69
+ )
70
+ parsed = parse_eval_metric_result_structured_output(
71
+ result, EvaluationMetricName.CONTEXT_RELEVANCY
72
+ )
73
+
74
+ assert isinstance(parsed, EvaluationMetricResult)
75
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
76
+ assert parsed.value == "high"
77
+ assert parsed.reason == "Test reason"
78
+ assert parsed.fact_list == ["Fact 1", "Fact 2"]
79
+ assert isinstance(parsed.fact_list, list)
80
+ assert len(parsed.fact_list) == 2 # None fact should be filtered out
@@ -0,0 +1,307 @@
1
+ import base64
2
+ import mimetypes
3
+
4
+ from datetime import datetime
5
+ from enum import StrEnum
6
+
7
+ import numpy as np
8
+ import tiktoken
9
+
10
+ from pydantic import RootModel
11
+
12
+ from _common.token.token_counting import num_tokens_per_language_model_message
13
+ from chat.service import ChatService
14
+ from content.service import ContentService
15
+ from language_model.schemas import LanguageModelMessages
16
+ from unique_toolkit.app import ChatEventUserMessage
17
+ from unique_toolkit.chat.schemas import ChatMessage
18
+ from unique_toolkit.chat.schemas import ChatMessageRole as ChatRole
19
+ from unique_toolkit.content.schemas import Content
20
+ from unique_toolkit.language_model import LanguageModelMessageRole as LLMRole
21
+ from unique_toolkit.language_model.infos import EncoderName
22
+
23
+
24
+
25
+ # TODO: Test this once it moves into the unique toolkit
26
+
27
+ map_chat_llm_message_role = {
28
+ ChatRole.USER: LLMRole.USER,
29
+ ChatRole.ASSISTANT: LLMRole.ASSISTANT,
30
+ }
31
+
32
+
33
+ class ImageMimeType(StrEnum):
34
+ JPEG = "image/jpeg"
35
+ PNG = "image/png"
36
+ GIF = "image/gif"
37
+ BMP = "image/bmp"
38
+ WEBP = "image/webp"
39
+ TIFF = "image/tiff"
40
+ SVG = "image/svg+xml"
41
+
42
+
43
+ class FileMimeType(StrEnum):
44
+ PDF = "application/pdf"
45
+ DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
46
+ DOC = "application/msword"
47
+ XLSX = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
48
+ XLS = "application/vnd.ms-excel"
49
+ PPTX = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
50
+ CSV = "text/csv"
51
+ HTML = "text/html"
52
+ MD = "text/markdown"
53
+ TXT = "text/plain"
54
+
55
+
56
+ class ChatMessageWithContents(ChatMessage):
57
+ contents: list[Content] = []
58
+
59
+
60
+ class ChatHistoryWithContent(RootModel):
61
+ root: list[ChatMessageWithContents]
62
+
63
+ @classmethod
64
+ def from_chat_history_and_contents(
65
+ cls,
66
+ chat_history: list[ChatMessage],
67
+ chat_contents: list[Content],
68
+ ):
69
+ combined = chat_contents + chat_history
70
+ combined.sort(key=lambda x: x.created_at or datetime.min)
71
+
72
+ grouped_elements = []
73
+ content_container = []
74
+
75
+ # Content is collected and added to the next chat message
76
+ for c in combined:
77
+ if isinstance(c, ChatMessage):
78
+ grouped_elements.append(
79
+ ChatMessageWithContents(
80
+ contents=content_container.copy(),
81
+ **c.model_dump(),
82
+ ),
83
+ )
84
+ content_container.clear()
85
+ else:
86
+ content_container.append(c)
87
+
88
+ return cls(root=grouped_elements)
89
+
90
+ def __iter__(self):
91
+ return iter(self.root)
92
+
93
+ def __getitem__(self, item):
94
+ return self.root[item]
95
+
96
+
97
+ def is_image_content(filename: str) -> bool:
98
+ mimetype, _ = mimetypes.guess_type(filename)
99
+
100
+ if not mimetype:
101
+ return False
102
+
103
+ return mimetype in ImageMimeType.__members__.values()
104
+
105
+
106
+ def is_file_content(filename: str) -> bool:
107
+ mimetype, _ = mimetypes.guess_type(filename)
108
+
109
+ if not mimetype:
110
+ return False
111
+
112
+ return mimetype in FileMimeType.__members__.values()
113
+
114
+
115
+ def get_chat_history_with_contents(
116
+ user_message: ChatEventUserMessage,
117
+ chat_id: str,
118
+ chat_history: list[ChatMessage],
119
+ content_service: ContentService,
120
+ ) -> ChatHistoryWithContent:
121
+ last_user_message = ChatMessage(
122
+ id=user_message.id,
123
+ chat_id=chat_id,
124
+ text=user_message.text,
125
+ originalText=user_message.original_text,
126
+ role=ChatRole.USER,
127
+ gpt_request=None,
128
+ created_at=datetime.fromisoformat(user_message.created_at),
129
+ )
130
+ if len(chat_history) > 0 and last_user_message.id == chat_history[-1].id:
131
+ pass
132
+ else:
133
+ chat_history.append(last_user_message)
134
+
135
+ chat_contents = content_service.search_contents(
136
+ where={
137
+ "ownerId": {
138
+ "equals": chat_id,
139
+ },
140
+ },
141
+ )
142
+
143
+ return ChatHistoryWithContent.from_chat_history_and_contents(
144
+ chat_history,
145
+ chat_contents,
146
+ )
147
+
148
+
149
+ def download_encoded_images(
150
+ contents: list[Content],
151
+ content_service: ContentService,
152
+ chat_id: str,
153
+ ) -> list[str]:
154
+ base64_encoded_images = []
155
+ for im in contents:
156
+ if is_image_content(im.key):
157
+ try:
158
+ file_bytes = content_service.download_content_to_bytes(
159
+ content_id=im.id,
160
+ chat_id=chat_id,
161
+ )
162
+
163
+ mime_type, _ = mimetypes.guess_type(im.key)
164
+ encoded_string = base64.b64encode(file_bytes).decode("utf-8")
165
+ image_string = f"data:{mime_type};base64," + encoded_string
166
+ base64_encoded_images.append(image_string)
167
+ except Exception as e:
168
+ print(e)
169
+ return base64_encoded_images
170
+
171
+
172
+ class FileContentSerialization(StrEnum):
173
+ NONE = "none"
174
+ FILE_NAME = "file_name"
175
+
176
+
177
+ class ImageContentInclusion(StrEnum):
178
+ NONE = "none"
179
+ ALL = "all"
180
+
181
+
182
+ def file_content_serialization(
183
+ file_contents: list[Content],
184
+ file_content_serialization: FileContentSerialization,
185
+ ) -> str:
186
+ match file_content_serialization:
187
+ case FileContentSerialization.NONE:
188
+ return ""
189
+ case FileContentSerialization.FILE_NAME:
190
+ file_names = [
191
+ f"- Uploaded file: {f.key} at {f.created_at}"
192
+ for f in file_contents
193
+ ]
194
+ return "\n".join(
195
+ [
196
+ "Files Uploaded to Chat can be accessed by internal search tool if available:\n",
197
+ ]
198
+ + file_names,
199
+ )
200
+
201
+
202
+ def get_full_history_with_contents(
203
+ user_message: ChatEventUserMessage,
204
+ chat_id: str,
205
+ chat_service: ChatService,
206
+ content_service: ContentService,
207
+ include_images: ImageContentInclusion = ImageContentInclusion.ALL,
208
+ file_content_serialization_type: FileContentSerialization = FileContentSerialization.FILE_NAME,
209
+ ) -> LanguageModelMessages:
210
+ grouped_elements = get_chat_history_with_contents(
211
+ user_message=user_message,
212
+ chat_id=chat_id,
213
+ chat_history=chat_service.get_full_history(),
214
+ content_service=content_service,
215
+ )
216
+
217
+ builder = LanguageModelMessages([]).builder()
218
+ for c in grouped_elements:
219
+ # LanguageModelUserMessage has not field original content
220
+ text = c.original_content if c.original_content else c.content
221
+ if text is None:
222
+ if c.role == ChatRole.USER:
223
+ raise ValueError(
224
+ "Content or original_content of LanguageModelMessages should exist.",
225
+ )
226
+ text = ""
227
+
228
+ if len(c.contents) > 0:
229
+ file_contents = [
230
+ co for co in c.contents if is_file_content(co.key)
231
+ ]
232
+ image_contents = [
233
+ co for co in c.contents if is_image_content(co.key)
234
+ ]
235
+
236
+ content = (
237
+ text
238
+ + "\n\n"
239
+ + file_content_serialization(
240
+ file_contents,
241
+ file_content_serialization_type,
242
+ )
243
+ )
244
+ content = content.strip()
245
+
246
+ if include_images and len(image_contents) > 0:
247
+ builder.image_message_append(
248
+ content=content,
249
+ images=download_encoded_images(
250
+ contents=image_contents,
251
+ content_service=content_service,
252
+ chat_id=chat_id,
253
+ ),
254
+ role=map_chat_llm_message_role[c.role],
255
+ )
256
+ else:
257
+ builder.message_append(
258
+ role=map_chat_llm_message_role[c.role],
259
+ content=content,
260
+ )
261
+ else:
262
+ builder.message_append(
263
+ role=map_chat_llm_message_role[c.role],
264
+ content=text,
265
+ )
266
+ return builder.build()
267
+
268
+
269
+ def get_full_history_as_llm_messages(
270
+ chat_service: ChatService,
271
+ ) -> LanguageModelMessages:
272
+ chat_history = chat_service.get_full_history()
273
+
274
+ map_chat_llm_message_role = {
275
+ ChatRole.USER: LLMRole.USER,
276
+ ChatRole.ASSISTANT: LLMRole.ASSISTANT,
277
+ }
278
+
279
+ builder = LanguageModelMessages([]).builder()
280
+ for c in chat_history:
281
+ builder.message_append(
282
+ role=map_chat_llm_message_role[c.role],
283
+ content=c.content or "",
284
+ )
285
+ return builder.build()
286
+
287
+
288
+
289
+ def limit_to_token_window(
290
+ messages: LanguageModelMessages,
291
+ token_limit: int,
292
+ encoding_name: EncoderName = EncoderName.O200K_BASE,
293
+ ) -> LanguageModelMessages:
294
+ encoder = tiktoken.get_encoding(encoding_name)
295
+ token_per_message_reversed = num_tokens_per_language_model_message(
296
+ messages,
297
+ encode=encoder.encode,
298
+ )
299
+
300
+ to_take: list[bool] = (
301
+ np.cumsum(token_per_message_reversed) < token_limit
302
+ ).tolist()
303
+ to_take.reverse()
304
+
305
+ return LanguageModelMessages(
306
+ root=[m for m, tt in zip(messages, to_take, strict=False) if tt],
307
+ )