OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1219 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Iterable
16
+ import json
17
+ from json.decoder import JSONDecodeError
18
+ import os
19
+ import random
20
+ import re
21
+
22
+ import torch
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM
24
+
25
+ from ...base import AutoLearner, AutoLLM
26
+
27
+ try:
28
+ from outlines.models import Transformers as OutlinesTFModel
29
+ from outlines.generate import json as outlines_generate_json
30
+ from pydantic import BaseModel
31
+
32
+ class _PredictedTypesSchema(BaseModel):
33
+ """Schema used when generating structured JSON { "types": [...] }."""
34
+
35
+ types: List[str]
36
+
37
+ OUTLINES_AVAILABLE: bool = True
38
+ except Exception:
39
+ # If outlines is unavailable, we will fall back to greedy decoding + regex parsing.
40
+ OUTLINES_AVAILABLE = False
41
+ _PredictedTypesSchema = None
42
+ OutlinesTFModel = None
43
+ outlines_generate_json = None
44
+
45
+
46
+ class LocalAutoLLM(AutoLLM):
47
+ """
48
+ Minimal local LLM helper.
49
+
50
+ - Inherits AutoLLM but overrides load/generate to avoid label_mapper.
51
+ - Optional 4-bit loading with `load_in_4bit=True` in .load().
52
+ - Greedy decoding by default (deterministic).
53
+ """
54
+
55
+ def __init__(self, device: str = "cpu", token: str = "") -> None:
56
+ """
57
+ Initialize the local LLM holder.
58
+
59
+ Parameters
60
+ ----------
61
+ device : str
62
+ Execution device: "cpu" or "cuda".
63
+ token : str
64
+ Optional auth token for private model hubs.
65
+ """
66
+ super().__init__(label_mapper=None, device=device, token=token)
67
+ self.model: Optional[AutoModelForCausalLM] = None
68
+ self.tokenizer: Optional[AutoTokenizer] = None
69
+
70
+ def load(self, model_id: str, *, load_in_4bit: bool = False) -> None:
71
+ """
72
+ Load a Hugging Face causal model + tokenizer and set deterministic
73
+ generation defaults.
74
+
75
+ Parameters
76
+ ----------
77
+ model_id : str
78
+ Model identifier resolvable by HF `from_pretrained`.
79
+ load_in_4bit : bool
80
+ If True and bitsandbytes is available, load using 4-bit quantization.
81
+ """
82
+ # Tokenizer
83
+ self.tokenizer = AutoTokenizer.from_pretrained(
84
+ model_id, padding_side="left", token=self.token
85
+ )
86
+ if self.tokenizer.pad_token is None:
87
+ self.tokenizer.pad_token = self.tokenizer.eos_token
88
+
89
+ # Model (optionally quantized)
90
+ if load_in_4bit:
91
+ from transformers import BitsAndBytesConfig
92
+
93
+ quantization_config = BitsAndBytesConfig(
94
+ load_in_4bit=True,
95
+ bnb_4bit_quant_type="nf4",
96
+ bnb_4bit_use_double_quant=True,
97
+ bnb_4bit_compute_dtype=torch.bfloat16,
98
+ )
99
+ self.model = AutoModelForCausalLM.from_pretrained(
100
+ model_id,
101
+ device_map="auto",
102
+ quantization_config=quantization_config,
103
+ token=self.token,
104
+ )
105
+ else:
106
+ device_map = (
107
+ "auto" if (self.device != "cpu" and torch.cuda.is_available()) else None
108
+ )
109
+ self.model = AutoModelForCausalLM.from_pretrained(
110
+ model_id,
111
+ device_map=device_map,
112
+ torch_dtype=torch.bfloat16
113
+ if torch.cuda.is_available()
114
+ else torch.float32,
115
+ token=self.token,
116
+ )
117
+
118
+ # Deterministic generation defaults
119
+ generation_cfg = self.model.generation_config
120
+ generation_cfg.do_sample = False
121
+ generation_cfg.temperature = None
122
+ generation_cfg.top_k = None
123
+ generation_cfg.top_p = None
124
+ generation_cfg.num_beams = 1
125
+
126
+ def generate(self, prompts: List[str], max_new_tokens: int = 128) -> List[str]:
127
+ """
128
+ Greedy-generate continuations for a list of prompts.
129
+
130
+ Parameters
131
+ ----------
132
+ prompts : List[str]
133
+ Prompts to generate for (batched).
134
+ max_new_tokens : int
135
+ Maximum number of new tokens per continuation.
136
+
137
+ Returns
138
+ -------
139
+ List[str]
140
+ Decoded new-token texts (no special tokens, stripped).
141
+ """
142
+ if self.model is None or self.tokenizer is None:
143
+ raise RuntimeError(
144
+ "Call .load(model_id) on LocalAutoLLM before generate()."
145
+ )
146
+
147
+ tokenized_batch = self.tokenizer(
148
+ prompts, return_tensors="pt", padding=True, truncation=True
149
+ )
150
+ input_seq_len = tokenized_batch["input_ids"].shape[1]
151
+ tokenized_batch = {
152
+ k: v.to(self.model.device) for k, v in tokenized_batch.items()
153
+ }
154
+
155
+ with torch.no_grad():
156
+ outputs = self.model.generate(
157
+ **tokenized_batch,
158
+ max_new_tokens=max_new_tokens,
159
+ pad_token_id=self.tokenizer.eos_token_id,
160
+ do_sample=False,
161
+ num_beams=1,
162
+ )
163
+
164
+ # Only return the newly generated part for each row in the batch
165
+ continuation_token_ids = outputs[:, input_seq_len:]
166
+ return [
167
+ self.tokenizer.decode(row, skip_special_tokens=True).strip()
168
+ for row in continuation_token_ids
169
+ ]
170
+
171
+
172
+ class AlexbekFewShotLearner(AutoLearner):
173
+ """
174
+ Text2Onto learner for LLMS4OL Task A (term & type extraction).
175
+
176
+ Public API (A1 + convenience):
177
+ - fit(train_docs_jsonl, terms2doc_json, sample_size=24, seed=42)
178
+ - predict_terms(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
179
+ - predict_types(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
180
+ - evaluate_extraction_f1(gold_item2docs_json, preds_jsonl, key="term"|"type") -> float
181
+
182
+ Option A (A2, term→types) bridge:
183
+ - predict_types_from_terms_option_a(...)
184
+ Reads your A1 results (docs→terms), predicts types for each term, and
185
+ writes two files: terms2types_pred.json + types2docs_pred.json
186
+ """
187
+
188
+ def __init__(self, model: LocalAutoLLM, device: str = "cpu", **_: Any) -> None:
189
+ """
190
+ Initialize learner state and canned prompts.
191
+
192
+ Parameters
193
+ ----------
194
+ model : LocalAutoLLM
195
+ Loaded local LLM helper instance.
196
+ device : str
197
+ Device name ("cpu" or "cuda").
198
+ """
199
+ super().__init__(**_)
200
+ self.model = model
201
+ self.device = device
202
+
203
+ # Few-shot exemplars for A1 (Docs→Terms) and for Docs→Types:
204
+ # Each exemplar is a tuple: (title, text, gold_list)
205
+ self._fewshot_terms_docs: List[Tuple[str, str, List[str]]] = []
206
+ self._fewshot_types_docs: List[Tuple[str, str, List[str]]] = []
207
+
208
+ # System prompts
209
+ self._system_prompt_terms = (
210
+ "You are an expert in ontology term extraction.\n"
211
+ "Extract only terms that explicitly appear in the document.\n"
212
+ 'Answer strictly as JSON: {"terms": ["..."]}\n'
213
+ )
214
+ self._system_prompt_types = (
215
+ "You are an expert in ontology type classification.\n"
216
+ "List ontology *types* that characterize the document’s terminology.\n"
217
+ 'Answer strictly as JSON: {"types": ["..."]}\n'
218
+ )
219
+
220
+ # Compiled regex for robust JSON extraction from LLM outputs
221
+ self._json_object_regex = re.compile(r"\{[^{}]*\}", re.S)
222
+ self._json_array_regex = re.compile(r"\[[^\]]*\]", re.S)
223
+
224
+ # Term→Types (Option A) specific prompt
225
+ self._system_prompt_term_to_types = (
226
+ "You are an expert in ontology and semantic type classification.\n"
227
+ "Given a term, predict its semantic types from the domain-specific ontology.\n"
228
+ 'Answer strictly as JSON:\n{"types": ["type1", "type2", "..."]}'
229
+ )
230
+
231
+ def fit(
232
+ self,
233
+ *,
234
+ train_docs_jsonl: str,
235
+ terms2doc_json: str,
236
+ sample_size: int = 24,
237
+ seed: int = 42,
238
+ ) -> None:
239
+ """
240
+ Build internal few-shot exemplars from a labeled training split.
241
+
242
+ Parameters
243
+ ----------
244
+ train_docs_jsonl : str
245
+ Path to JSONL (or tolerant JSON/JSONL) with train documents.
246
+ terms2doc_json : str
247
+ JSON mapping item -> [doc_id,...]; "item" can be a term or type.
248
+ sample_size : int
249
+ Number of exemplar documents to keep for few-shot prompting.
250
+ seed : int
251
+ RNG seed for reproducible sampling.
252
+ """
253
+ rng = random.Random(seed)
254
+
255
+ # Load documents and map doc_id -> row
256
+ document_map = self._load_documents_jsonl(train_docs_jsonl)
257
+ if not document_map:
258
+ raise FileNotFoundError(f"No documents found in: {train_docs_jsonl}")
259
+
260
+ # Load item -> [doc_ids]
261
+ item_to_docs_map = self._load_json(terms2doc_json)
262
+ if not isinstance(item_to_docs_map, dict):
263
+ raise ValueError(
264
+ f"{terms2doc_json} must be a JSON dict mapping item -> [doc_ids]"
265
+ )
266
+
267
+ # Reverse mapping: doc_id -> [items]
268
+ doc_id_to_items_map: Dict[str, List[str]] = {}
269
+ for item_label, doc_id_list in item_to_docs_map.items():
270
+ for doc_id in doc_id_list:
271
+ doc_id_to_items_map.setdefault(doc_id, []).append(item_label)
272
+
273
+ # Build candidate exemplars (title, text, gold_list)
274
+ exemplar_candidates: List[Tuple[str, str, List[str]]] = []
275
+ for doc_id, labeled_items in doc_id_to_items_map.items():
276
+ doc_row = document_map.get(doc_id)
277
+ if not doc_row:
278
+ continue
279
+ doc_title = str(doc_row.get("title", "")) # be defensive (may be None)
280
+ doc_text = self._to_text(
281
+ doc_row.get("text", "")
282
+ ) # string-ify list if needed
283
+ if not doc_text:
284
+ continue
285
+ gold_items = self._unique_preserve(
286
+ [s for s in labeled_items if isinstance(s, str)]
287
+ )
288
+ if gold_items:
289
+ exemplar_candidates.append((doc_title, doc_text, gold_items))
290
+
291
+ if not exemplar_candidates:
292
+ raise RuntimeError(
293
+ "No candidate docs with items found to build few-shot exemplars."
294
+ )
295
+
296
+ chosen_exemplars = rng.sample(
297
+ exemplar_candidates, k=min(sample_size, len(exemplar_candidates))
298
+ )
299
+ # Reuse exemplars for both docs→terms and docs→types prompting
300
+ self._fewshot_terms_docs = chosen_exemplars
301
+ self._fewshot_types_docs = chosen_exemplars
302
+
303
+ def predict_terms(
304
+ self,
305
+ *,
306
+ docs_test_jsonl: str,
307
+ out_jsonl: str,
308
+ max_new_tokens: int = 128,
309
+ few_shot_k: int = 6,
310
+ ) -> int:
311
+ """
312
+ Extract terms that explicitly appear in each document.
313
+
314
+ Writes one JSON object per line:
315
+ {"id": "<doc_id>", "terms": ["...", "...", ...]}
316
+
317
+ Parameters
318
+ ----------
319
+ docs_test_jsonl : str
320
+ Path to test/dev documents in JSONL or tolerant JSON/JSONL.
321
+ out_jsonl : str
322
+ Output JSONL path where predictions are written (one line per doc).
323
+ max_new_tokens : int
324
+ Max generation length.
325
+ few_shot_k : int
326
+ Number of few-shot exemplars to prepend per prompt.
327
+
328
+ Returns
329
+ -------
330
+ int
331
+ Number of lines written (i.e., number of processed documents).
332
+ """
333
+ if self.model is None or self.model.model is None:
334
+ raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
335
+
336
+ test_documents = self._load_documents_jsonl(docs_test_jsonl)
337
+ prompts: List[str] = []
338
+ document_order: List[str] = []
339
+
340
+ for document_id, document_row in test_documents.items():
341
+ title = str(document_row.get("title", ""))
342
+ text = self._to_text(document_row.get("text", ""))
343
+
344
+ fewshot_block = self._format_fewshot_block(
345
+ self._system_prompt_terms,
346
+ self._fewshot_terms_docs,
347
+ key="terms",
348
+ k=few_shot_k,
349
+ )
350
+ user_block = self._format_user_block(title, text)
351
+
352
+ prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
353
+ document_order.append(document_id)
354
+
355
+ generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
356
+ parsed_term_lists = [
357
+ self._parse_json_list(generated, key="terms") for generated in generations
358
+ ]
359
+
360
+ os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
361
+ lines_written = 0
362
+ with open(out_jsonl, "w", encoding="utf-8") as fp_out:
363
+ for document_id, term_list in zip(document_order, parsed_term_lists):
364
+ payload = {"id": document_id, "terms": self._unique_preserve(term_list)}
365
+ fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
366
+ lines_written += 1
367
+ return lines_written
368
+
369
+ def predict_types(
370
+ self,
371
+ *,
372
+ docs_test_jsonl: str,
373
+ out_jsonl: str,
374
+ max_new_tokens: int = 128,
375
+ few_shot_k: int = 6,
376
+ ) -> int:
377
+ """
378
+ Predict ontology types that characterize each document’s terminology.
379
+
380
+ Writes one JSON object per line:
381
+ {"id": "<doc_id>", "types": ["...", "...", ...]}
382
+
383
+ Parameters
384
+ ----------
385
+ docs_test_jsonl : str
386
+ Path to test/dev documents in JSONL or tolerant JSON/JSONL.
387
+ out_jsonl : str
388
+ Output JSONL path where predictions are written (one line per doc).
389
+ max_new_tokens : int
390
+ Max generation length.
391
+ few_shot_k : int
392
+ Number of few-shot exemplars to prepend per prompt.
393
+
394
+ Returns
395
+ -------
396
+ int
397
+ Number of lines written (i.e., number of processed documents).
398
+ """
399
+ if self.model is None or self.model.model is None:
400
+ raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
401
+
402
+ test_documents = self._load_documents_jsonl(docs_test_jsonl)
403
+ prompts: List[str] = []
404
+ document_order: List[str] = []
405
+
406
+ for document_id, document_row in test_documents.items():
407
+ title = str(document_row.get("title", ""))
408
+ text = self._to_text(document_row.get("text", ""))
409
+
410
+ fewshot_block = self._format_fewshot_block(
411
+ self._system_prompt_types,
412
+ self._fewshot_types_docs,
413
+ key="types",
414
+ k=few_shot_k,
415
+ )
416
+ user_block = self._format_user_block(title, text)
417
+
418
+ prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
419
+ document_order.append(document_id)
420
+
421
+ generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
422
+ parsed_type_lists = [
423
+ self._parse_json_list(generated, key="types") for generated in generations
424
+ ]
425
+
426
+ os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
427
+ lines_written = 0
428
+ with open(out_jsonl, "w", encoding="utf-8") as fp_out:
429
+ for document_id, type_list in zip(document_order, parsed_type_lists):
430
+ payload = {"id": document_id, "types": self._unique_preserve(type_list)}
431
+ fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
432
+ lines_written += 1
433
+ return lines_written
434
+
435
+ def evaluate_extraction_f1(
436
+ self,
437
+ gold_item2docs_json: str,
438
+ preds_jsonl: str,
439
+ *,
440
+ key: str = "term",
441
+ ) -> float:
442
+ """
443
+ Compute micro-F1 over (doc_id, item) pairs.
444
+
445
+ Parameters
446
+ ----------
447
+ gold_item2docs_json : str
448
+ JSON mapping item -> [doc_ids].
449
+ preds_jsonl : str
450
+ JSONL lines like {"id": "...", "terms":[...]} or {"id":"...","types":[...]}.
451
+ key : str
452
+ "term" or "type" depending on what you are evaluating.
453
+
454
+ Returns
455
+ -------
456
+ float
457
+ Micro-averaged F1 score.
458
+ """
459
+ item_to_doc_ids: Dict[str, List[str]] = self._load_json(gold_item2docs_json)
460
+
461
+ # Build gold: doc_id -> set(items)
462
+ gold_doc_to_items: Dict[str, set] = {}
463
+ for item_label, doc_id_list in item_to_doc_ids.items():
464
+ for document_id in doc_id_list:
465
+ gold_doc_to_items.setdefault(document_id, set()).add(
466
+ self._norm(item_label)
467
+ )
468
+
469
+ # Build predictions: doc_id -> set(items)
470
+ pred_doc_to_items: Dict[str, set] = {}
471
+ with open(preds_jsonl, "r", encoding="utf-8") as fp_in:
472
+ for line in fp_in:
473
+ row = json.loads(line.strip())
474
+ document_id = str(row.get("id", ""))
475
+ items_list = row.get("terms" if key == "term" else "types", [])
476
+ pred_doc_to_items[document_id] = {
477
+ self._norm(x) for x in items_list if isinstance(x, str)
478
+ }
479
+
480
+ # Micro counts
481
+ true_positive = false_positive = false_negative = 0
482
+ all_document_ids = set(gold_doc_to_items.keys()) | set(pred_doc_to_items.keys())
483
+ for document_id in all_document_ids:
484
+ gold_set = gold_doc_to_items.get(document_id, set())
485
+ pred_set = pred_doc_to_items.get(document_id, set())
486
+ true_positive += len(gold_set & pred_set)
487
+ false_positive += len(pred_set - gold_set)
488
+ false_negative += len(gold_set - pred_set)
489
+
490
+ precision = (
491
+ true_positive / (true_positive + false_positive)
492
+ if (true_positive + false_positive)
493
+ else 0.0
494
+ )
495
+ recall = (
496
+ true_positive / (true_positive + false_negative)
497
+ if (true_positive + false_negative)
498
+ else 0.0
499
+ )
500
+ f1 = (
501
+ 2 * precision * recall / (precision + recall)
502
+ if (precision + recall)
503
+ else 0.0
504
+ )
505
+ return f1
506
+
507
+ def predict_types_from_terms(
508
+ self,
509
+ *,
510
+ doc_terms_jsonl: Optional[str] = None, # formerly a1_results_jsonl
511
+ doc_terms_list: Optional[List[Dict]] = None, # formerly a1_results_list
512
+ few_shot_jsonl: Optional[
513
+ str
514
+ ] = None, # JSONL lines: {"term":"...", "types":[...]}
515
+ rag_terms_json: Optional[
516
+ str
517
+ ] = None, # JSON list; items may contain "term" and "RAG":[...]
518
+ random_few_shot: Optional[int] = 3,
519
+ model_id: str = "Qwen/Qwen2.5-1.5B-Instruct",
520
+ use_structured_output: bool = True,
521
+ seed: int = 42,
522
+ out_terms2types: str = "terms2types_pred.json",
523
+ out_types2docs: str = "types2docs_pred.json",
524
+ ) -> Dict[str, Any]:
525
+ """
526
+ Predict types for each unique term extracted per document and derive a types→docs map.
527
+
528
+ Parameters
529
+ ----------
530
+ doc_terms_jsonl : Optional[str]
531
+ Path to JSONL with lines like {"id": "...", "terms": [...]} or a JSON with {"results":[...]}.
532
+ doc_terms_list : Optional[List[Dict]]
533
+ In-memory results like [{"id":"...","extracted_terms":[...]}] or {"id":"...","terms":[...]}.
534
+ few_shot_jsonl : Optional[str]
535
+ Global few-shot exemplars: one JSON object per line with {"term": "...", "types":[...]}.
536
+ rag_terms_json : Optional[str]
537
+ Optional per-term RAG exemplars: a JSON list of {"term": "...", "RAG":[{"term": "...", "types":[...]}]}.
538
+ random_few_shot : Optional[int]
539
+ If provided, randomly select up to this many few-shot examples for each prediction.
540
+ model_id : str
541
+ HF model id used specifically for term→types predictions.
542
+ use_structured_output : bool
543
+ If True and outlines is available, enforce structured {"types":[...]} output.
544
+ seed : int
545
+ Random seed for reproducibility.
546
+ out_terms2types : str
547
+ Output JSON path for list of {"term": "...", "predicted_types":[...]}.
548
+ out_types2docs : str
549
+ Output JSON path for dict {"TYPE":[doc_ids,...], ...}.
550
+
551
+ Returns
552
+ -------
553
+ Dict[str, Any]
554
+ Summary with predictions and counts.
555
+ """
556
+ torch.manual_seed(seed)
557
+ if torch.cuda.is_available():
558
+ torch.cuda.manual_seed(seed)
559
+
560
+ # Load normalized document→terms results
561
+ doc_term_extractions = self._load_doc_term_extractions(
562
+ results_json_path=doc_terms_jsonl,
563
+ in_memory_results=doc_terms_list,
564
+ )
565
+ if not doc_term_extractions:
566
+ raise ValueError(
567
+ "No document→terms results provided (doc_terms_jsonl/doc_terms_list)."
568
+ )
569
+
570
+ # Prepare unique term list and term→doc occurrences
571
+ unique_terms = self._collect_unique_terms_from_extractions(doc_term_extractions)
572
+ term_to_doc_ids_map = self._build_term_to_doc_ids(doc_term_extractions)
573
+
574
+ # Load optional global few-shot examples
575
+ global_few_shot_examples: List[Dict] = []
576
+ if few_shot_jsonl and os.path.exists(few_shot_jsonl):
577
+ with open(few_shot_jsonl, "r", encoding="utf-8") as few_shot_file:
578
+ for raw_line in few_shot_file:
579
+ raw_line = raw_line.strip()
580
+ if not raw_line:
581
+ continue
582
+ try:
583
+ json_obj = json.loads(raw_line)
584
+ except Exception:
585
+ continue
586
+ if (
587
+ isinstance(json_obj, dict)
588
+ and "term" in json_obj
589
+ and "types" in json_obj
590
+ ):
591
+ global_few_shot_examples.append(json_obj)
592
+
593
+ # Optional per-term RAG examples: {normalized_term -> [examples]}
594
+ rag_examples_lookup: Dict[str, List[Dict]] = {}
595
+ if rag_terms_json and os.path.exists(rag_terms_json):
596
+ try:
597
+ rag_payload = self._load_json(rag_terms_json)
598
+ if isinstance(rag_payload, list):
599
+ for rag_item in rag_payload:
600
+ if isinstance(rag_item, dict):
601
+ normalized_term = self._normalize_term(
602
+ rag_item.get("term", "")
603
+ )
604
+ rag_examples_lookup[normalized_term] = rag_item.get(
605
+ "RAG", []
606
+ )
607
+ except Exception:
608
+ pass
609
+
610
+ # Load a small chat LLM dedicated to Term→Types
611
+ typing_model, typing_tokenizer = self._load_llm_for_types(model_id)
612
+
613
+ # Predict types per term
614
+ term_to_predicted_types_list: List[Dict] = []
615
+ for term_text in unique_terms:
616
+ normalized_term = self._normalize_term(term_text)
617
+
618
+ # Prefer per-term RAG for this term, else use global few-shot
619
+ few_shot_examples_for_term = (
620
+ rag_examples_lookup.get(normalized_term, None)
621
+ or global_few_shot_examples
622
+ )
623
+
624
+ # Build conversation and prompt
625
+ conversation_messages = self._build_conv_for_type_infer(
626
+ term=term_text,
627
+ few_shot_examples=few_shot_examples_for_term,
628
+ random_k=random_few_shot,
629
+ )
630
+ typing_prompt_string = self._apply_chat_template_safe_types(
631
+ typing_tokenizer, conversation_messages
632
+ )
633
+
634
+ predicted_types: List[str] = []
635
+ raw_generation_text: str = ""
636
+
637
+ # Structured JSON path (if requested and available)
638
+ if (
639
+ use_structured_output
640
+ and OUTLINES_AVAILABLE
641
+ and _PredictedTypesSchema is not None
642
+ ):
643
+ try:
644
+ outlines_model = OutlinesTFModel(typing_model, typing_tokenizer) # type: ignore
645
+ generator = outlines_generate_json(
646
+ outlines_model, _PredictedTypesSchema
647
+ ) # type: ignore
648
+ structured = generator(typing_prompt_string, max_tokens=512)
649
+ predicted_types = [
650
+ label for label in structured.types if isinstance(label, str)
651
+ ]
652
+ raw_generation_text = json.dumps(
653
+ {"types": predicted_types}, ensure_ascii=False
654
+ )
655
+ except Exception:
656
+ # Fall back to greedy decoding
657
+ use_structured_output = False
658
+
659
+ # Greedy decode fallback
660
+ if (
661
+ not use_structured_output
662
+ or not OUTLINES_AVAILABLE
663
+ or _PredictedTypesSchema is None
664
+ ):
665
+ tokenized_prompt = typing_tokenizer(
666
+ typing_prompt_string,
667
+ return_tensors="pt",
668
+ truncation=True,
669
+ max_length=2048,
670
+ )
671
+ if torch.cuda.is_available():
672
+ tokenized_prompt = {
673
+ name: tensor.cuda() for name, tensor in tokenized_prompt.items()
674
+ }
675
+ with torch.no_grad():
676
+ output_ids = typing_model.generate(
677
+ **tokenized_prompt,
678
+ max_new_tokens=256,
679
+ do_sample=False,
680
+ num_beams=1,
681
+ pad_token_id=typing_tokenizer.eos_token_id,
682
+ )
683
+ new_token_span = output_ids[0][tokenized_prompt["input_ids"].shape[1] :]
684
+ raw_generation_text = typing_tokenizer.decode(
685
+ new_token_span, skip_special_tokens=True
686
+ )
687
+ predicted_types = self._extract_types_from_text(raw_generation_text)
688
+
689
+ term_to_predicted_types_list.append(
690
+ {
691
+ "term": term_text,
692
+ "predicted_types": sorted(set(predicted_types)),
693
+ }
694
+ )
695
+
696
+ # 7) Build types→docs from (term→types) and (term→docs)
697
+ types_to_doc_id_set: Dict[str, set] = {}
698
+ for term_prediction in term_to_predicted_types_list:
699
+ normalized_term = self._normalize_term(term_prediction["term"])
700
+ doc_ids_for_term = term_to_doc_ids_map.get(normalized_term, [])
701
+ for type_label in term_prediction.get("predicted_types", []):
702
+ types_to_doc_id_set.setdefault(type_label, set()).update(
703
+ doc_ids_for_term
704
+ )
705
+
706
+ types_to_doc_ids: Dict[str, List[str]] = {
707
+ type_label: sorted(doc_id_set)
708
+ for type_label, doc_id_set in types_to_doc_id_set.items()
709
+ }
710
+
711
+ # 8) Save outputs
712
+ os.makedirs(os.path.dirname(out_terms2types) or ".", exist_ok=True)
713
+ with open(out_terms2types, "w", encoding="utf-8") as fp_terms2types:
714
+ json.dump(
715
+ term_to_predicted_types_list,
716
+ fp_terms2types,
717
+ ensure_ascii=False,
718
+ indent=2,
719
+ )
720
+
721
+ os.makedirs(os.path.dirname(out_types2docs) or ".", exist_ok=True)
722
+ with open(out_types2docs, "w", encoding="utf-8") as fp_types2docs:
723
+ json.dump(types_to_doc_ids, fp_types2docs, ensure_ascii=False, indent=2)
724
+
725
+ # Cleanup VRAM if any
726
+ del typing_model, typing_tokenizer
727
+ if torch.cuda.is_available():
728
+ torch.cuda.empty_cache()
729
+
730
+ return {
731
+ "terms2types_pred": term_to_predicted_types_list,
732
+ "types2docs_pred": types_to_doc_ids,
733
+ "unique_terms": len(unique_terms),
734
+ "types_count": len(types_to_doc_ids),
735
+ }
736
+
737
+ def _load_json(self, path: str) -> Dict[str, Any]:
738
+ """Load a JSON file from disk and return its parsed object."""
739
+ with open(path, "r", encoding="utf-8") as file_obj:
740
+ return json.load(file_obj)
741
+
742
+ def _iter_json_objects(self, blob: str) -> Iterable[Dict[str, Any]]:
743
+ """
744
+ Iterate over *all* JSON objects found inside a string.
745
+
746
+ Supports cases where multiple JSON objects are concatenated back-to-back
747
+ in a single line. It skips stray commas/whitespace between objects.
748
+
749
+ Parameters
750
+ ----------
751
+ blob : str
752
+ A string that may contain one or more JSON objects.
753
+
754
+ Yields
755
+ ------
756
+ Dict[str, Any]
757
+ Each parsed JSON object.
758
+ """
759
+ json_decoder = json.JSONDecoder()
760
+ cursor_index, text_length = 0, len(blob)
761
+ while cursor_index < text_length:
762
+ # Skip whitespace/commas between objects
763
+ while cursor_index < text_length and blob[cursor_index] in " \t\r\n,":
764
+ cursor_index += 1
765
+ if cursor_index >= text_length:
766
+ break
767
+ try:
768
+ json_obj, end_index = json_decoder.raw_decode(blob, idx=cursor_index)
769
+ except JSONDecodeError:
770
+ # Can't decode from this position; stop scanning this chunk
771
+ break
772
+ yield json_obj
773
+ cursor_index = end_index
774
+
775
+ def _load_documents_jsonl(self, path: str) -> Dict[str, Dict[str, Any]]:
776
+ """
777
+ Robust reader that supports:
778
+ • True JSONL (one object per line)
779
+ • Lines with multiple concatenated JSON objects
780
+ • Whole file as a JSON array
781
+
782
+ Returns
783
+ -------
784
+ Dict[str, Dict[str, Any]]
785
+ Mapping doc_id -> full document row.
786
+ """
787
+ documents_by_id: Dict[str, Dict[str, Any]] = {}
788
+
789
+ with open(path, "r", encoding="utf-8") as file_obj:
790
+ content = file_obj.read().strip()
791
+
792
+ # Case A: whole-file JSON array
793
+ if content.startswith("["):
794
+ try:
795
+ json_array = json.loads(content)
796
+ if isinstance(json_array, list):
797
+ for record in json_array:
798
+ if not isinstance(record, dict):
799
+ continue
800
+ document_id = str(
801
+ record.get("id")
802
+ or record.get("doc_id")
803
+ or (record.get("doc") or {}).get("id")
804
+ or ""
805
+ )
806
+ if document_id:
807
+ documents_by_id[document_id] = record
808
+ return documents_by_id
809
+ except Exception:
810
+ # Fall back to line-wise handling if array parsing fails
811
+ pass
812
+
813
+ # Case B: treat as JSONL-ish; parse *all* objects per line
814
+ for raw_line in content.splitlines():
815
+ line = raw_line.strip()
816
+ if not line:
817
+ continue
818
+ for record in self._iter_json_objects(line):
819
+ if not isinstance(record, dict):
820
+ continue
821
+ document_id = str(
822
+ record.get("id")
823
+ or record.get("doc_id")
824
+ or (record.get("doc") or {}).get("id")
825
+ or ""
826
+ )
827
+ if document_id:
828
+ documents_by_id[document_id] = record
829
+
830
+ return documents_by_id
831
+
832
+ def _to_text(self, text_field: Any) -> str:
833
+ """
834
+ Convert a 'text' field into a single string (handles list-of-strings).
835
+
836
+ Parameters
837
+ ----------
838
+ text_field : Any
839
+ The value found under "text" in the dataset row.
840
+
841
+ Returns
842
+ -------
843
+ str
844
+ A single-string representation of the text.
845
+ """
846
+ if isinstance(text_field, str):
847
+ return text_field
848
+ if isinstance(text_field, list):
849
+ return " ".join(str(part) for part in text_field)
850
+ return str(text_field) if text_field is not None else ""
851
+
852
+ def _unique_preserve(self, values: List[str]) -> List[str]:
853
+ """
854
+ Deduplicate values while preserving the original order.
855
+
856
+ Parameters
857
+ ----------
858
+ values : List[str]
859
+ Sequence possibly containing duplicates.
860
+
861
+ Returns
862
+ -------
863
+ List[str]
864
+ Sequence without duplicates, order preserved.
865
+ """
866
+ seen_values: set = set()
867
+ ordered_values: List[str] = []
868
+ for candidate in values:
869
+ if candidate not in seen_values:
870
+ seen_values.add(candidate)
871
+ ordered_values.append(candidate)
872
+ return ordered_values
873
+
874
+ def _norm(self, text: str) -> str:
875
+ """
876
+ Lowercased, single-spaced normalization (for comparisons).
877
+
878
+ Parameters
879
+ ----------
880
+ text : str
881
+ Input string.
882
+
883
+ Returns
884
+ -------
885
+ str
886
+ Normalized string.
887
+ """
888
+ return " ".join(text.lower().split())
889
+
890
+ def _normalize_term(self, term: str) -> str:
891
+ """
892
+ Normalization tailored for term keys / lookups.
893
+
894
+ Parameters
895
+ ----------
896
+ term : str
897
+ Term to normalize.
898
+
899
+ Returns
900
+ -------
901
+ str
902
+ Lowercased, trimmed and single-spaced term.
903
+ """
904
+ return " ".join(str(term).strip().split()).lower()
905
+
906
+ def _format_fewshot_block(
907
+ self,
908
+ system_prompt: str,
909
+ fewshot_examples: List[Tuple[str, str, List[str]]],
910
+ *,
911
+ key: str,
912
+ k: int = 6,
913
+ ) -> str:
914
+ """
915
+ Render a few-shot block like:
916
+
917
+ <SYSTEM PROMPT>
918
+
919
+ ### Example
920
+ User:
921
+ Title: ...
922
+ <text>
923
+ Assistant:
924
+ {"terms": [...]} or {"types": [...]}
925
+
926
+ Parameters
927
+ ----------
928
+ system_prompt : str
929
+ Instructional system text to prepend.
930
+ fewshot_examples : List[Tuple[str, str, List[str]]]
931
+ Examples as (title, text, labels_list).
932
+ key : str
933
+ Either "terms" or "types" depending on the task.
934
+ k : int
935
+ Number of examples to include.
936
+
937
+ Returns
938
+ -------
939
+ str
940
+ Formatted few-shot block text.
941
+ """
942
+ lines: List[str] = [system_prompt.strip(), ""]
943
+ for example_title, example_text, gold_list in fewshot_examples[:k]:
944
+ lines.append("### Example")
945
+ lines.append(f"User:\nTitle: {example_title}\n{example_text}")
946
+ lines.append(
947
+ f'Assistant:\n{{"{key}": '
948
+ + json.dumps(gold_list, ensure_ascii=False)
949
+ + "}"
950
+ )
951
+ return "\n".join(lines)
952
+
953
+ def _format_user_block(self, title: str, text: str) -> str:
954
+ """
955
+ Format the 'Task' block for the current document.
956
+
957
+ Parameters
958
+ ----------
959
+ title : str
960
+ Document title.
961
+ text : str
962
+ Document text (single string).
963
+
964
+ Returns
965
+ -------
966
+ str
967
+ Formatted user block.
968
+ """
969
+ return f"### Task\nUser:\nTitle: {title}\n{text}"
970
+
971
+ def _parse_json_list(self, generated_text: str, *, key: str) -> List[str]:
972
+ """
973
+ Extract a list from model output, trying:
974
+ 1) JSON object with the key ({"terms":[...]} or {"types":[...]}).
975
+ 2) Any top-level JSON array.
976
+ 3) Fallback: comma-split.
977
+
978
+ Parameters
979
+ ----------
980
+ generated_text : str
981
+ Raw generation text to parse.
982
+ key : str
983
+ "terms" or "types".
984
+
985
+ Returns
986
+ -------
987
+ List[str]
988
+ Parsed strings (best-effort).
989
+ """
990
+ # 1) Try a JSON object and read key
991
+ try:
992
+ object_match = self._json_object_regex.search(generated_text)
993
+ if object_match:
994
+ json_obj = json.loads(object_match.group(0))
995
+ json_array = json_obj.get(key)
996
+ if isinstance(json_array, list):
997
+ return [value for value in json_array if isinstance(value, str)]
998
+ except Exception:
999
+ pass
1000
+
1001
+ # 2) Any JSON array
1002
+ try:
1003
+ array_match = self._json_array_regex.search(generated_text)
1004
+ if array_match:
1005
+ json_array = json.loads(array_match.group(0))
1006
+ if isinstance(json_array, list):
1007
+ return [value for value in json_array if isinstance(value, str)]
1008
+ except Exception:
1009
+ pass
1010
+
1011
+ # 3) Fallback: comma-split (last resort)
1012
+ if "," in generated_text:
1013
+ return [
1014
+ part.strip().strip('"').strip("'")
1015
+ for part in generated_text.split(",")
1016
+ if part.strip()
1017
+ ]
1018
+ return []
1019
+
1020
+ def _apply_chat_template_safe_types(
1021
+ self, tokenizer: AutoTokenizer, messages: List[Dict[str, str]]
1022
+ ) -> str:
1023
+ """
1024
+ Safely build a prompt string for chat models. Uses the model's chat template
1025
+ when available; otherwise falls back to a simple concatenation.
1026
+ """
1027
+ try:
1028
+ return tokenizer.apply_chat_template(
1029
+ messages, add_generation_prompt=True, tokenize=False
1030
+ )
1031
+ except Exception:
1032
+ system_text = next(
1033
+ (m["content"] for m in messages if m.get("role") == "system"), ""
1034
+ )
1035
+ last_user_text = next(
1036
+ (m["content"] for m in reversed(messages) if m.get("role") == "user"),
1037
+ "",
1038
+ )
1039
+ return f"{system_text}\n\nUser:\n{last_user_text}\n\nAssistant:"
1040
+
1041
+ def _build_conv_for_type_infer(
1042
+ self,
1043
+ term: str,
1044
+ few_shot_examples: Optional[List[Dict]] = None,
1045
+ random_k: Optional[int] = None,
1046
+ ) -> List[Dict[str, str]]:
1047
+ """
1048
+ Create a chat-style conversation for a single term→types query,
1049
+ optionally prepending few-shot examples.
1050
+ """
1051
+ messages: List[Dict[str, str]] = [
1052
+ {"role": "system", "content": self._system_prompt_term_to_types}
1053
+ ]
1054
+ examples = list(few_shot_examples or [])
1055
+ if random_k and len(examples) > random_k:
1056
+ import random as _rnd
1057
+
1058
+ examples = _rnd.sample(examples, random_k)
1059
+ for exemplar in examples:
1060
+ example_term = exemplar.get("term", "")
1061
+ example_types = exemplar.get("types", [])
1062
+ messages.append({"role": "user", "content": f"Term: {example_term}"})
1063
+ messages.append(
1064
+ {
1065
+ "role": "assistant",
1066
+ "content": json.dumps({"types": example_types}, ensure_ascii=False),
1067
+ }
1068
+ )
1069
+ messages.append({"role": "user", "content": f"Term: {term}"})
1070
+ return messages
1071
+
1072
+ def _extract_types_from_text(self, generated_text: str) -> List[str]:
1073
+ """
1074
+ Parse {"types":[...]} from a free-form generation.
1075
+ """
1076
+ try:
1077
+ object_match = re.search(r'\{[^}]*"types"[^}]*\}', generated_text)
1078
+ if object_match:
1079
+ json_obj = json.loads(object_match.group(0))
1080
+ types_array = json_obj.get("types", [])
1081
+ return [
1082
+ type_label
1083
+ for type_label in types_array
1084
+ if isinstance(type_label, str)
1085
+ ]
1086
+ except Exception:
1087
+ pass
1088
+ return []
1089
+
1090
+ def _load_llm_for_types(
1091
+ self, model_id: str
1092
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
1093
+ """
1094
+ Load a *separate* small chat model for Term→Types (keeps LocalAutoLLM untouched).
1095
+ """
1096
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
1097
+ if tokenizer.pad_token is None:
1098
+ tokenizer.pad_token = tokenizer.eos_token
1099
+ model = AutoModelForCausalLM.from_pretrained(
1100
+ model_id,
1101
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
1102
+ device_map="auto" if torch.cuda.is_available() else None,
1103
+ )
1104
+ return model, tokenizer
1105
+
1106
+ def _load_doc_term_extractions(
1107
+ self,
1108
+ *,
1109
+ results_json_path: Optional[str] = None,
1110
+ in_memory_results: Optional[List[Dict]] = None,
1111
+ ) -> List[Dict]:
1112
+ """
1113
+ Normalize document→terms outputs to a list of:
1114
+ {"id": "<doc_id>", "extracted_terms": ["...", ...]}
1115
+
1116
+ Accepts either:
1117
+ - in_memory_results (list of dicts)
1118
+ - results_json_path pointing to:
1119
+ • a JSONL file with lines: {"id": "...", "terms": [...]}
1120
+ • OR a JSON file with {"results":[{"id":..., "extracted_terms": [...]}, ...]}
1121
+ • OR a JSON list of dicts
1122
+ """
1123
+ normalized_records: List[Dict] = []
1124
+
1125
+ def _coerce_to_record(source_row: Dict) -> Optional[Dict]:
1126
+ document_id = str(source_row.get("id", "")) or str(
1127
+ source_row.get("doc_id", "")
1128
+ )
1129
+ if not document_id:
1130
+ return None
1131
+ terms = source_row.get("extracted_terms")
1132
+ if terms is None:
1133
+ terms = source_row.get("terms")
1134
+ if (
1135
+ terms is None
1136
+ and "payload" in source_row
1137
+ and isinstance(source_row["payload"], dict)
1138
+ ):
1139
+ terms = source_row["payload"].get("terms")
1140
+ if not isinstance(terms, list):
1141
+ terms = []
1142
+ return {
1143
+ "id": document_id,
1144
+ "extracted_terms": [t for t in terms if isinstance(t, str)],
1145
+ }
1146
+
1147
+ if in_memory_results is not None:
1148
+ for source_row in in_memory_results:
1149
+ coerced_record = _coerce_to_record(source_row)
1150
+ if coerced_record:
1151
+ normalized_records.append(coerced_record)
1152
+ return normalized_records
1153
+
1154
+ if not results_json_path:
1155
+ raise ValueError("Provide results_json_path or in_memory_results")
1156
+
1157
+ # Detect JSON vs JSONL by extension (best-effort)
1158
+ if results_json_path.endswith(".jsonl"):
1159
+ with open(results_json_path, "r", encoding="utf-8") as file_in:
1160
+ for raw_line in file_in:
1161
+ raw_line = raw_line.strip()
1162
+ if not raw_line:
1163
+ continue
1164
+ # Multiple concatenated objects per line? Iterate them all.
1165
+ for json_obj in self._iter_json_objects(raw_line):
1166
+ if isinstance(json_obj, dict):
1167
+ coerced_record = _coerce_to_record(json_obj)
1168
+ if coerced_record:
1169
+ normalized_records.append(coerced_record)
1170
+ else:
1171
+ payload_obj = self._load_json(results_json_path)
1172
+ if isinstance(payload_obj, dict) and "results" in payload_obj:
1173
+ for source_row in payload_obj["results"]:
1174
+ coerced_record = _coerce_to_record(source_row)
1175
+ if coerced_record:
1176
+ normalized_records.append(coerced_record)
1177
+ elif isinstance(payload_obj, list):
1178
+ for source_row in payload_obj:
1179
+ if isinstance(source_row, dict):
1180
+ coerced_record = _coerce_to_record(source_row)
1181
+ if coerced_record:
1182
+ normalized_records.append(coerced_record)
1183
+
1184
+ return normalized_records
1185
+
1186
+ def _collect_unique_terms_from_extractions(
1187
+ self, doc_term_extractions: List[Dict]
1188
+ ) -> List[str]:
1189
+ """
1190
+ Collect unique terms (original casing) from normalized document→terms results.
1191
+ """
1192
+ seen_normalized_terms: set = set()
1193
+ ordered_unique_terms: List[str] = []
1194
+ for record in doc_term_extractions:
1195
+ for term_text in record.get("extracted_terms", []):
1196
+ normalized = self._normalize_term(term_text)
1197
+ if normalized and normalized not in seen_normalized_terms:
1198
+ seen_normalized_terms.add(normalized)
1199
+ ordered_unique_terms.append(term_text.strip())
1200
+ return ordered_unique_terms
1201
+
1202
+ def _build_term_to_doc_ids(
1203
+ self, doc_term_extractions: List[Dict]
1204
+ ) -> Dict[str, List[str]]:
1205
+ """
1206
+ Build lookup: normalized_term -> sorted unique list of doc_ids.
1207
+ """
1208
+ term_to_doc_set: Dict[str, set] = {}
1209
+ for record in doc_term_extractions:
1210
+ document_id = str(record.get("id", ""))
1211
+ for term_text in record.get("extracted_terms", []):
1212
+ normalized = self._normalize_term(term_text)
1213
+ if not normalized or not document_id:
1214
+ continue
1215
+ term_to_doc_set.setdefault(normalized, set()).add(document_id)
1216
+ return {
1217
+ normalized_term: sorted(doc_ids)
1218
+ for normalized_term, doc_ids in term_to_doc_set.items()
1219
+ }