eval-ai-library 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eval-ai-library might be problematic. Click here for more details.
- eval_ai_library-0.1.0.dist-info/METADATA +753 -0
- eval_ai_library-0.1.0.dist-info/RECORD +34 -0
- eval_ai_library-0.1.0.dist-info/WHEEL +5 -0
- eval_ai_library-0.1.0.dist-info/licenses/LICENSE +21 -0
- eval_ai_library-0.1.0.dist-info/top_level.txt +1 -0
- eval_lib/__init__.py +122 -0
- eval_lib/agent_metrics/__init__.py +12 -0
- eval_lib/agent_metrics/knowledge_retention_metric/knowledge_retention.py +231 -0
- eval_lib/agent_metrics/role_adherence_metric/role_adherence.py +251 -0
- eval_lib/agent_metrics/task_success_metric/task_success_rate.py +347 -0
- eval_lib/agent_metrics/tools_correctness_metric/tool_correctness.py +106 -0
- eval_lib/datagenerator/datagenerator.py +230 -0
- eval_lib/datagenerator/document_loader.py +510 -0
- eval_lib/datagenerator/prompts.py +192 -0
- eval_lib/evaluate.py +335 -0
- eval_lib/evaluation_schema.py +63 -0
- eval_lib/llm_client.py +286 -0
- eval_lib/metric_pattern.py +229 -0
- eval_lib/metrics/__init__.py +25 -0
- eval_lib/metrics/answer_precision_metric/answer_precision.py +405 -0
- eval_lib/metrics/answer_relevancy_metric/answer_relevancy.py +195 -0
- eval_lib/metrics/bias_metric/bias.py +114 -0
- eval_lib/metrics/contextual_precision_metric/contextual_precision.py +102 -0
- eval_lib/metrics/contextual_recall_metric/contextual_recall.py +91 -0
- eval_lib/metrics/contextual_relevancy_metric/contextual_relevancy.py +169 -0
- eval_lib/metrics/custom_metric/custom_eval.py +303 -0
- eval_lib/metrics/faithfulness_metric/faithfulness.py +140 -0
- eval_lib/metrics/geval/geval.py +326 -0
- eval_lib/metrics/restricted_refusal_metric/restricted_refusal.py +102 -0
- eval_lib/metrics/toxicity_metric/toxicity.py +113 -0
- eval_lib/price.py +37 -0
- eval_lib/py.typed +1 -0
- eval_lib/testcases_schema.py +27 -0
- eval_lib/utils.py +99 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
# task_success_rate.py
|
|
2
|
+
"""
|
|
3
|
+
Task Success Rate Metric: Evaluates whether the AI assistant successfully helped
|
|
4
|
+
the user achieve their goal in a conversation.
|
|
5
|
+
|
|
6
|
+
Score calculation: Softmax aggregation of success criteria verdicts
|
|
7
|
+
"""
|
|
8
|
+
import json
|
|
9
|
+
from typing import List, Dict, Any, Tuple
|
|
10
|
+
from eval_lib.testcases_schema import ConversationalEvalTestCase
|
|
11
|
+
from eval_lib.metric_pattern import ConversationalMetricPattern
|
|
12
|
+
from eval_lib.llm_client import chat_complete
|
|
13
|
+
from eval_lib.utils import score_agg, extract_json_block
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Verdict weights for task completion levels
|
|
17
|
+
VERDICT_WEIGHTS = {
|
|
18
|
+
"fully": 1.0, # Criterion completely satisfied
|
|
19
|
+
"mostly": 0.9, # Criterion largely satisfied with minor gaps
|
|
20
|
+
"partial": 0.7, # Criterion partially satisfied
|
|
21
|
+
"minor": 0.3, # Criterion minimally addressed
|
|
22
|
+
"none": 0.0 # Criterion not satisfied at all
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
# Configuration constants
|
|
26
|
+
MAX_CRITERIA = 10
|
|
27
|
+
LINK_CRITERION = "The user got the link to the requested resource."
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TaskSuccessRateMetric(ConversationalMetricPattern):
|
|
31
|
+
"""
|
|
32
|
+
Evaluates whether an AI assistant successfully helped the user complete
|
|
33
|
+
their intended task across a multi-turn conversation.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = "taskSuccessRateMetric"
|
|
37
|
+
template_cls = None
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
model: str,
|
|
42
|
+
threshold: float = 0.7,
|
|
43
|
+
temperature: float = 1.1,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize Task Success Rate metric.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model: LLM model name
|
|
50
|
+
threshold: Success threshold (0.0-1.0)
|
|
51
|
+
temperature: Score aggregation temperature for softmax
|
|
52
|
+
"""
|
|
53
|
+
super().__init__(model=model, threshold=threshold)
|
|
54
|
+
self.temperature = temperature
|
|
55
|
+
|
|
56
|
+
# ==================== HELPER METHODS ====================
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _render_dialogue(turns) -> str:
|
|
60
|
+
"""Convert conversation turns into readable format"""
|
|
61
|
+
return "\n".join(
|
|
62
|
+
f"{i+1}. User: {t.input}\n Assistant: {t.actual_output}"
|
|
63
|
+
for i, t in enumerate(turns)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _prompt_label_help() -> str:
|
|
68
|
+
"""Explanation of task success verdict levels"""
|
|
69
|
+
return """Rate task success criteria satisfaction (worst → best):
|
|
70
|
+
|
|
71
|
+
none – criterion not satisfied at all
|
|
72
|
+
minor – criterion minimally addressed
|
|
73
|
+
partial – criterion partially satisfied
|
|
74
|
+
mostly – criterion largely satisfied with minor gaps
|
|
75
|
+
fully – criterion completely satisfied"""
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _prompt_criteria_few_shot() -> str:
|
|
79
|
+
"""Few-shot examples for criteria generation"""
|
|
80
|
+
return """Example 1:
|
|
81
|
+
User goal: Order a pizza online
|
|
82
|
+
Criteria: [
|
|
83
|
+
"The assistant provided available pizza options.",
|
|
84
|
+
"The user received an order confirmation number."
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
Example 2:
|
|
88
|
+
User goal: Reset an email password
|
|
89
|
+
Criteria: [
|
|
90
|
+
"The assistant gave a working password-reset link.",
|
|
91
|
+
"The user confirmed they could log in."
|
|
92
|
+
]"""
|
|
93
|
+
|
|
94
|
+
# ==================== CORE EVALUATION STEPS ====================
|
|
95
|
+
|
|
96
|
+
async def _infer_user_goal(self, dialogue: str) -> Tuple[str, float]:
|
|
97
|
+
"""
|
|
98
|
+
Infer the user's primary goal from the conversation.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
dialogue: Formatted conversation text
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple of (user_goal_description, llm_cost)
|
|
105
|
+
"""
|
|
106
|
+
prompt = (
|
|
107
|
+
"You will be shown an ENTIRE conversation between a user and an assistant.\n"
|
|
108
|
+
"Write ONE concise sentence describing the user's PRIMARY GOAL in this conversation.\n\n"
|
|
109
|
+
f"CONVERSATION:\n{dialogue}\n\n"
|
|
110
|
+
"User goal:"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
text, cost = await chat_complete(
|
|
114
|
+
self.model,
|
|
115
|
+
messages=[{"role": "user", "content": prompt}],
|
|
116
|
+
temperature=0.0
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return text.strip(), cost or 0.0
|
|
120
|
+
|
|
121
|
+
async def _generate_success_criteria(self, goal: str) -> Tuple[List[str], float]:
|
|
122
|
+
"""
|
|
123
|
+
Generate concrete success criteria for the user's goal.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
goal: The inferred user goal
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Tuple of (criteria_list, llm_cost)
|
|
130
|
+
"""
|
|
131
|
+
prompt = (
|
|
132
|
+
f"{self._prompt_criteria_few_shot()}\n\n"
|
|
133
|
+
f"Now do the same for the next case.\n\n"
|
|
134
|
+
f"User goal: {goal}\n\n"
|
|
135
|
+
f"List up to {MAX_CRITERIA} concrete SUCCESS CRITERIA that could realistically be satisfied "
|
|
136
|
+
f"within a brief chat of 2–5 turns. "
|
|
137
|
+
"Then **add** this exact sentence: "
|
|
138
|
+
f"\"{LINK_CRITERION}\"\n\n"
|
|
139
|
+
"Each criterion must be a short, observable statement.\n"
|
|
140
|
+
"Return only a JSON array of strings."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
text, cost = await chat_complete(
|
|
144
|
+
self.model,
|
|
145
|
+
messages=[{"role": "user", "content": prompt}],
|
|
146
|
+
temperature=0.0
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
raw_json = extract_json_block(text)
|
|
151
|
+
criteria = json.loads(raw_json)
|
|
152
|
+
|
|
153
|
+
if not isinstance(criteria, list):
|
|
154
|
+
raise ValueError("Expected JSON array of criteria")
|
|
155
|
+
|
|
156
|
+
# Ensure LINK_CRITERION is included
|
|
157
|
+
if LINK_CRITERION not in criteria:
|
|
158
|
+
criteria.append(LINK_CRITERION)
|
|
159
|
+
|
|
160
|
+
# Keep LINK_CRITERION first and limit to MAX_CRITERIA
|
|
161
|
+
if len(criteria) > MAX_CRITERIA:
|
|
162
|
+
criteria = (
|
|
163
|
+
[LINK_CRITERION] +
|
|
164
|
+
[c for c in criteria if c != LINK_CRITERION][:MAX_CRITERIA - 1]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Truncate to MAX_CRITERIA
|
|
168
|
+
criteria = criteria[:MAX_CRITERIA]
|
|
169
|
+
|
|
170
|
+
return criteria, cost or 0.0
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise RuntimeError(
|
|
174
|
+
f"Failed to parse success criteria: {e}\n{text}")
|
|
175
|
+
|
|
176
|
+
async def _generate_verdicts(
|
|
177
|
+
self,
|
|
178
|
+
goal: str,
|
|
179
|
+
criteria: List[str],
|
|
180
|
+
dialogue: str
|
|
181
|
+
) -> Tuple[List[Dict[str, str]], float, float]:
|
|
182
|
+
"""
|
|
183
|
+
Generate verdicts for each success criterion.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
goal: The user's goal
|
|
187
|
+
criteria: List of success criteria
|
|
188
|
+
dialogue: Formatted conversation text
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Tuple of (verdicts_list, aggregated_score, llm_cost)
|
|
192
|
+
"""
|
|
193
|
+
prompt = (
|
|
194
|
+
f"{self._prompt_label_help()}\n\n"
|
|
195
|
+
f"USER GOAL: {goal}\n\n"
|
|
196
|
+
f"FULL DIALOGUE:\n{dialogue}\n\n"
|
|
197
|
+
f"SUCCESS CRITERIA (as JSON array):\n{json.dumps(criteria, ensure_ascii=False)}\n\n"
|
|
198
|
+
"For **each** criterion, decide how well it is satisfied at the END of the dialogue.\n"
|
|
199
|
+
"Use exactly one of: fully, mostly, partial, minor, none.\n\n"
|
|
200
|
+
"Return JSON array with **exactly the same length and order** as the criteria list:\n"
|
|
201
|
+
"[{\"verdict\":\"fully|mostly|partial|minor|none\",\"reason\":\"<one sentence>\"}, …]\n\n"
|
|
202
|
+
"No extra text."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
text, cost = await chat_complete(
|
|
206
|
+
self.model,
|
|
207
|
+
messages=[{"role": "user", "content": prompt}],
|
|
208
|
+
temperature=0.0
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
raw_json = extract_json_block(text)
|
|
213
|
+
verdicts = json.loads(raw_json)
|
|
214
|
+
|
|
215
|
+
if not isinstance(verdicts, list):
|
|
216
|
+
raise ValueError("Expected JSON array of verdicts")
|
|
217
|
+
|
|
218
|
+
# Ensure verdicts match criteria length
|
|
219
|
+
if len(verdicts) != len(criteria):
|
|
220
|
+
# Pad or truncate to match
|
|
221
|
+
if len(verdicts) < len(criteria):
|
|
222
|
+
verdicts.extend(
|
|
223
|
+
[{"verdict": "none", "reason": "Missing evaluation"}] * (len(criteria) - len(verdicts)))
|
|
224
|
+
else:
|
|
225
|
+
verdicts = verdicts[:len(criteria)]
|
|
226
|
+
|
|
227
|
+
# Calculate aggregated score from verdicts
|
|
228
|
+
weights = [VERDICT_WEIGHTS.get(
|
|
229
|
+
v.get("verdict", "none"), 0.0) for v in verdicts]
|
|
230
|
+
score = round(score_agg(weights, temperature=self.temperature), 4)
|
|
231
|
+
|
|
232
|
+
return verdicts, score, cost or 0.0
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
raise RuntimeError(f"Failed to parse verdicts: {e}\n{text}")
|
|
236
|
+
|
|
237
|
+
async def _summarize_verdicts(
|
|
238
|
+
self,
|
|
239
|
+
verdicts: List[Dict[str, str]]
|
|
240
|
+
) -> Tuple[str, float]:
|
|
241
|
+
"""
|
|
242
|
+
Generate concise summary of task success assessment.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
verdicts: List of verdict objects with reasons
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Tuple of (summary_text, llm_cost)
|
|
249
|
+
"""
|
|
250
|
+
# Take up to 6 most relevant verdicts for summary
|
|
251
|
+
bullets = "\n".join(f"- {v['reason']}" for v in verdicts[:6])
|
|
252
|
+
|
|
253
|
+
prompt = (
|
|
254
|
+
"Write a concise (max 2 sentences) overall assessment of task success, "
|
|
255
|
+
"based on these observations:\n\n"
|
|
256
|
+
f"{bullets}\n\n"
|
|
257
|
+
"Summary:"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
text, cost = await chat_complete(
|
|
261
|
+
self.model,
|
|
262
|
+
messages=[{"role": "user", "content": prompt}],
|
|
263
|
+
temperature=0.0
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return text.strip(), cost or 0.0
|
|
267
|
+
|
|
268
|
+
# ==================== MAIN EVALUATION ====================
|
|
269
|
+
|
|
270
|
+
async def evaluate(self, test_case: ConversationalEvalTestCase) -> Dict[str, Any]:
|
|
271
|
+
"""
|
|
272
|
+
Evaluate task success rate across conversation turns.
|
|
273
|
+
|
|
274
|
+
Steps:
|
|
275
|
+
1. Format dialogue into readable text
|
|
276
|
+
2. Infer user's primary goal from conversation
|
|
277
|
+
3. Generate concrete success criteria for the goal
|
|
278
|
+
4. Generate verdicts for each criterion (fully/mostly/partial/minor/none)
|
|
279
|
+
5. Aggregate verdicts into final score using softmax
|
|
280
|
+
6. Generate summary explanation
|
|
281
|
+
7. Build comprehensive evaluation log
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
test_case: Conversational test case with multiple turns
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Evaluation results with score, success, reason, cost, and detailed log
|
|
288
|
+
"""
|
|
289
|
+
total_cost = 0.0
|
|
290
|
+
|
|
291
|
+
# Step 1: Format dialogue
|
|
292
|
+
dialogue_text = self._render_dialogue(test_case.turns)
|
|
293
|
+
|
|
294
|
+
# Step 2: Infer user goal
|
|
295
|
+
user_goal, cost = await self._infer_user_goal(dialogue_text)
|
|
296
|
+
total_cost += cost
|
|
297
|
+
|
|
298
|
+
# Step 3: Generate success criteria
|
|
299
|
+
success_criteria, cost = await self._generate_success_criteria(user_goal)
|
|
300
|
+
total_cost += cost
|
|
301
|
+
|
|
302
|
+
# Step 4: Generate verdicts for each criterion
|
|
303
|
+
verdicts, verdict_score, cost = await self._generate_verdicts(
|
|
304
|
+
user_goal,
|
|
305
|
+
success_criteria,
|
|
306
|
+
dialogue_text
|
|
307
|
+
)
|
|
308
|
+
total_cost += cost
|
|
309
|
+
|
|
310
|
+
# Step 5: Generate summary explanation
|
|
311
|
+
summary, cost = await self._summarize_verdicts(verdicts)
|
|
312
|
+
total_cost += cost
|
|
313
|
+
|
|
314
|
+
# Step 6: Determine success
|
|
315
|
+
final_score = verdict_score
|
|
316
|
+
success = final_score >= self.threshold
|
|
317
|
+
|
|
318
|
+
# Step 7: Build evaluation log
|
|
319
|
+
evaluation_log = {
|
|
320
|
+
"dialogue": dialogue_text,
|
|
321
|
+
"comment_dialogue": "Full conversation text used for task success evaluation.",
|
|
322
|
+
"number_of_turns": len(test_case.turns),
|
|
323
|
+
"comment_number_of_turns": "Total conversation turns analyzed.",
|
|
324
|
+
"user_goal": user_goal,
|
|
325
|
+
"comment_user_goal": "LLM-inferred primary goal the user wanted to achieve.",
|
|
326
|
+
"success_criteria": success_criteria,
|
|
327
|
+
"comment_success_criteria": f"Auto-generated checklist of {len(success_criteria)} observable criteria for task completion.",
|
|
328
|
+
"verdicts": verdicts,
|
|
329
|
+
"comment_verdicts": "LLM-generated verdicts assessing each criterion (fully/mostly/partial/minor/none).",
|
|
330
|
+
"verdict_weights": {i: VERDICT_WEIGHTS.get(v["verdict"], 0.0) for i, v in enumerate(verdicts)},
|
|
331
|
+
"comment_verdict_weights": "Numeric weights assigned to each verdict for score calculation.",
|
|
332
|
+
"final_score": final_score,
|
|
333
|
+
"comment_final_score": f"Softmax aggregation of verdict weights (temperature={self.temperature}).",
|
|
334
|
+
"threshold": self.threshold,
|
|
335
|
+
"success": success,
|
|
336
|
+
"comment_success": "Whether the task success score meets the required threshold.",
|
|
337
|
+
"final_reason": summary,
|
|
338
|
+
"comment_reasoning": "Concise explanation of the overall task completion assessment."
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
return {
|
|
342
|
+
"score": final_score,
|
|
343
|
+
"success": success,
|
|
344
|
+
"reason": summary,
|
|
345
|
+
"evaluation_cost": round(total_cost, 6),
|
|
346
|
+
"evaluation_log": evaluation_log
|
|
347
|
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Tool Correctness Metric: Evaluates whether the correct tools were called
|
|
3
|
+
during the execution of an AI agent.
|
|
4
|
+
Score calculation: Proportion of expected tools correctly called
|
|
5
|
+
'''
|
|
6
|
+
|
|
7
|
+
from typing import Dict, Any, List
|
|
8
|
+
from eval_lib.metric_pattern import MetricPattern
|
|
9
|
+
from eval_lib.testcases_schema import EvalTestCase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ToolCorrectnessMetric(MetricPattern):
|
|
13
|
+
name = "toolCorrectnessMetric"
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
threshold: float = 0.5,
|
|
18
|
+
evaluation_params: List[str] = [],
|
|
19
|
+
should_exact_match: bool = False,
|
|
20
|
+
should_consider_ordering: bool = False
|
|
21
|
+
):
|
|
22
|
+
super().__init__(model=None, threshold=threshold)
|
|
23
|
+
self.evaluation_params = evaluation_params
|
|
24
|
+
self.should_exact_match = should_exact_match
|
|
25
|
+
self.should_consider_ordering = should_consider_ordering
|
|
26
|
+
|
|
27
|
+
async def evaluate(self, test_case: EvalTestCase) -> Dict[str, Any]:
|
|
28
|
+
self.tools_called = test_case.tools_called or []
|
|
29
|
+
self.expected_tools = test_case.expected_tools or []
|
|
30
|
+
|
|
31
|
+
score = self.calculate_score()
|
|
32
|
+
reason = self.generate_reason()
|
|
33
|
+
|
|
34
|
+
return {
|
|
35
|
+
"score": score,
|
|
36
|
+
"success": score >= self.threshold,
|
|
37
|
+
"reason": reason,
|
|
38
|
+
"evaluation_cost": 0.0 # No LLM cost for this metric
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def generate_reason(self) -> str:
|
|
42
|
+
called_names = self.tools_called
|
|
43
|
+
expected_names = self.expected_tools
|
|
44
|
+
|
|
45
|
+
if self.should_exact_match:
|
|
46
|
+
if self.calculate_exact_match_score() == 1.0:
|
|
47
|
+
return f"Exact match: all expected tools {expected_names} were called exactly."
|
|
48
|
+
else:
|
|
49
|
+
return f"Mismatch: expected {expected_names}, called {called_names}."
|
|
50
|
+
elif self.should_consider_ordering:
|
|
51
|
+
lcs, weighted = self.compute_weighted_lcs()
|
|
52
|
+
if weighted == len(self.expected_tools):
|
|
53
|
+
return "Correct tool usage and order."
|
|
54
|
+
else:
|
|
55
|
+
return f"Incomplete or unordered: expected {expected_names}, got {called_names}."
|
|
56
|
+
else:
|
|
57
|
+
used_expected = set(self.tools_called) & set(self.expected_tools)
|
|
58
|
+
missing = set(self.expected_tools) - used_expected
|
|
59
|
+
if not missing:
|
|
60
|
+
return f"All expected tools {expected_names} were called."
|
|
61
|
+
else:
|
|
62
|
+
return f"Missing tools {list(missing)}. Expected {expected_names}, got {called_names}."
|
|
63
|
+
|
|
64
|
+
def calculate_score(self) -> float:
|
|
65
|
+
if self.should_exact_match:
|
|
66
|
+
return self.calculate_exact_match_score()
|
|
67
|
+
elif self.should_consider_ordering:
|
|
68
|
+
_, score = self.compute_weighted_lcs()
|
|
69
|
+
return score / len(self.expected_tools) if self.expected_tools else 0.0
|
|
70
|
+
else:
|
|
71
|
+
return self.calculate_non_exact_match_score()
|
|
72
|
+
|
|
73
|
+
def calculate_exact_match_score(self) -> float:
|
|
74
|
+
if len(self.tools_called) != len(self.expected_tools):
|
|
75
|
+
return 0.0
|
|
76
|
+
for i in range(len(self.tools_called)):
|
|
77
|
+
if self.tools_called[i] != self.expected_tools[i]:
|
|
78
|
+
return 0.0
|
|
79
|
+
return 1.0
|
|
80
|
+
|
|
81
|
+
def calculate_non_exact_match_score(self) -> float:
|
|
82
|
+
match_count = 0
|
|
83
|
+
used = set()
|
|
84
|
+
for expected in self.expected_tools:
|
|
85
|
+
for i, called in enumerate(self.tools_called):
|
|
86
|
+
if i in used:
|
|
87
|
+
continue
|
|
88
|
+
if expected == called:
|
|
89
|
+
match_count += 1
|
|
90
|
+
used.add(i)
|
|
91
|
+
break
|
|
92
|
+
return match_count / len(self.expected_tools) if self.expected_tools else 0.0
|
|
93
|
+
|
|
94
|
+
def compute_weighted_lcs(self):
|
|
95
|
+
m, n = len(self.expected_tools), len(self.tools_called)
|
|
96
|
+
dp = [[0.0] * (n + 1) for _ in range(m + 1)]
|
|
97
|
+
|
|
98
|
+
for i in range(1, m + 1):
|
|
99
|
+
for j in range(1, n + 1):
|
|
100
|
+
if self.expected_tools[i - 1] == self.tools_called[j - 1]:
|
|
101
|
+
dp[i][j] = dp[i - 1][j - 1] + 1
|
|
102
|
+
else:
|
|
103
|
+
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
|
104
|
+
|
|
105
|
+
score = dp[m][n]
|
|
106
|
+
return [], score
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from eval_lib.llm_client import chat_complete
|
|
3
|
+
from .document_loader import load_documents, chunk_documents
|
|
4
|
+
import math
|
|
5
|
+
from eval_lib.llm_client import get_embeddings
|
|
6
|
+
import numpy as np
|
|
7
|
+
from .prompts import dataset_generation_prompt, dataset_generation_from_scratch_prompt
|
|
8
|
+
from eval_lib.utils import extract_json_block
|
|
9
|
+
import asyncio
|
|
10
|
+
import random
|
|
11
|
+
import json
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def retry_async(fn, *args, retries=4, base_delay=0.6, max_delay=6.0,
|
|
15
|
+
retriable_statuses=(429, 500, 502, 503, 504),
|
|
16
|
+
**kwargs):
|
|
17
|
+
"""
|
|
18
|
+
fn — корутина, которая может бросить исключение вида:
|
|
19
|
+
- HTTPException-like с .status_code
|
|
20
|
+
- Exception с текстом, где встречается 'Service Unavailable' и т.п.
|
|
21
|
+
"""
|
|
22
|
+
attempt = 0
|
|
23
|
+
while True:
|
|
24
|
+
try:
|
|
25
|
+
return await fn(*args, **kwargs)
|
|
26
|
+
except Exception as e:
|
|
27
|
+
attempt += 1
|
|
28
|
+
status = getattr(e, "status_code", None)
|
|
29
|
+
msg = str(e).lower()
|
|
30
|
+
|
|
31
|
+
retriable = (status in retriable_statuses) or any(
|
|
32
|
+
s in msg for s in ["service unavailable", "temporarily unavailable",
|
|
33
|
+
"gateway timeout", "bad gateway", "timeout"])
|
|
34
|
+
if attempt > retries or not retriable:
|
|
35
|
+
raise
|
|
36
|
+
|
|
37
|
+
# экспоненциальный бэкофф + джиттер
|
|
38
|
+
delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
|
|
39
|
+
delay += random.uniform(0, 0.4)
|
|
40
|
+
await asyncio.sleep(delay)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DatasetGenerator:
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
*,
|
|
48
|
+
model: str,
|
|
49
|
+
input_format: str,
|
|
50
|
+
expected_output_format: str,
|
|
51
|
+
agent_description: str,
|
|
52
|
+
test_types: List[str],
|
|
53
|
+
question_length: str = "mixed",
|
|
54
|
+
question_openness: str = "mixed",
|
|
55
|
+
chunk_size: int = 1024,
|
|
56
|
+
chunk_overlap: int = 100,
|
|
57
|
+
temperature: float = 0.3,
|
|
58
|
+
max_rows: int = 10,
|
|
59
|
+
trap_density: float = 0.1,
|
|
60
|
+
language: str = "en",
|
|
61
|
+
max_chunks: int = 30,
|
|
62
|
+
relevance_margin: float = 1.5,
|
|
63
|
+
embedding_model: str = "openai:text-embedding-3-small",
|
|
64
|
+
):
|
|
65
|
+
self.model = model
|
|
66
|
+
self.input_format = input_format
|
|
67
|
+
self.expected_output_format = expected_output_format
|
|
68
|
+
self.agent_description = agent_description
|
|
69
|
+
self.test_types = test_types
|
|
70
|
+
self.question_length = question_length
|
|
71
|
+
self.question_openness = question_openness
|
|
72
|
+
self.chunk_size = chunk_size
|
|
73
|
+
self.chunk_overlap = chunk_overlap
|
|
74
|
+
self.temperature = temperature
|
|
75
|
+
self.max_rows = max_rows
|
|
76
|
+
self.trap_density = trap_density
|
|
77
|
+
self.language = language
|
|
78
|
+
self.max_chunks = max_chunks
|
|
79
|
+
self.relevance_margin = relevance_margin
|
|
80
|
+
self.embedding_model = embedding_model
|
|
81
|
+
|
|
82
|
+
async def generate_from_scratch(self) -> List[dict]:
|
|
83
|
+
prompt = dataset_generation_from_scratch_prompt(
|
|
84
|
+
max_rows=self.max_rows,
|
|
85
|
+
agent_description=self.agent_description,
|
|
86
|
+
input_format=self.input_format,
|
|
87
|
+
expected_output_format=self.expected_output_format,
|
|
88
|
+
test_types=self.test_types,
|
|
89
|
+
question_length=self.question_length,
|
|
90
|
+
question_openness=self.question_openness,
|
|
91
|
+
trap_density=self.trap_density,
|
|
92
|
+
language=self.language
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
raw, _ = await chat_complete(
|
|
96
|
+
llm=self.model,
|
|
97
|
+
messages=[{"role": "user", "content": prompt}],
|
|
98
|
+
temperature=self.temperature,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
raw_json = extract_json_block(raw)
|
|
103
|
+
data = json.loads(raw_json)
|
|
104
|
+
assert isinstance(data, list), "not a JSON array"
|
|
105
|
+
return data
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
raise RuntimeError(f"Failed to parse dataset:\n{exc}\n\n{raw}")
|
|
108
|
+
|
|
109
|
+
async def generate_from_documents(self, file_paths: List[str]) -> List[dict]:
|
|
110
|
+
|
|
111
|
+
docs = load_documents(file_paths)
|
|
112
|
+
doc_chunks = chunk_documents(docs,
|
|
113
|
+
chunk_size=self.chunk_size,
|
|
114
|
+
chunk_overlap=self.chunk_overlap)
|
|
115
|
+
|
|
116
|
+
chunks_text = [d.page_content for d in doc_chunks]
|
|
117
|
+
if not chunks_text:
|
|
118
|
+
raise ValueError("No text extracted from documents.")
|
|
119
|
+
|
|
120
|
+
ranked_chunks = await self._rank_chunks_by_relevance(chunks_text)
|
|
121
|
+
|
|
122
|
+
total_chunks = len(ranked_chunks)
|
|
123
|
+
rows_per_chunk = max(1, math.ceil(self.max_rows / total_chunks))
|
|
124
|
+
|
|
125
|
+
needed_chunks = math.ceil(self.max_rows / rows_per_chunk)
|
|
126
|
+
top_k = min(int(needed_chunks * self.relevance_margin),
|
|
127
|
+
self.max_chunks)
|
|
128
|
+
selected_chunks = ranked_chunks[:top_k]
|
|
129
|
+
|
|
130
|
+
dataset: list[dict] = []
|
|
131
|
+
|
|
132
|
+
MAX_PROMPT_CHARS = 24_000
|
|
133
|
+
|
|
134
|
+
for chunk in selected_chunks:
|
|
135
|
+
|
|
136
|
+
safe_chunk = chunk if len(
|
|
137
|
+
chunk) <= MAX_PROMPT_CHARS else chunk[:MAX_PROMPT_CHARS]
|
|
138
|
+
|
|
139
|
+
prompt = dataset_generation_prompt(
|
|
140
|
+
chunk=safe_chunk,
|
|
141
|
+
rows_per_chunk=rows_per_chunk,
|
|
142
|
+
agent_description=self.agent_description,
|
|
143
|
+
input_format=self.input_format,
|
|
144
|
+
expected_output_format=self.expected_output_format,
|
|
145
|
+
test_types=self.test_types,
|
|
146
|
+
question_length=self.question_length,
|
|
147
|
+
question_openness=self.question_openness,
|
|
148
|
+
trap_density=self.trap_density,
|
|
149
|
+
language=self.language
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
raw, _ = await retry_async(
|
|
153
|
+
chat_complete,
|
|
154
|
+
llm=self.model,
|
|
155
|
+
messages=[{"role": "user", "content": prompt}],
|
|
156
|
+
temperature=self.temperature,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
chunk_data = json.loads(extract_json_block(raw))
|
|
161
|
+
assert isinstance(chunk_data, list)
|
|
162
|
+
dataset.extend(chunk_data)
|
|
163
|
+
except Exception as exc:
|
|
164
|
+
raise RuntimeError(f"Chunk parsing error:\n{exc}\n\n{raw}")
|
|
165
|
+
|
|
166
|
+
if len(dataset) >= self.max_rows:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
return dataset[: self.max_rows]
|
|
170
|
+
|
|
171
|
+
async def _rank_chunks_by_relevance(self, chunks: list[str]) -> list[str]:
|
|
172
|
+
"""
|
|
173
|
+
Count token similarity between chunks and query.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
# estimate tokens
|
|
177
|
+
def approx_tokens(s: str) -> int:
|
|
178
|
+
return max(1, len(s) // 4)
|
|
179
|
+
|
|
180
|
+
# restrict length of each chunk for embedding (e.g., to ~8k tokens)
|
|
181
|
+
MAX_EMBED_TOKENS_PER_INPUT = 8000
|
|
182
|
+
MAX_EMBED_CHARS_PER_INPUT = MAX_EMBED_TOKENS_PER_INPUT * 4
|
|
183
|
+
|
|
184
|
+
truncated_chunks = [
|
|
185
|
+
c if len(
|
|
186
|
+
c) <= MAX_EMBED_CHARS_PER_INPUT else c[:MAX_EMBED_CHARS_PER_INPUT]
|
|
187
|
+
for c in chunks
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
# limit tokens per request
|
|
191
|
+
TOKEN_BUDGET_PER_REQUEST = 280_000
|
|
192
|
+
|
|
193
|
+
# divide into batches by total tokens
|
|
194
|
+
batches: list[list[str]] = []
|
|
195
|
+
cur: list[str] = []
|
|
196
|
+
cur_tokens = 0
|
|
197
|
+
for c in truncated_chunks:
|
|
198
|
+
t = approx_tokens(c)
|
|
199
|
+
if cur and (cur_tokens + t) > TOKEN_BUDGET_PER_REQUEST:
|
|
200
|
+
batches.append(cur)
|
|
201
|
+
cur = [c]
|
|
202
|
+
cur_tokens = t
|
|
203
|
+
else:
|
|
204
|
+
cur.append(c)
|
|
205
|
+
cur_tokens += t
|
|
206
|
+
if cur:
|
|
207
|
+
batches.append(cur)
|
|
208
|
+
|
|
209
|
+
# embedding for query
|
|
210
|
+
query = self.agent_description + " " + " ".join(self.test_types)
|
|
211
|
+
q_vec, _ = await retry_async(get_embeddings, model=self.embedding_model, texts=[query])
|
|
212
|
+
q_vec = q_vec[0]
|
|
213
|
+
|
|
214
|
+
# go through batches, accumulating embeddings
|
|
215
|
+
all_vecs = []
|
|
216
|
+
for batch in batches:
|
|
217
|
+
vecs, _ = await retry_async(get_embeddings, model=self.embedding_model, texts=batch)
|
|
218
|
+
all_vecs.extend(vecs)
|
|
219
|
+
|
|
220
|
+
import numpy as np
|
|
221
|
+
q_norm = np.linalg.norm(q_vec) + 1e-7
|
|
222
|
+
sims = [
|
|
223
|
+
float(np.dot(q_vec, v) / (q_norm * (np.linalg.norm(v) + 1e-7)))
|
|
224
|
+
for v in all_vecs
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
# sort
|
|
228
|
+
ranked = [c for _, c in sorted(
|
|
229
|
+
zip(sims, chunks), key=lambda x: x[0], reverse=True)]
|
|
230
|
+
return ranked
|