OntoLearner 1.4.7__py3-none-any.whl → 1.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +15 -12
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/retriever.py +24 -3
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/METADATA +4 -1
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/RECORD +20 -8
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1082 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
import re
|
|
19
|
+
import platform
|
|
20
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Dict, List, Optional, Tuple, Callable
|
|
23
|
+
from functools import partial
|
|
24
|
+
from tqdm.auto import tqdm
|
|
25
|
+
import g4f
|
|
26
|
+
from g4f.client import Client as _G4FClient
|
|
27
|
+
import torch
|
|
28
|
+
from datasets import Dataset, DatasetDict
|
|
29
|
+
from transformers import (
|
|
30
|
+
AutoTokenizer,
|
|
31
|
+
AutoModelForSequenceClassification,
|
|
32
|
+
DataCollatorWithPadding,
|
|
33
|
+
Trainer,
|
|
34
|
+
TrainingArguments,
|
|
35
|
+
set_seed,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from ...base import AutoLearner
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RWTHDBISSFTLearner(AutoLearner):
|
|
42
|
+
"""
|
|
43
|
+
Supervised classifier for (parent, child) taxonomy edges.
|
|
44
|
+
|
|
45
|
+
Model input format:
|
|
46
|
+
"<relation template> ## <optional context>"
|
|
47
|
+
|
|
48
|
+
Context building:
|
|
49
|
+
If no `context_json_path` is provided, the learner precomputes a fixed-name
|
|
50
|
+
context file `rwthdbis_onto_processed.json` under `output_dir/context/`
|
|
51
|
+
from the ontology terms and stores the path in `self.context_json_path`.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
model_name: Hugging Face model identifier.
|
|
55
|
+
output_dir: Directory where checkpoints and tokenizer are saved/loaded.
|
|
56
|
+
min_predictions: If no candidate is predicted positive, return the top-k
|
|
57
|
+
by positive probability (k = min_predictions).
|
|
58
|
+
max_length: Maximum tokenized length for inputs.
|
|
59
|
+
per_device_train_batch_size: Micro-batch size per device.
|
|
60
|
+
gradient_accumulation_steps: Gradient accumulation steps.
|
|
61
|
+
num_train_epochs: Number of training epochs.
|
|
62
|
+
learning_rate: Optimizer LR.
|
|
63
|
+
weight_decay: Weight decay for AdamW.
|
|
64
|
+
logging_steps: Logging interval for Trainer.
|
|
65
|
+
save_strategy: HF saving strategy (e.g., 'epoch').
|
|
66
|
+
save_total_limit: Max checkpoints to keep.
|
|
67
|
+
fp16: Enable FP16 mixed precision.
|
|
68
|
+
bf16: Enable BF16 mixed precision (on supported hardware).
|
|
69
|
+
seed: Random seed for reproducibility.
|
|
70
|
+
negative_ratio: Number of negatives per positive during training.
|
|
71
|
+
bidirectional_templates: If True, also add reversed template examples.
|
|
72
|
+
context_json_path: Path to the preprocessed term-context JSON. If None,
|
|
73
|
+
the file is generated with the fixed prefix `rwthdbis_onto_*`.
|
|
74
|
+
ontology_name: Logical dataset/domain label used in prompts and filtering
|
|
75
|
+
(filenames still use the fixed `rwthdbis_onto_*` prefix).
|
|
76
|
+
device: user-defined argument as 'cuda' or 'cpu'.
|
|
77
|
+
model: Loaded/initialized `AutoModelForSequenceClassification`.
|
|
78
|
+
tokenizer: Loaded/initialized `AutoTokenizer`.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Sentences containing any of these phrases are pruned from term_info.
|
|
82
|
+
_CONTEXT_REMOVALS = [
|
|
83
|
+
"couldn't find any",
|
|
84
|
+
"does not require",
|
|
85
|
+
"assist you further",
|
|
86
|
+
"feel free to",
|
|
87
|
+
"I'm currently unable",
|
|
88
|
+
"the search results",
|
|
89
|
+
"I'm unable to",
|
|
90
|
+
"recommend referring directly",
|
|
91
|
+
"bear with me",
|
|
92
|
+
"searching for the most relevant information",
|
|
93
|
+
"I'm currently checking the most relevant",
|
|
94
|
+
"already in English",
|
|
95
|
+
"require further",
|
|
96
|
+
"any additional information",
|
|
97
|
+
"already an English",
|
|
98
|
+
"don't have information",
|
|
99
|
+
"I'm sorry,",
|
|
100
|
+
"For further exploration",
|
|
101
|
+
"For more detailed information",
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
min_predictions: int = 1,
|
|
107
|
+
model_name: str = "distilroberta-base",
|
|
108
|
+
output_dir: str = "./results/taxonomy-discovery",
|
|
109
|
+
device: str = "cpu",
|
|
110
|
+
max_length: int = 256,
|
|
111
|
+
per_device_train_batch_size: int = 8,
|
|
112
|
+
gradient_accumulation_steps: int = 4,
|
|
113
|
+
num_train_epochs: int = 1,
|
|
114
|
+
learning_rate: float = 2e-5,
|
|
115
|
+
weight_decay: float = 0.01,
|
|
116
|
+
logging_steps: int = 25,
|
|
117
|
+
save_strategy: str = "epoch",
|
|
118
|
+
save_total_limit: int = 1,
|
|
119
|
+
fp16: bool = True,
|
|
120
|
+
bf16: bool = False,
|
|
121
|
+
seed: int = 42,
|
|
122
|
+
negative_ratio: int = 5,
|
|
123
|
+
bidirectional_templates: bool = True,
|
|
124
|
+
context_json_path: Optional[str] = None,
|
|
125
|
+
ontology_name: str = "Geonames",
|
|
126
|
+
) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Initialize the taxonomy-edge learner and set training/inference knobs.
|
|
129
|
+
|
|
130
|
+
Notes:
|
|
131
|
+
- Output artifacts are written under `output_dir`, including
|
|
132
|
+
the model weights and tokenizer (for later `from_pretrained` loads).
|
|
133
|
+
- If `context_json_path` is not provided, a new context file named
|
|
134
|
+
`rwthdbis_onto_processed.json` is generated under `output_dir/context/`.
|
|
135
|
+
"""
|
|
136
|
+
super().__init__()
|
|
137
|
+
|
|
138
|
+
self.model_name = model_name
|
|
139
|
+
safe_model_name = model_name.replace("/", "__")
|
|
140
|
+
|
|
141
|
+
resolved_output = output_dir.format(model_name=safe_model_name)
|
|
142
|
+
self.output_dir = str(Path(resolved_output))
|
|
143
|
+
Path(self.output_dir).mkdir(parents=True, exist_ok=True)
|
|
144
|
+
|
|
145
|
+
# Store provided argument values as-is (types are enforced by callers).
|
|
146
|
+
self.min_predictions = min_predictions
|
|
147
|
+
self.max_length = max_length
|
|
148
|
+
self.per_device_train_batch_size = per_device_train_batch_size
|
|
149
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
150
|
+
self.num_train_epochs = num_train_epochs
|
|
151
|
+
self.learning_rate = learning_rate
|
|
152
|
+
self.weight_decay = weight_decay
|
|
153
|
+
self.logging_steps = logging_steps
|
|
154
|
+
self.save_strategy = save_strategy
|
|
155
|
+
self.save_total_limit = save_total_limit
|
|
156
|
+
self.fp16 = fp16
|
|
157
|
+
self.bf16 = bf16
|
|
158
|
+
self.seed = seed
|
|
159
|
+
|
|
160
|
+
self.negative_ratio = negative_ratio
|
|
161
|
+
self.bidirectional_templates = bidirectional_templates
|
|
162
|
+
self.context_json_path = context_json_path
|
|
163
|
+
|
|
164
|
+
self.ontology_name = ontology_name
|
|
165
|
+
self.device = device
|
|
166
|
+
self.model: Optional[AutoModelForSequenceClassification] = None
|
|
167
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
168
|
+
|
|
169
|
+
# Context caches built from the context JSON.
|
|
170
|
+
self._context_exact: Dict[str, str] = {} # lower(term) -> info
|
|
171
|
+
self._context_rows: List[
|
|
172
|
+
Dict[str, str]
|
|
173
|
+
] = [] # [{'term': str, 'term_info': str}, ...]
|
|
174
|
+
|
|
175
|
+
def _is_windows(self) -> bool:
|
|
176
|
+
"""Return True if the current OS is Windows (NT)."""
|
|
177
|
+
return (os.name == "nt") or (platform.system().lower() == "windows")
|
|
178
|
+
|
|
179
|
+
def _normalize_text(self, raw_text: str, *, drop_questions: bool = False) -> str:
|
|
180
|
+
"""
|
|
181
|
+
Normalize plain text consistently across the pipeline.
|
|
182
|
+
|
|
183
|
+
Operations:
|
|
184
|
+
- Remove markdown-like link patterns (e.g., '[[1]](http://...)').
|
|
185
|
+
- Replace newlines with spaces; collapse repeated spaces.
|
|
186
|
+
- Optionally drop sentences containing '?' (useful for model generations).
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
raw_text: Input text to normalize.
|
|
190
|
+
drop_questions: If True, filter out sentences with '?'.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
str: Cleaned single-line string.
|
|
194
|
+
"""
|
|
195
|
+
if raw_text is None:
|
|
196
|
+
return ""
|
|
197
|
+
text = str(raw_text)
|
|
198
|
+
|
|
199
|
+
# Remove simple markdown link artifacts like [[1]](http://...)
|
|
200
|
+
text = re.sub(r"\[\[\d+\]\]\(https?://[^\)]+\)", "", text)
|
|
201
|
+
|
|
202
|
+
# Replace newlines with spaces and collapse multiple spaces
|
|
203
|
+
text = text.replace("\n", " ")
|
|
204
|
+
text = re.sub(r"\s{2,}", " ", text)
|
|
205
|
+
|
|
206
|
+
if drop_questions:
|
|
207
|
+
sentences = [s.strip() for s in text.split(".")]
|
|
208
|
+
sentences = [s for s in sentences if s and "?" not in s]
|
|
209
|
+
text = ". ".join(sentences)
|
|
210
|
+
|
|
211
|
+
return text.strip()
|
|
212
|
+
|
|
213
|
+
def _default_gpt_inference_with_dataset(self, term: str, dataset_name: str) -> str:
|
|
214
|
+
"""
|
|
215
|
+
Generate a plain-text description for `term`, conditioned on `dataset_name`,
|
|
216
|
+
via g4f (best-effort). Falls back to an empty string on failure.
|
|
217
|
+
|
|
218
|
+
The raw output is then normalized with `_normalize_text(drop_questions=True)`.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
term: Term to describe.
|
|
222
|
+
dataset_name: Ontology/domain name used in the prompt.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
str: Cleaned paragraph describing the term, or "" on failure.
|
|
226
|
+
"""
|
|
227
|
+
prompt = (
|
|
228
|
+
f"Here is a: {term}, which is of domain name :{dataset_name}, translate it into english, "
|
|
229
|
+
"Provide as detailed a definition of this term as possible in plain text.without any markdown format."
|
|
230
|
+
"No reference link in result. "
|
|
231
|
+
"- Focus on intrinsic properties; do not name other entities or explicit relationships.\n"
|
|
232
|
+
"- Include classification/type, defining features, scope/scale, roles/functions, and measurable attributes when applicable.\n"
|
|
233
|
+
"Output: Plain text paragraphs only, neutral and factual."
|
|
234
|
+
f"Make sure all provided information can be used for discovering implicit relation of other {dataset_name} term, but don't mention the relation in result."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
client = _G4FClient()
|
|
239
|
+
response = client.chat.completions.create(
|
|
240
|
+
model=g4f.models.default,
|
|
241
|
+
messages=[{"role": "user", "content": prompt}],
|
|
242
|
+
)
|
|
243
|
+
raw_text = (
|
|
244
|
+
response.choices[0].message.content
|
|
245
|
+
if response and response.choices
|
|
246
|
+
else ""
|
|
247
|
+
)
|
|
248
|
+
except Exception:
|
|
249
|
+
raw_text = "" # best-effort fallback
|
|
250
|
+
|
|
251
|
+
return self._normalize_text(raw_text, drop_questions=True)
|
|
252
|
+
|
|
253
|
+
def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
254
|
+
"""
|
|
255
|
+
AutoLearner hook: route to training or prediction.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
data: Ontology-like object (has `.taxonomies` or `.type_taxonomies.taxonomies`).
|
|
259
|
+
test: If True, run inference; otherwise, train a model.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
If test=True, a list of accepted edges as dicts with keys `parent` and `child`;
|
|
263
|
+
otherwise None.
|
|
264
|
+
"""
|
|
265
|
+
return self._predict_pairs(data) if test else self._train_from_pairs(data)
|
|
266
|
+
|
|
267
|
+
def _train_from_pairs(self, train_data: Any) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Train a binary classifier from ontology pairs.
|
|
270
|
+
|
|
271
|
+
Steps:
|
|
272
|
+
1) (Re)build the term-context JSON unless `context_json_path` is set.
|
|
273
|
+
2) Extract positive (parent, child) edges from `train_data`.
|
|
274
|
+
3) Sample negatives at `negative_ratio`.
|
|
275
|
+
4) Tokenize, instantiate HF Trainer, train, and save.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
train_data: Ontology-like object with `.type_taxonomies.taxonomies`
|
|
279
|
+
(preferred) or `.taxonomies`, each item providing `parent` and `child`.
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
ValueError: If no positive pairs are found.
|
|
283
|
+
|
|
284
|
+
Side Effects:
|
|
285
|
+
- Writes a trained model to `self.output_dir` (via `trainer.save_model`).
|
|
286
|
+
- Writes the tokenizer to `self.output_dir` (via `save_pretrained`).
|
|
287
|
+
- Sets `self.context_json_path` if it was previously unset.
|
|
288
|
+
The generated context file is named `rwthdbis_onto_processed.json`.
|
|
289
|
+
"""
|
|
290
|
+
# Always (re)build context from ontology unless an explicit file is provided
|
|
291
|
+
if not self.context_json_path:
|
|
292
|
+
context_dir = Path(self.output_dir) / "context"
|
|
293
|
+
context_dir.mkdir(parents=True, exist_ok=True)
|
|
294
|
+
processed_context_file = context_dir / "rwthdbis_onto_processed.json"
|
|
295
|
+
|
|
296
|
+
# Remove stale file then regenerate
|
|
297
|
+
if processed_context_file.exists():
|
|
298
|
+
try:
|
|
299
|
+
processed_context_file.unlink()
|
|
300
|
+
except Exception:
|
|
301
|
+
pass
|
|
302
|
+
|
|
303
|
+
self.preprocess_context_from_ontology(
|
|
304
|
+
ontology=train_data,
|
|
305
|
+
processed_dir=context_dir,
|
|
306
|
+
dataset_name=self.ontology_name,
|
|
307
|
+
num_workers=max(1, min(os.cpu_count() or 2, 4)),
|
|
308
|
+
provider=partial(
|
|
309
|
+
self._default_gpt_inference_with_dataset,
|
|
310
|
+
dataset_name=self.ontology_name,
|
|
311
|
+
),
|
|
312
|
+
max_retries=5,
|
|
313
|
+
)
|
|
314
|
+
self.context_json_path = str(processed_context_file)
|
|
315
|
+
|
|
316
|
+
# Reproducibility
|
|
317
|
+
set_seed(self.seed)
|
|
318
|
+
random.seed(self.seed)
|
|
319
|
+
torch.manual_seed(self.seed)
|
|
320
|
+
if torch.cuda.is_available():
|
|
321
|
+
torch.cuda.manual_seed_all(self.seed)
|
|
322
|
+
|
|
323
|
+
# Build labeled pairs from ontology; context comes from preprocessed map
|
|
324
|
+
positive_pairs = self._extract_positive_pairs(train_data)
|
|
325
|
+
if not positive_pairs:
|
|
326
|
+
raise ValueError("No positive (parent, child) pairs found in train_data.")
|
|
327
|
+
|
|
328
|
+
entity_names = sorted(
|
|
329
|
+
{parent for parent, _ in positive_pairs}
|
|
330
|
+
| {child for _, child in positive_pairs}
|
|
331
|
+
)
|
|
332
|
+
negative_pairs = self._generate_negatives(
|
|
333
|
+
positives=positive_pairs,
|
|
334
|
+
entities=entity_names,
|
|
335
|
+
ratio=self.negative_ratio,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
labels, input_texts = self._build_text_dataset(positive_pairs, negative_pairs)
|
|
339
|
+
dataset_dict = DatasetDict(
|
|
340
|
+
{"train": Dataset.from_dict({"label": labels, "text": input_texts})}
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
344
|
+
# Ensure a pad token exists for robust padding across models.
|
|
345
|
+
if self.tokenizer.pad_token is None:
|
|
346
|
+
self.tokenizer.pad_token = (
|
|
347
|
+
getattr(self.tokenizer, "eos_token", None)
|
|
348
|
+
or getattr(self.tokenizer, "sep_token", None)
|
|
349
|
+
or getattr(self.tokenizer, "cls_token", None)
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def tokenize_batch(batch: Dict[str, List[str]]):
|
|
353
|
+
"""Tokenize a batch of input texts for HF Datasets mapping."""
|
|
354
|
+
return self.tokenizer(
|
|
355
|
+
batch["text"], truncation=True, max_length=self.max_length
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
tokenized_dataset = dataset_dict.map(
|
|
359
|
+
tokenize_batch, batched=True, remove_columns=["text"]
|
|
360
|
+
)
|
|
361
|
+
data_collator = DataCollatorWithPadding(self.tokenizer)
|
|
362
|
+
|
|
363
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
364
|
+
self.model_name,
|
|
365
|
+
num_labels=2,
|
|
366
|
+
id2label={0: "incorrect", 1: "correct"},
|
|
367
|
+
label2id={"incorrect": 0, "correct": 1},
|
|
368
|
+
)
|
|
369
|
+
# Ensure model has a pad_token_id if tokenizer provides one.
|
|
370
|
+
if (
|
|
371
|
+
getattr(self.model.config, "pad_token_id", None) is None
|
|
372
|
+
and self.tokenizer.pad_token_id is not None
|
|
373
|
+
):
|
|
374
|
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
375
|
+
|
|
376
|
+
training_args = TrainingArguments(
|
|
377
|
+
output_dir=self.output_dir,
|
|
378
|
+
learning_rate=self.learning_rate,
|
|
379
|
+
per_device_train_batch_size=self.per_device_train_batch_size,
|
|
380
|
+
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
|
381
|
+
num_train_epochs=self.num_train_epochs,
|
|
382
|
+
weight_decay=self.weight_decay,
|
|
383
|
+
save_strategy=self.save_strategy,
|
|
384
|
+
save_total_limit=self.save_total_limit,
|
|
385
|
+
logging_steps=self.logging_steps,
|
|
386
|
+
dataloader_pin_memory=bool(torch.cuda.is_available()),
|
|
387
|
+
fp16=self.fp16,
|
|
388
|
+
bf16=self.bf16,
|
|
389
|
+
report_to="none",
|
|
390
|
+
save_safetensors=True,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
trainer = Trainer(
|
|
394
|
+
model=self.model,
|
|
395
|
+
args=training_args,
|
|
396
|
+
train_dataset=tokenized_dataset["train"],
|
|
397
|
+
tokenizer=self.tokenizer,
|
|
398
|
+
data_collator=data_collator,
|
|
399
|
+
)
|
|
400
|
+
trainer.train()
|
|
401
|
+
trainer.save_model()
|
|
402
|
+
# Persist tokenizer alongside the model for from_pretrained() loads.
|
|
403
|
+
self.tokenizer.save_pretrained(self.output_dir)
|
|
404
|
+
|
|
405
|
+
def _predict_pairs(self, eval_data: Any) -> List[Dict[str, str]]:
|
|
406
|
+
"""
|
|
407
|
+
Score candidate pairs and return those predicted as positive.
|
|
408
|
+
|
|
409
|
+
If no pair is predicted positive but `min_predictions` > 0, the top-k
|
|
410
|
+
pairs by positive probability are returned.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
eval_data: Ontology-like object with either `.pairs` (preferred) or
|
|
414
|
+
`.type_taxonomies.taxonomies` / `.taxonomies`.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
list[dict]: Each dict has keys `parent` and `child`.
|
|
418
|
+
"""
|
|
419
|
+
import torch.nn.functional as F
|
|
420
|
+
|
|
421
|
+
self._ensure_loaded_for_inference()
|
|
422
|
+
|
|
423
|
+
candidate_pairs = self._extract_pairs_for_eval(eval_data)
|
|
424
|
+
if not candidate_pairs:
|
|
425
|
+
return []
|
|
426
|
+
|
|
427
|
+
accepted_pairs: List[Dict[str, str]] = []
|
|
428
|
+
scored_candidates: List[Tuple[float, str, str, int]] = []
|
|
429
|
+
|
|
430
|
+
self.model.eval()
|
|
431
|
+
with torch.no_grad():
|
|
432
|
+
for parent_term, child_term in candidate_pairs:
|
|
433
|
+
input_text = self._format_input(parent_term, child_term)
|
|
434
|
+
inputs = self.tokenizer(
|
|
435
|
+
input_text,
|
|
436
|
+
return_tensors="pt",
|
|
437
|
+
truncation=True,
|
|
438
|
+
max_length=self.max_length,
|
|
439
|
+
)
|
|
440
|
+
inputs = {key: tensor.to(self.device) for key, tensor in inputs.items()}
|
|
441
|
+
logits = self.model(**inputs).logits
|
|
442
|
+
probabilities = F.softmax(logits, dim=-1).squeeze(0)
|
|
443
|
+
p_positive = float(probabilities[1].item())
|
|
444
|
+
predicted_label = int(torch.argmax(logits, dim=-1).item())
|
|
445
|
+
scored_candidates.append(
|
|
446
|
+
(p_positive, parent_term, child_term, predicted_label)
|
|
447
|
+
)
|
|
448
|
+
if predicted_label == 1:
|
|
449
|
+
accepted_pairs.append({"parent": parent_term, "child": child_term})
|
|
450
|
+
|
|
451
|
+
if accepted_pairs:
|
|
452
|
+
return accepted_pairs
|
|
453
|
+
|
|
454
|
+
top_k = max(0, int(self.min_predictions))
|
|
455
|
+
if top_k == 0:
|
|
456
|
+
return []
|
|
457
|
+
scored_candidates.sort(key=lambda item: item[0], reverse=True)
|
|
458
|
+
return [
|
|
459
|
+
{"parent": parent_term, "child": child_term}
|
|
460
|
+
for (_prob, parent_term, child_term, _pred) in scored_candidates[:top_k]
|
|
461
|
+
]
|
|
462
|
+
|
|
463
|
+
def _ensure_loaded_for_inference(self) -> None:
|
|
464
|
+
"""
|
|
465
|
+
Load model and tokenizer from `self.output_dir` if not already loaded.
|
|
466
|
+
|
|
467
|
+
Side Effects:
|
|
468
|
+
- Sets `self.model` and `self.tokenizer`.
|
|
469
|
+
- Moves the model to `self.device`.
|
|
470
|
+
- Ensures `tokenizer.pad_token_id` is set if model config provides one.
|
|
471
|
+
"""
|
|
472
|
+
if self.model is not None and self.tokenizer is not None:
|
|
473
|
+
return
|
|
474
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
475
|
+
self.output_dir
|
|
476
|
+
).to(self.device)
|
|
477
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.output_dir)
|
|
478
|
+
if (
|
|
479
|
+
self.tokenizer.pad_token_id is None
|
|
480
|
+
and getattr(self.model.config, "pad_token_id", None) is not None
|
|
481
|
+
):
|
|
482
|
+
self.tokenizer.pad_token_id = self.model.config.pad_token_id
|
|
483
|
+
|
|
484
|
+
def _load_context_map(self) -> None:
|
|
485
|
+
"""
|
|
486
|
+
Populate in-memory maps from the context JSON (`self.context_json_path`).
|
|
487
|
+
|
|
488
|
+
Builds:
|
|
489
|
+
- `_context_exact`: dict mapping lowercased term → term_info.
|
|
490
|
+
- `_context_rows`: list of dict rows with 'term' and 'term_info'.
|
|
491
|
+
|
|
492
|
+
If `context_json_path` is falsy or loading fails, both structures become empty.
|
|
493
|
+
"""
|
|
494
|
+
if not self.context_json_path:
|
|
495
|
+
self._context_exact = {}
|
|
496
|
+
self._context_rows = []
|
|
497
|
+
return
|
|
498
|
+
try:
|
|
499
|
+
rows = json.load(open(self.context_json_path, "r", encoding="utf-8"))
|
|
500
|
+
self._context_exact = {
|
|
501
|
+
str(row.get("term", "")).strip().lower(): str(
|
|
502
|
+
row.get("term_info", "")
|
|
503
|
+
).strip()
|
|
504
|
+
for row in rows
|
|
505
|
+
}
|
|
506
|
+
self._context_rows = [
|
|
507
|
+
{
|
|
508
|
+
"term": str(row.get("term", "")),
|
|
509
|
+
"term_info": str(row.get("term_info", "")),
|
|
510
|
+
}
|
|
511
|
+
for row in rows
|
|
512
|
+
]
|
|
513
|
+
except Exception:
|
|
514
|
+
self._context_exact = {}
|
|
515
|
+
self._context_rows = []
|
|
516
|
+
|
|
517
|
+
def _lookup_context_info(self, raw_term: str) -> str:
|
|
518
|
+
"""
|
|
519
|
+
Retrieve textual context for a term using exact and simple fuzzy matching.
|
|
520
|
+
|
|
521
|
+
- Exact: lowercased term lookup in `_context_exact`.
|
|
522
|
+
- Fuzzy: split `raw_term` by commas, strip whitespace; treat each piece
|
|
523
|
+
as a case-insensitive substring against row['term'].
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
raw_term: Original term string (possibly comma-separated).
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
str: Concatenated matches' term_info ('.' joined). Empty string if none.
|
|
530
|
+
"""
|
|
531
|
+
if not raw_term:
|
|
532
|
+
return ""
|
|
533
|
+
term_key = raw_term.strip().lower()
|
|
534
|
+
if term_key in self._context_exact:
|
|
535
|
+
return self._context_exact[term_key]
|
|
536
|
+
|
|
537
|
+
subterms = [re.sub(r"\s+", "", piece) for piece in raw_term.split(",")]
|
|
538
|
+
matched_infos: List[str] = []
|
|
539
|
+
for subterm in subterms:
|
|
540
|
+
if not subterm:
|
|
541
|
+
continue
|
|
542
|
+
lower_subterm = subterm.lower()
|
|
543
|
+
for row in self._context_rows:
|
|
544
|
+
if lower_subterm in row["term"].lower():
|
|
545
|
+
info = row.get("term_info", "")
|
|
546
|
+
if info:
|
|
547
|
+
matched_infos.append(info)
|
|
548
|
+
break # one hit per subterm
|
|
549
|
+
return ".".join(matched_infos)
|
|
550
|
+
|
|
551
|
+
def _extract_positive_pairs(self, ontology_obj: Any) -> List[Tuple[str, str]]:
|
|
552
|
+
"""
|
|
553
|
+
Extract positive (parent, child) edges from an ontology-like object.
|
|
554
|
+
|
|
555
|
+
Reads from `ontology_obj.type_taxonomies.taxonomies` (preferred) or
|
|
556
|
+
falls back to `ontology_obj.taxonomies`. Each item must expose `parent`
|
|
557
|
+
and `child` as attributes or dict keys.
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
list[tuple[str, str]]: (parent, child) pairs (may be empty).
|
|
561
|
+
"""
|
|
562
|
+
type_taxonomies = getattr(ontology_obj, "type_taxonomies", None)
|
|
563
|
+
items = (
|
|
564
|
+
getattr(type_taxonomies, "taxonomies", None)
|
|
565
|
+
if type_taxonomies is not None
|
|
566
|
+
else getattr(ontology_obj, "taxonomies", None)
|
|
567
|
+
)
|
|
568
|
+
pairs: List[Tuple[str, str]] = []
|
|
569
|
+
if items:
|
|
570
|
+
for item in items:
|
|
571
|
+
parent_term = (
|
|
572
|
+
getattr(item, "parent", None)
|
|
573
|
+
if not isinstance(item, dict)
|
|
574
|
+
else item.get("parent")
|
|
575
|
+
)
|
|
576
|
+
child_term = (
|
|
577
|
+
getattr(item, "child", None)
|
|
578
|
+
if not isinstance(item, dict)
|
|
579
|
+
else item.get("child")
|
|
580
|
+
)
|
|
581
|
+
if parent_term and child_term:
|
|
582
|
+
pairs.append((str(parent_term), str(child_term)))
|
|
583
|
+
return pairs
|
|
584
|
+
|
|
585
|
+
def _extract_pairs_for_eval(self, ontology_obj: Any) -> List[Tuple[str, str]]:
|
|
586
|
+
"""
|
|
587
|
+
Extract candidate pairs for evaluation.
|
|
588
|
+
|
|
589
|
+
Prefers `ontology_obj.pairs` if present; otherwise falls back to the
|
|
590
|
+
positive pairs from the ontology (see `_extract_positive_pairs`).
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
list[tuple[str, str]]: Candidate (parent, child) pairs.
|
|
594
|
+
"""
|
|
595
|
+
candidate_pairs = getattr(ontology_obj, "pairs", None)
|
|
596
|
+
if candidate_pairs:
|
|
597
|
+
pairs: List[Tuple[str, str]] = []
|
|
598
|
+
for item in candidate_pairs:
|
|
599
|
+
parent_term = (
|
|
600
|
+
getattr(item, "parent", None)
|
|
601
|
+
if not isinstance(item, dict)
|
|
602
|
+
else item.get("parent")
|
|
603
|
+
)
|
|
604
|
+
child_term = (
|
|
605
|
+
getattr(item, "child", None)
|
|
606
|
+
if not isinstance(item, dict)
|
|
607
|
+
else item.get("child")
|
|
608
|
+
)
|
|
609
|
+
if parent_term and child_term:
|
|
610
|
+
pairs.append((str(parent_term), str(child_term)))
|
|
611
|
+
return pairs
|
|
612
|
+
return self._extract_positive_pairs(ontology_obj)
|
|
613
|
+
|
|
614
|
+
def _generate_negatives(
|
|
615
|
+
self,
|
|
616
|
+
positives: List[Tuple[str, str]],
|
|
617
|
+
entities: List[str],
|
|
618
|
+
ratio: int,
|
|
619
|
+
) -> List[Tuple[str, str]]:
|
|
620
|
+
"""
|
|
621
|
+
Sample negative edges by excluding known positives and self-pairs.
|
|
622
|
+
|
|
623
|
+
Constructs the cartesian product of entities (excluding (x, x)),
|
|
624
|
+
removes all known positives, and samples up to `ratio * len(positives)`
|
|
625
|
+
negatives uniformly at random.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
positives: Known positive edges.
|
|
629
|
+
entities: Unique set/list of entity terms.
|
|
630
|
+
ratio: Target negatives per positive (lower-bounded by 1×).
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
list[tuple[str, str]]: Sampled negative pairs (may be smaller).
|
|
634
|
+
"""
|
|
635
|
+
positive_set = set(positives)
|
|
636
|
+
all_possible = {
|
|
637
|
+
(parent, child)
|
|
638
|
+
for parent in entities
|
|
639
|
+
for child in entities
|
|
640
|
+
if parent != child
|
|
641
|
+
}
|
|
642
|
+
negative_candidates = list(all_possible - positive_set)
|
|
643
|
+
|
|
644
|
+
target_count = max(len(positive_set) * max(1, ratio), len(positive_set))
|
|
645
|
+
sample_count = min(target_count, len(negative_candidates))
|
|
646
|
+
return (
|
|
647
|
+
random.sample(negative_candidates, k=sample_count)
|
|
648
|
+
if sample_count > 0
|
|
649
|
+
else []
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
def _build_text_dataset(
|
|
653
|
+
self,
|
|
654
|
+
positives: List[Tuple[str, str]],
|
|
655
|
+
negatives: List[Tuple[str, str]],
|
|
656
|
+
) -> Tuple[List[int], List[str]]:
|
|
657
|
+
"""
|
|
658
|
+
Create parallel lists of labels and input texts for HF Datasets.
|
|
659
|
+
|
|
660
|
+
Builds formatted inputs using `_format_input`, and duplicates examples in
|
|
661
|
+
the reverse direction if `bidirectional_templates` is True.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
tuple[list[int], list[str]]: (labels, input_texts) where labels are
|
|
665
|
+
1 for positive and 0 for negative.
|
|
666
|
+
"""
|
|
667
|
+
self._load_context_map()
|
|
668
|
+
|
|
669
|
+
labels: List[int] = []
|
|
670
|
+
input_texts: List[str] = []
|
|
671
|
+
|
|
672
|
+
def add_example(parent_term: str, child_term: str, label_value: int) -> None:
|
|
673
|
+
"""Append one (and optionally reversed) example to the dataset."""
|
|
674
|
+
input_texts.append(self._format_input(parent_term, child_term))
|
|
675
|
+
labels.append(label_value)
|
|
676
|
+
if self.bidirectional_templates:
|
|
677
|
+
input_texts.append(
|
|
678
|
+
self._format_input(child_term, parent_term, reverse=True)
|
|
679
|
+
)
|
|
680
|
+
labels.append(label_value)
|
|
681
|
+
|
|
682
|
+
for parent_term, child_term in positives:
|
|
683
|
+
add_example(parent_term, child_term, 1)
|
|
684
|
+
for parent_term, child_term in negatives:
|
|
685
|
+
add_example(parent_term, child_term, 0)
|
|
686
|
+
|
|
687
|
+
return labels, input_texts
|
|
688
|
+
|
|
689
|
+
def _format_input(
|
|
690
|
+
self, parent_term: str, child_term: str, reverse: bool = False
|
|
691
|
+
) -> str:
|
|
692
|
+
"""
|
|
693
|
+
Format a (parent, child) pair into relation text + optional context.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
str: "<relation template> [## Context. 'parent': ... 'child': ...]"
|
|
697
|
+
"""
|
|
698
|
+
relation_text = (
|
|
699
|
+
f"{child_term} is a subclass / child / subtype / descendant class of {parent_term}"
|
|
700
|
+
if reverse
|
|
701
|
+
else f"{parent_term} is the superclass / parent / supertype / ancestor class of {child_term}"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
parent_info = self._lookup_context_info(parent_term)
|
|
705
|
+
child_info = self._lookup_context_info(child_term)
|
|
706
|
+
if not parent_info and not child_info:
|
|
707
|
+
return relation_text
|
|
708
|
+
|
|
709
|
+
context_text = (
|
|
710
|
+
f"## Context. '{parent_term}': {parent_info} '{child_term}': {child_info}"
|
|
711
|
+
)
|
|
712
|
+
return f"{relation_text} {context_text}"
|
|
713
|
+
|
|
714
|
+
def _fill_bucket_threaded(
|
|
715
|
+
self, bucket_rows: List[dict], output_path: Path, provider: Callable[[str], str]
|
|
716
|
+
) -> None:
|
|
717
|
+
"""
|
|
718
|
+
Populate a shard with provider-generated `term_info` using threads.
|
|
719
|
+
|
|
720
|
+
Resumes from `output_path` if it already exists, periodically writes
|
|
721
|
+
progress (every ~10 items), and finally dumps the full bucket to disk.
|
|
722
|
+
"""
|
|
723
|
+
start_index = 0
|
|
724
|
+
try:
|
|
725
|
+
if output_path.is_file():
|
|
726
|
+
existing_rows = json.load(open(output_path, "r", encoding="utf-8"))
|
|
727
|
+
if isinstance(existing_rows, list) and existing_rows:
|
|
728
|
+
bucket_rows[: len(existing_rows)] = existing_rows
|
|
729
|
+
start_index = len(existing_rows)
|
|
730
|
+
except Exception:
|
|
731
|
+
pass
|
|
732
|
+
|
|
733
|
+
for row_index in range(start_index, len(bucket_rows)):
|
|
734
|
+
try:
|
|
735
|
+
bucket_rows[row_index]["term_info"] = provider(
|
|
736
|
+
bucket_rows[row_index]["term"]
|
|
737
|
+
)
|
|
738
|
+
except Exception:
|
|
739
|
+
bucket_rows[row_index]["term_info"] = ""
|
|
740
|
+
if row_index % 10 == 1:
|
|
741
|
+
json.dump(
|
|
742
|
+
bucket_rows[: row_index + 1],
|
|
743
|
+
open(output_path, "w", encoding="utf-8"),
|
|
744
|
+
ensure_ascii=False,
|
|
745
|
+
indent=2,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
json.dump(
|
|
749
|
+
bucket_rows,
|
|
750
|
+
open(output_path, "w", encoding="utf-8"),
|
|
751
|
+
ensure_ascii=False,
|
|
752
|
+
indent=2,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
def _merge_part_files(
|
|
756
|
+
self, dataset_name: str, merged_path: Path, shard_paths: List[Path]
|
|
757
|
+
) -> None:
|
|
758
|
+
"""
|
|
759
|
+
Merge shard files into one JSON and filter boilerplate sentences.
|
|
760
|
+
|
|
761
|
+
- Reads shard lists/dicts from `shard_paths`.
|
|
762
|
+
- Drops sentences that contain markers in `_CONTEXT_REMOVALS` or the
|
|
763
|
+
`dataset_name` string.
|
|
764
|
+
- Normalizes the remaining text via `_normalize_text`.
|
|
765
|
+
- Writes merged JSON to `merged_path`, then best-effort deletes shards.
|
|
766
|
+
"""
|
|
767
|
+
merged_rows: List[dict] = []
|
|
768
|
+
for shard_path in shard_paths:
|
|
769
|
+
try:
|
|
770
|
+
if not shard_path.is_file():
|
|
771
|
+
continue
|
|
772
|
+
part_content = json.load(open(shard_path, "r", encoding="utf-8"))
|
|
773
|
+
if isinstance(part_content, list):
|
|
774
|
+
merged_rows.extend(part_content)
|
|
775
|
+
elif isinstance(part_content, dict):
|
|
776
|
+
merged_rows.append(part_content)
|
|
777
|
+
except Exception:
|
|
778
|
+
continue
|
|
779
|
+
|
|
780
|
+
removal_markers = list(self._CONTEXT_REMOVALS) + [dataset_name]
|
|
781
|
+
for row in merged_rows:
|
|
782
|
+
term_info_raw = str(row.get("term_info", ""))
|
|
783
|
+
kept_sentences: List[str] = []
|
|
784
|
+
for sentence in term_info_raw.split("."):
|
|
785
|
+
sentence_no_links = re.sub(
|
|
786
|
+
r"\[\[\d+\]\]\(https?://[^\)]+\)", "", sentence
|
|
787
|
+
)
|
|
788
|
+
if any(marker in sentence_no_links for marker in removal_markers):
|
|
789
|
+
continue
|
|
790
|
+
kept_sentences.append(sentence_no_links)
|
|
791
|
+
row["term_info"] = self._normalize_text(
|
|
792
|
+
".".join(kept_sentences), drop_questions=False
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
merged_path.parent.mkdir(parents=True, exist_ok=True)
|
|
796
|
+
json.dump(
|
|
797
|
+
merged_rows,
|
|
798
|
+
open(merged_path, "w", encoding="utf-8"),
|
|
799
|
+
ensure_ascii=False,
|
|
800
|
+
indent=4,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
# best-effort cleanup
|
|
804
|
+
for shard_path in shard_paths:
|
|
805
|
+
try:
|
|
806
|
+
os.remove(shard_path)
|
|
807
|
+
except Exception:
|
|
808
|
+
pass
|
|
809
|
+
|
|
810
|
+
def _execute_for_terms(
|
|
811
|
+
self,
|
|
812
|
+
terms: List[str],
|
|
813
|
+
merged_path: Path,
|
|
814
|
+
shard_paths: List[Path],
|
|
815
|
+
provider: Callable[[str], str],
|
|
816
|
+
dataset_name: str,
|
|
817
|
+
num_workers: int = 2,
|
|
818
|
+
) -> None:
|
|
819
|
+
"""
|
|
820
|
+
Generate context for `terms`, writing shards to `shard_paths`, then merge.
|
|
821
|
+
|
|
822
|
+
Always uses threads (pickling-safe for instance methods).
|
|
823
|
+
Shows a tqdm progress bar and merges shards at the end.
|
|
824
|
+
"""
|
|
825
|
+
worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
|
|
826
|
+
all_rows = [
|
|
827
|
+
{"id": index, "term": term, "term_info": ""}
|
|
828
|
+
for index, term in enumerate(terms)
|
|
829
|
+
]
|
|
830
|
+
|
|
831
|
+
buckets: List[List[dict]] = [[] for _ in range(worker_count)]
|
|
832
|
+
for reversed_index, row in enumerate(reversed(all_rows)):
|
|
833
|
+
buckets[reversed_index % worker_count].append(row)
|
|
834
|
+
|
|
835
|
+
total_rows = len(terms)
|
|
836
|
+
progress_bar = tqdm(
|
|
837
|
+
total=total_rows, desc=f"{dataset_name} generation (threads)"
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
def run_bucket(bucket_rows: List[dict], out_path: Path) -> int:
|
|
841
|
+
self._fill_bucket_threaded(bucket_rows, out_path, provider)
|
|
842
|
+
return len(bucket_rows)
|
|
843
|
+
|
|
844
|
+
with ThreadPoolExecutor(max_workers=worker_count) as pool:
|
|
845
|
+
futures = [
|
|
846
|
+
pool.submit(
|
|
847
|
+
run_bucket, buckets[bucket_index], shard_paths[bucket_index]
|
|
848
|
+
)
|
|
849
|
+
for bucket_index in range(worker_count)
|
|
850
|
+
]
|
|
851
|
+
for future in as_completed(futures):
|
|
852
|
+
completed_count = future.result()
|
|
853
|
+
if progress_bar:
|
|
854
|
+
progress_bar.update(completed_count)
|
|
855
|
+
if progress_bar:
|
|
856
|
+
progress_bar.close()
|
|
857
|
+
|
|
858
|
+
self._merge_part_files(dataset_name, merged_path, shard_paths)
|
|
859
|
+
|
|
860
|
+
def _re_infer_short_entries(
|
|
861
|
+
self,
|
|
862
|
+
merged_path: Path,
|
|
863
|
+
re_shard_paths: List[Path],
|
|
864
|
+
re_merged_path: Path,
|
|
865
|
+
provider: Callable[[str], str],
|
|
866
|
+
dataset_name: str,
|
|
867
|
+
num_workers: int,
|
|
868
|
+
) -> int:
|
|
869
|
+
"""
|
|
870
|
+
Re-query terms whose `term_info` is too short (< 50 chars).
|
|
871
|
+
|
|
872
|
+
Process:
|
|
873
|
+
- Read `merged_path`.
|
|
874
|
+
- Filter boilerplate using `_CONTEXT_REMOVALS` and `dataset_name`.
|
|
875
|
+
- Split into short/long groups by length 50.
|
|
876
|
+
- Regenerate short group with `provider` in parallel (threads).
|
|
877
|
+
- Merge regenerated + long back into `merged_path`.
|
|
878
|
+
|
|
879
|
+
Returns:
|
|
880
|
+
int: Count of rows still < 50 chars after re-inference.
|
|
881
|
+
"""
|
|
882
|
+
merged_rows = json.load(open(merged_path, "r", encoding="utf-8"))
|
|
883
|
+
|
|
884
|
+
removal_markers = list(self._CONTEXT_REMOVALS) + [dataset_name]
|
|
885
|
+
short_rows: List[dict] = []
|
|
886
|
+
long_rows: List[dict] = []
|
|
887
|
+
|
|
888
|
+
for row in merged_rows:
|
|
889
|
+
term_info_raw = str(row.get("term_info", ""))
|
|
890
|
+
sentences = term_info_raw.split(".")
|
|
891
|
+
for marker in removal_markers:
|
|
892
|
+
sentences = [
|
|
893
|
+
sentence if marker not in sentence else "" for sentence in sentences
|
|
894
|
+
]
|
|
895
|
+
filtered_info = self._normalize_text(
|
|
896
|
+
".".join(sentences), drop_questions=False
|
|
897
|
+
)
|
|
898
|
+
row["term_info"] = filtered_info
|
|
899
|
+
|
|
900
|
+
(short_rows if len(filtered_info) < 50 else long_rows).append(row)
|
|
901
|
+
|
|
902
|
+
worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
|
|
903
|
+
buckets: List[List[dict]] = [[] for _ in range(worker_count)]
|
|
904
|
+
for row_index, row in enumerate(short_rows):
|
|
905
|
+
buckets[row_index % worker_count].append(row)
|
|
906
|
+
|
|
907
|
+
# Clean old re-inference shards
|
|
908
|
+
for path in re_shard_paths:
|
|
909
|
+
try:
|
|
910
|
+
os.remove(path)
|
|
911
|
+
except Exception:
|
|
912
|
+
pass
|
|
913
|
+
|
|
914
|
+
total_candidates = len(short_rows)
|
|
915
|
+
progress_bar = tqdm(
|
|
916
|
+
total=total_candidates, desc=f"{dataset_name} re-inference (threads)"
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
def run_bucket(bucket_rows: List[dict], out_path: Path) -> int:
|
|
920
|
+
self._fill_bucket_threaded(bucket_rows, out_path, provider)
|
|
921
|
+
return len(bucket_rows)
|
|
922
|
+
|
|
923
|
+
with ThreadPoolExecutor(max_workers=worker_count) as pool:
|
|
924
|
+
futures = [
|
|
925
|
+
pool.submit(
|
|
926
|
+
run_bucket, buckets[bucket_index], re_shard_paths[bucket_index]
|
|
927
|
+
)
|
|
928
|
+
for bucket_index in range(worker_count)
|
|
929
|
+
]
|
|
930
|
+
for future in as_completed(futures):
|
|
931
|
+
completed_count = future.result()
|
|
932
|
+
if progress_bar:
|
|
933
|
+
progress_bar.update(completed_count)
|
|
934
|
+
if progress_bar:
|
|
935
|
+
progress_bar.close()
|
|
936
|
+
|
|
937
|
+
# Merge and write back
|
|
938
|
+
self._merge_part_files(dataset_name, re_merged_path, re_shard_paths)
|
|
939
|
+
new_rows = (
|
|
940
|
+
json.load(open(re_merged_path, "r", encoding="utf-8"))
|
|
941
|
+
if re_merged_path.is_file()
|
|
942
|
+
else []
|
|
943
|
+
)
|
|
944
|
+
final_rows = long_rows + new_rows
|
|
945
|
+
json.dump(
|
|
946
|
+
final_rows,
|
|
947
|
+
open(merged_path, "w", encoding="utf-8"),
|
|
948
|
+
ensure_ascii=False,
|
|
949
|
+
indent=4,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
remaining_short = sum(
|
|
953
|
+
1 for row in final_rows if len(str(row.get("term_info", ""))) < 50
|
|
954
|
+
)
|
|
955
|
+
return remaining_short
|
|
956
|
+
|
|
957
|
+
def _extract_terms_from_ontology(self, ontology: Any) -> List[str]:
|
|
958
|
+
"""
|
|
959
|
+
Collect unique term names from `ontology.type_taxonomies.taxonomies`,
|
|
960
|
+
falling back to `ontology.taxonomies` if needed.
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
list[str]: Sorted unique term list.
|
|
964
|
+
"""
|
|
965
|
+
type_taxonomies = getattr(ontology, "type_taxonomies", None)
|
|
966
|
+
taxonomies = (
|
|
967
|
+
getattr(type_taxonomies, "taxonomies", None)
|
|
968
|
+
if type_taxonomies is not None
|
|
969
|
+
else getattr(ontology, "taxonomies", None)
|
|
970
|
+
)
|
|
971
|
+
unique_terms: set[str] = set()
|
|
972
|
+
if taxonomies:
|
|
973
|
+
for row in taxonomies:
|
|
974
|
+
parent_term = (
|
|
975
|
+
getattr(row, "parent", None)
|
|
976
|
+
if not isinstance(row, dict)
|
|
977
|
+
else row.get("parent")
|
|
978
|
+
)
|
|
979
|
+
child_term = (
|
|
980
|
+
getattr(row, "child", None)
|
|
981
|
+
if not isinstance(row, dict)
|
|
982
|
+
else row.get("child")
|
|
983
|
+
)
|
|
984
|
+
if parent_term:
|
|
985
|
+
unique_terms.add(str(parent_term))
|
|
986
|
+
if child_term:
|
|
987
|
+
unique_terms.add(str(child_term))
|
|
988
|
+
return sorted(unique_terms)
|
|
989
|
+
|
|
990
|
+
def preprocess_context_from_ontology(
|
|
991
|
+
self,
|
|
992
|
+
ontology: Any,
|
|
993
|
+
processed_dir: str | Path,
|
|
994
|
+
dataset_name: str = "GeoNames",
|
|
995
|
+
num_workers: int = 2,
|
|
996
|
+
provider: Optional[Callable[[str], str]] = None,
|
|
997
|
+
max_retries: int = 5,
|
|
998
|
+
) -> Path:
|
|
999
|
+
"""
|
|
1000
|
+
Build `{id, term, term_info}` rows from an ontology object.
|
|
1001
|
+
|
|
1002
|
+
Always regenerates the fixed-name file `rwthdbis_onto_processed.json`,
|
|
1003
|
+
performing:
|
|
1004
|
+
- Parallel generation of term_info in shards (`_execute_for_terms`),
|
|
1005
|
+
- Re-inference rounds for short entries (`_re_infer_short_entries`),
|
|
1006
|
+
- Final merge and cleanup,
|
|
1007
|
+
- Updates `self.context_json_path`.
|
|
1008
|
+
|
|
1009
|
+
Filenames under `processed_dir`:
|
|
1010
|
+
- merged: `rwthdbis_onto_processed.json`
|
|
1011
|
+
- shards: `rwthdbis_onto_type_part{idx}.json`
|
|
1012
|
+
- re-infer shards: `rwthdbis_onto_re_inference{idx}.json`
|
|
1013
|
+
- re-infer merged: `rwthdbis_onto_Types_re_inference.json`
|
|
1014
|
+
|
|
1015
|
+
Returns:
|
|
1016
|
+
Path: The merged context JSON path (`rwthdbis_onto_processed.json`).
|
|
1017
|
+
"""
|
|
1018
|
+
provider = provider or partial(
|
|
1019
|
+
self._default_gpt_inference_with_dataset, dataset_name=dataset_name
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
processed_dir = Path(processed_dir)
|
|
1023
|
+
processed_dir.mkdir(parents=True, exist_ok=True)
|
|
1024
|
+
|
|
1025
|
+
merged_path = processed_dir / "rwthdbis_onto_processed.json"
|
|
1026
|
+
if merged_path.exists():
|
|
1027
|
+
try:
|
|
1028
|
+
merged_path.unlink()
|
|
1029
|
+
except Exception:
|
|
1030
|
+
pass
|
|
1031
|
+
|
|
1032
|
+
worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
|
|
1033
|
+
shard_paths = [
|
|
1034
|
+
processed_dir / f"rwthdbis_onto_type_part{index}.json"
|
|
1035
|
+
for index in range(worker_count)
|
|
1036
|
+
]
|
|
1037
|
+
re_shard_paths = [
|
|
1038
|
+
processed_dir / f"rwthdbis_onto_re_inference{index}.json"
|
|
1039
|
+
for index in range(worker_count)
|
|
1040
|
+
]
|
|
1041
|
+
re_merged_path = processed_dir / "rwthdbis_onto_Types_re_inference.json"
|
|
1042
|
+
|
|
1043
|
+
# Remove any leftover shards
|
|
1044
|
+
for path in shard_paths + re_shard_paths + [re_merged_path]:
|
|
1045
|
+
try:
|
|
1046
|
+
if path.exists():
|
|
1047
|
+
path.unlink()
|
|
1048
|
+
except Exception:
|
|
1049
|
+
pass
|
|
1050
|
+
|
|
1051
|
+
unique_terms = self._extract_terms_from_ontology(ontology)
|
|
1052
|
+
print(f"[Preprocess] Unique terms from ontology: {len(unique_terms)}")
|
|
1053
|
+
|
|
1054
|
+
self._execute_for_terms(
|
|
1055
|
+
terms=unique_terms,
|
|
1056
|
+
merged_path=merged_path,
|
|
1057
|
+
shard_paths=shard_paths,
|
|
1058
|
+
provider=provider,
|
|
1059
|
+
dataset_name=dataset_name,
|
|
1060
|
+
num_workers=worker_count,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
retry_round = 0
|
|
1064
|
+
while retry_round < max_retries:
|
|
1065
|
+
remaining_count = self._re_infer_short_entries(
|
|
1066
|
+
merged_path=merged_path,
|
|
1067
|
+
re_shard_paths=re_shard_paths,
|
|
1068
|
+
re_merged_path=re_merged_path,
|
|
1069
|
+
provider=provider,
|
|
1070
|
+
dataset_name=dataset_name,
|
|
1071
|
+
num_workers=worker_count,
|
|
1072
|
+
)
|
|
1073
|
+
print(
|
|
1074
|
+
f"[Preprocess] Re-infer round {retry_round + 1} done. Remaining short entries: {remaining_count}"
|
|
1075
|
+
)
|
|
1076
|
+
retry_round += 1
|
|
1077
|
+
if remaining_count == 0:
|
|
1078
|
+
break
|
|
1079
|
+
|
|
1080
|
+
print(f"[Preprocess] Done. Merged context at: {merged_path}")
|
|
1081
|
+
self.context_json_path = str(merged_path)
|
|
1082
|
+
return merged_path
|