OntoLearner 1.4.6__py3-none-any.whl → 1.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +20 -14
- ontolearner/learner/__init__.py +1 -1
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/llm.py +73 -3
- ontolearner/learner/retriever.py +24 -3
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/METADATA +5 -1
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/RECORD +22 -10
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1219 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Iterable
|
|
16
|
+
import json
|
|
17
|
+
from json.decoder import JSONDecodeError
|
|
18
|
+
import os
|
|
19
|
+
import random
|
|
20
|
+
import re
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
24
|
+
|
|
25
|
+
from ...base import AutoLearner, AutoLLM
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from outlines.models import Transformers as OutlinesTFModel
|
|
29
|
+
from outlines.generate import json as outlines_generate_json
|
|
30
|
+
from pydantic import BaseModel
|
|
31
|
+
|
|
32
|
+
class _PredictedTypesSchema(BaseModel):
|
|
33
|
+
"""Schema used when generating structured JSON { "types": [...] }."""
|
|
34
|
+
|
|
35
|
+
types: List[str]
|
|
36
|
+
|
|
37
|
+
OUTLINES_AVAILABLE: bool = True
|
|
38
|
+
except Exception:
|
|
39
|
+
# If outlines is unavailable, we will fall back to greedy decoding + regex parsing.
|
|
40
|
+
OUTLINES_AVAILABLE = False
|
|
41
|
+
_PredictedTypesSchema = None
|
|
42
|
+
OutlinesTFModel = None
|
|
43
|
+
outlines_generate_json = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LocalAutoLLM(AutoLLM):
|
|
47
|
+
"""
|
|
48
|
+
Minimal local LLM helper.
|
|
49
|
+
|
|
50
|
+
- Inherits AutoLLM but overrides load/generate to avoid label_mapper.
|
|
51
|
+
- Optional 4-bit loading with `load_in_4bit=True` in .load().
|
|
52
|
+
- Greedy decoding by default (deterministic).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, device: str = "cpu", token: str = "") -> None:
|
|
56
|
+
"""
|
|
57
|
+
Initialize the local LLM holder.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
device : str
|
|
62
|
+
Execution device: "cpu" or "cuda".
|
|
63
|
+
token : str
|
|
64
|
+
Optional auth token for private model hubs.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(label_mapper=None, device=device, token=token)
|
|
67
|
+
self.model: Optional[AutoModelForCausalLM] = None
|
|
68
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
69
|
+
|
|
70
|
+
def load(self, model_id: str, *, load_in_4bit: bool = False) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Load a Hugging Face causal model + tokenizer and set deterministic
|
|
73
|
+
generation defaults.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
model_id : str
|
|
78
|
+
Model identifier resolvable by HF `from_pretrained`.
|
|
79
|
+
load_in_4bit : bool
|
|
80
|
+
If True and bitsandbytes is available, load using 4-bit quantization.
|
|
81
|
+
"""
|
|
82
|
+
# Tokenizer
|
|
83
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
84
|
+
model_id, padding_side="left", token=self.token
|
|
85
|
+
)
|
|
86
|
+
if self.tokenizer.pad_token is None:
|
|
87
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
88
|
+
|
|
89
|
+
# Model (optionally quantized)
|
|
90
|
+
if load_in_4bit:
|
|
91
|
+
from transformers import BitsAndBytesConfig
|
|
92
|
+
|
|
93
|
+
quantization_config = BitsAndBytesConfig(
|
|
94
|
+
load_in_4bit=True,
|
|
95
|
+
bnb_4bit_quant_type="nf4",
|
|
96
|
+
bnb_4bit_use_double_quant=True,
|
|
97
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
98
|
+
)
|
|
99
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
100
|
+
model_id,
|
|
101
|
+
device_map="auto",
|
|
102
|
+
quantization_config=quantization_config,
|
|
103
|
+
token=self.token,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
device_map = (
|
|
107
|
+
"auto" if (self.device != "cpu" and torch.cuda.is_available()) else None
|
|
108
|
+
)
|
|
109
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
110
|
+
model_id,
|
|
111
|
+
device_map=device_map,
|
|
112
|
+
torch_dtype=torch.bfloat16
|
|
113
|
+
if torch.cuda.is_available()
|
|
114
|
+
else torch.float32,
|
|
115
|
+
token=self.token,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Deterministic generation defaults
|
|
119
|
+
generation_cfg = self.model.generation_config
|
|
120
|
+
generation_cfg.do_sample = False
|
|
121
|
+
generation_cfg.temperature = None
|
|
122
|
+
generation_cfg.top_k = None
|
|
123
|
+
generation_cfg.top_p = None
|
|
124
|
+
generation_cfg.num_beams = 1
|
|
125
|
+
|
|
126
|
+
def generate(self, prompts: List[str], max_new_tokens: int = 128) -> List[str]:
|
|
127
|
+
"""
|
|
128
|
+
Greedy-generate continuations for a list of prompts.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
prompts : List[str]
|
|
133
|
+
Prompts to generate for (batched).
|
|
134
|
+
max_new_tokens : int
|
|
135
|
+
Maximum number of new tokens per continuation.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
List[str]
|
|
140
|
+
Decoded new-token texts (no special tokens, stripped).
|
|
141
|
+
"""
|
|
142
|
+
if self.model is None or self.tokenizer is None:
|
|
143
|
+
raise RuntimeError(
|
|
144
|
+
"Call .load(model_id) on LocalAutoLLM before generate()."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
tokenized_batch = self.tokenizer(
|
|
148
|
+
prompts, return_tensors="pt", padding=True, truncation=True
|
|
149
|
+
)
|
|
150
|
+
input_seq_len = tokenized_batch["input_ids"].shape[1]
|
|
151
|
+
tokenized_batch = {
|
|
152
|
+
k: v.to(self.model.device) for k, v in tokenized_batch.items()
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
with torch.no_grad():
|
|
156
|
+
outputs = self.model.generate(
|
|
157
|
+
**tokenized_batch,
|
|
158
|
+
max_new_tokens=max_new_tokens,
|
|
159
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
160
|
+
do_sample=False,
|
|
161
|
+
num_beams=1,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Only return the newly generated part for each row in the batch
|
|
165
|
+
continuation_token_ids = outputs[:, input_seq_len:]
|
|
166
|
+
return [
|
|
167
|
+
self.tokenizer.decode(row, skip_special_tokens=True).strip()
|
|
168
|
+
for row in continuation_token_ids
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class AlexbekFewShotLearner(AutoLearner):
|
|
173
|
+
"""
|
|
174
|
+
Text2Onto learner for LLMS4OL Task A (term & type extraction).
|
|
175
|
+
|
|
176
|
+
Public API (A1 + convenience):
|
|
177
|
+
- fit(train_docs_jsonl, terms2doc_json, sample_size=24, seed=42)
|
|
178
|
+
- predict_terms(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
|
|
179
|
+
- predict_types(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
|
|
180
|
+
- evaluate_extraction_f1(gold_item2docs_json, preds_jsonl, key="term"|"type") -> float
|
|
181
|
+
|
|
182
|
+
Option A (A2, term→types) bridge:
|
|
183
|
+
- predict_types_from_terms_option_a(...)
|
|
184
|
+
Reads your A1 results (docs→terms), predicts types for each term, and
|
|
185
|
+
writes two files: terms2types_pred.json + types2docs_pred.json
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(self, model: LocalAutoLLM, device: str = "cpu", **_: Any) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Initialize learner state and canned prompts.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
model : LocalAutoLLM
|
|
195
|
+
Loaded local LLM helper instance.
|
|
196
|
+
device : str
|
|
197
|
+
Device name ("cpu" or "cuda").
|
|
198
|
+
"""
|
|
199
|
+
super().__init__(**_)
|
|
200
|
+
self.model = model
|
|
201
|
+
self.device = device
|
|
202
|
+
|
|
203
|
+
# Few-shot exemplars for A1 (Docs→Terms) and for Docs→Types:
|
|
204
|
+
# Each exemplar is a tuple: (title, text, gold_list)
|
|
205
|
+
self._fewshot_terms_docs: List[Tuple[str, str, List[str]]] = []
|
|
206
|
+
self._fewshot_types_docs: List[Tuple[str, str, List[str]]] = []
|
|
207
|
+
|
|
208
|
+
# System prompts
|
|
209
|
+
self._system_prompt_terms = (
|
|
210
|
+
"You are an expert in ontology term extraction.\n"
|
|
211
|
+
"Extract only terms that explicitly appear in the document.\n"
|
|
212
|
+
'Answer strictly as JSON: {"terms": ["..."]}\n'
|
|
213
|
+
)
|
|
214
|
+
self._system_prompt_types = (
|
|
215
|
+
"You are an expert in ontology type classification.\n"
|
|
216
|
+
"List ontology *types* that characterize the document’s terminology.\n"
|
|
217
|
+
'Answer strictly as JSON: {"types": ["..."]}\n'
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Compiled regex for robust JSON extraction from LLM outputs
|
|
221
|
+
self._json_object_regex = re.compile(r"\{[^{}]*\}", re.S)
|
|
222
|
+
self._json_array_regex = re.compile(r"\[[^\]]*\]", re.S)
|
|
223
|
+
|
|
224
|
+
# Term→Types (Option A) specific prompt
|
|
225
|
+
self._system_prompt_term_to_types = (
|
|
226
|
+
"You are an expert in ontology and semantic type classification.\n"
|
|
227
|
+
"Given a term, predict its semantic types from the domain-specific ontology.\n"
|
|
228
|
+
'Answer strictly as JSON:\n{"types": ["type1", "type2", "..."]}'
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def fit(
|
|
232
|
+
self,
|
|
233
|
+
*,
|
|
234
|
+
train_docs_jsonl: str,
|
|
235
|
+
terms2doc_json: str,
|
|
236
|
+
sample_size: int = 24,
|
|
237
|
+
seed: int = 42,
|
|
238
|
+
) -> None:
|
|
239
|
+
"""
|
|
240
|
+
Build internal few-shot exemplars from a labeled training split.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
train_docs_jsonl : str
|
|
245
|
+
Path to JSONL (or tolerant JSON/JSONL) with train documents.
|
|
246
|
+
terms2doc_json : str
|
|
247
|
+
JSON mapping item -> [doc_id,...]; "item" can be a term or type.
|
|
248
|
+
sample_size : int
|
|
249
|
+
Number of exemplar documents to keep for few-shot prompting.
|
|
250
|
+
seed : int
|
|
251
|
+
RNG seed for reproducible sampling.
|
|
252
|
+
"""
|
|
253
|
+
rng = random.Random(seed)
|
|
254
|
+
|
|
255
|
+
# Load documents and map doc_id -> row
|
|
256
|
+
document_map = self._load_documents_jsonl(train_docs_jsonl)
|
|
257
|
+
if not document_map:
|
|
258
|
+
raise FileNotFoundError(f"No documents found in: {train_docs_jsonl}")
|
|
259
|
+
|
|
260
|
+
# Load item -> [doc_ids]
|
|
261
|
+
item_to_docs_map = self._load_json(terms2doc_json)
|
|
262
|
+
if not isinstance(item_to_docs_map, dict):
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"{terms2doc_json} must be a JSON dict mapping item -> [doc_ids]"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Reverse mapping: doc_id -> [items]
|
|
268
|
+
doc_id_to_items_map: Dict[str, List[str]] = {}
|
|
269
|
+
for item_label, doc_id_list in item_to_docs_map.items():
|
|
270
|
+
for doc_id in doc_id_list:
|
|
271
|
+
doc_id_to_items_map.setdefault(doc_id, []).append(item_label)
|
|
272
|
+
|
|
273
|
+
# Build candidate exemplars (title, text, gold_list)
|
|
274
|
+
exemplar_candidates: List[Tuple[str, str, List[str]]] = []
|
|
275
|
+
for doc_id, labeled_items in doc_id_to_items_map.items():
|
|
276
|
+
doc_row = document_map.get(doc_id)
|
|
277
|
+
if not doc_row:
|
|
278
|
+
continue
|
|
279
|
+
doc_title = str(doc_row.get("title", "")) # be defensive (may be None)
|
|
280
|
+
doc_text = self._to_text(
|
|
281
|
+
doc_row.get("text", "")
|
|
282
|
+
) # string-ify list if needed
|
|
283
|
+
if not doc_text:
|
|
284
|
+
continue
|
|
285
|
+
gold_items = self._unique_preserve(
|
|
286
|
+
[s for s in labeled_items if isinstance(s, str)]
|
|
287
|
+
)
|
|
288
|
+
if gold_items:
|
|
289
|
+
exemplar_candidates.append((doc_title, doc_text, gold_items))
|
|
290
|
+
|
|
291
|
+
if not exemplar_candidates:
|
|
292
|
+
raise RuntimeError(
|
|
293
|
+
"No candidate docs with items found to build few-shot exemplars."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
chosen_exemplars = rng.sample(
|
|
297
|
+
exemplar_candidates, k=min(sample_size, len(exemplar_candidates))
|
|
298
|
+
)
|
|
299
|
+
# Reuse exemplars for both docs→terms and docs→types prompting
|
|
300
|
+
self._fewshot_terms_docs = chosen_exemplars
|
|
301
|
+
self._fewshot_types_docs = chosen_exemplars
|
|
302
|
+
|
|
303
|
+
def predict_terms(
|
|
304
|
+
self,
|
|
305
|
+
*,
|
|
306
|
+
docs_test_jsonl: str,
|
|
307
|
+
out_jsonl: str,
|
|
308
|
+
max_new_tokens: int = 128,
|
|
309
|
+
few_shot_k: int = 6,
|
|
310
|
+
) -> int:
|
|
311
|
+
"""
|
|
312
|
+
Extract terms that explicitly appear in each document.
|
|
313
|
+
|
|
314
|
+
Writes one JSON object per line:
|
|
315
|
+
{"id": "<doc_id>", "terms": ["...", "...", ...]}
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
docs_test_jsonl : str
|
|
320
|
+
Path to test/dev documents in JSONL or tolerant JSON/JSONL.
|
|
321
|
+
out_jsonl : str
|
|
322
|
+
Output JSONL path where predictions are written (one line per doc).
|
|
323
|
+
max_new_tokens : int
|
|
324
|
+
Max generation length.
|
|
325
|
+
few_shot_k : int
|
|
326
|
+
Number of few-shot exemplars to prepend per prompt.
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
int
|
|
331
|
+
Number of lines written (i.e., number of processed documents).
|
|
332
|
+
"""
|
|
333
|
+
if self.model is None or self.model.model is None:
|
|
334
|
+
raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
|
|
335
|
+
|
|
336
|
+
test_documents = self._load_documents_jsonl(docs_test_jsonl)
|
|
337
|
+
prompts: List[str] = []
|
|
338
|
+
document_order: List[str] = []
|
|
339
|
+
|
|
340
|
+
for document_id, document_row in test_documents.items():
|
|
341
|
+
title = str(document_row.get("title", ""))
|
|
342
|
+
text = self._to_text(document_row.get("text", ""))
|
|
343
|
+
|
|
344
|
+
fewshot_block = self._format_fewshot_block(
|
|
345
|
+
self._system_prompt_terms,
|
|
346
|
+
self._fewshot_terms_docs,
|
|
347
|
+
key="terms",
|
|
348
|
+
k=few_shot_k,
|
|
349
|
+
)
|
|
350
|
+
user_block = self._format_user_block(title, text)
|
|
351
|
+
|
|
352
|
+
prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
|
|
353
|
+
document_order.append(document_id)
|
|
354
|
+
|
|
355
|
+
generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
|
|
356
|
+
parsed_term_lists = [
|
|
357
|
+
self._parse_json_list(generated, key="terms") for generated in generations
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
|
361
|
+
lines_written = 0
|
|
362
|
+
with open(out_jsonl, "w", encoding="utf-8") as fp_out:
|
|
363
|
+
for document_id, term_list in zip(document_order, parsed_term_lists):
|
|
364
|
+
payload = {"id": document_id, "terms": self._unique_preserve(term_list)}
|
|
365
|
+
fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
366
|
+
lines_written += 1
|
|
367
|
+
return lines_written
|
|
368
|
+
|
|
369
|
+
def predict_types(
|
|
370
|
+
self,
|
|
371
|
+
*,
|
|
372
|
+
docs_test_jsonl: str,
|
|
373
|
+
out_jsonl: str,
|
|
374
|
+
max_new_tokens: int = 128,
|
|
375
|
+
few_shot_k: int = 6,
|
|
376
|
+
) -> int:
|
|
377
|
+
"""
|
|
378
|
+
Predict ontology types that characterize each document’s terminology.
|
|
379
|
+
|
|
380
|
+
Writes one JSON object per line:
|
|
381
|
+
{"id": "<doc_id>", "types": ["...", "...", ...]}
|
|
382
|
+
|
|
383
|
+
Parameters
|
|
384
|
+
----------
|
|
385
|
+
docs_test_jsonl : str
|
|
386
|
+
Path to test/dev documents in JSONL or tolerant JSON/JSONL.
|
|
387
|
+
out_jsonl : str
|
|
388
|
+
Output JSONL path where predictions are written (one line per doc).
|
|
389
|
+
max_new_tokens : int
|
|
390
|
+
Max generation length.
|
|
391
|
+
few_shot_k : int
|
|
392
|
+
Number of few-shot exemplars to prepend per prompt.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
int
|
|
397
|
+
Number of lines written (i.e., number of processed documents).
|
|
398
|
+
"""
|
|
399
|
+
if self.model is None or self.model.model is None:
|
|
400
|
+
raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
|
|
401
|
+
|
|
402
|
+
test_documents = self._load_documents_jsonl(docs_test_jsonl)
|
|
403
|
+
prompts: List[str] = []
|
|
404
|
+
document_order: List[str] = []
|
|
405
|
+
|
|
406
|
+
for document_id, document_row in test_documents.items():
|
|
407
|
+
title = str(document_row.get("title", ""))
|
|
408
|
+
text = self._to_text(document_row.get("text", ""))
|
|
409
|
+
|
|
410
|
+
fewshot_block = self._format_fewshot_block(
|
|
411
|
+
self._system_prompt_types,
|
|
412
|
+
self._fewshot_types_docs,
|
|
413
|
+
key="types",
|
|
414
|
+
k=few_shot_k,
|
|
415
|
+
)
|
|
416
|
+
user_block = self._format_user_block(title, text)
|
|
417
|
+
|
|
418
|
+
prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
|
|
419
|
+
document_order.append(document_id)
|
|
420
|
+
|
|
421
|
+
generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
|
|
422
|
+
parsed_type_lists = [
|
|
423
|
+
self._parse_json_list(generated, key="types") for generated in generations
|
|
424
|
+
]
|
|
425
|
+
|
|
426
|
+
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
|
427
|
+
lines_written = 0
|
|
428
|
+
with open(out_jsonl, "w", encoding="utf-8") as fp_out:
|
|
429
|
+
for document_id, type_list in zip(document_order, parsed_type_lists):
|
|
430
|
+
payload = {"id": document_id, "types": self._unique_preserve(type_list)}
|
|
431
|
+
fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
432
|
+
lines_written += 1
|
|
433
|
+
return lines_written
|
|
434
|
+
|
|
435
|
+
def evaluate_extraction_f1(
|
|
436
|
+
self,
|
|
437
|
+
gold_item2docs_json: str,
|
|
438
|
+
preds_jsonl: str,
|
|
439
|
+
*,
|
|
440
|
+
key: str = "term",
|
|
441
|
+
) -> float:
|
|
442
|
+
"""
|
|
443
|
+
Compute micro-F1 over (doc_id, item) pairs.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
gold_item2docs_json : str
|
|
448
|
+
JSON mapping item -> [doc_ids].
|
|
449
|
+
preds_jsonl : str
|
|
450
|
+
JSONL lines like {"id": "...", "terms":[...]} or {"id":"...","types":[...]}.
|
|
451
|
+
key : str
|
|
452
|
+
"term" or "type" depending on what you are evaluating.
|
|
453
|
+
|
|
454
|
+
Returns
|
|
455
|
+
-------
|
|
456
|
+
float
|
|
457
|
+
Micro-averaged F1 score.
|
|
458
|
+
"""
|
|
459
|
+
item_to_doc_ids: Dict[str, List[str]] = self._load_json(gold_item2docs_json)
|
|
460
|
+
|
|
461
|
+
# Build gold: doc_id -> set(items)
|
|
462
|
+
gold_doc_to_items: Dict[str, set] = {}
|
|
463
|
+
for item_label, doc_id_list in item_to_doc_ids.items():
|
|
464
|
+
for document_id in doc_id_list:
|
|
465
|
+
gold_doc_to_items.setdefault(document_id, set()).add(
|
|
466
|
+
self._norm(item_label)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Build predictions: doc_id -> set(items)
|
|
470
|
+
pred_doc_to_items: Dict[str, set] = {}
|
|
471
|
+
with open(preds_jsonl, "r", encoding="utf-8") as fp_in:
|
|
472
|
+
for line in fp_in:
|
|
473
|
+
row = json.loads(line.strip())
|
|
474
|
+
document_id = str(row.get("id", ""))
|
|
475
|
+
items_list = row.get("terms" if key == "term" else "types", [])
|
|
476
|
+
pred_doc_to_items[document_id] = {
|
|
477
|
+
self._norm(x) for x in items_list if isinstance(x, str)
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
# Micro counts
|
|
481
|
+
true_positive = false_positive = false_negative = 0
|
|
482
|
+
all_document_ids = set(gold_doc_to_items.keys()) | set(pred_doc_to_items.keys())
|
|
483
|
+
for document_id in all_document_ids:
|
|
484
|
+
gold_set = gold_doc_to_items.get(document_id, set())
|
|
485
|
+
pred_set = pred_doc_to_items.get(document_id, set())
|
|
486
|
+
true_positive += len(gold_set & pred_set)
|
|
487
|
+
false_positive += len(pred_set - gold_set)
|
|
488
|
+
false_negative += len(gold_set - pred_set)
|
|
489
|
+
|
|
490
|
+
precision = (
|
|
491
|
+
true_positive / (true_positive + false_positive)
|
|
492
|
+
if (true_positive + false_positive)
|
|
493
|
+
else 0.0
|
|
494
|
+
)
|
|
495
|
+
recall = (
|
|
496
|
+
true_positive / (true_positive + false_negative)
|
|
497
|
+
if (true_positive + false_negative)
|
|
498
|
+
else 0.0
|
|
499
|
+
)
|
|
500
|
+
f1 = (
|
|
501
|
+
2 * precision * recall / (precision + recall)
|
|
502
|
+
if (precision + recall)
|
|
503
|
+
else 0.0
|
|
504
|
+
)
|
|
505
|
+
return f1
|
|
506
|
+
|
|
507
|
+
def predict_types_from_terms(
|
|
508
|
+
self,
|
|
509
|
+
*,
|
|
510
|
+
doc_terms_jsonl: Optional[str] = None, # formerly a1_results_jsonl
|
|
511
|
+
doc_terms_list: Optional[List[Dict]] = None, # formerly a1_results_list
|
|
512
|
+
few_shot_jsonl: Optional[
|
|
513
|
+
str
|
|
514
|
+
] = None, # JSONL lines: {"term":"...", "types":[...]}
|
|
515
|
+
rag_terms_json: Optional[
|
|
516
|
+
str
|
|
517
|
+
] = None, # JSON list; items may contain "term" and "RAG":[...]
|
|
518
|
+
random_few_shot: Optional[int] = 3,
|
|
519
|
+
model_id: str = "Qwen/Qwen2.5-1.5B-Instruct",
|
|
520
|
+
use_structured_output: bool = True,
|
|
521
|
+
seed: int = 42,
|
|
522
|
+
out_terms2types: str = "terms2types_pred.json",
|
|
523
|
+
out_types2docs: str = "types2docs_pred.json",
|
|
524
|
+
) -> Dict[str, Any]:
|
|
525
|
+
"""
|
|
526
|
+
Predict types for each unique term extracted per document and derive a types→docs map.
|
|
527
|
+
|
|
528
|
+
Parameters
|
|
529
|
+
----------
|
|
530
|
+
doc_terms_jsonl : Optional[str]
|
|
531
|
+
Path to JSONL with lines like {"id": "...", "terms": [...]} or a JSON with {"results":[...]}.
|
|
532
|
+
doc_terms_list : Optional[List[Dict]]
|
|
533
|
+
In-memory results like [{"id":"...","extracted_terms":[...]}] or {"id":"...","terms":[...]}.
|
|
534
|
+
few_shot_jsonl : Optional[str]
|
|
535
|
+
Global few-shot exemplars: one JSON object per line with {"term": "...", "types":[...]}.
|
|
536
|
+
rag_terms_json : Optional[str]
|
|
537
|
+
Optional per-term RAG exemplars: a JSON list of {"term": "...", "RAG":[{"term": "...", "types":[...]}]}.
|
|
538
|
+
random_few_shot : Optional[int]
|
|
539
|
+
If provided, randomly select up to this many few-shot examples for each prediction.
|
|
540
|
+
model_id : str
|
|
541
|
+
HF model id used specifically for term→types predictions.
|
|
542
|
+
use_structured_output : bool
|
|
543
|
+
If True and outlines is available, enforce structured {"types":[...]} output.
|
|
544
|
+
seed : int
|
|
545
|
+
Random seed for reproducibility.
|
|
546
|
+
out_terms2types : str
|
|
547
|
+
Output JSON path for list of {"term": "...", "predicted_types":[...]}.
|
|
548
|
+
out_types2docs : str
|
|
549
|
+
Output JSON path for dict {"TYPE":[doc_ids,...], ...}.
|
|
550
|
+
|
|
551
|
+
Returns
|
|
552
|
+
-------
|
|
553
|
+
Dict[str, Any]
|
|
554
|
+
Summary with predictions and counts.
|
|
555
|
+
"""
|
|
556
|
+
torch.manual_seed(seed)
|
|
557
|
+
if torch.cuda.is_available():
|
|
558
|
+
torch.cuda.manual_seed(seed)
|
|
559
|
+
|
|
560
|
+
# Load normalized document→terms results
|
|
561
|
+
doc_term_extractions = self._load_doc_term_extractions(
|
|
562
|
+
results_json_path=doc_terms_jsonl,
|
|
563
|
+
in_memory_results=doc_terms_list,
|
|
564
|
+
)
|
|
565
|
+
if not doc_term_extractions:
|
|
566
|
+
raise ValueError(
|
|
567
|
+
"No document→terms results provided (doc_terms_jsonl/doc_terms_list)."
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Prepare unique term list and term→doc occurrences
|
|
571
|
+
unique_terms = self._collect_unique_terms_from_extractions(doc_term_extractions)
|
|
572
|
+
term_to_doc_ids_map = self._build_term_to_doc_ids(doc_term_extractions)
|
|
573
|
+
|
|
574
|
+
# Load optional global few-shot examples
|
|
575
|
+
global_few_shot_examples: List[Dict] = []
|
|
576
|
+
if few_shot_jsonl and os.path.exists(few_shot_jsonl):
|
|
577
|
+
with open(few_shot_jsonl, "r", encoding="utf-8") as few_shot_file:
|
|
578
|
+
for raw_line in few_shot_file:
|
|
579
|
+
raw_line = raw_line.strip()
|
|
580
|
+
if not raw_line:
|
|
581
|
+
continue
|
|
582
|
+
try:
|
|
583
|
+
json_obj = json.loads(raw_line)
|
|
584
|
+
except Exception:
|
|
585
|
+
continue
|
|
586
|
+
if (
|
|
587
|
+
isinstance(json_obj, dict)
|
|
588
|
+
and "term" in json_obj
|
|
589
|
+
and "types" in json_obj
|
|
590
|
+
):
|
|
591
|
+
global_few_shot_examples.append(json_obj)
|
|
592
|
+
|
|
593
|
+
# Optional per-term RAG examples: {normalized_term -> [examples]}
|
|
594
|
+
rag_examples_lookup: Dict[str, List[Dict]] = {}
|
|
595
|
+
if rag_terms_json and os.path.exists(rag_terms_json):
|
|
596
|
+
try:
|
|
597
|
+
rag_payload = self._load_json(rag_terms_json)
|
|
598
|
+
if isinstance(rag_payload, list):
|
|
599
|
+
for rag_item in rag_payload:
|
|
600
|
+
if isinstance(rag_item, dict):
|
|
601
|
+
normalized_term = self._normalize_term(
|
|
602
|
+
rag_item.get("term", "")
|
|
603
|
+
)
|
|
604
|
+
rag_examples_lookup[normalized_term] = rag_item.get(
|
|
605
|
+
"RAG", []
|
|
606
|
+
)
|
|
607
|
+
except Exception:
|
|
608
|
+
pass
|
|
609
|
+
|
|
610
|
+
# Load a small chat LLM dedicated to Term→Types
|
|
611
|
+
typing_model, typing_tokenizer = self._load_llm_for_types(model_id)
|
|
612
|
+
|
|
613
|
+
# Predict types per term
|
|
614
|
+
term_to_predicted_types_list: List[Dict] = []
|
|
615
|
+
for term_text in unique_terms:
|
|
616
|
+
normalized_term = self._normalize_term(term_text)
|
|
617
|
+
|
|
618
|
+
# Prefer per-term RAG for this term, else use global few-shot
|
|
619
|
+
few_shot_examples_for_term = (
|
|
620
|
+
rag_examples_lookup.get(normalized_term, None)
|
|
621
|
+
or global_few_shot_examples
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Build conversation and prompt
|
|
625
|
+
conversation_messages = self._build_conv_for_type_infer(
|
|
626
|
+
term=term_text,
|
|
627
|
+
few_shot_examples=few_shot_examples_for_term,
|
|
628
|
+
random_k=random_few_shot,
|
|
629
|
+
)
|
|
630
|
+
typing_prompt_string = self._apply_chat_template_safe_types(
|
|
631
|
+
typing_tokenizer, conversation_messages
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
predicted_types: List[str] = []
|
|
635
|
+
raw_generation_text: str = ""
|
|
636
|
+
|
|
637
|
+
# Structured JSON path (if requested and available)
|
|
638
|
+
if (
|
|
639
|
+
use_structured_output
|
|
640
|
+
and OUTLINES_AVAILABLE
|
|
641
|
+
and _PredictedTypesSchema is not None
|
|
642
|
+
):
|
|
643
|
+
try:
|
|
644
|
+
outlines_model = OutlinesTFModel(typing_model, typing_tokenizer) # type: ignore
|
|
645
|
+
generator = outlines_generate_json(
|
|
646
|
+
outlines_model, _PredictedTypesSchema
|
|
647
|
+
) # type: ignore
|
|
648
|
+
structured = generator(typing_prompt_string, max_tokens=512)
|
|
649
|
+
predicted_types = [
|
|
650
|
+
label for label in structured.types if isinstance(label, str)
|
|
651
|
+
]
|
|
652
|
+
raw_generation_text = json.dumps(
|
|
653
|
+
{"types": predicted_types}, ensure_ascii=False
|
|
654
|
+
)
|
|
655
|
+
except Exception:
|
|
656
|
+
# Fall back to greedy decoding
|
|
657
|
+
use_structured_output = False
|
|
658
|
+
|
|
659
|
+
# Greedy decode fallback
|
|
660
|
+
if (
|
|
661
|
+
not use_structured_output
|
|
662
|
+
or not OUTLINES_AVAILABLE
|
|
663
|
+
or _PredictedTypesSchema is None
|
|
664
|
+
):
|
|
665
|
+
tokenized_prompt = typing_tokenizer(
|
|
666
|
+
typing_prompt_string,
|
|
667
|
+
return_tensors="pt",
|
|
668
|
+
truncation=True,
|
|
669
|
+
max_length=2048,
|
|
670
|
+
)
|
|
671
|
+
if torch.cuda.is_available():
|
|
672
|
+
tokenized_prompt = {
|
|
673
|
+
name: tensor.cuda() for name, tensor in tokenized_prompt.items()
|
|
674
|
+
}
|
|
675
|
+
with torch.no_grad():
|
|
676
|
+
output_ids = typing_model.generate(
|
|
677
|
+
**tokenized_prompt,
|
|
678
|
+
max_new_tokens=256,
|
|
679
|
+
do_sample=False,
|
|
680
|
+
num_beams=1,
|
|
681
|
+
pad_token_id=typing_tokenizer.eos_token_id,
|
|
682
|
+
)
|
|
683
|
+
new_token_span = output_ids[0][tokenized_prompt["input_ids"].shape[1] :]
|
|
684
|
+
raw_generation_text = typing_tokenizer.decode(
|
|
685
|
+
new_token_span, skip_special_tokens=True
|
|
686
|
+
)
|
|
687
|
+
predicted_types = self._extract_types_from_text(raw_generation_text)
|
|
688
|
+
|
|
689
|
+
term_to_predicted_types_list.append(
|
|
690
|
+
{
|
|
691
|
+
"term": term_text,
|
|
692
|
+
"predicted_types": sorted(set(predicted_types)),
|
|
693
|
+
}
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# 7) Build types→docs from (term→types) and (term→docs)
|
|
697
|
+
types_to_doc_id_set: Dict[str, set] = {}
|
|
698
|
+
for term_prediction in term_to_predicted_types_list:
|
|
699
|
+
normalized_term = self._normalize_term(term_prediction["term"])
|
|
700
|
+
doc_ids_for_term = term_to_doc_ids_map.get(normalized_term, [])
|
|
701
|
+
for type_label in term_prediction.get("predicted_types", []):
|
|
702
|
+
types_to_doc_id_set.setdefault(type_label, set()).update(
|
|
703
|
+
doc_ids_for_term
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
types_to_doc_ids: Dict[str, List[str]] = {
|
|
707
|
+
type_label: sorted(doc_id_set)
|
|
708
|
+
for type_label, doc_id_set in types_to_doc_id_set.items()
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
# 8) Save outputs
|
|
712
|
+
os.makedirs(os.path.dirname(out_terms2types) or ".", exist_ok=True)
|
|
713
|
+
with open(out_terms2types, "w", encoding="utf-8") as fp_terms2types:
|
|
714
|
+
json.dump(
|
|
715
|
+
term_to_predicted_types_list,
|
|
716
|
+
fp_terms2types,
|
|
717
|
+
ensure_ascii=False,
|
|
718
|
+
indent=2,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
os.makedirs(os.path.dirname(out_types2docs) or ".", exist_ok=True)
|
|
722
|
+
with open(out_types2docs, "w", encoding="utf-8") as fp_types2docs:
|
|
723
|
+
json.dump(types_to_doc_ids, fp_types2docs, ensure_ascii=False, indent=2)
|
|
724
|
+
|
|
725
|
+
# Cleanup VRAM if any
|
|
726
|
+
del typing_model, typing_tokenizer
|
|
727
|
+
if torch.cuda.is_available():
|
|
728
|
+
torch.cuda.empty_cache()
|
|
729
|
+
|
|
730
|
+
return {
|
|
731
|
+
"terms2types_pred": term_to_predicted_types_list,
|
|
732
|
+
"types2docs_pred": types_to_doc_ids,
|
|
733
|
+
"unique_terms": len(unique_terms),
|
|
734
|
+
"types_count": len(types_to_doc_ids),
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
def _load_json(self, path: str) -> Dict[str, Any]:
|
|
738
|
+
"""Load a JSON file from disk and return its parsed object."""
|
|
739
|
+
with open(path, "r", encoding="utf-8") as file_obj:
|
|
740
|
+
return json.load(file_obj)
|
|
741
|
+
|
|
742
|
+
def _iter_json_objects(self, blob: str) -> Iterable[Dict[str, Any]]:
|
|
743
|
+
"""
|
|
744
|
+
Iterate over *all* JSON objects found inside a string.
|
|
745
|
+
|
|
746
|
+
Supports cases where multiple JSON objects are concatenated back-to-back
|
|
747
|
+
in a single line. It skips stray commas/whitespace between objects.
|
|
748
|
+
|
|
749
|
+
Parameters
|
|
750
|
+
----------
|
|
751
|
+
blob : str
|
|
752
|
+
A string that may contain one or more JSON objects.
|
|
753
|
+
|
|
754
|
+
Yields
|
|
755
|
+
------
|
|
756
|
+
Dict[str, Any]
|
|
757
|
+
Each parsed JSON object.
|
|
758
|
+
"""
|
|
759
|
+
json_decoder = json.JSONDecoder()
|
|
760
|
+
cursor_index, text_length = 0, len(blob)
|
|
761
|
+
while cursor_index < text_length:
|
|
762
|
+
# Skip whitespace/commas between objects
|
|
763
|
+
while cursor_index < text_length and blob[cursor_index] in " \t\r\n,":
|
|
764
|
+
cursor_index += 1
|
|
765
|
+
if cursor_index >= text_length:
|
|
766
|
+
break
|
|
767
|
+
try:
|
|
768
|
+
json_obj, end_index = json_decoder.raw_decode(blob, idx=cursor_index)
|
|
769
|
+
except JSONDecodeError:
|
|
770
|
+
# Can't decode from this position; stop scanning this chunk
|
|
771
|
+
break
|
|
772
|
+
yield json_obj
|
|
773
|
+
cursor_index = end_index
|
|
774
|
+
|
|
775
|
+
def _load_documents_jsonl(self, path: str) -> Dict[str, Dict[str, Any]]:
|
|
776
|
+
"""
|
|
777
|
+
Robust reader that supports:
|
|
778
|
+
• True JSONL (one object per line)
|
|
779
|
+
• Lines with multiple concatenated JSON objects
|
|
780
|
+
• Whole file as a JSON array
|
|
781
|
+
|
|
782
|
+
Returns
|
|
783
|
+
-------
|
|
784
|
+
Dict[str, Dict[str, Any]]
|
|
785
|
+
Mapping doc_id -> full document row.
|
|
786
|
+
"""
|
|
787
|
+
documents_by_id: Dict[str, Dict[str, Any]] = {}
|
|
788
|
+
|
|
789
|
+
with open(path, "r", encoding="utf-8") as file_obj:
|
|
790
|
+
content = file_obj.read().strip()
|
|
791
|
+
|
|
792
|
+
# Case A: whole-file JSON array
|
|
793
|
+
if content.startswith("["):
|
|
794
|
+
try:
|
|
795
|
+
json_array = json.loads(content)
|
|
796
|
+
if isinstance(json_array, list):
|
|
797
|
+
for record in json_array:
|
|
798
|
+
if not isinstance(record, dict):
|
|
799
|
+
continue
|
|
800
|
+
document_id = str(
|
|
801
|
+
record.get("id")
|
|
802
|
+
or record.get("doc_id")
|
|
803
|
+
or (record.get("doc") or {}).get("id")
|
|
804
|
+
or ""
|
|
805
|
+
)
|
|
806
|
+
if document_id:
|
|
807
|
+
documents_by_id[document_id] = record
|
|
808
|
+
return documents_by_id
|
|
809
|
+
except Exception:
|
|
810
|
+
# Fall back to line-wise handling if array parsing fails
|
|
811
|
+
pass
|
|
812
|
+
|
|
813
|
+
# Case B: treat as JSONL-ish; parse *all* objects per line
|
|
814
|
+
for raw_line in content.splitlines():
|
|
815
|
+
line = raw_line.strip()
|
|
816
|
+
if not line:
|
|
817
|
+
continue
|
|
818
|
+
for record in self._iter_json_objects(line):
|
|
819
|
+
if not isinstance(record, dict):
|
|
820
|
+
continue
|
|
821
|
+
document_id = str(
|
|
822
|
+
record.get("id")
|
|
823
|
+
or record.get("doc_id")
|
|
824
|
+
or (record.get("doc") or {}).get("id")
|
|
825
|
+
or ""
|
|
826
|
+
)
|
|
827
|
+
if document_id:
|
|
828
|
+
documents_by_id[document_id] = record
|
|
829
|
+
|
|
830
|
+
return documents_by_id
|
|
831
|
+
|
|
832
|
+
def _to_text(self, text_field: Any) -> str:
|
|
833
|
+
"""
|
|
834
|
+
Convert a 'text' field into a single string (handles list-of-strings).
|
|
835
|
+
|
|
836
|
+
Parameters
|
|
837
|
+
----------
|
|
838
|
+
text_field : Any
|
|
839
|
+
The value found under "text" in the dataset row.
|
|
840
|
+
|
|
841
|
+
Returns
|
|
842
|
+
-------
|
|
843
|
+
str
|
|
844
|
+
A single-string representation of the text.
|
|
845
|
+
"""
|
|
846
|
+
if isinstance(text_field, str):
|
|
847
|
+
return text_field
|
|
848
|
+
if isinstance(text_field, list):
|
|
849
|
+
return " ".join(str(part) for part in text_field)
|
|
850
|
+
return str(text_field) if text_field is not None else ""
|
|
851
|
+
|
|
852
|
+
def _unique_preserve(self, values: List[str]) -> List[str]:
|
|
853
|
+
"""
|
|
854
|
+
Deduplicate values while preserving the original order.
|
|
855
|
+
|
|
856
|
+
Parameters
|
|
857
|
+
----------
|
|
858
|
+
values : List[str]
|
|
859
|
+
Sequence possibly containing duplicates.
|
|
860
|
+
|
|
861
|
+
Returns
|
|
862
|
+
-------
|
|
863
|
+
List[str]
|
|
864
|
+
Sequence without duplicates, order preserved.
|
|
865
|
+
"""
|
|
866
|
+
seen_values: set = set()
|
|
867
|
+
ordered_values: List[str] = []
|
|
868
|
+
for candidate in values:
|
|
869
|
+
if candidate not in seen_values:
|
|
870
|
+
seen_values.add(candidate)
|
|
871
|
+
ordered_values.append(candidate)
|
|
872
|
+
return ordered_values
|
|
873
|
+
|
|
874
|
+
def _norm(self, text: str) -> str:
|
|
875
|
+
"""
|
|
876
|
+
Lowercased, single-spaced normalization (for comparisons).
|
|
877
|
+
|
|
878
|
+
Parameters
|
|
879
|
+
----------
|
|
880
|
+
text : str
|
|
881
|
+
Input string.
|
|
882
|
+
|
|
883
|
+
Returns
|
|
884
|
+
-------
|
|
885
|
+
str
|
|
886
|
+
Normalized string.
|
|
887
|
+
"""
|
|
888
|
+
return " ".join(text.lower().split())
|
|
889
|
+
|
|
890
|
+
def _normalize_term(self, term: str) -> str:
|
|
891
|
+
"""
|
|
892
|
+
Normalization tailored for term keys / lookups.
|
|
893
|
+
|
|
894
|
+
Parameters
|
|
895
|
+
----------
|
|
896
|
+
term : str
|
|
897
|
+
Term to normalize.
|
|
898
|
+
|
|
899
|
+
Returns
|
|
900
|
+
-------
|
|
901
|
+
str
|
|
902
|
+
Lowercased, trimmed and single-spaced term.
|
|
903
|
+
"""
|
|
904
|
+
return " ".join(str(term).strip().split()).lower()
|
|
905
|
+
|
|
906
|
+
def _format_fewshot_block(
|
|
907
|
+
self,
|
|
908
|
+
system_prompt: str,
|
|
909
|
+
fewshot_examples: List[Tuple[str, str, List[str]]],
|
|
910
|
+
*,
|
|
911
|
+
key: str,
|
|
912
|
+
k: int = 6,
|
|
913
|
+
) -> str:
|
|
914
|
+
"""
|
|
915
|
+
Render a few-shot block like:
|
|
916
|
+
|
|
917
|
+
<SYSTEM PROMPT>
|
|
918
|
+
|
|
919
|
+
### Example
|
|
920
|
+
User:
|
|
921
|
+
Title: ...
|
|
922
|
+
<text>
|
|
923
|
+
Assistant:
|
|
924
|
+
{"terms": [...]} or {"types": [...]}
|
|
925
|
+
|
|
926
|
+
Parameters
|
|
927
|
+
----------
|
|
928
|
+
system_prompt : str
|
|
929
|
+
Instructional system text to prepend.
|
|
930
|
+
fewshot_examples : List[Tuple[str, str, List[str]]]
|
|
931
|
+
Examples as (title, text, labels_list).
|
|
932
|
+
key : str
|
|
933
|
+
Either "terms" or "types" depending on the task.
|
|
934
|
+
k : int
|
|
935
|
+
Number of examples to include.
|
|
936
|
+
|
|
937
|
+
Returns
|
|
938
|
+
-------
|
|
939
|
+
str
|
|
940
|
+
Formatted few-shot block text.
|
|
941
|
+
"""
|
|
942
|
+
lines: List[str] = [system_prompt.strip(), ""]
|
|
943
|
+
for example_title, example_text, gold_list in fewshot_examples[:k]:
|
|
944
|
+
lines.append("### Example")
|
|
945
|
+
lines.append(f"User:\nTitle: {example_title}\n{example_text}")
|
|
946
|
+
lines.append(
|
|
947
|
+
f'Assistant:\n{{"{key}": '
|
|
948
|
+
+ json.dumps(gold_list, ensure_ascii=False)
|
|
949
|
+
+ "}"
|
|
950
|
+
)
|
|
951
|
+
return "\n".join(lines)
|
|
952
|
+
|
|
953
|
+
def _format_user_block(self, title: str, text: str) -> str:
|
|
954
|
+
"""
|
|
955
|
+
Format the 'Task' block for the current document.
|
|
956
|
+
|
|
957
|
+
Parameters
|
|
958
|
+
----------
|
|
959
|
+
title : str
|
|
960
|
+
Document title.
|
|
961
|
+
text : str
|
|
962
|
+
Document text (single string).
|
|
963
|
+
|
|
964
|
+
Returns
|
|
965
|
+
-------
|
|
966
|
+
str
|
|
967
|
+
Formatted user block.
|
|
968
|
+
"""
|
|
969
|
+
return f"### Task\nUser:\nTitle: {title}\n{text}"
|
|
970
|
+
|
|
971
|
+
def _parse_json_list(self, generated_text: str, *, key: str) -> List[str]:
|
|
972
|
+
"""
|
|
973
|
+
Extract a list from model output, trying:
|
|
974
|
+
1) JSON object with the key ({"terms":[...]} or {"types":[...]}).
|
|
975
|
+
2) Any top-level JSON array.
|
|
976
|
+
3) Fallback: comma-split.
|
|
977
|
+
|
|
978
|
+
Parameters
|
|
979
|
+
----------
|
|
980
|
+
generated_text : str
|
|
981
|
+
Raw generation text to parse.
|
|
982
|
+
key : str
|
|
983
|
+
"terms" or "types".
|
|
984
|
+
|
|
985
|
+
Returns
|
|
986
|
+
-------
|
|
987
|
+
List[str]
|
|
988
|
+
Parsed strings (best-effort).
|
|
989
|
+
"""
|
|
990
|
+
# 1) Try a JSON object and read key
|
|
991
|
+
try:
|
|
992
|
+
object_match = self._json_object_regex.search(generated_text)
|
|
993
|
+
if object_match:
|
|
994
|
+
json_obj = json.loads(object_match.group(0))
|
|
995
|
+
json_array = json_obj.get(key)
|
|
996
|
+
if isinstance(json_array, list):
|
|
997
|
+
return [value for value in json_array if isinstance(value, str)]
|
|
998
|
+
except Exception:
|
|
999
|
+
pass
|
|
1000
|
+
|
|
1001
|
+
# 2) Any JSON array
|
|
1002
|
+
try:
|
|
1003
|
+
array_match = self._json_array_regex.search(generated_text)
|
|
1004
|
+
if array_match:
|
|
1005
|
+
json_array = json.loads(array_match.group(0))
|
|
1006
|
+
if isinstance(json_array, list):
|
|
1007
|
+
return [value for value in json_array if isinstance(value, str)]
|
|
1008
|
+
except Exception:
|
|
1009
|
+
pass
|
|
1010
|
+
|
|
1011
|
+
# 3) Fallback: comma-split (last resort)
|
|
1012
|
+
if "," in generated_text:
|
|
1013
|
+
return [
|
|
1014
|
+
part.strip().strip('"').strip("'")
|
|
1015
|
+
for part in generated_text.split(",")
|
|
1016
|
+
if part.strip()
|
|
1017
|
+
]
|
|
1018
|
+
return []
|
|
1019
|
+
|
|
1020
|
+
def _apply_chat_template_safe_types(
|
|
1021
|
+
self, tokenizer: AutoTokenizer, messages: List[Dict[str, str]]
|
|
1022
|
+
) -> str:
|
|
1023
|
+
"""
|
|
1024
|
+
Safely build a prompt string for chat models. Uses the model's chat template
|
|
1025
|
+
when available; otherwise falls back to a simple concatenation.
|
|
1026
|
+
"""
|
|
1027
|
+
try:
|
|
1028
|
+
return tokenizer.apply_chat_template(
|
|
1029
|
+
messages, add_generation_prompt=True, tokenize=False
|
|
1030
|
+
)
|
|
1031
|
+
except Exception:
|
|
1032
|
+
system_text = next(
|
|
1033
|
+
(m["content"] for m in messages if m.get("role") == "system"), ""
|
|
1034
|
+
)
|
|
1035
|
+
last_user_text = next(
|
|
1036
|
+
(m["content"] for m in reversed(messages) if m.get("role") == "user"),
|
|
1037
|
+
"",
|
|
1038
|
+
)
|
|
1039
|
+
return f"{system_text}\n\nUser:\n{last_user_text}\n\nAssistant:"
|
|
1040
|
+
|
|
1041
|
+
def _build_conv_for_type_infer(
|
|
1042
|
+
self,
|
|
1043
|
+
term: str,
|
|
1044
|
+
few_shot_examples: Optional[List[Dict]] = None,
|
|
1045
|
+
random_k: Optional[int] = None,
|
|
1046
|
+
) -> List[Dict[str, str]]:
|
|
1047
|
+
"""
|
|
1048
|
+
Create a chat-style conversation for a single term→types query,
|
|
1049
|
+
optionally prepending few-shot examples.
|
|
1050
|
+
"""
|
|
1051
|
+
messages: List[Dict[str, str]] = [
|
|
1052
|
+
{"role": "system", "content": self._system_prompt_term_to_types}
|
|
1053
|
+
]
|
|
1054
|
+
examples = list(few_shot_examples or [])
|
|
1055
|
+
if random_k and len(examples) > random_k:
|
|
1056
|
+
import random as _rnd
|
|
1057
|
+
|
|
1058
|
+
examples = _rnd.sample(examples, random_k)
|
|
1059
|
+
for exemplar in examples:
|
|
1060
|
+
example_term = exemplar.get("term", "")
|
|
1061
|
+
example_types = exemplar.get("types", [])
|
|
1062
|
+
messages.append({"role": "user", "content": f"Term: {example_term}"})
|
|
1063
|
+
messages.append(
|
|
1064
|
+
{
|
|
1065
|
+
"role": "assistant",
|
|
1066
|
+
"content": json.dumps({"types": example_types}, ensure_ascii=False),
|
|
1067
|
+
}
|
|
1068
|
+
)
|
|
1069
|
+
messages.append({"role": "user", "content": f"Term: {term}"})
|
|
1070
|
+
return messages
|
|
1071
|
+
|
|
1072
|
+
def _extract_types_from_text(self, generated_text: str) -> List[str]:
|
|
1073
|
+
"""
|
|
1074
|
+
Parse {"types":[...]} from a free-form generation.
|
|
1075
|
+
"""
|
|
1076
|
+
try:
|
|
1077
|
+
object_match = re.search(r'\{[^}]*"types"[^}]*\}', generated_text)
|
|
1078
|
+
if object_match:
|
|
1079
|
+
json_obj = json.loads(object_match.group(0))
|
|
1080
|
+
types_array = json_obj.get("types", [])
|
|
1081
|
+
return [
|
|
1082
|
+
type_label
|
|
1083
|
+
for type_label in types_array
|
|
1084
|
+
if isinstance(type_label, str)
|
|
1085
|
+
]
|
|
1086
|
+
except Exception:
|
|
1087
|
+
pass
|
|
1088
|
+
return []
|
|
1089
|
+
|
|
1090
|
+
def _load_llm_for_types(
|
|
1091
|
+
self, model_id: str
|
|
1092
|
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
|
1093
|
+
"""
|
|
1094
|
+
Load a *separate* small chat model for Term→Types (keeps LocalAutoLLM untouched).
|
|
1095
|
+
"""
|
|
1096
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
1097
|
+
if tokenizer.pad_token is None:
|
|
1098
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1099
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1100
|
+
model_id,
|
|
1101
|
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
|
1102
|
+
device_map="auto" if torch.cuda.is_available() else None,
|
|
1103
|
+
)
|
|
1104
|
+
return model, tokenizer
|
|
1105
|
+
|
|
1106
|
+
def _load_doc_term_extractions(
|
|
1107
|
+
self,
|
|
1108
|
+
*,
|
|
1109
|
+
results_json_path: Optional[str] = None,
|
|
1110
|
+
in_memory_results: Optional[List[Dict]] = None,
|
|
1111
|
+
) -> List[Dict]:
|
|
1112
|
+
"""
|
|
1113
|
+
Normalize document→terms outputs to a list of:
|
|
1114
|
+
{"id": "<doc_id>", "extracted_terms": ["...", ...]}
|
|
1115
|
+
|
|
1116
|
+
Accepts either:
|
|
1117
|
+
- in_memory_results (list of dicts)
|
|
1118
|
+
- results_json_path pointing to:
|
|
1119
|
+
• a JSONL file with lines: {"id": "...", "terms": [...]}
|
|
1120
|
+
• OR a JSON file with {"results":[{"id":..., "extracted_terms": [...]}, ...]}
|
|
1121
|
+
• OR a JSON list of dicts
|
|
1122
|
+
"""
|
|
1123
|
+
normalized_records: List[Dict] = []
|
|
1124
|
+
|
|
1125
|
+
def _coerce_to_record(source_row: Dict) -> Optional[Dict]:
|
|
1126
|
+
document_id = str(source_row.get("id", "")) or str(
|
|
1127
|
+
source_row.get("doc_id", "")
|
|
1128
|
+
)
|
|
1129
|
+
if not document_id:
|
|
1130
|
+
return None
|
|
1131
|
+
terms = source_row.get("extracted_terms")
|
|
1132
|
+
if terms is None:
|
|
1133
|
+
terms = source_row.get("terms")
|
|
1134
|
+
if (
|
|
1135
|
+
terms is None
|
|
1136
|
+
and "payload" in source_row
|
|
1137
|
+
and isinstance(source_row["payload"], dict)
|
|
1138
|
+
):
|
|
1139
|
+
terms = source_row["payload"].get("terms")
|
|
1140
|
+
if not isinstance(terms, list):
|
|
1141
|
+
terms = []
|
|
1142
|
+
return {
|
|
1143
|
+
"id": document_id,
|
|
1144
|
+
"extracted_terms": [t for t in terms if isinstance(t, str)],
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
if in_memory_results is not None:
|
|
1148
|
+
for source_row in in_memory_results:
|
|
1149
|
+
coerced_record = _coerce_to_record(source_row)
|
|
1150
|
+
if coerced_record:
|
|
1151
|
+
normalized_records.append(coerced_record)
|
|
1152
|
+
return normalized_records
|
|
1153
|
+
|
|
1154
|
+
if not results_json_path:
|
|
1155
|
+
raise ValueError("Provide results_json_path or in_memory_results")
|
|
1156
|
+
|
|
1157
|
+
# Detect JSON vs JSONL by extension (best-effort)
|
|
1158
|
+
if results_json_path.endswith(".jsonl"):
|
|
1159
|
+
with open(results_json_path, "r", encoding="utf-8") as file_in:
|
|
1160
|
+
for raw_line in file_in:
|
|
1161
|
+
raw_line = raw_line.strip()
|
|
1162
|
+
if not raw_line:
|
|
1163
|
+
continue
|
|
1164
|
+
# Multiple concatenated objects per line? Iterate them all.
|
|
1165
|
+
for json_obj in self._iter_json_objects(raw_line):
|
|
1166
|
+
if isinstance(json_obj, dict):
|
|
1167
|
+
coerced_record = _coerce_to_record(json_obj)
|
|
1168
|
+
if coerced_record:
|
|
1169
|
+
normalized_records.append(coerced_record)
|
|
1170
|
+
else:
|
|
1171
|
+
payload_obj = self._load_json(results_json_path)
|
|
1172
|
+
if isinstance(payload_obj, dict) and "results" in payload_obj:
|
|
1173
|
+
for source_row in payload_obj["results"]:
|
|
1174
|
+
coerced_record = _coerce_to_record(source_row)
|
|
1175
|
+
if coerced_record:
|
|
1176
|
+
normalized_records.append(coerced_record)
|
|
1177
|
+
elif isinstance(payload_obj, list):
|
|
1178
|
+
for source_row in payload_obj:
|
|
1179
|
+
if isinstance(source_row, dict):
|
|
1180
|
+
coerced_record = _coerce_to_record(source_row)
|
|
1181
|
+
if coerced_record:
|
|
1182
|
+
normalized_records.append(coerced_record)
|
|
1183
|
+
|
|
1184
|
+
return normalized_records
|
|
1185
|
+
|
|
1186
|
+
def _collect_unique_terms_from_extractions(
|
|
1187
|
+
self, doc_term_extractions: List[Dict]
|
|
1188
|
+
) -> List[str]:
|
|
1189
|
+
"""
|
|
1190
|
+
Collect unique terms (original casing) from normalized document→terms results.
|
|
1191
|
+
"""
|
|
1192
|
+
seen_normalized_terms: set = set()
|
|
1193
|
+
ordered_unique_terms: List[str] = []
|
|
1194
|
+
for record in doc_term_extractions:
|
|
1195
|
+
for term_text in record.get("extracted_terms", []):
|
|
1196
|
+
normalized = self._normalize_term(term_text)
|
|
1197
|
+
if normalized and normalized not in seen_normalized_terms:
|
|
1198
|
+
seen_normalized_terms.add(normalized)
|
|
1199
|
+
ordered_unique_terms.append(term_text.strip())
|
|
1200
|
+
return ordered_unique_terms
|
|
1201
|
+
|
|
1202
|
+
def _build_term_to_doc_ids(
|
|
1203
|
+
self, doc_term_extractions: List[Dict]
|
|
1204
|
+
) -> Dict[str, List[str]]:
|
|
1205
|
+
"""
|
|
1206
|
+
Build lookup: normalized_term -> sorted unique list of doc_ids.
|
|
1207
|
+
"""
|
|
1208
|
+
term_to_doc_set: Dict[str, set] = {}
|
|
1209
|
+
for record in doc_term_extractions:
|
|
1210
|
+
document_id = str(record.get("id", ""))
|
|
1211
|
+
for term_text in record.get("extracted_terms", []):
|
|
1212
|
+
normalized = self._normalize_term(term_text)
|
|
1213
|
+
if not normalized or not document_id:
|
|
1214
|
+
continue
|
|
1215
|
+
term_to_doc_set.setdefault(normalized, set()).add(document_id)
|
|
1216
|
+
return {
|
|
1217
|
+
normalized_term: sorted(doc_ids)
|
|
1218
|
+
for normalized_term, doc_ids in term_to_doc_set.items()
|
|
1219
|
+
}
|