unique_toolkit 0.8.14__py3-none-any.whl → 0.8.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/default_language_model.py +6 -0
- unique_toolkit/_common/token/image_token_counting.py +67 -0
- unique_toolkit/_common/token/token_counting.py +196 -0
- unique_toolkit/evals/config.py +36 -0
- unique_toolkit/evals/context_relevancy/prompts.py +56 -0
- unique_toolkit/evals/context_relevancy/schema.py +88 -0
- unique_toolkit/evals/context_relevancy/service.py +241 -0
- unique_toolkit/evals/hallucination/constants.py +61 -0
- unique_toolkit/evals/hallucination/hallucination_evaluation.py +92 -0
- unique_toolkit/evals/hallucination/prompts.py +79 -0
- unique_toolkit/evals/hallucination/service.py +57 -0
- unique_toolkit/evals/hallucination/utils.py +213 -0
- unique_toolkit/evals/output_parser.py +48 -0
- unique_toolkit/evals/tests/test_context_relevancy_service.py +252 -0
- unique_toolkit/evals/tests/test_output_parser.py +80 -0
- unique_toolkit/history_manager/history_construction_with_contents.py +307 -0
- unique_toolkit/history_manager/history_manager.py +80 -111
- unique_toolkit/history_manager/loop_token_reducer.py +457 -0
- unique_toolkit/language_model/schemas.py +8 -0
- unique_toolkit/reference_manager/reference_manager.py +15 -2
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/METADATA +7 -1
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/RECORD +24 -7
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.14.dist-info → unique_toolkit-0.8.16.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from unique_toolkit.app.schemas import ChatEvent
|
|
5
|
+
from unique_toolkit.chat.service import LanguageModelName
|
|
6
|
+
from unique_toolkit.language_model.infos import (
|
|
7
|
+
LanguageModelInfo,
|
|
8
|
+
)
|
|
9
|
+
from unique_toolkit.language_model.schemas import (
|
|
10
|
+
LanguageModelAssistantMessage,
|
|
11
|
+
LanguageModelCompletionChoice,
|
|
12
|
+
LanguageModelMessages,
|
|
13
|
+
)
|
|
14
|
+
from unique_toolkit.language_model.service import LanguageModelResponse
|
|
15
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
16
|
+
from unique_toolkit.evals.context_relevancy.prompts import (
|
|
17
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
|
18
|
+
)
|
|
19
|
+
from unique_toolkit.evals.context_relevancy.schema import (
|
|
20
|
+
EvaluationSchemaStructuredOutput,
|
|
21
|
+
)
|
|
22
|
+
from unique_toolkit.evals.context_relevancy.service import (
|
|
23
|
+
ContextRelevancyEvaluator,
|
|
24
|
+
)
|
|
25
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
26
|
+
from unique_toolkit.evals.schemas import (
|
|
27
|
+
EvaluationMetricInput,
|
|
28
|
+
EvaluationMetricName,
|
|
29
|
+
EvaluationMetricResult,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def event():
|
|
35
|
+
event = MagicMock(spec=ChatEvent)
|
|
36
|
+
event.payload = MagicMock()
|
|
37
|
+
event.payload.user_message = MagicMock()
|
|
38
|
+
event.payload.user_message.text = "Test query"
|
|
39
|
+
event.user_id = "user_0"
|
|
40
|
+
event.company_id = "company_0"
|
|
41
|
+
return event
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def evaluator(event):
|
|
46
|
+
return ContextRelevancyEvaluator(event)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def basic_config():
|
|
51
|
+
return EvaluationMetricConfig(
|
|
52
|
+
enabled=True,
|
|
53
|
+
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
|
54
|
+
language_model=LanguageModelInfo.from_name(
|
|
55
|
+
LanguageModelName.AZURE_GPT_4o_2024_0806
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def structured_config(basic_config):
|
|
62
|
+
model_info = LanguageModelInfo.from_name(LanguageModelName.AZURE_GPT_4o_2024_0806)
|
|
63
|
+
return EvaluationMetricConfig(
|
|
64
|
+
enabled=True,
|
|
65
|
+
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
|
66
|
+
language_model=model_info,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.fixture
|
|
71
|
+
def sample_input():
|
|
72
|
+
return EvaluationMetricInput(
|
|
73
|
+
input_text="test query",
|
|
74
|
+
context_texts=["test context 1", "test context 2"],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.mark.asyncio
|
|
79
|
+
async def test_analyze_disabled(evaluator, sample_input, basic_config):
|
|
80
|
+
basic_config.enabled = False
|
|
81
|
+
result = await evaluator.analyze(sample_input, basic_config)
|
|
82
|
+
assert result is None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.mark.asyncio
|
|
86
|
+
async def test_analyze_empty_context(evaluator, basic_config):
|
|
87
|
+
input_with_empty_context = EvaluationMetricInput(
|
|
88
|
+
input_text="test query", context_texts=[]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
92
|
+
await evaluator.analyze(input_with_empty_context, basic_config)
|
|
93
|
+
|
|
94
|
+
assert "No context texts provided." in str(exc_info.value)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.asyncio
|
|
98
|
+
async def test_analyze_regular_output(evaluator, sample_input, basic_config):
|
|
99
|
+
mock_result = LanguageModelResponse(
|
|
100
|
+
choices=[
|
|
101
|
+
LanguageModelCompletionChoice(
|
|
102
|
+
index=0,
|
|
103
|
+
message=LanguageModelAssistantMessage(
|
|
104
|
+
content="""{
|
|
105
|
+
"value": "high",
|
|
106
|
+
"reason": "Test reason"
|
|
107
|
+
}"""
|
|
108
|
+
),
|
|
109
|
+
finish_reason="stop",
|
|
110
|
+
)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with patch.object(
|
|
115
|
+
evaluator.language_model_service,
|
|
116
|
+
"complete_async",
|
|
117
|
+
return_value=mock_result,
|
|
118
|
+
) as mock_complete:
|
|
119
|
+
result = await evaluator.analyze(sample_input, basic_config)
|
|
120
|
+
|
|
121
|
+
assert isinstance(result, EvaluationMetricResult)
|
|
122
|
+
assert result.value.lower() == "high"
|
|
123
|
+
mock_complete.assert_called_once()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@pytest.mark.asyncio
|
|
127
|
+
async def test_analyze_structured_output(evaluator, sample_input, structured_config):
|
|
128
|
+
mock_result = LanguageModelResponse(
|
|
129
|
+
choices=[
|
|
130
|
+
LanguageModelCompletionChoice(
|
|
131
|
+
index=0,
|
|
132
|
+
message=LanguageModelAssistantMessage(
|
|
133
|
+
content="HIGH",
|
|
134
|
+
parsed={"value": "high", "reason": "Test reason"},
|
|
135
|
+
),
|
|
136
|
+
finish_reason="stop",
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
structured_output_schema = EvaluationSchemaStructuredOutput
|
|
142
|
+
|
|
143
|
+
with patch.object(
|
|
144
|
+
evaluator.language_model_service,
|
|
145
|
+
"complete_async",
|
|
146
|
+
return_value=mock_result,
|
|
147
|
+
) as mock_complete:
|
|
148
|
+
result = await evaluator.analyze(
|
|
149
|
+
sample_input, structured_config, structured_output_schema
|
|
150
|
+
)
|
|
151
|
+
assert isinstance(result, EvaluationMetricResult)
|
|
152
|
+
assert result.value.lower() == "high"
|
|
153
|
+
mock_complete.assert_called_once()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_analyze_structured_output_validation_error(
|
|
158
|
+
evaluator, sample_input, structured_config
|
|
159
|
+
):
|
|
160
|
+
mock_result = LanguageModelResponse(
|
|
161
|
+
choices=[
|
|
162
|
+
LanguageModelCompletionChoice(
|
|
163
|
+
index=0,
|
|
164
|
+
message=LanguageModelAssistantMessage(
|
|
165
|
+
content="HIGH", parsed={"invalid": "data"}
|
|
166
|
+
),
|
|
167
|
+
finish_reason="stop",
|
|
168
|
+
)
|
|
169
|
+
]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
structured_output_schema = EvaluationSchemaStructuredOutput
|
|
173
|
+
|
|
174
|
+
with patch.object(
|
|
175
|
+
evaluator.language_model_service,
|
|
176
|
+
"complete_async",
|
|
177
|
+
return_value=mock_result,
|
|
178
|
+
):
|
|
179
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
180
|
+
await evaluator.analyze(
|
|
181
|
+
sample_input, structured_config, structured_output_schema
|
|
182
|
+
)
|
|
183
|
+
assert "Error occurred during structured output validation" in str(
|
|
184
|
+
exc_info.value
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@pytest.mark.asyncio
|
|
189
|
+
async def test_analyze_regular_output_empty_response(
|
|
190
|
+
evaluator, sample_input, basic_config
|
|
191
|
+
):
|
|
192
|
+
mock_result = LanguageModelResponse(
|
|
193
|
+
choices=[
|
|
194
|
+
LanguageModelCompletionChoice(
|
|
195
|
+
index=0,
|
|
196
|
+
message=LanguageModelAssistantMessage(content=""),
|
|
197
|
+
finish_reason="stop",
|
|
198
|
+
)
|
|
199
|
+
]
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
with patch.object(
|
|
203
|
+
evaluator.language_model_service,
|
|
204
|
+
"complete_async",
|
|
205
|
+
return_value=mock_result,
|
|
206
|
+
):
|
|
207
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
208
|
+
await evaluator.analyze(sample_input, basic_config)
|
|
209
|
+
assert "did not return a result" in str(exc_info.value)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def test_compose_msgs_regular(evaluator, sample_input, basic_config):
|
|
213
|
+
messages = evaluator._compose_msgs(
|
|
214
|
+
sample_input, basic_config, enable_structured_output=False
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
assert isinstance(messages, LanguageModelMessages)
|
|
218
|
+
assert messages.root[0].content == CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
|
|
219
|
+
assert isinstance(messages.root[1].content, str)
|
|
220
|
+
assert "test query" in messages.root[1].content
|
|
221
|
+
assert "test context 1" in messages.root[1].content
|
|
222
|
+
assert "test context 2" in messages.root[1].content
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_compose_msgs_structured(evaluator, sample_input, structured_config):
|
|
226
|
+
messages = evaluator._compose_msgs(
|
|
227
|
+
sample_input, structured_config, enable_structured_output=True
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert isinstance(messages, LanguageModelMessages)
|
|
231
|
+
assert len(messages.root) == 2
|
|
232
|
+
assert (
|
|
233
|
+
messages.root[0].content != CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
|
|
234
|
+
) # Should use structured output prompt
|
|
235
|
+
assert isinstance(messages.root[1].content, str)
|
|
236
|
+
assert "test query" in messages.root[1].content
|
|
237
|
+
assert "test context 1" in messages.root[1].content
|
|
238
|
+
assert "test context 2" in messages.root[1].content
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.asyncio
|
|
242
|
+
async def test_analyze_unknown_error(evaluator, sample_input, basic_config):
|
|
243
|
+
with patch.object(
|
|
244
|
+
evaluator.language_model_service,
|
|
245
|
+
"complete_async",
|
|
246
|
+
side_effect=Exception("Unknown error"),
|
|
247
|
+
):
|
|
248
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
249
|
+
await evaluator.analyze(sample_input, basic_config)
|
|
250
|
+
assert "Unknown error occurred during context relevancy metric analysis" in str(
|
|
251
|
+
exc_info.value
|
|
252
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from unique_toolkit.evals.context_relevancy.schema import EvaluationSchemaStructuredOutput, Fact
|
|
4
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
5
|
+
from unique_toolkit.evals.output_parser import parse_eval_metric_result, parse_eval_metric_result_structured_output
|
|
6
|
+
from unique_toolkit.evals.schemas import EvaluationMetricName, EvaluationMetricResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_parse_eval_metric_result_success():
|
|
12
|
+
# Test successful parsing with all fields
|
|
13
|
+
result = '{"value": "high", "reason": "Test reason"}'
|
|
14
|
+
parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
15
|
+
|
|
16
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
17
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
18
|
+
assert parsed.value == "high"
|
|
19
|
+
assert parsed.reason == "Test reason"
|
|
20
|
+
assert parsed.fact_list == []
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_parse_eval_metric_result_missing_fields():
|
|
24
|
+
# Test parsing with missing fields (should use default "None")
|
|
25
|
+
result = '{"value": "high"}'
|
|
26
|
+
parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
27
|
+
|
|
28
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
29
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
30
|
+
assert parsed.value == "high"
|
|
31
|
+
assert parsed.reason == "None"
|
|
32
|
+
assert parsed.fact_list == []
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_parse_eval_metric_result_invalid_json():
|
|
36
|
+
# Test parsing with invalid JSON
|
|
37
|
+
result = "invalid json"
|
|
38
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
39
|
+
parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
40
|
+
|
|
41
|
+
assert "Error occurred during parsing the evaluation metric result" in str(
|
|
42
|
+
exc_info.value
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_parse_eval_metric_result_structured_output_basic():
|
|
47
|
+
# Test basic structured output without fact list
|
|
48
|
+
result = EvaluationSchemaStructuredOutput(value="high", reason="Test reason")
|
|
49
|
+
parsed = parse_eval_metric_result_structured_output(
|
|
50
|
+
result, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
54
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
55
|
+
assert parsed.value == "high"
|
|
56
|
+
assert parsed.reason == "Test reason"
|
|
57
|
+
assert parsed.fact_list == []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_parse_eval_metric_result_structured_output_with_facts():
|
|
61
|
+
# Test structured output with fact list
|
|
62
|
+
result = EvaluationSchemaStructuredOutput(
|
|
63
|
+
value="high",
|
|
64
|
+
reason="Test reason",
|
|
65
|
+
fact_list=[
|
|
66
|
+
Fact(fact="Fact 1"),
|
|
67
|
+
Fact(fact="Fact 2"),
|
|
68
|
+
],
|
|
69
|
+
)
|
|
70
|
+
parsed = parse_eval_metric_result_structured_output(
|
|
71
|
+
result, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
75
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
76
|
+
assert parsed.value == "high"
|
|
77
|
+
assert parsed.reason == "Test reason"
|
|
78
|
+
assert parsed.fact_list == ["Fact 1", "Fact 2"]
|
|
79
|
+
assert isinstance(parsed.fact_list, list)
|
|
80
|
+
assert len(parsed.fact_list) == 2 # None fact should be filtered out
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import mimetypes
|
|
3
|
+
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import tiktoken
|
|
9
|
+
|
|
10
|
+
from pydantic import RootModel
|
|
11
|
+
|
|
12
|
+
from _common.token.token_counting import num_tokens_per_language_model_message
|
|
13
|
+
from chat.service import ChatService
|
|
14
|
+
from content.service import ContentService
|
|
15
|
+
from language_model.schemas import LanguageModelMessages
|
|
16
|
+
from unique_toolkit.app import ChatEventUserMessage
|
|
17
|
+
from unique_toolkit.chat.schemas import ChatMessage
|
|
18
|
+
from unique_toolkit.chat.schemas import ChatMessageRole as ChatRole
|
|
19
|
+
from unique_toolkit.content.schemas import Content
|
|
20
|
+
from unique_toolkit.language_model import LanguageModelMessageRole as LLMRole
|
|
21
|
+
from unique_toolkit.language_model.infos import EncoderName
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# TODO: Test this once it moves into the unique toolkit
|
|
26
|
+
|
|
27
|
+
map_chat_llm_message_role = {
|
|
28
|
+
ChatRole.USER: LLMRole.USER,
|
|
29
|
+
ChatRole.ASSISTANT: LLMRole.ASSISTANT,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ImageMimeType(StrEnum):
|
|
34
|
+
JPEG = "image/jpeg"
|
|
35
|
+
PNG = "image/png"
|
|
36
|
+
GIF = "image/gif"
|
|
37
|
+
BMP = "image/bmp"
|
|
38
|
+
WEBP = "image/webp"
|
|
39
|
+
TIFF = "image/tiff"
|
|
40
|
+
SVG = "image/svg+xml"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FileMimeType(StrEnum):
|
|
44
|
+
PDF = "application/pdf"
|
|
45
|
+
DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
46
|
+
DOC = "application/msword"
|
|
47
|
+
XLSX = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
|
48
|
+
XLS = "application/vnd.ms-excel"
|
|
49
|
+
PPTX = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
|
50
|
+
CSV = "text/csv"
|
|
51
|
+
HTML = "text/html"
|
|
52
|
+
MD = "text/markdown"
|
|
53
|
+
TXT = "text/plain"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ChatMessageWithContents(ChatMessage):
|
|
57
|
+
contents: list[Content] = []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ChatHistoryWithContent(RootModel):
|
|
61
|
+
root: list[ChatMessageWithContents]
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_chat_history_and_contents(
|
|
65
|
+
cls,
|
|
66
|
+
chat_history: list[ChatMessage],
|
|
67
|
+
chat_contents: list[Content],
|
|
68
|
+
):
|
|
69
|
+
combined = chat_contents + chat_history
|
|
70
|
+
combined.sort(key=lambda x: x.created_at or datetime.min)
|
|
71
|
+
|
|
72
|
+
grouped_elements = []
|
|
73
|
+
content_container = []
|
|
74
|
+
|
|
75
|
+
# Content is collected and added to the next chat message
|
|
76
|
+
for c in combined:
|
|
77
|
+
if isinstance(c, ChatMessage):
|
|
78
|
+
grouped_elements.append(
|
|
79
|
+
ChatMessageWithContents(
|
|
80
|
+
contents=content_container.copy(),
|
|
81
|
+
**c.model_dump(),
|
|
82
|
+
),
|
|
83
|
+
)
|
|
84
|
+
content_container.clear()
|
|
85
|
+
else:
|
|
86
|
+
content_container.append(c)
|
|
87
|
+
|
|
88
|
+
return cls(root=grouped_elements)
|
|
89
|
+
|
|
90
|
+
def __iter__(self):
|
|
91
|
+
return iter(self.root)
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, item):
|
|
94
|
+
return self.root[item]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def is_image_content(filename: str) -> bool:
|
|
98
|
+
mimetype, _ = mimetypes.guess_type(filename)
|
|
99
|
+
|
|
100
|
+
if not mimetype:
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
return mimetype in ImageMimeType.__members__.values()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def is_file_content(filename: str) -> bool:
|
|
107
|
+
mimetype, _ = mimetypes.guess_type(filename)
|
|
108
|
+
|
|
109
|
+
if not mimetype:
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
return mimetype in FileMimeType.__members__.values()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_chat_history_with_contents(
|
|
116
|
+
user_message: ChatEventUserMessage,
|
|
117
|
+
chat_id: str,
|
|
118
|
+
chat_history: list[ChatMessage],
|
|
119
|
+
content_service: ContentService,
|
|
120
|
+
) -> ChatHistoryWithContent:
|
|
121
|
+
last_user_message = ChatMessage(
|
|
122
|
+
id=user_message.id,
|
|
123
|
+
chat_id=chat_id,
|
|
124
|
+
text=user_message.text,
|
|
125
|
+
originalText=user_message.original_text,
|
|
126
|
+
role=ChatRole.USER,
|
|
127
|
+
gpt_request=None,
|
|
128
|
+
created_at=datetime.fromisoformat(user_message.created_at),
|
|
129
|
+
)
|
|
130
|
+
if len(chat_history) > 0 and last_user_message.id == chat_history[-1].id:
|
|
131
|
+
pass
|
|
132
|
+
else:
|
|
133
|
+
chat_history.append(last_user_message)
|
|
134
|
+
|
|
135
|
+
chat_contents = content_service.search_contents(
|
|
136
|
+
where={
|
|
137
|
+
"ownerId": {
|
|
138
|
+
"equals": chat_id,
|
|
139
|
+
},
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return ChatHistoryWithContent.from_chat_history_and_contents(
|
|
144
|
+
chat_history,
|
|
145
|
+
chat_contents,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def download_encoded_images(
|
|
150
|
+
contents: list[Content],
|
|
151
|
+
content_service: ContentService,
|
|
152
|
+
chat_id: str,
|
|
153
|
+
) -> list[str]:
|
|
154
|
+
base64_encoded_images = []
|
|
155
|
+
for im in contents:
|
|
156
|
+
if is_image_content(im.key):
|
|
157
|
+
try:
|
|
158
|
+
file_bytes = content_service.download_content_to_bytes(
|
|
159
|
+
content_id=im.id,
|
|
160
|
+
chat_id=chat_id,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
mime_type, _ = mimetypes.guess_type(im.key)
|
|
164
|
+
encoded_string = base64.b64encode(file_bytes).decode("utf-8")
|
|
165
|
+
image_string = f"data:{mime_type};base64," + encoded_string
|
|
166
|
+
base64_encoded_images.append(image_string)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
print(e)
|
|
169
|
+
return base64_encoded_images
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class FileContentSerialization(StrEnum):
|
|
173
|
+
NONE = "none"
|
|
174
|
+
FILE_NAME = "file_name"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ImageContentInclusion(StrEnum):
|
|
178
|
+
NONE = "none"
|
|
179
|
+
ALL = "all"
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def file_content_serialization(
|
|
183
|
+
file_contents: list[Content],
|
|
184
|
+
file_content_serialization: FileContentSerialization,
|
|
185
|
+
) -> str:
|
|
186
|
+
match file_content_serialization:
|
|
187
|
+
case FileContentSerialization.NONE:
|
|
188
|
+
return ""
|
|
189
|
+
case FileContentSerialization.FILE_NAME:
|
|
190
|
+
file_names = [
|
|
191
|
+
f"- Uploaded file: {f.key} at {f.created_at}"
|
|
192
|
+
for f in file_contents
|
|
193
|
+
]
|
|
194
|
+
return "\n".join(
|
|
195
|
+
[
|
|
196
|
+
"Files Uploaded to Chat can be accessed by internal search tool if available:\n",
|
|
197
|
+
]
|
|
198
|
+
+ file_names,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def get_full_history_with_contents(
|
|
203
|
+
user_message: ChatEventUserMessage,
|
|
204
|
+
chat_id: str,
|
|
205
|
+
chat_service: ChatService,
|
|
206
|
+
content_service: ContentService,
|
|
207
|
+
include_images: ImageContentInclusion = ImageContentInclusion.ALL,
|
|
208
|
+
file_content_serialization_type: FileContentSerialization = FileContentSerialization.FILE_NAME,
|
|
209
|
+
) -> LanguageModelMessages:
|
|
210
|
+
grouped_elements = get_chat_history_with_contents(
|
|
211
|
+
user_message=user_message,
|
|
212
|
+
chat_id=chat_id,
|
|
213
|
+
chat_history=chat_service.get_full_history(),
|
|
214
|
+
content_service=content_service,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
builder = LanguageModelMessages([]).builder()
|
|
218
|
+
for c in grouped_elements:
|
|
219
|
+
# LanguageModelUserMessage has not field original content
|
|
220
|
+
text = c.original_content if c.original_content else c.content
|
|
221
|
+
if text is None:
|
|
222
|
+
if c.role == ChatRole.USER:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Content or original_content of LanguageModelMessages should exist.",
|
|
225
|
+
)
|
|
226
|
+
text = ""
|
|
227
|
+
|
|
228
|
+
if len(c.contents) > 0:
|
|
229
|
+
file_contents = [
|
|
230
|
+
co for co in c.contents if is_file_content(co.key)
|
|
231
|
+
]
|
|
232
|
+
image_contents = [
|
|
233
|
+
co for co in c.contents if is_image_content(co.key)
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
content = (
|
|
237
|
+
text
|
|
238
|
+
+ "\n\n"
|
|
239
|
+
+ file_content_serialization(
|
|
240
|
+
file_contents,
|
|
241
|
+
file_content_serialization_type,
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
content = content.strip()
|
|
245
|
+
|
|
246
|
+
if include_images and len(image_contents) > 0:
|
|
247
|
+
builder.image_message_append(
|
|
248
|
+
content=content,
|
|
249
|
+
images=download_encoded_images(
|
|
250
|
+
contents=image_contents,
|
|
251
|
+
content_service=content_service,
|
|
252
|
+
chat_id=chat_id,
|
|
253
|
+
),
|
|
254
|
+
role=map_chat_llm_message_role[c.role],
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
builder.message_append(
|
|
258
|
+
role=map_chat_llm_message_role[c.role],
|
|
259
|
+
content=content,
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
builder.message_append(
|
|
263
|
+
role=map_chat_llm_message_role[c.role],
|
|
264
|
+
content=text,
|
|
265
|
+
)
|
|
266
|
+
return builder.build()
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def get_full_history_as_llm_messages(
|
|
270
|
+
chat_service: ChatService,
|
|
271
|
+
) -> LanguageModelMessages:
|
|
272
|
+
chat_history = chat_service.get_full_history()
|
|
273
|
+
|
|
274
|
+
map_chat_llm_message_role = {
|
|
275
|
+
ChatRole.USER: LLMRole.USER,
|
|
276
|
+
ChatRole.ASSISTANT: LLMRole.ASSISTANT,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
builder = LanguageModelMessages([]).builder()
|
|
280
|
+
for c in chat_history:
|
|
281
|
+
builder.message_append(
|
|
282
|
+
role=map_chat_llm_message_role[c.role],
|
|
283
|
+
content=c.content or "",
|
|
284
|
+
)
|
|
285
|
+
return builder.build()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def limit_to_token_window(
|
|
290
|
+
messages: LanguageModelMessages,
|
|
291
|
+
token_limit: int,
|
|
292
|
+
encoding_name: EncoderName = EncoderName.O200K_BASE,
|
|
293
|
+
) -> LanguageModelMessages:
|
|
294
|
+
encoder = tiktoken.get_encoding(encoding_name)
|
|
295
|
+
token_per_message_reversed = num_tokens_per_language_model_message(
|
|
296
|
+
messages,
|
|
297
|
+
encode=encoder.encode,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
to_take: list[bool] = (
|
|
301
|
+
np.cumsum(token_per_message_reversed) < token_limit
|
|
302
|
+
).tolist()
|
|
303
|
+
to_take.reverse()
|
|
304
|
+
|
|
305
|
+
return LanguageModelMessages(
|
|
306
|
+
root=[m for m, tt in zip(messages, to_take, strict=False) if tt],
|
|
307
|
+
)
|