OntoLearner 1.4.6__py3-none-any.whl → 1.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,402 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ # License: MIT
3
+
4
+ import os
5
+ import re
6
+ import json
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
+ from ...base import AutoLearner
12
+
13
+
14
+ class SBUNLPFewShotLearner(AutoLearner):
15
+ """
16
+ Few-shot taxonomy discovery via N×M batch prompting.
17
+
18
+ This learner:
19
+ - Caches & cleans gold parent–child pairs during `fit`.
20
+ - Splits (train pairs × test terms) into a grid of chunks.
21
+ - Builds an instruction prompt per grid cell with few-shot JSON examples.
22
+ - Generates and parses model outputs as JSON relations.
23
+ - Merges & deduplicates all predicted edges.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model_name: str = "Qwen/Qwen2.5-0.5B-Instruct",
29
+ try_4bit: bool = True,
30
+ device: str = "cpu",
31
+ num_train_chunks: int = 7,
32
+ num_test_chunks: int = 7,
33
+ max_new_tokens: int = 140,
34
+ max_input_tokens: int = 1500,
35
+ temperature: float = 0.0,
36
+ top_p: float = 1.0,
37
+ limit_num_prompts: Optional[int] = None,
38
+ output_dir: Optional[str] = None,
39
+ **kwargs: Any,
40
+ ) -> None:
41
+ """
42
+ Initialize the learner and core generation / batching settings.
43
+
44
+ Args:
45
+ model_name: HF id/path of the causal LLM (e.g., Qwen Instruct).
46
+ try_4bit: If True and on CUDA, load with 4-bit NF4 quantization.
47
+ device: "cpu" or "cuda" for model execution.
48
+ num_train_chunks: Number of chunks for the gold (parent, child) bank.
49
+ num_test_chunks: Number of chunks for the test term list.
50
+ max_new_tokens: Max new tokens to generate per prompt call.
51
+ max_input_tokens: Clip the *input* prompt to this many tokens (tail kept).
52
+ temperature: Sampling temperature; 0.0 uses greedy decoding.
53
+ top_p: Nucleus sampling parameter (used when temperature > 0).
54
+ limit_num_prompts: Optional hard cap on prompts issued (debug/cost).
55
+ output_dir: Optional directory to save per-batch JSON predictions.
56
+ **kwargs: Forwarded to the base class.
57
+ """
58
+ super().__init__(**kwargs)
59
+ self.model_name = model_name
60
+ self.try_4bit = try_4bit
61
+ self.device = device
62
+
63
+ self.num_train_chunks = num_train_chunks
64
+ self.num_test_chunks = num_test_chunks
65
+ self.max_new_tokens = max_new_tokens
66
+ self.max_input_tokens = max_input_tokens
67
+ self.temperature = temperature
68
+ self.top_p = top_p
69
+ self.limit_num_prompts = limit_num_prompts
70
+ self.output_dir = output_dir
71
+
72
+ self.tokenizer: Optional[AutoTokenizer] = None
73
+ self.model: Optional[AutoModelForCausalLM] = None
74
+ self.train_pairs_clean: List[Dict[str, str]] = []
75
+
76
+ def _clean_pairs(self, pair_rows: List[Dict[str, str]]) -> List[Dict[str, str]]:
77
+ """
78
+ Normalize, filter, and deduplicate relation pairs.
79
+
80
+ Operations:
81
+ - Cast 'parent'/'child' to strings and strip whitespace.
82
+ - Drop rows with empty values.
83
+ - Drop self-relations (case-insensitive parent == child).
84
+ - Deduplicate by lowercase (parent, child).
85
+
86
+ Args:
87
+ pair_rows: Raw list of dicts with at least 'parent' and 'child'.
88
+
89
+ Returns:
90
+ Cleaned list of {'parent','child'} dicts.
91
+ """
92
+ cleaned, seen = [], set()
93
+ for rec in pair_rows or []:
94
+ if not isinstance(rec, dict):
95
+ continue
96
+ p = str(rec.get("parent", "")).strip()
97
+ c = str(rec.get("child", "")).strip()
98
+ if not p or not c:
99
+ continue
100
+ key = (p.lower(), c.lower())
101
+ if key[0] == key[1] or key in seen:
102
+ continue
103
+ seen.add(key)
104
+ cleaned.append({"parent": p, "child": c})
105
+ return cleaned
106
+
107
+ def _chunk_list(self, items: List[Any], num_chunks: int) -> List[List[Any]]:
108
+ """
109
+ Split a list into `num_chunks` near-equal contiguous parts.
110
+
111
+ Args:
112
+ items: Sequence to split.
113
+ num_chunks: Number of chunks to produce; if <= 0, returns [items].
114
+
115
+ Returns:
116
+ List of chunks (some may be empty if len(items) < num_chunks).
117
+ """
118
+ if num_chunks <= 0:
119
+ return [items]
120
+ n = len(items)
121
+ base, rem = divmod(n, num_chunks)
122
+ out, start = [], 0
123
+ for i in range(num_chunks):
124
+ size = base + (1 if i < rem else 0)
125
+ out.append(items[start : start + size])
126
+ start += size
127
+ return out
128
+
129
+ def _ensure_dir(self, path: Optional[str]) -> None:
130
+ """
131
+ Create a directory if `path` is a non-empty string.
132
+
133
+ Args:
134
+ path: Directory to create (recursively). Ignored if falsy.
135
+ """
136
+ if path:
137
+ os.makedirs(path, exist_ok=True)
138
+
139
+ def load(self, **_: Any) -> None:
140
+ """
141
+ Load tokenizer and model; optionally enable 4-bit quantization.
142
+
143
+ Assumes bitsandbytes is available if `try_4bit=True` on CUDA.
144
+ Sets tokenizer pad token if missing. Places model on GPU (device_map='auto')
145
+ when `device='cuda'`, otherwise on CPU.
146
+
147
+ Args:
148
+ **_: Unused kwargs for interface compatibility.
149
+ """
150
+ quant_config = None
151
+ if self.try_4bit and self.device == "cuda":
152
+ quant_config = BitsAndBytesConfig(
153
+ load_in_4bit=True,
154
+ bnb_4bit_compute_dtype=torch.float16,
155
+ bnb_4bit_use_double_quant=True,
156
+ bnb_4bit_quant_type="nf4",
157
+ )
158
+
159
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
160
+ if getattr(self.tokenizer, "pad_token_id", None) is None:
161
+ if getattr(self.tokenizer, "eos_token", None) is not None:
162
+ self.tokenizer.pad_token = self.tokenizer.eos_token
163
+ elif getattr(self.tokenizer, "unk_token", None) is not None:
164
+ self.tokenizer.pad_token = self.tokenizer.unk_token
165
+
166
+ self.model = AutoModelForCausalLM.from_pretrained(
167
+ self.model_name,
168
+ device_map=("auto" if self.device == "cuda" else None),
169
+ torch_dtype=(torch.float16 if self.device == "cuda" else torch.float32),
170
+ quantization_config=quant_config,
171
+ )
172
+ if self.device == "cpu":
173
+ self.model.to("cpu")
174
+
175
+ def _format_chat(self, user_text: str) -> str:
176
+ """
177
+ Wrap plain text with the model's chat template, if provided.
178
+
179
+ Many instruction-tuned models expose `tokenizer.chat_template`.
180
+ If available, use it to construct a proper chat prompt; otherwise,
181
+ return the text unchanged.
182
+
183
+ Args:
184
+ user_text: Content of the user message.
185
+
186
+ Returns:
187
+ A generation-ready prompt string.
188
+ """
189
+ if hasattr(self.tokenizer, "apply_chat_template") and getattr(
190
+ self.tokenizer, "chat_template", None
191
+ ):
192
+ return self.tokenizer.apply_chat_template(
193
+ [{"role": "user", "content": user_text}],
194
+ tokenize=False,
195
+ add_generation_prompt=True,
196
+ )
197
+ return user_text
198
+
199
+ @torch.no_grad()
200
+ def _generate(self, prompt_text: str) -> str:
201
+ """
202
+ Generate text for a single prompt, guarding input length.
203
+
204
+ Steps:
205
+ 1) Format prompt via chat template (if present).
206
+ 2) Tokenize and clip the *input* to `max_input_tokens` (tail kept).
207
+ 3) Call `model.generate` with configured decoding params.
208
+ 4) Strip the echoed prompt from the decoded output (if present).
209
+
210
+ Args:
211
+ prompt_text: Textual prompt to feed the model.
212
+
213
+ Returns:
214
+ Model continuation string (prompt-echo stripped when applicable).
215
+ """
216
+ formatted = self._format_chat(prompt_text)
217
+ ids = self.tokenizer(formatted, add_special_tokens=False, return_tensors=None)[
218
+ "input_ids"
219
+ ]
220
+ if len(ids) > self.max_input_tokens:
221
+ ids = ids[-self.max_input_tokens :]
222
+ device = next(self.model.parameters()).device
223
+ input_ids = torch.tensor([ids], device=device)
224
+
225
+ out = self.model.generate(
226
+ input_ids=input_ids,
227
+ max_new_tokens=self.max_new_tokens,
228
+ do_sample=(self.temperature > 0.0),
229
+ temperature=self.temperature,
230
+ top_p=self.top_p,
231
+ pad_token_id=self.tokenizer.pad_token_id,
232
+ eos_token_id=getattr(self.tokenizer, "eos_token_id", None),
233
+ use_cache=True,
234
+ )
235
+
236
+ decoded_full = self.tokenizer.decode(out[0], skip_special_tokens=True)
237
+ decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
238
+ return (
239
+ decoded_full[len(decoded_prompt) :].strip()
240
+ if decoded_full.startswith(decoded_prompt)
241
+ else decoded_full.strip()
242
+ )
243
+
244
+ def _build_prompt(
245
+ self,
246
+ train_pairs_chunk: List[Dict[str, str]],
247
+ test_terms_chunk: List[str],
248
+ ) -> str:
249
+ """
250
+ Construct a few-shot prompt with JSON examples and test terms.
251
+
252
+ The prompt:
253
+ - Shows several gold (parent, child) examples in JSON.
254
+ - Lists the test terms (one per line) between [PAIR] tags.
255
+ - Instructs to return ONLY a JSON array of {'parent','child'}.
256
+
257
+ Args:
258
+ train_pairs_chunk: Cleaned training relations for examples.
259
+ test_terms_chunk: The current chunk of test terms.
260
+
261
+ Returns:
262
+ The fully formatted prompt string.
263
+ """
264
+ examples_json = json.dumps(train_pairs_chunk, ensure_ascii=False, indent=2)
265
+ test_block = "\n".join(test_terms_chunk)
266
+ prompt = (
267
+ "From this file, extract all parent–child relations like in the examples.\n"
268
+ "Return ONLY a JSON array of objects with keys 'parent' and 'child'.\n"
269
+ "Output format:\n"
270
+ "[\n"
271
+ ' {"parent": "parent1", "child": "child1"},\n'
272
+ ' {"parent": "parent2", "child": "child2"}\n'
273
+ "]\n\n"
274
+ "EXAMPLES (JSON):\n"
275
+ f"{examples_json}\n\n"
276
+ "TEST TYPES (between [PAIR] tags):\n"
277
+ "[PAIR]\n"
278
+ f"{test_block}\n"
279
+ "[PAIR]\n"
280
+ "Return only JSON."
281
+ )
282
+ return prompt
283
+
284
+ def _parse_pairs(self, text: str) -> List[Dict[str, str]]:
285
+ """
286
+ Parse a generation string into a list of relation dicts.
287
+
288
+ Parsing strategy:
289
+ 1) Try to parse the entire string as JSON; expect a list.
290
+ 2) Else, regex-extract the outermost JSON-like array and parse that.
291
+ 3) On failure, return an empty list.
292
+
293
+ Args:
294
+ text: Raw model output.
295
+
296
+ Returns:
297
+ Cleaned list of {'parent','child'} dicts (possibly empty).
298
+ """
299
+ text = text.strip()
300
+ try:
301
+ obj = json.loads(text)
302
+ if isinstance(obj, list):
303
+ return self._clean_pairs(obj)
304
+ except Exception:
305
+ pass
306
+ m = re.search(r"\[\s*(?:\{[\s\S]*?\}\s*,?\s*)*\]", text)
307
+ if m:
308
+ try:
309
+ obj = json.loads(m.group(0))
310
+ if isinstance(obj, list):
311
+ return self._clean_pairs(obj)
312
+ except Exception:
313
+ pass
314
+ return []
315
+
316
+ def fit(self, train_data: Any, task: str, ontologizer: bool = True):
317
+ """
318
+ Cache and clean gold relations for few-shot prompting.
319
+
320
+ For `task == "taxonomy-discovery"`:
321
+ - If `ontologizer=True`, convert ontology-like input into
322
+ a list of {'parent','child'} via the base helper.
323
+ - Otherwise, accept a user-provided list directly.
324
+ - Store a cleaned, deduplicated bank in `self.train_pairs_clean`.
325
+
326
+ Args:
327
+ train_data: Ontology-like object or list of relation dicts.
328
+ task: Task selector (expects "taxonomy-discovery").
329
+ ontologizer: Whether to transform ontology inputs.
330
+
331
+ Returns:
332
+ None. (State is stored on the instance.)
333
+ """
334
+ if task != "taxonomy-discovery":
335
+ return super().fit(train_data, task, ontologizer)
336
+ if ontologizer:
337
+ gold = self.tasks_ground_truth_former(train_data, task="taxonomy-discovery")
338
+ self.train_pairs_clean = self._clean_pairs(gold)
339
+ else:
340
+ self.train_pairs_clean = self._clean_pairs(train_data)
341
+
342
+ def _taxonomy_discovery(
343
+ self, data: Any, test: bool = False
344
+ ) -> Optional[List[Dict[str, str]]]:
345
+ """
346
+ Run few-shot inference (test=True) or no-op during training.
347
+
348
+ Inference steps:
349
+ - Ensure tokenizer/model are loaded.
350
+ - Normalize `data` to a list of test terms (via base helper if needed).
351
+ - Create the N×M grid across (train_pairs_chunk × test_terms_chunk).
352
+ - For each cell: build prompt → generate → parse → (optionally) save.
353
+ - Merge and deduplicate all predicted pairs before returning.
354
+
355
+ Args:
356
+ data: Test input (ontology-like, list of strings, or mixed).
357
+ test: If True, perform prediction; otherwise return None.
358
+
359
+ Returns:
360
+ On `test=True`: deduplicated list of {'parent','child'}.
361
+ On `test=False`: None.
362
+ """
363
+ if not test:
364
+ return None
365
+ if self.model is None or self.tokenizer is None:
366
+ self.load()
367
+
368
+ if isinstance(data, list) and (len(data) == 0 or isinstance(data[0], str)):
369
+ test_terms: List[str] = data
370
+ else:
371
+ test_terms = super().tasks_data_former(
372
+ data=data, task="taxonomy-discovery", test=True
373
+ )
374
+
375
+ train_chunks = self._chunk_list(self.train_pairs_clean, self.num_train_chunks)
376
+ test_chunks = self._chunk_list(test_terms, self.num_test_chunks)
377
+
378
+ self._ensure_dir(self.output_dir)
379
+
380
+ merged: List[Dict[str, str]] = []
381
+ issued = 0
382
+
383
+ for ti, tr in enumerate(train_chunks, 1):
384
+ for si, ts in enumerate(test_chunks, 1):
385
+ issued += 1
386
+ if self.limit_num_prompts and issued > self.limit_num_prompts:
387
+ break
388
+ prompt = self._build_prompt(tr, ts)
389
+ resp = self._generate(prompt)
390
+ pairs = self._parse_pairs(resp)
391
+
392
+ if self.output_dir:
393
+ path = os.path.join(self.output_dir, f"pairs_T{ti}_S{si}.json")
394
+ with open(path, "w", encoding="utf-8") as f:
395
+ json.dump(pairs, f, ensure_ascii=False, indent=2)
396
+
397
+ merged.extend(pairs)
398
+
399
+ if self.limit_num_prompts and issued >= (self.limit_num_prompts or 0):
400
+ break
401
+
402
+ return self._clean_pairs(merged)