OntoLearner 1.4.10__py3-none-any.whl → 1.4.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@
4
4
  # you may not use this file except in compliance with the License.
5
5
  # You may obtain a copy of the License at
6
6
  #
7
- #      https://opensource.org/licenses/MIT
7
+ # https://opensource.org/licenses/MIT
8
8
  #
9
9
  # Unless required by applicable law or agreed to in writing, software
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,587 +12,592 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import json
16
- import random
17
- import re
18
15
  import ast
19
16
  import gc
20
- from typing import Any, Dict, List, Optional, Set, Tuple
17
+ import random
18
+ import re
21
19
  from collections import defaultdict
20
+ from typing import Any, DefaultDict, Dict, List, Optional, Set
22
21
 
23
22
  import torch
24
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
25
24
 
26
- from ...base import AutoLearner, AutoLLM
25
+ from ...base import AutoLearner
27
26
 
28
-
29
- # -----------------------------------------------------------------------------
30
- # Concrete AutoLLM: local HF wrapper that follows the AutoLLM interface
31
- # -----------------------------------------------------------------------------
32
- class LocalAutoLLM(AutoLLM):
27
+ class SBUNLPFewShotLearner(AutoLearner):
33
28
  """
34
- Handles loading and generation for a Hugging Face Causal Language Model (Qwen/TinyLlama).
35
- Uses 4-bit quantization for efficiency and greedy decoding by default.
29
+ Public API expected by the pipeline:
30
+ - `load(model_id=...)`
31
+ - `fit(train_data, task=..., ontologizer=...)`
32
+ - `predict(test_data, task=..., ontologizer=...)`
33
+
34
+ Expected input bundle format (train/test):
35
+ - "documents": list of dicts, each with keys: {"id", "title", "text"}
36
+ - "terms2docs": dict mapping term -> list of doc_ids
37
+ - "terms2types": optional dict mapping term -> list of types
38
+
39
+ Prediction output payload (pipeline wraps this):
40
+ - {"terms": [{"doc_id": str, "term": str}, ...],
41
+ "types": [{"doc_id": str, "type": str}, ...]}
36
42
  """
37
43
 
38
44
  def __init__(
39
- self, label_mapper: Any = None, device: str = "cpu", token: str = ""
40
- ) -> None:
41
- super().__init__(label_mapper=label_mapper, device=device, token=token)
42
- self.model = None
43
- self.tokenizer = None
44
-
45
- def load(
46
45
  self,
47
- model_id: str,
46
+ llm_model_id: Optional[str] = None,
47
+ device: str = "cpu",
48
48
  load_in_4bit: bool = False,
49
- dtype: str = "auto",
49
+ max_new_tokens: int = 256,
50
50
  trust_remote_code: bool = True,
51
- ):
52
- """Load tokenizer + model, applying 4-bit quantization if specified and possible."""
51
+ ) -> None:
52
+ """
53
+ Initialize the few-shot learner.
54
+
55
+ Args:
56
+ llm_model_id: Default HF model id to load if `load()` is called without one.
57
+ device: "cpu" or a CUDA device identifier (e.g. "cuda").
58
+ load_in_4bit: Whether to attempt 4-bit quantized loading (bitsandbytes).
59
+ max_new_tokens: Maximum tokens to generate per prompt.
60
+ retriever_model_id: Unused (kept for compatibility).
61
+ top_k: Unused (kept for compatibility).
62
+ trust_remote_code: Forwarded to HF loaders (use with caution).
63
+ """
64
+ super().__init__()
65
+ self.device = device
66
+ self.max_new_tokens = int(max_new_tokens)
53
67
 
54
- # Determine the target data type (default to float32 for CPU, float16 for GPU)
55
- torch_dtype_val = torch.float16 if torch.cuda.is_available() else torch.float32
68
+ self._default_model_id = llm_model_id
69
+ self._load_in_4bit_default = bool(load_in_4bit)
70
+ self._trust_remote_code_default = bool(trust_remote_code)
56
71
 
57
- # Load the tokenizer
58
- self.tokenizer = AutoTokenizer.from_pretrained(
59
- model_id, trust_remote_code=trust_remote_code
60
- )
61
- if self.tokenizer.pad_token is None:
62
- self.tokenizer.pad_token = self.tokenizer.eos_token
72
+ # HF objects
73
+ self.model: Optional[AutoModelForCausalLM] = None
74
+ self.tokenizer: Optional[AutoTokenizer] = None
75
+
76
+ self._is_loaded = False
77
+ self._loaded_model_id: Optional[str] = None
63
78
 
64
- quant_config = None
79
+ # Cached few-shot example blocks built during `fit()`
80
+ self.few_shot_terms_block: str = ""
81
+ self.few_shot_types_block: str = ""
82
+
83
+ def load(self, model_id: Optional[str] = None, **kwargs: Any) -> None:
84
+ """
85
+ Load the underlying HF causal LM and tokenizer.
86
+
87
+ LearnerPipeline typically calls: `learner.load(model_id=llm_id)`.
88
+
89
+ Args:
90
+ model_id: HF model id. If None, uses `llm_model_id` from __init__.
91
+ **kwargs:
92
+ load_in_4bit: override default 4-bit loading.
93
+ trust_remote_code: override default trust_remote_code.
94
+ """
95
+ resolved_model_id = model_id or self._default_model_id
96
+ if not resolved_model_id:
97
+ raise ValueError(
98
+ f"No model_id provided to {self.__class__.__name__}.load() and no llm_model_id in __init__."
99
+ )
100
+
101
+ load_in_4bit = bool(kwargs.get("load_in_4bit", self._load_in_4bit_default))
102
+ trust_remote_code = bool(kwargs.get("trust_remote_code", self._trust_remote_code_default))
103
+
104
+ # Avoid re-loading same model
105
+ if self._is_loaded and self._loaded_model_id == resolved_model_id:
106
+ return
107
+
108
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(resolved_model_id, trust_remote_code=trust_remote_code)
111
+ if tokenizer.pad_token is None:
112
+ tokenizer.pad_token = tokenizer.eos_token
113
+ self.tokenizer = tokenizer
114
+
115
+ quantization_config = None
65
116
  if load_in_4bit:
66
- # Configure BitsAndBytes for 4-bit loading
67
- quant_config = BitsAndBytesConfig(
117
+ quantization_config = BitsAndBytesConfig(
68
118
  load_in_4bit=True,
69
119
  bnb_4bit_compute_dtype=torch.float16,
70
120
  bnb_4bit_use_double_quant=True,
71
121
  bnb_4bit_quant_type="nf4",
72
122
  )
73
- if torch_dtype_val is None:
74
- torch_dtype_val = torch.float16
123
+ torch_dtype = torch.float16
75
124
 
76
- # Set device mapping (auto for multi-GPU or single GPU, explicit CPU otherwise)
77
125
  device_map = "auto" if (self.device != "cpu") else {"": "cpu"}
78
126
 
79
- # Load the Causal Language Model
80
- self.model = AutoModelForCausalLM.from_pretrained(
81
- model_id,
127
+ model = AutoModelForCausalLM.from_pretrained(
128
+ resolved_model_id,
82
129
  device_map=device_map,
83
- torch_dtype=torch_dtype_val,
84
- quantization_config=quant_config,
130
+ torch_dtype=torch_dtype,
131
+ quantization_config=quantization_config,
85
132
  trust_remote_code=trust_remote_code,
86
133
  )
87
134
 
88
- # Ensure model is on the correct device (redundant if device_map="auto" but safe)
89
135
  if self.device == "cpu":
90
- self.model.to("cpu")
136
+ model.to("cpu")
91
137
 
92
- def generate(
93
- self,
94
- inputs: List[str],
95
- max_new_tokens: int = 64,
96
- temperature: float = 0.0,
97
- top_p: float = 1.0,
98
- ) -> List[str]:
99
- """Generate continuations for a list of prompts, returning only the generated part."""
100
- if self.model is None or self.tokenizer is None:
101
- raise RuntimeError("Model/tokenizer not loaded. Call .load() first.")
138
+ self.model = model
139
+ self._is_loaded = True
140
+ self._loaded_model_id = resolved_model_id
141
+
142
+ def _invert_terms_to_docs_mapping(self, terms_to_documents: Dict[str, List[str]]) -> Dict[str, List[str]]:
143
+ """
144
+ Convert term->docs mapping to doc->terms mapping.
102
145
 
103
- # --- Generation Setup ---
104
- # Tokenize batch (padding is essential for batch inference)
105
- enc = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
106
- input_ids = enc["input_ids"]
107
- attention_mask = enc["attention_mask"]
146
+ Args:
147
+ terms_to_documents: Mapping from term to list of document IDs.
108
148
 
109
- # Move tensors to the model's device (e.g., cuda:0)
110
- model_device = next(self.model.parameters()).device
111
- input_ids = input_ids.to(model_device)
112
- attention_mask = attention_mask.to(model_device)
149
+ Returns:
150
+ Mapping from document ID to list of terms associated with it.
151
+ """
152
+ document_to_terms: DefaultDict[str, List[str]] = defaultdict(list)
153
+ for term, document_ids in (terms_to_documents or {}).items():
154
+ for document_id in document_ids or []:
155
+ document_to_terms[str(document_id)].append(str(term))
156
+ return dict(document_to_terms)
113
157
 
114
- # --- Generate ---
115
- with torch.no_grad():
116
- outputs = self.model.generate(
117
- input_ids=input_ids,
118
- attention_mask=attention_mask,
119
- max_new_tokens=max_new_tokens,
120
- do_sample=(
121
- temperature > 0.0
122
- ), # Use greedy decoding if temperature is 0.0
123
- temperature=temperature,
124
- top_p=top_p,
125
- pad_token_id=self.tokenizer.eos_token_id,
126
- )
158
+ def _derive_document_to_types(
159
+ self,
160
+ terms_to_documents: Dict[str, List[str]],
161
+ terms_to_types: Optional[Dict[str, List[str]]],
162
+ ) -> Dict[str, List[str]]:
163
+ """
164
+ Derive doc->types mapping using (terms->docs) and (terms->types).
127
165
 
128
- # --- Post-processing: Extract only the generated tail ---
129
- decoded_outputs: List[str] = []
130
- for i, output_ids in enumerate(outputs):
131
- full_decoded_text = self.tokenizer.decode(
132
- output_ids, skip_special_tokens=True
133
- )
134
- prompt_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True)
166
+ Args:
167
+ terms_to_documents: term -> [doc_id...]
168
+ terms_to_types: term -> [type...]
135
169
 
136
- # Safely strip the prompt text from the full output
137
- if full_decoded_text.startswith(prompt_text):
138
- generated_tail = full_decoded_text[len(prompt_text) :].strip()
139
- else:
140
- # Fallback extraction (less robust if padding affects token indices)
141
- prompt_len = input_ids.shape[1]
142
- generated_tail = self.tokenizer.decode(
143
- output_ids[prompt_len:], skip_special_tokens=True
144
- ).strip()
145
- decoded_outputs.append(generated_tail)
170
+ Returns:
171
+ doc_id -> sorted list of unique types.
172
+ """
173
+ if not terms_to_types:
174
+ return {}
146
175
 
147
- return decoded_outputs
176
+ document_to_types: DefaultDict[str, Set[str]] = defaultdict(set)
148
177
 
178
+ for term, document_ids in (terms_to_documents or {}).items():
179
+ candidate_types = terms_to_types.get(term, []) or []
180
+ for document_id in document_ids or []:
181
+ for candidate_type in candidate_types:
182
+ if isinstance(candidate_type, str) and candidate_type.strip():
183
+ document_to_types[str(document_id)].add(candidate_type.strip())
149
184
 
150
- # -----------------------------------------------------------------------------
151
- # Main Learner: SBUNLPFewShotLearner (Task A Text2Onto)
152
- # -----------------------------------------------------------------------------
153
- class SBUNLPFewShotLearner(AutoLearner):
154
- """
155
- Concrete learner implementing the Task A Text2Onto pipeline (Term and Type Extraction).
156
- It uses Few-Shot prompts generated from training data for inference.
157
- """
185
+ return {doc_id: sorted(list(type_set)) for doc_id, type_set in document_to_types.items()}
158
186
 
159
- def __init__(self, model: Optional[AutoLLM] = None, device: str = "cpu"):
160
- super().__init__()
161
- # self.model is an instance of LocalAutoLLM
162
- self.model = model or LocalAutoLLM(device=device)
163
- self.device = device
164
- # Cached in-memory prompt blocks built during the fit phase
165
- self.fewshot_terms_block: str = ""
166
- self.fewshot_types_block: str = ""
187
+ def _truncate_text(self, text: str, max_chars: int) -> str:
188
+ """
189
+ Truncate text to a maximum number of characters (adds an ellipsis when truncated).
190
+
191
+ Args:
192
+ text: Input text.
193
+ max_chars: Maximum characters to keep. If <= 0, returns the original text.
194
+
195
+ Returns:
196
+ Truncated or original text.
197
+ """
198
+ if not max_chars or max_chars <= 0 or not text:
199
+ return text or ""
200
+ return (text[:max_chars] + "…") if len(text) > max_chars else text
167
201
 
168
- # --- Few-shot construction (terms) ---
169
- def build_stratified_fewshot_prompt(
202
+ def build_few_shot_terms_block(
170
203
  self,
171
- documents_path: str,
172
- terms_path: str,
204
+ documents: List[Dict[str, Any]],
205
+ terms_to_documents: Dict[str, List[str]],
173
206
  sample_size: int = 28,
174
207
  seed: int = 123,
175
208
  max_chars_per_text: int = 1200,
176
209
  ) -> str:
177
210
  """
178
- Builds the few-shot exemplar block for Term Extraction using stratified sampling.
211
+ Build and cache the few-shot block for term extraction.
212
+
213
+ Strategy:
214
+ - Create strata by associated terms (doc -> associated term list).
215
+ - Sample proportionally across strata.
216
+ - Deduplicate by document id and top up from remaining docs if needed.
217
+
218
+ Args:
219
+ documents: Documents with keys: {"id","title","text"}.
220
+ terms_to_documents: Mapping term -> list of doc IDs.
221
+ sample_size: Desired number of examples in the block.
222
+ seed: RNG seed (local to this call).
223
+ max_chars_per_text: Text truncation limit per example.
224
+
225
+ Returns:
226
+ The formatted few-shot example block string.
179
227
  """
180
- random.seed(seed)
181
-
182
- # Read documents (JSONL) into a list
183
- corpus_documents: List[Dict[str, Any]] = []
184
- with open(documents_path, "r", encoding="utf-8") as file_handle:
185
- for line in file_handle:
186
- if line.strip():
187
- corpus_documents.append(json.loads(line))
188
-
189
- num_total_docs = len(corpus_documents)
190
- num_sample_docs = min(sample_size, num_total_docs)
191
-
192
- # Load the map of term -> [list of document IDs]
193
- with open(terms_path, "r", encoding="utf-8") as file_handle:
194
- term_to_doc_map = json.load(file_handle)
195
-
196
- # Invert map: document ID -> [list of terms]
197
- doc_id_to_terms_map = defaultdict(list)
198
- for term, doc_ids in term_to_doc_map.items():
199
- for doc_id in doc_ids:
200
- doc_id_to_terms_map[doc_id].append(term)
201
-
202
- # Define strata (groups of documents associated with specific terms)
203
- strata_map = defaultdict(list)
204
- for doc in corpus_documents:
205
- doc_id = doc.get("id", "")
206
- associated_terms = doc_id_to_terms_map.get(doc_id, ["no_term"])
228
+ rng = random.Random(seed)
229
+
230
+ document_to_terms = self._invert_terms_to_docs_mapping(terms_to_documents)
231
+ total_documents = len(documents)
232
+ target_sample_count = min(int(sample_size), total_documents)
233
+
234
+ strata: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list)
235
+ for document in documents:
236
+ document_id = str(document.get("id", ""))
237
+ associated_terms = document_to_terms.get(document_id, ["no_term"])
207
238
  for term in associated_terms:
208
- strata_map[term].append(doc)
239
+ strata[str(term)].append(document)
209
240
 
210
- # Perform proportional sampling across strata
211
241
  sampled_documents: List[Dict[str, Any]] = []
212
- for term_str, stratum_docs in strata_map.items():
213
- num_stratum_docs = len(stratum_docs)
214
- if num_stratum_docs == 0:
242
+ for docs_in_stratum in strata.values():
243
+ if not docs_in_stratum:
215
244
  continue
216
-
217
- # Calculate proportional sample size
218
- proportion = num_stratum_docs / num_total_docs
219
- num_to_sample_from_stratum = int(num_sample_docs * proportion)
220
-
221
- if num_to_sample_from_stratum > 0:
222
- sampled_documents.extend(
223
- random.sample(
224
- stratum_docs, min(num_to_sample_from_stratum, num_stratum_docs)
225
- )
245
+ proportion = len(docs_in_stratum) / max(1, total_documents)
246
+ stratum_quota = int(target_sample_count * proportion)
247
+ if stratum_quota > 0:
248
+ sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum))))
249
+
250
+ sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")}
251
+ final_documents = list(sampled_by_id.values())
252
+
253
+ if len(final_documents) > target_sample_count:
254
+ final_documents = rng.sample(final_documents, target_sample_count)
255
+ elif len(final_documents) < target_sample_count:
256
+ remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id]
257
+ additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents))
258
+ if additional_needed > 0:
259
+ final_documents.extend(rng.sample(remaining_documents, additional_needed))
260
+
261
+ lines: List[str] = []
262
+ for document in final_documents:
263
+ document_id = str(document.get("id", ""))
264
+ title = str(document.get("title", ""))
265
+ text = self._truncate_text(str(document.get("text", "")), max_chars_per_text)
266
+ associated_terms = document_to_terms.get(document_id, [])
267
+
268
+ lines.append(
269
+ "Document ID: {doc_id}\n"
270
+ "Title: {title}\n"
271
+ "Text: {text}\n"
272
+ "Associated Terms: {terms}\n"
273
+ "----------------------------------------".format(
274
+ doc_id=document_id,
275
+ title=title,
276
+ text=text,
277
+ terms=associated_terms,
226
278
  )
227
-
228
- # Deduplicate sampled documents by ID and adjust count to exactly 'sample_size'
229
- unique_docs_by_id = {}
230
- for doc in sampled_documents:
231
- unique_docs_by_id[doc.get("id", "")] = doc
232
-
233
- final_sample_docs = list(unique_docs_by_id.values())
234
-
235
- if len(final_sample_docs) > num_sample_docs:
236
- final_sample_docs = random.sample(final_sample_docs, num_sample_docs)
237
- elif len(final_sample_docs) < num_sample_docs:
238
- remaining_docs = [
239
- d for d in corpus_documents if d.get("id", "") not in unique_docs_by_id
240
- ]
241
- needed_count = min(
242
- num_sample_docs - len(final_sample_docs), len(remaining_docs)
243
- )
244
- final_sample_docs.extend(random.sample(remaining_docs, needed_count))
245
-
246
- # Format the few-shot exemplar text block
247
- prompt_lines: List[str] = []
248
- for doc in final_sample_docs:
249
- doc_id = doc.get("id", "")
250
- title = doc.get("title", "")
251
- text = doc.get("text", "")
252
-
253
- # Truncate text if it exceeds the maximum character limit
254
- if max_chars_per_text and len(text) > max_chars_per_text:
255
- text = text[:max_chars_per_text] + "…"
256
-
257
- associated_terms = doc_id_to_terms_map.get(doc_id, [])
258
- prompt_lines.append(
259
- f"Document ID: {doc_id}\nTitle: {title}\nText: {text}\nAssociated Terms: {associated_terms}\n----------------------------------------"
260
279
  )
261
280
 
262
- prompt_block = "\n".join(prompt_lines)
263
- self.fewshot_terms_block = prompt_block
264
- return prompt_block
281
+ self.few_shot_terms_block = "\n".join(lines)
282
+ return self.few_shot_terms_block
265
283
 
266
- # --- Few-shot construction (types) ---
267
- def build_types_fewshot_block(
284
+ def build_few_shot_types_block(
268
285
  self,
269
- docs_jsonl: str,
270
- terms2doc_json: str,
271
- sample_per_term: int = 1,
272
- full_word: bool = True,
273
- case_sensitive: bool = True,
286
+ documents: List[Dict[str, Any]],
287
+ terms_to_documents: Dict[str, List[str]],
288
+ terms_to_types: Optional[Dict[str, List[str]]] = None,
289
+ sample_size: int = 28,
290
+ seed: int = 123,
274
291
  max_chars_per_text: int = 800,
275
292
  ) -> str:
276
293
  """
277
- Builds the few-shot block for Type Extraction.
278
- This method samples documents based on finding an associated term/type within the text.
294
+ Build and cache the few-shot block for type (class) extraction.
295
+
296
+ Prefers doc->types derived from `terms_to_types`; if absent, falls back to treating
297
+ associated terms as "types" for stratification (behavior-preserving fallback).
298
+
299
+ Args:
300
+ documents: Documents with keys: {"id","title","text"}.
301
+ terms_to_documents: Mapping term -> list of doc IDs.
302
+ terms_to_types: Optional mapping term -> list of types.
303
+ sample_size: Desired number of examples in the block.
304
+ seed: RNG seed (local to this call).
305
+ max_chars_per_text: Text truncation limit per example.
306
+
307
+ Returns:
308
+ The formatted few-shot example block string.
279
309
  """
280
- # Load documents into dict by ID
281
- docs_by_id = {}
282
- with open(docs_jsonl, "r", encoding="utf-8") as file_handle:
283
- for line in file_handle:
284
- line_stripped = line.strip()
285
- if line_stripped:
286
- try:
287
- doc = json.loads(line_stripped)
288
- doc_id = doc.get("id", "")
289
- if doc_id:
290
- docs_by_id[doc_id] = doc
291
- except json.JSONDecodeError:
292
- continue
293
-
294
- # Load term -> [doc_id,...] map
295
- with open(terms2doc_json, "r", encoding="utf-8") as file_handle:
296
- term_to_doc_map = json.load(file_handle)
297
-
298
- flags = 0 if case_sensitive else re.IGNORECASE
299
- prompt_lines: List[str] = []
300
-
301
- # Iterate over terms (which act as types in this context)
302
- for term, doc_ids in term_to_doc_map.items():
303
- escaped_term = re.escape(term)
304
- # Create regex pattern for matching the term in the text
305
- pattern = rf"\b{escaped_term}\b" if full_word else escaped_term
306
- term_regex = re.compile(pattern, flags=flags)
307
-
308
- picked_count = 0
309
- for doc_id in doc_ids:
310
- doc = docs_by_id.get(doc_id)
311
- if not doc:
312
- continue
313
-
314
- title = doc.get("title", "")
315
- text = doc.get("text", "")
316
-
317
- # Check if the term/type is actually present in the document text/title
318
- if term_regex.search(f"{title} {text}"):
319
- text_content = text
320
-
321
- # Truncate text if necessary
322
- if max_chars_per_text and len(text_content) > max_chars_per_text:
323
- text_content = text_content[:max_chars_per_text] + "…"
324
-
325
- # Escape single quotes in the term for Python list formatting in the prompt
326
- term_for_prompt = term.replace("'", "\\'")
327
-
328
- prompt_lines.append(
329
- f"Document ID: {doc_id}\nTitle: {title}\nText: {text_content}\nAssociated Types: ['{term_for_prompt}']\n----------------------------------------"
330
- )
331
- picked_count += 1
332
-
333
- if picked_count >= sample_per_term:
334
- break # Move to the next term
335
-
336
- prompt_block = "\n".join(prompt_lines)
337
- self.fewshot_types_block = prompt_block
338
- return prompt_block
310
+ rng = random.Random(seed)
339
311
 
340
- def fit(
341
- self,
342
- train_docs_jsonl: str,
343
- terms2doc_json: str,
344
- sample_size: int = 28,
345
- seed: int = 123,
346
- ) -> None:
312
+ documents_by_id = {str(d.get("id", "")): d for d in documents if d.get("id", "")}
313
+
314
+ document_to_types = self._derive_document_to_types(terms_to_documents, terms_to_types)
315
+ if not document_to_types:
316
+ document_to_types = self._invert_terms_to_docs_mapping(terms_to_documents)
317
+
318
+ type_to_documents: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list)
319
+ for document_id, candidate_types in document_to_types.items():
320
+ document = documents_by_id.get(document_id)
321
+ if not document:
322
+ continue
323
+ for candidate_type in candidate_types:
324
+ type_to_documents[str(candidate_type)].append(document)
325
+
326
+ total_documents = len(documents)
327
+ target_sample_count = min(int(sample_size), total_documents)
328
+
329
+ sampled_documents: List[Dict[str, Any]] = []
330
+ for docs_in_stratum in type_to_documents.values():
331
+ if not docs_in_stratum:
332
+ continue
333
+ proportion = len(docs_in_stratum) / max(1, total_documents)
334
+ stratum_quota = int(target_sample_count * proportion)
335
+ if stratum_quota > 0:
336
+ sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum))))
337
+
338
+ sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")}
339
+ final_documents = list(sampled_by_id.values())
340
+
341
+ if len(final_documents) > target_sample_count:
342
+ final_documents = rng.sample(final_documents, target_sample_count)
343
+ elif len(final_documents) < target_sample_count:
344
+ remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id]
345
+ additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents))
346
+ if additional_needed > 0:
347
+ final_documents.extend(rng.sample(remaining_documents, additional_needed))
348
+
349
+ lines: List[str] = []
350
+ for document in final_documents:
351
+ document_id = str(document.get("id", ""))
352
+ title = str(document.get("title", ""))
353
+ text = self._truncate_text(str(document.get("text", "")), max_chars_per_text)
354
+
355
+ associated_types = document_to_types.get(document_id, [])
356
+ associated_types_escaped = [t.replace("'", "\\'") for t in associated_types]
357
+
358
+ lines.append(
359
+ "Document ID: {doc_id}\n"
360
+ "Title: {title}\n"
361
+ "Text: {text}\n"
362
+ "Associated Types: {types}\n"
363
+ "----------------------------------------".format(
364
+ doc_id=document_id,
365
+ title=title,
366
+ text=text,
367
+ types=associated_types_escaped,
368
+ )
369
+ )
370
+
371
+ self.few_shot_types_block = "\n".join(lines)
372
+ return self.few_shot_types_block
373
+
374
+ def _format_term_prompt(self, example_block: str, title: str, text: str) -> str:
347
375
  """
348
- Fit phase: Builds and caches the few-shot prompt blocks from the training files.
349
- No model training occurs (Few-Shot/In-Context Learning).
376
+ Format a prompt for term extraction.
377
+
378
+ Args:
379
+ example_block: Few-shot examples block.
380
+ title: Document title.
381
+ text: Document text.
382
+
383
+ Returns:
384
+ Prompt string.
350
385
  """
351
- # Build prompt block for Term extraction
352
- _ = self.build_stratified_fewshot_prompt(
353
- train_docs_jsonl, terms2doc_json, sample_size=sample_size, seed=seed
386
+ return (
387
+ f"{example_block}\n"
388
+ "[var]\n"
389
+ f"Title: {title}\n"
390
+ f"Text: {text}\n"
391
+ "[var]\n"
392
+ "Extract all relevant terms that could form the basis of an ontology from the above document.\n"
393
+ "Return ONLY a Python list like ['term1', 'term2', ...] and nothing else.\n"
394
+ "If no terms are found, return [].\n"
354
395
  )
355
- # Build prompt block for Type extraction
356
- _ = self.build_types_fewshot_block(
357
- train_docs_jsonl, terms2doc_json, sample_per_term=1
396
+
397
+ def _format_type_prompt(self, example_block: str, title: str, text: str) -> str:
398
+ """
399
+ Format a prompt for type (class) extraction.
400
+
401
+ Args:
402
+ example_block: Few-shot examples block.
403
+ title: Document title.
404
+ text: Document text.
405
+
406
+ Returns:
407
+ Prompt string.
408
+ """
409
+ return (
410
+ f"{example_block}\n"
411
+ "[var]\n"
412
+ f"Title: {title}\n"
413
+ f"Text: {text}\n"
414
+ "[var]\n"
415
+ "Extract all relevant TYPES mentioned in the above document that could serve as ontology classes.\n"
416
+ "Only consider content inside the [var] ... [var] block.\n"
417
+ "Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return [].\n"
358
418
  )
359
419
 
360
- # -------------------------
361
- # Inference helpers (prompt construction and output parsing)
362
- # -------------------------
363
- def _build_term_prompt(self, example_block: str, title: str, text: str) -> str:
364
- """Constructs the full prompt for Term Extraction."""
365
- return f"""{example_block}
366
- [var]
367
- Title: {title}
368
- Text: {text}
369
- [var]
370
- Extract all relevant terms that could form the basis of an ontology from the above document.
371
- Return ONLY a Python list like ['term1', 'term2', ...] and nothing else.
372
- If no terms are found, return [].
373
- """
374
-
375
- def _build_type_prompt(self, example_block: str, title: str, text: str) -> str:
376
- """Constructs the full prompt for Type Extraction."""
377
- return f"""{example_block}
378
- [var]
379
- Title: {title}
380
- Text: {text}
381
- [var]
382
- Extract all relevant TYPES mentioned in the above document that could serve as ontology classes.
383
- Only consider content inside the [var] ... [var] block.
384
- Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return [].
385
- """
386
-
387
- def _parse_list_like(self, raw_string: str) -> List[str]:
388
- """Try to extract a Python list of strings from model output robustly."""
389
- processed_string = raw_string.strip()
390
- if processed_string in ("[]", ""):
420
+ def _parse_python_list_of_strings(self, raw_text: str) -> List[str]:
421
+ """
422
+ Parse an LLM response intended to be a Python list of strings.
423
+
424
+ This parser is intentionally tolerant:
425
+ 1) Try literal_eval on the full string
426
+ 2) Else extract the first [...] block and literal_eval it
427
+ 3) Else fallback to extracting quoted strings
428
+
429
+ Args:
430
+ raw_text: Model output.
431
+
432
+ Returns:
433
+ List of strings (possibly empty).
434
+ """
435
+ stripped = (raw_text or "").strip()
436
+ if stripped in ("", "[]"):
391
437
  return []
392
438
 
393
- # 1. Try direct evaluation
394
439
  try:
395
- parsed_value = ast.literal_eval(processed_string)
396
- if isinstance(parsed_value, list):
397
- # Filter to ensure only strings are returned
398
- return [item for item in parsed_value if isinstance(item, str)]
440
+ parsed = ast.literal_eval(stripped)
441
+ if isinstance(parsed, list):
442
+ return [item for item in parsed if isinstance(item, str)]
399
443
  except Exception:
400
444
  pass
401
445
 
402
- # 2. Try finding and evaluating text within outermost brackets [ ... ]
403
- bracket_match = re.search(r"\[[\s\S]*?\]", processed_string)
404
- if bracket_match:
446
+ match = re.search(r"\[[\s\S]*?\]", stripped)
447
+ if match:
405
448
  try:
406
- parsed_value = ast.literal_eval(bracket_match.group(0))
407
- if isinstance(parsed_value, list):
408
- return [item for item in parsed_value if isinstance(item, str)]
449
+ parsed = ast.literal_eval(match.group(0))
450
+ if isinstance(parsed, list):
451
+ return [item for item in parsed if isinstance(item, str)]
409
452
  except Exception:
410
453
  pass
411
454
 
412
- # 3. Fallback: Find comma-separated quoted substrings (less robust, but catches errors)
413
- # Finds content inside either single quotes ('...') or double quotes ("...")
414
- quoted_matches = re.findall(r"'([^']+)'|\"([^\"]+)\"", processed_string)
415
- flattened_list = [a_match or b_match for a_match, b_match in quoted_matches]
416
- return flattened_list
417
-
418
- def _call_model_one(self, prompt: str, max_new_tokens: int = 120) -> str:
419
- """Calls the underlying LocalAutoLLM for a single prompt. Returns the raw tail output."""
420
- # self.model is an instance of LocalAutoLLM
421
- model_output = self.model.generate(
422
- [prompt], max_new_tokens=max_new_tokens, temperature=0.0, top_p=1.0
423
- )
424
- return model_output[0] if model_output else ""
455
+ quoted = re.findall(r"'([^']+)'|\"([^\"]+)\"", stripped)
456
+ return [a or b for a, b in quoted]
425
457
 
426
- def predict_terms(
427
- self,
428
- docs_test_jsonl: str,
429
- out_jsonl: str,
430
- max_lines: int = -1,
431
- max_new_tokens: int = 120,
432
- ) -> int:
458
+ def _generate_completion(self, prompt_text: str) -> str:
433
459
  """
434
- Runs Term Extraction on the test documents and saves results to a JSONL file.
435
- Returns: The count of individual terms written.
460
+ Generate a completion for a single prompt (deterministic decoding).
461
+
462
+ Args:
463
+ prompt_text: Full prompt to send to the model.
464
+
465
+ Returns:
466
+ The generated completion text (prompt stripped where possible).
436
467
  """
437
- if not self.fewshot_terms_block:
438
- raise RuntimeError("Few-shot block for terms is empty. Call fit() first.")
439
-
440
- num_written_terms = 0
441
- with (
442
- open(docs_test_jsonl, "r", encoding="utf-8") as file_in,
443
- open(out_jsonl, "w", encoding="utf-8") as file_out,
444
- ):
445
- for line_index, line in enumerate(file_in, start=1):
446
- if 0 < max_lines < line_index:
447
- break
448
-
449
- try:
450
- document = json.loads(line.strip())
451
- except Exception:
452
- continue # Skip malformed JSON lines
453
-
454
- doc_id = document.get("id", "unknown")
455
- title = document.get("title", "")
456
- text = document.get("text", "")
457
-
458
- # Construct and call model
459
- prompt = self._build_term_prompt(self.fewshot_terms_block, title, text)
460
- raw_output = self._call_model_one(prompt, max_new_tokens=max_new_tokens)
461
- predicted_terms = self._parse_list_like(raw_output)
462
-
463
- # Write extracted terms
464
- for term_or_type in predicted_terms:
465
- if isinstance(term_or_type, str) and term_or_type.strip():
466
- file_out.write(
467
- json.dumps({"doc_id": doc_id, "term": term_or_type.strip()})
468
- + "\n"
469
- )
470
- num_written_terms += 1
471
-
472
- # Lightweight memory management for long runs
473
- if line_index % 50 == 0:
474
- gc.collect()
475
- if torch.cuda.is_available():
476
- torch.cuda.empty_cache()
477
-
478
- return num_written_terms
479
-
480
- def predict_types(
468
+ if self.model is None or self.tokenizer is None:
469
+ raise RuntimeError("Model/tokenizer not loaded. Call .load() first.")
470
+
471
+ encoded = self.tokenizer([prompt_text], return_tensors="pt", padding=True, truncation=True)
472
+ input_ids = encoded["input_ids"]
473
+ attention_mask = encoded["attention_mask"]
474
+
475
+ model_device = next(self.model.parameters()).device
476
+ input_ids = input_ids.to(model_device)
477
+ attention_mask = attention_mask.to(model_device)
478
+
479
+ with torch.no_grad():
480
+ output_ids = self.model.generate(
481
+ input_ids=input_ids,
482
+ attention_mask=attention_mask,
483
+ max_new_tokens=self.max_new_tokens,
484
+ do_sample=False,
485
+ temperature=0.0,
486
+ top_p=1.0,
487
+ pad_token_id=self.tokenizer.eos_token_id,
488
+ )[0]
489
+
490
+ decoded_full = self.tokenizer.decode(output_ids, skip_special_tokens=True)
491
+ decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
492
+
493
+ if decoded_full.startswith(decoded_prompt):
494
+ return decoded_full[len(decoded_prompt) :].strip()
495
+
496
+ prompt_token_count = int(attention_mask[0].sum().item())
497
+ return self.tokenizer.decode(output_ids[prompt_token_count:], skip_special_tokens=True).strip()
498
+
499
+ def fit(
481
500
  self,
482
- docs_test_jsonl: str,
483
- out_jsonl: str,
484
- max_lines: int = -1,
485
- max_new_tokens: int = 120,
486
- ) -> int:
487
- """
488
- Runs Type Extraction on the test documents and saves results to a JSONL file.
489
- Returns: The count of individual types written.
490
- """
491
- if not self.fewshot_types_block:
492
- raise RuntimeError("Few-shot block for types is empty. Call fit() first.")
493
-
494
- num_written_types = 0
495
- with (
496
- open(docs_test_jsonl, "r", encoding="utf-8") as file_in,
497
- open(out_jsonl, "w", encoding="utf-8") as file_out,
498
- ):
499
- for line_index, line in enumerate(file_in, start=1):
500
- if 0 < max_lines < line_index:
501
- break
502
-
503
- try:
504
- document = json.loads(line.strip())
505
- except Exception:
506
- continue # Skip malformed JSON lines
507
-
508
- doc_id = document.get("id", "unknown")
509
- title = document.get("title", "")
510
- text = document.get("text", "")
511
-
512
- # Construct and call model using the dedicated type prompt block
513
- prompt = self._build_type_prompt(self.fewshot_types_block, title, text)
514
- raw_output = self._call_model_one(prompt, max_new_tokens=max_new_tokens)
515
- predicted_types = self._parse_list_like(raw_output)
516
-
517
- # Write extracted types
518
- for term_or_type in predicted_types:
519
- if isinstance(term_or_type, str) and term_or_type.strip():
520
- file_out.write(
521
- json.dumps({"doc_id": doc_id, "type": term_or_type.strip()})
522
- + "\n"
523
- )
524
- num_written_types += 1
525
-
526
- if line_index % 50 == 0:
527
- gc.collect()
528
- if torch.cuda.is_available():
529
- torch.cuda.empty_cache()
530
-
531
- return num_written_types
532
-
533
- # --- Evaluation utilities (unchanged from prior definition, added docstrings) ---
534
- def load_gold_pairs(self, terms2doc_path: str) -> Set[Tuple[str, str]]:
535
- """Convert terms2docs JSON into a set of unique (doc_id, term) pairs, lowercased."""
536
- gold_pairs = set()
537
- with open(terms2doc_path, "r", encoding="utf-8") as file_handle:
538
- term_to_doc_map = json.load(file_handle)
539
-
540
- for term, doc_ids in term_to_doc_map.items():
541
- clean_term = term.strip().lower()
542
- for doc_id in doc_ids:
543
- gold_pairs.add((doc_id, clean_term))
544
- return gold_pairs
545
-
546
- def load_predicted_pairs(
547
- self, predicted_jsonl_path: str, key: str = "term"
548
- ) -> Set[Tuple[str, str]]:
549
- """Load predicted (doc_id, term/type) pairs from a JSONL file, lowercased."""
550
- predicted_pairs = set()
551
- with open(predicted_jsonl_path, "r", encoding="utf-8") as file_handle:
552
- for line in file_handle:
553
- try:
554
- entry = json.loads(line.strip())
555
- except Exception:
556
- continue
557
- doc_id = entry.get("doc_id")
558
- value = entry.get(key)
559
- if doc_id and value:
560
- predicted_pairs.add((doc_id, value.strip().lower()))
561
- return predicted_pairs
562
-
563
- def evaluate_extraction_f1(
564
- self, terms2doc_path: str, predicted_jsonl: str, key: str = "term"
565
- ) -> float:
501
+ train_data: Any,
502
+ task: str = "text2onto",
503
+ ontologizer: bool = False,
504
+ **kwargs: Any,
505
+ ) -> None:
566
506
  """
567
- Computes set-based binary Precision, Recall, and F1 score against the gold pairs.
507
+ Build and cache few-shot blocks from the training split.
508
+
509
+ Args:
510
+ train_data: A split bundle dict. Must contain "documents" and "terms2docs".
511
+ task: Must be "text2onto".
512
+ ontologizer: Unused here (kept for signature compatibility).
513
+ **kwargs:
514
+ sample_size: Few-shot sample size per block (default 28).
515
+ seed: RNG seed (default 123).
568
516
  """
569
- # Load the ground truth and predictions
570
- gold_set = self.load_gold_pairs(terms2doc_path)
571
- predicted_set = self.load_predicted_pairs(predicted_jsonl, key=key)
517
+ if task != "text2onto":
518
+ raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).")
572
519
 
573
- # Build combined universe of all pairs for score calculation
574
- all_pairs = sorted(gold_set | predicted_set)
520
+ if not self._is_loaded:
521
+ self.load(model_id=self._default_model_id)
575
522
 
576
- # Create binary labels (1=present, 0=absent)
577
- y_true = [1 if pair in gold_set else 0 for pair in all_pairs]
578
- y_pred = [1 if pair in predicted_set else 0 for pair in all_pairs]
523
+ documents: List[Dict[str, Any]] = train_data.get("documents", []) or []
524
+ terms_to_documents: Dict[str, List[str]] = train_data.get("terms2docs", {}) or {}
525
+ terms_to_types: Optional[Dict[str, List[str]]] = train_data.get("terms2types", None)
579
526
 
580
- # Use scikit-learn for metric calculation
581
- from sklearn.metrics import precision_recall_fscore_support
527
+ sample_size = int(kwargs.get("sample_size", 28))
528
+ seed = int(kwargs.get("seed", 123))
582
529
 
583
- precision, recall, f1, _ = precision_recall_fscore_support(
584
- y_true, y_pred, average="binary", zero_division=0
530
+ self.build_few_shot_terms_block(
531
+ documents=documents,
532
+ terms_to_documents=terms_to_documents,
533
+ sample_size=sample_size,
534
+ seed=seed,
535
+ )
536
+ self.build_few_shot_types_block(
537
+ documents=documents,
538
+ terms_to_documents=terms_to_documents,
539
+ terms_to_types=terms_to_types,
540
+ sample_size=sample_size,
541
+ seed=seed,
585
542
  )
586
543
 
587
- # Display results
588
- num_true_positives = len(gold_set & predicted_set)
544
+ def predict(
545
+ self,
546
+ test_data: Any,
547
+ task: str = "text2onto",
548
+ ontologizer: bool = False,
549
+ **kwargs: Any,
550
+ ) -> Dict[str, Any]:
551
+ """
552
+ Run term/type extraction over test documents.
589
553
 
590
- print("\n📊 Evaluation Results:")
591
- print(f" ✅ Precision: {precision:.4f}")
592
- print(f" ✅ Recall: {recall:.4f}")
593
- print(f" ✅ F1 Score: {f1:.4f}")
594
- print(f" 📌 Gold pairs: {len(gold_set)}")
595
- print(f" 📌 Predicted pairs:{len(predicted_set)}")
596
- print(f" 🎯 True Positives: {num_true_positives}")
554
+ Args:
555
+ test_data: A split bundle dict. Must contain "documents".
556
+ task: Must be "text2onto".
557
+ ontologizer: Unused here (kept for signature compatibility).
558
+ **kwargs:
559
+ max_docs: If > 0, limit number of docs processed.
597
560
 
598
- return float(f1)
561
+ Returns:
562
+ Prediction payload dict: {"terms": [...], "types": [...]}.
563
+ """
564
+ if task != "text2onto":
565
+ raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).")
566
+
567
+ if not self.few_shot_terms_block or not self.few_shot_types_block:
568
+ raise RuntimeError("Few-shot blocks are empty. Pipeline should call fit() before predict().")
569
+
570
+ max_docs = int(kwargs.get("max_docs", -1))
571
+ documents: List[Dict[str, Any]] = test_data.get("documents", []) or []
572
+ if max_docs > 0:
573
+ documents = documents[:max_docs]
574
+
575
+ term_predictions: List[Dict[str, str]] = []
576
+ type_predictions: List[Dict[str, str]] = []
577
+
578
+ for doc_index, document in enumerate(documents, start=1):
579
+ document_id = str(document.get("id", "unknown"))
580
+ title = str(document.get("title", ""))
581
+ text = str(document.get("text", ""))
582
+
583
+ term_prompt = self._format_term_prompt(self.few_shot_terms_block, title, text)
584
+ extracted_terms = self._parse_python_list_of_strings(self._generate_completion(term_prompt))
585
+ for term in extracted_terms:
586
+ normalized_term = (term or "").strip()
587
+ if normalized_term:
588
+ term_predictions.append({"doc_id": document_id, "term": normalized_term})
589
+
590
+ type_prompt = self._format_type_prompt(self.few_shot_types_block, title, text)
591
+ extracted_types = self._parse_python_list_of_strings(self._generate_completion(type_prompt))
592
+ for extracted_type in extracted_types:
593
+ normalized_type = (extracted_type or "").strip()
594
+ if normalized_type:
595
+ type_predictions.append({"doc_id": document_id, "type": normalized_type})
596
+
597
+ if doc_index % 50 == 0:
598
+ gc.collect()
599
+ if torch.cuda.is_available():
600
+ torch.cuda.empty_cache()
601
+
602
+ # IMPORTANT: return only the prediction payload; LearnerPipeline wraps it.
603
+ return {"terms": term_predictions, "types": type_predictions}