spooling 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,407 @@
1
+ """Experiments: scenario-based evaluation with Strands Cases + Evaluators.
2
+
3
+ This module wraps `strands_evals.Experiment` so you can define a bundle of
4
+ test cases + evaluators once, run them against any task function (typically
5
+ a Strands Agent you want to grade), and persist the reports into Spooling.
6
+
7
+ Two execution patterns are supported:
8
+
9
+ 1. **Plain** — `task_fn(case) -> str` just produces an output string per
10
+ case. The evaluators score those outputs.
11
+ 2. **Simulated** — `task_fn` is built from `ActorSimulator.from_case_for_user_simulator`
12
+ to drive a multi-turn back-and-forth between a user persona and the
13
+ agent under test. The returned trajectory feeds the trace-level
14
+ evaluators (Helpfulness, Trajectory, GoalSuccessRate).
15
+
16
+ Experiments and their runs land in two tables (see migrations/004_experiments.sql):
17
+ `experiments` stores the catalog; `experiment_runs` stores each run's
18
+ reports plus the trace ids generated along the way so you can click
19
+ through from a run to the Spooling trace it produced.
20
+
21
+ CLI:
22
+ spooling experiment create --file cases.json
23
+ spooling experiment list
24
+ spooling experiment run --id <exp-id>
25
+ spooling experiment show --run <run-id>
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import json
31
+ import uuid
32
+ from dataclasses import dataclass
33
+ from datetime import datetime, timezone
34
+ from typing import Any, Callable, Optional
35
+
36
+ from spooling.db import get_connection
37
+
38
+
39
+ # --- Strands imports deferred to runtime so evals.py can import us freely ---
40
+
41
+ def _strands_eval_classes():
42
+ from strands_evals import Case, Experiment
43
+ from strands_evals import evaluators as ev_mod
44
+ from strands_evals.types import EvaluationData
45
+ return Case, Experiment, ev_mod, EvaluationData
46
+
47
+
48
+ def _ollama_judge_model():
49
+ """Reuse spool.evals judge model config so experiments get the same Ollama + qwen default."""
50
+ from spooling.evals import _judge_config, _make_ollama_model
51
+ return _make_ollama_model(_judge_config())
52
+
53
+
54
+ # --- Catalog operations ----------------------------------------------------
55
+
56
+ @dataclass
57
+ class ExperimentSpec:
58
+ id: str
59
+ name: str
60
+ description: Optional[str]
61
+ cases: list[dict] # [{name, input, expected_output?, metadata?}, ...]
62
+ evaluators: list[dict] # [{type: "HelpfulnessEvaluator"}, {type: "OutputEvaluator", rubric: "..."}]
63
+ config: dict # optional knobs (e.g. {"simulated": true, "max_turns": 5})
64
+
65
+
66
+ def create_experiment(spec: ExperimentSpec) -> str:
67
+ """Persist an experiment spec. Returns the id."""
68
+ conn = get_connection()
69
+ try:
70
+ conn.execute(
71
+ """INSERT INTO experiments (id, name, description, cases, evaluators, config)
72
+ VALUES (%s, %s, %s, %s, %s, %s)
73
+ ON CONFLICT (id) DO UPDATE SET
74
+ name = EXCLUDED.name,
75
+ description = EXCLUDED.description,
76
+ cases = EXCLUDED.cases,
77
+ evaluators = EXCLUDED.evaluators,
78
+ config = EXCLUDED.config""",
79
+ (
80
+ spec.id, spec.name, spec.description,
81
+ json.dumps(spec.cases),
82
+ json.dumps(spec.evaluators),
83
+ json.dumps(spec.config),
84
+ ),
85
+ )
86
+ conn.commit()
87
+ finally:
88
+ conn.close()
89
+ return spec.id
90
+
91
+
92
+ def load_experiment(experiment_id: str) -> Optional[ExperimentSpec]:
93
+ conn = get_connection()
94
+ try:
95
+ row = conn.execute(
96
+ "SELECT * FROM experiments WHERE id = %s", (experiment_id,)
97
+ ).fetchone()
98
+ finally:
99
+ conn.close()
100
+ if not row:
101
+ return None
102
+ return ExperimentSpec(
103
+ id=row["id"],
104
+ name=row["name"],
105
+ description=row.get("description"),
106
+ cases=row.get("cases") or [],
107
+ evaluators=row.get("evaluators") or [],
108
+ config=row.get("config") or {},
109
+ )
110
+
111
+
112
+ def list_experiments() -> list[dict]:
113
+ conn = get_connection()
114
+ try:
115
+ rows = conn.execute(
116
+ """SELECT id, name, description, created_at,
117
+ jsonb_array_length(cases) AS case_count,
118
+ jsonb_array_length(evaluators) AS evaluator_count
119
+ FROM experiments ORDER BY created_at DESC"""
120
+ ).fetchall()
121
+ finally:
122
+ conn.close()
123
+ return [dict(r) for r in rows]
124
+
125
+
126
+ # --- Factory helpers for Strands Cases / Evaluators ------------------------
127
+
128
+ def _build_strands_cases(spec: ExperimentSpec):
129
+ Case, _Experiment, _ev_mod, _EvaluationData = _strands_eval_classes()
130
+ cases = []
131
+ for i, c in enumerate(spec.cases):
132
+ cases.append(Case(
133
+ name=c.get("name") or f"case-{i}",
134
+ input=c["input"],
135
+ expected_output=c.get("expected_output"),
136
+ metadata=c.get("metadata") or {},
137
+ ))
138
+ return cases
139
+
140
+
141
+ def _build_strands_evaluators(spec: ExperimentSpec) -> list:
142
+ _Case, _Experiment, ev_mod, _EvaluationData = _strands_eval_classes()
143
+ model = _ollama_judge_model()
144
+ evaluators = []
145
+ for e in spec.evaluators:
146
+ type_name = e.get("type")
147
+ if not type_name:
148
+ continue
149
+ cls = getattr(ev_mod, type_name, None)
150
+ if cls is None:
151
+ continue
152
+ if type_name in ("OutputEvaluator", "TrajectoryEvaluator"):
153
+ rubric = e.get("rubric") or (
154
+ "Pass if the output directly and correctly addresses the user's "
155
+ "request. Score 0-1 based on accuracy and completeness."
156
+ )
157
+ evaluators.append(cls(rubric=rubric, model=model))
158
+ else:
159
+ evaluators.append(cls(model=model))
160
+ return evaluators
161
+
162
+
163
+ # --- Task function adapters ------------------------------------------------
164
+
165
+ def _plain_task_fn(experiment_id: str, captured_trace_ids: list) -> Callable:
166
+ """Return a task function that runs a fresh Strands Agent per case.
167
+
168
+ Wraps the agent call in an in-memory OTel exporter + StrandsInMemorySessionMapper
169
+ so the returned payload has both `output` (string) and `trajectory`
170
+ (Strands Session). Trace-level evaluators need the Session; output-
171
+ level evaluators just read the string.
172
+
173
+ As a side effect, each case's captured Session is also ingested into
174
+ Spooling via `remote_otel.ingest_strands_session`, so the experiment's
175
+ runs show up in the /traces page linked to their originating experiment.
176
+ """
177
+ from strands import Agent
178
+ from strands_evals.telemetry import StrandsEvalsTelemetry
179
+ from strands_evals.mappers import StrandsInMemorySessionMapper
180
+ from spooling.remote_otel import ingest_strands_session
181
+
182
+ telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter()
183
+ exporter = telemetry.in_memory_exporter
184
+ mapper = StrandsInMemorySessionMapper()
185
+ model = _ollama_judge_model()
186
+
187
+ def _fn(case):
188
+ session_id = case.session_id or uuid.uuid4().hex
189
+ exporter.clear() if hasattr(exporter, "clear") else None
190
+ agent = Agent(
191
+ model=model,
192
+ trace_attributes={
193
+ "gen_ai.conversation.id": session_id,
194
+ "session.id": session_id,
195
+ },
196
+ callback_handler=None,
197
+ )
198
+ response = agent(case.input)
199
+ try:
200
+ spans = list(exporter.get_finished_spans())
201
+ session = mapper.map_to_session(spans, session_id=session_id)
202
+ except Exception:
203
+ session = None
204
+
205
+ if session is not None:
206
+ try:
207
+ tid = ingest_strands_session(
208
+ session,
209
+ provider_id=f"experiment:{experiment_id}",
210
+ project=None,
211
+ title=f"{case.name or 'case'}: {str(case.input)[:60]}",
212
+ )
213
+ captured_trace_ids.append(tid)
214
+ except Exception as e:
215
+ print(f"[spooling.experiments] trace ingest failed: {type(e).__name__}: {e}")
216
+
217
+ return {"output": str(response), "trajectory": session}
218
+
219
+ return _fn
220
+
221
+
222
+ def _simulated_task_fn(max_turns: int = 5) -> Callable:
223
+ """Return a task function that runs a simulated multi-turn conversation.
224
+
225
+ Uses ActorSimulator so a user persona drives the conversation against
226
+ a Strands Agent for up to `max_turns` turns. The final agent response
227
+ is returned. Traces generated during the conversation land in Spooling
228
+ via the live `spooling.sdk` tracer if configured.
229
+ """
230
+ from strands import Agent
231
+ from strands_evals import ActorSimulator
232
+
233
+ model = _ollama_judge_model()
234
+
235
+ def _fn(case):
236
+ try:
237
+ simulator = ActorSimulator.from_case_for_user_simulator(
238
+ case=case, max_turns=max_turns,
239
+ )
240
+ except Exception:
241
+ simulator = None
242
+
243
+ agent = Agent(model=model)
244
+ user_message = case.input
245
+ agent_response = ""
246
+ turns = 0
247
+
248
+ while True:
249
+ resp = agent(user_message)
250
+ agent_response = str(resp)
251
+ turns += 1
252
+ if simulator is None or turns >= max_turns:
253
+ break
254
+ if not simulator.has_next():
255
+ break
256
+ try:
257
+ user_result = simulator.act(agent_response)
258
+ user_message = str(user_result.structured_output.message)
259
+ except Exception:
260
+ break
261
+
262
+ return agent_response
263
+
264
+ return _fn
265
+
266
+
267
+ # --- Running ---------------------------------------------------------------
268
+
269
+ def run_experiment(experiment_id: str) -> str:
270
+ """Run an experiment and persist a new experiment_runs row. Returns the run id."""
271
+ spec = load_experiment(experiment_id)
272
+ if spec is None:
273
+ raise ValueError(f"Unknown experiment: {experiment_id}")
274
+
275
+ _Case, Experiment, _ev_mod, _EvaluationData = _strands_eval_classes()
276
+
277
+ cases = _build_strands_cases(spec)
278
+ evaluators = _build_strands_evaluators(spec)
279
+
280
+ if not cases or not evaluators:
281
+ raise ValueError(
282
+ "Experiment needs at least one case and one evaluator."
283
+ )
284
+
285
+ simulated = bool((spec.config or {}).get("simulated"))
286
+ max_turns = int((spec.config or {}).get("max_turns") or 5)
287
+ captured_trace_ids: list[str] = []
288
+ task_fn = (
289
+ _simulated_task_fn(max_turns=max_turns)
290
+ if simulated
291
+ else _plain_task_fn(experiment_id=spec.id, captured_trace_ids=captured_trace_ids)
292
+ )
293
+
294
+ run_id = f"run-{uuid.uuid4().hex[:12]}"
295
+ conn = get_connection()
296
+ try:
297
+ conn.execute(
298
+ """INSERT INTO experiment_runs (id, experiment_id, status)
299
+ VALUES (%s, %s, 'running')""",
300
+ (run_id, experiment_id),
301
+ )
302
+ conn.commit()
303
+
304
+ try:
305
+ experiment = Experiment(cases=cases, evaluators=evaluators)
306
+ reports = experiment.run_evaluations(task_fn)
307
+ except Exception as e:
308
+ conn.execute(
309
+ """UPDATE experiment_runs
310
+ SET status = 'error', finished_at = now(), error = %s
311
+ WHERE id = %s""",
312
+ (str(e)[:1000], run_id),
313
+ )
314
+ conn.commit()
315
+ raise
316
+
317
+ # Serialize reports into JSONB. EvaluationReport is a Pydantic
318
+ # model; use .model_dump() and convert any Decimal/datetime via default=str.
319
+ reports_payload = [_report_to_dict(r) for r in reports]
320
+ overall = {
321
+ r.evaluator_name: (
322
+ float(r.overall_score) if r.overall_score is not None else None
323
+ )
324
+ for r in reports
325
+ }
326
+
327
+ conn.execute(
328
+ """UPDATE experiment_runs
329
+ SET status = 'complete',
330
+ finished_at = now(),
331
+ reports = %s,
332
+ overall_scores = %s,
333
+ created_trace_ids = %s
334
+ WHERE id = %s""",
335
+ (json.dumps(reports_payload, default=str),
336
+ json.dumps(overall, default=str),
337
+ json.dumps(captured_trace_ids),
338
+ run_id),
339
+ )
340
+ conn.commit()
341
+ finally:
342
+ conn.close()
343
+ return run_id
344
+
345
+
346
+ def _report_to_dict(report) -> dict:
347
+ """Serialize a Strands EvaluationReport for JSONB storage."""
348
+ try:
349
+ return report.model_dump()
350
+ except Exception:
351
+ return {
352
+ "evaluator_name": getattr(report, "evaluator_name", None),
353
+ "overall_score": getattr(report, "overall_score", None),
354
+ "scores": list(getattr(report, "scores", []) or []),
355
+ "test_passes": list(getattr(report, "test_passes", []) or []),
356
+ "reasons": list(getattr(report, "reasons", []) or []),
357
+ }
358
+
359
+
360
+ def load_run(run_id: str) -> Optional[dict]:
361
+ conn = get_connection()
362
+ try:
363
+ row = conn.execute(
364
+ "SELECT * FROM experiment_runs WHERE id = %s", (run_id,)
365
+ ).fetchone()
366
+ finally:
367
+ conn.close()
368
+ return dict(row) if row else None
369
+
370
+
371
+ def list_runs(experiment_id: Optional[str] = None, limit: int = 20) -> list[dict]:
372
+ conn = get_connection()
373
+ try:
374
+ if experiment_id:
375
+ rows = conn.execute(
376
+ """SELECT id, experiment_id, started_at, finished_at,
377
+ status, overall_scores
378
+ FROM experiment_runs WHERE experiment_id = %s
379
+ ORDER BY started_at DESC LIMIT %s""",
380
+ (experiment_id, limit),
381
+ ).fetchall()
382
+ else:
383
+ rows = conn.execute(
384
+ """SELECT id, experiment_id, started_at, finished_at,
385
+ status, overall_scores
386
+ FROM experiment_runs
387
+ ORDER BY started_at DESC LIMIT %s""",
388
+ (limit,),
389
+ ).fetchall()
390
+ finally:
391
+ conn.close()
392
+ return [dict(r) for r in rows]
393
+
394
+
395
+ # --- File-based spec loader (for CLI `spooling experiment create --file`) ----
396
+
397
+ def load_spec_from_file(path: str) -> ExperimentSpec:
398
+ with open(path) as f:
399
+ data = json.load(f)
400
+ return ExperimentSpec(
401
+ id=data.get("id") or f"exp-{uuid.uuid4().hex[:10]}",
402
+ name=data["name"],
403
+ description=data.get("description"),
404
+ cases=data.get("cases", []),
405
+ evaluators=data.get("evaluators", []),
406
+ config=data.get("config", {}),
407
+ )