odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
retrieval/ppr/engines.py CHANGED
@@ -1,257 +1,257 @@
1
- from __future__ import annotations
2
- from typing import Dict, List, Tuple, Optional
3
- from collections import defaultdict, deque
4
- from dataclasses import dataclass
5
- import random
6
- from ..adapters import GraphAccessor, NodeId
7
- from ..budget import SearchBudget, BudgetTracker
8
-
9
-
10
- @dataclass
11
- class PPRParams:
12
- alpha: float = 0.2
13
- eps: float = 1e-6
14
- num_walks: int = 2000
15
- walk_len: int = 50
16
- topn: int = 200
17
-
18
-
19
- @dataclass
20
- class PPRResult:
21
- scores: List[Tuple[NodeId, float]]
22
- mass: float
23
- used_budget: Dict[str, int]
24
- trace: Dict[str, object]
25
-
26
-
27
- def build_alias_table(weighted_neighbors: List[Tuple[NodeId, float]]):
28
- if not weighted_neighbors:
29
- return [], [], []
30
- total = sum(max(0.0, w) for _, w in weighted_neighbors) or 1.0
31
- probs = [(n, w / total) for n, w in weighted_neighbors]
32
- n = len(probs)
33
- scaled = [p * n for _, p in probs]
34
- alias, prob = [0] * n, [0.0] * n
35
- small, large = [], []
36
- for i, sp in enumerate(scaled):
37
- (small if sp < 1 else large).append(i)
38
- while small and large:
39
- s, l = small.pop(), large.pop()
40
- prob[s] = scaled[s]
41
- alias[s] = l
42
- scaled[l] = scaled[l] - (1 - prob[s])
43
- (small if scaled[l] < 1 else large).append(l)
44
- for i in small + large:
45
- prob[i] = 1.0
46
- alias[i] = i
47
- nodes = [n for n, _ in probs]
48
- return nodes, prob, alias
49
-
50
-
51
- def alias_draw(nodes, prob, alias):
52
- if not nodes:
53
- return None
54
- i = random.randrange(len(nodes))
55
- return nodes[i] if random.random() < prob[i] else nodes[alias[i]]
56
-
57
-
58
- class PushPPREngine:
59
- def __init__(self, accessor: GraphAccessor, community_id: str):
60
- self.A = accessor
61
- self.cid = community_id
62
-
63
- def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None, personalization: Optional[Dict[NodeId, float]] = None) -> PPRResult:
64
- bt = BudgetTracker(budget or SearchBudget())
65
- p: Dict[NodeId, float] = defaultdict(float)
66
- r: Dict[NodeId, float] = defaultdict(float)
67
- q: deque[NodeId] = deque()
68
-
69
- if personalization is None:
70
- # Default behavior if no personalization is provided (uniform over seeds)
71
- seeds = seeds or []
72
- if not seeds:
73
- return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
74
- init_mass = 1.0 / len(seeds)
75
- for s in seeds:
76
- r[s] += init_mass
77
- q.append(s)
78
- else:
79
- # Use provided personalization vector
80
- for s, mass in personalization.items():
81
- r[s] += mass
82
- q.append(s)
83
- if not personalization:
84
- return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
85
-
86
- early_stop_reason = None
87
- iters = 0
88
- while q and not bt.over():
89
- u = q.popleft()
90
- iters += 1
91
- ru = r[u]
92
- if ru <= 0:
93
- continue
94
- p[u] += params.alpha * ru
95
- residual = (1 - params.alpha) * ru
96
- r[u] = 0.0
97
-
98
- nbrs = list(self.A.iter_out(u))
99
- deg = len(nbrs)
100
- if deg == 0:
101
- continue
102
- share = residual / deg
103
- for v, _, _ in nbrs:
104
- r[v] += share
105
- bt.tick_edges(1)
106
- if r[v] / max(1, self.A.degree(v)) > params.eps:
107
- q.append(v)
108
- bt.tick_nodes(1)
109
- if bt.timed_out():
110
- early_stop_reason = "timeout"
111
- break
112
-
113
- if not early_stop_reason and bt.over():
114
- early_stop_reason = "budget_exhausted"
115
-
116
- items = sorted(p.items(), key=lambda kv: kv[1], reverse=True)[: params.topn]
117
- mass = sum(p.values())
118
- return PPRResult(scores=items, mass=mass, used_budget=bt.usage.__dict__, trace={"engine": "push", "iters": iters, "cache_hit": False, "early_stop_reason": early_stop_reason})
119
-
120
-
121
- class MonteCarloPPREngine:
122
- def __init__(self, accessor: GraphAccessor, community_id: str, walk_index=None):
123
- self.A = accessor
124
- self.cid = community_id
125
- self._alias_cache: Dict[NodeId, Tuple[List[NodeId], List[float], List[int]]] = {}
126
- self.walk_index = walk_index
127
-
128
- def _alias_for(self, u: NodeId):
129
- if u in self._alias_cache:
130
- return self._alias_cache[u]
131
- nbrs = list(self.A.iter_out(u))
132
- table = build_alias_table([(v, w) for v, _, w in nbrs])
133
- self._alias_cache[u] = table
134
- return table
135
-
136
- def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None, personalization: Optional[Dict[NodeId, float]] = None) -> PPRResult:
137
- bt = BudgetTracker(budget or SearchBudget())
138
- if not seeds and (personalization is None or not personalization):
139
- return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "mc", "iters": 0, "cache_hit": False})
140
-
141
- hits: Dict[NodeId, int] = defaultdict(int)
142
- # Optional pre-hit sampling from walk index to save MC effort
143
- if self.walk_index is not None:
144
- for s in seeds:
145
- for v, c in self.walk_index.sample_hits(s).items():
146
- hits[v] += int(c)
147
-
148
- # Prepare for weighted random choice if personalization is provided
149
- personalization_nodes = list(personalization.keys()) if personalization else seeds
150
- personalization_weights = list(personalization.values()) if personalization else [1.0] * len(seeds)
151
- # Normalize weights for random.choices if provided
152
- total_personalization_weight = sum(personalization_weights)
153
- if total_personalization_weight == 0:
154
- # Fallback to uniform if all weights are zero
155
- personalization_weights = [1.0] * len(personalization_nodes)
156
- total_personalization_weight = float(len(personalization_nodes))
157
- normalized_personalization_weights = [w / total_personalization_weight for w in personalization_weights]
158
-
159
- early_stop_reason = None
160
- for _ in range(params.num_walks):
161
- if bt.over():
162
- early_stop_reason = "budget_exhausted"
163
- break
164
-
165
- # Start walk from personalized distribution if available, else uniform from seeds
166
- if personalization:
167
- u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
168
- else:
169
- u = random.choice(seeds)
170
-
171
- for _ in range(params.walk_len):
172
- if bt.over(): # Check budget before each step
173
- early_stop_reason = "budget_exhausted"
174
- break
175
- hits[u] += 1
176
- bt.tick_nodes(1)
177
- if random.random() < params.alpha:
178
- # Teleport back to personalized distribution if available, else uniform from seeds
179
- if personalization:
180
- u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
181
- else:
182
- u = random.choice(seeds)
183
- continue
184
- nodes, prob, alias = self._alias_for(u)
185
- if not nodes:
186
- # If dangling, teleport back to personalized distribution if available, else uniform from seeds
187
- if personalization:
188
- u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
189
- else:
190
- u = random.choice(seeds)
191
- break
192
- u = alias_draw(nodes, prob, alias)
193
- bt.tick_edges(1)
194
- if bt.timed_out():
195
- early_stop_reason = "timeout"
196
- break
197
-
198
- total = float(sum(hits.values()) or 1.0)
199
- scores = sorted(((n, c / total) for n, c in hits.items()), key=lambda kv: kv[1], reverse=True)[: params.topn]
200
- return PPRResult(scores=scores, mass=1.0, used_budget=bt.usage.__dict__, trace={"engine": "mc", "iters": params.num_walks, "cache_hit": False, "early_stop_reason": early_stop_reason})
201
-
202
-
203
- class BiPPREngine:
204
- """
205
- Bidirectional PPR for source→target scoring:
206
- - Reverse push from targets builds fingerprints
207
- - Forward random walks from source intersect with reverse mass
208
- """
209
-
210
- def __init__(self, graph: GraphAccessor, alpha=0.15, rmax=1e-4):
211
- self.G = graph
212
- self.alpha, self.rmax = alpha, rmax
213
-
214
- def _reverse_push(self, targets: List[NodeId]) -> Dict[NodeId, float]:
215
- p_t, r_t = defaultdict(float), defaultdict(float)
216
- q = deque()
217
- tgt = list(targets)
218
- if not tgt:
219
- return p_t
220
- mass = 1.0 / len(tgt)
221
- for t in tgt:
222
- r_t[t] = mass
223
- q.append(t)
224
- while q:
225
- u = q.popleft()
226
- if r_t[u] / max(1, self.G.in_degree(u)) <= self.rmax:
227
- continue
228
- push = (1 - self.alpha) * r_t[u]
229
- p_t[u] += self.alpha * r_t[u]
230
- r_t[u] = 0.0
231
- deg = self.G.in_degree(u)
232
- if deg == 0:
233
- continue
234
- share = push / deg
235
- for v in self.G.in_neighbors(u):
236
- r_t[v] += share
237
- if r_t[v] / max(1, self.G.in_degree(v)) > self.rmax:
238
- q.append(v)
239
- return p_t
240
-
241
- def score(self, source: NodeId, targets: List[NodeId], walks=5000) -> List[Tuple[NodeId, float]]:
242
- fp = self._reverse_push(targets)
243
- hits = defaultdict(int)
244
- for _ in range(walks):
245
- u = source
246
- while True:
247
- hits[u] += 1
248
- if random.random() < self.alpha:
249
- break
250
- nbrs = list(self.G.out_neighbors(u))
251
- if not nbrs:
252
- break
253
- u = random.choice(nbrs)
254
- Z = float(sum(hits.values()) or 1)
255
- return sorted(((t, fp.get(t, 0.0) * hits.get(t, 0) / Z) for t in targets), key=lambda kv: kv[1], reverse=True)
256
-
257
-
1
+ from __future__ import annotations
2
+ from typing import Dict, List, Tuple, Optional
3
+ from collections import defaultdict, deque
4
+ from dataclasses import dataclass
5
+ import random
6
+ from ..adapters import GraphAccessor, NodeId
7
+ from ..budget import SearchBudget, BudgetTracker
8
+
9
+
10
+ @dataclass
11
+ class PPRParams:
12
+ alpha: float = 0.2
13
+ eps: float = 1e-6
14
+ num_walks: int = 2000
15
+ walk_len: int = 50
16
+ topn: int = 200
17
+
18
+
19
+ @dataclass
20
+ class PPRResult:
21
+ scores: List[Tuple[NodeId, float]]
22
+ mass: float
23
+ used_budget: Dict[str, int]
24
+ trace: Dict[str, object]
25
+
26
+
27
+ def build_alias_table(weighted_neighbors: List[Tuple[NodeId, float]]):
28
+ if not weighted_neighbors:
29
+ return [], [], []
30
+ total = sum(max(0.0, w) for _, w in weighted_neighbors) or 1.0
31
+ probs = [(n, w / total) for n, w in weighted_neighbors]
32
+ n = len(probs)
33
+ scaled = [p * n for _, p in probs]
34
+ alias, prob = [0] * n, [0.0] * n
35
+ small, large = [], []
36
+ for i, sp in enumerate(scaled):
37
+ (small if sp < 1 else large).append(i)
38
+ while small and large:
39
+ s, l = small.pop(), large.pop()
40
+ prob[s] = scaled[s]
41
+ alias[s] = l
42
+ scaled[l] = scaled[l] - (1 - prob[s])
43
+ (small if scaled[l] < 1 else large).append(l)
44
+ for i in small + large:
45
+ prob[i] = 1.0
46
+ alias[i] = i
47
+ nodes = [n for n, _ in probs]
48
+ return nodes, prob, alias
49
+
50
+
51
+ def alias_draw(nodes, prob, alias):
52
+ if not nodes:
53
+ return None
54
+ i = random.randrange(len(nodes))
55
+ return nodes[i] if random.random() < prob[i] else nodes[alias[i]]
56
+
57
+
58
+ class PushPPREngine:
59
+ def __init__(self, accessor: GraphAccessor, community_id: str):
60
+ self.A = accessor
61
+ self.cid = community_id
62
+
63
+ def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None, personalization: Optional[Dict[NodeId, float]] = None) -> PPRResult:
64
+ bt = BudgetTracker(budget or SearchBudget())
65
+ p: Dict[NodeId, float] = defaultdict(float)
66
+ r: Dict[NodeId, float] = defaultdict(float)
67
+ q: deque[NodeId] = deque()
68
+
69
+ if personalization is None:
70
+ # Default behavior if no personalization is provided (uniform over seeds)
71
+ seeds = seeds or []
72
+ if not seeds:
73
+ return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
74
+ init_mass = 1.0 / len(seeds)
75
+ for s in seeds:
76
+ r[s] += init_mass
77
+ q.append(s)
78
+ else:
79
+ # Use provided personalization vector
80
+ for s, mass in personalization.items():
81
+ r[s] += mass
82
+ q.append(s)
83
+ if not personalization:
84
+ return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "push", "iters": 0, "cache_hit": False})
85
+
86
+ early_stop_reason = None
87
+ iters = 0
88
+ while q and not bt.over():
89
+ u = q.popleft()
90
+ iters += 1
91
+ ru = r[u]
92
+ if ru <= 0:
93
+ continue
94
+ p[u] += params.alpha * ru
95
+ residual = (1 - params.alpha) * ru
96
+ r[u] = 0.0
97
+
98
+ nbrs = list(self.A.iter_out(u))
99
+ deg = len(nbrs)
100
+ if deg == 0:
101
+ continue
102
+ share = residual / deg
103
+ for v, _, _ in nbrs:
104
+ r[v] += share
105
+ bt.tick_edges(1)
106
+ if r[v] / max(1, self.A.degree(v)) > params.eps:
107
+ q.append(v)
108
+ bt.tick_nodes(1)
109
+ if bt.timed_out():
110
+ early_stop_reason = "timeout"
111
+ break
112
+
113
+ if not early_stop_reason and bt.over():
114
+ early_stop_reason = "budget_exhausted"
115
+
116
+ items = sorted(p.items(), key=lambda kv: kv[1], reverse=True)[: params.topn]
117
+ mass = sum(p.values())
118
+ return PPRResult(scores=items, mass=mass, used_budget=bt.usage.__dict__, trace={"engine": "push", "iters": iters, "cache_hit": False, "early_stop_reason": early_stop_reason})
119
+
120
+
121
+ class MonteCarloPPREngine:
122
+ def __init__(self, accessor: GraphAccessor, community_id: str, walk_index=None):
123
+ self.A = accessor
124
+ self.cid = community_id
125
+ self._alias_cache: Dict[NodeId, Tuple[List[NodeId], List[float], List[int]]] = {}
126
+ self.walk_index = walk_index
127
+
128
+ def _alias_for(self, u: NodeId):
129
+ if u in self._alias_cache:
130
+ return self._alias_cache[u]
131
+ nbrs = list(self.A.iter_out(u))
132
+ table = build_alias_table([(v, w) for v, _, w in nbrs])
133
+ self._alias_cache[u] = table
134
+ return table
135
+
136
+ def run(self, seeds: List[NodeId], params: PPRParams, budget: Optional[SearchBudget] = None, personalization: Optional[Dict[NodeId, float]] = None) -> PPRResult:
137
+ bt = BudgetTracker(budget or SearchBudget())
138
+ if not seeds and (personalization is None or not personalization):
139
+ return PPRResult([], 0.0, bt.usage.__dict__, {"engine": "mc", "iters": 0, "cache_hit": False})
140
+
141
+ hits: Dict[NodeId, int] = defaultdict(int)
142
+ # Optional pre-hit sampling from walk index to save MC effort
143
+ if self.walk_index is not None:
144
+ for s in seeds:
145
+ for v, c in self.walk_index.sample_hits(s).items():
146
+ hits[v] += int(c)
147
+
148
+ # Prepare for weighted random choice if personalization is provided
149
+ personalization_nodes = list(personalization.keys()) if personalization else seeds
150
+ personalization_weights = list(personalization.values()) if personalization else [1.0] * len(seeds)
151
+ # Normalize weights for random.choices if provided
152
+ total_personalization_weight = sum(personalization_weights)
153
+ if total_personalization_weight == 0:
154
+ # Fallback to uniform if all weights are zero
155
+ personalization_weights = [1.0] * len(personalization_nodes)
156
+ total_personalization_weight = float(len(personalization_nodes))
157
+ normalized_personalization_weights = [w / total_personalization_weight for w in personalization_weights]
158
+
159
+ early_stop_reason = None
160
+ for _ in range(params.num_walks):
161
+ if bt.over():
162
+ early_stop_reason = "budget_exhausted"
163
+ break
164
+
165
+ # Start walk from personalized distribution if available, else uniform from seeds
166
+ if personalization:
167
+ u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
168
+ else:
169
+ u = random.choice(seeds)
170
+
171
+ for _ in range(params.walk_len):
172
+ if bt.over(): # Check budget before each step
173
+ early_stop_reason = "budget_exhausted"
174
+ break
175
+ hits[u] += 1
176
+ bt.tick_nodes(1)
177
+ if random.random() < params.alpha:
178
+ # Teleport back to personalized distribution if available, else uniform from seeds
179
+ if personalization:
180
+ u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
181
+ else:
182
+ u = random.choice(seeds)
183
+ continue
184
+ nodes, prob, alias = self._alias_for(u)
185
+ if not nodes:
186
+ # If dangling, teleport back to personalized distribution if available, else uniform from seeds
187
+ if personalization:
188
+ u = random.choices(personalization_nodes, weights=normalized_personalization_weights, k=1)[0]
189
+ else:
190
+ u = random.choice(seeds)
191
+ break
192
+ u = alias_draw(nodes, prob, alias)
193
+ bt.tick_edges(1)
194
+ if bt.timed_out():
195
+ early_stop_reason = "timeout"
196
+ break
197
+
198
+ total = float(sum(hits.values()) or 1.0)
199
+ scores = sorted(((n, c / total) for n, c in hits.items()), key=lambda kv: kv[1], reverse=True)[: params.topn]
200
+ return PPRResult(scores=scores, mass=1.0, used_budget=bt.usage.__dict__, trace={"engine": "mc", "iters": params.num_walks, "cache_hit": False, "early_stop_reason": early_stop_reason})
201
+
202
+
203
+ class BiPPREngine:
204
+ """
205
+ Bidirectional PPR for source→target scoring:
206
+ - Reverse push from targets builds fingerprints
207
+ - Forward random walks from source intersect with reverse mass
208
+ """
209
+
210
+ def __init__(self, graph: GraphAccessor, alpha=0.15, rmax=1e-4):
211
+ self.G = graph
212
+ self.alpha, self.rmax = alpha, rmax
213
+
214
+ def _reverse_push(self, targets: List[NodeId]) -> Dict[NodeId, float]:
215
+ p_t, r_t = defaultdict(float), defaultdict(float)
216
+ q = deque()
217
+ tgt = list(targets)
218
+ if not tgt:
219
+ return p_t
220
+ mass = 1.0 / len(tgt)
221
+ for t in tgt:
222
+ r_t[t] = mass
223
+ q.append(t)
224
+ while q:
225
+ u = q.popleft()
226
+ if r_t[u] / max(1, self.G.in_degree(u)) <= self.rmax:
227
+ continue
228
+ push = (1 - self.alpha) * r_t[u]
229
+ p_t[u] += self.alpha * r_t[u]
230
+ r_t[u] = 0.0
231
+ deg = self.G.in_degree(u)
232
+ if deg == 0:
233
+ continue
234
+ share = push / deg
235
+ for v in self.G.in_neighbors(u):
236
+ r_t[v] += share
237
+ if r_t[v] / max(1, self.G.in_degree(v)) > self.rmax:
238
+ q.append(v)
239
+ return p_t
240
+
241
+ def score(self, source: NodeId, targets: List[NodeId], walks=5000) -> List[Tuple[NodeId, float]]:
242
+ fp = self._reverse_push(targets)
243
+ hits = defaultdict(int)
244
+ for _ in range(walks):
245
+ u = source
246
+ while True:
247
+ hits[u] += 1
248
+ if random.random() < self.alpha:
249
+ break
250
+ nbrs = list(self.G.out_neighbors(u))
251
+ if not nbrs:
252
+ break
253
+ u = random.choice(nbrs)
254
+ Z = float(sum(hits.values()) or 1)
255
+ return sorted(((t, fp.get(t, 0.0) * hits.get(t, 0) / Z) for t in targets), key=lambda kv: kv[1], reverse=True)
256
+
257
+
@@ -1,76 +1,76 @@
1
- from __future__ import annotations
2
- from dataclasses import dataclass
3
- from typing import Dict, Iterable
4
-
5
- from ..adapters import GraphAccessor, NodeId
6
-
7
-
8
- @dataclass
9
- class GlobalPRParams:
10
- alpha: float = 0.15
11
- tol: float = 1e-8
12
- max_iter: int = 100
13
-
14
-
15
- class GlobalPR:
16
- """
17
- Power-iteration PageRank over nodes visible to the GraphAccessor in a given community.
18
- Handles dangling nodes by redistributing mass to personalization (uniform over nodes).
19
- """
20
-
21
- def __init__(self, accessor: GraphAccessor, community_id: str):
22
- self.A = accessor
23
- self.cid = community_id
24
- self.pr: Dict[NodeId, float] = {}
25
-
26
- def fit(self, params: GlobalPRParams = GlobalPRParams()) -> Dict[NodeId, float]:
27
- nodes = list(self.A.nodes(self.cid))
28
- if not nodes:
29
- self.pr = {}
30
- return self.pr
31
- n = len(nodes)
32
- idx: Dict[NodeId, int] = {u: i for i, u in enumerate(nodes)}
33
- pr_prev = [1.0 / n] * n
34
- alpha = params.alpha
35
- teleport = 1.0 / n
36
-
37
- # Precompute out-neighbors indices
38
- out_idx: Dict[int, list[int]] = {}
39
- for u in nodes:
40
- ui = idx[u]
41
- nbrs = [idx[v] for v, _, _ in self.A.iter_out(u)]
42
- out_idx[ui] = nbrs
43
-
44
- for _ in range(params.max_iter):
45
- pr = [0.0] * n
46
- dangling_mass = 0.0
47
- for ui in range(n):
48
- nbrs = out_idx[ui]
49
- if not nbrs:
50
- dangling_mass += (1.0 - alpha) * pr_prev[ui]
51
- pr[ui] += alpha * pr_prev[ui]
52
- continue
53
- share = (1.0 - alpha) * pr_prev[ui] / len(nbrs)
54
- pr[ui] += alpha * pr_prev[ui]
55
- for vj in nbrs:
56
- pr[vj] += share
57
- # Redistribute dangling to teleport set (uniform personalization)
58
- if dangling_mass > 0:
59
- add = dangling_mass * teleport
60
- pr = [x + add for x in pr]
61
-
62
- # Teleportation to uniform as well
63
- pr = [alpha * teleport + (1.0 - alpha) * (x - alpha * teleport) for x in pr]
64
-
65
- # Normalize and check convergence (L1)
66
- s = sum(pr) or 1.0
67
- pr = [x / s for x in pr]
68
- diff = sum(abs(pr[i] - pr_prev[i]) for i in range(n))
69
- pr_prev = pr
70
- if diff < params.tol:
71
- break
72
-
73
- self.pr = {u: pr_prev[idx[u]] for u in nodes}
74
- return self.pr
75
-
76
-
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Iterable
4
+
5
+ from ..adapters import GraphAccessor, NodeId
6
+
7
+
8
+ @dataclass
9
+ class GlobalPRParams:
10
+ alpha: float = 0.15
11
+ tol: float = 1e-8
12
+ max_iter: int = 100
13
+
14
+
15
+ class GlobalPR:
16
+ """
17
+ Power-iteration PageRank over nodes visible to the GraphAccessor in a given community.
18
+ Handles dangling nodes by redistributing mass to personalization (uniform over nodes).
19
+ """
20
+
21
+ def __init__(self, accessor: GraphAccessor, community_id: str):
22
+ self.A = accessor
23
+ self.cid = community_id
24
+ self.pr: Dict[NodeId, float] = {}
25
+
26
+ def fit(self, params: GlobalPRParams = GlobalPRParams()) -> Dict[NodeId, float]:
27
+ nodes = list(self.A.nodes(self.cid))
28
+ if not nodes:
29
+ self.pr = {}
30
+ return self.pr
31
+ n = len(nodes)
32
+ idx: Dict[NodeId, int] = {u: i for i, u in enumerate(nodes)}
33
+ pr_prev = [1.0 / n] * n
34
+ alpha = params.alpha
35
+ teleport = 1.0 / n
36
+
37
+ # Precompute out-neighbors indices
38
+ out_idx: Dict[int, list[int]] = {}
39
+ for u in nodes:
40
+ ui = idx[u]
41
+ nbrs = [idx[v] for v, _, _ in self.A.iter_out(u)]
42
+ out_idx[ui] = nbrs
43
+
44
+ for _ in range(params.max_iter):
45
+ pr = [0.0] * n
46
+ dangling_mass = 0.0
47
+ for ui in range(n):
48
+ nbrs = out_idx[ui]
49
+ if not nbrs:
50
+ dangling_mass += (1.0 - alpha) * pr_prev[ui]
51
+ pr[ui] += alpha * pr_prev[ui]
52
+ continue
53
+ share = (1.0 - alpha) * pr_prev[ui] / len(nbrs)
54
+ pr[ui] += alpha * pr_prev[ui]
55
+ for vj in nbrs:
56
+ pr[vj] += share
57
+ # Redistribute dangling to teleport set (uniform personalization)
58
+ if dangling_mass > 0:
59
+ add = dangling_mass * teleport
60
+ pr = [x + add for x in pr]
61
+
62
+ # Teleportation to uniform as well
63
+ pr = [alpha * teleport + (1.0 - alpha) * (x - alpha * teleport) for x in pr]
64
+
65
+ # Normalize and check convergence (L1)
66
+ s = sum(pr) or 1.0
67
+ pr = [x / s for x in pr]
68
+ diff = sum(abs(pr[i] - pr_prev[i]) for i in range(n))
69
+ pr_prev = pr
70
+ if diff < params.tol:
71
+ break
72
+
73
+ self.pr = {u: pr_prev[idx[u]] for u in nodes}
74
+ return self.pr
75
+
76
+