spooling 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spooling/evals.py ADDED
@@ -0,0 +1,611 @@
1
+ """Spooling eval runner backed by the Strands Evals SDK.
2
+
3
+ The rubric catalog lives in the `eval_rubrics` table. Each row either names
4
+ a Strands `Evaluator` subclass (`evaluator_type` column) or is a function
5
+ rubric handled by our own registry below. When a rubric is run we:
6
+
7
+ 1. Load the rubric row + the target trace (or span) + its descendants.
8
+ 2. Build a Strands `EvaluationData` from the trace by collecting the first
9
+ user message as `input`, the assistant's final output as `actual_output`,
10
+ and the tool-name sequence as `actual_trajectory`.
11
+ 3. Instantiate the Strands evaluator with an Ollama model (gemma by default)
12
+ so it works out of the box with no API key required.
13
+ 4. Call `evaluator.evaluate(data)` and persist the first `EvaluationOutput`
14
+ into our `evals` table (score / test_pass / reason / label).
15
+
16
+ Function rubrics (deterministic, non-LLM) still live in the
17
+ `_FUNCTION_RUBRICS` registry. They short-circuit the Strands path.
18
+
19
+ Custom rubrics are just rows with `evaluator_type='OutputEvaluator'` and a
20
+ `rubric_text` value — anyone can add one via POST /api/evals/rubrics.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import json
26
+ import os
27
+ from dataclasses import dataclass
28
+ from datetime import datetime, timezone
29
+ from typing import Any, Callable, Optional
30
+
31
+ from spooling.db import get_connection
32
+
33
+
34
+ # --- function rubric registry (deterministic) ------------------------------
35
+
36
+ FunctionGrader = Callable[[dict, list[dict]], "EvalResult"]
37
+ _FUNCTION_RUBRICS: dict[str, FunctionGrader] = {}
38
+
39
+
40
+ @dataclass
41
+ class EvalResult:
42
+ score: float | None = None
43
+ passed: bool | None = None
44
+ label: str | None = None
45
+ rationale: str | None = None
46
+ attrs: dict[str, Any] | None = None
47
+
48
+
49
+ def register_function_rubric(rubric_id: str):
50
+ def _wrap(fn: FunctionGrader) -> FunctionGrader:
51
+ _FUNCTION_RUBRICS[rubric_id] = fn
52
+ return fn
53
+ return _wrap
54
+
55
+
56
+ PASS_SCORE_THRESHOLD = 0.95
57
+
58
+
59
+ @register_function_rubric("tool-error-rate")
60
+ def _tool_error_rate(target: dict, children: list[dict]) -> EvalResult:
61
+ tool_spans = [c for c in children if c["kind"] == "tool"]
62
+ if not tool_spans:
63
+ return EvalResult(
64
+ score=None, passed=None, label="no-tools",
65
+ rationale="No tool spans in trace — skipped.",
66
+ )
67
+ errors = sum(1 for t in tool_spans if t.get("tool_is_error") or t.get("status") == "error")
68
+ rate = errors / len(tool_spans)
69
+ score = round(1 - rate, 3)
70
+ return EvalResult(
71
+ score=score,
72
+ passed=score >= PASS_SCORE_THRESHOLD,
73
+ label=f"{errors}/{len(tool_spans)} errors",
74
+ rationale=f"Tool error rate: {rate:.1%} (pass threshold: {1 - PASS_SCORE_THRESHOLD:.0%}).",
75
+ attrs={"tool_count": len(tool_spans), "error_count": errors},
76
+ )
77
+
78
+
79
+ # --- Strands evaluator factory ---------------------------------------------
80
+
81
+ # Default Ollama host + model. Overridable via spool-agent settings row.
82
+ #
83
+ # We default to `qwen2.5:7b` because Strands evaluators use tool-calling
84
+ # under the hood to return structured output, and Ollama's gemma3 family
85
+ # does not expose tool support. The 7b size is the smallest local model
86
+ # that's reliable at structured output under pressure from the evaluator
87
+ # prompts. Users can override via the SPOOLING_JUDGE_MODEL env var or the
88
+ # `spool-agent` settings row (judge_model field).
89
+ DEFAULT_OLLAMA_HOST = os.environ.get("SPOOLING_OLLAMA_HOST", "http://localhost:11434")
90
+ DEFAULT_JUDGE_MODEL = os.environ.get("SPOOLING_JUDGE_MODEL", "qwen2.5:7b")
91
+
92
+
93
+ def _judge_config() -> dict:
94
+ """Load judge config from the spool-agent providers row, or fall back."""
95
+ try:
96
+ conn = get_connection()
97
+ row = conn.execute(
98
+ "SELECT config FROM providers WHERE id = 'spooling-agent'"
99
+ ).fetchone()
100
+ conn.close()
101
+ except Exception:
102
+ row = None
103
+ cfg = (row or {}).get("config") if row else {}
104
+ if isinstance(cfg, str):
105
+ try:
106
+ cfg = json.loads(cfg)
107
+ except Exception:
108
+ cfg = {}
109
+ return cfg or {}
110
+
111
+
112
+ def _pick_judge_model(cfg: dict) -> str:
113
+ """Pick the Ollama model to use as the Strands judge.
114
+
115
+ Precedence:
116
+ 1. `judge_model` field in the spool-agent settings row (explicit override)
117
+ 2. $SPOOLING_JUDGE_MODEL env var (via DEFAULT_JUDGE_MODEL)
118
+ 3. `qwen2.5:3b` default
119
+
120
+ Note: the `model` field on spool-agent is the *chat* model (often gemma),
121
+ not the judge model. Gemma doesn't support tool-calling so it can't serve
122
+ as a Strands judge even though it works fine for the chat page.
123
+ """
124
+ judge = (cfg or {}).get("judge_model")
125
+ if judge:
126
+ return judge
127
+ return DEFAULT_JUDGE_MODEL
128
+
129
+
130
+ def _make_ollama_model(cfg: dict):
131
+ """Build a Strands OllamaModel pointing at the local daemon."""
132
+ from strands.models.ollama import OllamaModel
133
+
134
+ host = (cfg or {}).get("ollama_url") or DEFAULT_OLLAMA_HOST
135
+ model_id = _pick_judge_model(cfg)
136
+ return OllamaModel(host=host, model_id=model_id)
137
+
138
+
139
+ # Evaluators that auto-parse traces and require a Strands Session object
140
+ # as `actual_trajectory`. Everything else can accept a plain list of tool
141
+ # dicts or nothing at all.
142
+ _SESSION_EVALUATORS = {
143
+ "HelpfulnessEvaluator",
144
+ "CoherenceEvaluator",
145
+ "ConcisenessEvaluator",
146
+ "FaithfulnessEvaluator",
147
+ "HarmfulnessEvaluator",
148
+ "ResponseRelevanceEvaluator",
149
+ "ToolSelectionAccuracyEvaluator",
150
+ "ToolParameterAccuracyEvaluator",
151
+ "GoalSuccessRateEvaluator",
152
+ }
153
+
154
+
155
+ def _evaluator_factory(evaluator_type: str, rubric_text: Optional[str], model):
156
+ """Instantiate a Strands evaluator class by name with the Ollama model."""
157
+ from strands_evals import evaluators as ev
158
+
159
+ cls = getattr(ev, evaluator_type, None)
160
+ if cls is None:
161
+ raise ValueError(f"Unknown Strands evaluator: {evaluator_type}")
162
+
163
+ # OutputEvaluator and TrajectoryEvaluator require a rubric string.
164
+ if evaluator_type in ("OutputEvaluator", "TrajectoryEvaluator"):
165
+ rubric = rubric_text or (
166
+ "Pass if the output directly and correctly addresses the user's "
167
+ "request. Score 0-1 based on accuracy and completeness."
168
+ )
169
+ return cls(rubric=rubric, model=model)
170
+
171
+ return cls(model=model)
172
+
173
+
174
+ # --- Trace → Strands Session / EvaluationData extraction -----------------
175
+
176
+ # Local models fall over on very long evaluator prompts. Cap the number of
177
+ # tool spans we replay per Session so the judge stays reliable, and trim
178
+ # each tool's argument/output payload before handing it off.
179
+ _MAX_TOOL_SPANS_PER_TRACE = 20
180
+ _MAX_TOOL_OUTPUT_CHARS = 400
181
+ _MAX_TOOL_ARG_CHARS = 300
182
+
183
+
184
+ def _trim_tool_input(payload: dict) -> dict:
185
+ out = {}
186
+ for k, v in (payload or {}).items():
187
+ if isinstance(v, str) and len(v) > _MAX_TOOL_ARG_CHARS:
188
+ out[k] = v[:_MAX_TOOL_ARG_CHARS] + "…"
189
+ else:
190
+ out[k] = v
191
+ return out
192
+
193
+
194
+ def _build_strands_session(conn, trace_id: str):
195
+ """Materialize a Strands-shaped Session from our stored spans.
196
+
197
+ The Strands evaluators that auto-parse traces (HelpfulnessEvaluator,
198
+ ToolSelectionAccuracyEvaluator, etc.) require an `actual_trajectory`
199
+ that's a `Session` object, not a list. We translate each spool span
200
+ into the corresponding Strands span type.
201
+ """
202
+ from datetime import datetime, timezone
203
+ from strands_evals.types.trace import (
204
+ Session, Trace as StrandsTrace, InferenceSpan, ToolExecutionSpan,
205
+ AgentInvocationSpan, ToolConfig,
206
+ SpanInfo, SpanType, Role, ContentType, TokenUsage,
207
+ UserMessage, AssistantMessage, TextContent,
208
+ ToolCall, ToolResult, ToolCallContent, ToolResultContent,
209
+ )
210
+
211
+ trace_row = conn.execute(
212
+ "SELECT id, session_id, started_at, ended_at FROM traces WHERE id = %s",
213
+ (trace_id,),
214
+ ).fetchone()
215
+ if not trace_row:
216
+ return None
217
+ session_id = trace_row["session_id"] or trace_id
218
+
219
+ spans = conn.execute(
220
+ """SELECT id, parent_id, kind, name, started_at, ended_at,
221
+ input_tokens, output_tokens, cache_read_tokens, cache_write_tokens,
222
+ model, tool_name, tool_input, tool_output, tool_is_error,
223
+ sequence
224
+ FROM spans WHERE trace_id = %s ORDER BY sequence""",
225
+ (trace_id,),
226
+ ).fetchall()
227
+
228
+ # Pull the user-message content keyed by session_id so we can attach it
229
+ # to the corresponding inference span at the start of the conversation.
230
+ msg_rows = conn.execute(
231
+ """SELECT role, content, timestamp FROM messages
232
+ WHERE session_id = %s
233
+ AND COALESCE(length(trim(content)), 0) > 0
234
+ ORDER BY timestamp ASC NULLS LAST""",
235
+ (session_id,),
236
+ ).fetchall()
237
+ user_messages = [m["content"][:2000] for m in msg_rows if m["role"] == "user"]
238
+ assistant_messages = [m["content"][:2000] for m in msg_rows if m["role"] == "assistant"]
239
+
240
+ def _ts(val) -> datetime:
241
+ if isinstance(val, datetime):
242
+ return val
243
+ return datetime.now(tz=timezone.utc)
244
+
245
+ # Keep the span list tractable for the judge model: drop most tool
246
+ # spans if the trace has many, keeping a representative slice of the
247
+ # first N so tool-level evaluators don't choke on a 144-tool replay.
248
+ tool_span_rows = [sp for sp in spans if sp["kind"] == "tool"]
249
+ if len(tool_span_rows) > _MAX_TOOL_SPANS_PER_TRACE:
250
+ keep_tool_ids = {sp["id"] for sp in tool_span_rows[:_MAX_TOOL_SPANS_PER_TRACE]}
251
+ spans = [sp for sp in spans if sp["kind"] != "tool" or sp["id"] in keep_tool_ids]
252
+
253
+ built_spans = []
254
+ user_ix = 0
255
+ assistant_ix = 0
256
+
257
+ # Top-level AgentInvocationSpan wrapping the whole conversation. Trace-
258
+ # level Strands evaluators (HelpfulnessEvaluator etc.) require at least
259
+ # one of these to anchor the agent-turn input/output.
260
+ if spans:
261
+ top_start = _ts(trace_row["started_at"] or spans[0]["started_at"])
262
+ top_end = _ts(trace_row["ended_at"] or spans[-1]["ended_at"] or spans[-1]["started_at"])
263
+ top_info = SpanInfo(
264
+ trace_id=trace_id,
265
+ span_id=f"{trace_id}-root",
266
+ session_id=session_id,
267
+ parent_span_id=None,
268
+ start_time=top_start,
269
+ end_time=top_end,
270
+ )
271
+ first_user = user_messages[0] if user_messages else "(no user input)"
272
+ last_assistant = assistant_messages[-1] if assistant_messages else "(no assistant output)"
273
+ tool_names = sorted({sp["tool_name"] for sp in spans if sp.get("tool_name")})
274
+ try:
275
+ built_spans.append(AgentInvocationSpan(
276
+ span_info=top_info,
277
+ metadata={},
278
+ span_type=SpanType.AGENT_INVOCATION,
279
+ user_prompt=first_user,
280
+ agent_response=last_assistant,
281
+ available_tools=[ToolConfig(name=n) for n in tool_names],
282
+ ))
283
+ except Exception as e:
284
+ # Bubble the error up so _run_strands_evaluator reports it
285
+ # instead of silently running on an empty session.
286
+ raise RuntimeError(f"AgentInvocationSpan build failed: {e}") from e
287
+
288
+ for sp in spans:
289
+ start = _ts(sp["started_at"])
290
+ end = _ts(sp["ended_at"] or sp["started_at"])
291
+ span_info = SpanInfo(
292
+ trace_id=trace_id,
293
+ span_id=sp["id"],
294
+ session_id=session_id,
295
+ parent_span_id=sp["parent_id"],
296
+ start_time=start,
297
+ end_time=end,
298
+ )
299
+
300
+ if sp["kind"] == "llm_call":
301
+ user_text = user_messages[user_ix] if user_ix < len(user_messages) else "(no user message)"
302
+ assistant_text = (
303
+ assistant_messages[assistant_ix] if assistant_ix < len(assistant_messages)
304
+ else "(no assistant output)"
305
+ )
306
+ user_ix += 1
307
+ assistant_ix += 1
308
+
309
+ messages = [
310
+ UserMessage(role=Role.USER, content=[TextContent(content_type=ContentType.TEXT, text=user_text)]),
311
+ AssistantMessage(role=Role.ASSISTANT, content=[TextContent(content_type=ContentType.TEXT, text=assistant_text)]),
312
+ ]
313
+ try:
314
+ built_spans.append(InferenceSpan(
315
+ span_info=span_info,
316
+ metadata={"model": sp.get("model") or ""},
317
+ span_type=SpanType.INFERENCE,
318
+ messages=messages,
319
+ ))
320
+ except Exception:
321
+ # If InferenceSpan pydantic validation rejects the shape on a
322
+ # particular version, skip the span rather than fail the trace.
323
+ continue
324
+
325
+ elif sp["kind"] == "tool" and sp.get("tool_name"):
326
+ tool_input = sp.get("tool_input")
327
+ if isinstance(tool_input, str):
328
+ try:
329
+ tool_input = json.loads(tool_input)
330
+ except Exception:
331
+ tool_input = {"raw": tool_input}
332
+ if not isinstance(tool_input, dict):
333
+ tool_input = {}
334
+
335
+ tool_output = sp.get("tool_output") or ""
336
+ is_error = bool(sp.get("tool_is_error"))
337
+ # ToolResult.error is a string (error message), not a bool.
338
+ # On success we pass None; on failure we reuse the tool output.
339
+ error_text = tool_output[:_MAX_TOOL_OUTPUT_CHARS] if is_error else None
340
+ content_text = "" if is_error else tool_output[:_MAX_TOOL_OUTPUT_CHARS]
341
+ built_spans.append(ToolExecutionSpan(
342
+ span_info=span_info,
343
+ metadata={},
344
+ span_type=SpanType.TOOL_EXECUTION,
345
+ tool_call=ToolCall(
346
+ name=sp["tool_name"],
347
+ arguments=_trim_tool_input(tool_input),
348
+ tool_call_id=sp["id"],
349
+ ),
350
+ tool_result=ToolResult(
351
+ content=content_text,
352
+ error=error_text,
353
+ tool_call_id=sp["id"],
354
+ ),
355
+ ))
356
+
357
+ strands_trace = StrandsTrace(
358
+ spans=built_spans,
359
+ trace_id=trace_id,
360
+ session_id=session_id,
361
+ )
362
+ return Session(traces=[strands_trace], session_id=session_id)
363
+
364
+
365
+ def _extract_evaluation_data(conn, trace_id: str, target: dict, needs_session: bool):
366
+ """Build a Strands EvaluationData from a trace's stored spans + messages.
367
+
368
+ - input: first user message's content
369
+ - actual_output: concatenation of the final assistant turn's content
370
+ - actual_trajectory: list of {tool} dicts for output-level evaluators, or
371
+ a full Strands Session object for trace/tool/session-
372
+ level evaluators.
373
+ """
374
+ from strands_evals.types import EvaluationData
375
+
376
+ first_user = conn.execute(
377
+ """SELECT content FROM messages
378
+ WHERE session_id = (SELECT session_id FROM traces WHERE id = %s)
379
+ AND role = 'user'
380
+ AND COALESCE(length(trim(content)), 0) > 0
381
+ ORDER BY timestamp ASC NULLS LAST LIMIT 1""",
382
+ (trace_id,),
383
+ ).fetchone()
384
+ user_input = (first_user or {}).get("content") or "(no user input)"
385
+
386
+ last_assistant = conn.execute(
387
+ """SELECT content FROM messages
388
+ WHERE session_id = (SELECT session_id FROM traces WHERE id = %s)
389
+ AND role = 'assistant'
390
+ AND COALESCE(length(trim(content)), 0) > 0
391
+ ORDER BY timestamp DESC NULLS LAST LIMIT 1""",
392
+ (trace_id,),
393
+ ).fetchone()
394
+ actual_output = (last_assistant or {}).get("content") or "(no assistant output)"
395
+
396
+ user_input = user_input[:4000]
397
+ actual_output = actual_output[:4000]
398
+
399
+ trajectory: Any
400
+ if needs_session:
401
+ trajectory = _build_strands_session(conn, trace_id)
402
+ else:
403
+ tool_rows = conn.execute(
404
+ """SELECT tool_name FROM spans
405
+ WHERE trace_id = %s AND kind = 'tool' AND tool_name IS NOT NULL
406
+ ORDER BY sequence""",
407
+ (trace_id,),
408
+ ).fetchall()
409
+ trajectory = [{"tool": r["tool_name"]} for r in tool_rows] or None
410
+
411
+ return EvaluationData(
412
+ input=user_input,
413
+ actual_output=actual_output,
414
+ actual_trajectory=trajectory,
415
+ metadata={
416
+ "trace_id": trace_id,
417
+ "provider_id": target.get("provider_id"),
418
+ "project": target.get("project"),
419
+ },
420
+ )
421
+
422
+
423
+ # --- Loaders ----------------------------------------------------------------
424
+
425
+ def _load_rubric(conn, rubric_id: str) -> dict | None:
426
+ row = conn.execute(
427
+ """SELECT id, name, description, kind, target_kind,
428
+ evaluator_type, rubric_text, model_id, config
429
+ FROM eval_rubrics WHERE id = %s""",
430
+ (rubric_id,),
431
+ ).fetchone()
432
+ return dict(row) if row else None
433
+
434
+
435
+ def _target_row(conn, rubric: dict, trace_id: str, span_id: Optional[str]) -> dict | None:
436
+ if rubric["target_kind"] == "trace":
437
+ row = conn.execute("SELECT * FROM traces WHERE id = %s", (trace_id,)).fetchone()
438
+ if not row:
439
+ return None
440
+ target = dict(row)
441
+ target["_target_kind"] = "trace"
442
+ return target
443
+ sid = span_id
444
+ if not sid:
445
+ cfg = rubric.get("config") or {}
446
+ if isinstance(cfg, str):
447
+ cfg = json.loads(cfg)
448
+ wanted = cfg.get("span_kind") if isinstance(cfg, dict) else None
449
+ if wanted:
450
+ row = conn.execute(
451
+ "SELECT * FROM spans WHERE trace_id = %s AND kind = %s ORDER BY sequence LIMIT 1",
452
+ (trace_id, wanted),
453
+ ).fetchone()
454
+ if row:
455
+ sid = row["id"]
456
+ if not sid:
457
+ return None
458
+ row = conn.execute("SELECT * FROM spans WHERE id = %s", (sid,)).fetchone()
459
+ if not row:
460
+ return None
461
+ target = dict(row)
462
+ target["_target_kind"] = "span"
463
+ return target
464
+
465
+
466
+ def _trace_children(conn, trace_id: str) -> list[dict]:
467
+ rows = conn.execute(
468
+ "SELECT * FROM spans WHERE trace_id = %s ORDER BY sequence", (trace_id,)
469
+ ).fetchall()
470
+ return [dict(r) for r in rows]
471
+
472
+
473
+ def _span_descendants(conn, span_id: str) -> list[dict]:
474
+ rows = conn.execute(
475
+ """WITH RECURSIVE descendants AS (
476
+ SELECT * FROM spans WHERE id = %s
477
+ UNION ALL
478
+ SELECT s.* FROM spans s JOIN descendants d ON s.parent_id = d.id
479
+ )
480
+ SELECT * FROM descendants WHERE id <> %s""",
481
+ (span_id, span_id),
482
+ ).fetchall()
483
+ return [dict(r) for r in rows]
484
+
485
+
486
+ # --- Runners ----------------------------------------------------------------
487
+
488
+ def _run_strands_evaluator(conn, rubric: dict, target: dict, trace_id: str) -> EvalResult:
489
+ """Invoke a Strands evaluator, persist the outcome."""
490
+ try:
491
+ cfg = _judge_config()
492
+ model = _make_ollama_model(cfg)
493
+ model_id = _pick_judge_model(cfg)
494
+ except Exception as e:
495
+ return EvalResult(score=None, passed=None, label="model-init-error", rationale=str(e)[:300])
496
+
497
+ evaluator_type = rubric["evaluator_type"]
498
+ try:
499
+ evaluator = _evaluator_factory(
500
+ evaluator_type, rubric.get("rubric_text"), model
501
+ )
502
+ except Exception as e:
503
+ return EvalResult(score=None, passed=None, label="factory-error", rationale=str(e)[:300])
504
+
505
+ needs_session = evaluator_type in _SESSION_EVALUATORS
506
+ try:
507
+ data = _extract_evaluation_data(conn, trace_id, target, needs_session=needs_session)
508
+ except Exception as e:
509
+ return EvalResult(score=None, passed=None, label="extract-error", rationale=str(e)[:300])
510
+
511
+ try:
512
+ outputs = evaluator.evaluate(data)
513
+ except Exception as e:
514
+ msg = str(e)
515
+ label = "ollama-down" if "ConnectError" in msg or "Connection refused" in msg else "judge-error"
516
+ return EvalResult(
517
+ score=None, passed=None, label=label,
518
+ rationale=msg[:400],
519
+ attrs={"judge_model": model_id},
520
+ )
521
+
522
+ if not outputs:
523
+ return EvalResult(score=None, passed=None, label="no-output", rationale="Evaluator returned no outputs.")
524
+
525
+ out = outputs[0]
526
+ return EvalResult(
527
+ score=float(out.score) if out.score is not None else None,
528
+ passed=bool(out.test_pass) if out.test_pass is not None else None,
529
+ label=out.label or "judged",
530
+ rationale=(out.reason or "")[:2000],
531
+ attrs={"judge_model": model_id, "judge_cost_usd": 0.0},
532
+ )
533
+
534
+
535
+ def run_rubric(
536
+ rubric_id: str,
537
+ trace_id: str,
538
+ span_id: Optional[str] = None,
539
+ ) -> Optional[int]:
540
+ """Run a rubric against a trace or span. Returns the new eval row id."""
541
+ conn = get_connection()
542
+ try:
543
+ rubric = _load_rubric(conn, rubric_id)
544
+ if not rubric:
545
+ return None
546
+
547
+ target = _target_row(conn, rubric, trace_id, span_id)
548
+ if not target:
549
+ return None
550
+
551
+ is_trace = target["_target_kind"] == "trace"
552
+ resolved_trace_id = trace_id if is_trace else target.get("trace_id")
553
+ resolved_span_id = None if is_trace else target.get("id")
554
+
555
+ if rubric["kind"] == "function":
556
+ grader = _FUNCTION_RUBRICS.get(rubric_id)
557
+ if not grader:
558
+ return None
559
+ children = _trace_children(conn, trace_id) if is_trace else _span_descendants(conn, target["id"])
560
+ result = grader(target, children)
561
+ elif rubric["kind"] == "llm_judge":
562
+ if not rubric.get("evaluator_type"):
563
+ return None
564
+ result = _run_strands_evaluator(conn, rubric, target, resolved_trace_id)
565
+ else:
566
+ return None
567
+
568
+ row = conn.execute(
569
+ """INSERT INTO evals (
570
+ rubric_id, trace_id, span_id, score, passed, label, rationale,
571
+ judge_model, judge_cost_usd, attrs
572
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
573
+ RETURNING id""",
574
+ (
575
+ rubric_id, resolved_trace_id, resolved_span_id,
576
+ result.score, result.passed, result.label, result.rationale,
577
+ (result.attrs or {}).get("judge_model"),
578
+ float((result.attrs or {}).get("judge_cost_usd", 0.0)),
579
+ json.dumps(result.attrs or {}),
580
+ ),
581
+ ).fetchone()
582
+ conn.commit()
583
+ return row["id"] if row else None
584
+ finally:
585
+ conn.close()
586
+
587
+
588
+ def run_rubric_bulk(rubric_id: str, since: Optional[datetime] = None) -> dict:
589
+ """Run a rubric against every trace created since a cutoff."""
590
+ conn = get_connection()
591
+ try:
592
+ if since:
593
+ rows = conn.execute(
594
+ "SELECT id FROM traces WHERE started_at >= %s ORDER BY started_at DESC",
595
+ (since,),
596
+ ).fetchall()
597
+ else:
598
+ rows = conn.execute("SELECT id FROM traces ORDER BY started_at DESC").fetchall()
599
+ trace_ids = [r["id"] for r in rows]
600
+ finally:
601
+ conn.close()
602
+
603
+ ok = 0
604
+ failed = 0
605
+ for tid in trace_ids:
606
+ res = run_rubric(rubric_id, tid)
607
+ if res is not None:
608
+ ok += 1
609
+ else:
610
+ failed += 1
611
+ return {"rubric": rubric_id, "traces": len(trace_ids), "scored": ok, "skipped": failed}