unique_toolkit 0.8.15__py3-none-any.whl → 0.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,57 @@
1
+ import logging
2
+
3
+ from unique_toolkit.app.schemas import ChatEvent
4
+ from unique_toolkit.evals.config import EvaluationMetricConfig
5
+ from unique_toolkit.evals.schemas import EvaluationMetricInput, EvaluationMetricResult
6
+
7
+
8
+ from .constants import hallucination_metric_default_config
9
+ from .utils import check_hallucination
10
+
11
+ SYSTEM_MSG_KEY = "systemPrompt"
12
+ USER_MSG_KEY = "userPrompt"
13
+ SYSTEM_MSG_DEFAULT_KEY = "systemPromptDefault"
14
+ USER_MSG_DEFAULT_KEY = "userPromptDefault"
15
+
16
+
17
+ class HallucinationEvaluator:
18
+ def __init__(self, event: ChatEvent):
19
+ self.event = event
20
+
21
+ self.logger = logging.getLogger(f"HallucinationEvaluator.{__name__}")
22
+
23
+ async def analyze(
24
+ self,
25
+ input: EvaluationMetricInput,
26
+ config: EvaluationMetricConfig = hallucination_metric_default_config,
27
+ ) -> EvaluationMetricResult | None:
28
+ """
29
+ Analyzes the level of hallucination in the generated output by comparing it with the input
30
+ and the provided contexts or history. The analysis classifies the hallucination level as:
31
+ - low
32
+ - medium
33
+ - high
34
+
35
+ If no contexts or history are referenced in the generated output, the method verifies
36
+ that the output does not contain any relevant information to answer the question.
37
+
38
+ This method calls `check_hallucination` to perform the actual analysis. The `check_hallucination`
39
+ function handles the evaluation using the company ID from the event, the provided input, and the configuration.
40
+
41
+ Args:
42
+ input (EvaluationMetricInput): The input data used for evaluation, including the generated output and reference information.
43
+ config (EvaluationMetricConfig, optional): Configuration settings for the evaluation. Defaults to `hallucination_metric_default_config`.
44
+
45
+ Returns:
46
+ EvaluationMetricResult | None: The result of the evaluation, indicating the level of hallucination. Returns `None` if the analysis cannot be performed.
47
+
48
+ Raises:
49
+ EvaluatorException: If the context texts are empty, required fields are missing, or an error occurs during the evaluation.
50
+ """
51
+ if config.enabled is False:
52
+ self.logger.info("Hallucination metric is not enabled.")
53
+ return None
54
+
55
+ return await check_hallucination(
56
+ company_id=self.event.company_id, input=input, config=config
57
+ )
@@ -0,0 +1,213 @@
1
+ import logging
2
+ from string import Template
3
+
4
+ from unique_toolkit.content.schemas import ContentChunk
5
+ from unique_toolkit.language_model.schemas import (
6
+ LanguageModelMessages,
7
+ LanguageModelStreamResponse,
8
+ LanguageModelSystemMessage,
9
+ LanguageModelUserMessage,
10
+ )
11
+ from unique_toolkit.language_model.service import LanguageModelService
12
+ from unique_toolkit.evals.config import EvaluationMetricConfig
13
+ from unique_toolkit.evals.exception import EvaluatorException
14
+ from unique_toolkit.evals.output_parser import parse_eval_metric_result
15
+ from unique_toolkit.evals.schemas import (
16
+ EvaluationMetricInput,
17
+ EvaluationMetricName,
18
+ EvaluationMetricResult,
19
+ )
20
+
21
+
22
+ from .constants import (
23
+ SYSTEM_MSG_DEFAULT_KEY,
24
+ SYSTEM_MSG_KEY,
25
+ USER_MSG_DEFAULT_KEY,
26
+ USER_MSG_KEY,
27
+ hallucination_required_input_fields,
28
+ )
29
+ from .prompts import (
30
+ HALLUCINATION_METRIC_SYSTEM_MSG,
31
+ HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
32
+ HALLUCINATION_METRIC_USER_MSG,
33
+ HALLUCINATION_METRIC_USER_MSG_DEFAULT,
34
+ )
35
+
36
+
37
+ async def check_hallucination(
38
+ company_id: str,
39
+ input: EvaluationMetricInput,
40
+ config: EvaluationMetricConfig,
41
+ ) -> EvaluationMetricResult:
42
+ """
43
+ Analyzes the level of hallucination in the generated output by comparing it with the provided input
44
+ and the contexts or history. The analysis classifies the hallucination level as:
45
+ - low
46
+ - medium
47
+ - high
48
+
49
+ If no contexts or history are referenced in the generated output, the method checks that the output
50
+ does not contain any relevant information to answer the question.
51
+
52
+ This method performs the following steps:
53
+ 1. Checks if the hallucination metric is enabled using the provided `config`.
54
+ 2. Logs the start of the analysis using the provided `logger`.
55
+ 3. Validates the required fields in the `input` data.
56
+ 4. Retrieves the messages using the `_get_msgs` method.
57
+ 5. Calls `LanguageModelService.complete_async_util` to get a completion result.
58
+ 6. Parses and returns the evaluation metric result based on the content of the completion result.
59
+
60
+ Args:
61
+ company_id (str): The company ID for the analysis.
62
+ input (EvaluationMetricInput): The input data used for evaluation, including the generated output and reference information.
63
+ config (EvaluationMetricConfig, optional): Configuration settings for the evaluation. Defaults to `hallucination_metric_default_config`.
64
+ logger (Optional[logging.Logger], optional): The logger used for logging information and errors. Defaults to the logger for the current module.
65
+
66
+ Returns:
67
+ EvaluationMetricResult | None: The result of the evaluation, indicating the level of hallucination. Returns `None` if the metric is not enabled or if an error occurs.
68
+
69
+ Raises:
70
+ EvaluatorException: If the context texts are empty, required fields are missing, or an error occurs during the evaluation.
71
+ """
72
+
73
+ logger = logging.getLogger(f"check_hallucination.{__name__}")
74
+
75
+ model_name = config.language_model.name
76
+ logger.info(f"Analyzing level of hallucination with {model_name}.")
77
+
78
+ input.validate_required_fields(hallucination_required_input_fields)
79
+
80
+ try:
81
+ msgs = _get_msgs(input, config, logger)
82
+ result = await LanguageModelService.complete_async_util(
83
+ company_id=company_id, messages=msgs, model_name=model_name
84
+ )
85
+ result_content = result.choices[0].message.content
86
+ if not result_content:
87
+ error_message = "Hallucination evaluation did not return a result."
88
+ raise EvaluatorException(
89
+ error_message=error_message,
90
+ user_message=error_message,
91
+ )
92
+ return parse_eval_metric_result(
93
+ result_content, # type: ignore
94
+ EvaluationMetricName.HALLUCINATION,
95
+ )
96
+ except Exception as e:
97
+ error_message = "Error occurred during hallucination metric analysis"
98
+ raise EvaluatorException(
99
+ error_message=f"{error_message}: {e}",
100
+ user_message=error_message,
101
+ exception=e,
102
+ )
103
+
104
+
105
+ def _get_msgs(
106
+ input: EvaluationMetricInput,
107
+ config: EvaluationMetricConfig,
108
+ logger: logging.Logger,
109
+ ):
110
+ """
111
+ Composes the messages for hallucination analysis based on the provided input and configuration.
112
+
113
+ This method decides how to compose the messages based on the availability of context texts and history
114
+ message texts in the `input`
115
+
116
+ Args:
117
+ input (EvaluationMetricInput): The input data that includes context texts and history message texts
118
+ for the analysis.
119
+ config (EvaluationMetricConfig): The configuration settings for composing messages.
120
+ logger (Optional[logging.Logger], optional): The logger used for logging debug information.
121
+ Defaults to the logger for the current module.
122
+
123
+ Returns:
124
+ The composed messages as per the provided input and configuration. The exact type and structure
125
+ depend on the implementation of the `compose_msgs` and `compose_msgs_default` methods.
126
+
127
+ """
128
+ if input.context_texts or input.history_messages:
129
+ logger.debug("Using context / history for hallucination evaluation.")
130
+ return _compose_msgs(input, config)
131
+ else:
132
+ logger.debug("No contexts and history provided for hallucination evaluation.")
133
+ return _compose_msgs_default(input, config)
134
+
135
+
136
+ def _compose_msgs(
137
+ input: EvaluationMetricInput,
138
+ config: EvaluationMetricConfig,
139
+ ):
140
+ """
141
+ Composes the hallucination analysis messages.
142
+ """
143
+ system_msg_content = _get_system_prompt_with_contexts(config)
144
+ system_msg = LanguageModelSystemMessage(content=system_msg_content)
145
+
146
+ user_msg_templ = Template(_get_user_prompt_with_contexts(config))
147
+ user_msg_content = user_msg_templ.substitute(
148
+ input_text=input.input_text,
149
+ contexts_text=input.get_joined_context_texts(tag_name="reference"),
150
+ history_messages_text=input.get_joined_history_texts(tag_name="conversation"),
151
+ output_text=input.output_text,
152
+ )
153
+ user_msg = LanguageModelUserMessage(content=user_msg_content)
154
+ return LanguageModelMessages([system_msg, user_msg])
155
+
156
+
157
+ def _compose_msgs_default(
158
+ input: EvaluationMetricInput,
159
+ config: EvaluationMetricConfig,
160
+ ):
161
+ """
162
+ Composes the hallucination analysis prompt without messages.
163
+ """
164
+ system_msg_content = _get_system_prompt_default(config)
165
+ system_msg = LanguageModelSystemMessage(content=system_msg_content)
166
+
167
+ user_msg_templ = Template(_get_user_prompt_default(config))
168
+ user_msg_content = user_msg_templ.substitute(
169
+ input_text=input.input_text,
170
+ output_text=input.output_text,
171
+ )
172
+ user_msg = LanguageModelUserMessage(content=user_msg_content)
173
+ return LanguageModelMessages([system_msg, user_msg])
174
+
175
+
176
+ def _get_system_prompt_with_contexts(config: EvaluationMetricConfig):
177
+ return config.custom_prompts.setdefault(
178
+ SYSTEM_MSG_KEY,
179
+ HALLUCINATION_METRIC_SYSTEM_MSG,
180
+ )
181
+
182
+
183
+ def _get_user_prompt_with_contexts(config: EvaluationMetricConfig):
184
+ return config.custom_prompts.setdefault(
185
+ USER_MSG_KEY,
186
+ HALLUCINATION_METRIC_USER_MSG,
187
+ )
188
+
189
+
190
+ def _get_system_prompt_default(config: EvaluationMetricConfig):
191
+ return config.custom_prompts.setdefault(
192
+ SYSTEM_MSG_DEFAULT_KEY,
193
+ HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
194
+ )
195
+
196
+
197
+ def _get_user_prompt_default(config: EvaluationMetricConfig):
198
+ return config.custom_prompts.setdefault(
199
+ USER_MSG_DEFAULT_KEY,
200
+ HALLUCINATION_METRIC_USER_MSG_DEFAULT,
201
+ )
202
+
203
+
204
+ def context_text_from_stream_response(
205
+ response: LanguageModelStreamResponse, selected_chunks: list[ContentChunk]
206
+ ):
207
+ response_references = response.message.references
208
+ reference_ids = [reference.source_id for reference in response_references]
209
+ filtered_contexts: list[str] = []
210
+ for chunk in selected_chunks:
211
+ if f"{chunk.id}_{chunk.chunk_id}" in reference_ids:
212
+ filtered_contexts.append(chunk.text)
213
+ return filtered_contexts
@@ -0,0 +1,48 @@
1
+ from unique_toolkit.language_model.utils import convert_string_to_json
2
+ from unique_toolkit.evals.context_relevancy.schema import (
3
+ EvaluationSchemaStructuredOutput,
4
+ )
5
+ from unique_toolkit.evals.exception import EvaluatorException
6
+ from unique_toolkit.evals.schemas import (
7
+ EvaluationMetricName,
8
+ EvaluationMetricResult,
9
+ )
10
+
11
+
12
+ def parse_eval_metric_result(
13
+ result: str,
14
+ metric_name: EvaluationMetricName,
15
+ ):
16
+ """
17
+ Parses the evaluation metric result.
18
+ """
19
+
20
+ try:
21
+ parsed_result = convert_string_to_json(result)
22
+ except Exception as e:
23
+ error_message = "Error occurred during parsing the evaluation metric result"
24
+ raise EvaluatorException(
25
+ user_message=f"{error_message}.",
26
+ error_message=f"{error_message}: {str(e)}",
27
+ )
28
+
29
+ return EvaluationMetricResult(
30
+ name=metric_name,
31
+ value=parsed_result.get("value", "None"),
32
+ reason=parsed_result.get("reason", "None"),
33
+ )
34
+
35
+
36
+ def parse_eval_metric_result_structured_output(
37
+ result: EvaluationSchemaStructuredOutput,
38
+ metric_name: EvaluationMetricName,
39
+ ) -> EvaluationMetricResult:
40
+ """
41
+ Parses the evaluation metric result.
42
+ """
43
+ return EvaluationMetricResult(
44
+ name=metric_name,
45
+ value=result.value,
46
+ reason=result.reason,
47
+ fact_list=[item.fact for item in result.fact_list],
48
+ )
@@ -0,0 +1,252 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+ from unique_toolkit.app.schemas import ChatEvent
5
+ from unique_toolkit.chat.service import LanguageModelName
6
+ from unique_toolkit.language_model.infos import (
7
+ LanguageModelInfo,
8
+ )
9
+ from unique_toolkit.language_model.schemas import (
10
+ LanguageModelAssistantMessage,
11
+ LanguageModelCompletionChoice,
12
+ LanguageModelMessages,
13
+ )
14
+ from unique_toolkit.language_model.service import LanguageModelResponse
15
+ from unique_toolkit.evals.config import EvaluationMetricConfig
16
+ from unique_toolkit.evals.context_relevancy.prompts import (
17
+ CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
18
+ )
19
+ from unique_toolkit.evals.context_relevancy.schema import (
20
+ EvaluationSchemaStructuredOutput,
21
+ )
22
+ from unique_toolkit.evals.context_relevancy.service import (
23
+ ContextRelevancyEvaluator,
24
+ )
25
+ from unique_toolkit.evals.exception import EvaluatorException
26
+ from unique_toolkit.evals.schemas import (
27
+ EvaluationMetricInput,
28
+ EvaluationMetricName,
29
+ EvaluationMetricResult,
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def event():
35
+ event = MagicMock(spec=ChatEvent)
36
+ event.payload = MagicMock()
37
+ event.payload.user_message = MagicMock()
38
+ event.payload.user_message.text = "Test query"
39
+ event.user_id = "user_0"
40
+ event.company_id = "company_0"
41
+ return event
42
+
43
+
44
+ @pytest.fixture
45
+ def evaluator(event):
46
+ return ContextRelevancyEvaluator(event)
47
+
48
+
49
+ @pytest.fixture
50
+ def basic_config():
51
+ return EvaluationMetricConfig(
52
+ enabled=True,
53
+ name=EvaluationMetricName.CONTEXT_RELEVANCY,
54
+ language_model=LanguageModelInfo.from_name(
55
+ LanguageModelName.AZURE_GPT_4o_2024_0806
56
+ ),
57
+ )
58
+
59
+
60
+ @pytest.fixture
61
+ def structured_config(basic_config):
62
+ model_info = LanguageModelInfo.from_name(LanguageModelName.AZURE_GPT_4o_2024_0806)
63
+ return EvaluationMetricConfig(
64
+ enabled=True,
65
+ name=EvaluationMetricName.CONTEXT_RELEVANCY,
66
+ language_model=model_info,
67
+ )
68
+
69
+
70
+ @pytest.fixture
71
+ def sample_input():
72
+ return EvaluationMetricInput(
73
+ input_text="test query",
74
+ context_texts=["test context 1", "test context 2"],
75
+ )
76
+
77
+
78
+ @pytest.mark.asyncio
79
+ async def test_analyze_disabled(evaluator, sample_input, basic_config):
80
+ basic_config.enabled = False
81
+ result = await evaluator.analyze(sample_input, basic_config)
82
+ assert result is None
83
+
84
+
85
+ @pytest.mark.asyncio
86
+ async def test_analyze_empty_context(evaluator, basic_config):
87
+ input_with_empty_context = EvaluationMetricInput(
88
+ input_text="test query", context_texts=[]
89
+ )
90
+
91
+ with pytest.raises(EvaluatorException) as exc_info:
92
+ await evaluator.analyze(input_with_empty_context, basic_config)
93
+
94
+ assert "No context texts provided." in str(exc_info.value)
95
+
96
+
97
+ @pytest.mark.asyncio
98
+ async def test_analyze_regular_output(evaluator, sample_input, basic_config):
99
+ mock_result = LanguageModelResponse(
100
+ choices=[
101
+ LanguageModelCompletionChoice(
102
+ index=0,
103
+ message=LanguageModelAssistantMessage(
104
+ content="""{
105
+ "value": "high",
106
+ "reason": "Test reason"
107
+ }"""
108
+ ),
109
+ finish_reason="stop",
110
+ )
111
+ ]
112
+ )
113
+
114
+ with patch.object(
115
+ evaluator.language_model_service,
116
+ "complete_async",
117
+ return_value=mock_result,
118
+ ) as mock_complete:
119
+ result = await evaluator.analyze(sample_input, basic_config)
120
+
121
+ assert isinstance(result, EvaluationMetricResult)
122
+ assert result.value.lower() == "high"
123
+ mock_complete.assert_called_once()
124
+
125
+
126
+ @pytest.mark.asyncio
127
+ async def test_analyze_structured_output(evaluator, sample_input, structured_config):
128
+ mock_result = LanguageModelResponse(
129
+ choices=[
130
+ LanguageModelCompletionChoice(
131
+ index=0,
132
+ message=LanguageModelAssistantMessage(
133
+ content="HIGH",
134
+ parsed={"value": "high", "reason": "Test reason"},
135
+ ),
136
+ finish_reason="stop",
137
+ )
138
+ ]
139
+ )
140
+
141
+ structured_output_schema = EvaluationSchemaStructuredOutput
142
+
143
+ with patch.object(
144
+ evaluator.language_model_service,
145
+ "complete_async",
146
+ return_value=mock_result,
147
+ ) as mock_complete:
148
+ result = await evaluator.analyze(
149
+ sample_input, structured_config, structured_output_schema
150
+ )
151
+ assert isinstance(result, EvaluationMetricResult)
152
+ assert result.value.lower() == "high"
153
+ mock_complete.assert_called_once()
154
+
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_analyze_structured_output_validation_error(
158
+ evaluator, sample_input, structured_config
159
+ ):
160
+ mock_result = LanguageModelResponse(
161
+ choices=[
162
+ LanguageModelCompletionChoice(
163
+ index=0,
164
+ message=LanguageModelAssistantMessage(
165
+ content="HIGH", parsed={"invalid": "data"}
166
+ ),
167
+ finish_reason="stop",
168
+ )
169
+ ]
170
+ )
171
+
172
+ structured_output_schema = EvaluationSchemaStructuredOutput
173
+
174
+ with patch.object(
175
+ evaluator.language_model_service,
176
+ "complete_async",
177
+ return_value=mock_result,
178
+ ):
179
+ with pytest.raises(EvaluatorException) as exc_info:
180
+ await evaluator.analyze(
181
+ sample_input, structured_config, structured_output_schema
182
+ )
183
+ assert "Error occurred during structured output validation" in str(
184
+ exc_info.value
185
+ )
186
+
187
+
188
+ @pytest.mark.asyncio
189
+ async def test_analyze_regular_output_empty_response(
190
+ evaluator, sample_input, basic_config
191
+ ):
192
+ mock_result = LanguageModelResponse(
193
+ choices=[
194
+ LanguageModelCompletionChoice(
195
+ index=0,
196
+ message=LanguageModelAssistantMessage(content=""),
197
+ finish_reason="stop",
198
+ )
199
+ ]
200
+ )
201
+
202
+ with patch.object(
203
+ evaluator.language_model_service,
204
+ "complete_async",
205
+ return_value=mock_result,
206
+ ):
207
+ with pytest.raises(EvaluatorException) as exc_info:
208
+ await evaluator.analyze(sample_input, basic_config)
209
+ assert "did not return a result" in str(exc_info.value)
210
+
211
+
212
+ def test_compose_msgs_regular(evaluator, sample_input, basic_config):
213
+ messages = evaluator._compose_msgs(
214
+ sample_input, basic_config, enable_structured_output=False
215
+ )
216
+
217
+ assert isinstance(messages, LanguageModelMessages)
218
+ assert messages.root[0].content == CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
219
+ assert isinstance(messages.root[1].content, str)
220
+ assert "test query" in messages.root[1].content
221
+ assert "test context 1" in messages.root[1].content
222
+ assert "test context 2" in messages.root[1].content
223
+
224
+
225
+ def test_compose_msgs_structured(evaluator, sample_input, structured_config):
226
+ messages = evaluator._compose_msgs(
227
+ sample_input, structured_config, enable_structured_output=True
228
+ )
229
+
230
+ assert isinstance(messages, LanguageModelMessages)
231
+ assert len(messages.root) == 2
232
+ assert (
233
+ messages.root[0].content != CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
234
+ ) # Should use structured output prompt
235
+ assert isinstance(messages.root[1].content, str)
236
+ assert "test query" in messages.root[1].content
237
+ assert "test context 1" in messages.root[1].content
238
+ assert "test context 2" in messages.root[1].content
239
+
240
+
241
+ @pytest.mark.asyncio
242
+ async def test_analyze_unknown_error(evaluator, sample_input, basic_config):
243
+ with patch.object(
244
+ evaluator.language_model_service,
245
+ "complete_async",
246
+ side_effect=Exception("Unknown error"),
247
+ ):
248
+ with pytest.raises(EvaluatorException) as exc_info:
249
+ await evaluator.analyze(sample_input, basic_config)
250
+ assert "Unknown error occurred during context relevancy metric analysis" in str(
251
+ exc_info.value
252
+ )
@@ -0,0 +1,80 @@
1
+ import pytest
2
+
3
+ from unique_toolkit.evals.context_relevancy.schema import EvaluationSchemaStructuredOutput, Fact
4
+ from unique_toolkit.evals.exception import EvaluatorException
5
+ from unique_toolkit.evals.output_parser import parse_eval_metric_result, parse_eval_metric_result_structured_output
6
+ from unique_toolkit.evals.schemas import EvaluationMetricName, EvaluationMetricResult
7
+
8
+
9
+
10
+
11
+ def test_parse_eval_metric_result_success():
12
+ # Test successful parsing with all fields
13
+ result = '{"value": "high", "reason": "Test reason"}'
14
+ parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
15
+
16
+ assert isinstance(parsed, EvaluationMetricResult)
17
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
18
+ assert parsed.value == "high"
19
+ assert parsed.reason == "Test reason"
20
+ assert parsed.fact_list == []
21
+
22
+
23
+ def test_parse_eval_metric_result_missing_fields():
24
+ # Test parsing with missing fields (should use default "None")
25
+ result = '{"value": "high"}'
26
+ parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
27
+
28
+ assert isinstance(parsed, EvaluationMetricResult)
29
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
30
+ assert parsed.value == "high"
31
+ assert parsed.reason == "None"
32
+ assert parsed.fact_list == []
33
+
34
+
35
+ def test_parse_eval_metric_result_invalid_json():
36
+ # Test parsing with invalid JSON
37
+ result = "invalid json"
38
+ with pytest.raises(EvaluatorException) as exc_info:
39
+ parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
40
+
41
+ assert "Error occurred during parsing the evaluation metric result" in str(
42
+ exc_info.value
43
+ )
44
+
45
+
46
+ def test_parse_eval_metric_result_structured_output_basic():
47
+ # Test basic structured output without fact list
48
+ result = EvaluationSchemaStructuredOutput(value="high", reason="Test reason")
49
+ parsed = parse_eval_metric_result_structured_output(
50
+ result, EvaluationMetricName.CONTEXT_RELEVANCY
51
+ )
52
+
53
+ assert isinstance(parsed, EvaluationMetricResult)
54
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
55
+ assert parsed.value == "high"
56
+ assert parsed.reason == "Test reason"
57
+ assert parsed.fact_list == []
58
+
59
+
60
+ def test_parse_eval_metric_result_structured_output_with_facts():
61
+ # Test structured output with fact list
62
+ result = EvaluationSchemaStructuredOutput(
63
+ value="high",
64
+ reason="Test reason",
65
+ fact_list=[
66
+ Fact(fact="Fact 1"),
67
+ Fact(fact="Fact 2"),
68
+ ],
69
+ )
70
+ parsed = parse_eval_metric_result_structured_output(
71
+ result, EvaluationMetricName.CONTEXT_RELEVANCY
72
+ )
73
+
74
+ assert isinstance(parsed, EvaluationMetricResult)
75
+ assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
76
+ assert parsed.value == "high"
77
+ assert parsed.reason == "Test reason"
78
+ assert parsed.fact_list == ["Fact 1", "Fact 2"]
79
+ assert isinstance(parsed.fact_list, list)
80
+ assert len(parsed.fact_list) == 2 # None fact should be filtered out
@@ -18,6 +18,7 @@ from unique_toolkit.language_model.schemas import (
18
18
  LanguageModelFunction,
19
19
  LanguageModelMessage,
20
20
  LanguageModelMessageRole,
21
+ LanguageModelMessages,
21
22
  LanguageModelSystemMessage,
22
23
  LanguageModelToolMessage,
23
24
  LanguageModelUserMessage,
@@ -218,13 +219,7 @@ class HistoryManager:
218
219
  rendered_user_message_string: str,
219
220
  rendered_system_message_string: str,
220
221
  remove_from_text: Callable[[str], Awaitable[str]]
221
- ) -> list[
222
- LanguageModelMessage
223
- | LanguageModelToolMessage
224
- | LanguageModelAssistantMessage
225
- | LanguageModelSystemMessage
226
- | LanguageModelUserMessage
227
- ]:
222
+ ) -> LanguageModelMessages:
228
223
  messages = await self._token_reducer.get_history_for_model_call(
229
224
  original_user_message=original_user_message,
230
225
  rendered_user_message_string=rendered_user_message_string,
@@ -232,4 +227,4 @@ class HistoryManager:
232
227
  loop_history=self._loop_history,
233
228
  remove_from_text=remove_from_text,
234
229
  )
235
- return messages.root
230
+ return messages
@@ -129,6 +129,14 @@ class LanguageModelStreamResponse(BaseModel):
129
129
  message: LanguageModelStreamResponseMessage
130
130
  tool_calls: list[LanguageModelFunction] | None = None
131
131
 
132
+ def is_empty(self) -> bool:
133
+ """
134
+ Check if the stream response is empty.
135
+ An empty stream response has no text and no tool calls.
136
+ """
137
+ return not self.message.original_text and not self.tool_calls
138
+
139
+
132
140
  def to_openai_param(self) -> ChatCompletionAssistantMessageParam:
133
141
  return ChatCompletionAssistantMessageParam(
134
142
  role="assistant",