odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/cache.py CHANGED
@@ -1,159 +1,159 @@
1
- """
2
- Caching layer for GraphAccessor to prevent network hammering during PPR.
3
-
4
- PRODUCTION FIX: PPR's push algorithm repeatedly calls accessor.iter_out(u) for
5
- the same nodes, causing excessive network traffic. CachedGraphAccessor wraps any
6
- accessor and caches neighbor lookups.
7
- """
8
-
9
- from __future__ import annotations
10
- from typing import Iterable, Optional, List, Tuple
11
- from collections import OrderedDict
12
-
13
- from .adapters import GraphAccessor, NodeId, RelId
14
-
15
-
16
- class CachedGraphAccessor:
17
- """
18
- Wraps a GraphAccessor with LRU caching for neighbor queries.
19
-
20
- Critical for production: Prevents "network hammer" issue where PPR
21
- makes repeated calls to iter_out() for the same nodes, each hitting
22
- the database/network.
23
-
24
- Usage:
25
- base_accessor = ArangoCommunityAccessor(db, community_id="insurance")
26
- cached_accessor = CachedGraphAccessor(base_accessor, cache_size=5000)
27
-
28
- # Now PPR won't hammer the network
29
- orchestrator = RetrievalOrchestrator(accessor=cached_accessor, ...)
30
- """
31
-
32
- def __init__(self, base: GraphAccessor, cache_size: int = 5000):
33
- """
34
- Args:
35
- base: The underlying GraphAccessor to wrap
36
- cache_size: Maximum number of nodes to cache (default: 5000)
37
- """
38
- self.base = base
39
- self.cache_size = cache_size
40
-
41
- # LRU caches for outbound and inbound neighbors
42
- self._out_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
43
- self._in_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
44
-
45
- # Stats for monitoring
46
- self._hits = 0
47
- self._misses = 0
48
-
49
- def iter_out(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
50
- """Get outbound neighbors with caching."""
51
- if node in self._out_cache:
52
- # Cache hit - move to end (LRU)
53
- self._out_cache.move_to_end(node)
54
- self._hits += 1
55
- return iter(self._out_cache[node])
56
-
57
- # Cache miss - fetch from base accessor
58
- self._misses += 1
59
- neighbors = list(self.base.iter_out(node))
60
-
61
- # Store in cache with LRU eviction
62
- if len(self._out_cache) >= self.cache_size:
63
- self._out_cache.popitem(last=False) # Remove oldest
64
- self._out_cache[node] = neighbors
65
-
66
- return iter(neighbors)
67
-
68
- def iter_in(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
69
- """Get inbound neighbors with caching."""
70
- if node in self._in_cache:
71
- # Cache hit - move to end (LRU)
72
- self._in_cache.move_to_end(node)
73
- self._hits += 1
74
- return iter(self._in_cache[node])
75
-
76
- # Cache miss - fetch from base accessor
77
- self._misses += 1
78
- neighbors = list(self.base.iter_in(node))
79
-
80
- # Store in cache with LRU eviction
81
- if len(self._in_cache) >= self.cache_size:
82
- self._in_cache.popitem(last=False) # Remove oldest
83
- self._in_cache[node] = neighbors
84
-
85
- return iter(neighbors)
86
-
87
- def nodes(self, community_id: Optional[str] = None) -> Iterable[NodeId]:
88
- """Pass through to base accessor (no caching)."""
89
- return self.base.nodes(community_id)
90
-
91
- def get_node(self, node_id: NodeId, fields: Optional[List[str]] = None) -> dict:
92
- """Pass through to base accessor (node lookups are typically one-off)."""
93
- return self.base.get_node(node_id, fields)
94
-
95
- def degree(self, node: NodeId) -> int:
96
- """Pass through to base accessor."""
97
- return self.base.degree(node)
98
-
99
- def community_seed_norm(self, community_id: str, seeds: List[str]) -> List[str]:
100
- """Pass through to base accessor."""
101
- return self.base.community_seed_norm(community_id, seeds)
102
-
103
- def clear_cache(self):
104
- """Clear all caches. Useful for memory management or testing."""
105
- self._out_cache.clear()
106
- self._in_cache.clear()
107
- self._hits = 0
108
- self._misses = 0
109
-
110
- def cache_stats(self) -> dict:
111
- """
112
- Return cache statistics for monitoring.
113
-
114
- Returns:
115
- Dict with hit rate, sizes, and utilization metrics
116
- """
117
- total = self._hits + self._misses
118
- hit_rate = self._hits / total if total > 0 else 0
119
-
120
- return {
121
- "hits": self._hits,
122
- "misses": self._misses,
123
- "total_requests": total,
124
- "hit_rate": hit_rate,
125
- "out_cache_size": len(self._out_cache),
126
- "in_cache_size": len(self._in_cache),
127
- "max_cache_size": self.cache_size,
128
- "out_cache_utilization": len(self._out_cache) / self.cache_size,
129
- "in_cache_utilization": len(self._in_cache) / self.cache_size,
130
- }
131
-
132
- def warm_cache(self, nodes: List[NodeId], direction: str = "out"):
133
- """
134
- Pre-populate cache for a list of nodes.
135
- Useful for batch operations where you know which nodes will be accessed.
136
-
137
- Args:
138
- nodes: List of node IDs to pre-fetch
139
- direction: "out" or "in" for outbound/inbound neighbors
140
- """
141
- if direction == "out":
142
- for node in nodes:
143
- if node not in self._out_cache:
144
- neighbors = list(self.base.iter_out(node))
145
- if len(self._out_cache) >= self.cache_size:
146
- self._out_cache.popitem(last=False)
147
- self._out_cache[node] = neighbors
148
- elif direction == "in":
149
- for node in nodes:
150
- if node not in self._in_cache:
151
- neighbors = list(self.base.iter_in(node))
152
- if len(self._in_cache) >= self.cache_size:
153
- self._in_cache.popitem(last=False)
154
- self._in_cache[node] = neighbors
155
-
156
- # Delegate all other methods to base accessor
157
- def __getattr__(self, name):
158
- """Delegate unknown methods to the base accessor."""
159
- return getattr(self.base, name)
1
+ """
2
+ Caching layer for GraphAccessor to prevent network hammering during PPR.
3
+
4
+ PRODUCTION FIX: PPR's push algorithm repeatedly calls accessor.iter_out(u) for
5
+ the same nodes, causing excessive network traffic. CachedGraphAccessor wraps any
6
+ accessor and caches neighbor lookups.
7
+ """
8
+
9
+ from __future__ import annotations
10
+ from typing import Iterable, Optional, List, Tuple
11
+ from collections import OrderedDict
12
+
13
+ from .adapters import GraphAccessor, NodeId, RelId
14
+
15
+
16
+ class CachedGraphAccessor:
17
+ """
18
+ Wraps a GraphAccessor with LRU caching for neighbor queries.
19
+
20
+ Critical for production: Prevents "network hammer" issue where PPR
21
+ makes repeated calls to iter_out() for the same nodes, each hitting
22
+ the database/network.
23
+
24
+ Usage:
25
+ base_accessor = ArangoCommunityAccessor(db, community_id="insurance")
26
+ cached_accessor = CachedGraphAccessor(base_accessor, cache_size=5000)
27
+
28
+ # Now PPR won't hammer the network
29
+ orchestrator = RetrievalOrchestrator(accessor=cached_accessor, ...)
30
+ """
31
+
32
+ def __init__(self, base: GraphAccessor, cache_size: int = 5000):
33
+ """
34
+ Args:
35
+ base: The underlying GraphAccessor to wrap
36
+ cache_size: Maximum number of nodes to cache (default: 5000)
37
+ """
38
+ self.base = base
39
+ self.cache_size = cache_size
40
+
41
+ # LRU caches for outbound and inbound neighbors
42
+ self._out_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
43
+ self._in_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
44
+
45
+ # Stats for monitoring
46
+ self._hits = 0
47
+ self._misses = 0
48
+
49
+ def iter_out(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
50
+ """Get outbound neighbors with caching."""
51
+ if node in self._out_cache:
52
+ # Cache hit - move to end (LRU)
53
+ self._out_cache.move_to_end(node)
54
+ self._hits += 1
55
+ return iter(self._out_cache[node])
56
+
57
+ # Cache miss - fetch from base accessor
58
+ self._misses += 1
59
+ neighbors = list(self.base.iter_out(node))
60
+
61
+ # Store in cache with LRU eviction
62
+ if len(self._out_cache) >= self.cache_size:
63
+ self._out_cache.popitem(last=False) # Remove oldest
64
+ self._out_cache[node] = neighbors
65
+
66
+ return iter(neighbors)
67
+
68
+ def iter_in(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
69
+ """Get inbound neighbors with caching."""
70
+ if node in self._in_cache:
71
+ # Cache hit - move to end (LRU)
72
+ self._in_cache.move_to_end(node)
73
+ self._hits += 1
74
+ return iter(self._in_cache[node])
75
+
76
+ # Cache miss - fetch from base accessor
77
+ self._misses += 1
78
+ neighbors = list(self.base.iter_in(node))
79
+
80
+ # Store in cache with LRU eviction
81
+ if len(self._in_cache) >= self.cache_size:
82
+ self._in_cache.popitem(last=False) # Remove oldest
83
+ self._in_cache[node] = neighbors
84
+
85
+ return iter(neighbors)
86
+
87
+ def nodes(self, community_id: Optional[str] = None) -> Iterable[NodeId]:
88
+ """Pass through to base accessor (no caching)."""
89
+ return self.base.nodes(community_id)
90
+
91
+ def get_node(self, node_id: NodeId, fields: Optional[List[str]] = None) -> dict:
92
+ """Pass through to base accessor (node lookups are typically one-off)."""
93
+ return self.base.get_node(node_id, fields)
94
+
95
+ def degree(self, node: NodeId) -> int:
96
+ """Pass through to base accessor."""
97
+ return self.base.degree(node)
98
+
99
+ def community_seed_norm(self, community_id: str, seeds: List[str]) -> List[str]:
100
+ """Pass through to base accessor."""
101
+ return self.base.community_seed_norm(community_id, seeds)
102
+
103
+ def clear_cache(self):
104
+ """Clear all caches. Useful for memory management or testing."""
105
+ self._out_cache.clear()
106
+ self._in_cache.clear()
107
+ self._hits = 0
108
+ self._misses = 0
109
+
110
+ def cache_stats(self) -> dict:
111
+ """
112
+ Return cache statistics for monitoring.
113
+
114
+ Returns:
115
+ Dict with hit rate, sizes, and utilization metrics
116
+ """
117
+ total = self._hits + self._misses
118
+ hit_rate = self._hits / total if total > 0 else 0
119
+
120
+ return {
121
+ "hits": self._hits,
122
+ "misses": self._misses,
123
+ "total_requests": total,
124
+ "hit_rate": hit_rate,
125
+ "out_cache_size": len(self._out_cache),
126
+ "in_cache_size": len(self._in_cache),
127
+ "max_cache_size": self.cache_size,
128
+ "out_cache_utilization": len(self._out_cache) / self.cache_size,
129
+ "in_cache_utilization": len(self._in_cache) / self.cache_size,
130
+ }
131
+
132
+ def warm_cache(self, nodes: List[NodeId], direction: str = "out"):
133
+ """
134
+ Pre-populate cache for a list of nodes.
135
+ Useful for batch operations where you know which nodes will be accessed.
136
+
137
+ Args:
138
+ nodes: List of node IDs to pre-fetch
139
+ direction: "out" or "in" for outbound/inbound neighbors
140
+ """
141
+ if direction == "out":
142
+ for node in nodes:
143
+ if node not in self._out_cache:
144
+ neighbors = list(self.base.iter_out(node))
145
+ if len(self._out_cache) >= self.cache_size:
146
+ self._out_cache.popitem(last=False)
147
+ self._out_cache[node] = neighbors
148
+ elif direction == "in":
149
+ for node in nodes:
150
+ if node not in self._in_cache:
151
+ neighbors = list(self.base.iter_in(node))
152
+ if len(self._in_cache) >= self.cache_size:
153
+ self._in_cache.popitem(last=False)
154
+ self._in_cache[node] = neighbors
155
+
156
+ # Delegate all other methods to base accessor
157
+ def __getattr__(self, name):
158
+ """Delegate unknown methods to the base accessor."""
159
+ return getattr(self.base, name)
retrieval/confidence.py CHANGED
@@ -1,88 +1,88 @@
1
- from __future__ import annotations
2
- from typing import Protocol, Tuple, Optional, Callable, List
3
- from collections import OrderedDict
4
- from math import isfinite
5
- import torch
6
-
7
- NodeId = str
8
- RelId = str
9
-
10
-
11
- class EdgeConfidenceProvider(Protocol):
12
- def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float: ...
13
- def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]: ...
14
-
15
-
16
- class ConstantConfidence:
17
- def __init__(self, value: float = 0.8):
18
- self.value = max(1e-6, min(1.0, value))
19
-
20
- def confidence(self, u, rel, v) -> float:
21
- return self.value
22
-
23
- def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
24
- return [self.value] * len(edges)
25
-
26
-
27
- class NPLLConfidence:
28
- """
29
- Wraps an NPLL model for retrieval-time scoring with LRU caching.
30
- Uses the model's scoring module for batched probabilities when available.
31
-
32
- PRODUCTION FIX: Uses bounded LRU cache to prevent memory leaks in long-running processes.
33
- """
34
-
35
- def __init__(self, npll_model, cache_size: int = 10000):
36
- """
37
- Args:
38
- npll_model: Trained NPLL model for scoring triples
39
- cache_size: Maximum number of cached confidence scores (default: 10K)
40
- """
41
- self.model = npll_model
42
- self.cache_size = cache_size
43
- self._cache: OrderedDict[Tuple[str, str, str], float] = OrderedDict()
44
-
45
- def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float:
46
- return self.confidence_batch([(u, rel, v)])[0]
47
-
48
- def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
49
- todo = [(u, r, v) for (u, r, v) in edges if (u, r, v) not in self._cache]
50
- if todo:
51
- heads, rels, tails = zip(*todo)
52
- self.model.eval()
53
- with torch.no_grad():
54
- scores = self.model.scoring_module.forward_with_names(list(heads), list(rels), list(tails))
55
- # Don't apply per-group temperature scaling (requires group_ids we don't have)
56
- probs = self.model.probability_transform(scores, apply_temperature=False)
57
- for (u, r, v), p in zip(todo, probs.tolist()):
58
- confidence = max(1e-6, float(p)) if isfinite(p) else 1e-6
59
-
60
- # LRU eviction: remove oldest if at capacity
61
- if len(self._cache) >= self.cache_size:
62
- self._cache.popitem(last=False) # Remove oldest (FIFO)
63
-
64
- self._cache[(u, r, v)] = confidence
65
-
66
- # Move accessed items to end (LRU behavior)
67
- result = []
68
- for edge in edges:
69
- conf = self._cache[edge]
70
- # Move to end to mark as recently used
71
- self._cache.move_to_end(edge)
72
- result.append(conf)
73
-
74
- return result
75
-
76
- def clear_cache(self):
77
- """Clear the confidence cache. Useful for testing or memory management."""
78
- self._cache.clear()
79
-
80
- def cache_stats(self) -> dict:
81
- """Return cache statistics for monitoring."""
82
- return {
83
- "size": len(self._cache),
84
- "max_size": self.cache_size,
85
- "utilization": len(self._cache) / self.cache_size if self.cache_size > 0 else 0,
86
- }
87
-
88
-
1
+ from __future__ import annotations
2
+ from typing import Protocol, Tuple, Optional, Callable, List
3
+ from collections import OrderedDict
4
+ from math import isfinite
5
+ import torch
6
+
7
+ NodeId = str
8
+ RelId = str
9
+
10
+
11
+ class EdgeConfidenceProvider(Protocol):
12
+ def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float: ...
13
+ def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]: ...
14
+
15
+
16
+ class ConstantConfidence:
17
+ def __init__(self, value: float = 0.8):
18
+ self.value = max(1e-6, min(1.0, value))
19
+
20
+ def confidence(self, u, rel, v) -> float:
21
+ return self.value
22
+
23
+ def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
24
+ return [self.value] * len(edges)
25
+
26
+
27
+ class NPLLConfidence:
28
+ """
29
+ Wraps an NPLL model for retrieval-time scoring with LRU caching.
30
+ Uses the model's scoring module for batched probabilities when available.
31
+
32
+ PRODUCTION FIX: Uses bounded LRU cache to prevent memory leaks in long-running processes.
33
+ """
34
+
35
+ def __init__(self, npll_model, cache_size: int = 10000):
36
+ """
37
+ Args:
38
+ npll_model: Trained NPLL model for scoring triples
39
+ cache_size: Maximum number of cached confidence scores (default: 10K)
40
+ """
41
+ self.model = npll_model
42
+ self.cache_size = cache_size
43
+ self._cache: OrderedDict[Tuple[str, str, str], float] = OrderedDict()
44
+
45
+ def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float:
46
+ return self.confidence_batch([(u, rel, v)])[0]
47
+
48
+ def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
49
+ todo = [(u, r, v) for (u, r, v) in edges if (u, r, v) not in self._cache]
50
+ if todo:
51
+ heads, rels, tails = zip(*todo)
52
+ self.model.eval()
53
+ with torch.no_grad():
54
+ scores = self.model.scoring_module.forward_with_names(list(heads), list(rels), list(tails))
55
+ # Don't apply per-group temperature scaling (requires group_ids we don't have)
56
+ probs = self.model.probability_transform(scores, apply_temperature=False)
57
+ for (u, r, v), p in zip(todo, probs.tolist()):
58
+ confidence = max(1e-6, float(p)) if isfinite(p) else 1e-6
59
+
60
+ # LRU eviction: remove oldest if at capacity
61
+ if len(self._cache) >= self.cache_size:
62
+ self._cache.popitem(last=False) # Remove oldest (FIFO)
63
+
64
+ self._cache[(u, r, v)] = confidence
65
+
66
+ # Move accessed items to end (LRU behavior)
67
+ result = []
68
+ for edge in edges:
69
+ conf = self._cache[edge]
70
+ # Move to end to mark as recently used
71
+ self._cache.move_to_end(edge)
72
+ result.append(conf)
73
+
74
+ return result
75
+
76
+ def clear_cache(self):
77
+ """Clear the confidence cache. Useful for testing or memory management."""
78
+ self._cache.clear()
79
+
80
+ def cache_stats(self) -> dict:
81
+ """Return cache statistics for monitoring."""
82
+ return {
83
+ "size": len(self._cache),
84
+ "max_size": self.cache_size,
85
+ "utilization": len(self._cache) / self.cache_size if self.cache_size > 0 else 0,
86
+ }
87
+
88
+
retrieval/eval.py CHANGED
@@ -1,49 +1,49 @@
1
- from __future__ import annotations
2
- from typing import List, Tuple, Set
3
- from sklearn.isotonic import IsotonicRegression
4
- import numpy as np
5
- import numpy as np
6
-
7
-
8
- def recall_at_k(predicted: List[Tuple[str, float]], relevant: Set[str], k: int = 10) -> float:
9
- if not predicted or not relevant:
10
- return 0.0
11
- top = [n for n, _ in predicted[:k]]
12
- hits = sum(1 for n in top if n in relevant)
13
- denom = min(k, len(relevant)) or 1
14
- return hits / denom
15
-
16
-
17
- def expected_calibration_error(probs: np.ndarray, labels: np.ndarray, n_bins: int = 10) -> float:
18
- bins = np.linspace(0.0, 1.0, n_bins + 1)
19
- ece = 0.0
20
- N = len(probs)
21
- for i in range(n_bins):
22
- lo, hi = bins[i], bins[i + 1]
23
- mask = (probs > lo) & (probs <= hi)
24
- if not np.any(mask):
25
- continue
26
- acc = labels[mask].mean()
27
- conf = probs[mask].mean()
28
- ece += np.abs(conf - acc) * (mask.sum() / max(N, 1))
29
- return float(ece)
30
-
31
-
32
- class SimpleLLMCalibrator:
33
- """
34
- Wraps isotonic regression to calibrate LLM self-reported confidences.
35
- """
36
- def __init__(self):
37
- self.iso = IsotonicRegression(out_of_bounds='clip')
38
- self.is_fitted = False
39
-
40
- def fit(self, raw_conf: np.ndarray, labels: np.ndarray):
41
- self.iso.fit(raw_conf, labels)
42
- self.is_fitted = True
43
-
44
- def transform(self, raw_conf: np.ndarray) -> np.ndarray:
45
- if not self.is_fitted:
46
- return raw_conf
47
- return self.iso.transform(raw_conf)
48
-
49
-
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, Set
3
+ from sklearn.isotonic import IsotonicRegression
4
+ import numpy as np
5
+ import numpy as np
6
+
7
+
8
+ def recall_at_k(predicted: List[Tuple[str, float]], relevant: Set[str], k: int = 10) -> float:
9
+ if not predicted or not relevant:
10
+ return 0.0
11
+ top = [n for n, _ in predicted[:k]]
12
+ hits = sum(1 for n in top if n in relevant)
13
+ denom = min(k, len(relevant)) or 1
14
+ return hits / denom
15
+
16
+
17
+ def expected_calibration_error(probs: np.ndarray, labels: np.ndarray, n_bins: int = 10) -> float:
18
+ bins = np.linspace(0.0, 1.0, n_bins + 1)
19
+ ece = 0.0
20
+ N = len(probs)
21
+ for i in range(n_bins):
22
+ lo, hi = bins[i], bins[i + 1]
23
+ mask = (probs > lo) & (probs <= hi)
24
+ if not np.any(mask):
25
+ continue
26
+ acc = labels[mask].mean()
27
+ conf = probs[mask].mean()
28
+ ece += np.abs(conf - acc) * (mask.sum() / max(N, 1))
29
+ return float(ece)
30
+
31
+
32
+ class SimpleLLMCalibrator:
33
+ """
34
+ Wraps isotonic regression to calibrate LLM self-reported confidences.
35
+ """
36
+ def __init__(self):
37
+ self.iso = IsotonicRegression(out_of_bounds='clip')
38
+ self.is_fitted = False
39
+
40
+ def fit(self, raw_conf: np.ndarray, labels: np.ndarray):
41
+ self.iso.fit(raw_conf, labels)
42
+ self.is_fitted = True
43
+
44
+ def transform(self, raw_conf: np.ndarray) -> np.ndarray:
45
+ if not self.is_fitted:
46
+ return raw_conf
47
+ return self.iso.transform(raw_conf)
48
+
49
+