odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/utils/batch_utils.py CHANGED
@@ -1,493 +1,493 @@
1
- """
2
- Batch processing utilities for NPLL ground rules
3
- Handles efficient batching and sampling of ground rules for MLN computations
4
- """
5
-
6
- import torch
7
- import numpy as np
8
- from typing import List, Dict, Set, Tuple, Optional, Iterator, Any
9
- from collections import defaultdict
10
- import random
11
- import logging
12
- from dataclasses import dataclass
13
-
14
- from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
15
- from ..utils.config import NPLLConfig
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- @dataclass
21
- class GroundRuleBatch:
22
- """
23
- Batch of ground rules for efficient processing
24
-
25
- Contains ground rules and associated metadata for batch operations
26
- """
27
- ground_rules: List[GroundRule]
28
- rule_indices: torch.Tensor # Which logical rule each ground rule belongs to
29
- fact_indices: Dict[Triple, int] # Mapping from facts to batch indices
30
- batch_facts: List[Triple] # All unique facts in this batch
31
- batch_size: int
32
-
33
- def __post_init__(self):
34
- """Validate batch consistency"""
35
- assert len(self.ground_rules) == self.batch_size, \
36
- f"Inconsistent batch size: {len(self.ground_rules)} vs {self.batch_size}"
37
-
38
- assert len(self.rule_indices) == self.batch_size, \
39
- f"Rule indices length mismatch: {len(self.rule_indices)} vs {self.batch_size}"
40
-
41
- def get_fact_truth_matrix(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
42
- """
43
- Create truth value matrix for facts in this batch
44
-
45
- Returns:
46
- Tensor of shape [batch_size, max_facts_per_rule] with truth values
47
- """
48
- max_facts = max(len(gr.get_all_facts()) for gr in self.ground_rules) if self.ground_rules else 0
49
-
50
- if max_facts == 0:
51
- return torch.zeros(self.batch_size, 0, dtype=torch.bool)
52
-
53
- truth_matrix = torch.zeros(self.batch_size, max_facts, dtype=torch.bool)
54
-
55
- for i, ground_rule in enumerate(self.ground_rules):
56
- facts = ground_rule.get_all_facts()
57
- for j, fact in enumerate(facts):
58
- if j < max_facts:
59
- truth_matrix[i, j] = fact_assignments.get(fact, False)
60
-
61
- return truth_matrix
62
-
63
- def evaluate_ground_rules(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
64
- """
65
- Evaluate all ground rules in batch
66
-
67
- Returns:
68
- Boolean tensor indicating which ground rules are satisfied
69
- """
70
- satisfaction = torch.zeros(self.batch_size, dtype=torch.bool)
71
-
72
- for i, ground_rule in enumerate(self.ground_rules):
73
- # Check body satisfaction
74
- body_satisfied = all(
75
- fact_assignments.get(fact, False)
76
- for fact in ground_rule.body_facts
77
- )
78
-
79
- if not body_satisfied:
80
- # Body false -> rule vacuously true
81
- satisfaction[i] = True
82
- else:
83
- # Body true -> check head
84
- head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
85
- satisfaction[i] = head_satisfied
86
-
87
- return satisfaction
88
-
89
-
90
- class GroundRuleSampler:
91
- """
92
- Samples ground rules for efficient MLN training and inference
93
-
94
- Paper Section 4.2: "this paper randomly samples batches of ground rules to form datasets,
95
- wherein the ground rules are approximately independent of each batch"
96
- """
97
-
98
- def __init__(self, config: NPLLConfig, random_seed: Optional[int] = None):
99
- self.config = config
100
- self.batch_size = config.batch_size
101
- self.max_ground_rules = config.max_ground_rules
102
-
103
- if random_seed is not None:
104
- random.seed(random_seed)
105
- np.random.seed(random_seed)
106
- torch.manual_seed(random_seed)
107
-
108
- def sample_ground_rules(self, all_ground_rules: List[GroundRule],
109
- num_batches: int = 1,
110
- sampling_strategy: str = "uniform") -> List[GroundRuleBatch]:
111
- """
112
- Sample batches of ground rules
113
-
114
- Args:
115
- all_ground_rules: All available ground rules
116
- num_batches: Number of batches to create
117
- sampling_strategy: 'uniform', 'weighted', or 'stratified'
118
-
119
- Returns:
120
- List of GroundRuleBatch objects
121
- """
122
- if not all_ground_rules:
123
- return []
124
-
125
- total_rules = len(all_ground_rules)
126
- rules_per_batch = min(self.batch_size, total_rules // num_batches) if num_batches > 1 else min(self.batch_size, total_rules)
127
-
128
- batches = []
129
-
130
- for batch_idx in range(num_batches):
131
- if sampling_strategy == "uniform":
132
- sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
133
- elif sampling_strategy == "weighted":
134
- sampled_rules = self._weighted_sampling(all_ground_rules, rules_per_batch)
135
- elif sampling_strategy == "stratified":
136
- sampled_rules = self._stratified_sampling(all_ground_rules, rules_per_batch)
137
- else:
138
- sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
139
-
140
- if sampled_rules:
141
- batch = self._create_batch_from_rules(sampled_rules)
142
- batches.append(batch)
143
-
144
- logger.debug(f"Created {len(batches)} ground rule batches with avg size {rules_per_batch}")
145
- return batches
146
-
147
- def _uniform_sampling(self, ground_rules: List[GroundRule],
148
- sample_size: int) -> List[GroundRule]:
149
- """Uniform random sampling of ground rules"""
150
- if sample_size >= len(ground_rules):
151
- return ground_rules.copy()
152
-
153
- return random.sample(ground_rules, sample_size)
154
-
155
- def _weighted_sampling(self, ground_rules: List[GroundRule],
156
- sample_size: int) -> List[GroundRule]:
157
- """
158
- Weighted sampling based on rule confidence/support
159
- Higher confidence rules are more likely to be sampled
160
- """
161
- if sample_size >= len(ground_rules):
162
- return ground_rules.copy()
163
-
164
- # Use parent rule confidence as weight
165
- weights = [gr.parent_rule.confidence for gr in ground_rules]
166
-
167
- # Normalize weights
168
- total_weight = sum(weights)
169
- if total_weight > 0:
170
- weights = [w / total_weight for w in weights]
171
- else:
172
- weights = [1.0 / len(weights)] * len(weights)
173
-
174
- # Sample with replacement
175
- sampled_indices = np.random.choice(
176
- len(ground_rules),
177
- size=sample_size,
178
- p=weights,
179
- replace=False if sample_size <= len(ground_rules) else True
180
- )
181
-
182
- return [ground_rules[i] for i in sampled_indices]
183
-
184
- def _stratified_sampling(self, ground_rules: List[GroundRule],
185
- sample_size: int) -> List[GroundRule]:
186
- """
187
- Stratified sampling ensuring representation from different rule types
188
- """
189
- if sample_size >= len(ground_rules):
190
- return ground_rules.copy()
191
-
192
- # Group by parent rule type
193
- rule_type_groups = defaultdict(list)
194
- for gr in ground_rules:
195
- rule_type_groups[gr.parent_rule.rule_type].append(gr)
196
-
197
- # Sample proportionally from each group
198
- sampled_rules = []
199
- remaining_samples = sample_size
200
-
201
- for rule_type, type_rules in rule_type_groups.items():
202
- # Proportional allocation
203
- group_sample_size = min(
204
- len(type_rules),
205
- max(1, int(remaining_samples * len(type_rules) / len(ground_rules)))
206
- )
207
-
208
- if group_sample_size > 0:
209
- group_sample = random.sample(type_rules, group_sample_size)
210
- sampled_rules.extend(group_sample)
211
- remaining_samples -= group_sample_size
212
-
213
- # If we need more samples, fill randomly
214
- if remaining_samples > 0 and len(sampled_rules) < sample_size:
215
- remaining_rules = [gr for gr in ground_rules if gr not in sampled_rules]
216
- if remaining_rules:
217
- additional_samples = min(remaining_samples, len(remaining_rules))
218
- additional_rules = random.sample(remaining_rules, additional_samples)
219
- sampled_rules.extend(additional_rules)
220
-
221
- return sampled_rules[:sample_size]
222
-
223
- def _create_batch_from_rules(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
224
- """Create GroundRuleBatch from list of ground rules"""
225
- if not ground_rules:
226
- return GroundRuleBatch(
227
- ground_rules=[],
228
- rule_indices=torch.tensor([]),
229
- fact_indices={},
230
- batch_facts=[],
231
- batch_size=0
232
- )
233
-
234
- # Extract rule indices (assuming rules are indexed by their position in logical_rules list)
235
- rule_indices = []
236
- unique_facts = set()
237
-
238
- # Build parent rule ID to index mapping (this should be provided by MLN)
239
- rule_id_to_idx = {}
240
- for i, gr in enumerate(ground_rules):
241
- if gr.parent_rule.rule_id not in rule_id_to_idx:
242
- rule_id_to_idx[gr.parent_rule.rule_id] = len(rule_id_to_idx)
243
-
244
- rule_indices.append(rule_id_to_idx[gr.parent_rule.rule_id])
245
-
246
- # Collect all unique facts
247
- unique_facts.update(gr.get_all_facts())
248
-
249
- # Create fact indexing
250
- batch_facts = list(unique_facts)
251
- fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
252
-
253
- return GroundRuleBatch(
254
- ground_rules=ground_rules,
255
- rule_indices=torch.tensor(rule_indices, dtype=torch.long),
256
- fact_indices=fact_indices,
257
- batch_facts=batch_facts,
258
- batch_size=len(ground_rules)
259
- )
260
-
261
- def create_batches_for_training(self, ground_rules: List[GroundRule],
262
- shuffle: bool = True) -> List[GroundRuleBatch]:
263
- """
264
- Create batches specifically for training
265
-
266
- Args:
267
- ground_rules: All ground rules to batch
268
- shuffle: Whether to shuffle before batching
269
-
270
- Returns:
271
- List of training batches
272
- """
273
- if not ground_rules:
274
- return []
275
-
276
- # Shuffle if requested
277
- rules_to_batch = ground_rules.copy()
278
- if shuffle:
279
- random.shuffle(rules_to_batch)
280
-
281
- # Create sequential batches
282
- batches = []
283
- for i in range(0, len(rules_to_batch), self.batch_size):
284
- batch_rules = rules_to_batch[i:i + self.batch_size]
285
- batch = self._create_batch_from_rules(batch_rules)
286
- batches.append(batch)
287
-
288
- return batches
289
-
290
-
291
- class FactBatchProcessor:
292
- """
293
- Processes facts in batches for efficient scoring and probability computation
294
- """
295
-
296
- def __init__(self, config: NPLLConfig):
297
- self.config = config
298
- self.batch_size = config.batch_size
299
-
300
- def create_fact_batches(self, facts: List[Triple],
301
- batch_size: Optional[int] = None) -> List[List[Triple]]:
302
- """Create batches of facts for processing"""
303
- batch_size = batch_size or self.batch_size
304
-
305
- batches = []
306
- for i in range(0, len(facts), batch_size):
307
- batch = facts[i:i + batch_size]
308
- batches.append(batch)
309
-
310
- return batches
311
-
312
- def process_fact_batches(self, fact_batches: List[List[Triple]],
313
- processor_func) -> List[Any]:
314
- """Process batches using provided function"""
315
- results = []
316
-
317
- for batch in fact_batches:
318
- batch_result = processor_func(batch)
319
- results.append(batch_result)
320
-
321
- return results
322
-
323
-
324
- class MemoryEfficientBatcher:
325
- """
326
- Memory-efficient batching for large-scale ground rule processing
327
- Uses generators to avoid loading all data into memory
328
- """
329
-
330
- def __init__(self, config: NPLLConfig):
331
- self.config = config
332
- self.batch_size = config.batch_size
333
-
334
- def create_ground_rule_iterator(self, ground_rules: List[GroundRule],
335
- shuffle: bool = True) -> Iterator[GroundRuleBatch]:
336
- """
337
- Create iterator over ground rule batches for memory efficiency
338
-
339
- Args:
340
- ground_rules: All ground rules
341
- shuffle: Whether to shuffle order
342
-
343
- Yields:
344
- GroundRuleBatch objects
345
- """
346
- if shuffle:
347
- indices = list(range(len(ground_rules)))
348
- random.shuffle(indices)
349
- else:
350
- indices = list(range(len(ground_rules)))
351
-
352
- for i in range(0, len(indices), self.batch_size):
353
- batch_indices = indices[i:i + self.batch_size]
354
- batch_rules = [ground_rules[idx] for idx in batch_indices]
355
-
356
- # Create batch
357
- batch = self._create_efficient_batch(batch_rules)
358
- yield batch
359
-
360
- def _create_efficient_batch(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
361
- """Create batch with minimal memory overhead"""
362
- if not ground_rules:
363
- return GroundRuleBatch([], torch.tensor([]), {}, [], 0)
364
-
365
- # Efficient fact collection using sets
366
- all_facts = set()
367
- rule_indices = []
368
-
369
- # Single pass to collect facts and rule indices
370
- for i, gr in enumerate(ground_rules):
371
- all_facts.update(gr.get_all_facts())
372
- # Use hash of rule_id as index for efficiency
373
- rule_indices.append(hash(gr.parent_rule.rule_id) % 1000)
374
-
375
- batch_facts = list(all_facts)
376
- fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
377
-
378
- return GroundRuleBatch(
379
- ground_rules=ground_rules,
380
- rule_indices=torch.tensor(rule_indices, dtype=torch.long),
381
- fact_indices=fact_indices,
382
- batch_facts=batch_facts,
383
- batch_size=len(ground_rules)
384
- )
385
-
386
-
387
- class AdaptiveBatcher:
388
- """
389
- Adaptive batching that adjusts batch size based on memory usage and performance
390
- """
391
-
392
- def __init__(self, config: NPLLConfig, initial_batch_size: Optional[int] = None):
393
- self.config = config
394
- self.current_batch_size = initial_batch_size or config.batch_size
395
- self.min_batch_size = max(1, config.batch_size // 4)
396
- self.max_batch_size = config.batch_size * 2
397
-
398
- # Performance tracking
399
- self.performance_history = []
400
- self.memory_usage_history = []
401
-
402
- def adapt_batch_size(self, processing_time: float, memory_usage: float,
403
- target_time: float = 1.0):
404
- """
405
- Adapt batch size based on performance metrics
406
-
407
- Args:
408
- processing_time: Time taken to process current batch
409
- target_time: Target processing time per batch
410
- memory_usage: Memory usage for current batch
411
- """
412
- self.performance_history.append(processing_time)
413
- self.memory_usage_history.append(memory_usage)
414
-
415
- # Keep only recent history
416
- max_history = 10
417
- if len(self.performance_history) > max_history:
418
- self.performance_history = self.performance_history[-max_history:]
419
- self.memory_usage_history = self.memory_usage_history[-max_history:]
420
-
421
- # Adjust based on performance
422
- if processing_time > target_time * 1.5:
423
- # Too slow, decrease batch size
424
- new_batch_size = max(self.min_batch_size, int(self.current_batch_size * 0.8))
425
- elif processing_time < target_time * 0.5:
426
- # Too fast, increase batch size
427
- new_batch_size = min(self.max_batch_size, int(self.current_batch_size * 1.2))
428
- else:
429
- # Good performance, keep current size
430
- new_batch_size = self.current_batch_size
431
-
432
- if new_batch_size != self.current_batch_size:
433
- logger.debug(f"Adapted batch size from {self.current_batch_size} to {new_batch_size}")
434
- self.current_batch_size = new_batch_size
435
-
436
- def get_current_batch_size(self) -> int:
437
- """Get current adaptive batch size"""
438
- return self.current_batch_size
439
-
440
-
441
- def create_ground_rule_sampler(config: NPLLConfig, seed: Optional[int] = None) -> GroundRuleSampler:
442
- """Factory function to create ground rule sampler"""
443
- return GroundRuleSampler(config, seed)
444
-
445
-
446
- def verify_batch_utils():
447
- """Verify batch utility implementations"""
448
- from ..utils.config import default_config
449
- from ..core import Entity, Relation, load_knowledge_graph_from_triples
450
- from ..core.logical_rules import Variable, Atom, RuleType
451
-
452
- # Create test data
453
- test_triples = [
454
- ("A", "r1", "B"),
455
- ("B", "r2", "C"),
456
- ("A", "r3", "C")
457
- ]
458
-
459
- kg = load_knowledge_graph_from_triples(test_triples)
460
-
461
- # Create test rule and ground rules
462
- r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
463
- x, y, z = Variable('x'), Variable('y'), Variable('z')
464
-
465
- test_rule = LogicalRule(
466
- rule_id="test_rule",
467
- body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
468
- head=Atom(r3, (x, z)),
469
- rule_type=RuleType.TRANSITIVITY
470
- )
471
-
472
- ground_rules = test_rule.generate_ground_rules(kg, max_groundings=10)
473
-
474
- # Test sampler
475
- sampler = GroundRuleSampler(default_config, seed=42)
476
- batches = sampler.sample_ground_rules(ground_rules, num_batches=2)
477
-
478
- assert len(batches) <= 2, "Should create at most 2 batches"
479
-
480
- for batch in batches:
481
- assert batch.batch_size == len(batch.ground_rules), "Batch size consistency"
482
- assert len(batch.rule_indices) == batch.batch_size, "Rule indices length"
483
-
484
- # Test memory-efficient batcher
485
- efficient_batcher = MemoryEfficientBatcher(default_config)
486
- batch_iterator = efficient_batcher.create_ground_rule_iterator(ground_rules)
487
-
488
- batches_from_iterator = list(batch_iterator)
489
- assert len(batches_from_iterator) > 0, "Should create batches from iterator"
490
-
491
- logger.info("Batch utilities verified successfully")
492
-
1
+ """
2
+ Batch processing utilities for NPLL ground rules
3
+ Handles efficient batching and sampling of ground rules for MLN computations
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import List, Dict, Set, Tuple, Optional, Iterator, Any
9
+ from collections import defaultdict
10
+ import random
11
+ import logging
12
+ from dataclasses import dataclass
13
+
14
+ from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
15
+ from ..utils.config import NPLLConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class GroundRuleBatch:
22
+ """
23
+ Batch of ground rules for efficient processing
24
+
25
+ Contains ground rules and associated metadata for batch operations
26
+ """
27
+ ground_rules: List[GroundRule]
28
+ rule_indices: torch.Tensor # Which logical rule each ground rule belongs to
29
+ fact_indices: Dict[Triple, int] # Mapping from facts to batch indices
30
+ batch_facts: List[Triple] # All unique facts in this batch
31
+ batch_size: int
32
+
33
+ def __post_init__(self):
34
+ """Validate batch consistency"""
35
+ assert len(self.ground_rules) == self.batch_size, \
36
+ f"Inconsistent batch size: {len(self.ground_rules)} vs {self.batch_size}"
37
+
38
+ assert len(self.rule_indices) == self.batch_size, \
39
+ f"Rule indices length mismatch: {len(self.rule_indices)} vs {self.batch_size}"
40
+
41
+ def get_fact_truth_matrix(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
42
+ """
43
+ Create truth value matrix for facts in this batch
44
+
45
+ Returns:
46
+ Tensor of shape [batch_size, max_facts_per_rule] with truth values
47
+ """
48
+ max_facts = max(len(gr.get_all_facts()) for gr in self.ground_rules) if self.ground_rules else 0
49
+
50
+ if max_facts == 0:
51
+ return torch.zeros(self.batch_size, 0, dtype=torch.bool)
52
+
53
+ truth_matrix = torch.zeros(self.batch_size, max_facts, dtype=torch.bool)
54
+
55
+ for i, ground_rule in enumerate(self.ground_rules):
56
+ facts = ground_rule.get_all_facts()
57
+ for j, fact in enumerate(facts):
58
+ if j < max_facts:
59
+ truth_matrix[i, j] = fact_assignments.get(fact, False)
60
+
61
+ return truth_matrix
62
+
63
+ def evaluate_ground_rules(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
64
+ """
65
+ Evaluate all ground rules in batch
66
+
67
+ Returns:
68
+ Boolean tensor indicating which ground rules are satisfied
69
+ """
70
+ satisfaction = torch.zeros(self.batch_size, dtype=torch.bool)
71
+
72
+ for i, ground_rule in enumerate(self.ground_rules):
73
+ # Check body satisfaction
74
+ body_satisfied = all(
75
+ fact_assignments.get(fact, False)
76
+ for fact in ground_rule.body_facts
77
+ )
78
+
79
+ if not body_satisfied:
80
+ # Body false -> rule vacuously true
81
+ satisfaction[i] = True
82
+ else:
83
+ # Body true -> check head
84
+ head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
85
+ satisfaction[i] = head_satisfied
86
+
87
+ return satisfaction
88
+
89
+
90
+ class GroundRuleSampler:
91
+ """
92
+ Samples ground rules for efficient MLN training and inference
93
+
94
+ Paper Section 4.2: "this paper randomly samples batches of ground rules to form datasets,
95
+ wherein the ground rules are approximately independent of each batch"
96
+ """
97
+
98
+ def __init__(self, config: NPLLConfig, random_seed: Optional[int] = None):
99
+ self.config = config
100
+ self.batch_size = config.batch_size
101
+ self.max_ground_rules = config.max_ground_rules
102
+
103
+ if random_seed is not None:
104
+ random.seed(random_seed)
105
+ np.random.seed(random_seed)
106
+ torch.manual_seed(random_seed)
107
+
108
+ def sample_ground_rules(self, all_ground_rules: List[GroundRule],
109
+ num_batches: int = 1,
110
+ sampling_strategy: str = "uniform") -> List[GroundRuleBatch]:
111
+ """
112
+ Sample batches of ground rules
113
+
114
+ Args:
115
+ all_ground_rules: All available ground rules
116
+ num_batches: Number of batches to create
117
+ sampling_strategy: 'uniform', 'weighted', or 'stratified'
118
+
119
+ Returns:
120
+ List of GroundRuleBatch objects
121
+ """
122
+ if not all_ground_rules:
123
+ return []
124
+
125
+ total_rules = len(all_ground_rules)
126
+ rules_per_batch = min(self.batch_size, total_rules // num_batches) if num_batches > 1 else min(self.batch_size, total_rules)
127
+
128
+ batches = []
129
+
130
+ for batch_idx in range(num_batches):
131
+ if sampling_strategy == "uniform":
132
+ sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
133
+ elif sampling_strategy == "weighted":
134
+ sampled_rules = self._weighted_sampling(all_ground_rules, rules_per_batch)
135
+ elif sampling_strategy == "stratified":
136
+ sampled_rules = self._stratified_sampling(all_ground_rules, rules_per_batch)
137
+ else:
138
+ sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
139
+
140
+ if sampled_rules:
141
+ batch = self._create_batch_from_rules(sampled_rules)
142
+ batches.append(batch)
143
+
144
+ logger.debug(f"Created {len(batches)} ground rule batches with avg size {rules_per_batch}")
145
+ return batches
146
+
147
+ def _uniform_sampling(self, ground_rules: List[GroundRule],
148
+ sample_size: int) -> List[GroundRule]:
149
+ """Uniform random sampling of ground rules"""
150
+ if sample_size >= len(ground_rules):
151
+ return ground_rules.copy()
152
+
153
+ return random.sample(ground_rules, sample_size)
154
+
155
+ def _weighted_sampling(self, ground_rules: List[GroundRule],
156
+ sample_size: int) -> List[GroundRule]:
157
+ """
158
+ Weighted sampling based on rule confidence/support
159
+ Higher confidence rules are more likely to be sampled
160
+ """
161
+ if sample_size >= len(ground_rules):
162
+ return ground_rules.copy()
163
+
164
+ # Use parent rule confidence as weight
165
+ weights = [gr.parent_rule.confidence for gr in ground_rules]
166
+
167
+ # Normalize weights
168
+ total_weight = sum(weights)
169
+ if total_weight > 0:
170
+ weights = [w / total_weight for w in weights]
171
+ else:
172
+ weights = [1.0 / len(weights)] * len(weights)
173
+
174
+ # Sample with replacement
175
+ sampled_indices = np.random.choice(
176
+ len(ground_rules),
177
+ size=sample_size,
178
+ p=weights,
179
+ replace=False if sample_size <= len(ground_rules) else True
180
+ )
181
+
182
+ return [ground_rules[i] for i in sampled_indices]
183
+
184
+ def _stratified_sampling(self, ground_rules: List[GroundRule],
185
+ sample_size: int) -> List[GroundRule]:
186
+ """
187
+ Stratified sampling ensuring representation from different rule types
188
+ """
189
+ if sample_size >= len(ground_rules):
190
+ return ground_rules.copy()
191
+
192
+ # Group by parent rule type
193
+ rule_type_groups = defaultdict(list)
194
+ for gr in ground_rules:
195
+ rule_type_groups[gr.parent_rule.rule_type].append(gr)
196
+
197
+ # Sample proportionally from each group
198
+ sampled_rules = []
199
+ remaining_samples = sample_size
200
+
201
+ for rule_type, type_rules in rule_type_groups.items():
202
+ # Proportional allocation
203
+ group_sample_size = min(
204
+ len(type_rules),
205
+ max(1, int(remaining_samples * len(type_rules) / len(ground_rules)))
206
+ )
207
+
208
+ if group_sample_size > 0:
209
+ group_sample = random.sample(type_rules, group_sample_size)
210
+ sampled_rules.extend(group_sample)
211
+ remaining_samples -= group_sample_size
212
+
213
+ # If we need more samples, fill randomly
214
+ if remaining_samples > 0 and len(sampled_rules) < sample_size:
215
+ remaining_rules = [gr for gr in ground_rules if gr not in sampled_rules]
216
+ if remaining_rules:
217
+ additional_samples = min(remaining_samples, len(remaining_rules))
218
+ additional_rules = random.sample(remaining_rules, additional_samples)
219
+ sampled_rules.extend(additional_rules)
220
+
221
+ return sampled_rules[:sample_size]
222
+
223
+ def _create_batch_from_rules(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
224
+ """Create GroundRuleBatch from list of ground rules"""
225
+ if not ground_rules:
226
+ return GroundRuleBatch(
227
+ ground_rules=[],
228
+ rule_indices=torch.tensor([]),
229
+ fact_indices={},
230
+ batch_facts=[],
231
+ batch_size=0
232
+ )
233
+
234
+ # Extract rule indices (assuming rules are indexed by their position in logical_rules list)
235
+ rule_indices = []
236
+ unique_facts = set()
237
+
238
+ # Build parent rule ID to index mapping (this should be provided by MLN)
239
+ rule_id_to_idx = {}
240
+ for i, gr in enumerate(ground_rules):
241
+ if gr.parent_rule.rule_id not in rule_id_to_idx:
242
+ rule_id_to_idx[gr.parent_rule.rule_id] = len(rule_id_to_idx)
243
+
244
+ rule_indices.append(rule_id_to_idx[gr.parent_rule.rule_id])
245
+
246
+ # Collect all unique facts
247
+ unique_facts.update(gr.get_all_facts())
248
+
249
+ # Create fact indexing
250
+ batch_facts = list(unique_facts)
251
+ fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
252
+
253
+ return GroundRuleBatch(
254
+ ground_rules=ground_rules,
255
+ rule_indices=torch.tensor(rule_indices, dtype=torch.long),
256
+ fact_indices=fact_indices,
257
+ batch_facts=batch_facts,
258
+ batch_size=len(ground_rules)
259
+ )
260
+
261
+ def create_batches_for_training(self, ground_rules: List[GroundRule],
262
+ shuffle: bool = True) -> List[GroundRuleBatch]:
263
+ """
264
+ Create batches specifically for training
265
+
266
+ Args:
267
+ ground_rules: All ground rules to batch
268
+ shuffle: Whether to shuffle before batching
269
+
270
+ Returns:
271
+ List of training batches
272
+ """
273
+ if not ground_rules:
274
+ return []
275
+
276
+ # Shuffle if requested
277
+ rules_to_batch = ground_rules.copy()
278
+ if shuffle:
279
+ random.shuffle(rules_to_batch)
280
+
281
+ # Create sequential batches
282
+ batches = []
283
+ for i in range(0, len(rules_to_batch), self.batch_size):
284
+ batch_rules = rules_to_batch[i:i + self.batch_size]
285
+ batch = self._create_batch_from_rules(batch_rules)
286
+ batches.append(batch)
287
+
288
+ return batches
289
+
290
+
291
+ class FactBatchProcessor:
292
+ """
293
+ Processes facts in batches for efficient scoring and probability computation
294
+ """
295
+
296
+ def __init__(self, config: NPLLConfig):
297
+ self.config = config
298
+ self.batch_size = config.batch_size
299
+
300
+ def create_fact_batches(self, facts: List[Triple],
301
+ batch_size: Optional[int] = None) -> List[List[Triple]]:
302
+ """Create batches of facts for processing"""
303
+ batch_size = batch_size or self.batch_size
304
+
305
+ batches = []
306
+ for i in range(0, len(facts), batch_size):
307
+ batch = facts[i:i + batch_size]
308
+ batches.append(batch)
309
+
310
+ return batches
311
+
312
+ def process_fact_batches(self, fact_batches: List[List[Triple]],
313
+ processor_func) -> List[Any]:
314
+ """Process batches using provided function"""
315
+ results = []
316
+
317
+ for batch in fact_batches:
318
+ batch_result = processor_func(batch)
319
+ results.append(batch_result)
320
+
321
+ return results
322
+
323
+
324
+ class MemoryEfficientBatcher:
325
+ """
326
+ Memory-efficient batching for large-scale ground rule processing
327
+ Uses generators to avoid loading all data into memory
328
+ """
329
+
330
+ def __init__(self, config: NPLLConfig):
331
+ self.config = config
332
+ self.batch_size = config.batch_size
333
+
334
+ def create_ground_rule_iterator(self, ground_rules: List[GroundRule],
335
+ shuffle: bool = True) -> Iterator[GroundRuleBatch]:
336
+ """
337
+ Create iterator over ground rule batches for memory efficiency
338
+
339
+ Args:
340
+ ground_rules: All ground rules
341
+ shuffle: Whether to shuffle order
342
+
343
+ Yields:
344
+ GroundRuleBatch objects
345
+ """
346
+ if shuffle:
347
+ indices = list(range(len(ground_rules)))
348
+ random.shuffle(indices)
349
+ else:
350
+ indices = list(range(len(ground_rules)))
351
+
352
+ for i in range(0, len(indices), self.batch_size):
353
+ batch_indices = indices[i:i + self.batch_size]
354
+ batch_rules = [ground_rules[idx] for idx in batch_indices]
355
+
356
+ # Create batch
357
+ batch = self._create_efficient_batch(batch_rules)
358
+ yield batch
359
+
360
+ def _create_efficient_batch(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
361
+ """Create batch with minimal memory overhead"""
362
+ if not ground_rules:
363
+ return GroundRuleBatch([], torch.tensor([]), {}, [], 0)
364
+
365
+ # Efficient fact collection using sets
366
+ all_facts = set()
367
+ rule_indices = []
368
+
369
+ # Single pass to collect facts and rule indices
370
+ for i, gr in enumerate(ground_rules):
371
+ all_facts.update(gr.get_all_facts())
372
+ # Use hash of rule_id as index for efficiency
373
+ rule_indices.append(hash(gr.parent_rule.rule_id) % 1000)
374
+
375
+ batch_facts = list(all_facts)
376
+ fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
377
+
378
+ return GroundRuleBatch(
379
+ ground_rules=ground_rules,
380
+ rule_indices=torch.tensor(rule_indices, dtype=torch.long),
381
+ fact_indices=fact_indices,
382
+ batch_facts=batch_facts,
383
+ batch_size=len(ground_rules)
384
+ )
385
+
386
+
387
+ class AdaptiveBatcher:
388
+ """
389
+ Adaptive batching that adjusts batch size based on memory usage and performance
390
+ """
391
+
392
+ def __init__(self, config: NPLLConfig, initial_batch_size: Optional[int] = None):
393
+ self.config = config
394
+ self.current_batch_size = initial_batch_size or config.batch_size
395
+ self.min_batch_size = max(1, config.batch_size // 4)
396
+ self.max_batch_size = config.batch_size * 2
397
+
398
+ # Performance tracking
399
+ self.performance_history = []
400
+ self.memory_usage_history = []
401
+
402
+ def adapt_batch_size(self, processing_time: float, memory_usage: float,
403
+ target_time: float = 1.0):
404
+ """
405
+ Adapt batch size based on performance metrics
406
+
407
+ Args:
408
+ processing_time: Time taken to process current batch
409
+ target_time: Target processing time per batch
410
+ memory_usage: Memory usage for current batch
411
+ """
412
+ self.performance_history.append(processing_time)
413
+ self.memory_usage_history.append(memory_usage)
414
+
415
+ # Keep only recent history
416
+ max_history = 10
417
+ if len(self.performance_history) > max_history:
418
+ self.performance_history = self.performance_history[-max_history:]
419
+ self.memory_usage_history = self.memory_usage_history[-max_history:]
420
+
421
+ # Adjust based on performance
422
+ if processing_time > target_time * 1.5:
423
+ # Too slow, decrease batch size
424
+ new_batch_size = max(self.min_batch_size, int(self.current_batch_size * 0.8))
425
+ elif processing_time < target_time * 0.5:
426
+ # Too fast, increase batch size
427
+ new_batch_size = min(self.max_batch_size, int(self.current_batch_size * 1.2))
428
+ else:
429
+ # Good performance, keep current size
430
+ new_batch_size = self.current_batch_size
431
+
432
+ if new_batch_size != self.current_batch_size:
433
+ logger.debug(f"Adapted batch size from {self.current_batch_size} to {new_batch_size}")
434
+ self.current_batch_size = new_batch_size
435
+
436
+ def get_current_batch_size(self) -> int:
437
+ """Get current adaptive batch size"""
438
+ return self.current_batch_size
439
+
440
+
441
+ def create_ground_rule_sampler(config: NPLLConfig, seed: Optional[int] = None) -> GroundRuleSampler:
442
+ """Factory function to create ground rule sampler"""
443
+ return GroundRuleSampler(config, seed)
444
+
445
+
446
+ def verify_batch_utils():
447
+ """Verify batch utility implementations"""
448
+ from ..utils.config import default_config
449
+ from ..core import Entity, Relation, load_knowledge_graph_from_triples
450
+ from ..core.logical_rules import Variable, Atom, RuleType
451
+
452
+ # Create test data
453
+ test_triples = [
454
+ ("A", "r1", "B"),
455
+ ("B", "r2", "C"),
456
+ ("A", "r3", "C")
457
+ ]
458
+
459
+ kg = load_knowledge_graph_from_triples(test_triples)
460
+
461
+ # Create test rule and ground rules
462
+ r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
463
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
464
+
465
+ test_rule = LogicalRule(
466
+ rule_id="test_rule",
467
+ body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
468
+ head=Atom(r3, (x, z)),
469
+ rule_type=RuleType.TRANSITIVITY
470
+ )
471
+
472
+ ground_rules = test_rule.generate_ground_rules(kg, max_groundings=10)
473
+
474
+ # Test sampler
475
+ sampler = GroundRuleSampler(default_config, seed=42)
476
+ batches = sampler.sample_ground_rules(ground_rules, num_batches=2)
477
+
478
+ assert len(batches) <= 2, "Should create at most 2 batches"
479
+
480
+ for batch in batches:
481
+ assert batch.batch_size == len(batch.ground_rules), "Batch size consistency"
482
+ assert len(batch.rule_indices) == batch.batch_size, "Rule indices length"
483
+
484
+ # Test memory-efficient batcher
485
+ efficient_batcher = MemoryEfficientBatcher(default_config)
486
+ batch_iterator = efficient_batcher.create_ground_rule_iterator(ground_rules)
487
+
488
+ batches_from_iterator = list(batch_iterator)
489
+ assert len(batches_from_iterator) > 0, "Should create batches from iterator"
490
+
491
+ logger.info("Batch utilities verified successfully")
492
+
493
493
  return True