contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,604 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import sqlite3
5
+ import time
6
+ import uuid
7
+ from pathlib import Path
8
+ from typing import Any, Optional
9
+
10
+
11
+ SCHEMA_VERSION = 1
12
+
13
+
14
+ class SQLiteTraceStore:
15
+ def __init__(self, path: str = ".contexttrace/contexttrace.db") -> None:
16
+ self.path = Path(path)
17
+ self.path.parent.mkdir(parents=True, exist_ok=True)
18
+ self._init_db()
19
+
20
+ def create_trace(self, *, project: str, query: str, metadata: dict[str, Any]) -> dict[str, Any]:
21
+ trace_id = _new_id("trace")
22
+ project_id = _new_id("project")
23
+ now = _now()
24
+ with self._connect() as db:
25
+ db.execute(
26
+ """
27
+ INSERT INTO traces (id, project_id, project, query, metadata_json, status, created_at, updated_at)
28
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
29
+ """,
30
+ (trace_id, project_id, project, query, _json(metadata), "started", now, now),
31
+ )
32
+ return self.get_trace(trace_id)
33
+
34
+ def upsert_chunks(self, trace_id: str, chunks: list[dict[str, Any]], *, selected: bool) -> int:
35
+ with self._connect() as db:
36
+ current_count = db.execute(
37
+ "SELECT COUNT(*) FROM chunks WHERE trace_id = ?",
38
+ (trace_id,),
39
+ ).fetchone()[0]
40
+ for offset, chunk in enumerate(chunks):
41
+ chunk_id = str(chunk.get("chunk_id") or chunk.get("id") or f"chunk_{current_count + offset}")
42
+ existing = db.execute(
43
+ "SELECT id, selected FROM chunks WHERE trace_id = ? AND chunk_id = ?",
44
+ (trace_id, chunk_id),
45
+ ).fetchone()
46
+ if existing:
47
+ db.execute(
48
+ """
49
+ UPDATE chunks
50
+ SET content = ?, source = ?, metadata_json = ?, relevance_score = ?,
51
+ selected = ?, updated_at = ?
52
+ WHERE id = ?
53
+ """,
54
+ (
55
+ str(chunk.get("content") or ""),
56
+ chunk.get("source"),
57
+ _json(chunk.get("metadata") or {}),
58
+ chunk.get("relevance_score"),
59
+ bool(existing["selected"]) or selected,
60
+ _now(),
61
+ existing["id"],
62
+ ),
63
+ )
64
+ else:
65
+ db.execute(
66
+ """
67
+ INSERT INTO chunks (
68
+ id, trace_id, chunk_id, content, source, metadata_json,
69
+ relevance_score, position, selected, created_at, updated_at
70
+ )
71
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
72
+ """,
73
+ (
74
+ _new_id("chunk"),
75
+ trace_id,
76
+ chunk_id,
77
+ str(chunk.get("content") or ""),
78
+ chunk.get("source"),
79
+ _json(chunk.get("metadata") or {}),
80
+ chunk.get("relevance_score"),
81
+ current_count + offset,
82
+ selected,
83
+ _now(),
84
+ _now(),
85
+ ),
86
+ )
87
+ self._set_status(db, trace_id, "context_logged" if selected else "retrieval_logged")
88
+ return len(chunks)
89
+
90
+ def mark_context(self, trace_id: str, chunk_ids: list[str]) -> int:
91
+ if not chunk_ids:
92
+ return 0
93
+ with self._connect() as db:
94
+ placeholders = ",".join("?" for _ in chunk_ids)
95
+ params = [trace_id, *chunk_ids]
96
+ db.execute(
97
+ f"UPDATE chunks SET selected = 1, updated_at = ? WHERE trace_id = ? AND chunk_id IN ({placeholders})",
98
+ [_now(), *params],
99
+ )
100
+ accepted = db.execute(
101
+ f"SELECT COUNT(*) FROM chunks WHERE trace_id = ? AND chunk_id IN ({placeholders}) AND selected = 1",
102
+ params,
103
+ ).fetchone()[0]
104
+ self._set_status(db, trace_id, "context_logged")
105
+ return int(accepted)
106
+
107
+ def save_answer(self, trace_id: str, answer: dict[str, Any]) -> None:
108
+ with self._connect() as db:
109
+ existing = db.execute("SELECT id FROM answers WHERE trace_id = ?", (trace_id,)).fetchone()
110
+ values = (
111
+ str(answer.get("answer") or ""),
112
+ answer.get("model"),
113
+ _json(answer.get("usage") or {}),
114
+ _json(answer.get("metadata") or {}),
115
+ _now(),
116
+ )
117
+ if existing:
118
+ db.execute(
119
+ """
120
+ UPDATE answers
121
+ SET answer = ?, model = ?, usage_json = ?, metadata_json = ?, updated_at = ?
122
+ WHERE trace_id = ?
123
+ """,
124
+ (*values, trace_id),
125
+ )
126
+ else:
127
+ db.execute(
128
+ """
129
+ INSERT INTO answers (id, trace_id, answer, model, usage_json, metadata_json, created_at, updated_at)
130
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
131
+ """,
132
+ (_new_id("answer"), trace_id, *values[:-1], _now(), values[-1]),
133
+ )
134
+ self._set_status(db, trace_id, "answer_logged")
135
+
136
+ def save_citations(self, trace_id: str, citations: list[dict[str, Any]]) -> int:
137
+ with self._connect() as db:
138
+ db.execute("DELETE FROM citation_checks WHERE trace_id = ?", (trace_id,))
139
+ for citation in citations:
140
+ db.execute(
141
+ """
142
+ INSERT INTO citation_checks (
143
+ id, trace_id, claim, source_chunk_id, support_status,
144
+ support_score, rationale, metadata_json, created_at, updated_at
145
+ )
146
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
147
+ """,
148
+ (
149
+ _new_id("check"),
150
+ trace_id,
151
+ citation["claim"],
152
+ citation["source_chunk_id"],
153
+ "pending",
154
+ None,
155
+ None,
156
+ _json(citation.get("metadata") or {}),
157
+ _now(),
158
+ _now(),
159
+ ),
160
+ )
161
+ self._set_status(db, trace_id, "citations_logged")
162
+ return len(citations)
163
+
164
+ def save_evaluation(self, trace_id: str, evaluation: dict[str, Any]) -> None:
165
+ failure = evaluation.get("failure") or {}
166
+ with self._connect() as db:
167
+ for check in evaluation.get("citation_checks") or []:
168
+ db.execute(
169
+ """
170
+ UPDATE citation_checks
171
+ SET support_status = ?, support_score = ?, rationale = ?, updated_at = ?
172
+ WHERE trace_id = ? AND claim = ? AND source_chunk_id = ?
173
+ """,
174
+ (
175
+ check.get("verdict"),
176
+ check.get("support_score"),
177
+ check.get("reason"),
178
+ _now(),
179
+ trace_id,
180
+ check.get("claim"),
181
+ check.get("source_chunk_id"),
182
+ ),
183
+ )
184
+ existing = db.execute("SELECT id FROM failure_reports WHERE trace_id = ?", (trace_id,)).fetchone()
185
+ values = (
186
+ failure.get("failure_type") or "unknown",
187
+ failure.get("severity") or "medium",
188
+ failure.get("root_cause") or "",
189
+ failure.get("suggested_fix") or "",
190
+ _json(evaluation.get("scores") or {}),
191
+ _json(evaluation.get("reliability") or {}),
192
+ _json(evaluation.get("metadata") or {}),
193
+ _now(),
194
+ )
195
+ if existing:
196
+ db.execute(
197
+ """
198
+ UPDATE failure_reports
199
+ SET failure_type = ?, severity = ?, root_cause = ?, suggested_fix = ?,
200
+ scores_json = ?, reliability_json = ?, metadata_json = ?, updated_at = ?
201
+ WHERE trace_id = ?
202
+ """,
203
+ (*values, trace_id),
204
+ )
205
+ else:
206
+ db.execute(
207
+ """
208
+ INSERT INTO failure_reports (
209
+ id, trace_id, failure_type, severity, root_cause, suggested_fix,
210
+ scores_json, reliability_json, metadata_json, created_at, updated_at
211
+ )
212
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
213
+ """,
214
+ (_new_id("failure"), trace_id, *values[:-1], _now(), values[-1]),
215
+ )
216
+ self._set_status(db, trace_id, "evaluated")
217
+
218
+ def add_agent_event(self, trace_id: str, payload: dict[str, Any]) -> dict[str, Any]:
219
+ event_id = _new_id("agent_event")
220
+ with self._connect() as db:
221
+ db.execute(
222
+ """
223
+ INSERT INTO agent_events (
224
+ id, trace_id, event_type, name, input_json, output_json,
225
+ metadata_json, latency_ms, error_message, created_at
226
+ )
227
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
228
+ """,
229
+ (
230
+ event_id,
231
+ trace_id,
232
+ payload["event_type"],
233
+ payload.get("name"),
234
+ _json(payload.get("input_json") if payload.get("input_json") is not None else {}),
235
+ _json(payload.get("output_json") if payload.get("output_json") is not None else {}),
236
+ _json(payload.get("metadata_json") or {}),
237
+ payload.get("latency_ms"),
238
+ payload.get("error_message"),
239
+ _now(),
240
+ ),
241
+ )
242
+ self._set_status(db, trace_id, "agent_event_logged")
243
+ return {"trace_id": trace_id, "event_id": event_id, "accepted": 1}
244
+
245
+ def create_eval_run(self, *, dataset: str, endpoint: Optional[str], summary: dict[str, Any]) -> str:
246
+ run_id = _new_id("eval_run")
247
+ with self._connect() as db:
248
+ db.execute(
249
+ """
250
+ INSERT INTO eval_runs (id, dataset, endpoint, summary_json, created_at)
251
+ VALUES (?, ?, ?, ?, ?)
252
+ """,
253
+ (run_id, dataset, endpoint, _json(summary), _now()),
254
+ )
255
+ return run_id
256
+
257
+ def add_eval_question(self, *, eval_run_id: str, question: dict[str, Any], trace_id: Optional[str], position: int) -> None:
258
+ with self._connect() as db:
259
+ db.execute(
260
+ """
261
+ INSERT INTO eval_questions (id, eval_run_id, trace_id, question_json, position, created_at)
262
+ VALUES (?, ?, ?, ?, ?, ?)
263
+ """,
264
+ (_new_id("eval_question"), eval_run_id, trace_id, _json(question), position, _now()),
265
+ )
266
+
267
+ def last_eval_run(self) -> Optional[dict[str, Any]]:
268
+ with self._connect() as db:
269
+ row = db.execute("SELECT * FROM eval_runs ORDER BY created_at DESC LIMIT 1").fetchone()
270
+ return _eval_run_from_row(row) if row else None
271
+
272
+ def get_eval_run(self, eval_run_id: str) -> dict[str, Any]:
273
+ with self._connect() as db:
274
+ row = db.execute("SELECT * FROM eval_runs WHERE id = ?", (eval_run_id,)).fetchone()
275
+ questions = db.execute(
276
+ "SELECT * FROM eval_questions WHERE eval_run_id = ? ORDER BY position",
277
+ (eval_run_id,),
278
+ ).fetchall()
279
+ if row is None:
280
+ raise KeyError("Eval run not found: %s" % eval_run_id)
281
+ result = _eval_run_from_row(row)
282
+ result["questions"] = [
283
+ {
284
+ "id": question["id"],
285
+ "trace_id": question["trace_id"],
286
+ "question": _loads(question["question_json"], {}),
287
+ "position": question["position"],
288
+ "created_at": question["created_at"],
289
+ }
290
+ for question in questions
291
+ ]
292
+ return result
293
+
294
+ def list_eval_runs(self, *, limit: int = 20) -> list[dict[str, Any]]:
295
+ with self._connect() as db:
296
+ rows = db.execute(
297
+ "SELECT * FROM eval_runs ORDER BY created_at DESC LIMIT ?",
298
+ (limit,),
299
+ ).fetchall()
300
+ return [_eval_run_from_row(row) for row in rows]
301
+
302
+ def get_trace(self, trace_id: str) -> dict[str, Any]:
303
+ with self._connect() as db:
304
+ trace = db.execute("SELECT * FROM traces WHERE id = ?", (trace_id,)).fetchone()
305
+ if trace is None:
306
+ raise KeyError("Trace not found: %s" % trace_id)
307
+ chunks = db.execute(
308
+ "SELECT * FROM chunks WHERE trace_id = ? ORDER BY position, created_at",
309
+ (trace_id,),
310
+ ).fetchall()
311
+ answer = db.execute("SELECT * FROM answers WHERE trace_id = ?", (trace_id,)).fetchone()
312
+ checks = db.execute(
313
+ "SELECT * FROM citation_checks WHERE trace_id = ? ORDER BY created_at",
314
+ (trace_id,),
315
+ ).fetchall()
316
+ failure = db.execute("SELECT * FROM failure_reports WHERE trace_id = ?", (trace_id,)).fetchone()
317
+ events = db.execute(
318
+ "SELECT * FROM agent_events WHERE trace_id = ? ORDER BY created_at",
319
+ (trace_id,),
320
+ ).fetchall()
321
+ return _trace_from_rows(trace, chunks, answer, checks, failure, events)
322
+
323
+ def list_traces(self, *, limit: int = 20) -> list[dict[str, Any]]:
324
+ with self._connect() as db:
325
+ rows = db.execute(
326
+ "SELECT id FROM traces ORDER BY updated_at DESC LIMIT ?",
327
+ (limit,),
328
+ ).fetchall()
329
+ return [self.get_trace(row["id"]) for row in rows]
330
+
331
+ def last_trace(self) -> Optional[dict[str, Any]]:
332
+ traces = self.list_traces(limit=1)
333
+ return traces[0] if traces else None
334
+
335
+ def trace_count(self) -> int:
336
+ with self._connect() as db:
337
+ return int(db.execute("SELECT COUNT(*) FROM traces").fetchone()[0])
338
+
339
+ def _init_db(self) -> None:
340
+ with self._connect() as db:
341
+ db.execute("PRAGMA journal_mode=WAL")
342
+ db.execute(
343
+ "CREATE TABLE IF NOT EXISTS schema_meta (key TEXT PRIMARY KEY, value TEXT NOT NULL)"
344
+ )
345
+ db.execute(
346
+ """
347
+ CREATE TABLE IF NOT EXISTS traces (
348
+ id TEXT PRIMARY KEY,
349
+ project_id TEXT NOT NULL,
350
+ project TEXT NOT NULL,
351
+ query TEXT NOT NULL,
352
+ metadata_json TEXT NOT NULL,
353
+ status TEXT NOT NULL,
354
+ created_at TEXT NOT NULL,
355
+ updated_at TEXT NOT NULL
356
+ )
357
+ """
358
+ )
359
+ db.execute(
360
+ """
361
+ CREATE TABLE IF NOT EXISTS chunks (
362
+ id TEXT PRIMARY KEY,
363
+ trace_id TEXT NOT NULL,
364
+ chunk_id TEXT NOT NULL,
365
+ content TEXT NOT NULL,
366
+ source TEXT,
367
+ metadata_json TEXT NOT NULL,
368
+ relevance_score REAL,
369
+ position INTEGER NOT NULL,
370
+ selected INTEGER NOT NULL DEFAULT 0,
371
+ created_at TEXT NOT NULL,
372
+ updated_at TEXT NOT NULL,
373
+ UNIQUE(trace_id, chunk_id)
374
+ )
375
+ """
376
+ )
377
+ db.execute(
378
+ """
379
+ CREATE TABLE IF NOT EXISTS answers (
380
+ id TEXT PRIMARY KEY,
381
+ trace_id TEXT NOT NULL UNIQUE,
382
+ answer TEXT NOT NULL,
383
+ model TEXT,
384
+ usage_json TEXT NOT NULL,
385
+ metadata_json TEXT NOT NULL,
386
+ created_at TEXT NOT NULL,
387
+ updated_at TEXT NOT NULL
388
+ )
389
+ """
390
+ )
391
+ db.execute(
392
+ """
393
+ CREATE TABLE IF NOT EXISTS citation_checks (
394
+ id TEXT PRIMARY KEY,
395
+ trace_id TEXT NOT NULL,
396
+ claim TEXT NOT NULL,
397
+ source_chunk_id TEXT NOT NULL,
398
+ support_status TEXT NOT NULL,
399
+ support_score REAL,
400
+ rationale TEXT,
401
+ metadata_json TEXT NOT NULL,
402
+ created_at TEXT NOT NULL,
403
+ updated_at TEXT NOT NULL
404
+ )
405
+ """
406
+ )
407
+ db.execute(
408
+ """
409
+ CREATE TABLE IF NOT EXISTS failure_reports (
410
+ id TEXT PRIMARY KEY,
411
+ trace_id TEXT NOT NULL UNIQUE,
412
+ failure_type TEXT NOT NULL,
413
+ severity TEXT NOT NULL,
414
+ root_cause TEXT NOT NULL,
415
+ suggested_fix TEXT NOT NULL,
416
+ scores_json TEXT NOT NULL,
417
+ reliability_json TEXT NOT NULL,
418
+ metadata_json TEXT NOT NULL,
419
+ created_at TEXT NOT NULL,
420
+ updated_at TEXT NOT NULL
421
+ )
422
+ """
423
+ )
424
+ db.execute(
425
+ """
426
+ CREATE TABLE IF NOT EXISTS agent_events (
427
+ id TEXT PRIMARY KEY,
428
+ trace_id TEXT NOT NULL,
429
+ event_type TEXT NOT NULL,
430
+ name TEXT,
431
+ input_json TEXT NOT NULL,
432
+ output_json TEXT NOT NULL,
433
+ metadata_json TEXT NOT NULL,
434
+ latency_ms REAL,
435
+ error_message TEXT,
436
+ created_at TEXT NOT NULL
437
+ )
438
+ """
439
+ )
440
+ db.execute(
441
+ """
442
+ CREATE TABLE IF NOT EXISTS eval_runs (
443
+ id TEXT PRIMARY KEY,
444
+ dataset TEXT NOT NULL,
445
+ endpoint TEXT,
446
+ summary_json TEXT NOT NULL,
447
+ created_at TEXT NOT NULL
448
+ )
449
+ """
450
+ )
451
+ db.execute(
452
+ """
453
+ CREATE TABLE IF NOT EXISTS eval_questions (
454
+ id TEXT PRIMARY KEY,
455
+ eval_run_id TEXT NOT NULL,
456
+ trace_id TEXT,
457
+ question_json TEXT NOT NULL,
458
+ position INTEGER NOT NULL,
459
+ created_at TEXT NOT NULL
460
+ )
461
+ """
462
+ )
463
+ db.execute(
464
+ "INSERT OR REPLACE INTO schema_meta (key, value) VALUES ('schema_version', ?)",
465
+ (str(SCHEMA_VERSION),),
466
+ )
467
+
468
+ def _connect(self) -> sqlite3.Connection:
469
+ db = sqlite3.connect(str(self.path))
470
+ db.row_factory = sqlite3.Row
471
+ return db
472
+
473
+ def _set_status(self, db: sqlite3.Connection, trace_id: str, status: str) -> None:
474
+ db.execute(
475
+ "UPDATE traces SET status = ?, updated_at = ? WHERE id = ?",
476
+ (status, _now(), trace_id),
477
+ )
478
+
479
+
480
+ def _trace_from_rows(
481
+ trace: sqlite3.Row,
482
+ chunks: list[sqlite3.Row],
483
+ answer: Optional[sqlite3.Row],
484
+ checks: list[sqlite3.Row],
485
+ failure: Optional[sqlite3.Row],
486
+ events: list[sqlite3.Row],
487
+ ) -> dict[str, Any]:
488
+ citation_checks = [
489
+ {
490
+ "id": row["id"],
491
+ "claim": row["claim"],
492
+ "source_chunk_id": row["source_chunk_id"],
493
+ "support_status": row["support_status"],
494
+ "support_score": row["support_score"],
495
+ "rationale": row["rationale"],
496
+ "metadata": _loads(row["metadata_json"], {}),
497
+ }
498
+ for row in checks
499
+ ]
500
+ evaluation = None
501
+ if failure is not None:
502
+ evaluated = [
503
+ {
504
+ "claim": check["claim"],
505
+ "source_chunk_id": check["source_chunk_id"],
506
+ "verdict": check["support_status"],
507
+ "support_score": check["support_score"] or 0.0,
508
+ "reason": check["rationale"] or "",
509
+ }
510
+ for check in citation_checks
511
+ if check["support_status"] != "pending"
512
+ ]
513
+ evaluation = {
514
+ "scores": _loads(failure["scores_json"], {}),
515
+ "reliability": _loads(failure["reliability_json"], {}),
516
+ "citation_checks": evaluated,
517
+ "failure": {
518
+ "failure_type": failure["failure_type"],
519
+ "severity": failure["severity"],
520
+ "root_cause": failure["root_cause"],
521
+ "suggested_fix": failure["suggested_fix"],
522
+ },
523
+ }
524
+
525
+ return {
526
+ "id": trace["id"],
527
+ "project_id": trace["project_id"],
528
+ "project": trace["project"],
529
+ "query": trace["query"],
530
+ "metadata": _loads(trace["metadata_json"], {}),
531
+ "status": trace["status"],
532
+ "chunks": [
533
+ {
534
+ "id": row["id"],
535
+ "chunk_id": row["chunk_id"],
536
+ "content": row["content"],
537
+ "source": row["source"],
538
+ "metadata": _loads(row["metadata_json"], {}),
539
+ "relevance_score": row["relevance_score"],
540
+ "position": row["position"],
541
+ "selected": bool(row["selected"]),
542
+ }
543
+ for row in chunks
544
+ ],
545
+ "answer": (
546
+ {
547
+ "id": answer["id"],
548
+ "answer": answer["answer"],
549
+ "model": answer["model"],
550
+ "usage": _loads(answer["usage_json"], {}),
551
+ "metadata": _loads(answer["metadata_json"], {}),
552
+ }
553
+ if answer is not None
554
+ else None
555
+ ),
556
+ "citation_checks": citation_checks,
557
+ "agent_events": [
558
+ {
559
+ "id": row["id"],
560
+ "trace_id": row["trace_id"],
561
+ "event_type": row["event_type"],
562
+ "name": row["name"],
563
+ "input_json": _loads(row["input_json"], {}),
564
+ "output_json": _loads(row["output_json"], {}),
565
+ "metadata_json": _loads(row["metadata_json"], {}),
566
+ "latency_ms": row["latency_ms"],
567
+ "error_message": row["error_message"],
568
+ "created_at": row["created_at"],
569
+ }
570
+ for row in events
571
+ ],
572
+ "evaluation": evaluation,
573
+ "created_at": trace["created_at"],
574
+ "updated_at": trace["updated_at"],
575
+ }
576
+
577
+
578
+ def _eval_run_from_row(row: sqlite3.Row) -> dict[str, Any]:
579
+ return {
580
+ "id": row["id"],
581
+ "dataset": row["dataset"],
582
+ "endpoint": row["endpoint"],
583
+ "summary": _loads(row["summary_json"], {}),
584
+ "created_at": row["created_at"],
585
+ }
586
+
587
+
588
+ def _json(value: Any) -> str:
589
+ return json.dumps(value, sort_keys=True)
590
+
591
+
592
+ def _loads(value: str, default: Any) -> Any:
593
+ try:
594
+ return json.loads(value) if value else default
595
+ except json.JSONDecodeError:
596
+ return default
597
+
598
+
599
+ def _new_id(prefix: str) -> str:
600
+ return "%s_%s" % (prefix, uuid.uuid4().hex[:12])
601
+
602
+
603
+ def _now() -> str:
604
+ return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ import operator
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Any, Iterable
7
+
8
+
9
+ _OPERATORS = {
10
+ ">": operator.gt,
11
+ ">=": operator.ge,
12
+ "<": operator.lt,
13
+ "<=": operator.le,
14
+ }
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Threshold:
19
+ metric: str
20
+ operator: str
21
+ value: float
22
+
23
+ def evaluate(self, metrics: dict[str, Any]) -> bool:
24
+ if self.metric not in metrics:
25
+ raise ValueError("Metric '%s' is not available for threshold evaluation." % self.metric)
26
+ return bool(_OPERATORS[self.operator](float(metrics[self.metric]), self.value))
27
+
28
+ def describe(self, metrics: dict[str, Any]) -> str:
29
+ actual = metrics.get(self.metric)
30
+ return "%s=%s violates %s%s" % (self.metric, actual, self.operator, self.value)
31
+
32
+
33
+ def parse_threshold(value: str) -> Threshold:
34
+ match = re.fullmatch(r"\s*([A-Za-z_][A-Za-z0-9_]*)\s*(>=|<=|>|<)\s*([0-9]*\.?[0-9]+)\s*", value)
35
+ if not match:
36
+ raise ValueError("Threshold must look like failure_rate>0.25 or citation_support<0.80.")
37
+ metric, op, raw_value = match.groups()
38
+ return Threshold(metric=metric, operator=op, value=float(raw_value))
39
+
40
+
41
+ def parse_thresholds(values: Iterable[str]) -> list[Threshold]:
42
+ return [parse_threshold(value) for value in values]
43
+
44
+
45
+ def threshold_failures(metrics: dict[str, Any], thresholds: Iterable[Threshold]) -> list[str]:
46
+ failures = []
47
+ for threshold in thresholds:
48
+ if threshold.evaluate(metrics):
49
+ failures.append(threshold.describe(metrics))
50
+ return failures