unique_toolkit 0.8.15__py3-none-any.whl → 0.8.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/evals/config.py +36 -0
- unique_toolkit/evals/context_relevancy/prompts.py +56 -0
- unique_toolkit/evals/context_relevancy/schema.py +88 -0
- unique_toolkit/evals/context_relevancy/service.py +241 -0
- unique_toolkit/evals/hallucination/constants.py +61 -0
- unique_toolkit/evals/hallucination/hallucination_evaluation.py +92 -0
- unique_toolkit/evals/hallucination/prompts.py +79 -0
- unique_toolkit/evals/hallucination/service.py +57 -0
- unique_toolkit/evals/hallucination/utils.py +213 -0
- unique_toolkit/evals/output_parser.py +48 -0
- unique_toolkit/evals/tests/test_context_relevancy_service.py +252 -0
- unique_toolkit/evals/tests/test_output_parser.py +80 -0
- unique_toolkit/history_manager/history_manager.py +3 -8
- unique_toolkit/language_model/schemas.py +8 -0
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.16.dist-info}/METADATA +4 -1
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.16.dist-info}/RECORD +18 -6
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.16.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.16.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from unique_toolkit.app.schemas import ChatEvent
|
|
4
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
5
|
+
from unique_toolkit.evals.schemas import EvaluationMetricInput, EvaluationMetricResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from .constants import hallucination_metric_default_config
|
|
9
|
+
from .utils import check_hallucination
|
|
10
|
+
|
|
11
|
+
SYSTEM_MSG_KEY = "systemPrompt"
|
|
12
|
+
USER_MSG_KEY = "userPrompt"
|
|
13
|
+
SYSTEM_MSG_DEFAULT_KEY = "systemPromptDefault"
|
|
14
|
+
USER_MSG_DEFAULT_KEY = "userPromptDefault"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HallucinationEvaluator:
|
|
18
|
+
def __init__(self, event: ChatEvent):
|
|
19
|
+
self.event = event
|
|
20
|
+
|
|
21
|
+
self.logger = logging.getLogger(f"HallucinationEvaluator.{__name__}")
|
|
22
|
+
|
|
23
|
+
async def analyze(
|
|
24
|
+
self,
|
|
25
|
+
input: EvaluationMetricInput,
|
|
26
|
+
config: EvaluationMetricConfig = hallucination_metric_default_config,
|
|
27
|
+
) -> EvaluationMetricResult | None:
|
|
28
|
+
"""
|
|
29
|
+
Analyzes the level of hallucination in the generated output by comparing it with the input
|
|
30
|
+
and the provided contexts or history. The analysis classifies the hallucination level as:
|
|
31
|
+
- low
|
|
32
|
+
- medium
|
|
33
|
+
- high
|
|
34
|
+
|
|
35
|
+
If no contexts or history are referenced in the generated output, the method verifies
|
|
36
|
+
that the output does not contain any relevant information to answer the question.
|
|
37
|
+
|
|
38
|
+
This method calls `check_hallucination` to perform the actual analysis. The `check_hallucination`
|
|
39
|
+
function handles the evaluation using the company ID from the event, the provided input, and the configuration.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
input (EvaluationMetricInput): The input data used for evaluation, including the generated output and reference information.
|
|
43
|
+
config (EvaluationMetricConfig, optional): Configuration settings for the evaluation. Defaults to `hallucination_metric_default_config`.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
EvaluationMetricResult | None: The result of the evaluation, indicating the level of hallucination. Returns `None` if the analysis cannot be performed.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
EvaluatorException: If the context texts are empty, required fields are missing, or an error occurs during the evaluation.
|
|
50
|
+
"""
|
|
51
|
+
if config.enabled is False:
|
|
52
|
+
self.logger.info("Hallucination metric is not enabled.")
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
return await check_hallucination(
|
|
56
|
+
company_id=self.event.company_id, input=input, config=config
|
|
57
|
+
)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from string import Template
|
|
3
|
+
|
|
4
|
+
from unique_toolkit.content.schemas import ContentChunk
|
|
5
|
+
from unique_toolkit.language_model.schemas import (
|
|
6
|
+
LanguageModelMessages,
|
|
7
|
+
LanguageModelStreamResponse,
|
|
8
|
+
LanguageModelSystemMessage,
|
|
9
|
+
LanguageModelUserMessage,
|
|
10
|
+
)
|
|
11
|
+
from unique_toolkit.language_model.service import LanguageModelService
|
|
12
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
13
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
14
|
+
from unique_toolkit.evals.output_parser import parse_eval_metric_result
|
|
15
|
+
from unique_toolkit.evals.schemas import (
|
|
16
|
+
EvaluationMetricInput,
|
|
17
|
+
EvaluationMetricName,
|
|
18
|
+
EvaluationMetricResult,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
from .constants import (
|
|
23
|
+
SYSTEM_MSG_DEFAULT_KEY,
|
|
24
|
+
SYSTEM_MSG_KEY,
|
|
25
|
+
USER_MSG_DEFAULT_KEY,
|
|
26
|
+
USER_MSG_KEY,
|
|
27
|
+
hallucination_required_input_fields,
|
|
28
|
+
)
|
|
29
|
+
from .prompts import (
|
|
30
|
+
HALLUCINATION_METRIC_SYSTEM_MSG,
|
|
31
|
+
HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
|
|
32
|
+
HALLUCINATION_METRIC_USER_MSG,
|
|
33
|
+
HALLUCINATION_METRIC_USER_MSG_DEFAULT,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
async def check_hallucination(
|
|
38
|
+
company_id: str,
|
|
39
|
+
input: EvaluationMetricInput,
|
|
40
|
+
config: EvaluationMetricConfig,
|
|
41
|
+
) -> EvaluationMetricResult:
|
|
42
|
+
"""
|
|
43
|
+
Analyzes the level of hallucination in the generated output by comparing it with the provided input
|
|
44
|
+
and the contexts or history. The analysis classifies the hallucination level as:
|
|
45
|
+
- low
|
|
46
|
+
- medium
|
|
47
|
+
- high
|
|
48
|
+
|
|
49
|
+
If no contexts or history are referenced in the generated output, the method checks that the output
|
|
50
|
+
does not contain any relevant information to answer the question.
|
|
51
|
+
|
|
52
|
+
This method performs the following steps:
|
|
53
|
+
1. Checks if the hallucination metric is enabled using the provided `config`.
|
|
54
|
+
2. Logs the start of the analysis using the provided `logger`.
|
|
55
|
+
3. Validates the required fields in the `input` data.
|
|
56
|
+
4. Retrieves the messages using the `_get_msgs` method.
|
|
57
|
+
5. Calls `LanguageModelService.complete_async_util` to get a completion result.
|
|
58
|
+
6. Parses and returns the evaluation metric result based on the content of the completion result.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
company_id (str): The company ID for the analysis.
|
|
62
|
+
input (EvaluationMetricInput): The input data used for evaluation, including the generated output and reference information.
|
|
63
|
+
config (EvaluationMetricConfig, optional): Configuration settings for the evaluation. Defaults to `hallucination_metric_default_config`.
|
|
64
|
+
logger (Optional[logging.Logger], optional): The logger used for logging information and errors. Defaults to the logger for the current module.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
EvaluationMetricResult | None: The result of the evaluation, indicating the level of hallucination. Returns `None` if the metric is not enabled or if an error occurs.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
EvaluatorException: If the context texts are empty, required fields are missing, or an error occurs during the evaluation.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
logger = logging.getLogger(f"check_hallucination.{__name__}")
|
|
74
|
+
|
|
75
|
+
model_name = config.language_model.name
|
|
76
|
+
logger.info(f"Analyzing level of hallucination with {model_name}.")
|
|
77
|
+
|
|
78
|
+
input.validate_required_fields(hallucination_required_input_fields)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
msgs = _get_msgs(input, config, logger)
|
|
82
|
+
result = await LanguageModelService.complete_async_util(
|
|
83
|
+
company_id=company_id, messages=msgs, model_name=model_name
|
|
84
|
+
)
|
|
85
|
+
result_content = result.choices[0].message.content
|
|
86
|
+
if not result_content:
|
|
87
|
+
error_message = "Hallucination evaluation did not return a result."
|
|
88
|
+
raise EvaluatorException(
|
|
89
|
+
error_message=error_message,
|
|
90
|
+
user_message=error_message,
|
|
91
|
+
)
|
|
92
|
+
return parse_eval_metric_result(
|
|
93
|
+
result_content, # type: ignore
|
|
94
|
+
EvaluationMetricName.HALLUCINATION,
|
|
95
|
+
)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
error_message = "Error occurred during hallucination metric analysis"
|
|
98
|
+
raise EvaluatorException(
|
|
99
|
+
error_message=f"{error_message}: {e}",
|
|
100
|
+
user_message=error_message,
|
|
101
|
+
exception=e,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_msgs(
|
|
106
|
+
input: EvaluationMetricInput,
|
|
107
|
+
config: EvaluationMetricConfig,
|
|
108
|
+
logger: logging.Logger,
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
Composes the messages for hallucination analysis based on the provided input and configuration.
|
|
112
|
+
|
|
113
|
+
This method decides how to compose the messages based on the availability of context texts and history
|
|
114
|
+
message texts in the `input`
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
input (EvaluationMetricInput): The input data that includes context texts and history message texts
|
|
118
|
+
for the analysis.
|
|
119
|
+
config (EvaluationMetricConfig): The configuration settings for composing messages.
|
|
120
|
+
logger (Optional[logging.Logger], optional): The logger used for logging debug information.
|
|
121
|
+
Defaults to the logger for the current module.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The composed messages as per the provided input and configuration. The exact type and structure
|
|
125
|
+
depend on the implementation of the `compose_msgs` and `compose_msgs_default` methods.
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
if input.context_texts or input.history_messages:
|
|
129
|
+
logger.debug("Using context / history for hallucination evaluation.")
|
|
130
|
+
return _compose_msgs(input, config)
|
|
131
|
+
else:
|
|
132
|
+
logger.debug("No contexts and history provided for hallucination evaluation.")
|
|
133
|
+
return _compose_msgs_default(input, config)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _compose_msgs(
|
|
137
|
+
input: EvaluationMetricInput,
|
|
138
|
+
config: EvaluationMetricConfig,
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Composes the hallucination analysis messages.
|
|
142
|
+
"""
|
|
143
|
+
system_msg_content = _get_system_prompt_with_contexts(config)
|
|
144
|
+
system_msg = LanguageModelSystemMessage(content=system_msg_content)
|
|
145
|
+
|
|
146
|
+
user_msg_templ = Template(_get_user_prompt_with_contexts(config))
|
|
147
|
+
user_msg_content = user_msg_templ.substitute(
|
|
148
|
+
input_text=input.input_text,
|
|
149
|
+
contexts_text=input.get_joined_context_texts(tag_name="reference"),
|
|
150
|
+
history_messages_text=input.get_joined_history_texts(tag_name="conversation"),
|
|
151
|
+
output_text=input.output_text,
|
|
152
|
+
)
|
|
153
|
+
user_msg = LanguageModelUserMessage(content=user_msg_content)
|
|
154
|
+
return LanguageModelMessages([system_msg, user_msg])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _compose_msgs_default(
|
|
158
|
+
input: EvaluationMetricInput,
|
|
159
|
+
config: EvaluationMetricConfig,
|
|
160
|
+
):
|
|
161
|
+
"""
|
|
162
|
+
Composes the hallucination analysis prompt without messages.
|
|
163
|
+
"""
|
|
164
|
+
system_msg_content = _get_system_prompt_default(config)
|
|
165
|
+
system_msg = LanguageModelSystemMessage(content=system_msg_content)
|
|
166
|
+
|
|
167
|
+
user_msg_templ = Template(_get_user_prompt_default(config))
|
|
168
|
+
user_msg_content = user_msg_templ.substitute(
|
|
169
|
+
input_text=input.input_text,
|
|
170
|
+
output_text=input.output_text,
|
|
171
|
+
)
|
|
172
|
+
user_msg = LanguageModelUserMessage(content=user_msg_content)
|
|
173
|
+
return LanguageModelMessages([system_msg, user_msg])
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _get_system_prompt_with_contexts(config: EvaluationMetricConfig):
|
|
177
|
+
return config.custom_prompts.setdefault(
|
|
178
|
+
SYSTEM_MSG_KEY,
|
|
179
|
+
HALLUCINATION_METRIC_SYSTEM_MSG,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _get_user_prompt_with_contexts(config: EvaluationMetricConfig):
|
|
184
|
+
return config.custom_prompts.setdefault(
|
|
185
|
+
USER_MSG_KEY,
|
|
186
|
+
HALLUCINATION_METRIC_USER_MSG,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _get_system_prompt_default(config: EvaluationMetricConfig):
|
|
191
|
+
return config.custom_prompts.setdefault(
|
|
192
|
+
SYSTEM_MSG_DEFAULT_KEY,
|
|
193
|
+
HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _get_user_prompt_default(config: EvaluationMetricConfig):
|
|
198
|
+
return config.custom_prompts.setdefault(
|
|
199
|
+
USER_MSG_DEFAULT_KEY,
|
|
200
|
+
HALLUCINATION_METRIC_USER_MSG_DEFAULT,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def context_text_from_stream_response(
|
|
205
|
+
response: LanguageModelStreamResponse, selected_chunks: list[ContentChunk]
|
|
206
|
+
):
|
|
207
|
+
response_references = response.message.references
|
|
208
|
+
reference_ids = [reference.source_id for reference in response_references]
|
|
209
|
+
filtered_contexts: list[str] = []
|
|
210
|
+
for chunk in selected_chunks:
|
|
211
|
+
if f"{chunk.id}_{chunk.chunk_id}" in reference_ids:
|
|
212
|
+
filtered_contexts.append(chunk.text)
|
|
213
|
+
return filtered_contexts
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from unique_toolkit.language_model.utils import convert_string_to_json
|
|
2
|
+
from unique_toolkit.evals.context_relevancy.schema import (
|
|
3
|
+
EvaluationSchemaStructuredOutput,
|
|
4
|
+
)
|
|
5
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
6
|
+
from unique_toolkit.evals.schemas import (
|
|
7
|
+
EvaluationMetricName,
|
|
8
|
+
EvaluationMetricResult,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_eval_metric_result(
|
|
13
|
+
result: str,
|
|
14
|
+
metric_name: EvaluationMetricName,
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Parses the evaluation metric result.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
parsed_result = convert_string_to_json(result)
|
|
22
|
+
except Exception as e:
|
|
23
|
+
error_message = "Error occurred during parsing the evaluation metric result"
|
|
24
|
+
raise EvaluatorException(
|
|
25
|
+
user_message=f"{error_message}.",
|
|
26
|
+
error_message=f"{error_message}: {str(e)}",
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return EvaluationMetricResult(
|
|
30
|
+
name=metric_name,
|
|
31
|
+
value=parsed_result.get("value", "None"),
|
|
32
|
+
reason=parsed_result.get("reason", "None"),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def parse_eval_metric_result_structured_output(
|
|
37
|
+
result: EvaluationSchemaStructuredOutput,
|
|
38
|
+
metric_name: EvaluationMetricName,
|
|
39
|
+
) -> EvaluationMetricResult:
|
|
40
|
+
"""
|
|
41
|
+
Parses the evaluation metric result.
|
|
42
|
+
"""
|
|
43
|
+
return EvaluationMetricResult(
|
|
44
|
+
name=metric_name,
|
|
45
|
+
value=result.value,
|
|
46
|
+
reason=result.reason,
|
|
47
|
+
fact_list=[item.fact for item in result.fact_list],
|
|
48
|
+
)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from unique_toolkit.app.schemas import ChatEvent
|
|
5
|
+
from unique_toolkit.chat.service import LanguageModelName
|
|
6
|
+
from unique_toolkit.language_model.infos import (
|
|
7
|
+
LanguageModelInfo,
|
|
8
|
+
)
|
|
9
|
+
from unique_toolkit.language_model.schemas import (
|
|
10
|
+
LanguageModelAssistantMessage,
|
|
11
|
+
LanguageModelCompletionChoice,
|
|
12
|
+
LanguageModelMessages,
|
|
13
|
+
)
|
|
14
|
+
from unique_toolkit.language_model.service import LanguageModelResponse
|
|
15
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
16
|
+
from unique_toolkit.evals.context_relevancy.prompts import (
|
|
17
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
|
18
|
+
)
|
|
19
|
+
from unique_toolkit.evals.context_relevancy.schema import (
|
|
20
|
+
EvaluationSchemaStructuredOutput,
|
|
21
|
+
)
|
|
22
|
+
from unique_toolkit.evals.context_relevancy.service import (
|
|
23
|
+
ContextRelevancyEvaluator,
|
|
24
|
+
)
|
|
25
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
26
|
+
from unique_toolkit.evals.schemas import (
|
|
27
|
+
EvaluationMetricInput,
|
|
28
|
+
EvaluationMetricName,
|
|
29
|
+
EvaluationMetricResult,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def event():
|
|
35
|
+
event = MagicMock(spec=ChatEvent)
|
|
36
|
+
event.payload = MagicMock()
|
|
37
|
+
event.payload.user_message = MagicMock()
|
|
38
|
+
event.payload.user_message.text = "Test query"
|
|
39
|
+
event.user_id = "user_0"
|
|
40
|
+
event.company_id = "company_0"
|
|
41
|
+
return event
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def evaluator(event):
|
|
46
|
+
return ContextRelevancyEvaluator(event)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def basic_config():
|
|
51
|
+
return EvaluationMetricConfig(
|
|
52
|
+
enabled=True,
|
|
53
|
+
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
|
54
|
+
language_model=LanguageModelInfo.from_name(
|
|
55
|
+
LanguageModelName.AZURE_GPT_4o_2024_0806
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def structured_config(basic_config):
|
|
62
|
+
model_info = LanguageModelInfo.from_name(LanguageModelName.AZURE_GPT_4o_2024_0806)
|
|
63
|
+
return EvaluationMetricConfig(
|
|
64
|
+
enabled=True,
|
|
65
|
+
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
|
66
|
+
language_model=model_info,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.fixture
|
|
71
|
+
def sample_input():
|
|
72
|
+
return EvaluationMetricInput(
|
|
73
|
+
input_text="test query",
|
|
74
|
+
context_texts=["test context 1", "test context 2"],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.mark.asyncio
|
|
79
|
+
async def test_analyze_disabled(evaluator, sample_input, basic_config):
|
|
80
|
+
basic_config.enabled = False
|
|
81
|
+
result = await evaluator.analyze(sample_input, basic_config)
|
|
82
|
+
assert result is None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.mark.asyncio
|
|
86
|
+
async def test_analyze_empty_context(evaluator, basic_config):
|
|
87
|
+
input_with_empty_context = EvaluationMetricInput(
|
|
88
|
+
input_text="test query", context_texts=[]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
92
|
+
await evaluator.analyze(input_with_empty_context, basic_config)
|
|
93
|
+
|
|
94
|
+
assert "No context texts provided." in str(exc_info.value)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.asyncio
|
|
98
|
+
async def test_analyze_regular_output(evaluator, sample_input, basic_config):
|
|
99
|
+
mock_result = LanguageModelResponse(
|
|
100
|
+
choices=[
|
|
101
|
+
LanguageModelCompletionChoice(
|
|
102
|
+
index=0,
|
|
103
|
+
message=LanguageModelAssistantMessage(
|
|
104
|
+
content="""{
|
|
105
|
+
"value": "high",
|
|
106
|
+
"reason": "Test reason"
|
|
107
|
+
}"""
|
|
108
|
+
),
|
|
109
|
+
finish_reason="stop",
|
|
110
|
+
)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with patch.object(
|
|
115
|
+
evaluator.language_model_service,
|
|
116
|
+
"complete_async",
|
|
117
|
+
return_value=mock_result,
|
|
118
|
+
) as mock_complete:
|
|
119
|
+
result = await evaluator.analyze(sample_input, basic_config)
|
|
120
|
+
|
|
121
|
+
assert isinstance(result, EvaluationMetricResult)
|
|
122
|
+
assert result.value.lower() == "high"
|
|
123
|
+
mock_complete.assert_called_once()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@pytest.mark.asyncio
|
|
127
|
+
async def test_analyze_structured_output(evaluator, sample_input, structured_config):
|
|
128
|
+
mock_result = LanguageModelResponse(
|
|
129
|
+
choices=[
|
|
130
|
+
LanguageModelCompletionChoice(
|
|
131
|
+
index=0,
|
|
132
|
+
message=LanguageModelAssistantMessage(
|
|
133
|
+
content="HIGH",
|
|
134
|
+
parsed={"value": "high", "reason": "Test reason"},
|
|
135
|
+
),
|
|
136
|
+
finish_reason="stop",
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
structured_output_schema = EvaluationSchemaStructuredOutput
|
|
142
|
+
|
|
143
|
+
with patch.object(
|
|
144
|
+
evaluator.language_model_service,
|
|
145
|
+
"complete_async",
|
|
146
|
+
return_value=mock_result,
|
|
147
|
+
) as mock_complete:
|
|
148
|
+
result = await evaluator.analyze(
|
|
149
|
+
sample_input, structured_config, structured_output_schema
|
|
150
|
+
)
|
|
151
|
+
assert isinstance(result, EvaluationMetricResult)
|
|
152
|
+
assert result.value.lower() == "high"
|
|
153
|
+
mock_complete.assert_called_once()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_analyze_structured_output_validation_error(
|
|
158
|
+
evaluator, sample_input, structured_config
|
|
159
|
+
):
|
|
160
|
+
mock_result = LanguageModelResponse(
|
|
161
|
+
choices=[
|
|
162
|
+
LanguageModelCompletionChoice(
|
|
163
|
+
index=0,
|
|
164
|
+
message=LanguageModelAssistantMessage(
|
|
165
|
+
content="HIGH", parsed={"invalid": "data"}
|
|
166
|
+
),
|
|
167
|
+
finish_reason="stop",
|
|
168
|
+
)
|
|
169
|
+
]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
structured_output_schema = EvaluationSchemaStructuredOutput
|
|
173
|
+
|
|
174
|
+
with patch.object(
|
|
175
|
+
evaluator.language_model_service,
|
|
176
|
+
"complete_async",
|
|
177
|
+
return_value=mock_result,
|
|
178
|
+
):
|
|
179
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
180
|
+
await evaluator.analyze(
|
|
181
|
+
sample_input, structured_config, structured_output_schema
|
|
182
|
+
)
|
|
183
|
+
assert "Error occurred during structured output validation" in str(
|
|
184
|
+
exc_info.value
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@pytest.mark.asyncio
|
|
189
|
+
async def test_analyze_regular_output_empty_response(
|
|
190
|
+
evaluator, sample_input, basic_config
|
|
191
|
+
):
|
|
192
|
+
mock_result = LanguageModelResponse(
|
|
193
|
+
choices=[
|
|
194
|
+
LanguageModelCompletionChoice(
|
|
195
|
+
index=0,
|
|
196
|
+
message=LanguageModelAssistantMessage(content=""),
|
|
197
|
+
finish_reason="stop",
|
|
198
|
+
)
|
|
199
|
+
]
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
with patch.object(
|
|
203
|
+
evaluator.language_model_service,
|
|
204
|
+
"complete_async",
|
|
205
|
+
return_value=mock_result,
|
|
206
|
+
):
|
|
207
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
208
|
+
await evaluator.analyze(sample_input, basic_config)
|
|
209
|
+
assert "did not return a result" in str(exc_info.value)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def test_compose_msgs_regular(evaluator, sample_input, basic_config):
|
|
213
|
+
messages = evaluator._compose_msgs(
|
|
214
|
+
sample_input, basic_config, enable_structured_output=False
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
assert isinstance(messages, LanguageModelMessages)
|
|
218
|
+
assert messages.root[0].content == CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
|
|
219
|
+
assert isinstance(messages.root[1].content, str)
|
|
220
|
+
assert "test query" in messages.root[1].content
|
|
221
|
+
assert "test context 1" in messages.root[1].content
|
|
222
|
+
assert "test context 2" in messages.root[1].content
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_compose_msgs_structured(evaluator, sample_input, structured_config):
|
|
226
|
+
messages = evaluator._compose_msgs(
|
|
227
|
+
sample_input, structured_config, enable_structured_output=True
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert isinstance(messages, LanguageModelMessages)
|
|
231
|
+
assert len(messages.root) == 2
|
|
232
|
+
assert (
|
|
233
|
+
messages.root[0].content != CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG
|
|
234
|
+
) # Should use structured output prompt
|
|
235
|
+
assert isinstance(messages.root[1].content, str)
|
|
236
|
+
assert "test query" in messages.root[1].content
|
|
237
|
+
assert "test context 1" in messages.root[1].content
|
|
238
|
+
assert "test context 2" in messages.root[1].content
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.asyncio
|
|
242
|
+
async def test_analyze_unknown_error(evaluator, sample_input, basic_config):
|
|
243
|
+
with patch.object(
|
|
244
|
+
evaluator.language_model_service,
|
|
245
|
+
"complete_async",
|
|
246
|
+
side_effect=Exception("Unknown error"),
|
|
247
|
+
):
|
|
248
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
249
|
+
await evaluator.analyze(sample_input, basic_config)
|
|
250
|
+
assert "Unknown error occurred during context relevancy metric analysis" in str(
|
|
251
|
+
exc_info.value
|
|
252
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from unique_toolkit.evals.context_relevancy.schema import EvaluationSchemaStructuredOutput, Fact
|
|
4
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
5
|
+
from unique_toolkit.evals.output_parser import parse_eval_metric_result, parse_eval_metric_result_structured_output
|
|
6
|
+
from unique_toolkit.evals.schemas import EvaluationMetricName, EvaluationMetricResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_parse_eval_metric_result_success():
|
|
12
|
+
# Test successful parsing with all fields
|
|
13
|
+
result = '{"value": "high", "reason": "Test reason"}'
|
|
14
|
+
parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
15
|
+
|
|
16
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
17
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
18
|
+
assert parsed.value == "high"
|
|
19
|
+
assert parsed.reason == "Test reason"
|
|
20
|
+
assert parsed.fact_list == []
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_parse_eval_metric_result_missing_fields():
|
|
24
|
+
# Test parsing with missing fields (should use default "None")
|
|
25
|
+
result = '{"value": "high"}'
|
|
26
|
+
parsed = parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
27
|
+
|
|
28
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
29
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
30
|
+
assert parsed.value == "high"
|
|
31
|
+
assert parsed.reason == "None"
|
|
32
|
+
assert parsed.fact_list == []
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_parse_eval_metric_result_invalid_json():
|
|
36
|
+
# Test parsing with invalid JSON
|
|
37
|
+
result = "invalid json"
|
|
38
|
+
with pytest.raises(EvaluatorException) as exc_info:
|
|
39
|
+
parse_eval_metric_result(result, EvaluationMetricName.CONTEXT_RELEVANCY)
|
|
40
|
+
|
|
41
|
+
assert "Error occurred during parsing the evaluation metric result" in str(
|
|
42
|
+
exc_info.value
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_parse_eval_metric_result_structured_output_basic():
|
|
47
|
+
# Test basic structured output without fact list
|
|
48
|
+
result = EvaluationSchemaStructuredOutput(value="high", reason="Test reason")
|
|
49
|
+
parsed = parse_eval_metric_result_structured_output(
|
|
50
|
+
result, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
54
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
55
|
+
assert parsed.value == "high"
|
|
56
|
+
assert parsed.reason == "Test reason"
|
|
57
|
+
assert parsed.fact_list == []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_parse_eval_metric_result_structured_output_with_facts():
|
|
61
|
+
# Test structured output with fact list
|
|
62
|
+
result = EvaluationSchemaStructuredOutput(
|
|
63
|
+
value="high",
|
|
64
|
+
reason="Test reason",
|
|
65
|
+
fact_list=[
|
|
66
|
+
Fact(fact="Fact 1"),
|
|
67
|
+
Fact(fact="Fact 2"),
|
|
68
|
+
],
|
|
69
|
+
)
|
|
70
|
+
parsed = parse_eval_metric_result_structured_output(
|
|
71
|
+
result, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
assert isinstance(parsed, EvaluationMetricResult)
|
|
75
|
+
assert parsed.name == EvaluationMetricName.CONTEXT_RELEVANCY
|
|
76
|
+
assert parsed.value == "high"
|
|
77
|
+
assert parsed.reason == "Test reason"
|
|
78
|
+
assert parsed.fact_list == ["Fact 1", "Fact 2"]
|
|
79
|
+
assert isinstance(parsed.fact_list, list)
|
|
80
|
+
assert len(parsed.fact_list) == 2 # None fact should be filtered out
|
|
@@ -18,6 +18,7 @@ from unique_toolkit.language_model.schemas import (
|
|
|
18
18
|
LanguageModelFunction,
|
|
19
19
|
LanguageModelMessage,
|
|
20
20
|
LanguageModelMessageRole,
|
|
21
|
+
LanguageModelMessages,
|
|
21
22
|
LanguageModelSystemMessage,
|
|
22
23
|
LanguageModelToolMessage,
|
|
23
24
|
LanguageModelUserMessage,
|
|
@@ -218,13 +219,7 @@ class HistoryManager:
|
|
|
218
219
|
rendered_user_message_string: str,
|
|
219
220
|
rendered_system_message_string: str,
|
|
220
221
|
remove_from_text: Callable[[str], Awaitable[str]]
|
|
221
|
-
) ->
|
|
222
|
-
LanguageModelMessage
|
|
223
|
-
| LanguageModelToolMessage
|
|
224
|
-
| LanguageModelAssistantMessage
|
|
225
|
-
| LanguageModelSystemMessage
|
|
226
|
-
| LanguageModelUserMessage
|
|
227
|
-
]:
|
|
222
|
+
) -> LanguageModelMessages:
|
|
228
223
|
messages = await self._token_reducer.get_history_for_model_call(
|
|
229
224
|
original_user_message=original_user_message,
|
|
230
225
|
rendered_user_message_string=rendered_user_message_string,
|
|
@@ -232,4 +227,4 @@ class HistoryManager:
|
|
|
232
227
|
loop_history=self._loop_history,
|
|
233
228
|
remove_from_text=remove_from_text,
|
|
234
229
|
)
|
|
235
|
-
return messages
|
|
230
|
+
return messages
|
|
@@ -129,6 +129,14 @@ class LanguageModelStreamResponse(BaseModel):
|
|
|
129
129
|
message: LanguageModelStreamResponseMessage
|
|
130
130
|
tool_calls: list[LanguageModelFunction] | None = None
|
|
131
131
|
|
|
132
|
+
def is_empty(self) -> bool:
|
|
133
|
+
"""
|
|
134
|
+
Check if the stream response is empty.
|
|
135
|
+
An empty stream response has no text and no tool calls.
|
|
136
|
+
"""
|
|
137
|
+
return not self.message.original_text and not self.tool_calls
|
|
138
|
+
|
|
139
|
+
|
|
132
140
|
def to_openai_param(self) -> ChatCompletionAssistantMessageParam:
|
|
133
141
|
return ChatCompletionAssistantMessageParam(
|
|
134
142
|
role="assistant",
|