floeval 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- floeval/__init__.py +5 -0
- floeval/api/__init__.py +8 -0
- floeval/api/agent_evaluation.py +318 -0
- floeval/api/dataset.py +149 -0
- floeval/api/dataset_loaders/__init__.py +0 -0
- floeval/api/dataset_loaders/agent_file_loader.py +169 -0
- floeval/api/dataset_loaders/base.py +29 -0
- floeval/api/dataset_loaders/local_file_loader.py +69 -0
- floeval/api/evaluation.py +650 -0
- floeval/api/metrics/__init__.py +8 -0
- floeval/api/metrics/base.py +33 -0
- floeval/api/metrics/custom/__init__.py +17 -0
- floeval/api/metrics/custom/context.py +75 -0
- floeval/api/metrics/custom/criteria.py +264 -0
- floeval/api/metrics/custom/decorator.py +223 -0
- floeval/api/metrics/custom/llm_helper.py +171 -0
- floeval/api/metrics/registry.py +170 -0
- floeval/cli/__init__.py +28 -0
- floeval/cli/export.py +19 -0
- floeval/cli/main.py +86 -0
- floeval/cli/parse_evaluate.py +279 -0
- floeval/cli/parse_generate.py +73 -0
- floeval/cli/utils.py +88 -0
- floeval/config/__init__.py +1 -0
- floeval/config/schemas/__init__.py +0 -0
- floeval/config/schemas/deepeval/__init__.py +33 -0
- floeval/config/schemas/io/__init__.py +0 -0
- floeval/config/schemas/io/agent_dataset.py +240 -0
- floeval/config/schemas/io/dataset.py +106 -0
- floeval/config/schemas/io/llm.py +102 -0
- floeval/config/schemas/prompts.py +19 -0
- floeval/core/__init__.py +3 -0
- floeval/core/execution/__init__.py +3 -0
- floeval/core/execution/base.py +45 -0
- floeval/core/execution/llm_executor.py +199 -0
- floeval/core/execution/response_synthesizer.py +178 -0
- floeval/core/execution/trace.py +16 -0
- floeval/flotorch/__init__.py +55 -0
- floeval/flotorch/adk/__init__.py +32 -0
- floeval/flotorch/adk/agent.py +174 -0
- floeval/flotorch/adk/llm.py +134 -0
- floeval/flotorch/adk/memory.py +239 -0
- floeval/flotorch/adk/sessions.py +18 -0
- floeval/flotorch/adk/utils/__init__.py +29 -0
- floeval/flotorch/adk/utils/adk_utils.py +353 -0
- floeval/flotorch/adk/utils/warning_utils.py +64 -0
- floeval/flotorch/runner.py +122 -0
- floeval/flotorch/sdk/__init__.py +12 -0
- floeval/flotorch/sdk/llm.py +86 -0
- floeval/flotorch/sdk/memory.py +121 -0
- floeval/flotorch/sdk/utils/__init__.py +9 -0
- floeval/flotorch/sdk/utils/http_utils.py +120 -0
- floeval/flotorch/sdk/utils/memory_utils.py +254 -0
- floeval/metric_providers/__init__.py +12 -0
- floeval/metric_providers/builtin/__init__.py +7 -0
- floeval/metric_providers/builtin/agent_metrics.py +200 -0
- floeval/metric_providers/builtin/metrics.py +35 -0
- floeval/metric_providers/deepeval/__init__.py +36 -0
- floeval/metric_providers/deepeval/adapter.py +182 -0
- floeval/metric_providers/deepeval/custom_adapter.py +314 -0
- floeval/metric_providers/deepeval/metrics.py +339 -0
- floeval/metric_providers/ragas/__init__.py +54 -0
- floeval/metric_providers/ragas/adapter.py +245 -0
- floeval/metric_providers/ragas/agent_metrics.py +181 -0
- floeval/metric_providers/ragas/custom_adapter.py +184 -0
- floeval/metric_providers/ragas/metrics.py +553 -0
- floeval/utils/__init__.py +7 -0
- floeval/utils/agent_trace/__init__.py +31 -0
- floeval/utils/agent_trace/decorator.py +157 -0
- floeval/utils/agent_trace/helpers.py +107 -0
- floeval/utils/agent_trace/langchain_adapter.py +63 -0
- floeval/utils/agent_trace/patchers/__init__.py +1 -0
- floeval/utils/agent_trace/patchers/langchain_callback.py +73 -0
- floeval/utils/agent_trace/patchers/openai_patcher.py +118 -0
- floeval/utils/agent_trace/trace_collector.py +160 -0
- floeval/utils/agent_trace/trace_context.py +73 -0
- floeval/utils/loaders.py +82 -0
- floeval/utils/ragas_results.py +49 -0
- floeval-0.1.0b1.dist-info/METADATA +146 -0
- floeval-0.1.0b1.dist-info/RECORD +84 -0
- floeval-0.1.0b1.dist-info/WHEEL +5 -0
- floeval-0.1.0b1.dist-info/entry_points.txt +2 -0
- floeval-0.1.0b1.dist-info/licenses/LICENSE +201 -0
- floeval-0.1.0b1.dist-info/top_level.txt +1 -0
floeval/__init__.py
ADDED
floeval/api/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""Public API for Floeval - Evaluation Framework."""
|
|
2
|
+
|
|
3
|
+
from floeval.api.dataset import DatasetLoader
|
|
4
|
+
from floeval.api.evaluation import Evaluation
|
|
5
|
+
from floeval.api.metrics.registry import MetricRegistry
|
|
6
|
+
from floeval.config.schemas.io.dataset import Dataset
|
|
7
|
+
|
|
8
|
+
__all__ = ["Evaluation", "Dataset", "DatasetLoader", "MetricRegistry"]
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""Agent evaluation orchestrator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, Awaitable, Callable, Mapping
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
import floeval.metric_providers # noqa: F401 - trigger metric registration
|
|
14
|
+
from floeval.api.metrics.base import BaseMetric, MetricResult
|
|
15
|
+
from floeval.api.metrics.registry import MetricRegistry
|
|
16
|
+
from floeval.config.schemas.io.agent_dataset import (
|
|
17
|
+
AgentDataset,
|
|
18
|
+
AgentSample,
|
|
19
|
+
_to_display_str,
|
|
20
|
+
)
|
|
21
|
+
from floeval.config.schemas.io.llm import OpenAIProviderConfig
|
|
22
|
+
from floeval.core.execution.llm_executor import OpenAIProvider
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
MetricSpec = BaseMetric | str | dict[str, Any]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AgentEvaluationResult(BaseModel):
|
|
30
|
+
"""Agent evaluation results."""
|
|
31
|
+
|
|
32
|
+
sample_results: list[dict[str, Any]] = Field(default_factory=list)
|
|
33
|
+
summary: dict[str, Any] = Field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AgentEvaluation:
|
|
37
|
+
"""Agent evaluation orchestrator.
|
|
38
|
+
|
|
39
|
+
Supports Mode 1 (pre-captured traces), Mode 2 (partial + agent callable),
|
|
40
|
+
Mode 4 (partial + agent_runner e.g. FloTorchRunner).
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
dataset: AgentDataset,
|
|
46
|
+
metrics: list[MetricSpec],
|
|
47
|
+
llm_config: OpenAIProviderConfig | None = None,
|
|
48
|
+
agent: Callable[[str], str | Any] | Callable[[str], Awaitable[str | Any]] | None = None,
|
|
49
|
+
agent_runner: Any | None = None,
|
|
50
|
+
default_provider: str | None = "builtin",
|
|
51
|
+
metric_params: Mapping[str, dict[str, Any]] | None = None,
|
|
52
|
+
):
|
|
53
|
+
self.dataset = dataset
|
|
54
|
+
self.metrics = metrics
|
|
55
|
+
self.llm_config = llm_config
|
|
56
|
+
self.agent = agent
|
|
57
|
+
self.agent_runner = agent_runner
|
|
58
|
+
self.default_provider = default_provider or "builtin"
|
|
59
|
+
self.metric_params = dict(metric_params or {})
|
|
60
|
+
self._registry = MetricRegistry()
|
|
61
|
+
self._resolved_metrics = self._resolve_metrics(metrics)
|
|
62
|
+
|
|
63
|
+
def _resolve_metrics(self, specs: list[MetricSpec]) -> list[BaseMetric]:
|
|
64
|
+
"""Resolve metric specs to instances."""
|
|
65
|
+
resolved: list[BaseMetric] = []
|
|
66
|
+
|
|
67
|
+
for spec in specs:
|
|
68
|
+
if isinstance(spec, BaseMetric):
|
|
69
|
+
if (
|
|
70
|
+
self.llm_config is not None
|
|
71
|
+
and hasattr(spec, "llm_config")
|
|
72
|
+
and spec.llm_config is None
|
|
73
|
+
):
|
|
74
|
+
spec.llm_config = self.llm_config
|
|
75
|
+
resolved.append(spec)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if isinstance(spec, dict):
|
|
79
|
+
metric_id = spec.get("id")
|
|
80
|
+
if not metric_id:
|
|
81
|
+
raise ValueError("Metric dict spec must include 'id'")
|
|
82
|
+
provider = (
|
|
83
|
+
spec.get("provider")
|
|
84
|
+
or self.default_provider
|
|
85
|
+
or self._registry.resolve_best(metric_id, self.default_provider)
|
|
86
|
+
)
|
|
87
|
+
params = spec.get("params", {}) or {}
|
|
88
|
+
metric = self._create_metric(provider, metric_id, params)
|
|
89
|
+
resolved.append(metric)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if isinstance(spec, str):
|
|
93
|
+
if ":" in spec:
|
|
94
|
+
provider, metric_id = spec.split(":", 1)
|
|
95
|
+
else:
|
|
96
|
+
metric_id = spec
|
|
97
|
+
provider = self._registry.resolve_best(metric_id, self.default_provider)
|
|
98
|
+
metric = self._create_metric(provider, metric_id, params={})
|
|
99
|
+
resolved.append(metric)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
raise TypeError(f"Invalid metric spec: {spec!r}")
|
|
103
|
+
|
|
104
|
+
return resolved
|
|
105
|
+
|
|
106
|
+
def _merge_params(
|
|
107
|
+
self, provider: str, metric_id: str, params: dict[str, Any]
|
|
108
|
+
) -> dict[str, Any]:
|
|
109
|
+
"""Merge user-provided params with evaluation-level defaults.
|
|
110
|
+
|
|
111
|
+
Precedence (highest to lowest):
|
|
112
|
+
1) Explicit params passed in the metric spec dict
|
|
113
|
+
2) metric_params mapping (keyed by "provider:metric" or "metric")
|
|
114
|
+
3) llm_config when metrics expect it
|
|
115
|
+
"""
|
|
116
|
+
merged: dict[str, Any] = {}
|
|
117
|
+
merged.update(self.metric_params.get(metric_id, {}))
|
|
118
|
+
merged.update(self.metric_params.get(f"{provider}:{metric_id}", {}))
|
|
119
|
+
merged.update(params)
|
|
120
|
+
return merged
|
|
121
|
+
|
|
122
|
+
def _inject_context_params(
|
|
123
|
+
self, provider: str, metric_id: str, merged: dict[str, Any]
|
|
124
|
+
) -> dict[str, Any]:
|
|
125
|
+
"""Inject context-derived params based on metric constructor signature.
|
|
126
|
+
|
|
127
|
+
Uses introspection to avoid hardcoding metric names. Injects:
|
|
128
|
+
- llm_config: when metric accepts it and not in merged
|
|
129
|
+
- llm_provider: when metric accepts it, not in merged, and we have llm_config
|
|
130
|
+
"""
|
|
131
|
+
metric_factory = self._registry.get_class(provider, metric_id)
|
|
132
|
+
if metric_factory is None:
|
|
133
|
+
return merged
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
if callable(metric_factory) and not isinstance(metric_factory, type):
|
|
137
|
+
sig = inspect.signature(metric_factory)
|
|
138
|
+
else:
|
|
139
|
+
sig = inspect.signature(metric_factory.__init__)
|
|
140
|
+
except (TypeError, ValueError, AttributeError):
|
|
141
|
+
return merged
|
|
142
|
+
|
|
143
|
+
if self.llm_config is None:
|
|
144
|
+
return merged
|
|
145
|
+
|
|
146
|
+
if "llm_config" in sig.parameters and "llm_config" not in merged:
|
|
147
|
+
merged["llm_config"] = self.llm_config
|
|
148
|
+
|
|
149
|
+
if "llm_provider" in sig.parameters and "llm_provider" not in merged:
|
|
150
|
+
merged["llm_provider"] = OpenAIProvider(
|
|
151
|
+
config_name=f"{provider}:{metric_id}",
|
|
152
|
+
**self.llm_config.model_dump(),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return merged
|
|
156
|
+
|
|
157
|
+
def _create_metric(
|
|
158
|
+
self, provider: str, metric_id: str, params: dict[str, Any]
|
|
159
|
+
) -> BaseMetric:
|
|
160
|
+
"""Create metric instance with dynamic dependency injection."""
|
|
161
|
+
merged = self._merge_params(provider, metric_id, params)
|
|
162
|
+
merged = self._inject_context_params(provider, metric_id, merged)
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
return self._registry.create(provider, metric_id, **merged)
|
|
166
|
+
except TypeError as e:
|
|
167
|
+
if "llm_config" in merged or "llm_provider" in merged or "adapter" in merged:
|
|
168
|
+
fallback = {k: v for k, v in merged.items() if k not in ("llm_config", "llm_provider", "adapter")
|
|
169
|
+
}
|
|
170
|
+
try:
|
|
171
|
+
return self._registry.create(provider, metric_id, **fallback)
|
|
172
|
+
except TypeError:
|
|
173
|
+
pass
|
|
174
|
+
raise e
|
|
175
|
+
|
|
176
|
+
def _ensure_full_samples(self) -> list[AgentSample]:
|
|
177
|
+
"""Ensure all samples have traces. Run agent/runner if partial."""
|
|
178
|
+
if not self.dataset.is_partial:
|
|
179
|
+
return self.dataset.all_full
|
|
180
|
+
|
|
181
|
+
partial = self.dataset.all_partial
|
|
182
|
+
if not partial:
|
|
183
|
+
return []
|
|
184
|
+
|
|
185
|
+
if self.agent_runner is not None:
|
|
186
|
+
if hasattr(self.agent_runner, "run_on_dataset"):
|
|
187
|
+
return self.agent_runner.run_on_dataset(partial)
|
|
188
|
+
full = []
|
|
189
|
+
for p in partial:
|
|
190
|
+
text = _to_display_str(p.user_input)
|
|
191
|
+
trace = self.agent_runner.run(text)
|
|
192
|
+
full.append(AgentSample.from_partial(p, trace))
|
|
193
|
+
return full
|
|
194
|
+
|
|
195
|
+
if self.agent is not None:
|
|
196
|
+
from floeval.utils.agent_trace import TraceCollector
|
|
197
|
+
|
|
198
|
+
collector = TraceCollector(self.agent)
|
|
199
|
+
return collector.collect(partial)
|
|
200
|
+
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Dataset has partial samples but no agent or agent_runner provided. "
|
|
203
|
+
"Pass agent (callable) for Mode 2 or agent_runner for Mode 4."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
async def _ensure_full_samples_async(self) -> list[AgentSample]:
|
|
207
|
+
"""Async variant: run agent/runner without nesting asyncio.run()."""
|
|
208
|
+
if not self.dataset.is_partial:
|
|
209
|
+
return self.dataset.all_full
|
|
210
|
+
|
|
211
|
+
partial = self.dataset.all_partial
|
|
212
|
+
if not partial:
|
|
213
|
+
return []
|
|
214
|
+
|
|
215
|
+
if self.agent_runner is not None:
|
|
216
|
+
runner = self.agent_runner
|
|
217
|
+
if hasattr(runner, "run_on_dataset_async"):
|
|
218
|
+
return await runner.run_on_dataset_async(partial)
|
|
219
|
+
if hasattr(runner, "run_on_dataset"):
|
|
220
|
+
# Sync run_on_dataset uses asyncio.run() internally; we're already in a loop.
|
|
221
|
+
# Run it in a thread so asyncio.run() gets a fresh loop in that thread.
|
|
222
|
+
loop = asyncio.get_running_loop()
|
|
223
|
+
return await loop.run_in_executor(
|
|
224
|
+
None, lambda: runner.run_on_dataset(partial)
|
|
225
|
+
)
|
|
226
|
+
full = []
|
|
227
|
+
for p in partial:
|
|
228
|
+
text = _to_display_str(p.user_input)
|
|
229
|
+
if hasattr(runner, "arun") and asyncio.iscoroutinefunction(runner.arun):
|
|
230
|
+
trace = await runner.arun(text)
|
|
231
|
+
else:
|
|
232
|
+
loop = asyncio.get_running_loop()
|
|
233
|
+
trace = await loop.run_in_executor(None, lambda t=text: runner.run(t))
|
|
234
|
+
full.append(AgentSample.from_partial(p, trace))
|
|
235
|
+
return full
|
|
236
|
+
|
|
237
|
+
if self.agent is not None:
|
|
238
|
+
from floeval.utils.agent_trace import TraceCollector
|
|
239
|
+
|
|
240
|
+
collector = TraceCollector(self.agent)
|
|
241
|
+
loop = asyncio.get_running_loop()
|
|
242
|
+
return await loop.run_in_executor(None, lambda: collector.collect(partial))
|
|
243
|
+
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"Dataset has partial samples but no agent or agent_runner provided. "
|
|
246
|
+
"Pass agent (callable) for Mode 2 or agent_runner for Mode 4."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def run(self) -> AgentEvaluationResult:
|
|
250
|
+
"""Run evaluation synchronously."""
|
|
251
|
+
return asyncio.run(self.arun())
|
|
252
|
+
|
|
253
|
+
async def arun(self) -> AgentEvaluationResult:
|
|
254
|
+
"""Run evaluation asynchronously."""
|
|
255
|
+
full_samples = await self._ensure_full_samples_async()
|
|
256
|
+
|
|
257
|
+
if full_samples and self.dataset.is_partial:
|
|
258
|
+
captured = [
|
|
259
|
+
{
|
|
260
|
+
"user_input": s.user_input,
|
|
261
|
+
"trace": {
|
|
262
|
+
"messages": [m.model_dump() for m in s.trace.messages],
|
|
263
|
+
"final_response": s.trace.final_response,
|
|
264
|
+
},
|
|
265
|
+
}
|
|
266
|
+
for s in full_samples
|
|
267
|
+
]
|
|
268
|
+
logger.debug(
|
|
269
|
+
"Captured traces (full dataset for evaluation): %s",
|
|
270
|
+
json.dumps(captured, indent=2, default=str),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if not full_samples:
|
|
274
|
+
return AgentEvaluationResult(
|
|
275
|
+
sample_results=[],
|
|
276
|
+
summary={"error": "No full samples to evaluate"},
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
sample_results: list[dict[str, Any]] = []
|
|
280
|
+
metric_scores: dict[str, list[float]] = {}
|
|
281
|
+
|
|
282
|
+
for sample in full_samples:
|
|
283
|
+
row: dict[str, Any] = {
|
|
284
|
+
"user_input": sample.user_input,
|
|
285
|
+
"final_response": sample.trace.final_response,
|
|
286
|
+
"reference_outcome": sample.reference_outcome,
|
|
287
|
+
"metrics": {},
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
for metric in self._resolved_metrics:
|
|
291
|
+
name = getattr(metric, "name", metric.__class__.__name__)
|
|
292
|
+
try:
|
|
293
|
+
if asyncio.iscoroutinefunction(getattr(metric, "aevaluate", None)):
|
|
294
|
+
result = await metric.aevaluate(sample)
|
|
295
|
+
else:
|
|
296
|
+
result = metric.evaluate(sample)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.exception("Metric %s failed for sample: %s", name, e)
|
|
299
|
+
result = MetricResult(score=None, metadata={"error": str(e)})
|
|
300
|
+
|
|
301
|
+
row["metrics"][name] = {
|
|
302
|
+
"score": result.score,
|
|
303
|
+
"metadata": result.metadata,
|
|
304
|
+
}
|
|
305
|
+
if result.score is not None:
|
|
306
|
+
metric_scores.setdefault(name, []).append(result.score)
|
|
307
|
+
|
|
308
|
+
sample_results.append(row)
|
|
309
|
+
|
|
310
|
+
summary = {}
|
|
311
|
+
for name, scores in metric_scores.items():
|
|
312
|
+
if scores:
|
|
313
|
+
summary[name] = sum(scores) / len(scores)
|
|
314
|
+
|
|
315
|
+
return AgentEvaluationResult(
|
|
316
|
+
sample_results=sample_results,
|
|
317
|
+
summary=summary,
|
|
318
|
+
)
|
floeval/api/dataset.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Dataset and Sample classes.
|
|
2
|
+
|
|
3
|
+
The RAGAS adapter supports Pydantic models via `model_dump()`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Sequence, Union
|
|
11
|
+
|
|
12
|
+
from floeval.api.dataset_loaders.base import BaseDatasetLoader
|
|
13
|
+
from floeval.api.dataset_loaders.local_file_loader import get_loader_for_file
|
|
14
|
+
from floeval.config.schemas.io.dataset import (
|
|
15
|
+
Dataset,
|
|
16
|
+
PartialDataset,
|
|
17
|
+
PartialSample,
|
|
18
|
+
Sample,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dataset_from_dict(
|
|
23
|
+
data: Dict[str, Any], partial_dataset: bool
|
|
24
|
+
) -> Dataset | PartialDataset:
|
|
25
|
+
"""Create dataset from a dict object.
|
|
26
|
+
|
|
27
|
+
Dataset should be of the form:
|
|
28
|
+
{
|
|
29
|
+
"samples": [
|
|
30
|
+
{
|
|
31
|
+
"user_input": "What is the capital of France?",
|
|
32
|
+
"llm_response": "The capital of France is Paris."
|
|
33
|
+
},
|
|
34
|
+
...
|
|
35
|
+
]
|
|
36
|
+
}
|
|
37
|
+
Note: In case of partial dataset, the llm_response field can be omitted or set to empty string.
|
|
38
|
+
The loader will handle it accordingly.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
data: A dict containing the dataset information, typically loaded from a JSON file.
|
|
42
|
+
partial_dataset: flag to return PartialDataset instead of Dataset.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A Dataset or PartialDataset instance.
|
|
46
|
+
"""
|
|
47
|
+
raw_samples = data.get("samples", [])
|
|
48
|
+
if partial_dataset:
|
|
49
|
+
samples = [PartialSample(**s) for s in raw_samples]
|
|
50
|
+
return PartialDataset(samples=samples)
|
|
51
|
+
samples = [Sample(**s) for s in raw_samples]
|
|
52
|
+
return Dataset(samples=samples)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def dataset_from_json(
|
|
56
|
+
path: Union[str, Path], partial_dataset: bool
|
|
57
|
+
) -> Dataset | PartialDataset:
|
|
58
|
+
"""Load dataset from JSON file."""
|
|
59
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
60
|
+
data = json.load(f)
|
|
61
|
+
return dataset_from_dict(data, partial_dataset=partial_dataset)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def dataset_from_file(
|
|
65
|
+
ds_path: Union[str, Path], partial_dataset: bool
|
|
66
|
+
) -> Dataset | PartialDataset:
|
|
67
|
+
"""Convenience method to load dataset from file with type detection."""
|
|
68
|
+
ds_path = Path(ds_path)
|
|
69
|
+
if not ds_path.is_file():
|
|
70
|
+
raise FileNotFoundError(f"Dataset file not found: {ds_path}")
|
|
71
|
+
|
|
72
|
+
loader_cls: type[BaseDatasetLoader] = get_loader_for_file(str(ds_path))
|
|
73
|
+
if partial_dataset:
|
|
74
|
+
samples = loader_cls.to_partial_samples(ds_path)
|
|
75
|
+
return PartialDataset(samples=samples)
|
|
76
|
+
else:
|
|
77
|
+
samples = loader_cls.to_samples(ds_path)
|
|
78
|
+
return Dataset(samples=samples)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def dataset_from_samples(
|
|
82
|
+
samples: Sequence[Union[Sample, Dict[str, Any]]], partial_dataset: bool
|
|
83
|
+
) -> Dataset | PartialDataset:
|
|
84
|
+
"""Convenience helper: build dataset from Sample objects or flat dicts.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
samples: Sequence of Sample/PartialSample objects or dicts convertible to Sample.
|
|
88
|
+
partial_dataset: If True, convert to PartialSample and return PartialDataset.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Dataset or PartialDataset containing the provided samples.
|
|
92
|
+
"""
|
|
93
|
+
if partial_dataset:
|
|
94
|
+
normalized_samples = []
|
|
95
|
+
for s in samples:
|
|
96
|
+
if isinstance(s, Sample):
|
|
97
|
+
normalized_samples.append(PartialSample(**s.model_dump()))
|
|
98
|
+
elif isinstance(s, PartialSample):
|
|
99
|
+
normalized_samples.append(s)
|
|
100
|
+
elif isinstance(s, dict):
|
|
101
|
+
normalized_samples.append(PartialSample(**s))
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Invalid sample type: {type(s)}; expected PartialSample, or dict"
|
|
105
|
+
)
|
|
106
|
+
return PartialDataset(samples=normalized_samples)
|
|
107
|
+
|
|
108
|
+
# --- complete dataset case (containing llm_response field) --
|
|
109
|
+
normalized_samples = []
|
|
110
|
+
for s in samples:
|
|
111
|
+
if isinstance(s, Sample):
|
|
112
|
+
normalized_samples.append(s)
|
|
113
|
+
elif isinstance(s, dict):
|
|
114
|
+
normalized_samples.append(Sample(**s))
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Invalid sample type: {type(s)}; expected Sample or dict")
|
|
117
|
+
return Dataset(samples=normalized_samples)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class DatasetLoader:
|
|
121
|
+
"""Unified interface for loading datasets from various sources."""
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def from_dict(
|
|
125
|
+
data: Dict[str, Any], partial_dataset=False
|
|
126
|
+
) -> Dataset | PartialDataset:
|
|
127
|
+
"""Create dataset from a dict shaped like PRD examples."""
|
|
128
|
+
return dataset_from_dict(data, partial_dataset=partial_dataset)
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def from_json(
|
|
132
|
+
path: Union[str, Path], partial_dataset=False
|
|
133
|
+
) -> Dataset | PartialDataset:
|
|
134
|
+
"""Load dataset from JSON file."""
|
|
135
|
+
return dataset_from_json(path, partial_dataset=partial_dataset)
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def from_file(
|
|
139
|
+
ds_path: Union[str, Path], partial_dataset=False
|
|
140
|
+
) -> Dataset | PartialDataset:
|
|
141
|
+
"""Convenience method to load dataset from file with type detection."""
|
|
142
|
+
return dataset_from_file(ds_path, partial_dataset=partial_dataset)
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def from_samples(
|
|
146
|
+
samples: Sequence[Union[Sample, Dict[str, Any]]], partial_dataset=False
|
|
147
|
+
) -> Dataset | PartialDataset:
|
|
148
|
+
"""Build dataset from Sample objects or flat dicts."""
|
|
149
|
+
return dataset_from_samples(samples, partial_dataset=partial_dataset)
|
|
File without changes
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Dataset loader for agent evaluation with robust error handling."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from floeval.config.schemas.io.agent_dataset import (
|
|
8
|
+
AgentDataset,
|
|
9
|
+
AgentSample,
|
|
10
|
+
AgentTrace,
|
|
11
|
+
AIMessage,
|
|
12
|
+
HumanMessage,
|
|
13
|
+
PartialAgentSample,
|
|
14
|
+
ToolCall,
|
|
15
|
+
ToolMessage,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DatasetLoadError(Exception):
|
|
22
|
+
"""Raised when dataset loading fails."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AgentDatasetLoader:
|
|
26
|
+
"""Load agent datasets from files."""
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def from_file(path: str | Path) -> AgentDataset:
|
|
30
|
+
"""Load dataset from JSON or JSONL file.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
FileNotFoundError: File does not exist.
|
|
34
|
+
DatasetLoadError: File format invalid or parsing failed.
|
|
35
|
+
"""
|
|
36
|
+
path = Path(path)
|
|
37
|
+
|
|
38
|
+
if not path.exists():
|
|
39
|
+
raise FileNotFoundError(f"Dataset file not found: {path}")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
if path.suffix == ".jsonl":
|
|
43
|
+
return AgentDatasetLoader._load_jsonl(path)
|
|
44
|
+
if path.suffix == ".json":
|
|
45
|
+
return AgentDatasetLoader._load_json(path)
|
|
46
|
+
raise DatasetLoadError(
|
|
47
|
+
f"Unsupported format: {path.suffix}. Use .json or .jsonl"
|
|
48
|
+
)
|
|
49
|
+
except DatasetLoadError:
|
|
50
|
+
raise
|
|
51
|
+
except Exception as e:
|
|
52
|
+
raise DatasetLoadError(f"Failed to load dataset: {e}") from e
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def _parse_reference_tool_calls(data: dict) -> list[ToolCall] | None:
|
|
56
|
+
"""Parse reference tool calls from data dict."""
|
|
57
|
+
raw = data.get("reference_tool_calls")
|
|
58
|
+
if not raw:
|
|
59
|
+
return None
|
|
60
|
+
return [ToolCall(**tc) for tc in raw]
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _load_jsonl(path: Path) -> AgentDataset:
|
|
64
|
+
"""Load JSONL file."""
|
|
65
|
+
samples = []
|
|
66
|
+
|
|
67
|
+
with open(path, encoding="utf-8") as f:
|
|
68
|
+
for line_num, line in enumerate(f, start=1):
|
|
69
|
+
line = line.strip()
|
|
70
|
+
if not line:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
data = json.loads(line)
|
|
75
|
+
sample = AgentDatasetLoader._parse_sample(data)
|
|
76
|
+
samples.append(sample)
|
|
77
|
+
except json.JSONDecodeError as e:
|
|
78
|
+
raise DatasetLoadError(
|
|
79
|
+
f"Invalid JSON on line {line_num}: {e}"
|
|
80
|
+
) from e
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise DatasetLoadError(
|
|
83
|
+
f"Error parsing sample on line {line_num}: {e}"
|
|
84
|
+
) from e
|
|
85
|
+
|
|
86
|
+
if not samples:
|
|
87
|
+
raise DatasetLoadError(f"No valid samples found in {path}")
|
|
88
|
+
|
|
89
|
+
return AgentDataset(samples=samples)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _load_json(path: Path) -> AgentDataset:
|
|
93
|
+
"""Load JSON file."""
|
|
94
|
+
with open(path, encoding="utf-8") as f:
|
|
95
|
+
data = json.load(f)
|
|
96
|
+
|
|
97
|
+
if not isinstance(data, dict):
|
|
98
|
+
raise DatasetLoadError("JSON must be an object with 'samples' array")
|
|
99
|
+
|
|
100
|
+
if "samples" not in data:
|
|
101
|
+
raise DatasetLoadError("JSON must have 'samples' key")
|
|
102
|
+
|
|
103
|
+
if not isinstance(data["samples"], list):
|
|
104
|
+
raise DatasetLoadError("'samples' must be an array")
|
|
105
|
+
|
|
106
|
+
samples = [
|
|
107
|
+
AgentDatasetLoader._parse_sample(s) for s in data["samples"]
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
if not samples:
|
|
111
|
+
raise DatasetLoadError("'samples' array is empty")
|
|
112
|
+
|
|
113
|
+
return AgentDataset(samples=samples)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _parse_sample(data: dict) -> AgentSample | PartialAgentSample:
|
|
117
|
+
"""Parse dict to sample (auto-detect partial vs full)."""
|
|
118
|
+
if "trace" not in data:
|
|
119
|
+
return PartialAgentSample(
|
|
120
|
+
user_input=data["user_input"],
|
|
121
|
+
reference_outcome=data.get("reference_outcome"),
|
|
122
|
+
reference_tool_calls=AgentDatasetLoader._parse_reference_tool_calls(data),
|
|
123
|
+
metadata=data.get("metadata", {}),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
trace_data = data["trace"]
|
|
127
|
+
messages = []
|
|
128
|
+
|
|
129
|
+
for msg_data in trace_data["messages"]:
|
|
130
|
+
role = msg_data.get("role", "")
|
|
131
|
+
|
|
132
|
+
if role == "human":
|
|
133
|
+
messages.append(
|
|
134
|
+
HumanMessage(content=msg_data.get("content", ""))
|
|
135
|
+
)
|
|
136
|
+
elif role == "ai":
|
|
137
|
+
tool_calls = [
|
|
138
|
+
ToolCall(**tc) for tc in msg_data.get("tool_calls", [])
|
|
139
|
+
]
|
|
140
|
+
messages.append(
|
|
141
|
+
AIMessage(
|
|
142
|
+
content=msg_data.get("content", ""),
|
|
143
|
+
tool_calls=tool_calls,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
elif role == "tool":
|
|
147
|
+
messages.append(
|
|
148
|
+
ToolMessage(
|
|
149
|
+
content=msg_data.get("content", ""),
|
|
150
|
+
tool_name=msg_data.get("tool_name", ""),
|
|
151
|
+
tool_call_id=msg_data.get("tool_call_id"),
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(f"Unknown role: {role}")
|
|
156
|
+
|
|
157
|
+
trace = AgentTrace(
|
|
158
|
+
messages=messages,
|
|
159
|
+
final_response=trace_data.get("final_response", ""),
|
|
160
|
+
metadata=trace_data.get("metadata", {}),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return AgentSample(
|
|
164
|
+
user_input=data["user_input"],
|
|
165
|
+
trace=trace,
|
|
166
|
+
reference_outcome=data.get("reference_outcome"),
|
|
167
|
+
reference_tool_calls=AgentDatasetLoader._parse_reference_tool_calls(data),
|
|
168
|
+
metadata=data.get("metadata", {}),
|
|
169
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from floeval.config.schemas.io.dataset import PartialSample, Sample
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseDatasetLoader(ABC):
|
|
7
|
+
"""Abstract base class for dataset loaders."""
|
|
8
|
+
|
|
9
|
+
@classmethod
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def to_samples(cls, *args, **kwargs) -> list[Sample]:
|
|
12
|
+
"""Abstractmethod to convert raw data into a list of Sample objects.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
list[Sample]: A list of Sample objects representing the dataset.
|
|
16
|
+
"""
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def to_partial_samples(cls, *args, **kwargs) -> list[PartialSample]:
|
|
22
|
+
"""Abstractmethod to convert raw data into a list of PartialSample objects.
|
|
23
|
+
|
|
24
|
+
PartialSample allows empty llm_response field (for PartialDataset creation)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
list[PartialSample]: A list of PartialSample objects with llm_response field empty.
|
|
28
|
+
"""
|
|
29
|
+
...
|