odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/utils/batch_utils.py
CHANGED
|
@@ -1,493 +1,493 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Batch processing utilities for NPLL ground rules
|
|
3
|
-
Handles efficient batching and sampling of ground rules for MLN computations
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import numpy as np
|
|
8
|
-
from typing import List, Dict, Set, Tuple, Optional, Iterator, Any
|
|
9
|
-
from collections import defaultdict
|
|
10
|
-
import random
|
|
11
|
-
import logging
|
|
12
|
-
from dataclasses import dataclass
|
|
13
|
-
|
|
14
|
-
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
15
|
-
from ..utils.config import NPLLConfig
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class GroundRuleBatch:
|
|
22
|
-
"""
|
|
23
|
-
Batch of ground rules for efficient processing
|
|
24
|
-
|
|
25
|
-
Contains ground rules and associated metadata for batch operations
|
|
26
|
-
"""
|
|
27
|
-
ground_rules: List[GroundRule]
|
|
28
|
-
rule_indices: torch.Tensor # Which logical rule each ground rule belongs to
|
|
29
|
-
fact_indices: Dict[Triple, int] # Mapping from facts to batch indices
|
|
30
|
-
batch_facts: List[Triple] # All unique facts in this batch
|
|
31
|
-
batch_size: int
|
|
32
|
-
|
|
33
|
-
def __post_init__(self):
|
|
34
|
-
"""Validate batch consistency"""
|
|
35
|
-
assert len(self.ground_rules) == self.batch_size, \
|
|
36
|
-
f"Inconsistent batch size: {len(self.ground_rules)} vs {self.batch_size}"
|
|
37
|
-
|
|
38
|
-
assert len(self.rule_indices) == self.batch_size, \
|
|
39
|
-
f"Rule indices length mismatch: {len(self.rule_indices)} vs {self.batch_size}"
|
|
40
|
-
|
|
41
|
-
def get_fact_truth_matrix(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
42
|
-
"""
|
|
43
|
-
Create truth value matrix for facts in this batch
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
Tensor of shape [batch_size, max_facts_per_rule] with truth values
|
|
47
|
-
"""
|
|
48
|
-
max_facts = max(len(gr.get_all_facts()) for gr in self.ground_rules) if self.ground_rules else 0
|
|
49
|
-
|
|
50
|
-
if max_facts == 0:
|
|
51
|
-
return torch.zeros(self.batch_size, 0, dtype=torch.bool)
|
|
52
|
-
|
|
53
|
-
truth_matrix = torch.zeros(self.batch_size, max_facts, dtype=torch.bool)
|
|
54
|
-
|
|
55
|
-
for i, ground_rule in enumerate(self.ground_rules):
|
|
56
|
-
facts = ground_rule.get_all_facts()
|
|
57
|
-
for j, fact in enumerate(facts):
|
|
58
|
-
if j < max_facts:
|
|
59
|
-
truth_matrix[i, j] = fact_assignments.get(fact, False)
|
|
60
|
-
|
|
61
|
-
return truth_matrix
|
|
62
|
-
|
|
63
|
-
def evaluate_ground_rules(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
64
|
-
"""
|
|
65
|
-
Evaluate all ground rules in batch
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
Boolean tensor indicating which ground rules are satisfied
|
|
69
|
-
"""
|
|
70
|
-
satisfaction = torch.zeros(self.batch_size, dtype=torch.bool)
|
|
71
|
-
|
|
72
|
-
for i, ground_rule in enumerate(self.ground_rules):
|
|
73
|
-
# Check body satisfaction
|
|
74
|
-
body_satisfied = all(
|
|
75
|
-
fact_assignments.get(fact, False)
|
|
76
|
-
for fact in ground_rule.body_facts
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
if not body_satisfied:
|
|
80
|
-
# Body false -> rule vacuously true
|
|
81
|
-
satisfaction[i] = True
|
|
82
|
-
else:
|
|
83
|
-
# Body true -> check head
|
|
84
|
-
head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
|
|
85
|
-
satisfaction[i] = head_satisfied
|
|
86
|
-
|
|
87
|
-
return satisfaction
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class GroundRuleSampler:
|
|
91
|
-
"""
|
|
92
|
-
Samples ground rules for efficient MLN training and inference
|
|
93
|
-
|
|
94
|
-
Paper Section 4.2: "this paper randomly samples batches of ground rules to form datasets,
|
|
95
|
-
wherein the ground rules are approximately independent of each batch"
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
|
-
def __init__(self, config: NPLLConfig, random_seed: Optional[int] = None):
|
|
99
|
-
self.config = config
|
|
100
|
-
self.batch_size = config.batch_size
|
|
101
|
-
self.max_ground_rules = config.max_ground_rules
|
|
102
|
-
|
|
103
|
-
if random_seed is not None:
|
|
104
|
-
random.seed(random_seed)
|
|
105
|
-
np.random.seed(random_seed)
|
|
106
|
-
torch.manual_seed(random_seed)
|
|
107
|
-
|
|
108
|
-
def sample_ground_rules(self, all_ground_rules: List[GroundRule],
|
|
109
|
-
num_batches: int = 1,
|
|
110
|
-
sampling_strategy: str = "uniform") -> List[GroundRuleBatch]:
|
|
111
|
-
"""
|
|
112
|
-
Sample batches of ground rules
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
all_ground_rules: All available ground rules
|
|
116
|
-
num_batches: Number of batches to create
|
|
117
|
-
sampling_strategy: 'uniform', 'weighted', or 'stratified'
|
|
118
|
-
|
|
119
|
-
Returns:
|
|
120
|
-
List of GroundRuleBatch objects
|
|
121
|
-
"""
|
|
122
|
-
if not all_ground_rules:
|
|
123
|
-
return []
|
|
124
|
-
|
|
125
|
-
total_rules = len(all_ground_rules)
|
|
126
|
-
rules_per_batch = min(self.batch_size, total_rules // num_batches) if num_batches > 1 else min(self.batch_size, total_rules)
|
|
127
|
-
|
|
128
|
-
batches = []
|
|
129
|
-
|
|
130
|
-
for batch_idx in range(num_batches):
|
|
131
|
-
if sampling_strategy == "uniform":
|
|
132
|
-
sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
|
|
133
|
-
elif sampling_strategy == "weighted":
|
|
134
|
-
sampled_rules = self._weighted_sampling(all_ground_rules, rules_per_batch)
|
|
135
|
-
elif sampling_strategy == "stratified":
|
|
136
|
-
sampled_rules = self._stratified_sampling(all_ground_rules, rules_per_batch)
|
|
137
|
-
else:
|
|
138
|
-
sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
|
|
139
|
-
|
|
140
|
-
if sampled_rules:
|
|
141
|
-
batch = self._create_batch_from_rules(sampled_rules)
|
|
142
|
-
batches.append(batch)
|
|
143
|
-
|
|
144
|
-
logger.debug(f"Created {len(batches)} ground rule batches with avg size {rules_per_batch}")
|
|
145
|
-
return batches
|
|
146
|
-
|
|
147
|
-
def _uniform_sampling(self, ground_rules: List[GroundRule],
|
|
148
|
-
sample_size: int) -> List[GroundRule]:
|
|
149
|
-
"""Uniform random sampling of ground rules"""
|
|
150
|
-
if sample_size >= len(ground_rules):
|
|
151
|
-
return ground_rules.copy()
|
|
152
|
-
|
|
153
|
-
return random.sample(ground_rules, sample_size)
|
|
154
|
-
|
|
155
|
-
def _weighted_sampling(self, ground_rules: List[GroundRule],
|
|
156
|
-
sample_size: int) -> List[GroundRule]:
|
|
157
|
-
"""
|
|
158
|
-
Weighted sampling based on rule confidence/support
|
|
159
|
-
Higher confidence rules are more likely to be sampled
|
|
160
|
-
"""
|
|
161
|
-
if sample_size >= len(ground_rules):
|
|
162
|
-
return ground_rules.copy()
|
|
163
|
-
|
|
164
|
-
# Use parent rule confidence as weight
|
|
165
|
-
weights = [gr.parent_rule.confidence for gr in ground_rules]
|
|
166
|
-
|
|
167
|
-
# Normalize weights
|
|
168
|
-
total_weight = sum(weights)
|
|
169
|
-
if total_weight > 0:
|
|
170
|
-
weights = [w / total_weight for w in weights]
|
|
171
|
-
else:
|
|
172
|
-
weights = [1.0 / len(weights)] * len(weights)
|
|
173
|
-
|
|
174
|
-
# Sample with replacement
|
|
175
|
-
sampled_indices = np.random.choice(
|
|
176
|
-
len(ground_rules),
|
|
177
|
-
size=sample_size,
|
|
178
|
-
p=weights,
|
|
179
|
-
replace=False if sample_size <= len(ground_rules) else True
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
return [ground_rules[i] for i in sampled_indices]
|
|
183
|
-
|
|
184
|
-
def _stratified_sampling(self, ground_rules: List[GroundRule],
|
|
185
|
-
sample_size: int) -> List[GroundRule]:
|
|
186
|
-
"""
|
|
187
|
-
Stratified sampling ensuring representation from different rule types
|
|
188
|
-
"""
|
|
189
|
-
if sample_size >= len(ground_rules):
|
|
190
|
-
return ground_rules.copy()
|
|
191
|
-
|
|
192
|
-
# Group by parent rule type
|
|
193
|
-
rule_type_groups = defaultdict(list)
|
|
194
|
-
for gr in ground_rules:
|
|
195
|
-
rule_type_groups[gr.parent_rule.rule_type].append(gr)
|
|
196
|
-
|
|
197
|
-
# Sample proportionally from each group
|
|
198
|
-
sampled_rules = []
|
|
199
|
-
remaining_samples = sample_size
|
|
200
|
-
|
|
201
|
-
for rule_type, type_rules in rule_type_groups.items():
|
|
202
|
-
# Proportional allocation
|
|
203
|
-
group_sample_size = min(
|
|
204
|
-
len(type_rules),
|
|
205
|
-
max(1, int(remaining_samples * len(type_rules) / len(ground_rules)))
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
if group_sample_size > 0:
|
|
209
|
-
group_sample = random.sample(type_rules, group_sample_size)
|
|
210
|
-
sampled_rules.extend(group_sample)
|
|
211
|
-
remaining_samples -= group_sample_size
|
|
212
|
-
|
|
213
|
-
# If we need more samples, fill randomly
|
|
214
|
-
if remaining_samples > 0 and len(sampled_rules) < sample_size:
|
|
215
|
-
remaining_rules = [gr for gr in ground_rules if gr not in sampled_rules]
|
|
216
|
-
if remaining_rules:
|
|
217
|
-
additional_samples = min(remaining_samples, len(remaining_rules))
|
|
218
|
-
additional_rules = random.sample(remaining_rules, additional_samples)
|
|
219
|
-
sampled_rules.extend(additional_rules)
|
|
220
|
-
|
|
221
|
-
return sampled_rules[:sample_size]
|
|
222
|
-
|
|
223
|
-
def _create_batch_from_rules(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
|
|
224
|
-
"""Create GroundRuleBatch from list of ground rules"""
|
|
225
|
-
if not ground_rules:
|
|
226
|
-
return GroundRuleBatch(
|
|
227
|
-
ground_rules=[],
|
|
228
|
-
rule_indices=torch.tensor([]),
|
|
229
|
-
fact_indices={},
|
|
230
|
-
batch_facts=[],
|
|
231
|
-
batch_size=0
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
# Extract rule indices (assuming rules are indexed by their position in logical_rules list)
|
|
235
|
-
rule_indices = []
|
|
236
|
-
unique_facts = set()
|
|
237
|
-
|
|
238
|
-
# Build parent rule ID to index mapping (this should be provided by MLN)
|
|
239
|
-
rule_id_to_idx = {}
|
|
240
|
-
for i, gr in enumerate(ground_rules):
|
|
241
|
-
if gr.parent_rule.rule_id not in rule_id_to_idx:
|
|
242
|
-
rule_id_to_idx[gr.parent_rule.rule_id] = len(rule_id_to_idx)
|
|
243
|
-
|
|
244
|
-
rule_indices.append(rule_id_to_idx[gr.parent_rule.rule_id])
|
|
245
|
-
|
|
246
|
-
# Collect all unique facts
|
|
247
|
-
unique_facts.update(gr.get_all_facts())
|
|
248
|
-
|
|
249
|
-
# Create fact indexing
|
|
250
|
-
batch_facts = list(unique_facts)
|
|
251
|
-
fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
|
|
252
|
-
|
|
253
|
-
return GroundRuleBatch(
|
|
254
|
-
ground_rules=ground_rules,
|
|
255
|
-
rule_indices=torch.tensor(rule_indices, dtype=torch.long),
|
|
256
|
-
fact_indices=fact_indices,
|
|
257
|
-
batch_facts=batch_facts,
|
|
258
|
-
batch_size=len(ground_rules)
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
def create_batches_for_training(self, ground_rules: List[GroundRule],
|
|
262
|
-
shuffle: bool = True) -> List[GroundRuleBatch]:
|
|
263
|
-
"""
|
|
264
|
-
Create batches specifically for training
|
|
265
|
-
|
|
266
|
-
Args:
|
|
267
|
-
ground_rules: All ground rules to batch
|
|
268
|
-
shuffle: Whether to shuffle before batching
|
|
269
|
-
|
|
270
|
-
Returns:
|
|
271
|
-
List of training batches
|
|
272
|
-
"""
|
|
273
|
-
if not ground_rules:
|
|
274
|
-
return []
|
|
275
|
-
|
|
276
|
-
# Shuffle if requested
|
|
277
|
-
rules_to_batch = ground_rules.copy()
|
|
278
|
-
if shuffle:
|
|
279
|
-
random.shuffle(rules_to_batch)
|
|
280
|
-
|
|
281
|
-
# Create sequential batches
|
|
282
|
-
batches = []
|
|
283
|
-
for i in range(0, len(rules_to_batch), self.batch_size):
|
|
284
|
-
batch_rules = rules_to_batch[i:i + self.batch_size]
|
|
285
|
-
batch = self._create_batch_from_rules(batch_rules)
|
|
286
|
-
batches.append(batch)
|
|
287
|
-
|
|
288
|
-
return batches
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
class FactBatchProcessor:
|
|
292
|
-
"""
|
|
293
|
-
Processes facts in batches for efficient scoring and probability computation
|
|
294
|
-
"""
|
|
295
|
-
|
|
296
|
-
def __init__(self, config: NPLLConfig):
|
|
297
|
-
self.config = config
|
|
298
|
-
self.batch_size = config.batch_size
|
|
299
|
-
|
|
300
|
-
def create_fact_batches(self, facts: List[Triple],
|
|
301
|
-
batch_size: Optional[int] = None) -> List[List[Triple]]:
|
|
302
|
-
"""Create batches of facts for processing"""
|
|
303
|
-
batch_size = batch_size or self.batch_size
|
|
304
|
-
|
|
305
|
-
batches = []
|
|
306
|
-
for i in range(0, len(facts), batch_size):
|
|
307
|
-
batch = facts[i:i + batch_size]
|
|
308
|
-
batches.append(batch)
|
|
309
|
-
|
|
310
|
-
return batches
|
|
311
|
-
|
|
312
|
-
def process_fact_batches(self, fact_batches: List[List[Triple]],
|
|
313
|
-
processor_func) -> List[Any]:
|
|
314
|
-
"""Process batches using provided function"""
|
|
315
|
-
results = []
|
|
316
|
-
|
|
317
|
-
for batch in fact_batches:
|
|
318
|
-
batch_result = processor_func(batch)
|
|
319
|
-
results.append(batch_result)
|
|
320
|
-
|
|
321
|
-
return results
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
class MemoryEfficientBatcher:
|
|
325
|
-
"""
|
|
326
|
-
Memory-efficient batching for large-scale ground rule processing
|
|
327
|
-
Uses generators to avoid loading all data into memory
|
|
328
|
-
"""
|
|
329
|
-
|
|
330
|
-
def __init__(self, config: NPLLConfig):
|
|
331
|
-
self.config = config
|
|
332
|
-
self.batch_size = config.batch_size
|
|
333
|
-
|
|
334
|
-
def create_ground_rule_iterator(self, ground_rules: List[GroundRule],
|
|
335
|
-
shuffle: bool = True) -> Iterator[GroundRuleBatch]:
|
|
336
|
-
"""
|
|
337
|
-
Create iterator over ground rule batches for memory efficiency
|
|
338
|
-
|
|
339
|
-
Args:
|
|
340
|
-
ground_rules: All ground rules
|
|
341
|
-
shuffle: Whether to shuffle order
|
|
342
|
-
|
|
343
|
-
Yields:
|
|
344
|
-
GroundRuleBatch objects
|
|
345
|
-
"""
|
|
346
|
-
if shuffle:
|
|
347
|
-
indices = list(range(len(ground_rules)))
|
|
348
|
-
random.shuffle(indices)
|
|
349
|
-
else:
|
|
350
|
-
indices = list(range(len(ground_rules)))
|
|
351
|
-
|
|
352
|
-
for i in range(0, len(indices), self.batch_size):
|
|
353
|
-
batch_indices = indices[i:i + self.batch_size]
|
|
354
|
-
batch_rules = [ground_rules[idx] for idx in batch_indices]
|
|
355
|
-
|
|
356
|
-
# Create batch
|
|
357
|
-
batch = self._create_efficient_batch(batch_rules)
|
|
358
|
-
yield batch
|
|
359
|
-
|
|
360
|
-
def _create_efficient_batch(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
|
|
361
|
-
"""Create batch with minimal memory overhead"""
|
|
362
|
-
if not ground_rules:
|
|
363
|
-
return GroundRuleBatch([], torch.tensor([]), {}, [], 0)
|
|
364
|
-
|
|
365
|
-
# Efficient fact collection using sets
|
|
366
|
-
all_facts = set()
|
|
367
|
-
rule_indices = []
|
|
368
|
-
|
|
369
|
-
# Single pass to collect facts and rule indices
|
|
370
|
-
for i, gr in enumerate(ground_rules):
|
|
371
|
-
all_facts.update(gr.get_all_facts())
|
|
372
|
-
# Use hash of rule_id as index for efficiency
|
|
373
|
-
rule_indices.append(hash(gr.parent_rule.rule_id) % 1000)
|
|
374
|
-
|
|
375
|
-
batch_facts = list(all_facts)
|
|
376
|
-
fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
|
|
377
|
-
|
|
378
|
-
return GroundRuleBatch(
|
|
379
|
-
ground_rules=ground_rules,
|
|
380
|
-
rule_indices=torch.tensor(rule_indices, dtype=torch.long),
|
|
381
|
-
fact_indices=fact_indices,
|
|
382
|
-
batch_facts=batch_facts,
|
|
383
|
-
batch_size=len(ground_rules)
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
class AdaptiveBatcher:
|
|
388
|
-
"""
|
|
389
|
-
Adaptive batching that adjusts batch size based on memory usage and performance
|
|
390
|
-
"""
|
|
391
|
-
|
|
392
|
-
def __init__(self, config: NPLLConfig, initial_batch_size: Optional[int] = None):
|
|
393
|
-
self.config = config
|
|
394
|
-
self.current_batch_size = initial_batch_size or config.batch_size
|
|
395
|
-
self.min_batch_size = max(1, config.batch_size // 4)
|
|
396
|
-
self.max_batch_size = config.batch_size * 2
|
|
397
|
-
|
|
398
|
-
# Performance tracking
|
|
399
|
-
self.performance_history = []
|
|
400
|
-
self.memory_usage_history = []
|
|
401
|
-
|
|
402
|
-
def adapt_batch_size(self, processing_time: float, memory_usage: float,
|
|
403
|
-
target_time: float = 1.0):
|
|
404
|
-
"""
|
|
405
|
-
Adapt batch size based on performance metrics
|
|
406
|
-
|
|
407
|
-
Args:
|
|
408
|
-
processing_time: Time taken to process current batch
|
|
409
|
-
target_time: Target processing time per batch
|
|
410
|
-
memory_usage: Memory usage for current batch
|
|
411
|
-
"""
|
|
412
|
-
self.performance_history.append(processing_time)
|
|
413
|
-
self.memory_usage_history.append(memory_usage)
|
|
414
|
-
|
|
415
|
-
# Keep only recent history
|
|
416
|
-
max_history = 10
|
|
417
|
-
if len(self.performance_history) > max_history:
|
|
418
|
-
self.performance_history = self.performance_history[-max_history:]
|
|
419
|
-
self.memory_usage_history = self.memory_usage_history[-max_history:]
|
|
420
|
-
|
|
421
|
-
# Adjust based on performance
|
|
422
|
-
if processing_time > target_time * 1.5:
|
|
423
|
-
# Too slow, decrease batch size
|
|
424
|
-
new_batch_size = max(self.min_batch_size, int(self.current_batch_size * 0.8))
|
|
425
|
-
elif processing_time < target_time * 0.5:
|
|
426
|
-
# Too fast, increase batch size
|
|
427
|
-
new_batch_size = min(self.max_batch_size, int(self.current_batch_size * 1.2))
|
|
428
|
-
else:
|
|
429
|
-
# Good performance, keep current size
|
|
430
|
-
new_batch_size = self.current_batch_size
|
|
431
|
-
|
|
432
|
-
if new_batch_size != self.current_batch_size:
|
|
433
|
-
logger.debug(f"Adapted batch size from {self.current_batch_size} to {new_batch_size}")
|
|
434
|
-
self.current_batch_size = new_batch_size
|
|
435
|
-
|
|
436
|
-
def get_current_batch_size(self) -> int:
|
|
437
|
-
"""Get current adaptive batch size"""
|
|
438
|
-
return self.current_batch_size
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
def create_ground_rule_sampler(config: NPLLConfig, seed: Optional[int] = None) -> GroundRuleSampler:
|
|
442
|
-
"""Factory function to create ground rule sampler"""
|
|
443
|
-
return GroundRuleSampler(config, seed)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
def verify_batch_utils():
|
|
447
|
-
"""Verify batch utility implementations"""
|
|
448
|
-
from ..utils.config import default_config
|
|
449
|
-
from ..core import Entity, Relation, load_knowledge_graph_from_triples
|
|
450
|
-
from ..core.logical_rules import Variable, Atom, RuleType
|
|
451
|
-
|
|
452
|
-
# Create test data
|
|
453
|
-
test_triples = [
|
|
454
|
-
("A", "r1", "B"),
|
|
455
|
-
("B", "r2", "C"),
|
|
456
|
-
("A", "r3", "C")
|
|
457
|
-
]
|
|
458
|
-
|
|
459
|
-
kg = load_knowledge_graph_from_triples(test_triples)
|
|
460
|
-
|
|
461
|
-
# Create test rule and ground rules
|
|
462
|
-
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
463
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
464
|
-
|
|
465
|
-
test_rule = LogicalRule(
|
|
466
|
-
rule_id="test_rule",
|
|
467
|
-
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
468
|
-
head=Atom(r3, (x, z)),
|
|
469
|
-
rule_type=RuleType.TRANSITIVITY
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
ground_rules = test_rule.generate_ground_rules(kg, max_groundings=10)
|
|
473
|
-
|
|
474
|
-
# Test sampler
|
|
475
|
-
sampler = GroundRuleSampler(default_config, seed=42)
|
|
476
|
-
batches = sampler.sample_ground_rules(ground_rules, num_batches=2)
|
|
477
|
-
|
|
478
|
-
assert len(batches) <= 2, "Should create at most 2 batches"
|
|
479
|
-
|
|
480
|
-
for batch in batches:
|
|
481
|
-
assert batch.batch_size == len(batch.ground_rules), "Batch size consistency"
|
|
482
|
-
assert len(batch.rule_indices) == batch.batch_size, "Rule indices length"
|
|
483
|
-
|
|
484
|
-
# Test memory-efficient batcher
|
|
485
|
-
efficient_batcher = MemoryEfficientBatcher(default_config)
|
|
486
|
-
batch_iterator = efficient_batcher.create_ground_rule_iterator(ground_rules)
|
|
487
|
-
|
|
488
|
-
batches_from_iterator = list(batch_iterator)
|
|
489
|
-
assert len(batches_from_iterator) > 0, "Should create batches from iterator"
|
|
490
|
-
|
|
491
|
-
logger.info("Batch utilities verified successfully")
|
|
492
|
-
|
|
1
|
+
"""
|
|
2
|
+
Batch processing utilities for NPLL ground rules
|
|
3
|
+
Handles efficient batching and sampling of ground rules for MLN computations
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import List, Dict, Set, Tuple, Optional, Iterator, Any
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
import random
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
15
|
+
from ..utils.config import NPLLConfig
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class GroundRuleBatch:
|
|
22
|
+
"""
|
|
23
|
+
Batch of ground rules for efficient processing
|
|
24
|
+
|
|
25
|
+
Contains ground rules and associated metadata for batch operations
|
|
26
|
+
"""
|
|
27
|
+
ground_rules: List[GroundRule]
|
|
28
|
+
rule_indices: torch.Tensor # Which logical rule each ground rule belongs to
|
|
29
|
+
fact_indices: Dict[Triple, int] # Mapping from facts to batch indices
|
|
30
|
+
batch_facts: List[Triple] # All unique facts in this batch
|
|
31
|
+
batch_size: int
|
|
32
|
+
|
|
33
|
+
def __post_init__(self):
|
|
34
|
+
"""Validate batch consistency"""
|
|
35
|
+
assert len(self.ground_rules) == self.batch_size, \
|
|
36
|
+
f"Inconsistent batch size: {len(self.ground_rules)} vs {self.batch_size}"
|
|
37
|
+
|
|
38
|
+
assert len(self.rule_indices) == self.batch_size, \
|
|
39
|
+
f"Rule indices length mismatch: {len(self.rule_indices)} vs {self.batch_size}"
|
|
40
|
+
|
|
41
|
+
def get_fact_truth_matrix(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Create truth value matrix for facts in this batch
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tensor of shape [batch_size, max_facts_per_rule] with truth values
|
|
47
|
+
"""
|
|
48
|
+
max_facts = max(len(gr.get_all_facts()) for gr in self.ground_rules) if self.ground_rules else 0
|
|
49
|
+
|
|
50
|
+
if max_facts == 0:
|
|
51
|
+
return torch.zeros(self.batch_size, 0, dtype=torch.bool)
|
|
52
|
+
|
|
53
|
+
truth_matrix = torch.zeros(self.batch_size, max_facts, dtype=torch.bool)
|
|
54
|
+
|
|
55
|
+
for i, ground_rule in enumerate(self.ground_rules):
|
|
56
|
+
facts = ground_rule.get_all_facts()
|
|
57
|
+
for j, fact in enumerate(facts):
|
|
58
|
+
if j < max_facts:
|
|
59
|
+
truth_matrix[i, j] = fact_assignments.get(fact, False)
|
|
60
|
+
|
|
61
|
+
return truth_matrix
|
|
62
|
+
|
|
63
|
+
def evaluate_ground_rules(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
64
|
+
"""
|
|
65
|
+
Evaluate all ground rules in batch
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Boolean tensor indicating which ground rules are satisfied
|
|
69
|
+
"""
|
|
70
|
+
satisfaction = torch.zeros(self.batch_size, dtype=torch.bool)
|
|
71
|
+
|
|
72
|
+
for i, ground_rule in enumerate(self.ground_rules):
|
|
73
|
+
# Check body satisfaction
|
|
74
|
+
body_satisfied = all(
|
|
75
|
+
fact_assignments.get(fact, False)
|
|
76
|
+
for fact in ground_rule.body_facts
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if not body_satisfied:
|
|
80
|
+
# Body false -> rule vacuously true
|
|
81
|
+
satisfaction[i] = True
|
|
82
|
+
else:
|
|
83
|
+
# Body true -> check head
|
|
84
|
+
head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
|
|
85
|
+
satisfaction[i] = head_satisfied
|
|
86
|
+
|
|
87
|
+
return satisfaction
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class GroundRuleSampler:
|
|
91
|
+
"""
|
|
92
|
+
Samples ground rules for efficient MLN training and inference
|
|
93
|
+
|
|
94
|
+
Paper Section 4.2: "this paper randomly samples batches of ground rules to form datasets,
|
|
95
|
+
wherein the ground rules are approximately independent of each batch"
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, config: NPLLConfig, random_seed: Optional[int] = None):
|
|
99
|
+
self.config = config
|
|
100
|
+
self.batch_size = config.batch_size
|
|
101
|
+
self.max_ground_rules = config.max_ground_rules
|
|
102
|
+
|
|
103
|
+
if random_seed is not None:
|
|
104
|
+
random.seed(random_seed)
|
|
105
|
+
np.random.seed(random_seed)
|
|
106
|
+
torch.manual_seed(random_seed)
|
|
107
|
+
|
|
108
|
+
def sample_ground_rules(self, all_ground_rules: List[GroundRule],
|
|
109
|
+
num_batches: int = 1,
|
|
110
|
+
sampling_strategy: str = "uniform") -> List[GroundRuleBatch]:
|
|
111
|
+
"""
|
|
112
|
+
Sample batches of ground rules
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
all_ground_rules: All available ground rules
|
|
116
|
+
num_batches: Number of batches to create
|
|
117
|
+
sampling_strategy: 'uniform', 'weighted', or 'stratified'
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of GroundRuleBatch objects
|
|
121
|
+
"""
|
|
122
|
+
if not all_ground_rules:
|
|
123
|
+
return []
|
|
124
|
+
|
|
125
|
+
total_rules = len(all_ground_rules)
|
|
126
|
+
rules_per_batch = min(self.batch_size, total_rules // num_batches) if num_batches > 1 else min(self.batch_size, total_rules)
|
|
127
|
+
|
|
128
|
+
batches = []
|
|
129
|
+
|
|
130
|
+
for batch_idx in range(num_batches):
|
|
131
|
+
if sampling_strategy == "uniform":
|
|
132
|
+
sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
|
|
133
|
+
elif sampling_strategy == "weighted":
|
|
134
|
+
sampled_rules = self._weighted_sampling(all_ground_rules, rules_per_batch)
|
|
135
|
+
elif sampling_strategy == "stratified":
|
|
136
|
+
sampled_rules = self._stratified_sampling(all_ground_rules, rules_per_batch)
|
|
137
|
+
else:
|
|
138
|
+
sampled_rules = self._uniform_sampling(all_ground_rules, rules_per_batch)
|
|
139
|
+
|
|
140
|
+
if sampled_rules:
|
|
141
|
+
batch = self._create_batch_from_rules(sampled_rules)
|
|
142
|
+
batches.append(batch)
|
|
143
|
+
|
|
144
|
+
logger.debug(f"Created {len(batches)} ground rule batches with avg size {rules_per_batch}")
|
|
145
|
+
return batches
|
|
146
|
+
|
|
147
|
+
def _uniform_sampling(self, ground_rules: List[GroundRule],
|
|
148
|
+
sample_size: int) -> List[GroundRule]:
|
|
149
|
+
"""Uniform random sampling of ground rules"""
|
|
150
|
+
if sample_size >= len(ground_rules):
|
|
151
|
+
return ground_rules.copy()
|
|
152
|
+
|
|
153
|
+
return random.sample(ground_rules, sample_size)
|
|
154
|
+
|
|
155
|
+
def _weighted_sampling(self, ground_rules: List[GroundRule],
|
|
156
|
+
sample_size: int) -> List[GroundRule]:
|
|
157
|
+
"""
|
|
158
|
+
Weighted sampling based on rule confidence/support
|
|
159
|
+
Higher confidence rules are more likely to be sampled
|
|
160
|
+
"""
|
|
161
|
+
if sample_size >= len(ground_rules):
|
|
162
|
+
return ground_rules.copy()
|
|
163
|
+
|
|
164
|
+
# Use parent rule confidence as weight
|
|
165
|
+
weights = [gr.parent_rule.confidence for gr in ground_rules]
|
|
166
|
+
|
|
167
|
+
# Normalize weights
|
|
168
|
+
total_weight = sum(weights)
|
|
169
|
+
if total_weight > 0:
|
|
170
|
+
weights = [w / total_weight for w in weights]
|
|
171
|
+
else:
|
|
172
|
+
weights = [1.0 / len(weights)] * len(weights)
|
|
173
|
+
|
|
174
|
+
# Sample with replacement
|
|
175
|
+
sampled_indices = np.random.choice(
|
|
176
|
+
len(ground_rules),
|
|
177
|
+
size=sample_size,
|
|
178
|
+
p=weights,
|
|
179
|
+
replace=False if sample_size <= len(ground_rules) else True
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return [ground_rules[i] for i in sampled_indices]
|
|
183
|
+
|
|
184
|
+
def _stratified_sampling(self, ground_rules: List[GroundRule],
|
|
185
|
+
sample_size: int) -> List[GroundRule]:
|
|
186
|
+
"""
|
|
187
|
+
Stratified sampling ensuring representation from different rule types
|
|
188
|
+
"""
|
|
189
|
+
if sample_size >= len(ground_rules):
|
|
190
|
+
return ground_rules.copy()
|
|
191
|
+
|
|
192
|
+
# Group by parent rule type
|
|
193
|
+
rule_type_groups = defaultdict(list)
|
|
194
|
+
for gr in ground_rules:
|
|
195
|
+
rule_type_groups[gr.parent_rule.rule_type].append(gr)
|
|
196
|
+
|
|
197
|
+
# Sample proportionally from each group
|
|
198
|
+
sampled_rules = []
|
|
199
|
+
remaining_samples = sample_size
|
|
200
|
+
|
|
201
|
+
for rule_type, type_rules in rule_type_groups.items():
|
|
202
|
+
# Proportional allocation
|
|
203
|
+
group_sample_size = min(
|
|
204
|
+
len(type_rules),
|
|
205
|
+
max(1, int(remaining_samples * len(type_rules) / len(ground_rules)))
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if group_sample_size > 0:
|
|
209
|
+
group_sample = random.sample(type_rules, group_sample_size)
|
|
210
|
+
sampled_rules.extend(group_sample)
|
|
211
|
+
remaining_samples -= group_sample_size
|
|
212
|
+
|
|
213
|
+
# If we need more samples, fill randomly
|
|
214
|
+
if remaining_samples > 0 and len(sampled_rules) < sample_size:
|
|
215
|
+
remaining_rules = [gr for gr in ground_rules if gr not in sampled_rules]
|
|
216
|
+
if remaining_rules:
|
|
217
|
+
additional_samples = min(remaining_samples, len(remaining_rules))
|
|
218
|
+
additional_rules = random.sample(remaining_rules, additional_samples)
|
|
219
|
+
sampled_rules.extend(additional_rules)
|
|
220
|
+
|
|
221
|
+
return sampled_rules[:sample_size]
|
|
222
|
+
|
|
223
|
+
def _create_batch_from_rules(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
|
|
224
|
+
"""Create GroundRuleBatch from list of ground rules"""
|
|
225
|
+
if not ground_rules:
|
|
226
|
+
return GroundRuleBatch(
|
|
227
|
+
ground_rules=[],
|
|
228
|
+
rule_indices=torch.tensor([]),
|
|
229
|
+
fact_indices={},
|
|
230
|
+
batch_facts=[],
|
|
231
|
+
batch_size=0
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Extract rule indices (assuming rules are indexed by their position in logical_rules list)
|
|
235
|
+
rule_indices = []
|
|
236
|
+
unique_facts = set()
|
|
237
|
+
|
|
238
|
+
# Build parent rule ID to index mapping (this should be provided by MLN)
|
|
239
|
+
rule_id_to_idx = {}
|
|
240
|
+
for i, gr in enumerate(ground_rules):
|
|
241
|
+
if gr.parent_rule.rule_id not in rule_id_to_idx:
|
|
242
|
+
rule_id_to_idx[gr.parent_rule.rule_id] = len(rule_id_to_idx)
|
|
243
|
+
|
|
244
|
+
rule_indices.append(rule_id_to_idx[gr.parent_rule.rule_id])
|
|
245
|
+
|
|
246
|
+
# Collect all unique facts
|
|
247
|
+
unique_facts.update(gr.get_all_facts())
|
|
248
|
+
|
|
249
|
+
# Create fact indexing
|
|
250
|
+
batch_facts = list(unique_facts)
|
|
251
|
+
fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
|
|
252
|
+
|
|
253
|
+
return GroundRuleBatch(
|
|
254
|
+
ground_rules=ground_rules,
|
|
255
|
+
rule_indices=torch.tensor(rule_indices, dtype=torch.long),
|
|
256
|
+
fact_indices=fact_indices,
|
|
257
|
+
batch_facts=batch_facts,
|
|
258
|
+
batch_size=len(ground_rules)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def create_batches_for_training(self, ground_rules: List[GroundRule],
|
|
262
|
+
shuffle: bool = True) -> List[GroundRuleBatch]:
|
|
263
|
+
"""
|
|
264
|
+
Create batches specifically for training
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
ground_rules: All ground rules to batch
|
|
268
|
+
shuffle: Whether to shuffle before batching
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
List of training batches
|
|
272
|
+
"""
|
|
273
|
+
if not ground_rules:
|
|
274
|
+
return []
|
|
275
|
+
|
|
276
|
+
# Shuffle if requested
|
|
277
|
+
rules_to_batch = ground_rules.copy()
|
|
278
|
+
if shuffle:
|
|
279
|
+
random.shuffle(rules_to_batch)
|
|
280
|
+
|
|
281
|
+
# Create sequential batches
|
|
282
|
+
batches = []
|
|
283
|
+
for i in range(0, len(rules_to_batch), self.batch_size):
|
|
284
|
+
batch_rules = rules_to_batch[i:i + self.batch_size]
|
|
285
|
+
batch = self._create_batch_from_rules(batch_rules)
|
|
286
|
+
batches.append(batch)
|
|
287
|
+
|
|
288
|
+
return batches
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class FactBatchProcessor:
|
|
292
|
+
"""
|
|
293
|
+
Processes facts in batches for efficient scoring and probability computation
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def __init__(self, config: NPLLConfig):
|
|
297
|
+
self.config = config
|
|
298
|
+
self.batch_size = config.batch_size
|
|
299
|
+
|
|
300
|
+
def create_fact_batches(self, facts: List[Triple],
|
|
301
|
+
batch_size: Optional[int] = None) -> List[List[Triple]]:
|
|
302
|
+
"""Create batches of facts for processing"""
|
|
303
|
+
batch_size = batch_size or self.batch_size
|
|
304
|
+
|
|
305
|
+
batches = []
|
|
306
|
+
for i in range(0, len(facts), batch_size):
|
|
307
|
+
batch = facts[i:i + batch_size]
|
|
308
|
+
batches.append(batch)
|
|
309
|
+
|
|
310
|
+
return batches
|
|
311
|
+
|
|
312
|
+
def process_fact_batches(self, fact_batches: List[List[Triple]],
|
|
313
|
+
processor_func) -> List[Any]:
|
|
314
|
+
"""Process batches using provided function"""
|
|
315
|
+
results = []
|
|
316
|
+
|
|
317
|
+
for batch in fact_batches:
|
|
318
|
+
batch_result = processor_func(batch)
|
|
319
|
+
results.append(batch_result)
|
|
320
|
+
|
|
321
|
+
return results
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class MemoryEfficientBatcher:
|
|
325
|
+
"""
|
|
326
|
+
Memory-efficient batching for large-scale ground rule processing
|
|
327
|
+
Uses generators to avoid loading all data into memory
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
def __init__(self, config: NPLLConfig):
|
|
331
|
+
self.config = config
|
|
332
|
+
self.batch_size = config.batch_size
|
|
333
|
+
|
|
334
|
+
def create_ground_rule_iterator(self, ground_rules: List[GroundRule],
|
|
335
|
+
shuffle: bool = True) -> Iterator[GroundRuleBatch]:
|
|
336
|
+
"""
|
|
337
|
+
Create iterator over ground rule batches for memory efficiency
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
ground_rules: All ground rules
|
|
341
|
+
shuffle: Whether to shuffle order
|
|
342
|
+
|
|
343
|
+
Yields:
|
|
344
|
+
GroundRuleBatch objects
|
|
345
|
+
"""
|
|
346
|
+
if shuffle:
|
|
347
|
+
indices = list(range(len(ground_rules)))
|
|
348
|
+
random.shuffle(indices)
|
|
349
|
+
else:
|
|
350
|
+
indices = list(range(len(ground_rules)))
|
|
351
|
+
|
|
352
|
+
for i in range(0, len(indices), self.batch_size):
|
|
353
|
+
batch_indices = indices[i:i + self.batch_size]
|
|
354
|
+
batch_rules = [ground_rules[idx] for idx in batch_indices]
|
|
355
|
+
|
|
356
|
+
# Create batch
|
|
357
|
+
batch = self._create_efficient_batch(batch_rules)
|
|
358
|
+
yield batch
|
|
359
|
+
|
|
360
|
+
def _create_efficient_batch(self, ground_rules: List[GroundRule]) -> GroundRuleBatch:
|
|
361
|
+
"""Create batch with minimal memory overhead"""
|
|
362
|
+
if not ground_rules:
|
|
363
|
+
return GroundRuleBatch([], torch.tensor([]), {}, [], 0)
|
|
364
|
+
|
|
365
|
+
# Efficient fact collection using sets
|
|
366
|
+
all_facts = set()
|
|
367
|
+
rule_indices = []
|
|
368
|
+
|
|
369
|
+
# Single pass to collect facts and rule indices
|
|
370
|
+
for i, gr in enumerate(ground_rules):
|
|
371
|
+
all_facts.update(gr.get_all_facts())
|
|
372
|
+
# Use hash of rule_id as index for efficiency
|
|
373
|
+
rule_indices.append(hash(gr.parent_rule.rule_id) % 1000)
|
|
374
|
+
|
|
375
|
+
batch_facts = list(all_facts)
|
|
376
|
+
fact_indices = {fact: i for i, fact in enumerate(batch_facts)}
|
|
377
|
+
|
|
378
|
+
return GroundRuleBatch(
|
|
379
|
+
ground_rules=ground_rules,
|
|
380
|
+
rule_indices=torch.tensor(rule_indices, dtype=torch.long),
|
|
381
|
+
fact_indices=fact_indices,
|
|
382
|
+
batch_facts=batch_facts,
|
|
383
|
+
batch_size=len(ground_rules)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class AdaptiveBatcher:
|
|
388
|
+
"""
|
|
389
|
+
Adaptive batching that adjusts batch size based on memory usage and performance
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
def __init__(self, config: NPLLConfig, initial_batch_size: Optional[int] = None):
|
|
393
|
+
self.config = config
|
|
394
|
+
self.current_batch_size = initial_batch_size or config.batch_size
|
|
395
|
+
self.min_batch_size = max(1, config.batch_size // 4)
|
|
396
|
+
self.max_batch_size = config.batch_size * 2
|
|
397
|
+
|
|
398
|
+
# Performance tracking
|
|
399
|
+
self.performance_history = []
|
|
400
|
+
self.memory_usage_history = []
|
|
401
|
+
|
|
402
|
+
def adapt_batch_size(self, processing_time: float, memory_usage: float,
|
|
403
|
+
target_time: float = 1.0):
|
|
404
|
+
"""
|
|
405
|
+
Adapt batch size based on performance metrics
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
processing_time: Time taken to process current batch
|
|
409
|
+
target_time: Target processing time per batch
|
|
410
|
+
memory_usage: Memory usage for current batch
|
|
411
|
+
"""
|
|
412
|
+
self.performance_history.append(processing_time)
|
|
413
|
+
self.memory_usage_history.append(memory_usage)
|
|
414
|
+
|
|
415
|
+
# Keep only recent history
|
|
416
|
+
max_history = 10
|
|
417
|
+
if len(self.performance_history) > max_history:
|
|
418
|
+
self.performance_history = self.performance_history[-max_history:]
|
|
419
|
+
self.memory_usage_history = self.memory_usage_history[-max_history:]
|
|
420
|
+
|
|
421
|
+
# Adjust based on performance
|
|
422
|
+
if processing_time > target_time * 1.5:
|
|
423
|
+
# Too slow, decrease batch size
|
|
424
|
+
new_batch_size = max(self.min_batch_size, int(self.current_batch_size * 0.8))
|
|
425
|
+
elif processing_time < target_time * 0.5:
|
|
426
|
+
# Too fast, increase batch size
|
|
427
|
+
new_batch_size = min(self.max_batch_size, int(self.current_batch_size * 1.2))
|
|
428
|
+
else:
|
|
429
|
+
# Good performance, keep current size
|
|
430
|
+
new_batch_size = self.current_batch_size
|
|
431
|
+
|
|
432
|
+
if new_batch_size != self.current_batch_size:
|
|
433
|
+
logger.debug(f"Adapted batch size from {self.current_batch_size} to {new_batch_size}")
|
|
434
|
+
self.current_batch_size = new_batch_size
|
|
435
|
+
|
|
436
|
+
def get_current_batch_size(self) -> int:
|
|
437
|
+
"""Get current adaptive batch size"""
|
|
438
|
+
return self.current_batch_size
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def create_ground_rule_sampler(config: NPLLConfig, seed: Optional[int] = None) -> GroundRuleSampler:
|
|
442
|
+
"""Factory function to create ground rule sampler"""
|
|
443
|
+
return GroundRuleSampler(config, seed)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def verify_batch_utils():
|
|
447
|
+
"""Verify batch utility implementations"""
|
|
448
|
+
from ..utils.config import default_config
|
|
449
|
+
from ..core import Entity, Relation, load_knowledge_graph_from_triples
|
|
450
|
+
from ..core.logical_rules import Variable, Atom, RuleType
|
|
451
|
+
|
|
452
|
+
# Create test data
|
|
453
|
+
test_triples = [
|
|
454
|
+
("A", "r1", "B"),
|
|
455
|
+
("B", "r2", "C"),
|
|
456
|
+
("A", "r3", "C")
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
kg = load_knowledge_graph_from_triples(test_triples)
|
|
460
|
+
|
|
461
|
+
# Create test rule and ground rules
|
|
462
|
+
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
463
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
464
|
+
|
|
465
|
+
test_rule = LogicalRule(
|
|
466
|
+
rule_id="test_rule",
|
|
467
|
+
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
468
|
+
head=Atom(r3, (x, z)),
|
|
469
|
+
rule_type=RuleType.TRANSITIVITY
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
ground_rules = test_rule.generate_ground_rules(kg, max_groundings=10)
|
|
473
|
+
|
|
474
|
+
# Test sampler
|
|
475
|
+
sampler = GroundRuleSampler(default_config, seed=42)
|
|
476
|
+
batches = sampler.sample_ground_rules(ground_rules, num_batches=2)
|
|
477
|
+
|
|
478
|
+
assert len(batches) <= 2, "Should create at most 2 batches"
|
|
479
|
+
|
|
480
|
+
for batch in batches:
|
|
481
|
+
assert batch.batch_size == len(batch.ground_rules), "Batch size consistency"
|
|
482
|
+
assert len(batch.rule_indices) == batch.batch_size, "Rule indices length"
|
|
483
|
+
|
|
484
|
+
# Test memory-efficient batcher
|
|
485
|
+
efficient_batcher = MemoryEfficientBatcher(default_config)
|
|
486
|
+
batch_iterator = efficient_batcher.create_ground_rule_iterator(ground_rules)
|
|
487
|
+
|
|
488
|
+
batches_from_iterator = list(batch_iterator)
|
|
489
|
+
assert len(batches_from_iterator) > 0, "Should create batches from iterator"
|
|
490
|
+
|
|
491
|
+
logger.info("Batch utilities verified successfully")
|
|
492
|
+
|
|
493
493
|
return True
|