odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/aggregators.py CHANGED
@@ -1,707 +1,707 @@
1
- """
2
- Aggregation utilities for insight-ready features in Odin KG Engine.
3
- These functions analyze retrieval paths and generate quantitative patterns
4
- for discovery triage and briefs-ready evidence.
5
-
6
- Drop-in safe: timezone-aware, schema-noise guards, small-n gates, and
7
- explainable scores.
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- from typing import List, Dict, Any, Tuple, Optional
13
- from collections import Counter, defaultdict
14
- from datetime import datetime, timezone
15
- import statistics
16
- import logging
17
- import math
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
- # ----------------------------
22
- # Relation filtering & guards
23
- # ----------------------------
24
-
25
- BAN_RELATIONS = {
26
- "is", "object", "value", "from", "to", "object type", "object value",
27
- "object ID", "relationship", "related to", "page", "table", "id",
28
- "chunkid", "imagemetadata"
29
- }
30
-
31
- # Allowed only if BOTH endpoints are NOT schema crumbs (see keep_relation)
32
- COND_RELATIONS = {"contains", "includes"}
33
-
34
- SCHEMA_LABELS = {"page", "table", "imagemetadata", "id", "chunkid", "object"}
35
-
36
-
37
- def keep_relation(
38
- rel: Optional[str],
39
- u_label: Optional[str] = None,
40
- v_label: Optional[str] = None
41
- ) -> bool:
42
- """
43
- Return True if a relation should be kept for insight aggregation.
44
- Enforces banlist and conditional relations logic.
45
- """
46
- if not rel:
47
- return False
48
- rel = str(rel).strip().lower()
49
-
50
- if rel in BAN_RELATIONS:
51
- return False
52
-
53
- if rel in COND_RELATIONS:
54
- # Drop if either endpoint looks like schema/meta
55
- if (u_label and u_label.lower() in SCHEMA_LABELS) or (v_label and v_label.lower() in SCHEMA_LABELS):
56
- return False
57
-
58
- return True
59
-
60
-
61
- # ----------------------------
62
- # Timestamp & math utilities
63
- # ----------------------------
64
-
65
- def _now_utc() -> datetime:
66
- return datetime.now(timezone.utc)
67
-
68
-
69
- def parse_timestamp(ts: Any) -> Optional[datetime]:
70
- """
71
- Parse various timestamp formats to timezone-aware UTC datetime.
72
- Supports ISO strings (with or without Z) and unix epoch (int/float).
73
- """
74
- if ts is None:
75
- return None
76
- try:
77
- if isinstance(ts, str):
78
- # Normalize trailing Z to +00:00
79
- iso = ts.replace("Z", "+00:00")
80
- dt = datetime.fromisoformat(iso)
81
- if dt.tzinfo is None:
82
- dt = dt.replace(tzinfo=timezone.utc)
83
- return dt.astimezone(timezone.utc)
84
- if isinstance(ts, (int, float)):
85
- return datetime.fromtimestamp(float(ts), tz=timezone.utc)
86
- except Exception as e:
87
- logger.debug(f"parse_timestamp failed for {ts!r}: {e}")
88
- return None
89
-
90
-
91
- def geometric_mean(values: List[float]) -> float:
92
- """
93
- Numerically stable geometric mean on [1e-6, 1.0], returns 0.0 if empty.
94
- """
95
- if not values:
96
- return 0.0
97
- clamped = [max(1e-6, min(1.0, float(v))) for v in values]
98
- return math.exp(sum(math.log(v) for v in clamped) / len(clamped))
99
-
100
-
101
- # ----------------------------
102
- # Core aggregators
103
- # ----------------------------
104
-
105
- def extract_motifs(paths: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
106
- """
107
- Cluster paths by relation sequence (motifs) and compute statistics.
108
-
109
- Returns:
110
- List of dicts with:
111
- - pattern: "rel1->rel2->..."
112
- - edge_count: int
113
- - path_count: int (unique paths supporting this motif)
114
- - avg_edge_conf: float
115
- - median_recency_days: float | None
116
- """
117
- motif_stats: Dict[str, Dict[str, Any]] = defaultdict(
118
- lambda: {"edge_confs": [], "timestamps": [], "path_ids": set(), "edge_count": 0}
119
- )
120
-
121
- for path_idx, path in enumerate(paths):
122
- edges = path.get("edges", []) or []
123
- if not edges:
124
- continue
125
-
126
- # Build motif relation sequence (filtered)
127
- rel_seq: List[str] = []
128
- filtered_edges: List[Dict[str, Any]] = []
129
- for e in edges:
130
- rel = e.get("relation", e.get("relationship"))
131
- u_label = e.get("u_label")
132
- v_label = e.get("v_label")
133
- if keep_relation(rel, u_label, v_label):
134
- rel_seq.append(str(rel).strip())
135
- filtered_edges.append(e)
136
-
137
- if not rel_seq:
138
- continue
139
-
140
- motif_pattern = "->".join(rel_seq)
141
- path_id = path.get("id", f"path_{path_idx}")
142
-
143
- # Collect stats
144
- for e in filtered_edges:
145
- motif_stats[motif_pattern]["edge_count"] += 1
146
- conf = e.get("confidence", e.get("weight", 1.0))
147
- try:
148
- conf = float(conf)
149
- except Exception:
150
- conf = 1.0
151
- motif_stats[motif_pattern]["edge_confs"].append(conf)
152
-
153
- ts = e.get("created_at", e.get("timestamp"))
154
- dt = parse_timestamp(ts)
155
- if dt:
156
- motif_stats[motif_pattern]["timestamps"].append(dt)
157
-
158
- motif_stats[motif_pattern]["path_ids"].add(path_id)
159
-
160
- motifs: List[Dict[str, Any]] = []
161
- now = _now_utc()
162
-
163
- for pattern, stats in motif_stats.items():
164
- # Median recency (days)
165
- rec_days: Optional[float] = None
166
- if stats["timestamps"]:
167
- try:
168
- days_ago = [max(0.0, (now - dt).total_seconds() / 86400.0) for dt in stats["timestamps"]]
169
- if days_ago:
170
- rec_days = statistics.median(days_ago)
171
- except Exception as e:
172
- logger.warning(f"Error calculating recency for motif {pattern}: {e}")
173
-
174
- motifs.append(
175
- {
176
- "pattern": pattern,
177
- "edge_count": int(stats["edge_count"]),
178
- "path_count": int(len(stats["path_ids"])),
179
- "avg_edge_conf": round(statistics.mean(stats["edge_confs"]), 4) if stats["edge_confs"] else 0.0,
180
- "median_recency_days": rec_days,
181
- }
182
- )
183
-
184
- motifs.sort(key=lambda x: (x["edge_count"], x["path_count"]), reverse=True)
185
- return motifs
186
-
187
-
188
- def calculate_relation_shares(paths: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
189
- """
190
- Calculate the distribution of relations across all edges in paths (filtered).
191
-
192
- Returns:
193
- Dict[relation] = { "count": int, "share": float }
194
- """
195
- relation_counts: Counter = Counter()
196
- total_edges = 0
197
-
198
- for path in paths:
199
- for e in path.get("edges", []) or []:
200
- rel = e.get("relation", e.get("relationship"))
201
- u_label = e.get("u_label")
202
- v_label = e.get("v_label")
203
- if not keep_relation(rel, u_label, v_label):
204
- continue
205
- relation_counts[str(rel).strip()] += 1
206
- total_edges += 1
207
-
208
- # Avoid division by zero
209
- denom = max(1, total_edges)
210
- shares = {rel: {"count": cnt, "share": round(cnt / denom, 4)} for rel, cnt in relation_counts.items()}
211
- return shares
212
-
213
-
214
- def compute_provenance_coverage(paths: List[Dict[str, Any]]) -> float:
215
- """
216
- Fraction of edges that carry any provenance/document id.
217
- """
218
- prov_edges = 0
219
- total = 0
220
- for p in paths:
221
- for e in p.get("edges", []) or []:
222
- total += 1
223
- prov = e.get("provenance")
224
- src = e.get("source_doc")
225
- has = False
226
- if isinstance(prov, dict):
227
- has = bool(prov.get("document_id"))
228
- elif isinstance(prov, list):
229
- # any item with document_id
230
- has = any(bool(it.get("document_id")) for it in prov if isinstance(it, dict))
231
- has = has or bool(src)
232
- prov_edges += int(has)
233
- return (prov_edges / total) if total else 0.0
234
-
235
-
236
- def compute_recency_score(paths: List[Dict[str, Any]], half_life_days: float = 30.0) -> float:
237
- """
238
- Recency score in [0,1] as average of exp(-age/half_life) over edges with timestamps.
239
- """
240
- if half_life_days <= 0:
241
- return 0.0
242
- now = _now_utc()
243
- vals: List[float] = []
244
- for p in paths:
245
- for e in p.get("edges", []) or []:
246
- ts = e.get("created_at", e.get("timestamp"))
247
- dt = parse_timestamp(ts)
248
- if not dt:
249
- continue
250
- age_days = max(0.0, (now - dt).total_seconds() / 86400.0)
251
- vals.append(math.exp(-age_days / half_life_days))
252
- return statistics.mean(vals) if vals else 0.0
253
-
254
-
255
- def estimate_label_coverage(paths: List[Dict[str, Any]]) -> float:
256
- """
257
- Estimate fraction of labeled endpoints across edges (0..1).
258
- """
259
- labeled = 0
260
- total = 0
261
- for p in paths:
262
- for e in p.get("edges", []) or []:
263
- total += 2
264
- labeled += int(bool(e.get("u_label")))
265
- labeled += int(bool(e.get("v_label")))
266
- return (labeled / total) if total else 0.0
267
-
268
-
269
- def compute_motif_density(motifs: List[Dict[str, Any]], topk: int = 2) -> float:
270
- """
271
- Share of edges that fall into the top-K motifs (by edge_count).
272
- """
273
- if not motifs:
274
- return 0.0
275
- counts = [int(m.get("edge_count", m.get("count", 0))) for m in motifs]
276
- counts.sort(reverse=True)
277
- top = sum(counts[:topk])
278
- total = sum(counts)
279
- return (top / total) if total else 0.0
280
-
281
-
282
- def extract_snippet_anchors(
283
- paths: List[Dict[str, Any]],
284
- max_per_path: int = 2,
285
- max_total: int = 16
286
- ) -> List[Dict[str, Any]]:
287
- """
288
- Extract provenance anchors from top paths, preferring doc diversity.
289
-
290
- Returns:
291
- List of { path_idx, edge_id, document_id, span? }
292
- """
293
- anchors: List[Dict[str, Any]] = []
294
- seen_docs = set()
295
-
296
- for path_idx, path in enumerate(paths[:10]): # focus on top-10 paths
297
- if len(anchors) >= max_total:
298
- break
299
- edges = path.get("edges", []) or []
300
- path_local: List[Dict[str, Any]] = []
301
-
302
- for e in edges:
303
- # Skip edges without allowed relations to keep anchors on-semantic
304
- rel = e.get("relation", e.get("relationship"))
305
- if not keep_relation(rel, e.get("u_label"), e.get("v_label")):
306
- continue
307
-
308
- prov = e.get("provenance")
309
- doc_id = None
310
- span = None
311
-
312
- if isinstance(prov, dict):
313
- doc_id = prov.get("document_id")
314
- span = prov.get("char_span")
315
- elif isinstance(prov, list) and prov:
316
- # take first with document_id
317
- for it in prov:
318
- if isinstance(it, dict) and it.get("document_id"):
319
- doc_id = it["document_id"]
320
- span = it.get("char_span")
321
- break
322
-
323
- if not doc_id:
324
- doc_id = e.get("source_doc")
325
-
326
- if doc_id and doc_id not in seen_docs:
327
- anchor = {
328
- "path_idx": path_idx,
329
- "edge_id": e.get("_id", e.get("id")),
330
- "document_id": doc_id,
331
- }
332
- if span is not None:
333
- anchor["span"] = span
334
- path_local.append(anchor)
335
-
336
- if len(path_local) >= max_per_path or len(anchors) + len(path_local) >= max_total:
337
- break
338
-
339
- # prefer per-path diversity
340
- anchors.extend(path_local[:max_per_path])
341
-
342
- return anchors
343
-
344
-
345
- def surprise_vs_priors(relation_shares: Dict[str, Dict[str, Any]], priors: Dict[str, float]) -> Dict[str, float]:
346
- """
347
- Absolute deviation of observed share vs prior for each relation.
348
- """
349
- s: Dict[str, float] = {}
350
- priors = priors or {}
351
- for rel, stats in relation_shares.items():
352
- obs = float(stats.get("share", 0.0))
353
- prior = float(priors.get(rel, 0.0))
354
- s[rel] = abs(obs - prior)
355
- return s
356
-
357
-
358
- def compute_baseline_comparison(
359
- current_aggregates: Dict[str, Any],
360
- baseline_aggregates: Optional[Dict[str, Any]] = None,
361
- min_baseline_edges: int = 10
362
- ) -> Dict[str, Any]:
363
- """
364
- Compare current aggregates with baseline to compute deltas & relative changes.
365
- Adds note if baseline support is low.
366
- """
367
- if not baseline_aggregates:
368
- return {"surprise": {}, "deltas": {}, "relative_changes": {}, "note": "no_baseline"}
369
-
370
- comparison = {"surprise": {}, "deltas": {}, "relative_changes": {}, "pp_change": {}}
371
-
372
- cur_shares = current_aggregates.get("relation_share", {})
373
- base_shares = baseline_aggregates.get("relation_share", {})
374
-
375
- base_total_edges = sum(v.get("count", 0) for v in base_shares.values())
376
- if base_total_edges < min_baseline_edges:
377
- comparison["note"] = "baseline_low_support"
378
-
379
- for rel in set(cur_shares.keys()) | set(base_shares.keys()):
380
- c = cur_shares.get(rel, {})
381
- b = base_shares.get(rel, {})
382
- c_share = float(c.get("share", 0.0))
383
- b_share = float(b.get("share", 0.0))
384
- delta = c_share - b_share
385
- comparison["surprise"][rel] = round(delta, 4)
386
- comparison["pp_change"][rel] = round(100.0 * delta, 2)
387
- if b_share > 0:
388
- comparison["relative_changes"][rel] = round((c_share - b_share) / b_share, 4)
389
-
390
- # Motif deltas (edge_count-based)
391
- cur_motifs = {m["pattern"]: int(m.get("edge_count", m.get("count", 0))) for m in current_aggregates.get("motifs", [])}
392
- base_motifs = {m["pattern"]: int(m.get("edge_count", m.get("count", 0))) for m in baseline_aggregates.get("motifs", [])}
393
- motif_deltas = {}
394
- for pattern in set(cur_motifs) | set(base_motifs):
395
- motif_deltas[pattern] = cur_motifs.get(pattern, 0) - base_motifs.get(pattern, 0)
396
- comparison["deltas"]["motifs"] = motif_deltas
397
-
398
- return comparison
399
-
400
-
401
- def generate_aggregates(
402
- paths: List[Dict[str, Any]],
403
- baseline_paths: Optional[List[Dict[str, Any]]] = None,
404
- *,
405
- priors: Optional[Dict[str, float]] = None,
406
- half_life_days: float = 30.0,
407
- topk_motif_density: int = 2,
408
- min_current_edges: int = 20,
409
- min_baseline_edges: int = 10
410
- ) -> Dict[str, Any]:
411
- """
412
- Generate comprehensive aggregates for insight generation & discovery triage.
413
-
414
- Returns a dict with:
415
- - motifs, relation_share, snippet_anchors
416
- - baseline (if provided) and comparison
417
- - summary: totals, motif_density, label_coverage, provenance, recency, low_support
418
- - surprise_priors (if priors provided)
419
- - dominant_relation snapshot (share, prior, surprise)
420
- """
421
- motifs = extract_motifs(paths)
422
- relation_share = calculate_relation_shares(paths)
423
- anchors = extract_snippet_anchors(paths)
424
-
425
- # Summary metrics
426
- total_edges = sum(v["count"] for v in relation_share.values())
427
- prov = compute_provenance_coverage(paths)
428
- recency = compute_recency_score(paths, half_life_days=half_life_days)
429
- label_cov = estimate_label_coverage(paths)
430
- motif_dens = compute_motif_density(motifs, topk=topk_motif_density)
431
-
432
- aggregates: Dict[str, Any] = {
433
- "motifs": motifs,
434
- "relation_share": relation_share,
435
- "snippet_anchors": anchors,
436
- "summary": {
437
- "total_paths": len(paths),
438
- "unique_motifs": len(motifs),
439
- "unique_relations": len(relation_share),
440
- "total_edges": total_edges,
441
- "provenance": round(prov, 4),
442
- "recency": round(recency, 4),
443
- "label_coverage": round(label_cov, 4),
444
- "motif_density": round(motif_dens, 4),
445
- "has_baseline": baseline_paths is not None,
446
- "low_support": total_edges < min_current_edges
447
- }
448
- }
449
-
450
- # Baseline comparison (optional)
451
- if baseline_paths is not None:
452
- base_agg = {
453
- "motifs": extract_motifs(baseline_paths),
454
- "relation_share": calculate_relation_shares(baseline_paths)
455
- }
456
- aggregates["baseline"] = base_agg
457
- aggregates["comparison"] = compute_baseline_comparison(
458
- {"motifs": motifs, "relation_share": relation_share},
459
- base_agg,
460
- min_baseline_edges=min_baseline_edges
461
- )
462
-
463
- # Priors surprise (optional)
464
- if priors:
465
- s = surprise_vs_priors(relation_share, priors)
466
- aggregates["surprise_priors"] = {k: round(v, 4) for k, v in s.items()}
467
-
468
- # Dominant relation snapshot
469
- if relation_share:
470
- dominant_rel = max(relation_share.items(), key=lambda kv: kv[1]["share"])[0]
471
- dom = {"relation": dominant_rel, "share": relation_share[dominant_rel]["share"]}
472
- if priors:
473
- dom["prior"] = float(priors.get(dominant_rel, 0.0))
474
- dom["surprise_vs_prior"] = round(abs(dom["share"] - dom["prior"]), 4)
475
- if "comparison" in aggregates:
476
- dom["delta_vs_baseline"] = aggregates["comparison"]["surprise"].get(dominant_rel, 0.0)
477
- aggregates["dominant_relation"] = dom
478
-
479
- return aggregates
480
-
481
-
482
- # ----------------------------
483
- # Insight & triage scoring
484
- # ----------------------------
485
-
486
- def decompose_insight_score(
487
- paths: List[Dict[str, Any]],
488
- *,
489
- evidence_strength: float,
490
- community_relevance: float,
491
- insight_score: float
492
- ) -> Dict[str, Any]:
493
- """
494
- Decompose an insight score into drivers.
495
-
496
- Returns:
497
- {
498
- "value": float, "label": "High|Medium|Low",
499
- "drivers": {...},
500
- "quality_gate": {...}
501
- }
502
- """
503
- # Path-level strengths
504
- path_geo_confs: List[float] = []
505
- edge_confs_all: List[float] = []
506
- ppr_scores: List[float] = []
507
-
508
- for path in paths:
509
- edges = path.get("edges", []) or []
510
- e_confs = []
511
- for e in edges:
512
- c = e.get("confidence", e.get("weight", 1.0))
513
- try:
514
- e_confs.append(float(c))
515
- edge_confs_all.append(float(c))
516
- except Exception:
517
- edge_confs_all.append(1.0)
518
- if e_confs:
519
- path_geo_confs.append(geometric_mean(e_confs))
520
-
521
- if "ppr_score" in path:
522
- try:
523
- ppr_scores.append(float(path["ppr_score"]))
524
- except Exception:
525
- pass
526
-
527
- drivers = {
528
- "path_strength": round(statistics.mean(path_geo_confs), 3) if path_geo_confs else 0.0,
529
- "ppr_strength": round(statistics.mean(ppr_scores), 3) if ppr_scores else 0.0,
530
- "edge_conf_strength": round(statistics.mean(edge_confs_all), 3) if edge_confs_all else 0.0,
531
- "evidence_strength_f": round(float(evidence_strength), 3),
532
- "community_relevance_f": round(float(community_relevance), 3),
533
- "insight_score_f": round(float(insight_score), 3)
534
- }
535
-
536
- if insight_score >= 0.7:
537
- label = "High"
538
- elif insight_score >= 0.4:
539
- label = "Medium"
540
- else:
541
- label = "Low"
542
-
543
- quality_gate = {
544
- "meets_evidence_floor": evidence_strength >= 0.5,
545
- "has_strong_paths": drivers["path_strength"] >= 0.4,
546
- "recommendation": "proceed" if insight_score >= 0.4 else "gather_more_evidence"
547
- }
548
-
549
- return {"value": round(insight_score, 2), "label": label, "drivers": drivers, "quality_gate": quality_gate}
550
-
551
-
552
- def compute_triage_score(
553
- *,
554
- provenance: float,
555
- recency: float,
556
- surprise: float,
557
- motif_density: float,
558
- controllability: float,
559
- label_coverage: Optional[float] = None,
560
- low_support: bool = False
561
- ) -> Tuple[int, Dict[str, float]]:
562
- """
563
- Compute triage score per contract:
564
- score = 25*prov + 25*rec + 25*surprise + 15*motif + 10*control
565
- All inputs expected in [0,1]. Returns (score_int_0_100, components_dict).
566
-
567
- Guards:
568
- - If label_coverage < 0.8, cap motif_density at 0.3 and subtract 15 points.
569
- - If low_support, subtract 40% of the score.
570
- """
571
- # Coerce bounds
572
- clamp = lambda x: max(0.0, min(1.0, float(x)))
573
- provenance = clamp(provenance)
574
- recency = clamp(recency)
575
- surprise = clamp(surprise)
576
- motif_density = clamp(motif_density)
577
- controllability = clamp(controllability)
578
- label_coverage = clamp(label_coverage) if label_coverage is not None else None
579
-
580
- penalty = 0.0
581
- if label_coverage is not None and label_coverage < 0.8:
582
- motif_density = min(motif_density, 0.3)
583
- penalty += 15.0
584
-
585
- base_score = (
586
- 25.0 * provenance
587
- + 25.0 * recency
588
- + 25.0 * surprise
589
- + 15.0 * motif_density
590
- + 10.0 * controllability
591
- )
592
-
593
- if low_support:
594
- base_score *= 0.6 # subtract 40%
595
-
596
- score = max(0.0, min(100.0, base_score - penalty))
597
- components = {
598
- "provenance": round(provenance, 4),
599
- "recency": round(recency, 4),
600
- "surprise": round(surprise, 4),
601
- "motif_density": round(motif_density, 4),
602
- "controllability": round(controllability, 4),
603
- "label_coverage": round(label_coverage, 4) if label_coverage is not None else None,
604
- "penalty": round(penalty, 2),
605
- "low_support": bool(low_support)
606
- }
607
- return int(round(score)), components
608
-
609
-
610
- # ----------------------------
611
- # Convenience: one-shot features for Discovery
612
- # ----------------------------
613
-
614
- def build_opportunity_features(
615
- paths: List[Dict[str, Any]],
616
- baseline_paths: Optional[List[Dict[str, Any]]] = None,
617
- *,
618
- priors: Optional[Dict[str, float]] = None,
619
- half_life_days: float = 30.0,
620
- controllability: float = 1.0,
621
- min_current_edges: int = 20,
622
- min_baseline_edges: int = 10
623
- ) -> Dict[str, Any]:
624
- """
625
- Convenience wrapper to produce everything Discovery needs to compute triage.
626
-
627
- Returns dict:
628
- {
629
- "aggregates": { ... }, # from generate_aggregates
630
- "triage": {
631
- "score": 0-100,
632
- "components": {...},
633
- "dominant_relation": {...}
634
- }
635
- }
636
- """
637
- aggs = generate_aggregates(
638
- paths,
639
- baseline_paths,
640
- priors=priors,
641
- half_life_days=half_life_days,
642
- min_current_edges=min_current_edges,
643
- min_baseline_edges=min_baseline_edges,
644
- )
645
-
646
- # Surprise component selection:
647
- # Prefer priors-based surprise on dominant relation; fall back to baseline delta abs().
648
- surprise_component = 0.0
649
- dom = aggs.get("dominant_relation") or {}
650
- if priors and "surprise_vs_prior" in dom:
651
- surprise_component = abs(float(dom.get("surprise_vs_prior", 0.0)))
652
- elif "comparison" in aggs and "delta_vs_baseline" in dom:
653
- surprise_component = abs(float(dom.get("delta_vs_baseline", 0.0)))
654
-
655
- # Compute triage
656
- summary = aggs["summary"]
657
- score, comps = compute_triage_score(
658
- provenance=summary["provenance"],
659
- recency=summary["recency"],
660
- surprise=surprise_component,
661
- motif_density=summary["motif_density"],
662
- controllability=controllability,
663
- label_coverage=summary["label_coverage"],
664
- low_support=summary["low_support"]
665
- )
666
-
667
- triage = {
668
- "score": score,
669
- "components": comps,
670
- "dominant_relation": dom
671
- }
672
-
673
- return {"aggregates": aggs, "triage": triage}
674
-
675
-
676
- # ----------------------------
677
- # Optional: basic self-test (remove in prod if undesired)
678
- # ----------------------------
679
- if __name__ == "__main__":
680
- logging.basicConfig(level=logging.INFO)
681
- # Minimal synthetic example
682
- paths_example = [
683
- {
684
- "id": "p1",
685
- "edges": [
686
- {"u_label": "ClaimForm", "v_label": "Document", "relation": "supporting_documentation",
687
- "confidence": 0.9, "created_at": "2025-07-20T12:00:00Z", "provenance": {"document_id": "docA"}},
688
- {"u_label": "Document", "v_label": "OfficerSignature", "relation": "signed_by",
689
- "confidence": 0.95, "created_at": "2025-07-21T12:00:00Z", "provenance": {"document_id": "docB"}},
690
- ],
691
- "ppr_score": 0.12
692
- },
693
- {
694
- "id": "p2",
695
- "edges": [
696
- {"u_label": "AssessmentReport", "v_label": "Document", "relation": "attached",
697
- "confidence": 0.85, "created_at": "2025-07-22T15:30:00Z", "provenance": {"document_id": "docC"}},
698
- ],
699
- "ppr_score": 0.08
700
- }
701
- ]
702
-
703
- priors_example = {"supporting_documentation": 0.30, "signed_by": 0.10, "attached": 0.20}
704
-
705
- out = build_opportunity_features(paths_example, priors=priors_example, controllability=1.0)
706
- from pprint import pprint
707
- pprint(out)
1
+ """
2
+ Aggregation utilities for insight-ready features in Odin KG Engine.
3
+ These functions analyze retrieval paths and generate quantitative patterns
4
+ for discovery triage and briefs-ready evidence.
5
+
6
+ Drop-in safe: timezone-aware, schema-noise guards, small-n gates, and
7
+ explainable scores.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import List, Dict, Any, Tuple, Optional
13
+ from collections import Counter, defaultdict
14
+ from datetime import datetime, timezone
15
+ import statistics
16
+ import logging
17
+ import math
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # ----------------------------
22
+ # Relation filtering & guards
23
+ # ----------------------------
24
+
25
+ BAN_RELATIONS = {
26
+ "is", "object", "value", "from", "to", "object type", "object value",
27
+ "object ID", "relationship", "related to", "page", "table", "id",
28
+ "chunkid", "imagemetadata"
29
+ }
30
+
31
+ # Allowed only if BOTH endpoints are NOT schema crumbs (see keep_relation)
32
+ COND_RELATIONS = {"contains", "includes"}
33
+
34
+ SCHEMA_LABELS = {"page", "table", "imagemetadata", "id", "chunkid", "object"}
35
+
36
+
37
+ def keep_relation(
38
+ rel: Optional[str],
39
+ u_label: Optional[str] = None,
40
+ v_label: Optional[str] = None
41
+ ) -> bool:
42
+ """
43
+ Return True if a relation should be kept for insight aggregation.
44
+ Enforces banlist and conditional relations logic.
45
+ """
46
+ if not rel:
47
+ return False
48
+ rel = str(rel).strip().lower()
49
+
50
+ if rel in BAN_RELATIONS:
51
+ return False
52
+
53
+ if rel in COND_RELATIONS:
54
+ # Drop if either endpoint looks like schema/meta
55
+ if (u_label and u_label.lower() in SCHEMA_LABELS) or (v_label and v_label.lower() in SCHEMA_LABELS):
56
+ return False
57
+
58
+ return True
59
+
60
+
61
+ # ----------------------------
62
+ # Timestamp & math utilities
63
+ # ----------------------------
64
+
65
+ def _now_utc() -> datetime:
66
+ return datetime.now(timezone.utc)
67
+
68
+
69
+ def parse_timestamp(ts: Any) -> Optional[datetime]:
70
+ """
71
+ Parse various timestamp formats to timezone-aware UTC datetime.
72
+ Supports ISO strings (with or without Z) and unix epoch (int/float).
73
+ """
74
+ if ts is None:
75
+ return None
76
+ try:
77
+ if isinstance(ts, str):
78
+ # Normalize trailing Z to +00:00
79
+ iso = ts.replace("Z", "+00:00")
80
+ dt = datetime.fromisoformat(iso)
81
+ if dt.tzinfo is None:
82
+ dt = dt.replace(tzinfo=timezone.utc)
83
+ return dt.astimezone(timezone.utc)
84
+ if isinstance(ts, (int, float)):
85
+ return datetime.fromtimestamp(float(ts), tz=timezone.utc)
86
+ except Exception as e:
87
+ logger.debug(f"parse_timestamp failed for {ts!r}: {e}")
88
+ return None
89
+
90
+
91
+ def geometric_mean(values: List[float]) -> float:
92
+ """
93
+ Numerically stable geometric mean on [1e-6, 1.0], returns 0.0 if empty.
94
+ """
95
+ if not values:
96
+ return 0.0
97
+ clamped = [max(1e-6, min(1.0, float(v))) for v in values]
98
+ return math.exp(sum(math.log(v) for v in clamped) / len(clamped))
99
+
100
+
101
+ # ----------------------------
102
+ # Core aggregators
103
+ # ----------------------------
104
+
105
+ def extract_motifs(paths: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
106
+ """
107
+ Cluster paths by relation sequence (motifs) and compute statistics.
108
+
109
+ Returns:
110
+ List of dicts with:
111
+ - pattern: "rel1->rel2->..."
112
+ - edge_count: int
113
+ - path_count: int (unique paths supporting this motif)
114
+ - avg_edge_conf: float
115
+ - median_recency_days: float | None
116
+ """
117
+ motif_stats: Dict[str, Dict[str, Any]] = defaultdict(
118
+ lambda: {"edge_confs": [], "timestamps": [], "path_ids": set(), "edge_count": 0}
119
+ )
120
+
121
+ for path_idx, path in enumerate(paths):
122
+ edges = path.get("edges", []) or []
123
+ if not edges:
124
+ continue
125
+
126
+ # Build motif relation sequence (filtered)
127
+ rel_seq: List[str] = []
128
+ filtered_edges: List[Dict[str, Any]] = []
129
+ for e in edges:
130
+ rel = e.get("relation", e.get("relationship"))
131
+ u_label = e.get("u_label")
132
+ v_label = e.get("v_label")
133
+ if keep_relation(rel, u_label, v_label):
134
+ rel_seq.append(str(rel).strip())
135
+ filtered_edges.append(e)
136
+
137
+ if not rel_seq:
138
+ continue
139
+
140
+ motif_pattern = "->".join(rel_seq)
141
+ path_id = path.get("id", f"path_{path_idx}")
142
+
143
+ # Collect stats
144
+ for e in filtered_edges:
145
+ motif_stats[motif_pattern]["edge_count"] += 1
146
+ conf = e.get("confidence", e.get("weight", 1.0))
147
+ try:
148
+ conf = float(conf)
149
+ except Exception:
150
+ conf = 1.0
151
+ motif_stats[motif_pattern]["edge_confs"].append(conf)
152
+
153
+ ts = e.get("created_at", e.get("timestamp"))
154
+ dt = parse_timestamp(ts)
155
+ if dt:
156
+ motif_stats[motif_pattern]["timestamps"].append(dt)
157
+
158
+ motif_stats[motif_pattern]["path_ids"].add(path_id)
159
+
160
+ motifs: List[Dict[str, Any]] = []
161
+ now = _now_utc()
162
+
163
+ for pattern, stats in motif_stats.items():
164
+ # Median recency (days)
165
+ rec_days: Optional[float] = None
166
+ if stats["timestamps"]:
167
+ try:
168
+ days_ago = [max(0.0, (now - dt).total_seconds() / 86400.0) for dt in stats["timestamps"]]
169
+ if days_ago:
170
+ rec_days = statistics.median(days_ago)
171
+ except Exception as e:
172
+ logger.warning(f"Error calculating recency for motif {pattern}: {e}")
173
+
174
+ motifs.append(
175
+ {
176
+ "pattern": pattern,
177
+ "edge_count": int(stats["edge_count"]),
178
+ "path_count": int(len(stats["path_ids"])),
179
+ "avg_edge_conf": round(statistics.mean(stats["edge_confs"]), 4) if stats["edge_confs"] else 0.0,
180
+ "median_recency_days": rec_days,
181
+ }
182
+ )
183
+
184
+ motifs.sort(key=lambda x: (x["edge_count"], x["path_count"]), reverse=True)
185
+ return motifs
186
+
187
+
188
+ def calculate_relation_shares(paths: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
189
+ """
190
+ Calculate the distribution of relations across all edges in paths (filtered).
191
+
192
+ Returns:
193
+ Dict[relation] = { "count": int, "share": float }
194
+ """
195
+ relation_counts: Counter = Counter()
196
+ total_edges = 0
197
+
198
+ for path in paths:
199
+ for e in path.get("edges", []) or []:
200
+ rel = e.get("relation", e.get("relationship"))
201
+ u_label = e.get("u_label")
202
+ v_label = e.get("v_label")
203
+ if not keep_relation(rel, u_label, v_label):
204
+ continue
205
+ relation_counts[str(rel).strip()] += 1
206
+ total_edges += 1
207
+
208
+ # Avoid division by zero
209
+ denom = max(1, total_edges)
210
+ shares = {rel: {"count": cnt, "share": round(cnt / denom, 4)} for rel, cnt in relation_counts.items()}
211
+ return shares
212
+
213
+
214
+ def compute_provenance_coverage(paths: List[Dict[str, Any]]) -> float:
215
+ """
216
+ Fraction of edges that carry any provenance/document id.
217
+ """
218
+ prov_edges = 0
219
+ total = 0
220
+ for p in paths:
221
+ for e in p.get("edges", []) or []:
222
+ total += 1
223
+ prov = e.get("provenance")
224
+ src = e.get("source_doc")
225
+ has = False
226
+ if isinstance(prov, dict):
227
+ has = bool(prov.get("document_id"))
228
+ elif isinstance(prov, list):
229
+ # any item with document_id
230
+ has = any(bool(it.get("document_id")) for it in prov if isinstance(it, dict))
231
+ has = has or bool(src)
232
+ prov_edges += int(has)
233
+ return (prov_edges / total) if total else 0.0
234
+
235
+
236
+ def compute_recency_score(paths: List[Dict[str, Any]], half_life_days: float = 30.0) -> float:
237
+ """
238
+ Recency score in [0,1] as average of exp(-age/half_life) over edges with timestamps.
239
+ """
240
+ if half_life_days <= 0:
241
+ return 0.0
242
+ now = _now_utc()
243
+ vals: List[float] = []
244
+ for p in paths:
245
+ for e in p.get("edges", []) or []:
246
+ ts = e.get("created_at", e.get("timestamp"))
247
+ dt = parse_timestamp(ts)
248
+ if not dt:
249
+ continue
250
+ age_days = max(0.0, (now - dt).total_seconds() / 86400.0)
251
+ vals.append(math.exp(-age_days / half_life_days))
252
+ return statistics.mean(vals) if vals else 0.0
253
+
254
+
255
+ def estimate_label_coverage(paths: List[Dict[str, Any]]) -> float:
256
+ """
257
+ Estimate fraction of labeled endpoints across edges (0..1).
258
+ """
259
+ labeled = 0
260
+ total = 0
261
+ for p in paths:
262
+ for e in p.get("edges", []) or []:
263
+ total += 2
264
+ labeled += int(bool(e.get("u_label")))
265
+ labeled += int(bool(e.get("v_label")))
266
+ return (labeled / total) if total else 0.0
267
+
268
+
269
+ def compute_motif_density(motifs: List[Dict[str, Any]], topk: int = 2) -> float:
270
+ """
271
+ Share of edges that fall into the top-K motifs (by edge_count).
272
+ """
273
+ if not motifs:
274
+ return 0.0
275
+ counts = [int(m.get("edge_count", m.get("count", 0))) for m in motifs]
276
+ counts.sort(reverse=True)
277
+ top = sum(counts[:topk])
278
+ total = sum(counts)
279
+ return (top / total) if total else 0.0
280
+
281
+
282
+ def extract_snippet_anchors(
283
+ paths: List[Dict[str, Any]],
284
+ max_per_path: int = 2,
285
+ max_total: int = 16
286
+ ) -> List[Dict[str, Any]]:
287
+ """
288
+ Extract provenance anchors from top paths, preferring doc diversity.
289
+
290
+ Returns:
291
+ List of { path_idx, edge_id, document_id, span? }
292
+ """
293
+ anchors: List[Dict[str, Any]] = []
294
+ seen_docs = set()
295
+
296
+ for path_idx, path in enumerate(paths[:10]): # focus on top-10 paths
297
+ if len(anchors) >= max_total:
298
+ break
299
+ edges = path.get("edges", []) or []
300
+ path_local: List[Dict[str, Any]] = []
301
+
302
+ for e in edges:
303
+ # Skip edges without allowed relations to keep anchors on-semantic
304
+ rel = e.get("relation", e.get("relationship"))
305
+ if not keep_relation(rel, e.get("u_label"), e.get("v_label")):
306
+ continue
307
+
308
+ prov = e.get("provenance")
309
+ doc_id = None
310
+ span = None
311
+
312
+ if isinstance(prov, dict):
313
+ doc_id = prov.get("document_id")
314
+ span = prov.get("char_span")
315
+ elif isinstance(prov, list) and prov:
316
+ # take first with document_id
317
+ for it in prov:
318
+ if isinstance(it, dict) and it.get("document_id"):
319
+ doc_id = it["document_id"]
320
+ span = it.get("char_span")
321
+ break
322
+
323
+ if not doc_id:
324
+ doc_id = e.get("source_doc")
325
+
326
+ if doc_id and doc_id not in seen_docs:
327
+ anchor = {
328
+ "path_idx": path_idx,
329
+ "edge_id": e.get("_id", e.get("id")),
330
+ "document_id": doc_id,
331
+ }
332
+ if span is not None:
333
+ anchor["span"] = span
334
+ path_local.append(anchor)
335
+
336
+ if len(path_local) >= max_per_path or len(anchors) + len(path_local) >= max_total:
337
+ break
338
+
339
+ # prefer per-path diversity
340
+ anchors.extend(path_local[:max_per_path])
341
+
342
+ return anchors
343
+
344
+
345
+ def surprise_vs_priors(relation_shares: Dict[str, Dict[str, Any]], priors: Dict[str, float]) -> Dict[str, float]:
346
+ """
347
+ Absolute deviation of observed share vs prior for each relation.
348
+ """
349
+ s: Dict[str, float] = {}
350
+ priors = priors or {}
351
+ for rel, stats in relation_shares.items():
352
+ obs = float(stats.get("share", 0.0))
353
+ prior = float(priors.get(rel, 0.0))
354
+ s[rel] = abs(obs - prior)
355
+ return s
356
+
357
+
358
+ def compute_baseline_comparison(
359
+ current_aggregates: Dict[str, Any],
360
+ baseline_aggregates: Optional[Dict[str, Any]] = None,
361
+ min_baseline_edges: int = 10
362
+ ) -> Dict[str, Any]:
363
+ """
364
+ Compare current aggregates with baseline to compute deltas & relative changes.
365
+ Adds note if baseline support is low.
366
+ """
367
+ if not baseline_aggregates:
368
+ return {"surprise": {}, "deltas": {}, "relative_changes": {}, "note": "no_baseline"}
369
+
370
+ comparison = {"surprise": {}, "deltas": {}, "relative_changes": {}, "pp_change": {}}
371
+
372
+ cur_shares = current_aggregates.get("relation_share", {})
373
+ base_shares = baseline_aggregates.get("relation_share", {})
374
+
375
+ base_total_edges = sum(v.get("count", 0) for v in base_shares.values())
376
+ if base_total_edges < min_baseline_edges:
377
+ comparison["note"] = "baseline_low_support"
378
+
379
+ for rel in set(cur_shares.keys()) | set(base_shares.keys()):
380
+ c = cur_shares.get(rel, {})
381
+ b = base_shares.get(rel, {})
382
+ c_share = float(c.get("share", 0.0))
383
+ b_share = float(b.get("share", 0.0))
384
+ delta = c_share - b_share
385
+ comparison["surprise"][rel] = round(delta, 4)
386
+ comparison["pp_change"][rel] = round(100.0 * delta, 2)
387
+ if b_share > 0:
388
+ comparison["relative_changes"][rel] = round((c_share - b_share) / b_share, 4)
389
+
390
+ # Motif deltas (edge_count-based)
391
+ cur_motifs = {m["pattern"]: int(m.get("edge_count", m.get("count", 0))) for m in current_aggregates.get("motifs", [])}
392
+ base_motifs = {m["pattern"]: int(m.get("edge_count", m.get("count", 0))) for m in baseline_aggregates.get("motifs", [])}
393
+ motif_deltas = {}
394
+ for pattern in set(cur_motifs) | set(base_motifs):
395
+ motif_deltas[pattern] = cur_motifs.get(pattern, 0) - base_motifs.get(pattern, 0)
396
+ comparison["deltas"]["motifs"] = motif_deltas
397
+
398
+ return comparison
399
+
400
+
401
+ def generate_aggregates(
402
+ paths: List[Dict[str, Any]],
403
+ baseline_paths: Optional[List[Dict[str, Any]]] = None,
404
+ *,
405
+ priors: Optional[Dict[str, float]] = None,
406
+ half_life_days: float = 30.0,
407
+ topk_motif_density: int = 2,
408
+ min_current_edges: int = 20,
409
+ min_baseline_edges: int = 10
410
+ ) -> Dict[str, Any]:
411
+ """
412
+ Generate comprehensive aggregates for insight generation & discovery triage.
413
+
414
+ Returns a dict with:
415
+ - motifs, relation_share, snippet_anchors
416
+ - baseline (if provided) and comparison
417
+ - summary: totals, motif_density, label_coverage, provenance, recency, low_support
418
+ - surprise_priors (if priors provided)
419
+ - dominant_relation snapshot (share, prior, surprise)
420
+ """
421
+ motifs = extract_motifs(paths)
422
+ relation_share = calculate_relation_shares(paths)
423
+ anchors = extract_snippet_anchors(paths)
424
+
425
+ # Summary metrics
426
+ total_edges = sum(v["count"] for v in relation_share.values())
427
+ prov = compute_provenance_coverage(paths)
428
+ recency = compute_recency_score(paths, half_life_days=half_life_days)
429
+ label_cov = estimate_label_coverage(paths)
430
+ motif_dens = compute_motif_density(motifs, topk=topk_motif_density)
431
+
432
+ aggregates: Dict[str, Any] = {
433
+ "motifs": motifs,
434
+ "relation_share": relation_share,
435
+ "snippet_anchors": anchors,
436
+ "summary": {
437
+ "total_paths": len(paths),
438
+ "unique_motifs": len(motifs),
439
+ "unique_relations": len(relation_share),
440
+ "total_edges": total_edges,
441
+ "provenance": round(prov, 4),
442
+ "recency": round(recency, 4),
443
+ "label_coverage": round(label_cov, 4),
444
+ "motif_density": round(motif_dens, 4),
445
+ "has_baseline": baseline_paths is not None,
446
+ "low_support": total_edges < min_current_edges
447
+ }
448
+ }
449
+
450
+ # Baseline comparison (optional)
451
+ if baseline_paths is not None:
452
+ base_agg = {
453
+ "motifs": extract_motifs(baseline_paths),
454
+ "relation_share": calculate_relation_shares(baseline_paths)
455
+ }
456
+ aggregates["baseline"] = base_agg
457
+ aggregates["comparison"] = compute_baseline_comparison(
458
+ {"motifs": motifs, "relation_share": relation_share},
459
+ base_agg,
460
+ min_baseline_edges=min_baseline_edges
461
+ )
462
+
463
+ # Priors surprise (optional)
464
+ if priors:
465
+ s = surprise_vs_priors(relation_share, priors)
466
+ aggregates["surprise_priors"] = {k: round(v, 4) for k, v in s.items()}
467
+
468
+ # Dominant relation snapshot
469
+ if relation_share:
470
+ dominant_rel = max(relation_share.items(), key=lambda kv: kv[1]["share"])[0]
471
+ dom = {"relation": dominant_rel, "share": relation_share[dominant_rel]["share"]}
472
+ if priors:
473
+ dom["prior"] = float(priors.get(dominant_rel, 0.0))
474
+ dom["surprise_vs_prior"] = round(abs(dom["share"] - dom["prior"]), 4)
475
+ if "comparison" in aggregates:
476
+ dom["delta_vs_baseline"] = aggregates["comparison"]["surprise"].get(dominant_rel, 0.0)
477
+ aggregates["dominant_relation"] = dom
478
+
479
+ return aggregates
480
+
481
+
482
+ # ----------------------------
483
+ # Insight & triage scoring
484
+ # ----------------------------
485
+
486
+ def decompose_insight_score(
487
+ paths: List[Dict[str, Any]],
488
+ *,
489
+ evidence_strength: float,
490
+ community_relevance: float,
491
+ insight_score: float
492
+ ) -> Dict[str, Any]:
493
+ """
494
+ Decompose an insight score into drivers.
495
+
496
+ Returns:
497
+ {
498
+ "value": float, "label": "High|Medium|Low",
499
+ "drivers": {...},
500
+ "quality_gate": {...}
501
+ }
502
+ """
503
+ # Path-level strengths
504
+ path_geo_confs: List[float] = []
505
+ edge_confs_all: List[float] = []
506
+ ppr_scores: List[float] = []
507
+
508
+ for path in paths:
509
+ edges = path.get("edges", []) or []
510
+ e_confs = []
511
+ for e in edges:
512
+ c = e.get("confidence", e.get("weight", 1.0))
513
+ try:
514
+ e_confs.append(float(c))
515
+ edge_confs_all.append(float(c))
516
+ except Exception:
517
+ edge_confs_all.append(1.0)
518
+ if e_confs:
519
+ path_geo_confs.append(geometric_mean(e_confs))
520
+
521
+ if "ppr_score" in path:
522
+ try:
523
+ ppr_scores.append(float(path["ppr_score"]))
524
+ except Exception:
525
+ pass
526
+
527
+ drivers = {
528
+ "path_strength": round(statistics.mean(path_geo_confs), 3) if path_geo_confs else 0.0,
529
+ "ppr_strength": round(statistics.mean(ppr_scores), 3) if ppr_scores else 0.0,
530
+ "edge_conf_strength": round(statistics.mean(edge_confs_all), 3) if edge_confs_all else 0.0,
531
+ "evidence_strength_f": round(float(evidence_strength), 3),
532
+ "community_relevance_f": round(float(community_relevance), 3),
533
+ "insight_score_f": round(float(insight_score), 3)
534
+ }
535
+
536
+ if insight_score >= 0.7:
537
+ label = "High"
538
+ elif insight_score >= 0.4:
539
+ label = "Medium"
540
+ else:
541
+ label = "Low"
542
+
543
+ quality_gate = {
544
+ "meets_evidence_floor": evidence_strength >= 0.5,
545
+ "has_strong_paths": drivers["path_strength"] >= 0.4,
546
+ "recommendation": "proceed" if insight_score >= 0.4 else "gather_more_evidence"
547
+ }
548
+
549
+ return {"value": round(insight_score, 2), "label": label, "drivers": drivers, "quality_gate": quality_gate}
550
+
551
+
552
+ def compute_triage_score(
553
+ *,
554
+ provenance: float,
555
+ recency: float,
556
+ surprise: float,
557
+ motif_density: float,
558
+ controllability: float,
559
+ label_coverage: Optional[float] = None,
560
+ low_support: bool = False
561
+ ) -> Tuple[int, Dict[str, float]]:
562
+ """
563
+ Compute triage score per contract:
564
+ score = 25*prov + 25*rec + 25*surprise + 15*motif + 10*control
565
+ All inputs expected in [0,1]. Returns (score_int_0_100, components_dict).
566
+
567
+ Guards:
568
+ - If label_coverage < 0.8, cap motif_density at 0.3 and subtract 15 points.
569
+ - If low_support, subtract 40% of the score.
570
+ """
571
+ # Coerce bounds
572
+ clamp = lambda x: max(0.0, min(1.0, float(x)))
573
+ provenance = clamp(provenance)
574
+ recency = clamp(recency)
575
+ surprise = clamp(surprise)
576
+ motif_density = clamp(motif_density)
577
+ controllability = clamp(controllability)
578
+ label_coverage = clamp(label_coverage) if label_coverage is not None else None
579
+
580
+ penalty = 0.0
581
+ if label_coverage is not None and label_coverage < 0.8:
582
+ motif_density = min(motif_density, 0.3)
583
+ penalty += 15.0
584
+
585
+ base_score = (
586
+ 25.0 * provenance
587
+ + 25.0 * recency
588
+ + 25.0 * surprise
589
+ + 15.0 * motif_density
590
+ + 10.0 * controllability
591
+ )
592
+
593
+ if low_support:
594
+ base_score *= 0.6 # subtract 40%
595
+
596
+ score = max(0.0, min(100.0, base_score - penalty))
597
+ components = {
598
+ "provenance": round(provenance, 4),
599
+ "recency": round(recency, 4),
600
+ "surprise": round(surprise, 4),
601
+ "motif_density": round(motif_density, 4),
602
+ "controllability": round(controllability, 4),
603
+ "label_coverage": round(label_coverage, 4) if label_coverage is not None else None,
604
+ "penalty": round(penalty, 2),
605
+ "low_support": bool(low_support)
606
+ }
607
+ return int(round(score)), components
608
+
609
+
610
+ # ----------------------------
611
+ # Convenience: one-shot features for Discovery
612
+ # ----------------------------
613
+
614
+ def build_opportunity_features(
615
+ paths: List[Dict[str, Any]],
616
+ baseline_paths: Optional[List[Dict[str, Any]]] = None,
617
+ *,
618
+ priors: Optional[Dict[str, float]] = None,
619
+ half_life_days: float = 30.0,
620
+ controllability: float = 1.0,
621
+ min_current_edges: int = 20,
622
+ min_baseline_edges: int = 10
623
+ ) -> Dict[str, Any]:
624
+ """
625
+ Convenience wrapper to produce everything Discovery needs to compute triage.
626
+
627
+ Returns dict:
628
+ {
629
+ "aggregates": { ... }, # from generate_aggregates
630
+ "triage": {
631
+ "score": 0-100,
632
+ "components": {...},
633
+ "dominant_relation": {...}
634
+ }
635
+ }
636
+ """
637
+ aggs = generate_aggregates(
638
+ paths,
639
+ baseline_paths,
640
+ priors=priors,
641
+ half_life_days=half_life_days,
642
+ min_current_edges=min_current_edges,
643
+ min_baseline_edges=min_baseline_edges,
644
+ )
645
+
646
+ # Surprise component selection:
647
+ # Prefer priors-based surprise on dominant relation; fall back to baseline delta abs().
648
+ surprise_component = 0.0
649
+ dom = aggs.get("dominant_relation") or {}
650
+ if priors and "surprise_vs_prior" in dom:
651
+ surprise_component = abs(float(dom.get("surprise_vs_prior", 0.0)))
652
+ elif "comparison" in aggs and "delta_vs_baseline" in dom:
653
+ surprise_component = abs(float(dom.get("delta_vs_baseline", 0.0)))
654
+
655
+ # Compute triage
656
+ summary = aggs["summary"]
657
+ score, comps = compute_triage_score(
658
+ provenance=summary["provenance"],
659
+ recency=summary["recency"],
660
+ surprise=surprise_component,
661
+ motif_density=summary["motif_density"],
662
+ controllability=controllability,
663
+ label_coverage=summary["label_coverage"],
664
+ low_support=summary["low_support"]
665
+ )
666
+
667
+ triage = {
668
+ "score": score,
669
+ "components": comps,
670
+ "dominant_relation": dom
671
+ }
672
+
673
+ return {"aggregates": aggs, "triage": triage}
674
+
675
+
676
+ # ----------------------------
677
+ # Optional: basic self-test (remove in prod if undesired)
678
+ # ----------------------------
679
+ if __name__ == "__main__":
680
+ logging.basicConfig(level=logging.INFO)
681
+ # Minimal synthetic example
682
+ paths_example = [
683
+ {
684
+ "id": "p1",
685
+ "edges": [
686
+ {"u_label": "ClaimForm", "v_label": "Document", "relation": "supporting_documentation",
687
+ "confidence": 0.9, "created_at": "2025-07-20T12:00:00Z", "provenance": {"document_id": "docA"}},
688
+ {"u_label": "Document", "v_label": "OfficerSignature", "relation": "signed_by",
689
+ "confidence": 0.95, "created_at": "2025-07-21T12:00:00Z", "provenance": {"document_id": "docB"}},
690
+ ],
691
+ "ppr_score": 0.12
692
+ },
693
+ {
694
+ "id": "p2",
695
+ "edges": [
696
+ {"u_label": "AssessmentReport", "v_label": "Document", "relation": "attached",
697
+ "confidence": 0.85, "created_at": "2025-07-22T15:30:00Z", "provenance": {"document_id": "docC"}},
698
+ ],
699
+ "ppr_score": 0.08
700
+ }
701
+ ]
702
+
703
+ priors_example = {"supporting_documentation": 0.30, "signed_by": 0.10, "attached": 0.20}
704
+
705
+ out = build_opportunity_features(paths_example, priors=priors_example, controllability=1.0)
706
+ from pprint import pprint
707
+ pprint(out)