OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1082 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import random
18
+ import re
19
+ import platform
20
+ from concurrent.futures import ThreadPoolExecutor, as_completed
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Tuple, Callable
23
+ from functools import partial
24
+ from tqdm.auto import tqdm
25
+ import g4f
26
+ from g4f.client import Client as _G4FClient
27
+ import torch
28
+ from datasets import Dataset, DatasetDict
29
+ from transformers import (
30
+ AutoTokenizer,
31
+ AutoModelForSequenceClassification,
32
+ DataCollatorWithPadding,
33
+ Trainer,
34
+ TrainingArguments,
35
+ set_seed,
36
+ )
37
+
38
+ from ...base import AutoLearner
39
+
40
+
41
+ class RWTHDBISSFTLearner(AutoLearner):
42
+ """
43
+ Supervised classifier for (parent, child) taxonomy edges.
44
+
45
+ Model input format:
46
+ "<relation template> ## <optional context>"
47
+
48
+ Context building:
49
+ If no `context_json_path` is provided, the learner precomputes a fixed-name
50
+ context file `rwthdbis_onto_processed.json` under `output_dir/context/`
51
+ from the ontology terms and stores the path in `self.context_json_path`.
52
+
53
+ Attributes:
54
+ model_name: Hugging Face model identifier.
55
+ output_dir: Directory where checkpoints and tokenizer are saved/loaded.
56
+ min_predictions: If no candidate is predicted positive, return the top-k
57
+ by positive probability (k = min_predictions).
58
+ max_length: Maximum tokenized length for inputs.
59
+ per_device_train_batch_size: Micro-batch size per device.
60
+ gradient_accumulation_steps: Gradient accumulation steps.
61
+ num_train_epochs: Number of training epochs.
62
+ learning_rate: Optimizer LR.
63
+ weight_decay: Weight decay for AdamW.
64
+ logging_steps: Logging interval for Trainer.
65
+ save_strategy: HF saving strategy (e.g., 'epoch').
66
+ save_total_limit: Max checkpoints to keep.
67
+ fp16: Enable FP16 mixed precision.
68
+ bf16: Enable BF16 mixed precision (on supported hardware).
69
+ seed: Random seed for reproducibility.
70
+ negative_ratio: Number of negatives per positive during training.
71
+ bidirectional_templates: If True, also add reversed template examples.
72
+ context_json_path: Path to the preprocessed term-context JSON. If None,
73
+ the file is generated with the fixed prefix `rwthdbis_onto_*`.
74
+ ontology_name: Logical dataset/domain label used in prompts and filtering
75
+ (filenames still use the fixed `rwthdbis_onto_*` prefix).
76
+ device: user-defined argument as 'cuda' or 'cpu'.
77
+ model: Loaded/initialized `AutoModelForSequenceClassification`.
78
+ tokenizer: Loaded/initialized `AutoTokenizer`.
79
+ """
80
+
81
+ # Sentences containing any of these phrases are pruned from term_info.
82
+ _CONTEXT_REMOVALS = [
83
+ "couldn't find any",
84
+ "does not require",
85
+ "assist you further",
86
+ "feel free to",
87
+ "I'm currently unable",
88
+ "the search results",
89
+ "I'm unable to",
90
+ "recommend referring directly",
91
+ "bear with me",
92
+ "searching for the most relevant information",
93
+ "I'm currently checking the most relevant",
94
+ "already in English",
95
+ "require further",
96
+ "any additional information",
97
+ "already an English",
98
+ "don't have information",
99
+ "I'm sorry,",
100
+ "For further exploration",
101
+ "For more detailed information",
102
+ ]
103
+
104
+ def __init__(
105
+ self,
106
+ min_predictions: int = 1,
107
+ model_name: str = "distilroberta-base",
108
+ output_dir: str = "./results/taxonomy-discovery",
109
+ device: str = "cpu",
110
+ max_length: int = 256,
111
+ per_device_train_batch_size: int = 8,
112
+ gradient_accumulation_steps: int = 4,
113
+ num_train_epochs: int = 1,
114
+ learning_rate: float = 2e-5,
115
+ weight_decay: float = 0.01,
116
+ logging_steps: int = 25,
117
+ save_strategy: str = "epoch",
118
+ save_total_limit: int = 1,
119
+ fp16: bool = True,
120
+ bf16: bool = False,
121
+ seed: int = 42,
122
+ negative_ratio: int = 5,
123
+ bidirectional_templates: bool = True,
124
+ context_json_path: Optional[str] = None,
125
+ ontology_name: str = "Geonames",
126
+ ) -> None:
127
+ """
128
+ Initialize the taxonomy-edge learner and set training/inference knobs.
129
+
130
+ Notes:
131
+ - Output artifacts are written under `output_dir`, including
132
+ the model weights and tokenizer (for later `from_pretrained` loads).
133
+ - If `context_json_path` is not provided, a new context file named
134
+ `rwthdbis_onto_processed.json` is generated under `output_dir/context/`.
135
+ """
136
+ super().__init__()
137
+
138
+ self.model_name = model_name
139
+ safe_model_name = model_name.replace("/", "__")
140
+
141
+ resolved_output = output_dir.format(model_name=safe_model_name)
142
+ self.output_dir = str(Path(resolved_output))
143
+ Path(self.output_dir).mkdir(parents=True, exist_ok=True)
144
+
145
+ # Store provided argument values as-is (types are enforced by callers).
146
+ self.min_predictions = min_predictions
147
+ self.max_length = max_length
148
+ self.per_device_train_batch_size = per_device_train_batch_size
149
+ self.gradient_accumulation_steps = gradient_accumulation_steps
150
+ self.num_train_epochs = num_train_epochs
151
+ self.learning_rate = learning_rate
152
+ self.weight_decay = weight_decay
153
+ self.logging_steps = logging_steps
154
+ self.save_strategy = save_strategy
155
+ self.save_total_limit = save_total_limit
156
+ self.fp16 = fp16
157
+ self.bf16 = bf16
158
+ self.seed = seed
159
+
160
+ self.negative_ratio = negative_ratio
161
+ self.bidirectional_templates = bidirectional_templates
162
+ self.context_json_path = context_json_path
163
+
164
+ self.ontology_name = ontology_name
165
+ self.device = device
166
+ self.model: Optional[AutoModelForSequenceClassification] = None
167
+ self.tokenizer: Optional[AutoTokenizer] = None
168
+
169
+ # Context caches built from the context JSON.
170
+ self._context_exact: Dict[str, str] = {} # lower(term) -> info
171
+ self._context_rows: List[
172
+ Dict[str, str]
173
+ ] = [] # [{'term': str, 'term_info': str}, ...]
174
+
175
+ def _is_windows(self) -> bool:
176
+ """Return True if the current OS is Windows (NT)."""
177
+ return (os.name == "nt") or (platform.system().lower() == "windows")
178
+
179
+ def _normalize_text(self, raw_text: str, *, drop_questions: bool = False) -> str:
180
+ """
181
+ Normalize plain text consistently across the pipeline.
182
+
183
+ Operations:
184
+ - Remove markdown-like link patterns (e.g., '[[1]](http://...)').
185
+ - Replace newlines with spaces; collapse repeated spaces.
186
+ - Optionally drop sentences containing '?' (useful for model generations).
187
+
188
+ Args:
189
+ raw_text: Input text to normalize.
190
+ drop_questions: If True, filter out sentences with '?'.
191
+
192
+ Returns:
193
+ str: Cleaned single-line string.
194
+ """
195
+ if raw_text is None:
196
+ return ""
197
+ text = str(raw_text)
198
+
199
+ # Remove simple markdown link artifacts like [[1]](http://...)
200
+ text = re.sub(r"\[\[\d+\]\]\(https?://[^\)]+\)", "", text)
201
+
202
+ # Replace newlines with spaces and collapse multiple spaces
203
+ text = text.replace("\n", " ")
204
+ text = re.sub(r"\s{2,}", " ", text)
205
+
206
+ if drop_questions:
207
+ sentences = [s.strip() for s in text.split(".")]
208
+ sentences = [s for s in sentences if s and "?" not in s]
209
+ text = ". ".join(sentences)
210
+
211
+ return text.strip()
212
+
213
+ def _default_gpt_inference_with_dataset(self, term: str, dataset_name: str) -> str:
214
+ """
215
+ Generate a plain-text description for `term`, conditioned on `dataset_name`,
216
+ via g4f (best-effort). Falls back to an empty string on failure.
217
+
218
+ The raw output is then normalized with `_normalize_text(drop_questions=True)`.
219
+
220
+ Args:
221
+ term: Term to describe.
222
+ dataset_name: Ontology/domain name used in the prompt.
223
+
224
+ Returns:
225
+ str: Cleaned paragraph describing the term, or "" on failure.
226
+ """
227
+ prompt = (
228
+ f"Here is a: {term}, which is of domain name :{dataset_name}, translate it into english, "
229
+ "Provide as detailed a definition of this term as possible in plain text.without any markdown format."
230
+ "No reference link in result. "
231
+ "- Focus on intrinsic properties; do not name other entities or explicit relationships.\n"
232
+ "- Include classification/type, defining features, scope/scale, roles/functions, and measurable attributes when applicable.\n"
233
+ "Output: Plain text paragraphs only, neutral and factual."
234
+ f"Make sure all provided information can be used for discovering implicit relation of other {dataset_name} term, but don't mention the relation in result."
235
+ )
236
+
237
+ try:
238
+ client = _G4FClient()
239
+ response = client.chat.completions.create(
240
+ model=g4f.models.default,
241
+ messages=[{"role": "user", "content": prompt}],
242
+ )
243
+ raw_text = (
244
+ response.choices[0].message.content
245
+ if response and response.choices
246
+ else ""
247
+ )
248
+ except Exception:
249
+ raw_text = "" # best-effort fallback
250
+
251
+ return self._normalize_text(raw_text, drop_questions=True)
252
+
253
+ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
254
+ """
255
+ AutoLearner hook: route to training or prediction.
256
+
257
+ Args:
258
+ data: Ontology-like object (has `.taxonomies` or `.type_taxonomies.taxonomies`).
259
+ test: If True, run inference; otherwise, train a model.
260
+
261
+ Returns:
262
+ If test=True, a list of accepted edges as dicts with keys `parent` and `child`;
263
+ otherwise None.
264
+ """
265
+ return self._predict_pairs(data) if test else self._train_from_pairs(data)
266
+
267
+ def _train_from_pairs(self, train_data: Any) -> None:
268
+ """
269
+ Train a binary classifier from ontology pairs.
270
+
271
+ Steps:
272
+ 1) (Re)build the term-context JSON unless `context_json_path` is set.
273
+ 2) Extract positive (parent, child) edges from `train_data`.
274
+ 3) Sample negatives at `negative_ratio`.
275
+ 4) Tokenize, instantiate HF Trainer, train, and save.
276
+
277
+ Args:
278
+ train_data: Ontology-like object with `.type_taxonomies.taxonomies`
279
+ (preferred) or `.taxonomies`, each item providing `parent` and `child`.
280
+
281
+ Raises:
282
+ ValueError: If no positive pairs are found.
283
+
284
+ Side Effects:
285
+ - Writes a trained model to `self.output_dir` (via `trainer.save_model`).
286
+ - Writes the tokenizer to `self.output_dir` (via `save_pretrained`).
287
+ - Sets `self.context_json_path` if it was previously unset.
288
+ The generated context file is named `rwthdbis_onto_processed.json`.
289
+ """
290
+ # Always (re)build context from ontology unless an explicit file is provided
291
+ if not self.context_json_path:
292
+ context_dir = Path(self.output_dir) / "context"
293
+ context_dir.mkdir(parents=True, exist_ok=True)
294
+ processed_context_file = context_dir / "rwthdbis_onto_processed.json"
295
+
296
+ # Remove stale file then regenerate
297
+ if processed_context_file.exists():
298
+ try:
299
+ processed_context_file.unlink()
300
+ except Exception:
301
+ pass
302
+
303
+ self.preprocess_context_from_ontology(
304
+ ontology=train_data,
305
+ processed_dir=context_dir,
306
+ dataset_name=self.ontology_name,
307
+ num_workers=max(1, min(os.cpu_count() or 2, 4)),
308
+ provider=partial(
309
+ self._default_gpt_inference_with_dataset,
310
+ dataset_name=self.ontology_name,
311
+ ),
312
+ max_retries=5,
313
+ )
314
+ self.context_json_path = str(processed_context_file)
315
+
316
+ # Reproducibility
317
+ set_seed(self.seed)
318
+ random.seed(self.seed)
319
+ torch.manual_seed(self.seed)
320
+ if torch.cuda.is_available():
321
+ torch.cuda.manual_seed_all(self.seed)
322
+
323
+ # Build labeled pairs from ontology; context comes from preprocessed map
324
+ positive_pairs = self._extract_positive_pairs(train_data)
325
+ if not positive_pairs:
326
+ raise ValueError("No positive (parent, child) pairs found in train_data.")
327
+
328
+ entity_names = sorted(
329
+ {parent for parent, _ in positive_pairs}
330
+ | {child for _, child in positive_pairs}
331
+ )
332
+ negative_pairs = self._generate_negatives(
333
+ positives=positive_pairs,
334
+ entities=entity_names,
335
+ ratio=self.negative_ratio,
336
+ )
337
+
338
+ labels, input_texts = self._build_text_dataset(positive_pairs, negative_pairs)
339
+ dataset_dict = DatasetDict(
340
+ {"train": Dataset.from_dict({"label": labels, "text": input_texts})}
341
+ )
342
+
343
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
344
+ # Ensure a pad token exists for robust padding across models.
345
+ if self.tokenizer.pad_token is None:
346
+ self.tokenizer.pad_token = (
347
+ getattr(self.tokenizer, "eos_token", None)
348
+ or getattr(self.tokenizer, "sep_token", None)
349
+ or getattr(self.tokenizer, "cls_token", None)
350
+ )
351
+
352
+ def tokenize_batch(batch: Dict[str, List[str]]):
353
+ """Tokenize a batch of input texts for HF Datasets mapping."""
354
+ return self.tokenizer(
355
+ batch["text"], truncation=True, max_length=self.max_length
356
+ )
357
+
358
+ tokenized_dataset = dataset_dict.map(
359
+ tokenize_batch, batched=True, remove_columns=["text"]
360
+ )
361
+ data_collator = DataCollatorWithPadding(self.tokenizer)
362
+
363
+ self.model = AutoModelForSequenceClassification.from_pretrained(
364
+ self.model_name,
365
+ num_labels=2,
366
+ id2label={0: "incorrect", 1: "correct"},
367
+ label2id={"incorrect": 0, "correct": 1},
368
+ )
369
+ # Ensure model has a pad_token_id if tokenizer provides one.
370
+ if (
371
+ getattr(self.model.config, "pad_token_id", None) is None
372
+ and self.tokenizer.pad_token_id is not None
373
+ ):
374
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
375
+
376
+ training_args = TrainingArguments(
377
+ output_dir=self.output_dir,
378
+ learning_rate=self.learning_rate,
379
+ per_device_train_batch_size=self.per_device_train_batch_size,
380
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
381
+ num_train_epochs=self.num_train_epochs,
382
+ weight_decay=self.weight_decay,
383
+ save_strategy=self.save_strategy,
384
+ save_total_limit=self.save_total_limit,
385
+ logging_steps=self.logging_steps,
386
+ dataloader_pin_memory=bool(torch.cuda.is_available()),
387
+ fp16=self.fp16,
388
+ bf16=self.bf16,
389
+ report_to="none",
390
+ save_safetensors=True,
391
+ )
392
+
393
+ trainer = Trainer(
394
+ model=self.model,
395
+ args=training_args,
396
+ train_dataset=tokenized_dataset["train"],
397
+ tokenizer=self.tokenizer,
398
+ data_collator=data_collator,
399
+ )
400
+ trainer.train()
401
+ trainer.save_model()
402
+ # Persist tokenizer alongside the model for from_pretrained() loads.
403
+ self.tokenizer.save_pretrained(self.output_dir)
404
+
405
+ def _predict_pairs(self, eval_data: Any) -> List[Dict[str, str]]:
406
+ """
407
+ Score candidate pairs and return those predicted as positive.
408
+
409
+ If no pair is predicted positive but `min_predictions` > 0, the top-k
410
+ pairs by positive probability are returned.
411
+
412
+ Args:
413
+ eval_data: Ontology-like object with either `.pairs` (preferred) or
414
+ `.type_taxonomies.taxonomies` / `.taxonomies`.
415
+
416
+ Returns:
417
+ list[dict]: Each dict has keys `parent` and `child`.
418
+ """
419
+ import torch.nn.functional as F
420
+
421
+ self._ensure_loaded_for_inference()
422
+
423
+ candidate_pairs = self._extract_pairs_for_eval(eval_data)
424
+ if not candidate_pairs:
425
+ return []
426
+
427
+ accepted_pairs: List[Dict[str, str]] = []
428
+ scored_candidates: List[Tuple[float, str, str, int]] = []
429
+
430
+ self.model.eval()
431
+ with torch.no_grad():
432
+ for parent_term, child_term in candidate_pairs:
433
+ input_text = self._format_input(parent_term, child_term)
434
+ inputs = self.tokenizer(
435
+ input_text,
436
+ return_tensors="pt",
437
+ truncation=True,
438
+ max_length=self.max_length,
439
+ )
440
+ inputs = {key: tensor.to(self.device) for key, tensor in inputs.items()}
441
+ logits = self.model(**inputs).logits
442
+ probabilities = F.softmax(logits, dim=-1).squeeze(0)
443
+ p_positive = float(probabilities[1].item())
444
+ predicted_label = int(torch.argmax(logits, dim=-1).item())
445
+ scored_candidates.append(
446
+ (p_positive, parent_term, child_term, predicted_label)
447
+ )
448
+ if predicted_label == 1:
449
+ accepted_pairs.append({"parent": parent_term, "child": child_term})
450
+
451
+ if accepted_pairs:
452
+ return accepted_pairs
453
+
454
+ top_k = max(0, int(self.min_predictions))
455
+ if top_k == 0:
456
+ return []
457
+ scored_candidates.sort(key=lambda item: item[0], reverse=True)
458
+ return [
459
+ {"parent": parent_term, "child": child_term}
460
+ for (_prob, parent_term, child_term, _pred) in scored_candidates[:top_k]
461
+ ]
462
+
463
+ def _ensure_loaded_for_inference(self) -> None:
464
+ """
465
+ Load model and tokenizer from `self.output_dir` if not already loaded.
466
+
467
+ Side Effects:
468
+ - Sets `self.model` and `self.tokenizer`.
469
+ - Moves the model to `self.device`.
470
+ - Ensures `tokenizer.pad_token_id` is set if model config provides one.
471
+ """
472
+ if self.model is not None and self.tokenizer is not None:
473
+ return
474
+ self.model = AutoModelForSequenceClassification.from_pretrained(
475
+ self.output_dir
476
+ ).to(self.device)
477
+ self.tokenizer = AutoTokenizer.from_pretrained(self.output_dir)
478
+ if (
479
+ self.tokenizer.pad_token_id is None
480
+ and getattr(self.model.config, "pad_token_id", None) is not None
481
+ ):
482
+ self.tokenizer.pad_token_id = self.model.config.pad_token_id
483
+
484
+ def _load_context_map(self) -> None:
485
+ """
486
+ Populate in-memory maps from the context JSON (`self.context_json_path`).
487
+
488
+ Builds:
489
+ - `_context_exact`: dict mapping lowercased term → term_info.
490
+ - `_context_rows`: list of dict rows with 'term' and 'term_info'.
491
+
492
+ If `context_json_path` is falsy or loading fails, both structures become empty.
493
+ """
494
+ if not self.context_json_path:
495
+ self._context_exact = {}
496
+ self._context_rows = []
497
+ return
498
+ try:
499
+ rows = json.load(open(self.context_json_path, "r", encoding="utf-8"))
500
+ self._context_exact = {
501
+ str(row.get("term", "")).strip().lower(): str(
502
+ row.get("term_info", "")
503
+ ).strip()
504
+ for row in rows
505
+ }
506
+ self._context_rows = [
507
+ {
508
+ "term": str(row.get("term", "")),
509
+ "term_info": str(row.get("term_info", "")),
510
+ }
511
+ for row in rows
512
+ ]
513
+ except Exception:
514
+ self._context_exact = {}
515
+ self._context_rows = []
516
+
517
+ def _lookup_context_info(self, raw_term: str) -> str:
518
+ """
519
+ Retrieve textual context for a term using exact and simple fuzzy matching.
520
+
521
+ - Exact: lowercased term lookup in `_context_exact`.
522
+ - Fuzzy: split `raw_term` by commas, strip whitespace; treat each piece
523
+ as a case-insensitive substring against row['term'].
524
+
525
+ Args:
526
+ raw_term: Original term string (possibly comma-separated).
527
+
528
+ Returns:
529
+ str: Concatenated matches' term_info ('.' joined). Empty string if none.
530
+ """
531
+ if not raw_term:
532
+ return ""
533
+ term_key = raw_term.strip().lower()
534
+ if term_key in self._context_exact:
535
+ return self._context_exact[term_key]
536
+
537
+ subterms = [re.sub(r"\s+", "", piece) for piece in raw_term.split(",")]
538
+ matched_infos: List[str] = []
539
+ for subterm in subterms:
540
+ if not subterm:
541
+ continue
542
+ lower_subterm = subterm.lower()
543
+ for row in self._context_rows:
544
+ if lower_subterm in row["term"].lower():
545
+ info = row.get("term_info", "")
546
+ if info:
547
+ matched_infos.append(info)
548
+ break # one hit per subterm
549
+ return ".".join(matched_infos)
550
+
551
+ def _extract_positive_pairs(self, ontology_obj: Any) -> List[Tuple[str, str]]:
552
+ """
553
+ Extract positive (parent, child) edges from an ontology-like object.
554
+
555
+ Reads from `ontology_obj.type_taxonomies.taxonomies` (preferred) or
556
+ falls back to `ontology_obj.taxonomies`. Each item must expose `parent`
557
+ and `child` as attributes or dict keys.
558
+
559
+ Returns:
560
+ list[tuple[str, str]]: (parent, child) pairs (may be empty).
561
+ """
562
+ type_taxonomies = getattr(ontology_obj, "type_taxonomies", None)
563
+ items = (
564
+ getattr(type_taxonomies, "taxonomies", None)
565
+ if type_taxonomies is not None
566
+ else getattr(ontology_obj, "taxonomies", None)
567
+ )
568
+ pairs: List[Tuple[str, str]] = []
569
+ if items:
570
+ for item in items:
571
+ parent_term = (
572
+ getattr(item, "parent", None)
573
+ if not isinstance(item, dict)
574
+ else item.get("parent")
575
+ )
576
+ child_term = (
577
+ getattr(item, "child", None)
578
+ if not isinstance(item, dict)
579
+ else item.get("child")
580
+ )
581
+ if parent_term and child_term:
582
+ pairs.append((str(parent_term), str(child_term)))
583
+ return pairs
584
+
585
+ def _extract_pairs_for_eval(self, ontology_obj: Any) -> List[Tuple[str, str]]:
586
+ """
587
+ Extract candidate pairs for evaluation.
588
+
589
+ Prefers `ontology_obj.pairs` if present; otherwise falls back to the
590
+ positive pairs from the ontology (see `_extract_positive_pairs`).
591
+
592
+ Returns:
593
+ list[tuple[str, str]]: Candidate (parent, child) pairs.
594
+ """
595
+ candidate_pairs = getattr(ontology_obj, "pairs", None)
596
+ if candidate_pairs:
597
+ pairs: List[Tuple[str, str]] = []
598
+ for item in candidate_pairs:
599
+ parent_term = (
600
+ getattr(item, "parent", None)
601
+ if not isinstance(item, dict)
602
+ else item.get("parent")
603
+ )
604
+ child_term = (
605
+ getattr(item, "child", None)
606
+ if not isinstance(item, dict)
607
+ else item.get("child")
608
+ )
609
+ if parent_term and child_term:
610
+ pairs.append((str(parent_term), str(child_term)))
611
+ return pairs
612
+ return self._extract_positive_pairs(ontology_obj)
613
+
614
+ def _generate_negatives(
615
+ self,
616
+ positives: List[Tuple[str, str]],
617
+ entities: List[str],
618
+ ratio: int,
619
+ ) -> List[Tuple[str, str]]:
620
+ """
621
+ Sample negative edges by excluding known positives and self-pairs.
622
+
623
+ Constructs the cartesian product of entities (excluding (x, x)),
624
+ removes all known positives, and samples up to `ratio * len(positives)`
625
+ negatives uniformly at random.
626
+
627
+ Args:
628
+ positives: Known positive edges.
629
+ entities: Unique set/list of entity terms.
630
+ ratio: Target negatives per positive (lower-bounded by 1×).
631
+
632
+ Returns:
633
+ list[tuple[str, str]]: Sampled negative pairs (may be smaller).
634
+ """
635
+ positive_set = set(positives)
636
+ all_possible = {
637
+ (parent, child)
638
+ for parent in entities
639
+ for child in entities
640
+ if parent != child
641
+ }
642
+ negative_candidates = list(all_possible - positive_set)
643
+
644
+ target_count = max(len(positive_set) * max(1, ratio), len(positive_set))
645
+ sample_count = min(target_count, len(negative_candidates))
646
+ return (
647
+ random.sample(negative_candidates, k=sample_count)
648
+ if sample_count > 0
649
+ else []
650
+ )
651
+
652
+ def _build_text_dataset(
653
+ self,
654
+ positives: List[Tuple[str, str]],
655
+ negatives: List[Tuple[str, str]],
656
+ ) -> Tuple[List[int], List[str]]:
657
+ """
658
+ Create parallel lists of labels and input texts for HF Datasets.
659
+
660
+ Builds formatted inputs using `_format_input`, and duplicates examples in
661
+ the reverse direction if `bidirectional_templates` is True.
662
+
663
+ Returns:
664
+ tuple[list[int], list[str]]: (labels, input_texts) where labels are
665
+ 1 for positive and 0 for negative.
666
+ """
667
+ self._load_context_map()
668
+
669
+ labels: List[int] = []
670
+ input_texts: List[str] = []
671
+
672
+ def add_example(parent_term: str, child_term: str, label_value: int) -> None:
673
+ """Append one (and optionally reversed) example to the dataset."""
674
+ input_texts.append(self._format_input(parent_term, child_term))
675
+ labels.append(label_value)
676
+ if self.bidirectional_templates:
677
+ input_texts.append(
678
+ self._format_input(child_term, parent_term, reverse=True)
679
+ )
680
+ labels.append(label_value)
681
+
682
+ for parent_term, child_term in positives:
683
+ add_example(parent_term, child_term, 1)
684
+ for parent_term, child_term in negatives:
685
+ add_example(parent_term, child_term, 0)
686
+
687
+ return labels, input_texts
688
+
689
+ def _format_input(
690
+ self, parent_term: str, child_term: str, reverse: bool = False
691
+ ) -> str:
692
+ """
693
+ Format a (parent, child) pair into relation text + optional context.
694
+
695
+ Returns:
696
+ str: "<relation template> [## Context. 'parent': ... 'child': ...]"
697
+ """
698
+ relation_text = (
699
+ f"{child_term} is a subclass / child / subtype / descendant class of {parent_term}"
700
+ if reverse
701
+ else f"{parent_term} is the superclass / parent / supertype / ancestor class of {child_term}"
702
+ )
703
+
704
+ parent_info = self._lookup_context_info(parent_term)
705
+ child_info = self._lookup_context_info(child_term)
706
+ if not parent_info and not child_info:
707
+ return relation_text
708
+
709
+ context_text = (
710
+ f"## Context. '{parent_term}': {parent_info} '{child_term}': {child_info}"
711
+ )
712
+ return f"{relation_text} {context_text}"
713
+
714
+ def _fill_bucket_threaded(
715
+ self, bucket_rows: List[dict], output_path: Path, provider: Callable[[str], str]
716
+ ) -> None:
717
+ """
718
+ Populate a shard with provider-generated `term_info` using threads.
719
+
720
+ Resumes from `output_path` if it already exists, periodically writes
721
+ progress (every ~10 items), and finally dumps the full bucket to disk.
722
+ """
723
+ start_index = 0
724
+ try:
725
+ if output_path.is_file():
726
+ existing_rows = json.load(open(output_path, "r", encoding="utf-8"))
727
+ if isinstance(existing_rows, list) and existing_rows:
728
+ bucket_rows[: len(existing_rows)] = existing_rows
729
+ start_index = len(existing_rows)
730
+ except Exception:
731
+ pass
732
+
733
+ for row_index in range(start_index, len(bucket_rows)):
734
+ try:
735
+ bucket_rows[row_index]["term_info"] = provider(
736
+ bucket_rows[row_index]["term"]
737
+ )
738
+ except Exception:
739
+ bucket_rows[row_index]["term_info"] = ""
740
+ if row_index % 10 == 1:
741
+ json.dump(
742
+ bucket_rows[: row_index + 1],
743
+ open(output_path, "w", encoding="utf-8"),
744
+ ensure_ascii=False,
745
+ indent=2,
746
+ )
747
+
748
+ json.dump(
749
+ bucket_rows,
750
+ open(output_path, "w", encoding="utf-8"),
751
+ ensure_ascii=False,
752
+ indent=2,
753
+ )
754
+
755
+ def _merge_part_files(
756
+ self, dataset_name: str, merged_path: Path, shard_paths: List[Path]
757
+ ) -> None:
758
+ """
759
+ Merge shard files into one JSON and filter boilerplate sentences.
760
+
761
+ - Reads shard lists/dicts from `shard_paths`.
762
+ - Drops sentences that contain markers in `_CONTEXT_REMOVALS` or the
763
+ `dataset_name` string.
764
+ - Normalizes the remaining text via `_normalize_text`.
765
+ - Writes merged JSON to `merged_path`, then best-effort deletes shards.
766
+ """
767
+ merged_rows: List[dict] = []
768
+ for shard_path in shard_paths:
769
+ try:
770
+ if not shard_path.is_file():
771
+ continue
772
+ part_content = json.load(open(shard_path, "r", encoding="utf-8"))
773
+ if isinstance(part_content, list):
774
+ merged_rows.extend(part_content)
775
+ elif isinstance(part_content, dict):
776
+ merged_rows.append(part_content)
777
+ except Exception:
778
+ continue
779
+
780
+ removal_markers = list(self._CONTEXT_REMOVALS) + [dataset_name]
781
+ for row in merged_rows:
782
+ term_info_raw = str(row.get("term_info", ""))
783
+ kept_sentences: List[str] = []
784
+ for sentence in term_info_raw.split("."):
785
+ sentence_no_links = re.sub(
786
+ r"\[\[\d+\]\]\(https?://[^\)]+\)", "", sentence
787
+ )
788
+ if any(marker in sentence_no_links for marker in removal_markers):
789
+ continue
790
+ kept_sentences.append(sentence_no_links)
791
+ row["term_info"] = self._normalize_text(
792
+ ".".join(kept_sentences), drop_questions=False
793
+ )
794
+
795
+ merged_path.parent.mkdir(parents=True, exist_ok=True)
796
+ json.dump(
797
+ merged_rows,
798
+ open(merged_path, "w", encoding="utf-8"),
799
+ ensure_ascii=False,
800
+ indent=4,
801
+ )
802
+
803
+ # best-effort cleanup
804
+ for shard_path in shard_paths:
805
+ try:
806
+ os.remove(shard_path)
807
+ except Exception:
808
+ pass
809
+
810
+ def _execute_for_terms(
811
+ self,
812
+ terms: List[str],
813
+ merged_path: Path,
814
+ shard_paths: List[Path],
815
+ provider: Callable[[str], str],
816
+ dataset_name: str,
817
+ num_workers: int = 2,
818
+ ) -> None:
819
+ """
820
+ Generate context for `terms`, writing shards to `shard_paths`, then merge.
821
+
822
+ Always uses threads (pickling-safe for instance methods).
823
+ Shows a tqdm progress bar and merges shards at the end.
824
+ """
825
+ worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
826
+ all_rows = [
827
+ {"id": index, "term": term, "term_info": ""}
828
+ for index, term in enumerate(terms)
829
+ ]
830
+
831
+ buckets: List[List[dict]] = [[] for _ in range(worker_count)]
832
+ for reversed_index, row in enumerate(reversed(all_rows)):
833
+ buckets[reversed_index % worker_count].append(row)
834
+
835
+ total_rows = len(terms)
836
+ progress_bar = tqdm(
837
+ total=total_rows, desc=f"{dataset_name} generation (threads)"
838
+ )
839
+
840
+ def run_bucket(bucket_rows: List[dict], out_path: Path) -> int:
841
+ self._fill_bucket_threaded(bucket_rows, out_path, provider)
842
+ return len(bucket_rows)
843
+
844
+ with ThreadPoolExecutor(max_workers=worker_count) as pool:
845
+ futures = [
846
+ pool.submit(
847
+ run_bucket, buckets[bucket_index], shard_paths[bucket_index]
848
+ )
849
+ for bucket_index in range(worker_count)
850
+ ]
851
+ for future in as_completed(futures):
852
+ completed_count = future.result()
853
+ if progress_bar:
854
+ progress_bar.update(completed_count)
855
+ if progress_bar:
856
+ progress_bar.close()
857
+
858
+ self._merge_part_files(dataset_name, merged_path, shard_paths)
859
+
860
+ def _re_infer_short_entries(
861
+ self,
862
+ merged_path: Path,
863
+ re_shard_paths: List[Path],
864
+ re_merged_path: Path,
865
+ provider: Callable[[str], str],
866
+ dataset_name: str,
867
+ num_workers: int,
868
+ ) -> int:
869
+ """
870
+ Re-query terms whose `term_info` is too short (< 50 chars).
871
+
872
+ Process:
873
+ - Read `merged_path`.
874
+ - Filter boilerplate using `_CONTEXT_REMOVALS` and `dataset_name`.
875
+ - Split into short/long groups by length 50.
876
+ - Regenerate short group with `provider` in parallel (threads).
877
+ - Merge regenerated + long back into `merged_path`.
878
+
879
+ Returns:
880
+ int: Count of rows still < 50 chars after re-inference.
881
+ """
882
+ merged_rows = json.load(open(merged_path, "r", encoding="utf-8"))
883
+
884
+ removal_markers = list(self._CONTEXT_REMOVALS) + [dataset_name]
885
+ short_rows: List[dict] = []
886
+ long_rows: List[dict] = []
887
+
888
+ for row in merged_rows:
889
+ term_info_raw = str(row.get("term_info", ""))
890
+ sentences = term_info_raw.split(".")
891
+ for marker in removal_markers:
892
+ sentences = [
893
+ sentence if marker not in sentence else "" for sentence in sentences
894
+ ]
895
+ filtered_info = self._normalize_text(
896
+ ".".join(sentences), drop_questions=False
897
+ )
898
+ row["term_info"] = filtered_info
899
+
900
+ (short_rows if len(filtered_info) < 50 else long_rows).append(row)
901
+
902
+ worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
903
+ buckets: List[List[dict]] = [[] for _ in range(worker_count)]
904
+ for row_index, row in enumerate(short_rows):
905
+ buckets[row_index % worker_count].append(row)
906
+
907
+ # Clean old re-inference shards
908
+ for path in re_shard_paths:
909
+ try:
910
+ os.remove(path)
911
+ except Exception:
912
+ pass
913
+
914
+ total_candidates = len(short_rows)
915
+ progress_bar = tqdm(
916
+ total=total_candidates, desc=f"{dataset_name} re-inference (threads)"
917
+ )
918
+
919
+ def run_bucket(bucket_rows: List[dict], out_path: Path) -> int:
920
+ self._fill_bucket_threaded(bucket_rows, out_path, provider)
921
+ return len(bucket_rows)
922
+
923
+ with ThreadPoolExecutor(max_workers=worker_count) as pool:
924
+ futures = [
925
+ pool.submit(
926
+ run_bucket, buckets[bucket_index], re_shard_paths[bucket_index]
927
+ )
928
+ for bucket_index in range(worker_count)
929
+ ]
930
+ for future in as_completed(futures):
931
+ completed_count = future.result()
932
+ if progress_bar:
933
+ progress_bar.update(completed_count)
934
+ if progress_bar:
935
+ progress_bar.close()
936
+
937
+ # Merge and write back
938
+ self._merge_part_files(dataset_name, re_merged_path, re_shard_paths)
939
+ new_rows = (
940
+ json.load(open(re_merged_path, "r", encoding="utf-8"))
941
+ if re_merged_path.is_file()
942
+ else []
943
+ )
944
+ final_rows = long_rows + new_rows
945
+ json.dump(
946
+ final_rows,
947
+ open(merged_path, "w", encoding="utf-8"),
948
+ ensure_ascii=False,
949
+ indent=4,
950
+ )
951
+
952
+ remaining_short = sum(
953
+ 1 for row in final_rows if len(str(row.get("term_info", ""))) < 50
954
+ )
955
+ return remaining_short
956
+
957
+ def _extract_terms_from_ontology(self, ontology: Any) -> List[str]:
958
+ """
959
+ Collect unique term names from `ontology.type_taxonomies.taxonomies`,
960
+ falling back to `ontology.taxonomies` if needed.
961
+
962
+ Returns:
963
+ list[str]: Sorted unique term list.
964
+ """
965
+ type_taxonomies = getattr(ontology, "type_taxonomies", None)
966
+ taxonomies = (
967
+ getattr(type_taxonomies, "taxonomies", None)
968
+ if type_taxonomies is not None
969
+ else getattr(ontology, "taxonomies", None)
970
+ )
971
+ unique_terms: set[str] = set()
972
+ if taxonomies:
973
+ for row in taxonomies:
974
+ parent_term = (
975
+ getattr(row, "parent", None)
976
+ if not isinstance(row, dict)
977
+ else row.get("parent")
978
+ )
979
+ child_term = (
980
+ getattr(row, "child", None)
981
+ if not isinstance(row, dict)
982
+ else row.get("child")
983
+ )
984
+ if parent_term:
985
+ unique_terms.add(str(parent_term))
986
+ if child_term:
987
+ unique_terms.add(str(child_term))
988
+ return sorted(unique_terms)
989
+
990
+ def preprocess_context_from_ontology(
991
+ self,
992
+ ontology: Any,
993
+ processed_dir: str | Path,
994
+ dataset_name: str = "GeoNames",
995
+ num_workers: int = 2,
996
+ provider: Optional[Callable[[str], str]] = None,
997
+ max_retries: int = 5,
998
+ ) -> Path:
999
+ """
1000
+ Build `{id, term, term_info}` rows from an ontology object.
1001
+
1002
+ Always regenerates the fixed-name file `rwthdbis_onto_processed.json`,
1003
+ performing:
1004
+ - Parallel generation of term_info in shards (`_execute_for_terms`),
1005
+ - Re-inference rounds for short entries (`_re_infer_short_entries`),
1006
+ - Final merge and cleanup,
1007
+ - Updates `self.context_json_path`.
1008
+
1009
+ Filenames under `processed_dir`:
1010
+ - merged: `rwthdbis_onto_processed.json`
1011
+ - shards: `rwthdbis_onto_type_part{idx}.json`
1012
+ - re-infer shards: `rwthdbis_onto_re_inference{idx}.json`
1013
+ - re-infer merged: `rwthdbis_onto_Types_re_inference.json`
1014
+
1015
+ Returns:
1016
+ Path: The merged context JSON path (`rwthdbis_onto_processed.json`).
1017
+ """
1018
+ provider = provider or partial(
1019
+ self._default_gpt_inference_with_dataset, dataset_name=dataset_name
1020
+ )
1021
+
1022
+ processed_dir = Path(processed_dir)
1023
+ processed_dir.mkdir(parents=True, exist_ok=True)
1024
+
1025
+ merged_path = processed_dir / "rwthdbis_onto_processed.json"
1026
+ if merged_path.exists():
1027
+ try:
1028
+ merged_path.unlink()
1029
+ except Exception:
1030
+ pass
1031
+
1032
+ worker_count = max(1, min(num_workers, os.cpu_count() or 2, 4))
1033
+ shard_paths = [
1034
+ processed_dir / f"rwthdbis_onto_type_part{index}.json"
1035
+ for index in range(worker_count)
1036
+ ]
1037
+ re_shard_paths = [
1038
+ processed_dir / f"rwthdbis_onto_re_inference{index}.json"
1039
+ for index in range(worker_count)
1040
+ ]
1041
+ re_merged_path = processed_dir / "rwthdbis_onto_Types_re_inference.json"
1042
+
1043
+ # Remove any leftover shards
1044
+ for path in shard_paths + re_shard_paths + [re_merged_path]:
1045
+ try:
1046
+ if path.exists():
1047
+ path.unlink()
1048
+ except Exception:
1049
+ pass
1050
+
1051
+ unique_terms = self._extract_terms_from_ontology(ontology)
1052
+ print(f"[Preprocess] Unique terms from ontology: {len(unique_terms)}")
1053
+
1054
+ self._execute_for_terms(
1055
+ terms=unique_terms,
1056
+ merged_path=merged_path,
1057
+ shard_paths=shard_paths,
1058
+ provider=provider,
1059
+ dataset_name=dataset_name,
1060
+ num_workers=worker_count,
1061
+ )
1062
+
1063
+ retry_round = 0
1064
+ while retry_round < max_retries:
1065
+ remaining_count = self._re_infer_short_entries(
1066
+ merged_path=merged_path,
1067
+ re_shard_paths=re_shard_paths,
1068
+ re_merged_path=re_merged_path,
1069
+ provider=provider,
1070
+ dataset_name=dataset_name,
1071
+ num_workers=worker_count,
1072
+ )
1073
+ print(
1074
+ f"[Preprocess] Re-infer round {retry_round + 1} done. Remaining short entries: {remaining_count}"
1075
+ )
1076
+ retry_round += 1
1077
+ if remaining_count == 0:
1078
+ break
1079
+
1080
+ print(f"[Preprocess] Done. Merged context at: {merged_path}")
1081
+ self.context_json_path = str(merged_path)
1082
+ return merged_path