OntoLearner 1.4.10__py3-none-any.whl → 1.4.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,1208 +12,587 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Iterable
16
15
  import json
17
- from json.decoder import JSONDecodeError
18
- import os
19
- import random
20
16
  import re
17
+ from typing import Any, Dict, List, Optional
18
+ from collections import defaultdict
21
19
 
22
20
  import torch
23
21
  from transformers import AutoTokenizer, AutoModelForCausalLM
24
22
 
25
- from ...base import AutoLearner, AutoLLM
23
+ from ...base import AutoLearner, AutoRetriever
26
24
 
27
- try:
28
- from outlines.models import Transformers as OutlinesTFModel
29
- from outlines.generate import json as outlines_generate_json
30
- from pydantic import BaseModel
31
-
32
- class _PredictedTypesSchema(BaseModel):
33
- """Schema used when generating structured JSON { "types": [...] }."""
34
-
35
- types: List[str]
36
-
37
- OUTLINES_AVAILABLE: bool = True
38
- except Exception:
39
- # If outlines is unavailable, we will fall back to greedy decoding + regex parsing.
40
- OUTLINES_AVAILABLE = False
41
- _PredictedTypesSchema = None
42
- OutlinesTFModel = None
43
- outlines_generate_json = None
44
-
45
-
46
- class LocalAutoLLM(AutoLLM):
25
+ class AlexbekRAGFewShotLearner(AutoLearner):
47
26
  """
48
- Minimal local LLM helper.
49
-
50
- - Inherits AutoLLM but overrides load/generate to avoid label_mapper.
51
- - Optional 4-bit loading with `load_in_4bit=True` in .load().
52
- - Greedy decoding by default (deterministic).
27
+ What it does (2-stage):
28
+ 1) doc -> terms
29
+ - retrieve top-k similar TRAIN documents (each has gold OL terms)
30
+ - build a few-shot chat prompt: (doc -> {"terms":[...]}) examples + target doc
31
+ - generate JSON {"terms":[...]} and parse it
32
+
33
+ 2) term -> types
34
+ - retrieve top-k similar TRAIN terms (each has gold types)
35
+ - build a few-shot chat prompt: (term -> {"types":[...]}) examples + target term
36
+ - generate JSON {"types":[...]} and parse it
37
+
38
+ Training behavior (fit):
39
+ - builds two retrieval indices:
40
+ * doc_retriever index over JSON strings of train docs (with "OL" field = gold terms)
41
+ * term_retriever index over JSON strings of train term->types examples
42
+
43
+ Prediction behavior (predict):
44
+ - returns a dict compatible with OntoLearner evaluation_report:
45
+ {
46
+ "terms": [{"doc_id": "...", "term": "..."}, ...],
47
+ "types": [{"doc_id": "...", "type": "..."}, ...],
48
+ }
49
+
50
+ Expected data format for task="text2onto":
51
+ data = {
52
+ "documents": [ {"id"/"doc_id": str, "title": str, "text": str, ...}, ... ],
53
+ "terms2docs": { term(str): [doc_id(str), ...], ... }
54
+ "terms2types": { term(str): [type(str), ...], ... }
55
+ }
56
+
57
+ IMPORTANT:
58
+ - LearnerPipeline calls learner.load(model_id=llm_id). We accept that and override llm_model_id.
59
+ - We override tasks_data_former() so AutoLearner.fit/predict does NOT rewrite text2onto dicts.
60
+ - Device placement: we put the model exactly on the device string the user provides
61
+ ("cpu", "cuda", "cuda:0", "cuda:1", ...). No device_map="auto".
53
62
  """
54
63
 
55
- def __init__(self, device: str = "cpu", token: str = "") -> None:
64
+ TERM2TYPES_SYSTEM_PROMPT = (
65
+ "You are an expert in ontology and semantic type classification. Your task is to predict "
66
+ "the semantic types for given terms based on their context and similar examples.\n\n"
67
+ "Given a term, you should predict its semantic types from the domain-specific ontology. "
68
+ "Use the provided examples to understand the patterns and relationships between terms and their types.\n\n"
69
+ "Output your response as a JSON object with the following structure:\n"
70
+ '{\n "types": ["type1", "type2", ...]\n}\n\n'
71
+ "The types should be relevant semantic categories that best describe the given term."
72
+ )
73
+
74
+ DOC2TERMS_SYSTEM_PROMPT = (
75
+ "You are an expert in ontology term extraction.\n\n"
76
+ "TASK: Extract specific, relevant ontology terms from scientific documents.\n\n"
77
+ "INSTRUCTIONS:\n"
78
+ "- The following conversation contains few-shot examples showing correct term extraction patterns\n"
79
+ "- Study these examples carefully to understand the extraction style and approach\n"
80
+ "- Follow the EXACT same pattern and style demonstrated in the examples\n"
81
+ "- Extract only terms that actually appear in the document text\n"
82
+ "- Focus on domain-specific terminology, concepts, and technical terms\n\n"
83
+ "- The first three user-assistant conversation pairs serve as few-shot examples\n"
84
+ "- Each example shows: user provides a document, assistant extracts relevant terms\n"
85
+ "- Pay attention to the extraction patterns and term selection criteria in these examples\n\n"
86
+ "DO:\n"
87
+ "- Extract terms that are EXPLICITLY mentioned in the LAST document\n"
88
+ "- Follow the SAME extraction pattern as shown in examples\n"
89
+ "- Return unique terms without duplicates\n"
90
+ "- Use the same JSON format as demonstrated\n\n"
91
+ "DON'T:\n"
92
+ "- Hallucinate or invent terms not present in last the document\n"
93
+ "- Repeat the same term multiple times\n"
94
+ "- Deviate from the extraction style shown in examples\n\n"
95
+ "OUTPUT FORMAT: Return a JSON object with a single field 'terms' containing a list of extracted terms."
96
+ )
97
+
98
+ def __init__(
99
+ self,
100
+ llm_model_id: str,
101
+ retriever_model_id: str = "sentence-transformers/all-MiniLM-L6-v2",
102
+ device: str = "cpu",
103
+ top_k: int = 3,
104
+ max_new_tokens: int = 256,
105
+ max_input_length: int = 2048,
106
+ use_tfidf: bool = False,
107
+ seed: int = 42,
108
+ restrict_to_known_types: bool = True,
109
+ hf_token: str = "",
110
+ local_files_only: bool = False,
111
+ **kwargs: Any,
112
+ ):
56
113
  """
57
- Initialize the local LLM holder.
58
-
59
114
  Parameters
60
115
  ----------
61
- device : str
62
- Execution device: "cpu" or "cuda".
63
- token : str
64
- Optional auth token for private model hubs.
65
- """
66
- super().__init__(label_mapper=None, device=device, token=token)
116
+ llm_model_id:
117
+ HuggingFace model id OR local path to a downloaded model directory.
118
+ retriever_model_id:
119
+ SentenceTransformer model id OR local path to a downloaded SBERT directory.
120
+ device:
121
+ Exact device string to place model on ("cpu", "cuda", "cuda:0", ...).
122
+ top_k:
123
+ Number of retrieved examples for few-shot prompting in each stage.
124
+ max_new_tokens:
125
+ Max tokens to generate for each prompt.
126
+ max_input_length:
127
+ Max prompt length before truncation.
128
+ use_tfidf:
129
+ If docs include TF-IDF suggestions (key "TF-IDF" or "tfidf_terms"), include them in prompts.
130
+ seed:
131
+ Seed for reproducibility.
132
+ restrict_to_known_types:
133
+ If True, append allowed type label list (from training) to system prompt in term->types stage.
134
+ This helps exact-match evaluation by discouraging invented labels.
135
+ hf_token:
136
+ HuggingFace token for gated models (optional).
137
+ local_files_only:
138
+ If True, Transformers will not try to reach the internet (requires local cache / local path).
139
+ """
140
+ super().__init__(**kwargs)
141
+
142
+ self.llm_model_id: str = llm_model_id
143
+ self.retriever_model_id: str = retriever_model_id
144
+ self.device: str = device
145
+ self.top_k: int = int(top_k)
146
+ self.max_new_tokens: int = int(max_new_tokens)
147
+ self.max_input_length: int = int(max_input_length)
148
+ self.use_tfidf: bool = bool(use_tfidf)
149
+ self.seed: int = int(seed)
150
+ self.restrict_to_known_types: bool = bool(restrict_to_known_types)
151
+ self.hf_token: str = hf_token or ""
152
+ self.local_files_only: bool = bool(local_files_only)
153
+
67
154
  self.model: Optional[AutoModelForCausalLM] = None
68
155
  self.tokenizer: Optional[AutoTokenizer] = None
156
+ self._loaded: bool = False
69
157
 
70
- def load(self, model_id: str, *, load_in_4bit: bool = False) -> None:
71
- """
72
- Load a Hugging Face causal model + tokenizer and set deterministic
73
- generation defaults.
74
-
75
- Parameters
76
- ----------
77
- model_id : str
78
- Model identifier resolvable by HF `from_pretrained`.
79
- load_in_4bit : bool
80
- If True and bitsandbytes is available, load using 4-bit quantization.
81
- """
82
- # Tokenizer
83
- self.tokenizer = AutoTokenizer.from_pretrained(
84
- model_id, padding_side="left", token=self.token
85
- )
86
- if self.tokenizer.pad_token is None:
87
- self.tokenizer.pad_token = self.tokenizer.eos_token
88
-
89
- # Model (optionally quantized)
90
- if load_in_4bit:
91
- from transformers import BitsAndBytesConfig
158
+ # Internal retrievers (always used in method-1, even in "llm-only" pipeline mode)
159
+ self.doc_retriever = AutoRetriever()
160
+ self.term_retriever = AutoRetriever()
92
161
 
93
- quantization_config = BitsAndBytesConfig(
94
- load_in_4bit=True,
95
- bnb_4bit_quant_type="nf4",
96
- bnb_4bit_use_double_quant=True,
97
- bnb_4bit_compute_dtype=torch.bfloat16,
98
- )
99
- self.model = AutoModelForCausalLM.from_pretrained(
100
- model_id,
101
- device_map="auto",
102
- quantization_config=quantization_config,
103
- token=self.token,
104
- )
105
- else:
106
- device_map = (
107
- "auto" if (self.device != "cpu" and torch.cuda.is_available()) else None
108
- )
109
- self.model = AutoModelForCausalLM.from_pretrained(
110
- model_id,
111
- device_map=device_map,
112
- torch_dtype=torch.bfloat16
113
- if torch.cuda.is_available()
114
- else torch.float32,
115
- token=self.token,
116
- )
162
+ # Indexed corpora as JSON strings
163
+ self._doc_examples_json: List[str] = []
164
+ self._term_examples_json: List[str] = []
117
165
 
118
- # Deterministic generation defaults
119
- generation_cfg = self.model.generation_config
120
- generation_cfg.do_sample = False
121
- generation_cfg.temperature = None
122
- generation_cfg.top_k = None
123
- generation_cfg.top_p = None
124
- generation_cfg.num_beams = 1
166
+ # Cached allowed type labels (for optional restriction)
167
+ self._allowed_types: List[str] = []
125
168
 
126
- def generate(self, prompts: List[str], max_new_tokens: int = 128) -> List[str]:
169
+ def tasks_data_former(self, data: Any, task: str, test: bool = False):
127
170
  """
128
- Greedy-generate continuations for a list of prompts.
171
+ Override base formatter: for task='text2onto' return data unchanged.
172
+ """
173
+ if task == "text2onto":
174
+ return data
175
+ return super().tasks_data_former(data=data, task=task, test=test)
129
176
 
130
- Parameters
131
- ----------
132
- prompts : List[str]
133
- Prompts to generate for (batched).
134
- max_new_tokens : int
135
- Maximum number of new tokens per continuation.
136
-
137
- Returns
138
- -------
139
- List[str]
140
- Decoded new-token texts (no special tokens, stripped).
177
+ def load(self, **kwargs: Any):
141
178
  """
142
- if self.model is None or self.tokenizer is None:
143
- raise RuntimeError(
144
- "Call .load(model_id) on LocalAutoLLM before generate()."
145
- )
179
+ Called by LearnerPipeline as: learner.load(model_id=llm_id)
146
180
 
147
- tokenized_batch = self.tokenizer(
148
- prompts, return_tensors="pt", padding=True, truncation=True
149
- )
150
- input_seq_len = tokenized_batch["input_ids"].shape[1]
151
- tokenized_batch = {
152
- k: v.to(self.model.device) for k, v in tokenized_batch.items()
153
- }
181
+ We accept overrides via kwargs:
182
+ - model_id / llm_model_id
183
+ - device, top_k, max_new_tokens, max_input_length, use_tfidf, seed, restrict_to_known_types
184
+ - hf_token, local_files_only
185
+ """
186
+ model_id = kwargs.get("model_id") or kwargs.get("llm_model_id")
187
+ if model_id:
188
+ self.llm_model_id = str(model_id)
154
189
 
155
- with torch.no_grad():
156
- outputs = self.model.generate(
157
- **tokenized_batch,
158
- max_new_tokens=max_new_tokens,
159
- pad_token_id=self.tokenizer.eos_token_id,
160
- do_sample=False,
161
- num_beams=1,
162
- )
190
+ for k in [
191
+ "device",
192
+ "top_k",
193
+ "max_new_tokens",
194
+ "max_input_length",
195
+ "use_tfidf",
196
+ "seed",
197
+ "restrict_to_known_types",
198
+ "hf_token",
199
+ "local_files_only",
200
+ "retriever_model_id",
201
+ ]:
202
+ if k in kwargs:
203
+ setattr(self, k, kwargs[k])
163
204
 
164
- # Only return the newly generated part for each row in the batch
165
- continuation_token_ids = outputs[:, input_seq_len:]
166
- return [
167
- self.tokenizer.decode(row, skip_special_tokens=True).strip()
168
- for row in continuation_token_ids
169
- ]
205
+ if self._loaded:
206
+ return
170
207
 
208
+ torch.manual_seed(self.seed)
209
+ if torch.cuda.is_available():
210
+ torch.cuda.manual_seed_all(self.seed)
171
211
 
172
- class AlexbekFewShotLearner(AutoLearner):
173
- """
174
- Text2Onto learner for LLMS4OL Task A (term & type extraction).
175
-
176
- Public API (A1 + convenience):
177
- - fit(train_docs_jsonl, terms2doc_json, sample_size=24, seed=42)
178
- - predict_terms(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
179
- - predict_types(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
180
- - evaluate_extraction_f1(gold_item2docs_json, preds_jsonl, key="term"|"type") -> float
181
-
182
- Option A (A2, term→types) bridge:
183
- - predict_types_from_terms_option_a(...)
184
- Reads your A1 results (docs→terms), predicts types for each term, and
185
- writes two files: terms2types_pred.json + types2docs_pred.json
186
- """
212
+ dev = str(self.device).strip()
213
+ if dev.startswith("cuda") and not torch.cuda.is_available():
214
+ raise RuntimeError(f"Device was set to '{dev}', but CUDA is not available.")
187
215
 
188
- def __init__(self, model: LocalAutoLLM, device: str = "cpu", **_: Any) -> None:
189
- """
190
- Initialize learner state and canned prompts.
216
+ dtype = torch.bfloat16 if dev.startswith("cuda") else torch.float32
191
217
 
192
- Parameters
193
- ----------
194
- model : LocalAutoLLM
195
- Loaded local LLM helper instance.
196
- device : str
197
- Device name ("cpu" or "cuda").
198
- """
199
- super().__init__(**_)
200
- self.model = model
201
- self.device = device
202
-
203
- # Few-shot exemplars for A1 (Docs→Terms) and for Docs→Types:
204
- # Each exemplar is a tuple: (title, text, gold_list)
205
- self._fewshot_terms_docs: List[Tuple[str, str, List[str]]] = []
206
- self._fewshot_types_docs: List[Tuple[str, str, List[str]]] = []
207
-
208
- # System prompts
209
- self._system_prompt_terms = (
210
- "You are an expert in ontology term extraction.\n"
211
- "Extract only terms that explicitly appear in the document.\n"
212
- 'Answer strictly as JSON: {"terms": ["..."]}\n'
213
- )
214
- self._system_prompt_types = (
215
- "You are an expert in ontology type classification.\n"
216
- "List ontology *types* that characterize the document’s terminology.\n"
217
- 'Answer strictly as JSON: {"types": ["..."]}\n'
218
- )
218
+ tok_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only}
219
+ if self.hf_token:
220
+ tok_kwargs["token"] = self.hf_token
221
+ try:
222
+ self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs)
223
+ except TypeError:
224
+ tok_kwargs.pop("token", None)
225
+ if self.hf_token:
226
+ tok_kwargs["use_auth_token"] = self.hf_token
227
+ self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs)
219
228
 
220
- # Compiled regex for robust JSON extraction from LLM outputs
221
- self._json_object_regex = re.compile(r"\{[^{}]*\}", re.S)
222
- self._json_array_regex = re.compile(r"\[[^\]]*\]", re.S)
229
+ if self.tokenizer.pad_token is None:
230
+ self.tokenizer.pad_token = self.tokenizer.eos_token
223
231
 
224
- # Term→Types (Option A) specific prompt
225
- self._system_prompt_term_to_types = (
226
- "You are an expert in ontology and semantic type classification.\n"
227
- "Given a term, predict its semantic types from the domain-specific ontology.\n"
228
- 'Answer strictly as JSON:\n{"types": ["type1", "type2", "..."]}'
229
- )
230
232
 
231
- def fit(
232
- self,
233
- *,
234
- train_docs_jsonl: str,
235
- terms2doc_json: str,
236
- sample_size: int = 24,
237
- seed: int = 42,
238
- ) -> None:
239
- """
240
- Build internal few-shot exemplars from a labeled training split.
233
+ model_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only}
234
+ if self.hf_token:
235
+ model_kwargs["token"] = self.hf_token
241
236
 
242
- Parameters
243
- ----------
244
- train_docs_jsonl : str
245
- Path to JSONL (or tolerant JSON/JSONL) with train documents.
246
- terms2doc_json : str
247
- JSON mapping item -> [doc_id,...]; "item" can be a term or type.
248
- sample_size : int
249
- Number of exemplar documents to keep for few-shot prompting.
250
- seed : int
251
- RNG seed for reproducible sampling.
252
- """
253
- rng = random.Random(seed)
254
-
255
- # Load documents and map doc_id -> row
256
- document_map = self._load_documents_jsonl(train_docs_jsonl)
257
- if not document_map:
258
- raise FileNotFoundError(f"No documents found in: {train_docs_jsonl}")
259
-
260
- # Load item -> [doc_ids]
261
- item_to_docs_map = self._load_json(terms2doc_json)
262
- if not isinstance(item_to_docs_map, dict):
263
- raise ValueError(
264
- f"{terms2doc_json} must be a JSON dict mapping item -> [doc_ids]"
237
+ try:
238
+ self.model = AutoModelForCausalLM.from_pretrained(
239
+ self.llm_model_id,
240
+ dtype=dtype,
241
+ **model_kwargs,
265
242
  )
266
-
267
- # Reverse mapping: doc_id -> [items]
268
- doc_id_to_items_map: Dict[str, List[str]] = {}
269
- for item_label, doc_id_list in item_to_docs_map.items():
270
- for doc_id in doc_id_list:
271
- doc_id_to_items_map.setdefault(doc_id, []).append(item_label)
272
-
273
- # Build candidate exemplars (title, text, gold_list)
274
- exemplar_candidates: List[Tuple[str, str, List[str]]] = []
275
- for doc_id, labeled_items in doc_id_to_items_map.items():
276
- doc_row = document_map.get(doc_id)
277
- if not doc_row:
278
- continue
279
- doc_title = str(doc_row.get("title", "")) # be defensive (may be None)
280
- doc_text = self._to_text(
281
- doc_row.get("text", "")
282
- ) # string-ify list if needed
283
- if not doc_text:
284
- continue
285
- gold_items = self._unique_preserve(
286
- [s for s in labeled_items if isinstance(s, str)]
243
+ except TypeError:
244
+ model_kwargs.pop("token", None)
245
+ if self.hf_token:
246
+ model_kwargs["use_auth_token"] = self.hf_token
247
+ self.model = AutoModelForCausalLM.from_pretrained(
248
+ self.llm_model_id,
249
+ torch_dtype=dtype,
250
+ **model_kwargs,
287
251
  )
288
- if gold_items:
289
- exemplar_candidates.append((doc_title, doc_text, gold_items))
290
252
 
291
- if not exemplar_candidates:
292
- raise RuntimeError(
293
- "No candidate docs with items found to build few-shot exemplars."
294
- )
253
+ self.model = self.model.to(dev)
295
254
 
296
- chosen_exemplars = rng.sample(
297
- exemplar_candidates, k=min(sample_size, len(exemplar_candidates))
298
- )
299
- # Reuse exemplars for both docs→terms and docs→types prompting
300
- self._fewshot_terms_docs = chosen_exemplars
301
- self._fewshot_types_docs = chosen_exemplars
255
+ self.doc_retriever.load(self.retriever_model_id)
256
+ self.term_retriever.load(self.retriever_model_id)
302
257
 
303
- def predict_terms(
304
- self,
305
- *,
306
- docs_test_jsonl: str,
307
- out_jsonl: str,
308
- max_new_tokens: int = 128,
309
- few_shot_k: int = 6,
310
- ) -> int:
311
- """
312
- Extract terms that explicitly appear in each document.
258
+ self._loaded = True
313
259
 
314
- Writes one JSON object per line:
315
- {"id": "<doc_id>", "terms": ["...", "...", ...]}
316
260
 
317
- Parameters
318
- ----------
319
- docs_test_jsonl : str
320
- Path to test/dev documents in JSONL or tolerant JSON/JSONL.
321
- out_jsonl : str
322
- Output JSONL path where predictions are written (one line per doc).
323
- max_new_tokens : int
324
- Max generation length.
325
- few_shot_k : int
326
- Number of few-shot exemplars to prepend per prompt.
327
-
328
- Returns
329
- -------
330
- int
331
- Number of lines written (i.e., number of processed documents).
261
+ def _format_doc(self, title: str, text: str, tfidf: Optional[List[str]] = None) -> str:
332
262
  """
333
- if self.model is None or self.model.model is None:
334
- raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
335
-
336
- test_documents = self._load_documents_jsonl(docs_test_jsonl)
337
- prompts: List[str] = []
338
- document_order: List[str] = []
339
-
340
- for document_id, document_row in test_documents.items():
341
- title = str(document_row.get("title", ""))
342
- text = self._to_text(document_row.get("text", ""))
343
-
344
- fewshot_block = self._format_fewshot_block(
345
- self._system_prompt_terms,
346
- self._fewshot_terms_docs,
347
- key="terms",
348
- k=few_shot_k,
349
- )
350
- user_block = self._format_user_block(title, text)
351
-
352
- prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
353
- document_order.append(document_id)
354
-
355
- generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
356
- parsed_term_lists = [
357
- self._parse_json_list(generated, key="terms") for generated in generations
358
- ]
359
-
360
- os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
361
- lines_written = 0
362
- with open(out_jsonl, "w", encoding="utf-8") as fp_out:
363
- for document_id, term_list in zip(document_order, parsed_term_lists):
364
- payload = {"id": document_id, "terms": self._unique_preserve(term_list)}
365
- fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
366
- lines_written += 1
367
- return lines_written
368
-
369
- def predict_types(
370
- self,
371
- *,
372
- docs_test_jsonl: str,
373
- out_jsonl: str,
374
- max_new_tokens: int = 128,
375
- few_shot_k: int = 6,
376
- ) -> int:
263
+ Format doc as the retriever query and as the user prompt content.
377
264
  """
378
- Predict ontology types that characterize each document’s terminology.
379
-
380
- Writes one JSON object per line:
381
- {"id": "<doc_id>", "types": ["...", "...", ...]}
265
+ s = f"Title: {title}\n\nText:\n{text}"
266
+ if tfidf:
267
+ s += f"\n\nTF-IDF based suggestions: {tfidf}"
268
+ return s
382
269
 
383
- Parameters
384
- ----------
385
- docs_test_jsonl : str
386
- Path to test/dev documents in JSONL or tolerant JSON/JSONL.
387
- out_jsonl : str
388
- Output JSONL path where predictions are written (one line per doc).
389
- max_new_tokens : int
390
- Max generation length.
391
- few_shot_k : int
392
- Number of few-shot exemplars to prepend per prompt.
393
-
394
- Returns
395
- -------
396
- int
397
- Number of lines written (i.e., number of processed documents).
270
+ def _apply_chat_template(self, conversation: List[Dict[str, str]]) -> str:
398
271
  """
399
- if self.model is None or self.model.model is None:
400
- raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
401
-
402
- test_documents = self._load_documents_jsonl(docs_test_jsonl)
403
- prompts: List[str] = []
404
- document_order: List[str] = []
405
-
406
- for document_id, document_row in test_documents.items():
407
- title = str(document_row.get("title", ""))
408
- text = self._to_text(document_row.get("text", ""))
409
-
410
- fewshot_block = self._format_fewshot_block(
411
- self._system_prompt_types,
412
- self._fewshot_types_docs,
413
- key="types",
414
- k=few_shot_k,
415
- )
416
- user_block = self._format_user_block(title, text)
417
-
418
- prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
419
- document_order.append(document_id)
420
-
421
- generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
422
- parsed_type_lists = [
423
- self._parse_json_list(generated, key="types") for generated in generations
424
- ]
425
-
426
- os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
427
- lines_written = 0
428
- with open(out_jsonl, "w", encoding="utf-8") as fp_out:
429
- for document_id, type_list in zip(document_order, parsed_type_lists):
430
- payload = {"id": document_id, "types": self._unique_preserve(type_list)}
431
- fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
432
- lines_written += 1
433
- return lines_written
434
-
435
- def evaluate_extraction_f1(
436
- self,
437
- gold_item2docs_json: str,
438
- preds_jsonl: str,
439
- *,
440
- key: str = "term",
441
- ) -> float:
272
+ Convert conversation into a single prompt string using the tokenizer's chat template if available.
442
273
  """
443
- Compute micro-F1 over (doc_id, item) pairs.
444
-
445
- Parameters
446
- ----------
447
- gold_item2docs_json : str
448
- JSON mapping item -> [doc_ids].
449
- preds_jsonl : str
450
- JSONL lines like {"id": "...", "terms":[...]} or {"id":"...","types":[...]}.
451
- key : str
452
- "term" or "type" depending on what you are evaluating.
453
-
454
- Returns
455
- -------
456
- float
457
- Micro-averaged F1 score.
458
- """
459
- item_to_doc_ids: Dict[str, List[str]] = self._load_json(gold_item2docs_json)
460
-
461
- # Build gold: doc_id -> set(items)
462
- gold_doc_to_items: Dict[str, set] = {}
463
- for item_label, doc_id_list in item_to_doc_ids.items():
464
- for document_id in doc_id_list:
465
- gold_doc_to_items.setdefault(document_id, set()).add(
466
- self._norm(item_label)
467
- )
468
-
469
- # Build predictions: doc_id -> set(items)
470
- pred_doc_to_items: Dict[str, set] = {}
471
- with open(preds_jsonl, "r", encoding="utf-8") as fp_in:
472
- for line in fp_in:
473
- row = json.loads(line.strip())
474
- document_id = str(row.get("id", ""))
475
- items_list = row.get("terms" if key == "term" else "types", [])
476
- pred_doc_to_items[document_id] = {
477
- self._norm(x) for x in items_list if isinstance(x, str)
478
- }
274
+ assert self.tokenizer is not None
275
+ if hasattr(self.tokenizer, "apply_chat_template"):
276
+ return self.tokenizer.apply_chat_template(
277
+ conversation, add_generation_prompt=True, tokenize=False
278
+ )
479
279
 
480
- # Micro counts
481
- true_positive = false_positive = false_negative = 0
482
- all_document_ids = set(gold_doc_to_items.keys()) | set(pred_doc_to_items.keys())
483
- for document_id in all_document_ids:
484
- gold_set = gold_doc_to_items.get(document_id, set())
485
- pred_set = pred_doc_to_items.get(document_id, set())
486
- true_positive += len(gold_set & pred_set)
487
- false_positive += len(pred_set - gold_set)
488
- false_negative += len(gold_set - pred_set)
489
-
490
- precision = (
491
- true_positive / (true_positive + false_positive)
492
- if (true_positive + false_positive)
493
- else 0.0
494
- )
495
- recall = (
496
- true_positive / (true_positive + false_negative)
497
- if (true_positive + false_negative)
498
- else 0.0
499
- )
500
- f1 = (
501
- 2 * precision * recall / (precision + recall)
502
- if (precision + recall)
503
- else 0.0
504
- )
505
- return f1
280
+ parts = []
281
+ for t in conversation:
282
+ parts.append(f"{t['role'].upper()}:\n{t['content']}\n")
283
+ parts.append("ASSISTANT:\n")
284
+ return "\n".join(parts)
285
+
286
+ def _extract_first_json_obj(self, text: str) -> Optional[dict]:
287
+ """
288
+ Extract the first valid JSON object from generated text by scanning balanced {...}.
289
+ """
290
+ starts = [i for i, ch in enumerate(text) if ch == "{"]
291
+
292
+ for s in starts:
293
+ depth = 0
294
+ for e in range(s, len(text)):
295
+ if text[e] == "{":
296
+ depth += 1
297
+ elif text[e] == "}":
298
+ depth -= 1
299
+ if depth == 0:
300
+ candidate = text[s : e + 1].strip().replace("\n", " ")
301
+ try:
302
+ return json.loads(candidate)
303
+ except Exception:
304
+ try:
305
+ candidate2 = re.sub(r"'", '"', candidate)
306
+ return json.loads(candidate2)
307
+ except Exception:
308
+ pass
309
+ break
310
+ return None
311
+
312
+ def _dedup_clean(self, items: List[str]) -> List[str]:
313
+ """
314
+ Normalize and deduplicate strings (case-insensitive).
315
+ """
316
+ out: List[str] = []
317
+ seen = set()
318
+ for x in items or []:
319
+ if not isinstance(x, str):
320
+ continue
321
+ x2 = re.sub(r"\s+", " ", x.strip())
322
+ if not x2:
323
+ continue
324
+ k = x2.lower()
325
+ if k in seen:
326
+ continue
327
+ seen.add(k)
328
+ out.append(x2)
329
+ return out
506
330
 
507
- def predict_types_from_terms(
508
- self,
509
- *,
510
- doc_terms_jsonl: Optional[str] = None, # formerly a1_results_jsonl
511
- doc_terms_list: Optional[List[Dict]] = None, # formerly a1_results_list
512
- few_shot_jsonl: Optional[
513
- str
514
- ] = None, # JSONL lines: {"term":"...", "types":[...]}
515
- rag_terms_json: Optional[
516
- str
517
- ] = None, # JSON list; items may contain "term" and "RAG":[...]
518
- random_few_shot: Optional[int] = 3,
519
- model_id: str = "Qwen/Qwen2.5-1.5B-Instruct",
520
- use_structured_output: bool = True,
521
- seed: int = 42,
522
- out_terms2types: str = "terms2types_pred.json",
523
- out_types2docs: str = "types2docs_pred.json",
524
- ) -> Dict[str, Any]:
331
+ def _doc_id(self, d: Dict[str, Any]) -> str:
332
+ """
333
+ Extract doc_id from common keys: doc_id, id, docid.
525
334
  """
526
- Predict types for each unique term extracted per document and derive a types→docs map.
335
+ return str(d.get("doc_id") or d.get("id") or d.get("docid") or "")
527
336
 
528
- Parameters
529
- ----------
530
- doc_terms_jsonl : Optional[str]
531
- Path to JSONL with lines like {"id": "...", "terms": [...]} or a JSON with {"results":[...]}.
532
- doc_terms_list : Optional[List[Dict]]
533
- In-memory results like [{"id":"...","extracted_terms":[...]}] or {"id":"...","terms":[...]}.
534
- few_shot_jsonl : Optional[str]
535
- Global few-shot exemplars: one JSON object per line with {"term": "...", "types":[...]}.
536
- rag_terms_json : Optional[str]
537
- Optional per-term RAG exemplars: a JSON list of {"term": "...", "RAG":[{"term": "...", "types":[...]}]}.
538
- random_few_shot : Optional[int]
539
- If provided, randomly select up to this many few-shot examples for each prediction.
540
- model_id : str
541
- HF model id used specifically for term→types predictions.
542
- use_structured_output : bool
543
- If True and outlines is available, enforce structured {"types":[...]} output.
544
- seed : int
545
- Random seed for reproducibility.
546
- out_terms2types : str
547
- Output JSON path for list of {"term": "...", "predicted_types":[...]}.
548
- out_types2docs : str
549
- Output JSON path for dict {"TYPE":[doc_ids,...], ...}.
550
-
551
- Returns
552
- -------
553
- Dict[str, Any]
554
- Summary with predictions and counts.
337
+ def _extract_documents(self, data: Any) -> List[Dict[str, Any]]:
555
338
  """
556
- torch.manual_seed(seed)
557
- if torch.cuda.is_available():
558
- torch.cuda.manual_seed(seed)
339
+ Accept list-of-docs OR dict with 'documents'/'docs'.
340
+ """
341
+ if isinstance(data, list):
342
+ return data
343
+ if isinstance(data, dict):
344
+ if isinstance(data.get("documents"), list):
345
+ return data["documents"]
346
+ if isinstance(data.get("docs"), list):
347
+ return data["docs"]
348
+ raise ValueError("Expected dict with 'documents' (or 'docs'), or a list of docs.")
559
349
 
560
- # Load normalized document→terms results
561
- doc_term_extractions = self._load_doc_term_extractions(
562
- results_json_path=doc_terms_jsonl,
563
- in_memory_results=doc_terms_list,
564
- )
565
- if not doc_term_extractions:
566
- raise ValueError(
567
- "No document→terms results provided (doc_terms_jsonl/doc_terms_list)."
568
- )
350
+ def _normalize_terms2docs(self, raw_terms2docs: Any, docs: List[Dict[str, Any]]) -> Dict[str, List[str]]:
351
+ """
352
+ Normalize mapping to: term -> [doc_id, ...].
569
353
 
570
- # Prepare unique term list and term→doc occurrences
571
- unique_terms = self._collect_unique_terms_from_extractions(doc_term_extractions)
572
- term_to_doc_ids_map = self._build_term_to_doc_ids(doc_term_extractions)
573
-
574
- # Load optional global few-shot examples
575
- global_few_shot_examples: List[Dict] = []
576
- if few_shot_jsonl and os.path.exists(few_shot_jsonl):
577
- with open(few_shot_jsonl, "r", encoding="utf-8") as few_shot_file:
578
- for raw_line in few_shot_file:
579
- raw_line = raw_line.strip()
580
- if not raw_line:
581
- continue
582
- try:
583
- json_obj = json.loads(raw_line)
584
- except Exception:
585
- continue
586
- if (
587
- isinstance(json_obj, dict)
588
- and "term" in json_obj
589
- and "types" in json_obj
590
- ):
591
- global_few_shot_examples.append(json_obj)
592
-
593
- # Optional per-term RAG examples: {normalized_term -> [examples]}
594
- rag_examples_lookup: Dict[str, List[Dict]] = {}
595
- if rag_terms_json and os.path.exists(rag_terms_json):
596
- try:
597
- rag_payload = self._load_json(rag_terms_json)
598
- if isinstance(rag_payload, list):
599
- for rag_item in rag_payload:
600
- if isinstance(rag_item, dict):
601
- normalized_term = self._normalize_term(
602
- rag_item.get("term", "")
603
- )
604
- rag_examples_lookup[normalized_term] = rag_item.get(
605
- "RAG", []
606
- )
607
- except Exception:
608
- pass
354
+ If caller accidentally provides inverted mapping: doc_id -> [term, ...],
355
+ we detect it (keys mostly match doc_ids) and invert it.
356
+ """
357
+ if not isinstance(raw_terms2docs, dict) or not raw_terms2docs:
358
+ return {}
609
359
 
610
- # Load a small chat LLM dedicated to Term→Types
611
- typing_model, typing_tokenizer = self._load_llm_for_types(model_id)
360
+ doc_ids = {self._doc_id(d) for d in docs}
361
+ doc_ids.discard("")
612
362
 
613
- # Predict types per term
614
- term_to_predicted_types_list: List[Dict] = []
615
- for term_text in unique_terms:
616
- normalized_term = self._normalize_term(term_text)
363
+ keys = list(raw_terms2docs.keys())
364
+ sample = keys[:200]
365
+ hits = sum(1 for k in sample if str(k) in doc_ids)
617
366
 
618
- # Prefer per-term RAG for this term, else use global few-shot
619
- few_shot_examples_for_term = (
620
- rag_examples_lookup.get(normalized_term, None)
621
- or global_few_shot_examples
622
- )
367
+ if sample and hits >= int(0.6 * len(sample)):
368
+ term2docs: Dict[str, List[str]] = defaultdict(list)
369
+ for did, terms in raw_terms2docs.items():
370
+ did = str(did)
371
+ if did not in doc_ids:
372
+ continue
373
+ for t in (terms or []):
374
+ if isinstance(t, str) and t.strip():
375
+ term2docs[t.strip()].append(did)
376
+ return {t: self._dedup_clean(ds) for t, ds in term2docs.items()}
377
+
378
+ norm: Dict[str, List[str]] = {}
379
+ for term, doc_list in raw_terms2docs.items():
380
+ if not isinstance(term, str) or not term.strip():
381
+ continue
382
+ docs_norm = [str(d) for d in (doc_list or []) if str(d)]
383
+ if docs_norm:
384
+ norm[term.strip()] = self._dedup_clean(docs_norm)
385
+ return norm
623
386
 
624
- # Build conversation and prompt
625
- conversation_messages = self._build_conv_for_type_infer(
626
- term=term_text,
627
- few_shot_examples=few_shot_examples_for_term,
628
- random_k=random_few_shot,
629
- )
630
- typing_prompt_string = self._apply_chat_template_safe_types(
631
- typing_tokenizer, conversation_messages
632
- )
387
+ def _generate(self, prompt: str) -> str:
388
+ """
389
+ Deterministic single-prompt generation (no sampling).
390
+ Returns decoded completion only.
391
+ """
392
+ assert self.model is not None and self.tokenizer is not None
633
393
 
634
- predicted_types: List[str] = []
635
- raw_generation_text: str = ""
636
-
637
- # Structured JSON path (if requested and available)
638
- if (
639
- use_structured_output
640
- and OUTLINES_AVAILABLE
641
- and _PredictedTypesSchema is not None
642
- ):
643
- try:
644
- outlines_model = OutlinesTFModel(typing_model, typing_tokenizer) # type: ignore
645
- generator = outlines_generate_json(
646
- outlines_model, _PredictedTypesSchema
647
- ) # type: ignore
648
- structured = generator(typing_prompt_string, max_tokens=512)
649
- predicted_types = [
650
- label for label in structured.types if isinstance(label, str)
651
- ]
652
- raw_generation_text = json.dumps(
653
- {"types": predicted_types}, ensure_ascii=False
654
- )
655
- except Exception:
656
- # Fall back to greedy decoding
657
- use_structured_output = False
658
-
659
- # Greedy decode fallback
660
- if (
661
- not use_structured_output
662
- or not OUTLINES_AVAILABLE
663
- or _PredictedTypesSchema is None
664
- ):
665
- tokenized_prompt = typing_tokenizer(
666
- typing_prompt_string,
667
- return_tensors="pt",
668
- truncation=True,
669
- max_length=2048,
670
- )
671
- if torch.cuda.is_available():
672
- tokenized_prompt = {
673
- name: tensor.cuda() for name, tensor in tokenized_prompt.items()
674
- }
675
- with torch.no_grad():
676
- output_ids = typing_model.generate(
677
- **tokenized_prompt,
678
- max_new_tokens=256,
679
- do_sample=False,
680
- num_beams=1,
681
- pad_token_id=typing_tokenizer.eos_token_id,
682
- )
683
- new_token_span = output_ids[0][tokenized_prompt["input_ids"].shape[1] :]
684
- raw_generation_text = typing_tokenizer.decode(
685
- new_token_span, skip_special_tokens=True
686
- )
687
- predicted_types = self._extract_types_from_text(raw_generation_text)
688
-
689
- term_to_predicted_types_list.append(
690
- {
691
- "term": term_text,
692
- "predicted_types": sorted(set(predicted_types)),
693
- }
694
- )
394
+ enc = self.tokenizer(
395
+ prompt,
396
+ return_tensors="pt",
397
+ truncation=True,
398
+ max_length=self.max_input_length,
399
+ )
400
+ enc = {k: v.to(self.model.device) for k, v in enc.items()}
695
401
 
696
- # 7) Build types→docs from (term→types) and (term→docs)
697
- types_to_doc_id_set: Dict[str, set] = {}
698
- for term_prediction in term_to_predicted_types_list:
699
- normalized_term = self._normalize_term(term_prediction["term"])
700
- doc_ids_for_term = term_to_doc_ids_map.get(normalized_term, [])
701
- for type_label in term_prediction.get("predicted_types", []):
702
- types_to_doc_id_set.setdefault(type_label, set()).update(
703
- doc_ids_for_term
704
- )
705
-
706
- types_to_doc_ids: Dict[str, List[str]] = {
707
- type_label: sorted(doc_id_set)
708
- for type_label, doc_id_set in types_to_doc_id_set.items()
709
- }
710
-
711
- # 8) Save outputs
712
- os.makedirs(os.path.dirname(out_terms2types) or ".", exist_ok=True)
713
- with open(out_terms2types, "w", encoding="utf-8") as fp_terms2types:
714
- json.dump(
715
- term_to_predicted_types_list,
716
- fp_terms2types,
717
- ensure_ascii=False,
718
- indent=2,
402
+ with torch.no_grad():
403
+ out = self.model.generate(
404
+ **enc,
405
+ max_new_tokens=self.max_new_tokens,
406
+ do_sample=False,
407
+ num_beams=1,
408
+ pad_token_id=self.tokenizer.eos_token_id,
719
409
  )
720
410
 
721
- os.makedirs(os.path.dirname(out_types2docs) or ".", exist_ok=True)
722
- with open(out_types2docs, "w", encoding="utf-8") as fp_types2docs:
723
- json.dump(types_to_doc_ids, fp_types2docs, ensure_ascii=False, indent=2)
411
+ gen_tokens = out[0][enc["input_ids"].shape[1] :]
412
+ return self.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
724
413
 
725
- # Cleanup VRAM if any
726
- del typing_model, typing_tokenizer
727
- if torch.cuda.is_available():
728
- torch.cuda.empty_cache()
729
-
730
- return {
731
- "terms2types_pred": term_to_predicted_types_list,
732
- "types2docs_pred": types_to_doc_ids,
733
- "unique_terms": len(unique_terms),
734
- "types_count": len(types_to_doc_ids),
735
- }
736
-
737
- def _load_json(self, path: str) -> Dict[str, Any]:
738
- """Load a JSON file from disk and return its parsed object."""
739
- with open(path, "r", encoding="utf-8") as file_obj:
740
- return json.load(file_obj)
741
-
742
- def _iter_json_objects(self, blob: str) -> Iterable[Dict[str, Any]]:
414
+ def _retrieve_doc_fewshot(self, doc: Dict[str, Any]) -> List[Dict[str, Any]]:
743
415
  """
744
- Iterate over *all* JSON objects found inside a string.
745
-
746
- Supports cases where multiple JSON objects are concatenated back-to-back
747
- in a single line. It skips stray commas/whitespace between objects.
748
-
749
- Parameters
750
- ----------
751
- blob : str
752
- A string that may contain one or more JSON objects.
753
-
754
- Yields
755
- ------
756
- Dict[str, Any]
757
- Each parsed JSON object.
416
+ Retrieve top-k doc examples (JSON dicts) for few-shot doc->terms prompting.
758
417
  """
759
- json_decoder = json.JSONDecoder()
760
- cursor_index, text_length = 0, len(blob)
761
- while cursor_index < text_length:
762
- # Skip whitespace/commas between objects
763
- while cursor_index < text_length and blob[cursor_index] in " \t\r\n,":
764
- cursor_index += 1
765
- if cursor_index >= text_length:
766
- break
418
+ q = self._format_doc(doc.get("title", ""), doc.get("text", ""))
419
+ hits = self.doc_retriever.retrieve([q], top_k=self.top_k)[0]
420
+
421
+ out: List[Dict[str, Any]] = []
422
+ for h in hits:
767
423
  try:
768
- json_obj, end_index = json_decoder.raw_decode(blob, idx=cursor_index)
769
- except JSONDecodeError:
770
- # Can't decode from this position; stop scanning this chunk
771
- break
772
- yield json_obj
773
- cursor_index = end_index
774
-
775
- def _load_documents_jsonl(self, path: str) -> Dict[str, Dict[str, Any]]:
424
+ out.append(json.loads(h))
425
+ except Exception:
426
+ continue
427
+ return out
428
+
429
+ def _retrieve_term_fewshot(self, term: str) -> List[Dict[str, Any]]:
776
430
  """
777
- Robust reader that supports:
778
- • True JSONL (one object per line)
779
- • Lines with multiple concatenated JSON objects
780
- • Whole file as a JSON array
781
-
782
- Returns
783
- -------
784
- Dict[str, Dict[str, Any]]
785
- Mapping doc_id -> full document row.
431
+ Retrieve top-k term examples (JSON dicts) for few-shot term->types prompting.
786
432
  """
787
- documents_by_id: Dict[str, Dict[str, Any]] = {}
433
+ hits = self.term_retriever.retrieve([term], top_k=self.top_k)[0]
788
434
 
789
- with open(path, "r", encoding="utf-8") as file_obj:
790
- content = file_obj.read().strip()
791
-
792
- # Case A: whole-file JSON array
793
- if content.startswith("["):
435
+ out: List[Dict[str, Any]] = []
436
+ for h in hits:
794
437
  try:
795
- json_array = json.loads(content)
796
- if isinstance(json_array, list):
797
- for record in json_array:
798
- if not isinstance(record, dict):
799
- continue
800
- document_id = str(
801
- record.get("id")
802
- or record.get("doc_id")
803
- or (record.get("doc") or {}).get("id")
804
- or ""
805
- )
806
- if document_id:
807
- documents_by_id[document_id] = record
808
- return documents_by_id
438
+ out.append(json.loads(h))
809
439
  except Exception:
810
- # Fall back to line-wise handling if array parsing fails
811
- pass
812
-
813
- # Case B: treat as JSONL-ish; parse *all* objects per line
814
- for raw_line in content.splitlines():
815
- line = raw_line.strip()
816
- if not line:
817
440
  continue
818
- for record in self._iter_json_objects(line):
819
- if not isinstance(record, dict):
820
- continue
821
- document_id = str(
822
- record.get("id")
823
- or record.get("doc_id")
824
- or (record.get("doc") or {}).get("id")
825
- or ""
826
- )
827
- if document_id:
828
- documents_by_id[document_id] = record
829
-
830
- return documents_by_id
831
-
832
- def _to_text(self, text_field: Any) -> str:
833
- """
834
- Convert a 'text' field into a single string (handles list-of-strings).
835
-
836
- Parameters
837
- ----------
838
- text_field : Any
839
- The value found under "text" in the dataset row.
441
+ return out
840
442
 
841
- Returns
842
- -------
843
- str
844
- A single-string representation of the text.
443
+ def _doc_to_terms(self, doc: Dict[str, Any]) -> List[str]:
845
444
  """
846
- if isinstance(text_field, str):
847
- return text_field
848
- if isinstance(text_field, list):
849
- return " ".join(str(part) for part in text_field)
850
- return str(text_field) if text_field is not None else ""
851
-
852
- def _unique_preserve(self, values: List[str]) -> List[str]:
445
+ Predict terms for a document using few-shot prompting + doc retrieval.
853
446
  """
854
- Deduplicate values while preserving the original order.
447
+ fewshot = self._retrieve_doc_fewshot(doc)
855
448
 
856
- Parameters
857
- ----------
858
- values : List[str]
859
- Sequence possibly containing duplicates.
449
+ convo: List[Dict[str, str]] = [{"role": "system", "content": self.DOC2TERMS_SYSTEM_PROMPT}]
860
450
 
861
- Returns
862
- -------
863
- List[str]
864
- Sequence without duplicates, order preserved.
865
- """
866
- seen_values: set = set()
867
- ordered_values: List[str] = []
868
- for candidate in values:
869
- if candidate not in seen_values:
870
- seen_values.add(candidate)
871
- ordered_values.append(candidate)
872
- return ordered_values
873
-
874
- def _norm(self, text: str) -> str:
875
- """
876
- Lowercased, single-spaced normalization (for comparisons).
451
+ for ex in fewshot:
452
+ ex_tfidf = ex.get("TF-IDF") or ex.get("tfidf_terms") or []
453
+ convo += [
454
+ {
455
+ "role": "user",
456
+ "content": self._format_doc(
457
+ ex.get("title", ""),
458
+ ex.get("text", ""),
459
+ ex_tfidf if self.use_tfidf else None,
460
+ ),
461
+ },
462
+ {
463
+ "role": "assistant",
464
+ "content": json.dumps({"terms": ex.get("OL", [])}, ensure_ascii=False),
465
+ },
466
+ ]
877
467
 
878
- Parameters
879
- ----------
880
- text : str
881
- Input string.
468
+ tfidf = doc.get("TF-IDF") or doc.get("tfidf_terms") or []
469
+ convo.append(
470
+ {
471
+ "role": "user",
472
+ "content": self._format_doc(
473
+ doc.get("title", ""),
474
+ doc.get("text", ""),
475
+ tfidf if self.use_tfidf else None,
476
+ ),
477
+ }
478
+ )
882
479
 
883
- Returns
884
- -------
885
- str
886
- Normalized string.
887
- """
888
- return " ".join(text.lower().split())
480
+ prompt = self._apply_chat_template(convo)
481
+ gen = self._generate(prompt)
482
+ parsed = self._extract_first_json_obj(gen) or {}
483
+ return self._dedup_clean(parsed.get("terms", []))
889
484
 
890
- def _normalize_term(self, term: str) -> str:
485
+ def _term_to_types(self, term: str) -> List[str]:
486
+ """
487
+ Predict types for a term using few-shot prompting + term retrieval.
891
488
  """
892
- Normalization tailored for term keys / lookups.
489
+ fewshot = self._retrieve_term_fewshot(term)
893
490
 
894
- Parameters
895
- ----------
896
- term : str
897
- Term to normalize.
491
+ system = self.TERM2TYPES_SYSTEM_PROMPT
492
+ if self.restrict_to_known_types and self._allowed_types:
493
+ allowed_block = "\n".join(f"- {t}" for t in self._allowed_types)
494
+ system = (
495
+ system
496
+ + "\n\nIMPORTANT CONSTRAINT:\n"
497
+ + "Choose ONLY from the following valid ontology types (do not invent new labels):\n"
498
+ + allowed_block
499
+ )
898
500
 
899
- Returns
900
- -------
901
- str
902
- Lowercased, trimmed and single-spaced term.
903
- """
904
- return " ".join(str(term).strip().split()).lower()
501
+ convo: List[Dict[str, str]] = [{"role": "system", "content": system}]
905
502
 
906
- def _format_fewshot_block(
907
- self,
908
- system_prompt: str,
909
- fewshot_examples: List[Tuple[str, str, List[str]]],
910
- *,
911
- key: str,
912
- k: int = 6,
913
- ) -> str:
914
- """
915
- Render a few-shot block like:
503
+ for ex in fewshot:
504
+ convo += [
505
+ {"role": "user", "content": f"Term: {ex.get('term','')}"},
506
+ {
507
+ "role": "assistant",
508
+ "content": json.dumps({"types": ex.get("types", [])}, ensure_ascii=False),
509
+ },
510
+ ]
916
511
 
917
- <SYSTEM PROMPT>
512
+ convo.append({"role": "user", "content": f"Term: {term}"})
918
513
 
919
- ### Example
920
- User:
921
- Title: ...
922
- <text>
923
- Assistant:
924
- {"terms": [...]} or {"types": [...]}
514
+ prompt = self._apply_chat_template(convo)
515
+ gen = self._generate(prompt)
516
+ parsed = self._extract_first_json_obj(gen) or {}
517
+ return self._dedup_clean(parsed.get("types", []))
925
518
 
926
- Parameters
927
- ----------
928
- system_prompt : str
929
- Instructional system text to prepend.
930
- fewshot_examples : List[Tuple[str, str, List[str]]]
931
- Examples as (title, text, labels_list).
932
- key : str
933
- Either "terms" or "types" depending on the task.
934
- k : int
935
- Number of examples to include.
936
-
937
- Returns
938
- -------
939
- str
940
- Formatted few-shot block text.
519
+ def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
941
520
  """
942
- lines: List[str] = [system_prompt.strip(), ""]
943
- for example_title, example_text, gold_list in fewshot_examples[:k]:
944
- lines.append("### Example")
945
- lines.append(f"User:\nTitle: {example_title}\n{example_text}")
946
- lines.append(
947
- f'Assistant:\n{{"{key}": '
948
- + json.dumps(gold_list, ensure_ascii=False)
949
- + "}"
950
- )
951
- return "\n".join(lines)
521
+ Train or predict for task="text2onto".
952
522
 
953
- def _format_user_block(self, title: str, text: str) -> str:
523
+ Returns:
524
+ - training: None
525
+ - prediction: {"terms": [...], "types": [...]}
954
526
  """
955
- Format the 'Task' block for the current document.
527
+ if not self._loaded:
528
+ self.load(model_id=self.llm_model_id, device=self.device)
956
529
 
957
- Parameters
958
- ----------
959
- title : str
960
- Document title.
961
- text : str
962
- Document text (single string).
963
-
964
- Returns
965
- -------
966
- str
967
- Formatted user block.
968
- """
969
- return f"### Task\nUser:\nTitle: {title}\n{text}"
530
+ if not isinstance(data, dict):
531
+ raise ValueError("text2onto expects a dict with documents + mappings.")
970
532
 
971
- def _parse_json_list(self, generated_text: str, *, key: str) -> List[str]:
972
- """
973
- Extract a list from model output, trying:
974
- 1) JSON object with the key ({"terms":[...]} or {"types":[...]}).
975
- 2) Any top-level JSON array.
976
- 3) Fallback: comma-split.
533
+ docs = self._extract_documents(data)
977
534
 
978
- Parameters
979
- ----------
980
- generated_text : str
981
- Raw generation text to parse.
982
- key : str
983
- "terms" or "types".
984
-
985
- Returns
986
- -------
987
- List[str]
988
- Parsed strings (best-effort).
989
- """
990
- # 1) Try a JSON object and read key
991
- try:
992
- object_match = self._json_object_regex.search(generated_text)
993
- if object_match:
994
- json_obj = json.loads(object_match.group(0))
995
- json_array = json_obj.get(key)
996
- if isinstance(json_array, list):
997
- return [value for value in json_array if isinstance(value, str)]
998
- except Exception:
999
- pass
1000
-
1001
- # 2) Any JSON array
1002
- try:
1003
- array_match = self._json_array_regex.search(generated_text)
1004
- if array_match:
1005
- json_array = json.loads(array_match.group(0))
1006
- if isinstance(json_array, list):
1007
- return [value for value in json_array if isinstance(value, str)]
1008
- except Exception:
1009
- pass
1010
-
1011
- # 3) Fallback: comma-split (last resort)
1012
- if "," in generated_text:
1013
- return [
1014
- part.strip().strip('"').strip("'")
1015
- for part in generated_text.split(",")
1016
- if part.strip()
1017
- ]
1018
- return []
535
+ raw_terms2docs = data.get("terms2docs") or data.get("term2docs") or {}
536
+ terms2types = data.get("terms2types") or data.get("term2types") or {}
1019
537
 
1020
- def _apply_chat_template_safe_types(
1021
- self, tokenizer: AutoTokenizer, messages: List[Dict[str, str]]
1022
- ) -> str:
1023
- """
1024
- Safely build a prompt string for chat models. Uses the model's chat template
1025
- when available; otherwise falls back to a simple concatenation.
1026
- """
1027
- try:
1028
- return tokenizer.apply_chat_template(
1029
- messages, add_generation_prompt=True, tokenize=False
1030
- )
1031
- except Exception:
1032
- system_text = next(
1033
- (m["content"] for m in messages if m.get("role") == "system"), ""
1034
- )
1035
- last_user_text = next(
1036
- (m["content"] for m in reversed(messages) if m.get("role") == "user"),
1037
- "",
1038
- )
1039
- return f"{system_text}\n\nUser:\n{last_user_text}\n\nAssistant:"
538
+ terms2docs = self._normalize_terms2docs(raw_terms2docs, docs)
1040
539
 
1041
- def _build_conv_for_type_infer(
1042
- self,
1043
- term: str,
1044
- few_shot_examples: Optional[List[Dict]] = None,
1045
- random_k: Optional[int] = None,
1046
- ) -> List[Dict[str, str]]:
1047
- """
1048
- Create a chat-style conversation for a single term→types query,
1049
- optionally prepending few-shot examples.
1050
- """
1051
- messages: List[Dict[str, str]] = [
1052
- {"role": "system", "content": self._system_prompt_term_to_types}
1053
- ]
1054
- examples = list(few_shot_examples or [])
1055
- if random_k and len(examples) > random_k:
1056
- import random as _rnd
1057
-
1058
- examples = _rnd.sample(examples, random_k)
1059
- for exemplar in examples:
1060
- example_term = exemplar.get("term", "")
1061
- example_types = exemplar.get("types", [])
1062
- messages.append({"role": "user", "content": f"Term: {example_term}"})
1063
- messages.append(
540
+ if not test:
541
+ self._allowed_types = sorted(
1064
542
  {
1065
- "role": "assistant",
1066
- "content": json.dumps({"types": example_types}, ensure_ascii=False),
543
+ ty.strip()
544
+ for tys in (terms2types or {}).values()
545
+ for ty in (tys or [])
546
+ if isinstance(ty, str) and ty.strip()
1067
547
  }
1068
548
  )
1069
- messages.append({"role": "user", "content": f"Term: {term}"})
1070
- return messages
1071
549
 
1072
- def _extract_types_from_text(self, generated_text: str) -> List[str]:
1073
- """
1074
- Parse {"types":[...]} from a free-form generation.
1075
- """
1076
- try:
1077
- object_match = re.search(r'\{[^}]*"types"[^}]*\}', generated_text)
1078
- if object_match:
1079
- json_obj = json.loads(object_match.group(0))
1080
- types_array = json_obj.get("types", [])
1081
- return [
1082
- type_label
1083
- for type_label in types_array
1084
- if isinstance(type_label, str)
1085
- ]
1086
- except Exception:
1087
- pass
1088
- return []
1089
-
1090
- def _load_llm_for_types(
1091
- self, model_id: str
1092
- ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
1093
- """
1094
- Load a *separate* small chat model for Term→Types (keeps LocalAutoLLM untouched).
1095
- """
1096
- tokenizer = AutoTokenizer.from_pretrained(model_id)
1097
- if tokenizer.pad_token is None:
1098
- tokenizer.pad_token = tokenizer.eos_token
1099
- model = AutoModelForCausalLM.from_pretrained(
1100
- model_id,
1101
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
1102
- device_map="auto" if torch.cuda.is_available() else None,
1103
- )
1104
- return model, tokenizer
550
+ # build doc->terms from term->docs
551
+ doc2terms: Dict[str, List[str]] = defaultdict(list)
552
+ for term, doc_ids in (terms2docs or {}).items():
553
+ for did in (doc_ids or []):
554
+ doc2terms[str(did)].append(term)
555
+
556
+ # doc few-shot corpus: doc + gold OL terms
557
+ doc_examples: List[Dict[str, Any]] = []
558
+ for d in docs:
559
+ did = self._doc_id(d)
560
+ ex = dict(d)
561
+ ex["doc_id"] = did
562
+ ex["OL"] = self._dedup_clean(doc2terms.get(did, []))
563
+ doc_examples.append(ex)
564
+
565
+ # term few-shot corpus: term + gold types
566
+ term_examples = [
567
+ {"term": t, "types": self._dedup_clean(tys)}
568
+ for t, tys in (terms2types or {}).items()
569
+ ]
1105
570
 
1106
- def _load_doc_term_extractions(
1107
- self,
1108
- *,
1109
- results_json_path: Optional[str] = None,
1110
- in_memory_results: Optional[List[Dict]] = None,
1111
- ) -> List[Dict]:
1112
- """
1113
- Normalize document→terms outputs to a list of:
1114
- {"id": "<doc_id>", "extracted_terms": ["...", ...]}
1115
-
1116
- Accepts either:
1117
- - in_memory_results (list of dicts)
1118
- - results_json_path pointing to:
1119
- • a JSONL file with lines: {"id": "...", "terms": [...]}
1120
- • OR a JSON file with {"results":[{"id":..., "extracted_terms": [...]}, ...]}
1121
- • OR a JSON list of dicts
1122
- """
1123
- normalized_records: List[Dict] = []
571
+ # store as JSON strings so retrievers return parseable strings
572
+ self._doc_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in doc_examples]
573
+ self._term_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in term_examples]
1124
574
 
1125
- def _coerce_to_record(source_row: Dict) -> Optional[Dict]:
1126
- document_id = str(source_row.get("id", "")) or str(
1127
- source_row.get("doc_id", "")
1128
- )
1129
- if not document_id:
1130
- return None
1131
- terms = source_row.get("extracted_terms")
1132
- if terms is None:
1133
- terms = source_row.get("terms")
1134
- if (
1135
- terms is None
1136
- and "payload" in source_row
1137
- and isinstance(source_row["payload"], dict)
1138
- ):
1139
- terms = source_row["payload"].get("terms")
1140
- if not isinstance(terms, list):
1141
- terms = []
1142
- return {
1143
- "id": document_id,
1144
- "extracted_terms": [t for t in terms if isinstance(t, str)],
1145
- }
575
+ # index retrievers
576
+ self.doc_retriever.index(self._doc_examples_json)
577
+ self.term_retriever.index(self._term_examples_json)
578
+ return None
1146
579
 
1147
- if in_memory_results is not None:
1148
- for source_row in in_memory_results:
1149
- coerced_record = _coerce_to_record(source_row)
1150
- if coerced_record:
1151
- normalized_records.append(coerced_record)
1152
- return normalized_records
1153
-
1154
- if not results_json_path:
1155
- raise ValueError("Provide results_json_path or in_memory_results")
1156
-
1157
- # Detect JSON vs JSONL by extension (best-effort)
1158
- if results_json_path.endswith(".jsonl"):
1159
- with open(results_json_path, "r", encoding="utf-8") as file_in:
1160
- for raw_line in file_in:
1161
- raw_line = raw_line.strip()
1162
- if not raw_line:
1163
- continue
1164
- # Multiple concatenated objects per line? Iterate them all.
1165
- for json_obj in self._iter_json_objects(raw_line):
1166
- if isinstance(json_obj, dict):
1167
- coerced_record = _coerce_to_record(json_obj)
1168
- if coerced_record:
1169
- normalized_records.append(coerced_record)
1170
- else:
1171
- payload_obj = self._load_json(results_json_path)
1172
- if isinstance(payload_obj, dict) and "results" in payload_obj:
1173
- for source_row in payload_obj["results"]:
1174
- coerced_record = _coerce_to_record(source_row)
1175
- if coerced_record:
1176
- normalized_records.append(coerced_record)
1177
- elif isinstance(payload_obj, list):
1178
- for source_row in payload_obj:
1179
- if isinstance(source_row, dict):
1180
- coerced_record = _coerce_to_record(source_row)
1181
- if coerced_record:
1182
- normalized_records.append(coerced_record)
1183
-
1184
- return normalized_records
1185
-
1186
- def _collect_unique_terms_from_extractions(
1187
- self, doc_term_extractions: List[Dict]
1188
- ) -> List[str]:
1189
- """
1190
- Collect unique terms (original casing) from normalized document→terms results.
1191
- """
1192
- seen_normalized_terms: set = set()
1193
- ordered_unique_terms: List[str] = []
1194
- for record in doc_term_extractions:
1195
- for term_text in record.get("extracted_terms", []):
1196
- normalized = self._normalize_term(term_text)
1197
- if normalized and normalized not in seen_normalized_terms:
1198
- seen_normalized_terms.add(normalized)
1199
- ordered_unique_terms.append(term_text.strip())
1200
- return ordered_unique_terms
1201
-
1202
- def _build_term_to_doc_ids(
1203
- self, doc_term_extractions: List[Dict]
1204
- ) -> Dict[str, List[str]]:
1205
- """
1206
- Build lookup: normalized_term -> sorted unique list of doc_ids.
1207
- """
1208
- term_to_doc_set: Dict[str, set] = {}
1209
- for record in doc_term_extractions:
1210
- document_id = str(record.get("id", ""))
1211
- for term_text in record.get("extracted_terms", []):
1212
- normalized = self._normalize_term(term_text)
1213
- if not normalized or not document_id:
1214
- continue
1215
- term_to_doc_set.setdefault(normalized, set()).add(document_id)
1216
- return {
1217
- normalized_term: sorted(doc_ids)
1218
- for normalized_term, doc_ids in term_to_doc_set.items()
1219
- }
580
+ doc2terms_pred: Dict[str, List[str]] = {}
581
+ for d in docs:
582
+ did = self._doc_id(d)
583
+ doc2terms_pred[did] = self._doc_to_terms(d)
584
+
585
+ unique_terms = sorted({t for ts in doc2terms_pred.values() for t in ts})
586
+ term2types_pred: Dict[str, List[str]] = {t: self._term_to_types(t) for t in unique_terms}
587
+
588
+ doc2types_pred: Dict[str, List[str]] = {}
589
+ for did, terms in doc2terms_pred.items():
590
+ tys: List[str] = []
591
+ for t in terms:
592
+ tys.extend(term2types_pred.get(t, []))
593
+ doc2types_pred[did] = self._dedup_clean(tys)
594
+
595
+ pred_terms = [{"doc_id": did, "term": t} for did, ts in doc2terms_pred.items() for t in ts]
596
+ pred_types = [{"doc_id": did, "type": ty} for did, tys in doc2types_pred.items() for ty in tys]
597
+
598
+ return {"terms": pred_terms, "types": pred_types}