odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
benchmarks/__init__.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Odin Benchmarks: Academic validation against standard KG datasets.
|
|
3
|
-
|
|
4
|
-
This module provides:
|
|
5
|
-
- Standard dataset loaders (FB15k-237, WN18RR)
|
|
6
|
-
- KG completion metrics (MRR, Hits@K)
|
|
7
|
-
- Benchmark runners for NPLL evaluation
|
|
8
|
-
- Ablation study tools
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
from .metrics import mrr, hits_at_k, evaluate_rankings
|
|
12
|
-
from .datasets import load_fb15k237, load_wn18rr, BenchmarkDataset
|
|
13
|
-
|
|
14
|
-
__all__ = [
|
|
15
|
-
"mrr", "hits_at_k", "evaluate_rankings",
|
|
16
|
-
"load_fb15k237", "load_wn18rr", "BenchmarkDataset"
|
|
17
|
-
]
|
|
1
|
+
"""
|
|
2
|
+
Odin Benchmarks: Academic validation against standard KG datasets.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Standard dataset loaders (FB15k-237, WN18RR)
|
|
6
|
+
- KG completion metrics (MRR, Hits@K)
|
|
7
|
+
- Benchmark runners for NPLL evaluation
|
|
8
|
+
- Ablation study tools
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .metrics import mrr, hits_at_k, evaluate_rankings
|
|
12
|
+
from .datasets import load_fb15k237, load_wn18rr, BenchmarkDataset
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"mrr", "hits_at_k", "evaluate_rankings",
|
|
16
|
+
"load_fb15k237", "load_wn18rr", "BenchmarkDataset"
|
|
17
|
+
]
|
benchmarks/datasets.py
CHANGED
|
@@ -1,284 +1,284 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Standard Knowledge Graph Benchmark Datasets
|
|
3
|
-
|
|
4
|
-
Provides loaders for:
|
|
5
|
-
- FB15k-237: Freebase subset (14,541 entities, 237 relations)
|
|
6
|
-
- WN18RR: WordNet subset (40,943 entities, 11 relations)
|
|
7
|
-
|
|
8
|
-
Datasets are downloaded from standard sources and cached locally.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
import os
|
|
12
|
-
import urllib.request
|
|
13
|
-
import tarfile
|
|
14
|
-
import zipfile
|
|
15
|
-
from pathlib import Path
|
|
16
|
-
from typing import List, Tuple, Set, Dict, Optional
|
|
17
|
-
from dataclasses import dataclass
|
|
18
|
-
import logging
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger(__name__)
|
|
21
|
-
|
|
22
|
-
# Dataset URLs - using villmow/datasets_knowledge_embedding (reliable raw files)
|
|
23
|
-
DATASET_BASE_URLS = {
|
|
24
|
-
"fb15k237": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237",
|
|
25
|
-
"wn18rr": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR",
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
# Default cache directory
|
|
29
|
-
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "odin_benchmarks"
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@dataclass
|
|
33
|
-
class BenchmarkDataset:
|
|
34
|
-
"""Container for a benchmark dataset."""
|
|
35
|
-
name: str
|
|
36
|
-
train_triples: List[Tuple[str, str, str]]
|
|
37
|
-
valid_triples: List[Tuple[str, str, str]]
|
|
38
|
-
test_triples: List[Tuple[str, str, str]]
|
|
39
|
-
entities: List[str]
|
|
40
|
-
relations: List[str]
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def num_entities(self) -> int:
|
|
44
|
-
return len(self.entities)
|
|
45
|
-
|
|
46
|
-
@property
|
|
47
|
-
def num_relations(self) -> int:
|
|
48
|
-
return len(self.relations)
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def num_train(self) -> int:
|
|
52
|
-
return len(self.train_triples)
|
|
53
|
-
|
|
54
|
-
@property
|
|
55
|
-
def num_valid(self) -> int:
|
|
56
|
-
return len(self.valid_triples)
|
|
57
|
-
|
|
58
|
-
@property
|
|
59
|
-
def num_test(self) -> int:
|
|
60
|
-
return len(self.test_triples)
|
|
61
|
-
|
|
62
|
-
def get_train_set(self) -> Set[Tuple[str, str, str]]:
|
|
63
|
-
return set(self.train_triples)
|
|
64
|
-
|
|
65
|
-
def get_all_triples(self) -> Set[Tuple[str, str, str]]:
|
|
66
|
-
return set(self.train_triples + self.valid_triples + self.test_triples)
|
|
67
|
-
|
|
68
|
-
def __repr__(self) -> str:
|
|
69
|
-
return (
|
|
70
|
-
f"BenchmarkDataset({self.name})\n"
|
|
71
|
-
f" Entities: {self.num_entities:,}\n"
|
|
72
|
-
f" Relations: {self.num_relations}\n"
|
|
73
|
-
f" Train: {self.num_train:,}\n"
|
|
74
|
-
f" Valid: {self.num_valid:,}\n"
|
|
75
|
-
f" Test: {self.num_test:,}"
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _ensure_dir(path: Path):
|
|
80
|
-
"""Ensure directory exists."""
|
|
81
|
-
path.mkdir(parents=True, exist_ok=True)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def _download_file(url: str, dest: Path):
|
|
85
|
-
"""Download a file with progress."""
|
|
86
|
-
import ssl
|
|
87
|
-
import certifi
|
|
88
|
-
|
|
89
|
-
logger.info(f"Downloading {url}...")
|
|
90
|
-
|
|
91
|
-
# Try with certifi SSL context first, fall back to unverified
|
|
92
|
-
try:
|
|
93
|
-
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|
94
|
-
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
95
|
-
with open(dest, 'wb') as out_file:
|
|
96
|
-
out_file.write(response.read())
|
|
97
|
-
except (ImportError, ssl.SSLError):
|
|
98
|
-
# Fallback: disable SSL verification (for development only)
|
|
99
|
-
logger.warning("SSL verification disabled - using unverified context")
|
|
100
|
-
ssl_context = ssl.create_default_context()
|
|
101
|
-
ssl_context.check_hostname = False
|
|
102
|
-
ssl_context.verify_mode = ssl.CERT_NONE
|
|
103
|
-
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
104
|
-
with open(dest, 'wb') as out_file:
|
|
105
|
-
out_file.write(response.read())
|
|
106
|
-
|
|
107
|
-
logger.info(f"Downloaded to {dest}")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def _extract_tar_gz(archive: Path, dest_dir: Path):
|
|
111
|
-
"""Extract a tar.gz archive."""
|
|
112
|
-
logger.info(f"Extracting {archive}...")
|
|
113
|
-
with tarfile.open(archive, "r:gz") as tar:
|
|
114
|
-
tar.extractall(dest_dir)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _extract_zip(archive: Path, dest_dir: Path):
|
|
118
|
-
"""Extract a zip archive."""
|
|
119
|
-
logger.info(f"Extracting {archive}...")
|
|
120
|
-
with zipfile.ZipFile(archive, 'r') as zip_ref:
|
|
121
|
-
zip_ref.extractall(dest_dir)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def _load_triples(filepath: Path) -> List[Tuple[str, str, str]]:
|
|
125
|
-
"""Load triples from a TSV file (head, relation, tail)."""
|
|
126
|
-
triples = []
|
|
127
|
-
with open(filepath, "r", encoding="utf-8") as f:
|
|
128
|
-
for line in f:
|
|
129
|
-
parts = line.strip().split("\t")
|
|
130
|
-
if len(parts) == 3:
|
|
131
|
-
h, r, t = parts
|
|
132
|
-
triples.append((h, r, t))
|
|
133
|
-
return triples
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def _extract_entities_and_relations(
|
|
137
|
-
triples: List[Tuple[str, str, str]]
|
|
138
|
-
) -> Tuple[List[str], List[str]]:
|
|
139
|
-
"""Extract unique entities and relations from triples."""
|
|
140
|
-
entities = set()
|
|
141
|
-
relations = set()
|
|
142
|
-
for h, r, t in triples:
|
|
143
|
-
entities.add(h)
|
|
144
|
-
entities.add(t)
|
|
145
|
-
relations.add(r)
|
|
146
|
-
return sorted(entities), sorted(relations)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def _download_dataset_files(base_url: str, dataset_dir: Path):
|
|
150
|
-
"""Download train/valid/test files for a dataset."""
|
|
151
|
-
_ensure_dir(dataset_dir)
|
|
152
|
-
|
|
153
|
-
for split in ["train", "valid", "test"]:
|
|
154
|
-
file_path = dataset_dir / f"{split}.txt"
|
|
155
|
-
if not file_path.exists():
|
|
156
|
-
url = f"{base_url}/{split}.txt"
|
|
157
|
-
_download_file(url, file_path)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def load_fb15k237(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
161
|
-
"""
|
|
162
|
-
Load FB15k-237 dataset.
|
|
163
|
-
|
|
164
|
-
FB15k-237 is a subset of Freebase with:
|
|
165
|
-
- 14,541 entities
|
|
166
|
-
- 237 relations
|
|
167
|
-
- 310,116 triples
|
|
168
|
-
|
|
169
|
-
This version removes inverse relations from FB15k to prevent
|
|
170
|
-
data leakage during evaluation.
|
|
171
|
-
|
|
172
|
-
Args:
|
|
173
|
-
cache_dir: Directory to cache downloaded data
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
BenchmarkDataset object
|
|
177
|
-
"""
|
|
178
|
-
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
179
|
-
_ensure_dir(cache_dir)
|
|
180
|
-
|
|
181
|
-
dataset_dir = cache_dir / "FB15k-237"
|
|
182
|
-
|
|
183
|
-
# Download if not cached
|
|
184
|
-
if not (dataset_dir / "train.txt").exists():
|
|
185
|
-
_download_dataset_files(DATASET_BASE_URLS["fb15k237"], dataset_dir)
|
|
186
|
-
|
|
187
|
-
# Load splits
|
|
188
|
-
train = _load_triples(dataset_dir / "train.txt")
|
|
189
|
-
valid = _load_triples(dataset_dir / "valid.txt")
|
|
190
|
-
test = _load_triples(dataset_dir / "test.txt")
|
|
191
|
-
|
|
192
|
-
# Extract vocab
|
|
193
|
-
all_triples = train + valid + test
|
|
194
|
-
entities, relations = _extract_entities_and_relations(all_triples)
|
|
195
|
-
|
|
196
|
-
return BenchmarkDataset(
|
|
197
|
-
name="FB15k-237",
|
|
198
|
-
train_triples=train,
|
|
199
|
-
valid_triples=valid,
|
|
200
|
-
test_triples=test,
|
|
201
|
-
entities=entities,
|
|
202
|
-
relations=relations,
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
def load_wn18rr(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
207
|
-
"""
|
|
208
|
-
Load WN18RR dataset.
|
|
209
|
-
|
|
210
|
-
WN18RR is a subset of WordNet with:
|
|
211
|
-
- 40,943 entities
|
|
212
|
-
- 11 relations
|
|
213
|
-
- 93,003 triples
|
|
214
|
-
|
|
215
|
-
This version removes inverse relations from WN18 to prevent
|
|
216
|
-
data leakage during evaluation.
|
|
217
|
-
|
|
218
|
-
Args:
|
|
219
|
-
cache_dir: Directory to cache downloaded data
|
|
220
|
-
|
|
221
|
-
Returns:
|
|
222
|
-
BenchmarkDataset object
|
|
223
|
-
"""
|
|
224
|
-
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
225
|
-
_ensure_dir(cache_dir)
|
|
226
|
-
|
|
227
|
-
dataset_dir = cache_dir / "WN18RR"
|
|
228
|
-
|
|
229
|
-
# Download if not cached
|
|
230
|
-
if not (dataset_dir / "train.txt").exists():
|
|
231
|
-
_download_dataset_files(DATASET_BASE_URLS["wn18rr"], dataset_dir)
|
|
232
|
-
|
|
233
|
-
# Load splits
|
|
234
|
-
train = _load_triples(dataset_dir / "train.txt")
|
|
235
|
-
valid = _load_triples(dataset_dir / "valid.txt")
|
|
236
|
-
test = _load_triples(dataset_dir / "test.txt")
|
|
237
|
-
|
|
238
|
-
# Extract vocab
|
|
239
|
-
all_triples = train + valid + test
|
|
240
|
-
entities, relations = _extract_entities_and_relations(all_triples)
|
|
241
|
-
|
|
242
|
-
return BenchmarkDataset(
|
|
243
|
-
name="WN18RR",
|
|
244
|
-
train_triples=train,
|
|
245
|
-
valid_triples=valid,
|
|
246
|
-
test_triples=test,
|
|
247
|
-
entities=entities,
|
|
248
|
-
relations=relations,
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def dataset_to_kg(dataset: BenchmarkDataset):
|
|
253
|
-
"""
|
|
254
|
-
Convert BenchmarkDataset to Odin KnowledgeGraph.
|
|
255
|
-
|
|
256
|
-
Args:
|
|
257
|
-
dataset: BenchmarkDataset object
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
KnowledgeGraph object suitable for NPLL training
|
|
261
|
-
"""
|
|
262
|
-
from npll.core import KnowledgeGraph
|
|
263
|
-
|
|
264
|
-
kg = KnowledgeGraph()
|
|
265
|
-
|
|
266
|
-
# Add all training triples as known facts
|
|
267
|
-
for h, r, t in dataset.train_triples:
|
|
268
|
-
kg.add_known_fact(h, r, t)
|
|
269
|
-
|
|
270
|
-
return kg
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
# Quick test
|
|
274
|
-
if __name__ == "__main__":
|
|
275
|
-
logging.basicConfig(level=logging.INFO)
|
|
276
|
-
|
|
277
|
-
print("Loading FB15k-237...")
|
|
278
|
-
fb = load_fb15k237()
|
|
279
|
-
print(fb)
|
|
280
|
-
print()
|
|
281
|
-
|
|
282
|
-
print("Loading WN18RR...")
|
|
283
|
-
wn = load_wn18rr()
|
|
284
|
-
print(wn)
|
|
1
|
+
"""
|
|
2
|
+
Standard Knowledge Graph Benchmark Datasets
|
|
3
|
+
|
|
4
|
+
Provides loaders for:
|
|
5
|
+
- FB15k-237: Freebase subset (14,541 entities, 237 relations)
|
|
6
|
+
- WN18RR: WordNet subset (40,943 entities, 11 relations)
|
|
7
|
+
|
|
8
|
+
Datasets are downloaded from standard sources and cached locally.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import urllib.request
|
|
13
|
+
import tarfile
|
|
14
|
+
import zipfile
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import List, Tuple, Set, Dict, Optional
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Dataset URLs - using villmow/datasets_knowledge_embedding (reliable raw files)
|
|
23
|
+
DATASET_BASE_URLS = {
|
|
24
|
+
"fb15k237": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237",
|
|
25
|
+
"wn18rr": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Default cache directory
|
|
29
|
+
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "odin_benchmarks"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class BenchmarkDataset:
|
|
34
|
+
"""Container for a benchmark dataset."""
|
|
35
|
+
name: str
|
|
36
|
+
train_triples: List[Tuple[str, str, str]]
|
|
37
|
+
valid_triples: List[Tuple[str, str, str]]
|
|
38
|
+
test_triples: List[Tuple[str, str, str]]
|
|
39
|
+
entities: List[str]
|
|
40
|
+
relations: List[str]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def num_entities(self) -> int:
|
|
44
|
+
return len(self.entities)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_relations(self) -> int:
|
|
48
|
+
return len(self.relations)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def num_train(self) -> int:
|
|
52
|
+
return len(self.train_triples)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def num_valid(self) -> int:
|
|
56
|
+
return len(self.valid_triples)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def num_test(self) -> int:
|
|
60
|
+
return len(self.test_triples)
|
|
61
|
+
|
|
62
|
+
def get_train_set(self) -> Set[Tuple[str, str, str]]:
|
|
63
|
+
return set(self.train_triples)
|
|
64
|
+
|
|
65
|
+
def get_all_triples(self) -> Set[Tuple[str, str, str]]:
|
|
66
|
+
return set(self.train_triples + self.valid_triples + self.test_triples)
|
|
67
|
+
|
|
68
|
+
def __repr__(self) -> str:
|
|
69
|
+
return (
|
|
70
|
+
f"BenchmarkDataset({self.name})\n"
|
|
71
|
+
f" Entities: {self.num_entities:,}\n"
|
|
72
|
+
f" Relations: {self.num_relations}\n"
|
|
73
|
+
f" Train: {self.num_train:,}\n"
|
|
74
|
+
f" Valid: {self.num_valid:,}\n"
|
|
75
|
+
f" Test: {self.num_test:,}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _ensure_dir(path: Path):
|
|
80
|
+
"""Ensure directory exists."""
|
|
81
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _download_file(url: str, dest: Path):
|
|
85
|
+
"""Download a file with progress."""
|
|
86
|
+
import ssl
|
|
87
|
+
import certifi
|
|
88
|
+
|
|
89
|
+
logger.info(f"Downloading {url}...")
|
|
90
|
+
|
|
91
|
+
# Try with certifi SSL context first, fall back to unverified
|
|
92
|
+
try:
|
|
93
|
+
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|
94
|
+
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
95
|
+
with open(dest, 'wb') as out_file:
|
|
96
|
+
out_file.write(response.read())
|
|
97
|
+
except (ImportError, ssl.SSLError):
|
|
98
|
+
# Fallback: disable SSL verification (for development only)
|
|
99
|
+
logger.warning("SSL verification disabled - using unverified context")
|
|
100
|
+
ssl_context = ssl.create_default_context()
|
|
101
|
+
ssl_context.check_hostname = False
|
|
102
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
103
|
+
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
104
|
+
with open(dest, 'wb') as out_file:
|
|
105
|
+
out_file.write(response.read())
|
|
106
|
+
|
|
107
|
+
logger.info(f"Downloaded to {dest}")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _extract_tar_gz(archive: Path, dest_dir: Path):
|
|
111
|
+
"""Extract a tar.gz archive."""
|
|
112
|
+
logger.info(f"Extracting {archive}...")
|
|
113
|
+
with tarfile.open(archive, "r:gz") as tar:
|
|
114
|
+
tar.extractall(dest_dir)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _extract_zip(archive: Path, dest_dir: Path):
|
|
118
|
+
"""Extract a zip archive."""
|
|
119
|
+
logger.info(f"Extracting {archive}...")
|
|
120
|
+
with zipfile.ZipFile(archive, 'r') as zip_ref:
|
|
121
|
+
zip_ref.extractall(dest_dir)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _load_triples(filepath: Path) -> List[Tuple[str, str, str]]:
|
|
125
|
+
"""Load triples from a TSV file (head, relation, tail)."""
|
|
126
|
+
triples = []
|
|
127
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
128
|
+
for line in f:
|
|
129
|
+
parts = line.strip().split("\t")
|
|
130
|
+
if len(parts) == 3:
|
|
131
|
+
h, r, t = parts
|
|
132
|
+
triples.append((h, r, t))
|
|
133
|
+
return triples
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _extract_entities_and_relations(
|
|
137
|
+
triples: List[Tuple[str, str, str]]
|
|
138
|
+
) -> Tuple[List[str], List[str]]:
|
|
139
|
+
"""Extract unique entities and relations from triples."""
|
|
140
|
+
entities = set()
|
|
141
|
+
relations = set()
|
|
142
|
+
for h, r, t in triples:
|
|
143
|
+
entities.add(h)
|
|
144
|
+
entities.add(t)
|
|
145
|
+
relations.add(r)
|
|
146
|
+
return sorted(entities), sorted(relations)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _download_dataset_files(base_url: str, dataset_dir: Path):
|
|
150
|
+
"""Download train/valid/test files for a dataset."""
|
|
151
|
+
_ensure_dir(dataset_dir)
|
|
152
|
+
|
|
153
|
+
for split in ["train", "valid", "test"]:
|
|
154
|
+
file_path = dataset_dir / f"{split}.txt"
|
|
155
|
+
if not file_path.exists():
|
|
156
|
+
url = f"{base_url}/{split}.txt"
|
|
157
|
+
_download_file(url, file_path)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def load_fb15k237(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
161
|
+
"""
|
|
162
|
+
Load FB15k-237 dataset.
|
|
163
|
+
|
|
164
|
+
FB15k-237 is a subset of Freebase with:
|
|
165
|
+
- 14,541 entities
|
|
166
|
+
- 237 relations
|
|
167
|
+
- 310,116 triples
|
|
168
|
+
|
|
169
|
+
This version removes inverse relations from FB15k to prevent
|
|
170
|
+
data leakage during evaluation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
cache_dir: Directory to cache downloaded data
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
BenchmarkDataset object
|
|
177
|
+
"""
|
|
178
|
+
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
179
|
+
_ensure_dir(cache_dir)
|
|
180
|
+
|
|
181
|
+
dataset_dir = cache_dir / "FB15k-237"
|
|
182
|
+
|
|
183
|
+
# Download if not cached
|
|
184
|
+
if not (dataset_dir / "train.txt").exists():
|
|
185
|
+
_download_dataset_files(DATASET_BASE_URLS["fb15k237"], dataset_dir)
|
|
186
|
+
|
|
187
|
+
# Load splits
|
|
188
|
+
train = _load_triples(dataset_dir / "train.txt")
|
|
189
|
+
valid = _load_triples(dataset_dir / "valid.txt")
|
|
190
|
+
test = _load_triples(dataset_dir / "test.txt")
|
|
191
|
+
|
|
192
|
+
# Extract vocab
|
|
193
|
+
all_triples = train + valid + test
|
|
194
|
+
entities, relations = _extract_entities_and_relations(all_triples)
|
|
195
|
+
|
|
196
|
+
return BenchmarkDataset(
|
|
197
|
+
name="FB15k-237",
|
|
198
|
+
train_triples=train,
|
|
199
|
+
valid_triples=valid,
|
|
200
|
+
test_triples=test,
|
|
201
|
+
entities=entities,
|
|
202
|
+
relations=relations,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def load_wn18rr(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
207
|
+
"""
|
|
208
|
+
Load WN18RR dataset.
|
|
209
|
+
|
|
210
|
+
WN18RR is a subset of WordNet with:
|
|
211
|
+
- 40,943 entities
|
|
212
|
+
- 11 relations
|
|
213
|
+
- 93,003 triples
|
|
214
|
+
|
|
215
|
+
This version removes inverse relations from WN18 to prevent
|
|
216
|
+
data leakage during evaluation.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
cache_dir: Directory to cache downloaded data
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
BenchmarkDataset object
|
|
223
|
+
"""
|
|
224
|
+
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
225
|
+
_ensure_dir(cache_dir)
|
|
226
|
+
|
|
227
|
+
dataset_dir = cache_dir / "WN18RR"
|
|
228
|
+
|
|
229
|
+
# Download if not cached
|
|
230
|
+
if not (dataset_dir / "train.txt").exists():
|
|
231
|
+
_download_dataset_files(DATASET_BASE_URLS["wn18rr"], dataset_dir)
|
|
232
|
+
|
|
233
|
+
# Load splits
|
|
234
|
+
train = _load_triples(dataset_dir / "train.txt")
|
|
235
|
+
valid = _load_triples(dataset_dir / "valid.txt")
|
|
236
|
+
test = _load_triples(dataset_dir / "test.txt")
|
|
237
|
+
|
|
238
|
+
# Extract vocab
|
|
239
|
+
all_triples = train + valid + test
|
|
240
|
+
entities, relations = _extract_entities_and_relations(all_triples)
|
|
241
|
+
|
|
242
|
+
return BenchmarkDataset(
|
|
243
|
+
name="WN18RR",
|
|
244
|
+
train_triples=train,
|
|
245
|
+
valid_triples=valid,
|
|
246
|
+
test_triples=test,
|
|
247
|
+
entities=entities,
|
|
248
|
+
relations=relations,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def dataset_to_kg(dataset: BenchmarkDataset):
|
|
253
|
+
"""
|
|
254
|
+
Convert BenchmarkDataset to Odin KnowledgeGraph.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
dataset: BenchmarkDataset object
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
KnowledgeGraph object suitable for NPLL training
|
|
261
|
+
"""
|
|
262
|
+
from npll.core import KnowledgeGraph
|
|
263
|
+
|
|
264
|
+
kg = KnowledgeGraph()
|
|
265
|
+
|
|
266
|
+
# Add all training triples as known facts
|
|
267
|
+
for h, r, t in dataset.train_triples:
|
|
268
|
+
kg.add_known_fact(h, r, t)
|
|
269
|
+
|
|
270
|
+
return kg
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# Quick test
|
|
274
|
+
if __name__ == "__main__":
|
|
275
|
+
logging.basicConfig(level=logging.INFO)
|
|
276
|
+
|
|
277
|
+
print("Loading FB15k-237...")
|
|
278
|
+
fb = load_fb15k237()
|
|
279
|
+
print(fb)
|
|
280
|
+
print()
|
|
281
|
+
|
|
282
|
+
print("Loading WN18RR...")
|
|
283
|
+
wn = load_wn18rr()
|
|
284
|
+
print(wn)
|