odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/cache.py
CHANGED
|
@@ -1,159 +1,159 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Caching layer for GraphAccessor to prevent network hammering during PPR.
|
|
3
|
-
|
|
4
|
-
PRODUCTION FIX: PPR's push algorithm repeatedly calls accessor.iter_out(u) for
|
|
5
|
-
the same nodes, causing excessive network traffic. CachedGraphAccessor wraps any
|
|
6
|
-
accessor and caches neighbor lookups.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from __future__ import annotations
|
|
10
|
-
from typing import Iterable, Optional, List, Tuple
|
|
11
|
-
from collections import OrderedDict
|
|
12
|
-
|
|
13
|
-
from .adapters import GraphAccessor, NodeId, RelId
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class CachedGraphAccessor:
|
|
17
|
-
"""
|
|
18
|
-
Wraps a GraphAccessor with LRU caching for neighbor queries.
|
|
19
|
-
|
|
20
|
-
Critical for production: Prevents "network hammer" issue where PPR
|
|
21
|
-
makes repeated calls to iter_out() for the same nodes, each hitting
|
|
22
|
-
the database/network.
|
|
23
|
-
|
|
24
|
-
Usage:
|
|
25
|
-
base_accessor = ArangoCommunityAccessor(db, community_id="insurance")
|
|
26
|
-
cached_accessor = CachedGraphAccessor(base_accessor, cache_size=5000)
|
|
27
|
-
|
|
28
|
-
# Now PPR won't hammer the network
|
|
29
|
-
orchestrator = RetrievalOrchestrator(accessor=cached_accessor, ...)
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(self, base: GraphAccessor, cache_size: int = 5000):
|
|
33
|
-
"""
|
|
34
|
-
Args:
|
|
35
|
-
base: The underlying GraphAccessor to wrap
|
|
36
|
-
cache_size: Maximum number of nodes to cache (default: 5000)
|
|
37
|
-
"""
|
|
38
|
-
self.base = base
|
|
39
|
-
self.cache_size = cache_size
|
|
40
|
-
|
|
41
|
-
# LRU caches for outbound and inbound neighbors
|
|
42
|
-
self._out_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
|
|
43
|
-
self._in_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
|
|
44
|
-
|
|
45
|
-
# Stats for monitoring
|
|
46
|
-
self._hits = 0
|
|
47
|
-
self._misses = 0
|
|
48
|
-
|
|
49
|
-
def iter_out(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
|
|
50
|
-
"""Get outbound neighbors with caching."""
|
|
51
|
-
if node in self._out_cache:
|
|
52
|
-
# Cache hit - move to end (LRU)
|
|
53
|
-
self._out_cache.move_to_end(node)
|
|
54
|
-
self._hits += 1
|
|
55
|
-
return iter(self._out_cache[node])
|
|
56
|
-
|
|
57
|
-
# Cache miss - fetch from base accessor
|
|
58
|
-
self._misses += 1
|
|
59
|
-
neighbors = list(self.base.iter_out(node))
|
|
60
|
-
|
|
61
|
-
# Store in cache with LRU eviction
|
|
62
|
-
if len(self._out_cache) >= self.cache_size:
|
|
63
|
-
self._out_cache.popitem(last=False) # Remove oldest
|
|
64
|
-
self._out_cache[node] = neighbors
|
|
65
|
-
|
|
66
|
-
return iter(neighbors)
|
|
67
|
-
|
|
68
|
-
def iter_in(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
|
|
69
|
-
"""Get inbound neighbors with caching."""
|
|
70
|
-
if node in self._in_cache:
|
|
71
|
-
# Cache hit - move to end (LRU)
|
|
72
|
-
self._in_cache.move_to_end(node)
|
|
73
|
-
self._hits += 1
|
|
74
|
-
return iter(self._in_cache[node])
|
|
75
|
-
|
|
76
|
-
# Cache miss - fetch from base accessor
|
|
77
|
-
self._misses += 1
|
|
78
|
-
neighbors = list(self.base.iter_in(node))
|
|
79
|
-
|
|
80
|
-
# Store in cache with LRU eviction
|
|
81
|
-
if len(self._in_cache) >= self.cache_size:
|
|
82
|
-
self._in_cache.popitem(last=False) # Remove oldest
|
|
83
|
-
self._in_cache[node] = neighbors
|
|
84
|
-
|
|
85
|
-
return iter(neighbors)
|
|
86
|
-
|
|
87
|
-
def nodes(self, community_id: Optional[str] = None) -> Iterable[NodeId]:
|
|
88
|
-
"""Pass through to base accessor (no caching)."""
|
|
89
|
-
return self.base.nodes(community_id)
|
|
90
|
-
|
|
91
|
-
def get_node(self, node_id: NodeId, fields: Optional[List[str]] = None) -> dict:
|
|
92
|
-
"""Pass through to base accessor (node lookups are typically one-off)."""
|
|
93
|
-
return self.base.get_node(node_id, fields)
|
|
94
|
-
|
|
95
|
-
def degree(self, node: NodeId) -> int:
|
|
96
|
-
"""Pass through to base accessor."""
|
|
97
|
-
return self.base.degree(node)
|
|
98
|
-
|
|
99
|
-
def community_seed_norm(self, community_id: str, seeds: List[str]) -> List[str]:
|
|
100
|
-
"""Pass through to base accessor."""
|
|
101
|
-
return self.base.community_seed_norm(community_id, seeds)
|
|
102
|
-
|
|
103
|
-
def clear_cache(self):
|
|
104
|
-
"""Clear all caches. Useful for memory management or testing."""
|
|
105
|
-
self._out_cache.clear()
|
|
106
|
-
self._in_cache.clear()
|
|
107
|
-
self._hits = 0
|
|
108
|
-
self._misses = 0
|
|
109
|
-
|
|
110
|
-
def cache_stats(self) -> dict:
|
|
111
|
-
"""
|
|
112
|
-
Return cache statistics for monitoring.
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
Dict with hit rate, sizes, and utilization metrics
|
|
116
|
-
"""
|
|
117
|
-
total = self._hits + self._misses
|
|
118
|
-
hit_rate = self._hits / total if total > 0 else 0
|
|
119
|
-
|
|
120
|
-
return {
|
|
121
|
-
"hits": self._hits,
|
|
122
|
-
"misses": self._misses,
|
|
123
|
-
"total_requests": total,
|
|
124
|
-
"hit_rate": hit_rate,
|
|
125
|
-
"out_cache_size": len(self._out_cache),
|
|
126
|
-
"in_cache_size": len(self._in_cache),
|
|
127
|
-
"max_cache_size": self.cache_size,
|
|
128
|
-
"out_cache_utilization": len(self._out_cache) / self.cache_size,
|
|
129
|
-
"in_cache_utilization": len(self._in_cache) / self.cache_size,
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
def warm_cache(self, nodes: List[NodeId], direction: str = "out"):
|
|
133
|
-
"""
|
|
134
|
-
Pre-populate cache for a list of nodes.
|
|
135
|
-
Useful for batch operations where you know which nodes will be accessed.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
nodes: List of node IDs to pre-fetch
|
|
139
|
-
direction: "out" or "in" for outbound/inbound neighbors
|
|
140
|
-
"""
|
|
141
|
-
if direction == "out":
|
|
142
|
-
for node in nodes:
|
|
143
|
-
if node not in self._out_cache:
|
|
144
|
-
neighbors = list(self.base.iter_out(node))
|
|
145
|
-
if len(self._out_cache) >= self.cache_size:
|
|
146
|
-
self._out_cache.popitem(last=False)
|
|
147
|
-
self._out_cache[node] = neighbors
|
|
148
|
-
elif direction == "in":
|
|
149
|
-
for node in nodes:
|
|
150
|
-
if node not in self._in_cache:
|
|
151
|
-
neighbors = list(self.base.iter_in(node))
|
|
152
|
-
if len(self._in_cache) >= self.cache_size:
|
|
153
|
-
self._in_cache.popitem(last=False)
|
|
154
|
-
self._in_cache[node] = neighbors
|
|
155
|
-
|
|
156
|
-
# Delegate all other methods to base accessor
|
|
157
|
-
def __getattr__(self, name):
|
|
158
|
-
"""Delegate unknown methods to the base accessor."""
|
|
159
|
-
return getattr(self.base, name)
|
|
1
|
+
"""
|
|
2
|
+
Caching layer for GraphAccessor to prevent network hammering during PPR.
|
|
3
|
+
|
|
4
|
+
PRODUCTION FIX: PPR's push algorithm repeatedly calls accessor.iter_out(u) for
|
|
5
|
+
the same nodes, causing excessive network traffic. CachedGraphAccessor wraps any
|
|
6
|
+
accessor and caches neighbor lookups.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
from typing import Iterable, Optional, List, Tuple
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
|
|
13
|
+
from .adapters import GraphAccessor, NodeId, RelId
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CachedGraphAccessor:
|
|
17
|
+
"""
|
|
18
|
+
Wraps a GraphAccessor with LRU caching for neighbor queries.
|
|
19
|
+
|
|
20
|
+
Critical for production: Prevents "network hammer" issue where PPR
|
|
21
|
+
makes repeated calls to iter_out() for the same nodes, each hitting
|
|
22
|
+
the database/network.
|
|
23
|
+
|
|
24
|
+
Usage:
|
|
25
|
+
base_accessor = ArangoCommunityAccessor(db, community_id="insurance")
|
|
26
|
+
cached_accessor = CachedGraphAccessor(base_accessor, cache_size=5000)
|
|
27
|
+
|
|
28
|
+
# Now PPR won't hammer the network
|
|
29
|
+
orchestrator = RetrievalOrchestrator(accessor=cached_accessor, ...)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, base: GraphAccessor, cache_size: int = 5000):
|
|
33
|
+
"""
|
|
34
|
+
Args:
|
|
35
|
+
base: The underlying GraphAccessor to wrap
|
|
36
|
+
cache_size: Maximum number of nodes to cache (default: 5000)
|
|
37
|
+
"""
|
|
38
|
+
self.base = base
|
|
39
|
+
self.cache_size = cache_size
|
|
40
|
+
|
|
41
|
+
# LRU caches for outbound and inbound neighbors
|
|
42
|
+
self._out_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
|
|
43
|
+
self._in_cache: OrderedDict[NodeId, List[Tuple[NodeId, RelId, float]]] = OrderedDict()
|
|
44
|
+
|
|
45
|
+
# Stats for monitoring
|
|
46
|
+
self._hits = 0
|
|
47
|
+
self._misses = 0
|
|
48
|
+
|
|
49
|
+
def iter_out(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
|
|
50
|
+
"""Get outbound neighbors with caching."""
|
|
51
|
+
if node in self._out_cache:
|
|
52
|
+
# Cache hit - move to end (LRU)
|
|
53
|
+
self._out_cache.move_to_end(node)
|
|
54
|
+
self._hits += 1
|
|
55
|
+
return iter(self._out_cache[node])
|
|
56
|
+
|
|
57
|
+
# Cache miss - fetch from base accessor
|
|
58
|
+
self._misses += 1
|
|
59
|
+
neighbors = list(self.base.iter_out(node))
|
|
60
|
+
|
|
61
|
+
# Store in cache with LRU eviction
|
|
62
|
+
if len(self._out_cache) >= self.cache_size:
|
|
63
|
+
self._out_cache.popitem(last=False) # Remove oldest
|
|
64
|
+
self._out_cache[node] = neighbors
|
|
65
|
+
|
|
66
|
+
return iter(neighbors)
|
|
67
|
+
|
|
68
|
+
def iter_in(self, node: NodeId) -> Iterable[Tuple[NodeId, RelId, float]]:
|
|
69
|
+
"""Get inbound neighbors with caching."""
|
|
70
|
+
if node in self._in_cache:
|
|
71
|
+
# Cache hit - move to end (LRU)
|
|
72
|
+
self._in_cache.move_to_end(node)
|
|
73
|
+
self._hits += 1
|
|
74
|
+
return iter(self._in_cache[node])
|
|
75
|
+
|
|
76
|
+
# Cache miss - fetch from base accessor
|
|
77
|
+
self._misses += 1
|
|
78
|
+
neighbors = list(self.base.iter_in(node))
|
|
79
|
+
|
|
80
|
+
# Store in cache with LRU eviction
|
|
81
|
+
if len(self._in_cache) >= self.cache_size:
|
|
82
|
+
self._in_cache.popitem(last=False) # Remove oldest
|
|
83
|
+
self._in_cache[node] = neighbors
|
|
84
|
+
|
|
85
|
+
return iter(neighbors)
|
|
86
|
+
|
|
87
|
+
def nodes(self, community_id: Optional[str] = None) -> Iterable[NodeId]:
|
|
88
|
+
"""Pass through to base accessor (no caching)."""
|
|
89
|
+
return self.base.nodes(community_id)
|
|
90
|
+
|
|
91
|
+
def get_node(self, node_id: NodeId, fields: Optional[List[str]] = None) -> dict:
|
|
92
|
+
"""Pass through to base accessor (node lookups are typically one-off)."""
|
|
93
|
+
return self.base.get_node(node_id, fields)
|
|
94
|
+
|
|
95
|
+
def degree(self, node: NodeId) -> int:
|
|
96
|
+
"""Pass through to base accessor."""
|
|
97
|
+
return self.base.degree(node)
|
|
98
|
+
|
|
99
|
+
def community_seed_norm(self, community_id: str, seeds: List[str]) -> List[str]:
|
|
100
|
+
"""Pass through to base accessor."""
|
|
101
|
+
return self.base.community_seed_norm(community_id, seeds)
|
|
102
|
+
|
|
103
|
+
def clear_cache(self):
|
|
104
|
+
"""Clear all caches. Useful for memory management or testing."""
|
|
105
|
+
self._out_cache.clear()
|
|
106
|
+
self._in_cache.clear()
|
|
107
|
+
self._hits = 0
|
|
108
|
+
self._misses = 0
|
|
109
|
+
|
|
110
|
+
def cache_stats(self) -> dict:
|
|
111
|
+
"""
|
|
112
|
+
Return cache statistics for monitoring.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict with hit rate, sizes, and utilization metrics
|
|
116
|
+
"""
|
|
117
|
+
total = self._hits + self._misses
|
|
118
|
+
hit_rate = self._hits / total if total > 0 else 0
|
|
119
|
+
|
|
120
|
+
return {
|
|
121
|
+
"hits": self._hits,
|
|
122
|
+
"misses": self._misses,
|
|
123
|
+
"total_requests": total,
|
|
124
|
+
"hit_rate": hit_rate,
|
|
125
|
+
"out_cache_size": len(self._out_cache),
|
|
126
|
+
"in_cache_size": len(self._in_cache),
|
|
127
|
+
"max_cache_size": self.cache_size,
|
|
128
|
+
"out_cache_utilization": len(self._out_cache) / self.cache_size,
|
|
129
|
+
"in_cache_utilization": len(self._in_cache) / self.cache_size,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def warm_cache(self, nodes: List[NodeId], direction: str = "out"):
|
|
133
|
+
"""
|
|
134
|
+
Pre-populate cache for a list of nodes.
|
|
135
|
+
Useful for batch operations where you know which nodes will be accessed.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
nodes: List of node IDs to pre-fetch
|
|
139
|
+
direction: "out" or "in" for outbound/inbound neighbors
|
|
140
|
+
"""
|
|
141
|
+
if direction == "out":
|
|
142
|
+
for node in nodes:
|
|
143
|
+
if node not in self._out_cache:
|
|
144
|
+
neighbors = list(self.base.iter_out(node))
|
|
145
|
+
if len(self._out_cache) >= self.cache_size:
|
|
146
|
+
self._out_cache.popitem(last=False)
|
|
147
|
+
self._out_cache[node] = neighbors
|
|
148
|
+
elif direction == "in":
|
|
149
|
+
for node in nodes:
|
|
150
|
+
if node not in self._in_cache:
|
|
151
|
+
neighbors = list(self.base.iter_in(node))
|
|
152
|
+
if len(self._in_cache) >= self.cache_size:
|
|
153
|
+
self._in_cache.popitem(last=False)
|
|
154
|
+
self._in_cache[node] = neighbors
|
|
155
|
+
|
|
156
|
+
# Delegate all other methods to base accessor
|
|
157
|
+
def __getattr__(self, name):
|
|
158
|
+
"""Delegate unknown methods to the base accessor."""
|
|
159
|
+
return getattr(self.base, name)
|
retrieval/confidence.py
CHANGED
|
@@ -1,88 +1,88 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from typing import Protocol, Tuple, Optional, Callable, List
|
|
3
|
-
from collections import OrderedDict
|
|
4
|
-
from math import isfinite
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
NodeId = str
|
|
8
|
-
RelId = str
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class EdgeConfidenceProvider(Protocol):
|
|
12
|
-
def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float: ...
|
|
13
|
-
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]: ...
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class ConstantConfidence:
|
|
17
|
-
def __init__(self, value: float = 0.8):
|
|
18
|
-
self.value = max(1e-6, min(1.0, value))
|
|
19
|
-
|
|
20
|
-
def confidence(self, u, rel, v) -> float:
|
|
21
|
-
return self.value
|
|
22
|
-
|
|
23
|
-
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
|
|
24
|
-
return [self.value] * len(edges)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class NPLLConfidence:
|
|
28
|
-
"""
|
|
29
|
-
Wraps an NPLL model for retrieval-time scoring with LRU caching.
|
|
30
|
-
Uses the model's scoring module for batched probabilities when available.
|
|
31
|
-
|
|
32
|
-
PRODUCTION FIX: Uses bounded LRU cache to prevent memory leaks in long-running processes.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def __init__(self, npll_model, cache_size: int = 10000):
|
|
36
|
-
"""
|
|
37
|
-
Args:
|
|
38
|
-
npll_model: Trained NPLL model for scoring triples
|
|
39
|
-
cache_size: Maximum number of cached confidence scores (default: 10K)
|
|
40
|
-
"""
|
|
41
|
-
self.model = npll_model
|
|
42
|
-
self.cache_size = cache_size
|
|
43
|
-
self._cache: OrderedDict[Tuple[str, str, str], float] = OrderedDict()
|
|
44
|
-
|
|
45
|
-
def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float:
|
|
46
|
-
return self.confidence_batch([(u, rel, v)])[0]
|
|
47
|
-
|
|
48
|
-
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
|
|
49
|
-
todo = [(u, r, v) for (u, r, v) in edges if (u, r, v) not in self._cache]
|
|
50
|
-
if todo:
|
|
51
|
-
heads, rels, tails = zip(*todo)
|
|
52
|
-
self.model.eval()
|
|
53
|
-
with torch.no_grad():
|
|
54
|
-
scores = self.model.scoring_module.forward_with_names(list(heads), list(rels), list(tails))
|
|
55
|
-
# Don't apply per-group temperature scaling (requires group_ids we don't have)
|
|
56
|
-
probs = self.model.probability_transform(scores, apply_temperature=False)
|
|
57
|
-
for (u, r, v), p in zip(todo, probs.tolist()):
|
|
58
|
-
confidence = max(1e-6, float(p)) if isfinite(p) else 1e-6
|
|
59
|
-
|
|
60
|
-
# LRU eviction: remove oldest if at capacity
|
|
61
|
-
if len(self._cache) >= self.cache_size:
|
|
62
|
-
self._cache.popitem(last=False) # Remove oldest (FIFO)
|
|
63
|
-
|
|
64
|
-
self._cache[(u, r, v)] = confidence
|
|
65
|
-
|
|
66
|
-
# Move accessed items to end (LRU behavior)
|
|
67
|
-
result = []
|
|
68
|
-
for edge in edges:
|
|
69
|
-
conf = self._cache[edge]
|
|
70
|
-
# Move to end to mark as recently used
|
|
71
|
-
self._cache.move_to_end(edge)
|
|
72
|
-
result.append(conf)
|
|
73
|
-
|
|
74
|
-
return result
|
|
75
|
-
|
|
76
|
-
def clear_cache(self):
|
|
77
|
-
"""Clear the confidence cache. Useful for testing or memory management."""
|
|
78
|
-
self._cache.clear()
|
|
79
|
-
|
|
80
|
-
def cache_stats(self) -> dict:
|
|
81
|
-
"""Return cache statistics for monitoring."""
|
|
82
|
-
return {
|
|
83
|
-
"size": len(self._cache),
|
|
84
|
-
"max_size": self.cache_size,
|
|
85
|
-
"utilization": len(self._cache) / self.cache_size if self.cache_size > 0 else 0,
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Protocol, Tuple, Optional, Callable, List
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from math import isfinite
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
NodeId = str
|
|
8
|
+
RelId = str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EdgeConfidenceProvider(Protocol):
|
|
12
|
+
def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float: ...
|
|
13
|
+
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConstantConfidence:
|
|
17
|
+
def __init__(self, value: float = 0.8):
|
|
18
|
+
self.value = max(1e-6, min(1.0, value))
|
|
19
|
+
|
|
20
|
+
def confidence(self, u, rel, v) -> float:
|
|
21
|
+
return self.value
|
|
22
|
+
|
|
23
|
+
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
|
|
24
|
+
return [self.value] * len(edges)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class NPLLConfidence:
|
|
28
|
+
"""
|
|
29
|
+
Wraps an NPLL model for retrieval-time scoring with LRU caching.
|
|
30
|
+
Uses the model's scoring module for batched probabilities when available.
|
|
31
|
+
|
|
32
|
+
PRODUCTION FIX: Uses bounded LRU cache to prevent memory leaks in long-running processes.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, npll_model, cache_size: int = 10000):
|
|
36
|
+
"""
|
|
37
|
+
Args:
|
|
38
|
+
npll_model: Trained NPLL model for scoring triples
|
|
39
|
+
cache_size: Maximum number of cached confidence scores (default: 10K)
|
|
40
|
+
"""
|
|
41
|
+
self.model = npll_model
|
|
42
|
+
self.cache_size = cache_size
|
|
43
|
+
self._cache: OrderedDict[Tuple[str, str, str], float] = OrderedDict()
|
|
44
|
+
|
|
45
|
+
def confidence(self, u: NodeId, rel: RelId, v: NodeId) -> float:
|
|
46
|
+
return self.confidence_batch([(u, rel, v)])[0]
|
|
47
|
+
|
|
48
|
+
def confidence_batch(self, edges: List[Tuple[NodeId, RelId, NodeId]]) -> List[float]:
|
|
49
|
+
todo = [(u, r, v) for (u, r, v) in edges if (u, r, v) not in self._cache]
|
|
50
|
+
if todo:
|
|
51
|
+
heads, rels, tails = zip(*todo)
|
|
52
|
+
self.model.eval()
|
|
53
|
+
with torch.no_grad():
|
|
54
|
+
scores = self.model.scoring_module.forward_with_names(list(heads), list(rels), list(tails))
|
|
55
|
+
# Don't apply per-group temperature scaling (requires group_ids we don't have)
|
|
56
|
+
probs = self.model.probability_transform(scores, apply_temperature=False)
|
|
57
|
+
for (u, r, v), p in zip(todo, probs.tolist()):
|
|
58
|
+
confidence = max(1e-6, float(p)) if isfinite(p) else 1e-6
|
|
59
|
+
|
|
60
|
+
# LRU eviction: remove oldest if at capacity
|
|
61
|
+
if len(self._cache) >= self.cache_size:
|
|
62
|
+
self._cache.popitem(last=False) # Remove oldest (FIFO)
|
|
63
|
+
|
|
64
|
+
self._cache[(u, r, v)] = confidence
|
|
65
|
+
|
|
66
|
+
# Move accessed items to end (LRU behavior)
|
|
67
|
+
result = []
|
|
68
|
+
for edge in edges:
|
|
69
|
+
conf = self._cache[edge]
|
|
70
|
+
# Move to end to mark as recently used
|
|
71
|
+
self._cache.move_to_end(edge)
|
|
72
|
+
result.append(conf)
|
|
73
|
+
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
def clear_cache(self):
|
|
77
|
+
"""Clear the confidence cache. Useful for testing or memory management."""
|
|
78
|
+
self._cache.clear()
|
|
79
|
+
|
|
80
|
+
def cache_stats(self) -> dict:
|
|
81
|
+
"""Return cache statistics for monitoring."""
|
|
82
|
+
return {
|
|
83
|
+
"size": len(self._cache),
|
|
84
|
+
"max_size": self.cache_size,
|
|
85
|
+
"utilization": len(self._cache) / self.cache_size if self.cache_size > 0 else 0,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
retrieval/eval.py
CHANGED
|
@@ -1,49 +1,49 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from typing import List, Tuple, Set
|
|
3
|
-
from sklearn.isotonic import IsotonicRegression
|
|
4
|
-
import numpy as np
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def recall_at_k(predicted: List[Tuple[str, float]], relevant: Set[str], k: int = 10) -> float:
|
|
9
|
-
if not predicted or not relevant:
|
|
10
|
-
return 0.0
|
|
11
|
-
top = [n for n, _ in predicted[:k]]
|
|
12
|
-
hits = sum(1 for n in top if n in relevant)
|
|
13
|
-
denom = min(k, len(relevant)) or 1
|
|
14
|
-
return hits / denom
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def expected_calibration_error(probs: np.ndarray, labels: np.ndarray, n_bins: int = 10) -> float:
|
|
18
|
-
bins = np.linspace(0.0, 1.0, n_bins + 1)
|
|
19
|
-
ece = 0.0
|
|
20
|
-
N = len(probs)
|
|
21
|
-
for i in range(n_bins):
|
|
22
|
-
lo, hi = bins[i], bins[i + 1]
|
|
23
|
-
mask = (probs > lo) & (probs <= hi)
|
|
24
|
-
if not np.any(mask):
|
|
25
|
-
continue
|
|
26
|
-
acc = labels[mask].mean()
|
|
27
|
-
conf = probs[mask].mean()
|
|
28
|
-
ece += np.abs(conf - acc) * (mask.sum() / max(N, 1))
|
|
29
|
-
return float(ece)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class SimpleLLMCalibrator:
|
|
33
|
-
"""
|
|
34
|
-
Wraps isotonic regression to calibrate LLM self-reported confidences.
|
|
35
|
-
"""
|
|
36
|
-
def __init__(self):
|
|
37
|
-
self.iso = IsotonicRegression(out_of_bounds='clip')
|
|
38
|
-
self.is_fitted = False
|
|
39
|
-
|
|
40
|
-
def fit(self, raw_conf: np.ndarray, labels: np.ndarray):
|
|
41
|
-
self.iso.fit(raw_conf, labels)
|
|
42
|
-
self.is_fitted = True
|
|
43
|
-
|
|
44
|
-
def transform(self, raw_conf: np.ndarray) -> np.ndarray:
|
|
45
|
-
if not self.is_fitted:
|
|
46
|
-
return raw_conf
|
|
47
|
-
return self.iso.transform(raw_conf)
|
|
48
|
-
|
|
49
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Tuple, Set
|
|
3
|
+
from sklearn.isotonic import IsotonicRegression
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def recall_at_k(predicted: List[Tuple[str, float]], relevant: Set[str], k: int = 10) -> float:
|
|
9
|
+
if not predicted or not relevant:
|
|
10
|
+
return 0.0
|
|
11
|
+
top = [n for n, _ in predicted[:k]]
|
|
12
|
+
hits = sum(1 for n in top if n in relevant)
|
|
13
|
+
denom = min(k, len(relevant)) or 1
|
|
14
|
+
return hits / denom
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def expected_calibration_error(probs: np.ndarray, labels: np.ndarray, n_bins: int = 10) -> float:
|
|
18
|
+
bins = np.linspace(0.0, 1.0, n_bins + 1)
|
|
19
|
+
ece = 0.0
|
|
20
|
+
N = len(probs)
|
|
21
|
+
for i in range(n_bins):
|
|
22
|
+
lo, hi = bins[i], bins[i + 1]
|
|
23
|
+
mask = (probs > lo) & (probs <= hi)
|
|
24
|
+
if not np.any(mask):
|
|
25
|
+
continue
|
|
26
|
+
acc = labels[mask].mean()
|
|
27
|
+
conf = probs[mask].mean()
|
|
28
|
+
ece += np.abs(conf - acc) * (mask.sum() / max(N, 1))
|
|
29
|
+
return float(ece)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SimpleLLMCalibrator:
|
|
33
|
+
"""
|
|
34
|
+
Wraps isotonic regression to calibrate LLM self-reported confidences.
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.iso = IsotonicRegression(out_of_bounds='clip')
|
|
38
|
+
self.is_fitted = False
|
|
39
|
+
|
|
40
|
+
def fit(self, raw_conf: np.ndarray, labels: np.ndarray):
|
|
41
|
+
self.iso.fit(raw_conf, labels)
|
|
42
|
+
self.is_fitted = True
|
|
43
|
+
|
|
44
|
+
def transform(self, raw_conf: np.ndarray) -> np.ndarray:
|
|
45
|
+
if not self.is_fitted:
|
|
46
|
+
return raw_conf
|
|
47
|
+
return self.iso.transform(raw_conf)
|
|
48
|
+
|
|
49
|
+
|