eval-ai-library 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eval-ai-library might be problematic. Click here for more details.

Files changed (34) hide show
  1. eval_ai_library-0.1.0.dist-info/METADATA +753 -0
  2. eval_ai_library-0.1.0.dist-info/RECORD +34 -0
  3. eval_ai_library-0.1.0.dist-info/WHEEL +5 -0
  4. eval_ai_library-0.1.0.dist-info/licenses/LICENSE +21 -0
  5. eval_ai_library-0.1.0.dist-info/top_level.txt +1 -0
  6. eval_lib/__init__.py +122 -0
  7. eval_lib/agent_metrics/__init__.py +12 -0
  8. eval_lib/agent_metrics/knowledge_retention_metric/knowledge_retention.py +231 -0
  9. eval_lib/agent_metrics/role_adherence_metric/role_adherence.py +251 -0
  10. eval_lib/agent_metrics/task_success_metric/task_success_rate.py +347 -0
  11. eval_lib/agent_metrics/tools_correctness_metric/tool_correctness.py +106 -0
  12. eval_lib/datagenerator/datagenerator.py +230 -0
  13. eval_lib/datagenerator/document_loader.py +510 -0
  14. eval_lib/datagenerator/prompts.py +192 -0
  15. eval_lib/evaluate.py +335 -0
  16. eval_lib/evaluation_schema.py +63 -0
  17. eval_lib/llm_client.py +286 -0
  18. eval_lib/metric_pattern.py +229 -0
  19. eval_lib/metrics/__init__.py +25 -0
  20. eval_lib/metrics/answer_precision_metric/answer_precision.py +405 -0
  21. eval_lib/metrics/answer_relevancy_metric/answer_relevancy.py +195 -0
  22. eval_lib/metrics/bias_metric/bias.py +114 -0
  23. eval_lib/metrics/contextual_precision_metric/contextual_precision.py +102 -0
  24. eval_lib/metrics/contextual_recall_metric/contextual_recall.py +91 -0
  25. eval_lib/metrics/contextual_relevancy_metric/contextual_relevancy.py +169 -0
  26. eval_lib/metrics/custom_metric/custom_eval.py +303 -0
  27. eval_lib/metrics/faithfulness_metric/faithfulness.py +140 -0
  28. eval_lib/metrics/geval/geval.py +326 -0
  29. eval_lib/metrics/restricted_refusal_metric/restricted_refusal.py +102 -0
  30. eval_lib/metrics/toxicity_metric/toxicity.py +113 -0
  31. eval_lib/price.py +37 -0
  32. eval_lib/py.typed +1 -0
  33. eval_lib/testcases_schema.py +27 -0
  34. eval_lib/utils.py +99 -0
@@ -0,0 +1,347 @@
1
+ # task_success_rate.py
2
+ """
3
+ Task Success Rate Metric: Evaluates whether the AI assistant successfully helped
4
+ the user achieve their goal in a conversation.
5
+
6
+ Score calculation: Softmax aggregation of success criteria verdicts
7
+ """
8
+ import json
9
+ from typing import List, Dict, Any, Tuple
10
+ from eval_lib.testcases_schema import ConversationalEvalTestCase
11
+ from eval_lib.metric_pattern import ConversationalMetricPattern
12
+ from eval_lib.llm_client import chat_complete
13
+ from eval_lib.utils import score_agg, extract_json_block
14
+
15
+
16
+ # Verdict weights for task completion levels
17
+ VERDICT_WEIGHTS = {
18
+ "fully": 1.0, # Criterion completely satisfied
19
+ "mostly": 0.9, # Criterion largely satisfied with minor gaps
20
+ "partial": 0.7, # Criterion partially satisfied
21
+ "minor": 0.3, # Criterion minimally addressed
22
+ "none": 0.0 # Criterion not satisfied at all
23
+ }
24
+
25
+ # Configuration constants
26
+ MAX_CRITERIA = 10
27
+ LINK_CRITERION = "The user got the link to the requested resource."
28
+
29
+
30
+ class TaskSuccessRateMetric(ConversationalMetricPattern):
31
+ """
32
+ Evaluates whether an AI assistant successfully helped the user complete
33
+ their intended task across a multi-turn conversation.
34
+ """
35
+
36
+ name = "taskSuccessRateMetric"
37
+ template_cls = None
38
+
39
+ def __init__(
40
+ self,
41
+ model: str,
42
+ threshold: float = 0.7,
43
+ temperature: float = 1.1,
44
+ ):
45
+ """
46
+ Initialize Task Success Rate metric.
47
+
48
+ Args:
49
+ model: LLM model name
50
+ threshold: Success threshold (0.0-1.0)
51
+ temperature: Score aggregation temperature for softmax
52
+ """
53
+ super().__init__(model=model, threshold=threshold)
54
+ self.temperature = temperature
55
+
56
+ # ==================== HELPER METHODS ====================
57
+
58
+ @staticmethod
59
+ def _render_dialogue(turns) -> str:
60
+ """Convert conversation turns into readable format"""
61
+ return "\n".join(
62
+ f"{i+1}. User: {t.input}\n Assistant: {t.actual_output}"
63
+ for i, t in enumerate(turns)
64
+ )
65
+
66
+ @staticmethod
67
+ def _prompt_label_help() -> str:
68
+ """Explanation of task success verdict levels"""
69
+ return """Rate task success criteria satisfaction (worst → best):
70
+
71
+ none – criterion not satisfied at all
72
+ minor – criterion minimally addressed
73
+ partial – criterion partially satisfied
74
+ mostly – criterion largely satisfied with minor gaps
75
+ fully – criterion completely satisfied"""
76
+
77
+ @staticmethod
78
+ def _prompt_criteria_few_shot() -> str:
79
+ """Few-shot examples for criteria generation"""
80
+ return """Example 1:
81
+ User goal: Order a pizza online
82
+ Criteria: [
83
+ "The assistant provided available pizza options.",
84
+ "The user received an order confirmation number."
85
+ ]
86
+
87
+ Example 2:
88
+ User goal: Reset an email password
89
+ Criteria: [
90
+ "The assistant gave a working password-reset link.",
91
+ "The user confirmed they could log in."
92
+ ]"""
93
+
94
+ # ==================== CORE EVALUATION STEPS ====================
95
+
96
+ async def _infer_user_goal(self, dialogue: str) -> Tuple[str, float]:
97
+ """
98
+ Infer the user's primary goal from the conversation.
99
+
100
+ Args:
101
+ dialogue: Formatted conversation text
102
+
103
+ Returns:
104
+ Tuple of (user_goal_description, llm_cost)
105
+ """
106
+ prompt = (
107
+ "You will be shown an ENTIRE conversation between a user and an assistant.\n"
108
+ "Write ONE concise sentence describing the user's PRIMARY GOAL in this conversation.\n\n"
109
+ f"CONVERSATION:\n{dialogue}\n\n"
110
+ "User goal:"
111
+ )
112
+
113
+ text, cost = await chat_complete(
114
+ self.model,
115
+ messages=[{"role": "user", "content": prompt}],
116
+ temperature=0.0
117
+ )
118
+
119
+ return text.strip(), cost or 0.0
120
+
121
+ async def _generate_success_criteria(self, goal: str) -> Tuple[List[str], float]:
122
+ """
123
+ Generate concrete success criteria for the user's goal.
124
+
125
+ Args:
126
+ goal: The inferred user goal
127
+
128
+ Returns:
129
+ Tuple of (criteria_list, llm_cost)
130
+ """
131
+ prompt = (
132
+ f"{self._prompt_criteria_few_shot()}\n\n"
133
+ f"Now do the same for the next case.\n\n"
134
+ f"User goal: {goal}\n\n"
135
+ f"List up to {MAX_CRITERIA} concrete SUCCESS CRITERIA that could realistically be satisfied "
136
+ f"within a brief chat of 2–5 turns. "
137
+ "Then **add** this exact sentence: "
138
+ f"\"{LINK_CRITERION}\"\n\n"
139
+ "Each criterion must be a short, observable statement.\n"
140
+ "Return only a JSON array of strings."
141
+ )
142
+
143
+ text, cost = await chat_complete(
144
+ self.model,
145
+ messages=[{"role": "user", "content": prompt}],
146
+ temperature=0.0
147
+ )
148
+
149
+ try:
150
+ raw_json = extract_json_block(text)
151
+ criteria = json.loads(raw_json)
152
+
153
+ if not isinstance(criteria, list):
154
+ raise ValueError("Expected JSON array of criteria")
155
+
156
+ # Ensure LINK_CRITERION is included
157
+ if LINK_CRITERION not in criteria:
158
+ criteria.append(LINK_CRITERION)
159
+
160
+ # Keep LINK_CRITERION first and limit to MAX_CRITERIA
161
+ if len(criteria) > MAX_CRITERIA:
162
+ criteria = (
163
+ [LINK_CRITERION] +
164
+ [c for c in criteria if c != LINK_CRITERION][:MAX_CRITERIA - 1]
165
+ )
166
+
167
+ # Truncate to MAX_CRITERIA
168
+ criteria = criteria[:MAX_CRITERIA]
169
+
170
+ return criteria, cost or 0.0
171
+
172
+ except Exception as e:
173
+ raise RuntimeError(
174
+ f"Failed to parse success criteria: {e}\n{text}")
175
+
176
+ async def _generate_verdicts(
177
+ self,
178
+ goal: str,
179
+ criteria: List[str],
180
+ dialogue: str
181
+ ) -> Tuple[List[Dict[str, str]], float, float]:
182
+ """
183
+ Generate verdicts for each success criterion.
184
+
185
+ Args:
186
+ goal: The user's goal
187
+ criteria: List of success criteria
188
+ dialogue: Formatted conversation text
189
+
190
+ Returns:
191
+ Tuple of (verdicts_list, aggregated_score, llm_cost)
192
+ """
193
+ prompt = (
194
+ f"{self._prompt_label_help()}\n\n"
195
+ f"USER GOAL: {goal}\n\n"
196
+ f"FULL DIALOGUE:\n{dialogue}\n\n"
197
+ f"SUCCESS CRITERIA (as JSON array):\n{json.dumps(criteria, ensure_ascii=False)}\n\n"
198
+ "For **each** criterion, decide how well it is satisfied at the END of the dialogue.\n"
199
+ "Use exactly one of: fully, mostly, partial, minor, none.\n\n"
200
+ "Return JSON array with **exactly the same length and order** as the criteria list:\n"
201
+ "[{\"verdict\":\"fully|mostly|partial|minor|none\",\"reason\":\"<one sentence>\"}, …]\n\n"
202
+ "No extra text."
203
+ )
204
+
205
+ text, cost = await chat_complete(
206
+ self.model,
207
+ messages=[{"role": "user", "content": prompt}],
208
+ temperature=0.0
209
+ )
210
+
211
+ try:
212
+ raw_json = extract_json_block(text)
213
+ verdicts = json.loads(raw_json)
214
+
215
+ if not isinstance(verdicts, list):
216
+ raise ValueError("Expected JSON array of verdicts")
217
+
218
+ # Ensure verdicts match criteria length
219
+ if len(verdicts) != len(criteria):
220
+ # Pad or truncate to match
221
+ if len(verdicts) < len(criteria):
222
+ verdicts.extend(
223
+ [{"verdict": "none", "reason": "Missing evaluation"}] * (len(criteria) - len(verdicts)))
224
+ else:
225
+ verdicts = verdicts[:len(criteria)]
226
+
227
+ # Calculate aggregated score from verdicts
228
+ weights = [VERDICT_WEIGHTS.get(
229
+ v.get("verdict", "none"), 0.0) for v in verdicts]
230
+ score = round(score_agg(weights, temperature=self.temperature), 4)
231
+
232
+ return verdicts, score, cost or 0.0
233
+
234
+ except Exception as e:
235
+ raise RuntimeError(f"Failed to parse verdicts: {e}\n{text}")
236
+
237
+ async def _summarize_verdicts(
238
+ self,
239
+ verdicts: List[Dict[str, str]]
240
+ ) -> Tuple[str, float]:
241
+ """
242
+ Generate concise summary of task success assessment.
243
+
244
+ Args:
245
+ verdicts: List of verdict objects with reasons
246
+
247
+ Returns:
248
+ Tuple of (summary_text, llm_cost)
249
+ """
250
+ # Take up to 6 most relevant verdicts for summary
251
+ bullets = "\n".join(f"- {v['reason']}" for v in verdicts[:6])
252
+
253
+ prompt = (
254
+ "Write a concise (max 2 sentences) overall assessment of task success, "
255
+ "based on these observations:\n\n"
256
+ f"{bullets}\n\n"
257
+ "Summary:"
258
+ )
259
+
260
+ text, cost = await chat_complete(
261
+ self.model,
262
+ messages=[{"role": "user", "content": prompt}],
263
+ temperature=0.0
264
+ )
265
+
266
+ return text.strip(), cost or 0.0
267
+
268
+ # ==================== MAIN EVALUATION ====================
269
+
270
+ async def evaluate(self, test_case: ConversationalEvalTestCase) -> Dict[str, Any]:
271
+ """
272
+ Evaluate task success rate across conversation turns.
273
+
274
+ Steps:
275
+ 1. Format dialogue into readable text
276
+ 2. Infer user's primary goal from conversation
277
+ 3. Generate concrete success criteria for the goal
278
+ 4. Generate verdicts for each criterion (fully/mostly/partial/minor/none)
279
+ 5. Aggregate verdicts into final score using softmax
280
+ 6. Generate summary explanation
281
+ 7. Build comprehensive evaluation log
282
+
283
+ Args:
284
+ test_case: Conversational test case with multiple turns
285
+
286
+ Returns:
287
+ Evaluation results with score, success, reason, cost, and detailed log
288
+ """
289
+ total_cost = 0.0
290
+
291
+ # Step 1: Format dialogue
292
+ dialogue_text = self._render_dialogue(test_case.turns)
293
+
294
+ # Step 2: Infer user goal
295
+ user_goal, cost = await self._infer_user_goal(dialogue_text)
296
+ total_cost += cost
297
+
298
+ # Step 3: Generate success criteria
299
+ success_criteria, cost = await self._generate_success_criteria(user_goal)
300
+ total_cost += cost
301
+
302
+ # Step 4: Generate verdicts for each criterion
303
+ verdicts, verdict_score, cost = await self._generate_verdicts(
304
+ user_goal,
305
+ success_criteria,
306
+ dialogue_text
307
+ )
308
+ total_cost += cost
309
+
310
+ # Step 5: Generate summary explanation
311
+ summary, cost = await self._summarize_verdicts(verdicts)
312
+ total_cost += cost
313
+
314
+ # Step 6: Determine success
315
+ final_score = verdict_score
316
+ success = final_score >= self.threshold
317
+
318
+ # Step 7: Build evaluation log
319
+ evaluation_log = {
320
+ "dialogue": dialogue_text,
321
+ "comment_dialogue": "Full conversation text used for task success evaluation.",
322
+ "number_of_turns": len(test_case.turns),
323
+ "comment_number_of_turns": "Total conversation turns analyzed.",
324
+ "user_goal": user_goal,
325
+ "comment_user_goal": "LLM-inferred primary goal the user wanted to achieve.",
326
+ "success_criteria": success_criteria,
327
+ "comment_success_criteria": f"Auto-generated checklist of {len(success_criteria)} observable criteria for task completion.",
328
+ "verdicts": verdicts,
329
+ "comment_verdicts": "LLM-generated verdicts assessing each criterion (fully/mostly/partial/minor/none).",
330
+ "verdict_weights": {i: VERDICT_WEIGHTS.get(v["verdict"], 0.0) for i, v in enumerate(verdicts)},
331
+ "comment_verdict_weights": "Numeric weights assigned to each verdict for score calculation.",
332
+ "final_score": final_score,
333
+ "comment_final_score": f"Softmax aggregation of verdict weights (temperature={self.temperature}).",
334
+ "threshold": self.threshold,
335
+ "success": success,
336
+ "comment_success": "Whether the task success score meets the required threshold.",
337
+ "final_reason": summary,
338
+ "comment_reasoning": "Concise explanation of the overall task completion assessment."
339
+ }
340
+
341
+ return {
342
+ "score": final_score,
343
+ "success": success,
344
+ "reason": summary,
345
+ "evaluation_cost": round(total_cost, 6),
346
+ "evaluation_log": evaluation_log
347
+ }
@@ -0,0 +1,106 @@
1
+ '''
2
+ Tool Correctness Metric: Evaluates whether the correct tools were called
3
+ during the execution of an AI agent.
4
+ Score calculation: Proportion of expected tools correctly called
5
+ '''
6
+
7
+ from typing import Dict, Any, List
8
+ from eval_lib.metric_pattern import MetricPattern
9
+ from eval_lib.testcases_schema import EvalTestCase
10
+
11
+
12
+ class ToolCorrectnessMetric(MetricPattern):
13
+ name = "toolCorrectnessMetric"
14
+
15
+ def __init__(
16
+ self,
17
+ threshold: float = 0.5,
18
+ evaluation_params: List[str] = [],
19
+ should_exact_match: bool = False,
20
+ should_consider_ordering: bool = False
21
+ ):
22
+ super().__init__(model=None, threshold=threshold)
23
+ self.evaluation_params = evaluation_params
24
+ self.should_exact_match = should_exact_match
25
+ self.should_consider_ordering = should_consider_ordering
26
+
27
+ async def evaluate(self, test_case: EvalTestCase) -> Dict[str, Any]:
28
+ self.tools_called = test_case.tools_called or []
29
+ self.expected_tools = test_case.expected_tools or []
30
+
31
+ score = self.calculate_score()
32
+ reason = self.generate_reason()
33
+
34
+ return {
35
+ "score": score,
36
+ "success": score >= self.threshold,
37
+ "reason": reason,
38
+ "evaluation_cost": 0.0 # No LLM cost for this metric
39
+ }
40
+
41
+ def generate_reason(self) -> str:
42
+ called_names = self.tools_called
43
+ expected_names = self.expected_tools
44
+
45
+ if self.should_exact_match:
46
+ if self.calculate_exact_match_score() == 1.0:
47
+ return f"Exact match: all expected tools {expected_names} were called exactly."
48
+ else:
49
+ return f"Mismatch: expected {expected_names}, called {called_names}."
50
+ elif self.should_consider_ordering:
51
+ lcs, weighted = self.compute_weighted_lcs()
52
+ if weighted == len(self.expected_tools):
53
+ return "Correct tool usage and order."
54
+ else:
55
+ return f"Incomplete or unordered: expected {expected_names}, got {called_names}."
56
+ else:
57
+ used_expected = set(self.tools_called) & set(self.expected_tools)
58
+ missing = set(self.expected_tools) - used_expected
59
+ if not missing:
60
+ return f"All expected tools {expected_names} were called."
61
+ else:
62
+ return f"Missing tools {list(missing)}. Expected {expected_names}, got {called_names}."
63
+
64
+ def calculate_score(self) -> float:
65
+ if self.should_exact_match:
66
+ return self.calculate_exact_match_score()
67
+ elif self.should_consider_ordering:
68
+ _, score = self.compute_weighted_lcs()
69
+ return score / len(self.expected_tools) if self.expected_tools else 0.0
70
+ else:
71
+ return self.calculate_non_exact_match_score()
72
+
73
+ def calculate_exact_match_score(self) -> float:
74
+ if len(self.tools_called) != len(self.expected_tools):
75
+ return 0.0
76
+ for i in range(len(self.tools_called)):
77
+ if self.tools_called[i] != self.expected_tools[i]:
78
+ return 0.0
79
+ return 1.0
80
+
81
+ def calculate_non_exact_match_score(self) -> float:
82
+ match_count = 0
83
+ used = set()
84
+ for expected in self.expected_tools:
85
+ for i, called in enumerate(self.tools_called):
86
+ if i in used:
87
+ continue
88
+ if expected == called:
89
+ match_count += 1
90
+ used.add(i)
91
+ break
92
+ return match_count / len(self.expected_tools) if self.expected_tools else 0.0
93
+
94
+ def compute_weighted_lcs(self):
95
+ m, n = len(self.expected_tools), len(self.tools_called)
96
+ dp = [[0.0] * (n + 1) for _ in range(m + 1)]
97
+
98
+ for i in range(1, m + 1):
99
+ for j in range(1, n + 1):
100
+ if self.expected_tools[i - 1] == self.tools_called[j - 1]:
101
+ dp[i][j] = dp[i - 1][j - 1] + 1
102
+ else:
103
+ dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
104
+
105
+ score = dp[m][n]
106
+ return [], score
@@ -0,0 +1,230 @@
1
+ from typing import List
2
+ from eval_lib.llm_client import chat_complete
3
+ from .document_loader import load_documents, chunk_documents
4
+ import math
5
+ from eval_lib.llm_client import get_embeddings
6
+ import numpy as np
7
+ from .prompts import dataset_generation_prompt, dataset_generation_from_scratch_prompt
8
+ from eval_lib.utils import extract_json_block
9
+ import asyncio
10
+ import random
11
+ import json
12
+
13
+
14
+ async def retry_async(fn, *args, retries=4, base_delay=0.6, max_delay=6.0,
15
+ retriable_statuses=(429, 500, 502, 503, 504),
16
+ **kwargs):
17
+ """
18
+ fn — корутина, которая может бросить исключение вида:
19
+ - HTTPException-like с .status_code
20
+ - Exception с текстом, где встречается 'Service Unavailable' и т.п.
21
+ """
22
+ attempt = 0
23
+ while True:
24
+ try:
25
+ return await fn(*args, **kwargs)
26
+ except Exception as e:
27
+ attempt += 1
28
+ status = getattr(e, "status_code", None)
29
+ msg = str(e).lower()
30
+
31
+ retriable = (status in retriable_statuses) or any(
32
+ s in msg for s in ["service unavailable", "temporarily unavailable",
33
+ "gateway timeout", "bad gateway", "timeout"])
34
+ if attempt > retries or not retriable:
35
+ raise
36
+
37
+ # экспоненциальный бэкофф + джиттер
38
+ delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
39
+ delay += random.uniform(0, 0.4)
40
+ await asyncio.sleep(delay)
41
+
42
+
43
+ class DatasetGenerator:
44
+
45
+ def __init__(
46
+ self,
47
+ *,
48
+ model: str,
49
+ input_format: str,
50
+ expected_output_format: str,
51
+ agent_description: str,
52
+ test_types: List[str],
53
+ question_length: str = "mixed",
54
+ question_openness: str = "mixed",
55
+ chunk_size: int = 1024,
56
+ chunk_overlap: int = 100,
57
+ temperature: float = 0.3,
58
+ max_rows: int = 10,
59
+ trap_density: float = 0.1,
60
+ language: str = "en",
61
+ max_chunks: int = 30,
62
+ relevance_margin: float = 1.5,
63
+ embedding_model: str = "openai:text-embedding-3-small",
64
+ ):
65
+ self.model = model
66
+ self.input_format = input_format
67
+ self.expected_output_format = expected_output_format
68
+ self.agent_description = agent_description
69
+ self.test_types = test_types
70
+ self.question_length = question_length
71
+ self.question_openness = question_openness
72
+ self.chunk_size = chunk_size
73
+ self.chunk_overlap = chunk_overlap
74
+ self.temperature = temperature
75
+ self.max_rows = max_rows
76
+ self.trap_density = trap_density
77
+ self.language = language
78
+ self.max_chunks = max_chunks
79
+ self.relevance_margin = relevance_margin
80
+ self.embedding_model = embedding_model
81
+
82
+ async def generate_from_scratch(self) -> List[dict]:
83
+ prompt = dataset_generation_from_scratch_prompt(
84
+ max_rows=self.max_rows,
85
+ agent_description=self.agent_description,
86
+ input_format=self.input_format,
87
+ expected_output_format=self.expected_output_format,
88
+ test_types=self.test_types,
89
+ question_length=self.question_length,
90
+ question_openness=self.question_openness,
91
+ trap_density=self.trap_density,
92
+ language=self.language
93
+ )
94
+
95
+ raw, _ = await chat_complete(
96
+ llm=self.model,
97
+ messages=[{"role": "user", "content": prompt}],
98
+ temperature=self.temperature,
99
+ )
100
+
101
+ try:
102
+ raw_json = extract_json_block(raw)
103
+ data = json.loads(raw_json)
104
+ assert isinstance(data, list), "not a JSON array"
105
+ return data
106
+ except Exception as exc:
107
+ raise RuntimeError(f"Failed to parse dataset:\n{exc}\n\n{raw}")
108
+
109
+ async def generate_from_documents(self, file_paths: List[str]) -> List[dict]:
110
+
111
+ docs = load_documents(file_paths)
112
+ doc_chunks = chunk_documents(docs,
113
+ chunk_size=self.chunk_size,
114
+ chunk_overlap=self.chunk_overlap)
115
+
116
+ chunks_text = [d.page_content for d in doc_chunks]
117
+ if not chunks_text:
118
+ raise ValueError("No text extracted from documents.")
119
+
120
+ ranked_chunks = await self._rank_chunks_by_relevance(chunks_text)
121
+
122
+ total_chunks = len(ranked_chunks)
123
+ rows_per_chunk = max(1, math.ceil(self.max_rows / total_chunks))
124
+
125
+ needed_chunks = math.ceil(self.max_rows / rows_per_chunk)
126
+ top_k = min(int(needed_chunks * self.relevance_margin),
127
+ self.max_chunks)
128
+ selected_chunks = ranked_chunks[:top_k]
129
+
130
+ dataset: list[dict] = []
131
+
132
+ MAX_PROMPT_CHARS = 24_000
133
+
134
+ for chunk in selected_chunks:
135
+
136
+ safe_chunk = chunk if len(
137
+ chunk) <= MAX_PROMPT_CHARS else chunk[:MAX_PROMPT_CHARS]
138
+
139
+ prompt = dataset_generation_prompt(
140
+ chunk=safe_chunk,
141
+ rows_per_chunk=rows_per_chunk,
142
+ agent_description=self.agent_description,
143
+ input_format=self.input_format,
144
+ expected_output_format=self.expected_output_format,
145
+ test_types=self.test_types,
146
+ question_length=self.question_length,
147
+ question_openness=self.question_openness,
148
+ trap_density=self.trap_density,
149
+ language=self.language
150
+ )
151
+
152
+ raw, _ = await retry_async(
153
+ chat_complete,
154
+ llm=self.model,
155
+ messages=[{"role": "user", "content": prompt}],
156
+ temperature=self.temperature,
157
+ )
158
+
159
+ try:
160
+ chunk_data = json.loads(extract_json_block(raw))
161
+ assert isinstance(chunk_data, list)
162
+ dataset.extend(chunk_data)
163
+ except Exception as exc:
164
+ raise RuntimeError(f"Chunk parsing error:\n{exc}\n\n{raw}")
165
+
166
+ if len(dataset) >= self.max_rows:
167
+ break
168
+
169
+ return dataset[: self.max_rows]
170
+
171
+ async def _rank_chunks_by_relevance(self, chunks: list[str]) -> list[str]:
172
+ """
173
+ Count token similarity between chunks and query.
174
+
175
+ """
176
+ # estimate tokens
177
+ def approx_tokens(s: str) -> int:
178
+ return max(1, len(s) // 4)
179
+
180
+ # restrict length of each chunk for embedding (e.g., to ~8k tokens)
181
+ MAX_EMBED_TOKENS_PER_INPUT = 8000
182
+ MAX_EMBED_CHARS_PER_INPUT = MAX_EMBED_TOKENS_PER_INPUT * 4
183
+
184
+ truncated_chunks = [
185
+ c if len(
186
+ c) <= MAX_EMBED_CHARS_PER_INPUT else c[:MAX_EMBED_CHARS_PER_INPUT]
187
+ for c in chunks
188
+ ]
189
+
190
+ # limit tokens per request
191
+ TOKEN_BUDGET_PER_REQUEST = 280_000
192
+
193
+ # divide into batches by total tokens
194
+ batches: list[list[str]] = []
195
+ cur: list[str] = []
196
+ cur_tokens = 0
197
+ for c in truncated_chunks:
198
+ t = approx_tokens(c)
199
+ if cur and (cur_tokens + t) > TOKEN_BUDGET_PER_REQUEST:
200
+ batches.append(cur)
201
+ cur = [c]
202
+ cur_tokens = t
203
+ else:
204
+ cur.append(c)
205
+ cur_tokens += t
206
+ if cur:
207
+ batches.append(cur)
208
+
209
+ # embedding for query
210
+ query = self.agent_description + " " + " ".join(self.test_types)
211
+ q_vec, _ = await retry_async(get_embeddings, model=self.embedding_model, texts=[query])
212
+ q_vec = q_vec[0]
213
+
214
+ # go through batches, accumulating embeddings
215
+ all_vecs = []
216
+ for batch in batches:
217
+ vecs, _ = await retry_async(get_embeddings, model=self.embedding_model, texts=batch)
218
+ all_vecs.extend(vecs)
219
+
220
+ import numpy as np
221
+ q_norm = np.linalg.norm(q_vec) + 1e-7
222
+ sims = [
223
+ float(np.dot(q_vec, v) / (q_norm * (np.linalg.norm(v) + 1e-7)))
224
+ for v in all_vecs
225
+ ]
226
+
227
+ # sort
228
+ ranked = [c for _, c in sorted(
229
+ zip(sims, chunks), key=lambda x: x[0], reverse=True)]
230
+ return ranked