ibm-watsonx-orchestrate-evaluation-framework 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +49 -39
- wxo_agentic_evaluation/analyze_run.py +822 -344
- wxo_agentic_evaluation/arg_configs.py +39 -2
- wxo_agentic_evaluation/data_annotator.py +22 -4
- wxo_agentic_evaluation/description_quality_checker.py +29 -4
- wxo_agentic_evaluation/evaluation_package.py +197 -18
- wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
- wxo_agentic_evaluation/external_agent/types.py +1 -1
- wxo_agentic_evaluation/inference_backend.py +105 -108
- wxo_agentic_evaluation/llm_matching.py +104 -2
- wxo_agentic_evaluation/llm_user.py +2 -2
- wxo_agentic_evaluation/main.py +147 -38
- wxo_agentic_evaluation/metrics/__init__.py +5 -0
- wxo_agentic_evaluation/metrics/evaluations.py +124 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +4 -3
- wxo_agentic_evaluation/metrics/metrics.py +64 -1
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
- wxo_agentic_evaluation/prompt/template_render.py +20 -2
- wxo_agentic_evaluation/quick_eval.py +23 -11
- wxo_agentic_evaluation/record_chat.py +18 -10
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +169 -100
- wxo_agentic_evaluation/red_teaming/attack_generator.py +63 -40
- wxo_agentic_evaluation/red_teaming/attack_list.py +78 -8
- wxo_agentic_evaluation/red_teaming/attack_runner.py +71 -14
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +10 -10
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +103 -39
- wxo_agentic_evaluation/resource_map.py +3 -1
- wxo_agentic_evaluation/service_instance.py +12 -3
- wxo_agentic_evaluation/service_provider/__init__.py +129 -9
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
- wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
- wxo_agentic_evaluation/service_provider/provider.py +130 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
- wxo_agentic_evaluation/type.py +15 -5
- wxo_agentic_evaluation/utils/__init__.py +44 -3
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/parsers.py +71 -0
- wxo_agentic_evaluation/utils/utils.py +140 -20
- wxo_agentic_evaluation/wxo_client.py +81 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
|
+
from enum import StrEnum
|
|
3
4
|
from typing import List, Optional, Union
|
|
4
5
|
|
|
5
6
|
from wxo_agentic_evaluation import __file__
|
|
@@ -30,7 +31,27 @@ class LLMUserConfig:
|
|
|
30
31
|
@dataclass
|
|
31
32
|
class ProviderConfig:
|
|
32
33
|
model_id: str = field(default="meta-llama/llama-3-405b-instruct")
|
|
33
|
-
provider: str = field(
|
|
34
|
+
provider: str = field(
|
|
35
|
+
default_factory=lambda: (
|
|
36
|
+
"gateway"
|
|
37
|
+
if os.getenv("USE_GATEWAY_MODEL_PROVIDER", "").lower() == "true"
|
|
38
|
+
else "watsonx"
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
embedding_model_id: str = field(
|
|
42
|
+
default="sentence-transformers/all-minilm-l6-v2"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class CustomMetricsConfig:
|
|
48
|
+
paths: Optional[list[str]] = field(default=None)
|
|
49
|
+
llmaaj_config: ProviderConfig = field(default_factory=ProviderConfig)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class ExtractorsConfig:
|
|
54
|
+
paths: Optional[list[str]] = field(default=None)
|
|
34
55
|
|
|
35
56
|
|
|
36
57
|
@dataclass
|
|
@@ -41,12 +62,18 @@ class TestConfig:
|
|
|
41
62
|
wxo_lite_version: str
|
|
42
63
|
provider_config: ProviderConfig = field(default_factory=ProviderConfig)
|
|
43
64
|
llm_user_config: LLMUserConfig = field(default_factory=LLMUserConfig)
|
|
65
|
+
custom_metrics_config: CustomMetricsConfig = field(
|
|
66
|
+
default_factory=CustomMetricsConfig
|
|
67
|
+
)
|
|
68
|
+
extrators_config: ExtractorsConfig = field(default_factory=ExtractorsConfig)
|
|
44
69
|
enable_verbose_logging: bool = True
|
|
45
70
|
enable_manual_user_input: bool = False
|
|
46
71
|
skip_available_results: bool = False
|
|
47
72
|
data_annotation_run: bool = False
|
|
48
73
|
num_workers: int = 2
|
|
49
74
|
n_runs: int = 1
|
|
75
|
+
similarity_threshold: float = 0.8
|
|
76
|
+
enable_fuzzy_matching: bool = False
|
|
50
77
|
|
|
51
78
|
|
|
52
79
|
@dataclass
|
|
@@ -59,22 +86,32 @@ class AttackConfig:
|
|
|
59
86
|
enable_verbose_logging: bool = True
|
|
60
87
|
enable_manual_user_input: bool = False
|
|
61
88
|
num_workers: int = 2
|
|
89
|
+
skip_available_results: bool = True
|
|
62
90
|
|
|
63
91
|
|
|
64
92
|
@dataclass
|
|
65
93
|
class AttackGeneratorConfig:
|
|
66
94
|
attacks_list: Union[List[str], str]
|
|
67
95
|
datasets_path: Union[List[str], str]
|
|
68
|
-
|
|
96
|
+
agents_list_or_path: Union[List[str], str]
|
|
69
97
|
target_agent_name: str
|
|
98
|
+
auth_config: AuthConfig
|
|
70
99
|
output_dir: str = None
|
|
71
100
|
max_variants: int = None
|
|
72
101
|
|
|
73
102
|
|
|
103
|
+
class AnalyzeMode(StrEnum):
|
|
104
|
+
default = "default"
|
|
105
|
+
enhanced = "enhanced"
|
|
106
|
+
|
|
107
|
+
|
|
74
108
|
@dataclass
|
|
75
109
|
class AnalyzeConfig:
|
|
76
110
|
data_path: str
|
|
77
111
|
tool_definition_path: Optional[str] = None
|
|
112
|
+
mode: str = AnalyzeMode.default
|
|
113
|
+
num_workers: int = 10
|
|
114
|
+
run: int = -1
|
|
78
115
|
|
|
79
116
|
|
|
80
117
|
@dataclass
|
|
@@ -3,7 +3,10 @@ import collections
|
|
|
3
3
|
import json
|
|
4
4
|
from typing import Dict, List, Optional
|
|
5
5
|
|
|
6
|
-
from wxo_agentic_evaluation.arg_configs import
|
|
6
|
+
from wxo_agentic_evaluation.arg_configs import (
|
|
7
|
+
ChatRecordingConfig,
|
|
8
|
+
KeywordsGenerationConfig,
|
|
9
|
+
)
|
|
7
10
|
from wxo_agentic_evaluation.prompt.template_render import (
|
|
8
11
|
LlamaKeywordsGenerationTemplateRenderer,
|
|
9
12
|
)
|
|
@@ -223,11 +226,23 @@ class DataAnnotator:
|
|
|
223
226
|
return goals, goal_details, previous
|
|
224
227
|
|
|
225
228
|
def _process_summarization(
|
|
226
|
-
self,
|
|
229
|
+
self,
|
|
230
|
+
previous: str,
|
|
231
|
+
goals: Dict,
|
|
232
|
+
goal_details: List,
|
|
233
|
+
config: ChatRecordingConfig = None,
|
|
227
234
|
) -> None:
|
|
228
235
|
"""Process summarization step"""
|
|
229
236
|
summarize_step = None
|
|
230
237
|
# we assume single summary step at the end
|
|
238
|
+
extra_kwargs = {}
|
|
239
|
+
instance_url = getattr(config, "service_url", None)
|
|
240
|
+
token = getattr(config, "token", None)
|
|
241
|
+
if instance_url:
|
|
242
|
+
extra_kwargs["instance_url"] = instance_url
|
|
243
|
+
if token:
|
|
244
|
+
extra_kwargs["token"] = token
|
|
245
|
+
|
|
231
246
|
for message in self.messages[::-1]:
|
|
232
247
|
if message.role == "assistant":
|
|
233
248
|
provider = get_provider(
|
|
@@ -237,6 +252,7 @@ class DataAnnotator:
|
|
|
237
252
|
"decoding_method": "greedy",
|
|
238
253
|
"max_new_tokens": 256,
|
|
239
254
|
},
|
|
255
|
+
**extra_kwargs,
|
|
240
256
|
)
|
|
241
257
|
kw_generator = KeywordsGenerationLLM(
|
|
242
258
|
provider=provider,
|
|
@@ -261,10 +277,12 @@ class DataAnnotator:
|
|
|
261
277
|
else:
|
|
262
278
|
goals[previous] = ["summarize"]
|
|
263
279
|
|
|
264
|
-
def generate(self) -> Dict:
|
|
280
|
+
def generate(self, config: ChatRecordingConfig = None) -> Dict:
|
|
265
281
|
"""Generate the final dataset"""
|
|
266
282
|
goals, goal_details, previous = self._process_tool_calls()
|
|
267
|
-
self._process_summarization(
|
|
283
|
+
self._process_summarization(
|
|
284
|
+
previous, goals, goal_details, config=config
|
|
285
|
+
)
|
|
268
286
|
|
|
269
287
|
return {
|
|
270
288
|
"agent": self.initial_data.agent,
|
|
@@ -5,6 +5,7 @@ from typing import List
|
|
|
5
5
|
|
|
6
6
|
import rich
|
|
7
7
|
|
|
8
|
+
from wxo_agentic_evaluation.metrics.metrics import DescriptionQualityMetric
|
|
8
9
|
from wxo_agentic_evaluation.prompt.template_render import (
|
|
9
10
|
BadToolDescriptionRenderer,
|
|
10
11
|
)
|
|
@@ -15,6 +16,9 @@ from wxo_agentic_evaluation.tool_planner import (
|
|
|
15
16
|
parse_json_string,
|
|
16
17
|
)
|
|
17
18
|
from wxo_agentic_evaluation.type import ToolDefinition
|
|
19
|
+
from wxo_agentic_evaluation.utils.gateway_provider_utils import (
|
|
20
|
+
get_provider_kwargs,
|
|
21
|
+
)
|
|
18
22
|
from wxo_agentic_evaluation.utils.utils import safe_divide
|
|
19
23
|
|
|
20
24
|
|
|
@@ -60,12 +64,23 @@ class DescriptionQualityInspector:
|
|
|
60
64
|
root_dir, "prompt", "bad_tool_descriptions_prompt.jinja2"
|
|
61
65
|
)
|
|
62
66
|
|
|
67
|
+
DEFAULT_PROVIDER_KWARGS = {
|
|
68
|
+
"model_id": LLM_MODEL_ID,
|
|
69
|
+
"params": LLM_PARAMS,
|
|
70
|
+
}
|
|
71
|
+
|
|
63
72
|
def __init__(self, llm_client=None):
|
|
73
|
+
|
|
64
74
|
if llm_client is None:
|
|
75
|
+
|
|
76
|
+
provider_kwargs = get_provider_kwargs(
|
|
77
|
+
**self.DEFAULT_PROVIDER_KWARGS,
|
|
78
|
+
)
|
|
79
|
+
|
|
65
80
|
llm_client = get_provider(
|
|
66
|
-
|
|
67
|
-
params=self.LLM_PARAMS,
|
|
81
|
+
**provider_kwargs,
|
|
68
82
|
)
|
|
83
|
+
|
|
69
84
|
self.llm_client = llm_client
|
|
70
85
|
self.template = BadToolDescriptionRenderer(
|
|
71
86
|
self.BAD_TOOL_DESCRIPTIONS_DETECTOR_PATH
|
|
@@ -106,7 +121,9 @@ class DescriptionQualityInspector:
|
|
|
106
121
|
)
|
|
107
122
|
return tool_definitions
|
|
108
123
|
|
|
109
|
-
def detect_bad_description(
|
|
124
|
+
def detect_bad_description(
|
|
125
|
+
self, tool_definition: ToolDefinition
|
|
126
|
+
) -> DescriptionQualityMetric:
|
|
110
127
|
"""
|
|
111
128
|
Detects if a tool description is 'bad' using an LLM judge.
|
|
112
129
|
A 'bad' description is one that:
|
|
@@ -119,6 +136,10 @@ class DescriptionQualityInspector:
|
|
|
119
136
|
Returns:
|
|
120
137
|
bool: True if the description is 'bad', False otherwise.
|
|
121
138
|
"""
|
|
139
|
+
|
|
140
|
+
if tool_definition.tool_description is None:
|
|
141
|
+
return DescriptionQualityMetric(tool_name=tool_definition.tool_name)
|
|
142
|
+
|
|
122
143
|
prompt = self.template.render(tool_definition=tool_definition)
|
|
123
144
|
response = self.llm_client.query(prompt)
|
|
124
145
|
|
|
@@ -137,7 +158,11 @@ class DescriptionQualityInspector:
|
|
|
137
158
|
response_data=response_data
|
|
138
159
|
)
|
|
139
160
|
|
|
140
|
-
return
|
|
161
|
+
return DescriptionQualityMetric(
|
|
162
|
+
tool_name=tool_definition.tool_name,
|
|
163
|
+
description_score=final_description_score,
|
|
164
|
+
threshold=self.CLASSIFICATION_SCORE_THRESHOLD,
|
|
165
|
+
)
|
|
141
166
|
|
|
142
167
|
def _calculate_score(self, response_data: dict) -> float:
|
|
143
168
|
"""
|
|
@@ -1,19 +1,27 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
from
|
|
3
|
+
from gc import enable
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
4
5
|
|
|
5
6
|
import rich
|
|
7
|
+
from dateutil import parser
|
|
6
8
|
|
|
7
9
|
from wxo_agentic_evaluation import __file__
|
|
8
10
|
from wxo_agentic_evaluation.data_annotator import ERROR_KEYWORDS
|
|
9
11
|
from wxo_agentic_evaluation.llm_matching import LLMMatcher
|
|
10
12
|
from wxo_agentic_evaluation.llm_rag_eval import LLMJudge
|
|
11
13
|
from wxo_agentic_evaluation.llm_safety_eval import LLMSafetyJudge
|
|
14
|
+
from wxo_agentic_evaluation.metrics.evaluations import (
|
|
15
|
+
Evaluation,
|
|
16
|
+
Extractor,
|
|
17
|
+
Metric,
|
|
18
|
+
)
|
|
12
19
|
from wxo_agentic_evaluation.metrics.llm_as_judge import (
|
|
13
20
|
AnswerDerailment,
|
|
14
21
|
AnswerUnsafeTopic,
|
|
15
22
|
)
|
|
16
23
|
from wxo_agentic_evaluation.metrics.metrics import (
|
|
24
|
+
CustomEvalMetrics,
|
|
17
25
|
KeywordSemanticSearchMetric,
|
|
18
26
|
KnowledgeBaseMetrics,
|
|
19
27
|
TextMatchType,
|
|
@@ -28,7 +36,12 @@ from wxo_agentic_evaluation.prompt.template_render import (
|
|
|
28
36
|
UnsafeTopicTemplateRenderer,
|
|
29
37
|
)
|
|
30
38
|
from wxo_agentic_evaluation.resource_map import ResourceMap
|
|
31
|
-
from wxo_agentic_evaluation.
|
|
39
|
+
from wxo_agentic_evaluation.service_instance import tenant_setup
|
|
40
|
+
from wxo_agentic_evaluation.service_provider import (
|
|
41
|
+
USE_GATEWAY_MODEL_PROVIDER,
|
|
42
|
+
get_provider,
|
|
43
|
+
)
|
|
44
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
32
45
|
from wxo_agentic_evaluation.type import (
|
|
33
46
|
ContentType,
|
|
34
47
|
ConversationalSearch,
|
|
@@ -76,12 +89,18 @@ DUMMY_GRAPH_NODE_NAME = "dummy-goal"
|
|
|
76
89
|
class EvaluationPackage:
|
|
77
90
|
def __init__(
|
|
78
91
|
self,
|
|
79
|
-
test_case_name,
|
|
80
|
-
ground_truth,
|
|
81
|
-
messages,
|
|
92
|
+
test_case_name: str,
|
|
93
|
+
ground_truth: EvaluationData,
|
|
94
|
+
messages: list[Message],
|
|
82
95
|
conversational_search_data: List[ConversationalSearch] = None,
|
|
83
96
|
resource_map: ResourceMap = None,
|
|
84
97
|
is_attack_evaluation: bool = False,
|
|
98
|
+
config=None,
|
|
99
|
+
custom_evals: Optional[list[Evaluation]] = None,
|
|
100
|
+
custom_llmaaj_client: Optional[Provider] = None,
|
|
101
|
+
extractors: Optional[list[Extractor]] = None,
|
|
102
|
+
similarity_threshold=0.8,
|
|
103
|
+
enable_fuzzy_matching=False,
|
|
85
104
|
):
|
|
86
105
|
self.tool_dictionary = (
|
|
87
106
|
{
|
|
@@ -103,16 +122,56 @@ class EvaluationPackage:
|
|
|
103
122
|
else []
|
|
104
123
|
)
|
|
105
124
|
|
|
106
|
-
self.messages = messages
|
|
125
|
+
self.messages: List[Message] = messages
|
|
107
126
|
self.conversational_search_data = conversational_search_data
|
|
108
127
|
self.is_attack_evaluation = is_attack_evaluation
|
|
109
128
|
self.ground_truth = ground_truth
|
|
110
129
|
self.test_case_name = test_case_name
|
|
111
130
|
self.resource_map = resource_map
|
|
131
|
+
self.custom_evals = custom_evals
|
|
132
|
+
self.custom_llmaaj_client = custom_llmaaj_client
|
|
133
|
+
self.extractors = extractors
|
|
134
|
+
self.enable_fuzzy_matching = enable_fuzzy_matching
|
|
112
135
|
|
|
113
136
|
if not self.is_attack_evaluation:
|
|
114
137
|
self.validate_ground_truth(self.ground_truth, self.test_case_name)
|
|
115
138
|
|
|
139
|
+
extra_kwargs = {}
|
|
140
|
+
|
|
141
|
+
if USE_GATEWAY_MODEL_PROVIDER:
|
|
142
|
+
|
|
143
|
+
if resource_map and hasattr(resource_map, "wxo_client"):
|
|
144
|
+
wxo_client = resource_map.wxo_client
|
|
145
|
+
|
|
146
|
+
if hasattr(wxo_client, "service_url"):
|
|
147
|
+
extra_kwargs["instance_url"] = wxo_client.service_url
|
|
148
|
+
|
|
149
|
+
if hasattr(wxo_client, "api_key"):
|
|
150
|
+
extra_kwargs["token"] = wxo_client.api_key
|
|
151
|
+
|
|
152
|
+
elif config:
|
|
153
|
+
auth = getattr(config, "auth_config", None)
|
|
154
|
+
|
|
155
|
+
if auth:
|
|
156
|
+
instance_url = getattr(auth, "url", None)
|
|
157
|
+
token = getattr(auth, "token", None)
|
|
158
|
+
|
|
159
|
+
if instance_url:
|
|
160
|
+
extra_kwargs["instance_url"] = instance_url
|
|
161
|
+
|
|
162
|
+
if token:
|
|
163
|
+
extra_kwargs["token"] = token
|
|
164
|
+
else:
|
|
165
|
+
token, instance_url, env = tenant_setup(
|
|
166
|
+
service_url=None, tenant_name="local"
|
|
167
|
+
)
|
|
168
|
+
if instance_url:
|
|
169
|
+
extra_kwargs["instance_url"] = instance_url
|
|
170
|
+
|
|
171
|
+
if token:
|
|
172
|
+
extra_kwargs["token"] = token
|
|
173
|
+
|
|
174
|
+
# output response matching
|
|
116
175
|
self.matcher = LLMMatcher(
|
|
117
176
|
llm_client=get_provider(
|
|
118
177
|
model_id="meta-llama/llama-3-405b-instruct",
|
|
@@ -121,6 +180,8 @@ class EvaluationPackage:
|
|
|
121
180
|
"decoding_method": "greedy",
|
|
122
181
|
"max_new_tokens": 10,
|
|
123
182
|
},
|
|
183
|
+
embedding_model_id="sentence-transformers/all-minilm-l6-v2",
|
|
184
|
+
**extra_kwargs,
|
|
124
185
|
),
|
|
125
186
|
keyword_template=KeywordMatchingTemplateRenderer(
|
|
126
187
|
KEYWORD_MATCHING_PROMPT_PATH
|
|
@@ -128,7 +189,10 @@ class EvaluationPackage:
|
|
|
128
189
|
semantic_template=SemanticMatchingTemplateRenderer(
|
|
129
190
|
SEMANTIC_MATCHING_PROMPT_PATH
|
|
130
191
|
),
|
|
192
|
+
similarity_threshold=similarity_threshold,
|
|
193
|
+
enable_fuzzy_matching=enable_fuzzy_matching,
|
|
131
194
|
)
|
|
195
|
+
# only used for RAG evaluation
|
|
132
196
|
self.rag_llm_as_a_judge = LLMJudge(
|
|
133
197
|
llm_client=get_provider(
|
|
134
198
|
model_id="meta-llama/llama-3-405b-instruct",
|
|
@@ -137,6 +201,7 @@ class EvaluationPackage:
|
|
|
137
201
|
"decoding_method": "greedy",
|
|
138
202
|
"max_new_tokens": 4096,
|
|
139
203
|
},
|
|
204
|
+
**extra_kwargs,
|
|
140
205
|
),
|
|
141
206
|
faithfulness=FaithfulnessTemplateRenderer(FAITHFULNESS_PROMPT_PATH),
|
|
142
207
|
answer_relevancy=AnswerRelevancyTemplateRenderer(
|
|
@@ -151,6 +216,7 @@ class EvaluationPackage:
|
|
|
151
216
|
"decoding_method": "greedy",
|
|
152
217
|
"max_new_tokens": 4096,
|
|
153
218
|
},
|
|
219
|
+
**extra_kwargs,
|
|
154
220
|
),
|
|
155
221
|
answer_derailment=DerailmentTemplateRenderer(
|
|
156
222
|
DERAILMENT_PROMPT_PATH
|
|
@@ -303,8 +369,48 @@ class EvaluationPackage:
|
|
|
303
369
|
return str(data).lower()
|
|
304
370
|
|
|
305
371
|
@staticmethod
|
|
372
|
+
def _compare_as_date_or_number(normalized_actual, normalized_expected):
|
|
373
|
+
"""
|
|
374
|
+
Attempts to compare two normalized values as dates or numbers.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
normalized_actual: The actual value from tool call
|
|
378
|
+
normalized_expected: The expected value from ground truth
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
tuple: (conversion_succeeded, values_match)
|
|
382
|
+
- conversion_succeeded: True if values could be converted to numbers or dates
|
|
383
|
+
- values_match: True if converted values match
|
|
384
|
+
"""
|
|
385
|
+
# Try to convert to numbers
|
|
386
|
+
try:
|
|
387
|
+
num_actual = float(normalized_actual)
|
|
388
|
+
num_expected = float(normalized_expected)
|
|
389
|
+
# Conversion succeeded, check if values match
|
|
390
|
+
return (
|
|
391
|
+
True,
|
|
392
|
+
abs(num_actual - num_expected) <= 0.001,
|
|
393
|
+
) # Small epsilon for float comparison
|
|
394
|
+
except (ValueError, TypeError):
|
|
395
|
+
pass
|
|
396
|
+
|
|
397
|
+
# Try to convert to dates
|
|
398
|
+
try:
|
|
399
|
+
date_actual = parser.parse(normalized_actual)
|
|
400
|
+
date_expected = parser.parse(normalized_expected)
|
|
401
|
+
# Conversion succeeded, check if values match
|
|
402
|
+
return True, date_actual == date_expected
|
|
403
|
+
except (ValueError, TypeError):
|
|
404
|
+
pass
|
|
405
|
+
|
|
406
|
+
# If we get here, neither number nor date conversion worked
|
|
407
|
+
return False, False
|
|
408
|
+
|
|
306
409
|
def _check_if_args_match_with_ignore(
|
|
307
|
-
|
|
410
|
+
self,
|
|
411
|
+
actual_args: dict[str, str],
|
|
412
|
+
expected_args: dict[str, str],
|
|
413
|
+
enable_fuzzy_matching: bool = False,
|
|
308
414
|
) -> bool:
|
|
309
415
|
"""
|
|
310
416
|
This function checks if a registered tool call matches with the goal node when:
|
|
@@ -313,21 +419,50 @@ class EvaluationPackage:
|
|
|
313
419
|
actual_args (dict): Made during inference.
|
|
314
420
|
expected_args (dict): Defined in the test case/ground truth.
|
|
315
421
|
Returns:
|
|
316
|
-
bool: True if match with keyword parameters ignored | False otherwise (
|
|
422
|
+
bool: True if match with keyword parameters ignored | False otherwise (arguments were not corrected).
|
|
317
423
|
"""
|
|
318
|
-
|
|
319
424
|
if set(actual_args.keys()) != set(expected_args.keys()):
|
|
320
425
|
return False
|
|
321
426
|
|
|
427
|
+
## now we go through and check each parameter
|
|
322
428
|
for key in actual_args:
|
|
429
|
+
normalized_actual = EvaluationPackage.normalize_args(
|
|
430
|
+
actual_args[key]
|
|
431
|
+
)
|
|
432
|
+
normalized_expected = EvaluationPackage.normalize_args(
|
|
433
|
+
expected_args[key]
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# 1. If the args are an ignored keyword or exactly equal, continue to next parameter
|
|
323
437
|
if (
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
438
|
+
normalized_expected == RESERVED_KEYWORD_FOR_GROUND_TRUTH_ARGS
|
|
439
|
+
) or (normalized_actual == normalized_expected):
|
|
440
|
+
continue
|
|
441
|
+
else:
|
|
442
|
+
# if they're not equal, and fuzzy matching is enabled, do fuzzy.
|
|
443
|
+
if enable_fuzzy_matching:
|
|
444
|
+
# 3. Check date/number conversion
|
|
445
|
+
conversion_succeeded, values_match = (
|
|
446
|
+
EvaluationPackage._compare_as_date_or_number(
|
|
447
|
+
normalized_actual, normalized_expected
|
|
448
|
+
)
|
|
449
|
+
)
|
|
450
|
+
# If conversion succeeded and values match, continue to next parameter
|
|
451
|
+
if conversion_succeeded and values_match:
|
|
452
|
+
continue
|
|
453
|
+
# If conversion succeeded but values don't match, return False
|
|
454
|
+
if conversion_succeeded and not values_match:
|
|
455
|
+
return False
|
|
456
|
+
# 4. If conversion failed, try cosine matching. If this fails, return false for the function
|
|
457
|
+
if not self.matcher.cosine_similarity_semantic_match(
|
|
458
|
+
normalized_actual, normalized_expected
|
|
459
|
+
):
|
|
460
|
+
return False
|
|
461
|
+
else:
|
|
462
|
+
# If they're not equal and fuzzy matching is not enabled, return false
|
|
463
|
+
return False
|
|
330
464
|
|
|
465
|
+
# If we've made it through all parameters without returning False, return True
|
|
331
466
|
return True
|
|
332
467
|
|
|
333
468
|
def traverse(self):
|
|
@@ -399,8 +534,10 @@ class EvaluationPackage:
|
|
|
399
534
|
goal_detail.args
|
|
400
535
|
)
|
|
401
536
|
or self._check_if_args_match_with_ignore(
|
|
402
|
-
msg_tool_call["args"],
|
|
403
|
-
|
|
537
|
+
msg_tool_call["args"],
|
|
538
|
+
goal_detail.args,
|
|
539
|
+
enable_fuzzy_matching=self.enable_fuzzy_matching,
|
|
540
|
+
) # TODO arjun-gupta1 9/29/25: make this also return the method of matching (llm, fuzzy, cosine similarity) so we can write it out to analyze_run.py results
|
|
404
541
|
):
|
|
405
542
|
labelled_messages.append(goal_detail.name)
|
|
406
543
|
labelled_messages_without_text_step.append(
|
|
@@ -470,6 +607,7 @@ class EvaluationPackage:
|
|
|
470
607
|
if message.event == EventTypes.message_created
|
|
471
608
|
and message.role == "assistant"
|
|
472
609
|
]
|
|
610
|
+
|
|
473
611
|
keyword_semantic_list = []
|
|
474
612
|
for message in assistant_responses:
|
|
475
613
|
for goal_detail in self.text_list:
|
|
@@ -478,7 +616,10 @@ class EvaluationPackage:
|
|
|
478
616
|
message.content, goal_detail.keywords
|
|
479
617
|
)
|
|
480
618
|
semantic_match: bool = self.matcher.semantic_match(
|
|
481
|
-
|
|
619
|
+
self.messages[0].content,
|
|
620
|
+
prediction=message.content,
|
|
621
|
+
ground_truth=goal_detail.response,
|
|
622
|
+
enable_fuzzy_matching=self.enable_fuzzy_matching,
|
|
482
623
|
)
|
|
483
624
|
keyword_semantic_match = KeywordSemanticSearchMetric(
|
|
484
625
|
keyword_match=keyword_match,
|
|
@@ -513,6 +654,29 @@ class EvaluationPackage:
|
|
|
513
654
|
else:
|
|
514
655
|
return TextMatchType.text_mismatch.value
|
|
515
656
|
|
|
657
|
+
def generate_custom_metrics(
|
|
658
|
+
self, extracted_context: Dict[str, Any]
|
|
659
|
+
) -> Optional[CustomEvalMetrics]:
|
|
660
|
+
if self.custom_evals is None:
|
|
661
|
+
return None
|
|
662
|
+
|
|
663
|
+
results: list[Metric] = []
|
|
664
|
+
for evaluation in self.custom_evals:
|
|
665
|
+
# TODO: cleanup. The compute method returns a Metric but pydantic thinks it is different.
|
|
666
|
+
# Probably because of some path issue when we auto-discover metrics
|
|
667
|
+
evaluate_result = evaluation.evaluate(
|
|
668
|
+
messages=self.messages,
|
|
669
|
+
ground_truth=self.ground_truth,
|
|
670
|
+
extracted_context=extracted_context,
|
|
671
|
+
)
|
|
672
|
+
if evaluate_result is not None:
|
|
673
|
+
results.append(Metric(**evaluate_result.model_dump()))
|
|
674
|
+
|
|
675
|
+
custom_eval_results = CustomEvalMetrics(
|
|
676
|
+
dataset_name=self.test_case_name, custom_metrics=results
|
|
677
|
+
)
|
|
678
|
+
return custom_eval_results
|
|
679
|
+
|
|
516
680
|
def generate_summary(self):
|
|
517
681
|
llm_steps = 0
|
|
518
682
|
total_step = 0
|
|
@@ -525,6 +689,16 @@ class EvaluationPackage:
|
|
|
525
689
|
message_with_reasons,
|
|
526
690
|
) = self.traverse()
|
|
527
691
|
|
|
692
|
+
extracted_context = {}
|
|
693
|
+
if self.extractors is not None and self.custom_evals is not None:
|
|
694
|
+
for extractor in self.extractors:
|
|
695
|
+
context = extractor.extract(
|
|
696
|
+
messages=self.messages,
|
|
697
|
+
ground_truth=self.ground_truth,
|
|
698
|
+
matcher=self.matcher,
|
|
699
|
+
)
|
|
700
|
+
extracted_context[extractor.name] = context
|
|
701
|
+
|
|
528
702
|
is_success = self.is_topological_sort(
|
|
529
703
|
self.ground_truth.goals, labelled_messages
|
|
530
704
|
)
|
|
@@ -545,6 +719,10 @@ class EvaluationPackage:
|
|
|
545
719
|
knowledge_base_metric_summary = (
|
|
546
720
|
self.generate_knowledge_base_metric_summary()
|
|
547
721
|
)
|
|
722
|
+
|
|
723
|
+
custom_metric_summary = self.generate_custom_metrics(
|
|
724
|
+
extracted_context=extracted_context
|
|
725
|
+
)
|
|
548
726
|
# TO-DO: the table is not printing properly anymore with the new columns introduced
|
|
549
727
|
# we need to introduce a separate table for these.
|
|
550
728
|
|
|
@@ -558,6 +736,7 @@ class EvaluationPackage:
|
|
|
558
736
|
knowledge_base_metric_summary,
|
|
559
737
|
message_with_reasons,
|
|
560
738
|
metrics,
|
|
739
|
+
custom_metric_summary,
|
|
561
740
|
)
|
|
562
741
|
|
|
563
742
|
def _get_messages_by_role_before_cs(
|
|
@@ -74,7 +74,9 @@ class ExternalAgentValidation:
|
|
|
74
74
|
payload = {"stream": True}
|
|
75
75
|
payload["messages"] = messages
|
|
76
76
|
resp = requests.post(
|
|
77
|
-
url=self.service_url,
|
|
77
|
+
url=self.service_url,
|
|
78
|
+
headers=self.header,
|
|
79
|
+
json=payload,
|
|
78
80
|
)
|
|
79
81
|
success, logged_events = self._validate_streaming_response(resp)
|
|
80
82
|
|