odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/beam.py
CHANGED
|
@@ -1,127 +1,127 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from typing import List, Tuple, Dict, Optional, Set
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
import heapq, math, time
|
|
5
|
-
from .adapters import GraphAccessor, NodeId, RelId
|
|
6
|
-
from .confidence import EdgeConfidenceProvider, ConstantConfidence
|
|
7
|
-
from .budget import SearchBudget, BudgetTracker
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class BeamParams:
|
|
12
|
-
hop_limit: int = 3
|
|
13
|
-
beam_width: int = 64
|
|
14
|
-
max_paths: int = 200
|
|
15
|
-
lambdas: Tuple[float, float, float, float] = (0.6, 0.2, 0.15, 0.05)
|
|
16
|
-
allowed_relations: Optional[Set[RelId]] = None
|
|
17
|
-
max_out_degree: Optional[int] = None
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def default_recency(edge_timestamp: Optional[float], now_ts: Optional[float] = None, tau_days: float = 90.0) -> float:
|
|
21
|
-
if edge_timestamp is None:
|
|
22
|
-
return 1.0
|
|
23
|
-
if now_ts is None:
|
|
24
|
-
now_ts = time.time()
|
|
25
|
-
dt_days = max(0.0, (now_ts - edge_timestamp) / (60 * 60 * 24))
|
|
26
|
-
return math.exp(-dt_days / tau_days)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def safe_logp(x: float) -> float:
|
|
30
|
-
return math.log(max(x, 1e-12))
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def beam_search(
|
|
34
|
-
accessor: GraphAccessor,
|
|
35
|
-
community_id: str,
|
|
36
|
-
seeds: List[NodeId],
|
|
37
|
-
ppr_scores: List[Tuple[NodeId, float]],
|
|
38
|
-
budget: Optional[SearchBudget] = None,
|
|
39
|
-
beam_params: BeamParams = BeamParams(),
|
|
40
|
-
conf_provider: EdgeConfidenceProvider = ConstantConfidence(0.8),
|
|
41
|
-
edge_type_prior: Optional[Dict[RelId, float]] = None,
|
|
42
|
-
edge_timestamp_lookup=None,
|
|
43
|
-
) -> Dict[str, object]:
|
|
44
|
-
|
|
45
|
-
bt = BudgetTracker(budget or SearchBudget(max_paths=beam_params.max_paths))
|
|
46
|
-
L1, L2, L3, L4 = beam_params.lambdas
|
|
47
|
-
ppr = {n: p for n, p in ppr_scores}
|
|
48
|
-
edge_type_prior = edge_type_prior or {}
|
|
49
|
-
|
|
50
|
-
heap: List[Tuple[float, List[NodeId], List[Tuple[NodeId, RelId, NodeId]]]] = []
|
|
51
|
-
for s in seeds:
|
|
52
|
-
heapq.heappush(heap, (0.0, [s], []))
|
|
53
|
-
|
|
54
|
-
best_paths = []
|
|
55
|
-
|
|
56
|
-
def score_extension(u: NodeId, rel: RelId, v: NodeId) -> float:
|
|
57
|
-
p1 = ppr.get(v, 1e-12)
|
|
58
|
-
w_edge = edge_type_prior.get(rel, 1.0)
|
|
59
|
-
c = conf_provider.confidence(u, rel, v)
|
|
60
|
-
rec = default_recency(edge_timestamp_lookup(u, rel, v) if edge_timestamp_lookup else None)
|
|
61
|
-
return L1 * safe_logp(p1) + L2 * safe_logp(w_edge) + L3 * safe_logp(c) + L4 * safe_logp(rec)
|
|
62
|
-
|
|
63
|
-
early_stop_reason = None
|
|
64
|
-
for hop in range(1, beam_params.hop_limit + 1):
|
|
65
|
-
next_heap: List[Tuple[float, List[NodeId], List[Tuple[NodeId, RelId, NodeId]]]] = []
|
|
66
|
-
while heap and not bt.over():
|
|
67
|
-
logscore, path_nodes, path_edges = heapq.heappop(heap)
|
|
68
|
-
u = path_nodes[-1]
|
|
69
|
-
bt.tick_nodes(1)
|
|
70
|
-
out_iter = accessor.iter_out(u)
|
|
71
|
-
if beam_params.max_out_degree is not None:
|
|
72
|
-
# Degree cap: take only first N neighbors
|
|
73
|
-
out_iter = list(out_iter)[: beam_params.max_out_degree]
|
|
74
|
-
for v, rel, _ in out_iter:
|
|
75
|
-
if bt.over():
|
|
76
|
-
break
|
|
77
|
-
bt.tick_edges(1)
|
|
78
|
-
if v in path_nodes:
|
|
79
|
-
continue
|
|
80
|
-
if beam_params.allowed_relations is not None and rel not in beam_params.allowed_relations:
|
|
81
|
-
continue
|
|
82
|
-
inc = score_extension(u, rel, v)
|
|
83
|
-
new_score = logscore + inc
|
|
84
|
-
new_nodes = path_nodes + [v]
|
|
85
|
-
new_edges = path_edges + [(u, rel, v)]
|
|
86
|
-
heapq.heappush(next_heap, (new_score, new_nodes, new_edges))
|
|
87
|
-
if len(next_heap) > beam_params.beam_width:
|
|
88
|
-
heapq.heappop(next_heap)
|
|
89
|
-
if bt.timed_out():
|
|
90
|
-
early_stop_reason = early_stop_reason or "timeout"
|
|
91
|
-
break
|
|
92
|
-
|
|
93
|
-
next_heap.sort(key=lambda x: x[0], reverse=True)
|
|
94
|
-
for sc, nodes, edges in next_heap:
|
|
95
|
-
best_paths.append((sc, nodes, edges))
|
|
96
|
-
bt.tick_paths(1)
|
|
97
|
-
if bt.over():
|
|
98
|
-
break
|
|
99
|
-
heap = next_heap
|
|
100
|
-
if bt.over():
|
|
101
|
-
if early_stop_reason is None:
|
|
102
|
-
# Determine reason
|
|
103
|
-
if bt.usage.nodes >= bt.budget.max_nodes:
|
|
104
|
-
early_stop_reason = "max_nodes"
|
|
105
|
-
elif bt.usage.edges >= bt.budget.max_edges:
|
|
106
|
-
early_stop_reason = "max_edges"
|
|
107
|
-
elif bt.usage.paths >= bt.budget.max_paths:
|
|
108
|
-
early_stop_reason = "max_paths"
|
|
109
|
-
else:
|
|
110
|
-
early_stop_reason = "budget_exhausted"
|
|
111
|
-
break
|
|
112
|
-
|
|
113
|
-
best_paths.sort(key=lambda x: x[0], reverse=True)
|
|
114
|
-
return {
|
|
115
|
-
"paths": [
|
|
116
|
-
{
|
|
117
|
-
"score": float(sc),
|
|
118
|
-
"nodes": ns,
|
|
119
|
-
"edges": [{"u": u, "rel": r, "v": v} for (u, r, v) in es],
|
|
120
|
-
}
|
|
121
|
-
for sc, ns, es in best_paths[: beam_params.max_paths]
|
|
122
|
-
],
|
|
123
|
-
"used_budget": bt.usage.__dict__,
|
|
124
|
-
"trace": {"beam_width": beam_params.beam_width, "hop_limit": beam_params.hop_limit, "early_stop_reason": early_stop_reason},
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Tuple, Dict, Optional, Set
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import heapq, math, time
|
|
5
|
+
from .adapters import GraphAccessor, NodeId, RelId
|
|
6
|
+
from .confidence import EdgeConfidenceProvider, ConstantConfidence
|
|
7
|
+
from .budget import SearchBudget, BudgetTracker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class BeamParams:
|
|
12
|
+
hop_limit: int = 3
|
|
13
|
+
beam_width: int = 64
|
|
14
|
+
max_paths: int = 200
|
|
15
|
+
lambdas: Tuple[float, float, float, float] = (0.6, 0.2, 0.15, 0.05)
|
|
16
|
+
allowed_relations: Optional[Set[RelId]] = None
|
|
17
|
+
max_out_degree: Optional[int] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def default_recency(edge_timestamp: Optional[float], now_ts: Optional[float] = None, tau_days: float = 90.0) -> float:
|
|
21
|
+
if edge_timestamp is None:
|
|
22
|
+
return 1.0
|
|
23
|
+
if now_ts is None:
|
|
24
|
+
now_ts = time.time()
|
|
25
|
+
dt_days = max(0.0, (now_ts - edge_timestamp) / (60 * 60 * 24))
|
|
26
|
+
return math.exp(-dt_days / tau_days)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def safe_logp(x: float) -> float:
|
|
30
|
+
return math.log(max(x, 1e-12))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def beam_search(
|
|
34
|
+
accessor: GraphAccessor,
|
|
35
|
+
community_id: str,
|
|
36
|
+
seeds: List[NodeId],
|
|
37
|
+
ppr_scores: List[Tuple[NodeId, float]],
|
|
38
|
+
budget: Optional[SearchBudget] = None,
|
|
39
|
+
beam_params: BeamParams = BeamParams(),
|
|
40
|
+
conf_provider: EdgeConfidenceProvider = ConstantConfidence(0.8),
|
|
41
|
+
edge_type_prior: Optional[Dict[RelId, float]] = None,
|
|
42
|
+
edge_timestamp_lookup=None,
|
|
43
|
+
) -> Dict[str, object]:
|
|
44
|
+
|
|
45
|
+
bt = BudgetTracker(budget or SearchBudget(max_paths=beam_params.max_paths))
|
|
46
|
+
L1, L2, L3, L4 = beam_params.lambdas
|
|
47
|
+
ppr = {n: p for n, p in ppr_scores}
|
|
48
|
+
edge_type_prior = edge_type_prior or {}
|
|
49
|
+
|
|
50
|
+
heap: List[Tuple[float, List[NodeId], List[Tuple[NodeId, RelId, NodeId]]]] = []
|
|
51
|
+
for s in seeds:
|
|
52
|
+
heapq.heappush(heap, (0.0, [s], []))
|
|
53
|
+
|
|
54
|
+
best_paths = []
|
|
55
|
+
|
|
56
|
+
def score_extension(u: NodeId, rel: RelId, v: NodeId) -> float:
|
|
57
|
+
p1 = ppr.get(v, 1e-12)
|
|
58
|
+
w_edge = edge_type_prior.get(rel, 1.0)
|
|
59
|
+
c = conf_provider.confidence(u, rel, v)
|
|
60
|
+
rec = default_recency(edge_timestamp_lookup(u, rel, v) if edge_timestamp_lookup else None)
|
|
61
|
+
return L1 * safe_logp(p1) + L2 * safe_logp(w_edge) + L3 * safe_logp(c) + L4 * safe_logp(rec)
|
|
62
|
+
|
|
63
|
+
early_stop_reason = None
|
|
64
|
+
for hop in range(1, beam_params.hop_limit + 1):
|
|
65
|
+
next_heap: List[Tuple[float, List[NodeId], List[Tuple[NodeId, RelId, NodeId]]]] = []
|
|
66
|
+
while heap and not bt.over():
|
|
67
|
+
logscore, path_nodes, path_edges = heapq.heappop(heap)
|
|
68
|
+
u = path_nodes[-1]
|
|
69
|
+
bt.tick_nodes(1)
|
|
70
|
+
out_iter = accessor.iter_out(u)
|
|
71
|
+
if beam_params.max_out_degree is not None:
|
|
72
|
+
# Degree cap: take only first N neighbors
|
|
73
|
+
out_iter = list(out_iter)[: beam_params.max_out_degree]
|
|
74
|
+
for v, rel, _ in out_iter:
|
|
75
|
+
if bt.over():
|
|
76
|
+
break
|
|
77
|
+
bt.tick_edges(1)
|
|
78
|
+
if v in path_nodes:
|
|
79
|
+
continue
|
|
80
|
+
if beam_params.allowed_relations is not None and rel not in beam_params.allowed_relations:
|
|
81
|
+
continue
|
|
82
|
+
inc = score_extension(u, rel, v)
|
|
83
|
+
new_score = logscore + inc
|
|
84
|
+
new_nodes = path_nodes + [v]
|
|
85
|
+
new_edges = path_edges + [(u, rel, v)]
|
|
86
|
+
heapq.heappush(next_heap, (new_score, new_nodes, new_edges))
|
|
87
|
+
if len(next_heap) > beam_params.beam_width:
|
|
88
|
+
heapq.heappop(next_heap)
|
|
89
|
+
if bt.timed_out():
|
|
90
|
+
early_stop_reason = early_stop_reason or "timeout"
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
next_heap.sort(key=lambda x: x[0], reverse=True)
|
|
94
|
+
for sc, nodes, edges in next_heap:
|
|
95
|
+
best_paths.append((sc, nodes, edges))
|
|
96
|
+
bt.tick_paths(1)
|
|
97
|
+
if bt.over():
|
|
98
|
+
break
|
|
99
|
+
heap = next_heap
|
|
100
|
+
if bt.over():
|
|
101
|
+
if early_stop_reason is None:
|
|
102
|
+
# Determine reason
|
|
103
|
+
if bt.usage.nodes >= bt.budget.max_nodes:
|
|
104
|
+
early_stop_reason = "max_nodes"
|
|
105
|
+
elif bt.usage.edges >= bt.budget.max_edges:
|
|
106
|
+
early_stop_reason = "max_edges"
|
|
107
|
+
elif bt.usage.paths >= bt.budget.max_paths:
|
|
108
|
+
early_stop_reason = "max_paths"
|
|
109
|
+
else:
|
|
110
|
+
early_stop_reason = "budget_exhausted"
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
best_paths.sort(key=lambda x: x[0], reverse=True)
|
|
114
|
+
return {
|
|
115
|
+
"paths": [
|
|
116
|
+
{
|
|
117
|
+
"score": float(sc),
|
|
118
|
+
"nodes": ns,
|
|
119
|
+
"edges": [{"u": u, "rel": r, "v": v} for (u, r, v) in es],
|
|
120
|
+
}
|
|
121
|
+
for sc, ns, es in best_paths[: beam_params.max_paths]
|
|
122
|
+
],
|
|
123
|
+
"used_budget": bt.usage.__dict__,
|
|
124
|
+
"trace": {"beam_width": beam_params.beam_width, "hop_limit": beam_params.hop_limit, "early_stop_reason": early_stop_reason},
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
retrieval/budget.py
CHANGED
|
@@ -1,60 +1,60 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
import time
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
@dataclass
|
|
7
|
-
class SearchBudget:
|
|
8
|
-
max_nodes: int = 2_000
|
|
9
|
-
max_edges: int = 10_000
|
|
10
|
-
max_ms: int = 1_500
|
|
11
|
-
max_paths: int = 200
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@dataclass
|
|
15
|
-
class Usage:
|
|
16
|
-
nodes: int = 0
|
|
17
|
-
edges: int = 0
|
|
18
|
-
ms: int = 0
|
|
19
|
-
paths: int = 0
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class BudgetTracker:
|
|
23
|
-
def __init__(self, budget: SearchBudget):
|
|
24
|
-
self.budget = budget
|
|
25
|
-
self.usage = Usage()
|
|
26
|
-
self._start = time.perf_counter()
|
|
27
|
-
|
|
28
|
-
def tick_nodes(self, n=1):
|
|
29
|
-
self.usage.nodes += n
|
|
30
|
-
|
|
31
|
-
def tick_edges(self, n=1):
|
|
32
|
-
self.usage.edges += n
|
|
33
|
-
|
|
34
|
-
def tick_paths(self, n=1):
|
|
35
|
-
self.usage.paths += n
|
|
36
|
-
|
|
37
|
-
def timed_out(self) -> bool:
|
|
38
|
-
self.usage.ms = int((time.perf_counter() - self._start) * 1000)
|
|
39
|
-
return self.usage.ms >= self.budget.max_ms
|
|
40
|
-
|
|
41
|
-
def over(self) -> bool:
|
|
42
|
-
self.timed_out()
|
|
43
|
-
b, u = self.budget, self.usage
|
|
44
|
-
return (
|
|
45
|
-
u.nodes >= b.max_nodes
|
|
46
|
-
or u.edges >= b.max_edges
|
|
47
|
-
or u.paths >= b.max_paths
|
|
48
|
-
or u.ms >= b.max_ms
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
def left(self) -> SearchBudget:
|
|
52
|
-
self.timed_out()
|
|
53
|
-
return SearchBudget(
|
|
54
|
-
max_nodes=max(0, self.budget.max_nodes - self.usage.nodes),
|
|
55
|
-
max_edges=max(0, self.budget.max_edges - self.usage.edges),
|
|
56
|
-
max_ms=max(0, self.budget.max_ms - self.usage.ms),
|
|
57
|
-
max_paths=max(0, self.budget.max_paths - self.usage.paths),
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SearchBudget:
|
|
8
|
+
max_nodes: int = 2_000
|
|
9
|
+
max_edges: int = 10_000
|
|
10
|
+
max_ms: int = 1_500
|
|
11
|
+
max_paths: int = 200
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Usage:
|
|
16
|
+
nodes: int = 0
|
|
17
|
+
edges: int = 0
|
|
18
|
+
ms: int = 0
|
|
19
|
+
paths: int = 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BudgetTracker:
|
|
23
|
+
def __init__(self, budget: SearchBudget):
|
|
24
|
+
self.budget = budget
|
|
25
|
+
self.usage = Usage()
|
|
26
|
+
self._start = time.perf_counter()
|
|
27
|
+
|
|
28
|
+
def tick_nodes(self, n=1):
|
|
29
|
+
self.usage.nodes += n
|
|
30
|
+
|
|
31
|
+
def tick_edges(self, n=1):
|
|
32
|
+
self.usage.edges += n
|
|
33
|
+
|
|
34
|
+
def tick_paths(self, n=1):
|
|
35
|
+
self.usage.paths += n
|
|
36
|
+
|
|
37
|
+
def timed_out(self) -> bool:
|
|
38
|
+
self.usage.ms = int((time.perf_counter() - self._start) * 1000)
|
|
39
|
+
return self.usage.ms >= self.budget.max_ms
|
|
40
|
+
|
|
41
|
+
def over(self) -> bool:
|
|
42
|
+
self.timed_out()
|
|
43
|
+
b, u = self.budget, self.usage
|
|
44
|
+
return (
|
|
45
|
+
u.nodes >= b.max_nodes
|
|
46
|
+
or u.edges >= b.max_edges
|
|
47
|
+
or u.paths >= b.max_paths
|
|
48
|
+
or u.ms >= b.max_ms
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def left(self) -> SearchBudget:
|
|
52
|
+
self.timed_out()
|
|
53
|
+
return SearchBudget(
|
|
54
|
+
max_nodes=max(0, self.budget.max_nodes - self.usage.nodes),
|
|
55
|
+
max_edges=max(0, self.budget.max_edges - self.usage.edges),
|
|
56
|
+
max_ms=max(0, self.budget.max_ms - self.usage.ms),
|
|
57
|
+
max_paths=max(0, self.budget.max_paths - self.usage.paths),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|