unique_toolkit 0.8.15__py3-none-any.whl → 0.8.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/token/token_counting.py +2 -3
- unique_toolkit/_common/validators.py +7 -0
- unique_toolkit/debug_info_manager/debug_info_manager.py +19 -0
- unique_toolkit/evals/config.py +36 -0
- unique_toolkit/evals/context_relevancy/prompts.py +56 -0
- unique_toolkit/evals/context_relevancy/schema.py +88 -0
- unique_toolkit/evals/context_relevancy/service.py +241 -0
- unique_toolkit/evals/hallucination/constants.py +61 -0
- unique_toolkit/evals/hallucination/hallucination_evaluation.py +91 -0
- unique_toolkit/evals/hallucination/prompts.py +79 -0
- unique_toolkit/evals/hallucination/service.py +57 -0
- unique_toolkit/evals/hallucination/utils.py +213 -0
- unique_toolkit/evals/output_parser.py +48 -0
- unique_toolkit/evals/tests/test_context_relevancy_service.py +252 -0
- unique_toolkit/evals/tests/test_output_parser.py +80 -0
- unique_toolkit/history_manager/history_construction_with_contents.py +4 -4
- unique_toolkit/history_manager/history_manager.py +51 -66
- unique_toolkit/history_manager/loop_token_reducer.py +17 -17
- unique_toolkit/language_model/schemas.py +8 -0
- unique_toolkit/postprocessor/postprocessor_manager.py +1 -2
- unique_toolkit/tools/factory.py +7 -2
- unique_toolkit/tools/tool.py +0 -2
- unique_toolkit/tools/tool_manager.py +0 -3
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.17.dist-info}/METADATA +7 -1
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.17.dist-info}/RECORD +27 -15
- unique_toolkit/tools/agent_chunks_handler.py +0 -62
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.17.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.15.dist-info → unique_toolkit-0.8.17.dist-info}/WHEEL +0 -0
|
@@ -5,15 +5,14 @@ import json
|
|
|
5
5
|
from typing import Any, Callable
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel
|
|
8
|
+
from unique_toolkit._common.token.image_token_counting import calculate_image_tokens_from_base64
|
|
8
9
|
from unique_toolkit.language_model import (
|
|
9
10
|
LanguageModelMessage,
|
|
10
11
|
LanguageModelMessages,
|
|
11
12
|
LanguageModelName,
|
|
12
13
|
)
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
calculate_image_tokens_from_base64,
|
|
16
|
-
)
|
|
15
|
+
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
class SpecialToolCallingTokens(BaseModel):
|
|
@@ -28,6 +28,13 @@ LMI = Annotated[
|
|
|
28
28
|
),
|
|
29
29
|
]
|
|
30
30
|
|
|
31
|
+
def get_LMI_default_field(llm_name: LanguageModelName, **kwargs) -> Any:
|
|
32
|
+
return Field(
|
|
33
|
+
default=LanguageModelInfo.from_name(llm_name),
|
|
34
|
+
json_schema_extra={"default": llm_name},
|
|
35
|
+
**kwargs,
|
|
36
|
+
)
|
|
37
|
+
|
|
31
38
|
|
|
32
39
|
def serialize_lmi(model: LanguageModelInfo) -> str | LanguageModelInfo:
|
|
33
40
|
if model.provider == LanguageModelProvider.CUSTOM:
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from unique_toolkit.content.schemas import ContentChunk, ContentReference
|
|
2
|
+
from unique_toolkit.tools.schemas import ToolCallResponse
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DebugInfoManager:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
self.debug_info = {"tools": []}
|
|
8
|
+
|
|
9
|
+
def extract_tool_debug_info(self, tool_call_responses: list[ToolCallResponse]):
|
|
10
|
+
for tool_call_response in tool_call_responses:
|
|
11
|
+
self.debug_info["tools"].append(
|
|
12
|
+
{"name": tool_call_response.name, "data": tool_call_response.debug_info}
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def add(self, key, value):
|
|
16
|
+
self.debug_info = self.debug_info | {key: value}
|
|
17
|
+
|
|
18
|
+
def get(self):
|
|
19
|
+
return self.debug_info
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from humps import camelize
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
from unique_toolkit._common.validators import LMI
|
|
7
|
+
from unique_toolkit.language_model.infos import LanguageModelInfo, LanguageModelName
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .schemas import (
|
|
11
|
+
EvaluationMetricName,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
model_config = ConfigDict(
|
|
15
|
+
alias_generator=camelize,
|
|
16
|
+
populate_by_name=True,
|
|
17
|
+
arbitrary_types_allowed=True,
|
|
18
|
+
validate_default=True,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EvaluationMetricConfig(BaseModel):
|
|
23
|
+
model_config = model_config
|
|
24
|
+
|
|
25
|
+
enabled: bool = False
|
|
26
|
+
name: EvaluationMetricName
|
|
27
|
+
language_model: LMI = LanguageModelInfo.from_name(
|
|
28
|
+
LanguageModelName.AZURE_GPT_35_TURBO_0125,
|
|
29
|
+
)
|
|
30
|
+
additional_llm_options: dict[str, Any] = Field(
|
|
31
|
+
default={},
|
|
32
|
+
description="Additional options to pass to the language model.",
|
|
33
|
+
)
|
|
34
|
+
custom_prompts: dict[str, str] = {}
|
|
35
|
+
score_to_label: dict[str, str] = {}
|
|
36
|
+
score_to_title: dict[str, str] = {}
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG = """
|
|
2
|
+
You will receive an input and a set of contexts.
|
|
3
|
+
Your task is to evaluate how relevant the contexts are to the input text.
|
|
4
|
+
|
|
5
|
+
Use the following rating scale to generate a score:
|
|
6
|
+
[low] - The contexts are not relevant to the input.
|
|
7
|
+
[medium] - The contexts are somewhat relevant to the input.
|
|
8
|
+
[high] - The contexts are highly relevant to the input.
|
|
9
|
+
|
|
10
|
+
Your answer must be in JSON format:
|
|
11
|
+
{
|
|
12
|
+
"reason": Your explanation of your judgement of the evaluation,
|
|
13
|
+
"value": decision, must be one of the following ["low", "medium", "high"]
|
|
14
|
+
}
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG_STRUCTURED_OUTPUT = """
|
|
18
|
+
You will receive an input and a set of contexts.
|
|
19
|
+
Your task is to evaluate how relevant the contexts are to the input text.
|
|
20
|
+
Further you should extract relevant facts from the contexts.
|
|
21
|
+
|
|
22
|
+
# Output Format
|
|
23
|
+
- Generate data according to the provided data schema.
|
|
24
|
+
- Ensure the output adheres to the format required by the pydantic object.
|
|
25
|
+
- All necessary fields should be populated as per the data schema guidelines.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG = """
|
|
29
|
+
Here is the data:
|
|
30
|
+
|
|
31
|
+
Input:
|
|
32
|
+
'''
|
|
33
|
+
$input_text
|
|
34
|
+
'''
|
|
35
|
+
|
|
36
|
+
Contexts:
|
|
37
|
+
'''
|
|
38
|
+
$context_texts
|
|
39
|
+
'''
|
|
40
|
+
|
|
41
|
+
Answer as JSON:
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG_STRUCTURED_OUTPUT = """
|
|
45
|
+
Here is the data:
|
|
46
|
+
|
|
47
|
+
Input:
|
|
48
|
+
'''
|
|
49
|
+
$input_text
|
|
50
|
+
'''
|
|
51
|
+
|
|
52
|
+
Contexts:
|
|
53
|
+
'''
|
|
54
|
+
$context_texts
|
|
55
|
+
'''
|
|
56
|
+
"""
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field, create_model
|
|
2
|
+
from pydantic.json_schema import SkipJsonSchema
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from unique_toolkit.tools.config import get_configuration_dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class StructuredOutputModel(BaseModel):
|
|
12
|
+
model_config = ConfigDict(extra="forbid")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StructuredOutputConfig(BaseModel):
|
|
17
|
+
model_config = get_configuration_dict()
|
|
18
|
+
|
|
19
|
+
enabled: bool = Field(
|
|
20
|
+
default=False,
|
|
21
|
+
description="Whether to use structured output for the evaluation.",
|
|
22
|
+
)
|
|
23
|
+
extract_fact_list: bool = Field(
|
|
24
|
+
default=False,
|
|
25
|
+
description="Whether to extract a list of relevant facts from context chunks with structured output.",
|
|
26
|
+
)
|
|
27
|
+
reason_description: str = Field(
|
|
28
|
+
default="A brief explanation justifying your evaluation decision.",
|
|
29
|
+
description="The description of the reason field for structured output.",
|
|
30
|
+
)
|
|
31
|
+
value_description: str = Field(
|
|
32
|
+
default="Assessment of how relevant the facts are to the query. Must be one of: ['low', 'medium', 'high'].",
|
|
33
|
+
description="The description of the value field for structured output.",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
fact_description: str = Field(
|
|
37
|
+
default="A fact is an information that is directly answers the user's query. Make sure to emphasize the important information from the fact with bold text.",
|
|
38
|
+
description="The description of the fact field for structured output.",
|
|
39
|
+
)
|
|
40
|
+
fact_list_description: str = Field(
|
|
41
|
+
default="A list of relevant facts extracted from the source that supports or answers the user's query.",
|
|
42
|
+
description="The description of the fact list field for structured output.",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Fact(StructuredOutputModel):
|
|
47
|
+
fact: str
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class EvaluationSchemaStructuredOutput(StructuredOutputModel):
|
|
51
|
+
reason: str
|
|
52
|
+
value: str
|
|
53
|
+
|
|
54
|
+
fact_list: list[Fact] = Field(default_factory=list[Fact])
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def get_with_descriptions(cls, config: StructuredOutputConfig):
|
|
58
|
+
if config.extract_fact_list:
|
|
59
|
+
FactWithDescription = create_model(
|
|
60
|
+
"Fact",
|
|
61
|
+
fact=(str, Field(..., description=config.fact_description)),
|
|
62
|
+
__base__=Fact,
|
|
63
|
+
)
|
|
64
|
+
fact_list_field = (
|
|
65
|
+
list[FactWithDescription],
|
|
66
|
+
Field(
|
|
67
|
+
description=config.fact_list_description,
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
fact_list_field = (
|
|
72
|
+
SkipJsonSchema[list[Fact]],
|
|
73
|
+
Field(default_factory=list[Fact]),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return create_model(
|
|
77
|
+
"EvaluationSchemaStructuredOutputWithDescription",
|
|
78
|
+
reason=(
|
|
79
|
+
str,
|
|
80
|
+
Field(..., description=config.reason_description),
|
|
81
|
+
),
|
|
82
|
+
value=(
|
|
83
|
+
str,
|
|
84
|
+
Field(..., description=config.value_description),
|
|
85
|
+
),
|
|
86
|
+
fact_list=fact_list_field,
|
|
87
|
+
__base__=cls,
|
|
88
|
+
)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ValidationError
|
|
4
|
+
from unique_toolkit.app.schemas import ChatEvent
|
|
5
|
+
from unique_toolkit.chat.service import ChatService
|
|
6
|
+
from unique_toolkit.language_model.infos import (
|
|
7
|
+
LanguageModelInfo,
|
|
8
|
+
LanguageModelName,
|
|
9
|
+
ModelCapabilities,
|
|
10
|
+
)
|
|
11
|
+
from unique_toolkit.language_model.prompt import Prompt
|
|
12
|
+
from unique_toolkit.language_model.schemas import (
|
|
13
|
+
LanguageModelMessages,
|
|
14
|
+
)
|
|
15
|
+
from unique_toolkit.language_model.service import (
|
|
16
|
+
LanguageModelService,
|
|
17
|
+
)
|
|
18
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
19
|
+
from unique_toolkit.evals.context_relevancy.schema import (
|
|
20
|
+
EvaluationSchemaStructuredOutput,
|
|
21
|
+
)
|
|
22
|
+
from unique_toolkit.evals.exception import EvaluatorException
|
|
23
|
+
from unique_toolkit.evals.output_parser import (
|
|
24
|
+
parse_eval_metric_result,
|
|
25
|
+
parse_eval_metric_result_structured_output,
|
|
26
|
+
)
|
|
27
|
+
from unique_toolkit.evals.schemas import (
|
|
28
|
+
EvaluationMetricInput,
|
|
29
|
+
EvaluationMetricInputFieldName,
|
|
30
|
+
EvaluationMetricName,
|
|
31
|
+
EvaluationMetricResult,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
from .prompts import (
|
|
36
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
|
37
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG_STRUCTURED_OUTPUT,
|
|
38
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG,
|
|
39
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG_STRUCTURED_OUTPUT,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
SYSTEM_MSG_KEY = "systemPrompt"
|
|
43
|
+
USER_MSG_KEY = "userPrompt"
|
|
44
|
+
|
|
45
|
+
default_config = EvaluationMetricConfig(
|
|
46
|
+
enabled=False,
|
|
47
|
+
name=EvaluationMetricName.CONTEXT_RELEVANCY,
|
|
48
|
+
language_model=LanguageModelInfo.from_name(
|
|
49
|
+
LanguageModelName.AZURE_GPT_4o_2024_1120
|
|
50
|
+
),
|
|
51
|
+
custom_prompts={
|
|
52
|
+
SYSTEM_MSG_KEY: CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
|
53
|
+
USER_MSG_KEY: CONTEXT_RELEVANCY_METRIC_USER_MSG,
|
|
54
|
+
},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
relevancy_required_input_fields = [
|
|
58
|
+
EvaluationMetricInputFieldName.INPUT_TEXT,
|
|
59
|
+
EvaluationMetricInputFieldName.CONTEXT_TEXTS,
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ContextRelevancyEvaluator:
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
event: ChatEvent,
|
|
67
|
+
):
|
|
68
|
+
self.chat_service = ChatService(event)
|
|
69
|
+
self.language_model_service = LanguageModelService(event)
|
|
70
|
+
self.logger = logging.getLogger(f"ContextRelevancyEvaluator.{__name__}")
|
|
71
|
+
|
|
72
|
+
async def analyze(
|
|
73
|
+
self,
|
|
74
|
+
input: EvaluationMetricInput,
|
|
75
|
+
config: EvaluationMetricConfig = default_config,
|
|
76
|
+
structured_output_schema: type[BaseModel] | None = None,
|
|
77
|
+
) -> EvaluationMetricResult | None:
|
|
78
|
+
"""
|
|
79
|
+
Analyzes the level of relevancy of a context by comparing
|
|
80
|
+
it with the input text.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
input (EvaluationMetricInput): The input for the metric.
|
|
84
|
+
config (EvaluationMetricConfig): The configuration for the metric.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
EvaluationMetricResult | None
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
EvaluatorException: If the context texts are empty or required fields are missing or error occurred during evaluation.
|
|
91
|
+
"""
|
|
92
|
+
if config.enabled is False:
|
|
93
|
+
self.logger.info("Hallucination metric is not enabled.")
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
input.validate_required_fields(relevancy_required_input_fields)
|
|
97
|
+
|
|
98
|
+
if len(input.context_texts) == 0: # type: ignore
|
|
99
|
+
error_message = "No context texts provided."
|
|
100
|
+
raise EvaluatorException(
|
|
101
|
+
user_message=error_message,
|
|
102
|
+
error_message=error_message,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Handle structured output if enabled and supported by the model
|
|
107
|
+
if (
|
|
108
|
+
structured_output_schema
|
|
109
|
+
and ModelCapabilities.STRUCTURED_OUTPUT
|
|
110
|
+
in config.language_model.capabilities
|
|
111
|
+
):
|
|
112
|
+
return await self._handle_structured_output(
|
|
113
|
+
input, config, structured_output_schema
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Handle regular output
|
|
117
|
+
return await self._handle_regular_output(input, config)
|
|
118
|
+
|
|
119
|
+
except Exception as e:
|
|
120
|
+
error_message = (
|
|
121
|
+
"Unknown error occurred during context relevancy metric analysis"
|
|
122
|
+
)
|
|
123
|
+
raise EvaluatorException(
|
|
124
|
+
error_message=f"{error_message}: {e}",
|
|
125
|
+
user_message=error_message,
|
|
126
|
+
exception=e,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
async def _handle_structured_output(
|
|
130
|
+
self,
|
|
131
|
+
input: EvaluationMetricInput,
|
|
132
|
+
config: EvaluationMetricConfig,
|
|
133
|
+
structured_output_schema: type[BaseModel],
|
|
134
|
+
) -> EvaluationMetricResult:
|
|
135
|
+
"""Handle the structured output case for context relevancy evaluation."""
|
|
136
|
+
self.logger.info("Using structured output for context relevancy evaluation.")
|
|
137
|
+
msgs = self._compose_msgs(input, config, enable_structured_output=True)
|
|
138
|
+
result = await self.language_model_service.complete_async(
|
|
139
|
+
messages=msgs,
|
|
140
|
+
model_name=config.language_model.name,
|
|
141
|
+
structured_output_model=structured_output_schema,
|
|
142
|
+
structured_output_enforce_schema=True,
|
|
143
|
+
other_options=config.additional_llm_options,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
result_content = EvaluationSchemaStructuredOutput.model_validate(
|
|
148
|
+
result.choices[0].message.parsed
|
|
149
|
+
)
|
|
150
|
+
except ValidationError as e:
|
|
151
|
+
error_message = "Error occurred during structured output validation of the context relevancy evaluation."
|
|
152
|
+
raise EvaluatorException(
|
|
153
|
+
error_message=error_message,
|
|
154
|
+
user_message=error_message,
|
|
155
|
+
exception=e,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return parse_eval_metric_result_structured_output(
|
|
159
|
+
result_content, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def _handle_regular_output(
|
|
163
|
+
self,
|
|
164
|
+
input: EvaluationMetricInput,
|
|
165
|
+
config: EvaluationMetricConfig,
|
|
166
|
+
) -> EvaluationMetricResult:
|
|
167
|
+
"""Handle the regular output case for context relevancy evaluation."""
|
|
168
|
+
msgs = self._compose_msgs(input, config, enable_structured_output=False)
|
|
169
|
+
result = await self.language_model_service.complete_async(
|
|
170
|
+
messages=msgs,
|
|
171
|
+
model_name=config.language_model.name,
|
|
172
|
+
other_options=config.additional_llm_options,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
result_content = result.choices[0].message.content
|
|
176
|
+
if not result_content or not isinstance(result_content, str):
|
|
177
|
+
error_message = "Context relevancy evaluation did not return a result."
|
|
178
|
+
raise EvaluatorException(
|
|
179
|
+
error_message=error_message,
|
|
180
|
+
user_message=error_message,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return parse_eval_metric_result(
|
|
184
|
+
result_content, EvaluationMetricName.CONTEXT_RELEVANCY
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _compose_msgs(
|
|
188
|
+
self,
|
|
189
|
+
input: EvaluationMetricInput,
|
|
190
|
+
config: EvaluationMetricConfig,
|
|
191
|
+
enable_structured_output: bool,
|
|
192
|
+
) -> LanguageModelMessages:
|
|
193
|
+
"""
|
|
194
|
+
Composes the messages for the relevancy metric.
|
|
195
|
+
"""
|
|
196
|
+
system_msg_content = self._get_system_prompt(config, enable_structured_output)
|
|
197
|
+
system_msg = Prompt(system_msg_content).to_system_msg()
|
|
198
|
+
|
|
199
|
+
user_msg = Prompt(
|
|
200
|
+
self._get_user_prompt(config, enable_structured_output),
|
|
201
|
+
input_text=input.input_text,
|
|
202
|
+
context_texts=input.get_joined_context_texts(),
|
|
203
|
+
).to_user_msg()
|
|
204
|
+
|
|
205
|
+
return LanguageModelMessages([system_msg, user_msg])
|
|
206
|
+
|
|
207
|
+
def _get_system_prompt(
|
|
208
|
+
self,
|
|
209
|
+
config: EvaluationMetricConfig,
|
|
210
|
+
enable_structured_output: bool,
|
|
211
|
+
):
|
|
212
|
+
if (
|
|
213
|
+
enable_structured_output
|
|
214
|
+
and ModelCapabilities.STRUCTURED_OUTPUT
|
|
215
|
+
in config.language_model.capabilities
|
|
216
|
+
):
|
|
217
|
+
return config.custom_prompts.setdefault(
|
|
218
|
+
SYSTEM_MSG_KEY,
|
|
219
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG_STRUCTURED_OUTPUT,
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
return config.custom_prompts.setdefault(
|
|
223
|
+
SYSTEM_MSG_KEY,
|
|
224
|
+
CONTEXT_RELEVANCY_METRIC_SYSTEM_MSG,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def _get_user_prompt(
|
|
228
|
+
self,
|
|
229
|
+
config: EvaluationMetricConfig,
|
|
230
|
+
enable_structured_output: bool,
|
|
231
|
+
):
|
|
232
|
+
if enable_structured_output:
|
|
233
|
+
return config.custom_prompts.setdefault(
|
|
234
|
+
USER_MSG_KEY,
|
|
235
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG_STRUCTURED_OUTPUT,
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
return config.custom_prompts.setdefault(
|
|
239
|
+
USER_MSG_KEY,
|
|
240
|
+
CONTEXT_RELEVANCY_METRIC_USER_MSG,
|
|
241
|
+
)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
from unique_toolkit._common.validators import LMI
|
|
6
|
+
from unique_toolkit.evals.config import EvaluationMetricConfig
|
|
7
|
+
from unique_toolkit.evals.hallucination.prompts import (
|
|
8
|
+
HALLUCINATION_METRIC_SYSTEM_MSG,
|
|
9
|
+
HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
|
|
10
|
+
HALLUCINATION_METRIC_USER_MSG,
|
|
11
|
+
HALLUCINATION_METRIC_USER_MSG_DEFAULT,
|
|
12
|
+
)
|
|
13
|
+
from unique_toolkit.evals.schemas import (
|
|
14
|
+
EvaluationMetricInputFieldName,
|
|
15
|
+
EvaluationMetricName,
|
|
16
|
+
)
|
|
17
|
+
from unique_toolkit.language_model.infos import LanguageModelInfo, LanguageModelName
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
SYSTEM_MSG_KEY = "systemPrompt"
|
|
21
|
+
USER_MSG_KEY = "userPrompt"
|
|
22
|
+
SYSTEM_MSG_DEFAULT_KEY = "systemPromptDefault"
|
|
23
|
+
USER_MSG_DEFAULT_KEY = "userPromptDefault"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HallucinationConfig(EvaluationMetricConfig):
|
|
27
|
+
enabled: bool = False
|
|
28
|
+
name: EvaluationMetricName = EvaluationMetricName.HALLUCINATION
|
|
29
|
+
language_model: LMI = LanguageModelInfo.from_name(
|
|
30
|
+
LanguageModelName.AZURE_GPT_35_TURBO_0125,
|
|
31
|
+
)
|
|
32
|
+
additional_llm_options: dict[str, Any] = Field(
|
|
33
|
+
default={},
|
|
34
|
+
description="Additional options to pass to the language model.",
|
|
35
|
+
)
|
|
36
|
+
custom_prompts: dict = {
|
|
37
|
+
SYSTEM_MSG_KEY: HALLUCINATION_METRIC_SYSTEM_MSG,
|
|
38
|
+
USER_MSG_KEY: HALLUCINATION_METRIC_USER_MSG,
|
|
39
|
+
SYSTEM_MSG_DEFAULT_KEY: HALLUCINATION_METRIC_SYSTEM_MSG_DEFAULT,
|
|
40
|
+
USER_MSG_DEFAULT_KEY: HALLUCINATION_METRIC_USER_MSG_DEFAULT,
|
|
41
|
+
}
|
|
42
|
+
score_to_label: dict = {
|
|
43
|
+
"LOW": "GREEN",
|
|
44
|
+
"MEDIUM": "YELLOW",
|
|
45
|
+
"HIGH": "RED",
|
|
46
|
+
}
|
|
47
|
+
score_to_title: dict = {
|
|
48
|
+
"LOW": "No Hallucination Detected",
|
|
49
|
+
"MEDIUM": "Hallucination Warning",
|
|
50
|
+
"HIGH": "High Hallucination",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
hallucination_metric_default_config = HallucinationConfig()
|
|
55
|
+
|
|
56
|
+
hallucination_required_input_fields = [
|
|
57
|
+
EvaluationMetricInputFieldName.INPUT_TEXT,
|
|
58
|
+
EvaluationMetricInputFieldName.CONTEXT_TEXTS,
|
|
59
|
+
EvaluationMetricInputFieldName.HISTORY_MESSAGES,
|
|
60
|
+
EvaluationMetricInputFieldName.OUTPUT_TEXT,
|
|
61
|
+
]
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from unique_toolkit.app.schemas import ChatEvent
|
|
4
|
+
from unique_toolkit.chat.schemas import (
|
|
5
|
+
ChatMessageAssessmentLabel,
|
|
6
|
+
ChatMessageAssessmentStatus,
|
|
7
|
+
ChatMessageAssessmentType,
|
|
8
|
+
)
|
|
9
|
+
from unique_toolkit.evals.evaluation_manager import Evaluation
|
|
10
|
+
from unique_toolkit.evals.hallucination.utils import check_hallucination
|
|
11
|
+
from unique_toolkit.evals.schemas import (
|
|
12
|
+
EvaluationAssessmentMessage,
|
|
13
|
+
EvaluationMetricInput,
|
|
14
|
+
EvaluationMetricName,
|
|
15
|
+
EvaluationMetricResult,
|
|
16
|
+
)
|
|
17
|
+
from unique_toolkit.evals.hallucination.constants import (
|
|
18
|
+
HallucinationConfig,
|
|
19
|
+
)
|
|
20
|
+
from unique_toolkit.reference_manager.reference_manager import (
|
|
21
|
+
ReferenceManager,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from unique_toolkit.language_model.schemas import (
|
|
25
|
+
LanguageModelStreamResponse,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class HallucinationEvaluation(Evaluation):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
config: HallucinationConfig,
|
|
33
|
+
event: ChatEvent,
|
|
34
|
+
reference_manager: ReferenceManager,
|
|
35
|
+
):
|
|
36
|
+
self.config = config
|
|
37
|
+
self._company_id = event.company_id
|
|
38
|
+
self._user_id = event.user_id
|
|
39
|
+
self._reference_manager = reference_manager
|
|
40
|
+
self._user_message = event.payload.user_message.text
|
|
41
|
+
super().__init__(EvaluationMetricName.HALLUCINATION)
|
|
42
|
+
|
|
43
|
+
async def run(
|
|
44
|
+
self, loop_response: LanguageModelStreamResponse
|
|
45
|
+
) -> EvaluationMetricResult: # type: ignore
|
|
46
|
+
chunks = self._reference_manager.get_chunks()
|
|
47
|
+
|
|
48
|
+
evaluation_result: EvaluationMetricResult = await check_hallucination(
|
|
49
|
+
company_id=self._company_id,
|
|
50
|
+
input=EvaluationMetricInput(
|
|
51
|
+
input_text=self._user_message,
|
|
52
|
+
context_texts=[context.text for context in chunks],
|
|
53
|
+
history_messages=[], # TODO include loop_history messages
|
|
54
|
+
output_text=loop_response.message.text,
|
|
55
|
+
),
|
|
56
|
+
config=self.config,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
score_to_label = self.config.score_to_label
|
|
60
|
+
evaluation_result.is_positive = (
|
|
61
|
+
score_to_label.get(evaluation_result.value.upper(), "RED") != "RED"
|
|
62
|
+
)
|
|
63
|
+
return evaluation_result
|
|
64
|
+
|
|
65
|
+
def get_assessment_type(self) -> ChatMessageAssessmentType:
|
|
66
|
+
return ChatMessageAssessmentType.HALLUCINATION
|
|
67
|
+
|
|
68
|
+
async def evaluation_metric_to_assessment(
|
|
69
|
+
self, evaluation_result: EvaluationMetricResult
|
|
70
|
+
) -> EvaluationAssessmentMessage:
|
|
71
|
+
title = self.config.score_to_title.get(
|
|
72
|
+
evaluation_result.value.upper(), evaluation_result.value
|
|
73
|
+
)
|
|
74
|
+
label = ChatMessageAssessmentLabel(
|
|
75
|
+
self.config.score_to_label.get(
|
|
76
|
+
evaluation_result.value.upper(), evaluation_result.value.upper()
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
status = (
|
|
80
|
+
ChatMessageAssessmentStatus.DONE
|
|
81
|
+
if not evaluation_result.error
|
|
82
|
+
else ChatMessageAssessmentStatus.ERROR
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return EvaluationAssessmentMessage(
|
|
86
|
+
status=status,
|
|
87
|
+
title=title,
|
|
88
|
+
explanation=evaluation_result.reason,
|
|
89
|
+
label=label,
|
|
90
|
+
type=self.get_assessment_type(),
|
|
91
|
+
)
|