odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/ppr/indexes.py
CHANGED
|
@@ -1,78 +1,78 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Dict, List, Iterable, Optional
|
|
4
|
-
import random
|
|
5
|
-
|
|
6
|
-
NodeId = int
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@dataclass
|
|
10
|
-
class WalkIndexConfig:
|
|
11
|
-
omega: int = 10 # walks per node (cap)
|
|
12
|
-
rmax: float = 1e-3 # residual threshold knob (for capacity heuristics)
|
|
13
|
-
alpha: float = 0.15 # teleport
|
|
14
|
-
seed: int = 42
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class RandomWalkIndex:
|
|
18
|
-
"""
|
|
19
|
-
FIRM-style random-walk index (skeleton):
|
|
20
|
-
- stores short geometric walks per node to accelerate SSPPR queries
|
|
21
|
-
- supports O(1) expected-time updates under random arrival model (sketch)
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
def __init__(self, cfg: WalkIndexConfig):
|
|
25
|
-
self.cfg = cfg
|
|
26
|
-
self.walks: Dict[NodeId, List[List[NodeId]]] = {}
|
|
27
|
-
random.seed(cfg.seed)
|
|
28
|
-
|
|
29
|
-
def build(self, graph, nodes: Optional[Iterable[NodeId]] = None):
|
|
30
|
-
nodes = nodes or graph.nodes()
|
|
31
|
-
for u in nodes:
|
|
32
|
-
self.walks[u] = self._sample_walks(graph, u, self.cfg.omega)
|
|
33
|
-
|
|
34
|
-
def _sample_walks(self, graph, u: NodeId, k: int) -> List[List[NodeId]]:
|
|
35
|
-
walks: List[List[NodeId]] = []
|
|
36
|
-
for _ in range(k):
|
|
37
|
-
path = [u]
|
|
38
|
-
v = u
|
|
39
|
-
while True:
|
|
40
|
-
if random.random() < self.cfg.alpha:
|
|
41
|
-
break
|
|
42
|
-
nbrs = list(graph.out_neighbors(v))
|
|
43
|
-
if not nbrs:
|
|
44
|
-
break
|
|
45
|
-
v = random.choice(nbrs)
|
|
46
|
-
path.append(v)
|
|
47
|
-
walks.append(path)
|
|
48
|
-
return walks
|
|
49
|
-
|
|
50
|
-
def on_edge_insert(self, graph, u: NodeId, v: NodeId):
|
|
51
|
-
if u not in self.walks:
|
|
52
|
-
return
|
|
53
|
-
W = self.walks[u]
|
|
54
|
-
target = max(1, int(graph.out_degree(u) * self.cfg.rmax * self.cfg.omega))
|
|
55
|
-
while len(W) < target:
|
|
56
|
-
W.append(self._sample_walks(graph, u, 1)[0])
|
|
57
|
-
while len(W) > target and W:
|
|
58
|
-
W.pop()
|
|
59
|
-
|
|
60
|
-
def on_edge_delete(self, graph, u: NodeId, v: NodeId):
|
|
61
|
-
if u not in self.walks:
|
|
62
|
-
return
|
|
63
|
-
W = self.walks[u]
|
|
64
|
-
for _ in range(min(2, len(W))):
|
|
65
|
-
if W:
|
|
66
|
-
W.pop()
|
|
67
|
-
target = max(1, int(graph.out_degree(u) * self.cfg.rmax * self.cfg.omega))
|
|
68
|
-
while len(W) < target:
|
|
69
|
-
W.append(self._sample_walks(graph, u, 1)[0])
|
|
70
|
-
|
|
71
|
-
def sample_hits(self, source: NodeId) -> Dict[NodeId, int]:
|
|
72
|
-
counts: Dict[NodeId, int] = {}
|
|
73
|
-
for w in self.walks.get(source, []):
|
|
74
|
-
for x in w:
|
|
75
|
-
counts[x] = counts.get(x, 0) + 1
|
|
76
|
-
return counts
|
|
77
|
-
|
|
78
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, List, Iterable, Optional
|
|
4
|
+
import random
|
|
5
|
+
|
|
6
|
+
NodeId = int
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class WalkIndexConfig:
|
|
11
|
+
omega: int = 10 # walks per node (cap)
|
|
12
|
+
rmax: float = 1e-3 # residual threshold knob (for capacity heuristics)
|
|
13
|
+
alpha: float = 0.15 # teleport
|
|
14
|
+
seed: int = 42
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RandomWalkIndex:
|
|
18
|
+
"""
|
|
19
|
+
FIRM-style random-walk index (skeleton):
|
|
20
|
+
- stores short geometric walks per node to accelerate SSPPR queries
|
|
21
|
+
- supports O(1) expected-time updates under random arrival model (sketch)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, cfg: WalkIndexConfig):
|
|
25
|
+
self.cfg = cfg
|
|
26
|
+
self.walks: Dict[NodeId, List[List[NodeId]]] = {}
|
|
27
|
+
random.seed(cfg.seed)
|
|
28
|
+
|
|
29
|
+
def build(self, graph, nodes: Optional[Iterable[NodeId]] = None):
|
|
30
|
+
nodes = nodes or graph.nodes()
|
|
31
|
+
for u in nodes:
|
|
32
|
+
self.walks[u] = self._sample_walks(graph, u, self.cfg.omega)
|
|
33
|
+
|
|
34
|
+
def _sample_walks(self, graph, u: NodeId, k: int) -> List[List[NodeId]]:
|
|
35
|
+
walks: List[List[NodeId]] = []
|
|
36
|
+
for _ in range(k):
|
|
37
|
+
path = [u]
|
|
38
|
+
v = u
|
|
39
|
+
while True:
|
|
40
|
+
if random.random() < self.cfg.alpha:
|
|
41
|
+
break
|
|
42
|
+
nbrs = list(graph.out_neighbors(v))
|
|
43
|
+
if not nbrs:
|
|
44
|
+
break
|
|
45
|
+
v = random.choice(nbrs)
|
|
46
|
+
path.append(v)
|
|
47
|
+
walks.append(path)
|
|
48
|
+
return walks
|
|
49
|
+
|
|
50
|
+
def on_edge_insert(self, graph, u: NodeId, v: NodeId):
|
|
51
|
+
if u not in self.walks:
|
|
52
|
+
return
|
|
53
|
+
W = self.walks[u]
|
|
54
|
+
target = max(1, int(graph.out_degree(u) * self.cfg.rmax * self.cfg.omega))
|
|
55
|
+
while len(W) < target:
|
|
56
|
+
W.append(self._sample_walks(graph, u, 1)[0])
|
|
57
|
+
while len(W) > target and W:
|
|
58
|
+
W.pop()
|
|
59
|
+
|
|
60
|
+
def on_edge_delete(self, graph, u: NodeId, v: NodeId):
|
|
61
|
+
if u not in self.walks:
|
|
62
|
+
return
|
|
63
|
+
W = self.walks[u]
|
|
64
|
+
for _ in range(min(2, len(W))):
|
|
65
|
+
if W:
|
|
66
|
+
W.pop()
|
|
67
|
+
target = max(1, int(graph.out_degree(u) * self.cfg.rmax * self.cfg.omega))
|
|
68
|
+
while len(W) < target:
|
|
69
|
+
W.append(self._sample_walks(graph, u, 1)[0])
|
|
70
|
+
|
|
71
|
+
def sample_hits(self, source: NodeId) -> Dict[NodeId, int]:
|
|
72
|
+
counts: Dict[NodeId, int] = {}
|
|
73
|
+
for w in self.walks.get(source, []):
|
|
74
|
+
for x in w:
|
|
75
|
+
counts[x] = counts.get(x, 0) + 1
|
|
76
|
+
return counts
|
|
77
|
+
|
|
78
|
+
|
retrieval/ppr.py
CHANGED
|
@@ -1,156 +1,156 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from typing import Dict, List, Tuple, Optional
|
|
3
|
-
from collections import defaultdict, deque
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
import random
|
|
6
|
-
from .adapters import GraphAccessor, NodeId
|
|
7
|
-
from .budget import SearchBudget, BudgetTracker
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class PPRParams:
|
|
12
|
-
alpha: float = 0.2
|
|
13
|
-
eps: float = 1e-6
|
|
14
|
-
num_walks: int = 2000
|
|
15
|
-
walk_len: int = 50
|
|
16
|
-
topn: int = 200
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass
|
|
20
|
-
class PPRResult:
|
|
21
|
-
scores: List[Tuple[NodeId, float]]
|
|
22
|
-
mass: float
|
|
23
|
-
used_budget: Dict[str, int]
|
|
24
|
-
trace: Dict[str, object]
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def build_alias_table(weighted_neighbors: List[Tuple[NodeId, float]]):
|
|
28
|
-
if not weighted_neighbors:
|
|
29
|
-
return [], [], []
|
|
30
|
-
total = sum(max(0.0, w) for _, w in weighted_neighbors) or 1.0
|
|
31
|
-
probs = [(n, w / total) for n, w in weighted_neighbors]
|
|
32
|
-
n = len(probs)
|
|
33
|
-
scaled = [p * n for _, p in probs]
|
|
34
|
-
alias, prob = [0] * n, [0.0] * n
|
|
35
|
-
small, large = [], []
|
|
36
|
-
for i, sp in enumerate(scaled):
|
|
37
|
-
(small if sp < 1 else large).append(i)
|
|
38
|
-
while small and large:
|
|
39
|
-
s, l = small.pop(), large.pop()
|
|
40
|
-
prob[s] = scaled[s]
|
|
41
|
-
alias[s] = l
|
|
42
|
-
scaled[l] = scaled[l] - (1 - prob[s])
|
|
43
|
-
(small if scaled[l] < 1 else large).append(l)
|
|
44
|
-
for i in small + large:
|
|
45
|
-
prob[i] = 1.0
|
|
46
|
-
alias[i] = i
|
|
47
|
-
nodes = [n for n, _ in probs]
|
|
48
|
-
return nodes, prob, alias
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def alias_draw(nodes, prob, alias):
|
|
52
|
-
if not nodes:
|
|
53
|
-
return None
|
|
54
|
-
i = random.randrange(len(nodes))
|
|
55
|
-
return nodes[i] if random.random() < prob[i] else nodes[alias[i]]
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class PushPPREngine:
|
|
59
|
-
def __init__(self, accessor: GraphAccessor, community_id: str):
|
|
60
|
-
self.A = accessor
|
|
61
|
-
self.cid = community_id
|
|
62
|
-
|
|
63
|
-
def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None) -> PPRResult:
|
|
64
|
-
bt = BudgetTracker(budget or SearchBudget())
|
|
65
|
-
p: Dict[NodeId, float] = defaultdict(float)
|
|
66
|
-
r: Dict[NodeId, float] = defaultdict(float)
|
|
67
|
-
q: deque[NodeId] = deque()
|
|
68
|
-
|
|
69
|
-
seeds = seeds or []
|
|
70
|
-
if not seeds:
|
|
71
|
-
return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
|
|
72
|
-
|
|
73
|
-
for s in seeds:
|
|
74
|
-
r[s] += 1.0 / len(seeds)
|
|
75
|
-
q.append(s)
|
|
76
|
-
|
|
77
|
-
iters = 0
|
|
78
|
-
while q and not bt.over():
|
|
79
|
-
u = q.popleft()
|
|
80
|
-
iters += 1
|
|
81
|
-
ru = r[u]
|
|
82
|
-
if ru <= 0:
|
|
83
|
-
continue
|
|
84
|
-
p[u] += params.alpha * ru
|
|
85
|
-
residual = (1 - params.alpha) * ru
|
|
86
|
-
r[u] = 0.0
|
|
87
|
-
|
|
88
|
-
nbrs = list(self.A.iter_out(u))
|
|
89
|
-
deg = len(nbrs)
|
|
90
|
-
if deg == 0:
|
|
91
|
-
continue
|
|
92
|
-
share = residual / deg
|
|
93
|
-
for v, _, _ in nbrs:
|
|
94
|
-
r[v] += share
|
|
95
|
-
bt.tick_edges(1)
|
|
96
|
-
if r[v] / max(1, self.A.degree(v)) > params.eps:
|
|
97
|
-
q.append(v)
|
|
98
|
-
bt.tick_nodes(1)
|
|
99
|
-
if bt.timed_out():
|
|
100
|
-
break
|
|
101
|
-
|
|
102
|
-
items = sorted(p.items(), key=lambda kv: kv[1], reverse=True)[: params.topn]
|
|
103
|
-
mass = sum(p.values())
|
|
104
|
-
return PPRResult(scores=items, mass=mass, used_budget=bt.usage.__dict__, trace={"engine": "push", "iters": iters, "cache_hit": False})
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class MonteCarloPPREngine:
|
|
108
|
-
def __init__(self, accessor: GraphAccessor, community_id: str, walk_index=None):
|
|
109
|
-
self.A = accessor
|
|
110
|
-
self.cid = community_id
|
|
111
|
-
self._alias_cache: Dict[NodeId, Tuple[List[NodeId], List[float], List[int]]] = {}
|
|
112
|
-
self.walk_index = walk_index
|
|
113
|
-
|
|
114
|
-
def _alias_for(self, u: NodeId):
|
|
115
|
-
if u in self._alias_cache:
|
|
116
|
-
return self._alias_cache[u]
|
|
117
|
-
nbrs = list(self.A.iter_out(u))
|
|
118
|
-
table = build_alias_table([(v, w) for v, _, w in nbrs])
|
|
119
|
-
self._alias_cache[u] = table
|
|
120
|
-
return table
|
|
121
|
-
|
|
122
|
-
def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None) -> PPRResult:
|
|
123
|
-
bt = BudgetTracker(budget or SearchBudget())
|
|
124
|
-
if not seeds:
|
|
125
|
-
return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "mc", "iters": 0, "cache_hit": False})
|
|
126
|
-
|
|
127
|
-
hits: Dict[NodeId, int] = defaultdict(int)
|
|
128
|
-
# Optional pre-hit sampling from walk index to save MC effort
|
|
129
|
-
if self.walk_index is not None:
|
|
130
|
-
for s in seeds:
|
|
131
|
-
for v, c in self.walk_index.sample_hits(s).items():
|
|
132
|
-
hits[v] += int(c)
|
|
133
|
-
for _ in range(params.num_walks):
|
|
134
|
-
if bt.over():
|
|
135
|
-
break
|
|
136
|
-
u = random.choice(seeds)
|
|
137
|
-
for _ in range(params.walk_len):
|
|
138
|
-
hits[u] += 1
|
|
139
|
-
bt.tick_nodes(1)
|
|
140
|
-
if random.random() < params.alpha:
|
|
141
|
-
u = random.choice(seeds)
|
|
142
|
-
continue
|
|
143
|
-
nodes, prob, alias = self._alias_for(u)
|
|
144
|
-
if not nodes:
|
|
145
|
-
u = random.choice(seeds)
|
|
146
|
-
continue
|
|
147
|
-
u = alias_draw(nodes, prob, alias)
|
|
148
|
-
bt.tick_edges(1)
|
|
149
|
-
if bt.timed_out():
|
|
150
|
-
break
|
|
151
|
-
|
|
152
|
-
total = float(sum(hits.values())) or 1.0
|
|
153
|
-
scores = sorted(((n, c / total) for n, c in hits.items()), key=lambda kv: kv[1], reverse=True)[: params.topn]
|
|
154
|
-
return PPRResult(scores=scores, mass=1.0, used_budget=bt.usage.__dict__, trace={"engine": "mc", "iters": params.num_walks, "cache_hit": False})
|
|
155
|
-
|
|
156
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Dict, List, Tuple, Optional
|
|
3
|
+
from collections import defaultdict, deque
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
import random
|
|
6
|
+
from .adapters import GraphAccessor, NodeId
|
|
7
|
+
from .budget import SearchBudget, BudgetTracker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PPRParams:
|
|
12
|
+
alpha: float = 0.2
|
|
13
|
+
eps: float = 1e-6
|
|
14
|
+
num_walks: int = 2000
|
|
15
|
+
walk_len: int = 50
|
|
16
|
+
topn: int = 200
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class PPRResult:
|
|
21
|
+
scores: List[Tuple[NodeId, float]]
|
|
22
|
+
mass: float
|
|
23
|
+
used_budget: Dict[str, int]
|
|
24
|
+
trace: Dict[str, object]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def build_alias_table(weighted_neighbors: List[Tuple[NodeId, float]]):
|
|
28
|
+
if not weighted_neighbors:
|
|
29
|
+
return [], [], []
|
|
30
|
+
total = sum(max(0.0, w) for _, w in weighted_neighbors) or 1.0
|
|
31
|
+
probs = [(n, w / total) for n, w in weighted_neighbors]
|
|
32
|
+
n = len(probs)
|
|
33
|
+
scaled = [p * n for _, p in probs]
|
|
34
|
+
alias, prob = [0] * n, [0.0] * n
|
|
35
|
+
small, large = [], []
|
|
36
|
+
for i, sp in enumerate(scaled):
|
|
37
|
+
(small if sp < 1 else large).append(i)
|
|
38
|
+
while small and large:
|
|
39
|
+
s, l = small.pop(), large.pop()
|
|
40
|
+
prob[s] = scaled[s]
|
|
41
|
+
alias[s] = l
|
|
42
|
+
scaled[l] = scaled[l] - (1 - prob[s])
|
|
43
|
+
(small if scaled[l] < 1 else large).append(l)
|
|
44
|
+
for i in small + large:
|
|
45
|
+
prob[i] = 1.0
|
|
46
|
+
alias[i] = i
|
|
47
|
+
nodes = [n for n, _ in probs]
|
|
48
|
+
return nodes, prob, alias
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def alias_draw(nodes, prob, alias):
|
|
52
|
+
if not nodes:
|
|
53
|
+
return None
|
|
54
|
+
i = random.randrange(len(nodes))
|
|
55
|
+
return nodes[i] if random.random() < prob[i] else nodes[alias[i]]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PushPPREngine:
|
|
59
|
+
def __init__(self, accessor: GraphAccessor, community_id: str):
|
|
60
|
+
self.A = accessor
|
|
61
|
+
self.cid = community_id
|
|
62
|
+
|
|
63
|
+
def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None) -> PPRResult:
|
|
64
|
+
bt = BudgetTracker(budget or SearchBudget())
|
|
65
|
+
p: Dict[NodeId, float] = defaultdict(float)
|
|
66
|
+
r: Dict[NodeId, float] = defaultdict(float)
|
|
67
|
+
q: deque[NodeId] = deque()
|
|
68
|
+
|
|
69
|
+
seeds = seeds or []
|
|
70
|
+
if not seeds:
|
|
71
|
+
return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
|
|
72
|
+
|
|
73
|
+
for s in seeds:
|
|
74
|
+
r[s] += 1.0 / len(seeds)
|
|
75
|
+
q.append(s)
|
|
76
|
+
|
|
77
|
+
iters = 0
|
|
78
|
+
while q and not bt.over():
|
|
79
|
+
u = q.popleft()
|
|
80
|
+
iters += 1
|
|
81
|
+
ru = r[u]
|
|
82
|
+
if ru <= 0:
|
|
83
|
+
continue
|
|
84
|
+
p[u] += params.alpha * ru
|
|
85
|
+
residual = (1 - params.alpha) * ru
|
|
86
|
+
r[u] = 0.0
|
|
87
|
+
|
|
88
|
+
nbrs = list(self.A.iter_out(u))
|
|
89
|
+
deg = len(nbrs)
|
|
90
|
+
if deg == 0:
|
|
91
|
+
continue
|
|
92
|
+
share = residual / deg
|
|
93
|
+
for v, _, _ in nbrs:
|
|
94
|
+
r[v] += share
|
|
95
|
+
bt.tick_edges(1)
|
|
96
|
+
if r[v] / max(1, self.A.degree(v)) > params.eps:
|
|
97
|
+
q.append(v)
|
|
98
|
+
bt.tick_nodes(1)
|
|
99
|
+
if bt.timed_out():
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
items = sorted(p.items(), key=lambda kv: kv[1], reverse=True)[: params.topn]
|
|
103
|
+
mass = sum(p.values())
|
|
104
|
+
return PPRResult(scores=items, mass=mass, used_budget=bt.usage.__dict__, trace={"engine": "push", "iters": iters, "cache_hit": False})
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class MonteCarloPPREngine:
|
|
108
|
+
def __init__(self, accessor: GraphAccessor, community_id: str, walk_index=None):
|
|
109
|
+
self.A = accessor
|
|
110
|
+
self.cid = community_id
|
|
111
|
+
self._alias_cache: Dict[NodeId, Tuple[List[NodeId], List[float], List[int]]] = {}
|
|
112
|
+
self.walk_index = walk_index
|
|
113
|
+
|
|
114
|
+
def _alias_for(self, u: NodeId):
|
|
115
|
+
if u in self._alias_cache:
|
|
116
|
+
return self._alias_cache[u]
|
|
117
|
+
nbrs = list(self.A.iter_out(u))
|
|
118
|
+
table = build_alias_table([(v, w) for v, _, w in nbrs])
|
|
119
|
+
self._alias_cache[u] = table
|
|
120
|
+
return table
|
|
121
|
+
|
|
122
|
+
def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None) -> PPRResult:
|
|
123
|
+
bt = BudgetTracker(budget or SearchBudget())
|
|
124
|
+
if not seeds:
|
|
125
|
+
return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "mc", "iters": 0, "cache_hit": False})
|
|
126
|
+
|
|
127
|
+
hits: Dict[NodeId, int] = defaultdict(int)
|
|
128
|
+
# Optional pre-hit sampling from walk index to save MC effort
|
|
129
|
+
if self.walk_index is not None:
|
|
130
|
+
for s in seeds:
|
|
131
|
+
for v, c in self.walk_index.sample_hits(s).items():
|
|
132
|
+
hits[v] += int(c)
|
|
133
|
+
for _ in range(params.num_walks):
|
|
134
|
+
if bt.over():
|
|
135
|
+
break
|
|
136
|
+
u = random.choice(seeds)
|
|
137
|
+
for _ in range(params.walk_len):
|
|
138
|
+
hits[u] += 1
|
|
139
|
+
bt.tick_nodes(1)
|
|
140
|
+
if random.random() < params.alpha:
|
|
141
|
+
u = random.choice(seeds)
|
|
142
|
+
continue
|
|
143
|
+
nodes, prob, alias = self._alias_for(u)
|
|
144
|
+
if not nodes:
|
|
145
|
+
u = random.choice(seeds)
|
|
146
|
+
continue
|
|
147
|
+
u = alias_draw(nodes, prob, alias)
|
|
148
|
+
bt.tick_edges(1)
|
|
149
|
+
if bt.timed_out():
|
|
150
|
+
break
|
|
151
|
+
|
|
152
|
+
total = float(sum(hits.values())) or 1.0
|
|
153
|
+
scores = sorted(((n, c / total) for n, c in hits.items()), key=lambda kv: kv[1], reverse=True)[: params.topn]
|
|
154
|
+
return PPRResult(scores=scores, mass=1.0, used_budget=bt.usage.__dict__, trace={"engine": "mc", "iters": params.num_walks, "cache_hit": False})
|
|
155
|
+
|
|
156
|
+
|
retrieval/ppr_cache.py
CHANGED
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import hashlib
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def _key(community_id: str, seeds: list[str], alpha: float, engine: str, **kwargs) -> str:
|
|
6
|
-
seed_hash = hashlib.md5(json.dumps(sorted(seeds)).encode(), usedforsecurity=False).hexdigest()
|
|
7
|
-
prior_hash = kwargs.get('prior_hash', '')
|
|
8
|
-
return f"{community_id}:{engine}:{alpha:.4f}:{seed_hash}:{prior_hash}"
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class PPRCache:
|
|
12
|
-
def __init__(self, capacity: int = 256):
|
|
13
|
-
self.capacity = capacity
|
|
14
|
-
self._cache: dict[str, object] = {}
|
|
15
|
-
|
|
16
|
-
def get(self, key: str):
|
|
17
|
-
return self._cache.get(key)
|
|
18
|
-
|
|
19
|
-
def put(self, key: str, value: object):
|
|
20
|
-
if len(self._cache) >= self.capacity:
|
|
21
|
-
# Simple FIFO eviction
|
|
22
|
-
self._cache.pop(next(iter(self._cache)))
|
|
23
|
-
self._cache[key] = value
|
|
24
|
-
|
|
25
|
-
|
|
1
|
+
import json
|
|
2
|
+
import hashlib
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _key(community_id: str, seeds: list[str], alpha: float, engine: str, **kwargs) -> str:
|
|
6
|
+
seed_hash = hashlib.md5(json.dumps(sorted(seeds)).encode(), usedforsecurity=False).hexdigest()
|
|
7
|
+
prior_hash = kwargs.get('prior_hash', '')
|
|
8
|
+
return f"{community_id}:{engine}:{alpha:.4f}:{seed_hash}:{prior_hash}"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PPRCache:
|
|
12
|
+
def __init__(self, capacity: int = 256):
|
|
13
|
+
self.capacity = capacity
|
|
14
|
+
self._cache: dict[str, object] = {}
|
|
15
|
+
|
|
16
|
+
def get(self, key: str):
|
|
17
|
+
return self._cache.get(key)
|
|
18
|
+
|
|
19
|
+
def put(self, key: str, value: object):
|
|
20
|
+
if len(self._cache) >= self.capacity:
|
|
21
|
+
# Simple FIFO eviction
|
|
22
|
+
self._cache.pop(next(iter(self._cache)))
|
|
23
|
+
self._cache[key] = value
|
|
24
|
+
|
|
25
|
+
|