OntoLearner 1.4.10__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +41 -18
- ontolearner/evaluation/metrics.py +72 -32
- ontolearner/learner/__init__.py +3 -2
- ontolearner/learner/label_mapper.py +5 -4
- ontolearner/learner/llm.py +257 -0
- ontolearner/learner/prompt.py +40 -5
- ontolearner/learner/rag/__init__.py +14 -0
- ontolearner/learner/{rag.py → rag/rag.py} +7 -2
- ontolearner/learner/retriever/__init__.py +1 -1
- ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
- ontolearner/learner/retriever/learner.py +3 -4
- ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
- ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
- ontolearner/learner/text2onto/__init__.py +1 -1
- ontolearner/learner/text2onto/alexbek.py +484 -1105
- ontolearner/learner/text2onto/sbunlp.py +498 -493
- ontolearner/ontology/biology.py +2 -3
- ontolearner/ontology/chemistry.py +16 -18
- ontolearner/ontology/ecology_environment.py +2 -3
- ontolearner/ontology/general.py +4 -6
- ontolearner/ontology/material_science_engineering.py +64 -45
- ontolearner/ontology/medicine.py +2 -3
- ontolearner/ontology/scholarly_knowledge.py +6 -9
- ontolearner/processor.py +3 -3
- ontolearner/text2onto/splitter.py +69 -6
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/METADATA +2 -2
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/RECORD +30 -29
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
5
5
|
# You may obtain a copy of the License at
|
|
6
6
|
#
|
|
7
|
-
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
8
|
#
|
|
9
9
|
# Unless required by applicable law or agreed to in writing, software
|
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
@@ -12,587 +12,592 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import json
|
|
16
|
-
import random
|
|
17
|
-
import re
|
|
18
15
|
import ast
|
|
19
16
|
import gc
|
|
20
|
-
|
|
17
|
+
import random
|
|
18
|
+
import re
|
|
21
19
|
from collections import defaultdict
|
|
20
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Set
|
|
22
21
|
|
|
23
22
|
import torch
|
|
24
|
-
from transformers import
|
|
23
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
25
24
|
|
|
26
|
-
from ...base import AutoLearner
|
|
25
|
+
from ...base import AutoLearner
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
# -----------------------------------------------------------------------------
|
|
30
|
-
# Concrete AutoLLM: local HF wrapper that follows the AutoLLM interface
|
|
31
|
-
# -----------------------------------------------------------------------------
|
|
32
|
-
class LocalAutoLLM(AutoLLM):
|
|
27
|
+
class SBUNLPFewShotLearner(AutoLearner):
|
|
33
28
|
"""
|
|
34
|
-
|
|
35
|
-
|
|
29
|
+
Public API expected by the pipeline:
|
|
30
|
+
- `load(model_id=...)`
|
|
31
|
+
- `fit(train_data, task=..., ontologizer=...)`
|
|
32
|
+
- `predict(test_data, task=..., ontologizer=...)`
|
|
33
|
+
|
|
34
|
+
Expected input bundle format (train/test):
|
|
35
|
+
- "documents": list of dicts, each with keys: {"id", "title", "text"}
|
|
36
|
+
- "terms2docs": dict mapping term -> list of doc_ids
|
|
37
|
+
- "terms2types": optional dict mapping term -> list of types
|
|
38
|
+
|
|
39
|
+
Prediction output payload (pipeline wraps this):
|
|
40
|
+
- {"terms": [{"doc_id": str, "term": str}, ...],
|
|
41
|
+
"types": [{"doc_id": str, "type": str}, ...]}
|
|
36
42
|
"""
|
|
37
43
|
|
|
38
44
|
def __init__(
|
|
39
|
-
self, label_mapper: Any = None, device: str = "cpu", token: str = ""
|
|
40
|
-
) -> None:
|
|
41
|
-
super().__init__(label_mapper=label_mapper, device=device, token=token)
|
|
42
|
-
self.model = None
|
|
43
|
-
self.tokenizer = None
|
|
44
|
-
|
|
45
|
-
def load(
|
|
46
45
|
self,
|
|
47
|
-
|
|
46
|
+
llm_model_id: Optional[str] = None,
|
|
47
|
+
device: str = "cpu",
|
|
48
48
|
load_in_4bit: bool = False,
|
|
49
|
-
|
|
49
|
+
max_new_tokens: int = 256,
|
|
50
50
|
trust_remote_code: bool = True,
|
|
51
|
-
):
|
|
52
|
-
"""
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Initialize the few-shot learner.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
llm_model_id: Default HF model id to load if `load()` is called without one.
|
|
57
|
+
device: "cpu" or a CUDA device identifier (e.g. "cuda").
|
|
58
|
+
load_in_4bit: Whether to attempt 4-bit quantized loading (bitsandbytes).
|
|
59
|
+
max_new_tokens: Maximum tokens to generate per prompt.
|
|
60
|
+
retriever_model_id: Unused (kept for compatibility).
|
|
61
|
+
top_k: Unused (kept for compatibility).
|
|
62
|
+
trust_remote_code: Forwarded to HF loaders (use with caution).
|
|
63
|
+
"""
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.device = device
|
|
66
|
+
self.max_new_tokens = int(max_new_tokens)
|
|
53
67
|
|
|
54
|
-
|
|
55
|
-
|
|
68
|
+
self._default_model_id = llm_model_id
|
|
69
|
+
self._load_in_4bit_default = bool(load_in_4bit)
|
|
70
|
+
self._trust_remote_code_default = bool(trust_remote_code)
|
|
56
71
|
|
|
57
|
-
#
|
|
58
|
-
self.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
72
|
+
# HF objects
|
|
73
|
+
self.model: Optional[AutoModelForCausalLM] = None
|
|
74
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
75
|
+
|
|
76
|
+
self._is_loaded = False
|
|
77
|
+
self._loaded_model_id: Optional[str] = None
|
|
63
78
|
|
|
64
|
-
|
|
79
|
+
# Cached few-shot example blocks built during `fit()`
|
|
80
|
+
self.few_shot_terms_block: str = ""
|
|
81
|
+
self.few_shot_types_block: str = ""
|
|
82
|
+
|
|
83
|
+
def load(self, model_id: Optional[str] = None, **kwargs: Any) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Load the underlying HF causal LM and tokenizer.
|
|
86
|
+
|
|
87
|
+
LearnerPipeline typically calls: `learner.load(model_id=llm_id)`.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model_id: HF model id. If None, uses `llm_model_id` from __init__.
|
|
91
|
+
**kwargs:
|
|
92
|
+
load_in_4bit: override default 4-bit loading.
|
|
93
|
+
trust_remote_code: override default trust_remote_code.
|
|
94
|
+
"""
|
|
95
|
+
resolved_model_id = model_id or self._default_model_id
|
|
96
|
+
if not resolved_model_id:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"No model_id provided to {self.__class__.__name__}.load() and no llm_model_id in __init__."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
load_in_4bit = bool(kwargs.get("load_in_4bit", self._load_in_4bit_default))
|
|
102
|
+
trust_remote_code = bool(kwargs.get("trust_remote_code", self._trust_remote_code_default))
|
|
103
|
+
|
|
104
|
+
# Avoid re-loading same model
|
|
105
|
+
if self._is_loaded and self._loaded_model_id == resolved_model_id:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
109
|
+
|
|
110
|
+
tokenizer = AutoTokenizer.from_pretrained(resolved_model_id, trust_remote_code=trust_remote_code)
|
|
111
|
+
if tokenizer.pad_token is None:
|
|
112
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
113
|
+
self.tokenizer = tokenizer
|
|
114
|
+
|
|
115
|
+
quantization_config = None
|
|
65
116
|
if load_in_4bit:
|
|
66
|
-
|
|
67
|
-
quant_config = BitsAndBytesConfig(
|
|
117
|
+
quantization_config = BitsAndBytesConfig(
|
|
68
118
|
load_in_4bit=True,
|
|
69
119
|
bnb_4bit_compute_dtype=torch.float16,
|
|
70
120
|
bnb_4bit_use_double_quant=True,
|
|
71
121
|
bnb_4bit_quant_type="nf4",
|
|
72
122
|
)
|
|
73
|
-
|
|
74
|
-
torch_dtype_val = torch.float16
|
|
123
|
+
torch_dtype = torch.float16
|
|
75
124
|
|
|
76
|
-
# Set device mapping (auto for multi-GPU or single GPU, explicit CPU otherwise)
|
|
77
125
|
device_map = "auto" if (self.device != "cpu") else {"": "cpu"}
|
|
78
126
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
model_id,
|
|
127
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
128
|
+
resolved_model_id,
|
|
82
129
|
device_map=device_map,
|
|
83
|
-
torch_dtype=
|
|
84
|
-
quantization_config=
|
|
130
|
+
torch_dtype=torch_dtype,
|
|
131
|
+
quantization_config=quantization_config,
|
|
85
132
|
trust_remote_code=trust_remote_code,
|
|
86
133
|
)
|
|
87
134
|
|
|
88
|
-
# Ensure model is on the correct device (redundant if device_map="auto" but safe)
|
|
89
135
|
if self.device == "cpu":
|
|
90
|
-
|
|
136
|
+
model.to("cpu")
|
|
91
137
|
|
|
92
|
-
|
|
93
|
-
self
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
"""Generate continuations for a list of prompts, returning only the generated part."""
|
|
100
|
-
if self.model is None or self.tokenizer is None:
|
|
101
|
-
raise RuntimeError("Model/tokenizer not loaded. Call .load() first.")
|
|
138
|
+
self.model = model
|
|
139
|
+
self._is_loaded = True
|
|
140
|
+
self._loaded_model_id = resolved_model_id
|
|
141
|
+
|
|
142
|
+
def _invert_terms_to_docs_mapping(self, terms_to_documents: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
|
143
|
+
"""
|
|
144
|
+
Convert term->docs mapping to doc->terms mapping.
|
|
102
145
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
enc = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
|
|
106
|
-
input_ids = enc["input_ids"]
|
|
107
|
-
attention_mask = enc["attention_mask"]
|
|
146
|
+
Args:
|
|
147
|
+
terms_to_documents: Mapping from term to list of document IDs.
|
|
108
148
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
149
|
+
Returns:
|
|
150
|
+
Mapping from document ID to list of terms associated with it.
|
|
151
|
+
"""
|
|
152
|
+
document_to_terms: DefaultDict[str, List[str]] = defaultdict(list)
|
|
153
|
+
for term, document_ids in (terms_to_documents or {}).items():
|
|
154
|
+
for document_id in document_ids or []:
|
|
155
|
+
document_to_terms[str(document_id)].append(str(term))
|
|
156
|
+
return dict(document_to_terms)
|
|
113
157
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
temperature > 0.0
|
|
122
|
-
), # Use greedy decoding if temperature is 0.0
|
|
123
|
-
temperature=temperature,
|
|
124
|
-
top_p=top_p,
|
|
125
|
-
pad_token_id=self.tokenizer.eos_token_id,
|
|
126
|
-
)
|
|
158
|
+
def _derive_document_to_types(
|
|
159
|
+
self,
|
|
160
|
+
terms_to_documents: Dict[str, List[str]],
|
|
161
|
+
terms_to_types: Optional[Dict[str, List[str]]],
|
|
162
|
+
) -> Dict[str, List[str]]:
|
|
163
|
+
"""
|
|
164
|
+
Derive doc->types mapping using (terms->docs) and (terms->types).
|
|
127
165
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
full_decoded_text = self.tokenizer.decode(
|
|
132
|
-
output_ids, skip_special_tokens=True
|
|
133
|
-
)
|
|
134
|
-
prompt_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True)
|
|
166
|
+
Args:
|
|
167
|
+
terms_to_documents: term -> [doc_id...]
|
|
168
|
+
terms_to_types: term -> [type...]
|
|
135
169
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
prompt_len = input_ids.shape[1]
|
|
142
|
-
generated_tail = self.tokenizer.decode(
|
|
143
|
-
output_ids[prompt_len:], skip_special_tokens=True
|
|
144
|
-
).strip()
|
|
145
|
-
decoded_outputs.append(generated_tail)
|
|
170
|
+
Returns:
|
|
171
|
+
doc_id -> sorted list of unique types.
|
|
172
|
+
"""
|
|
173
|
+
if not terms_to_types:
|
|
174
|
+
return {}
|
|
146
175
|
|
|
147
|
-
|
|
176
|
+
document_to_types: DefaultDict[str, Set[str]] = defaultdict(set)
|
|
148
177
|
|
|
178
|
+
for term, document_ids in (terms_to_documents or {}).items():
|
|
179
|
+
candidate_types = terms_to_types.get(term, []) or []
|
|
180
|
+
for document_id in document_ids or []:
|
|
181
|
+
for candidate_type in candidate_types:
|
|
182
|
+
if isinstance(candidate_type, str) and candidate_type.strip():
|
|
183
|
+
document_to_types[str(document_id)].add(candidate_type.strip())
|
|
149
184
|
|
|
150
|
-
|
|
151
|
-
# Main Learner: SBUNLPFewShotLearner (Task A Text2Onto)
|
|
152
|
-
# -----------------------------------------------------------------------------
|
|
153
|
-
class SBUNLPFewShotLearner(AutoLearner):
|
|
154
|
-
"""
|
|
155
|
-
Concrete learner implementing the Task A Text2Onto pipeline (Term and Type Extraction).
|
|
156
|
-
It uses Few-Shot prompts generated from training data for inference.
|
|
157
|
-
"""
|
|
185
|
+
return {doc_id: sorted(list(type_set)) for doc_id, type_set in document_to_types.items()}
|
|
158
186
|
|
|
159
|
-
def
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
187
|
+
def _truncate_text(self, text: str, max_chars: int) -> str:
|
|
188
|
+
"""
|
|
189
|
+
Truncate text to a maximum number of characters (adds an ellipsis when truncated).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
text: Input text.
|
|
193
|
+
max_chars: Maximum characters to keep. If <= 0, returns the original text.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Truncated or original text.
|
|
197
|
+
"""
|
|
198
|
+
if not max_chars or max_chars <= 0 or not text:
|
|
199
|
+
return text or ""
|
|
200
|
+
return (text[:max_chars] + "…") if len(text) > max_chars else text
|
|
167
201
|
|
|
168
|
-
|
|
169
|
-
def build_stratified_fewshot_prompt(
|
|
202
|
+
def build_few_shot_terms_block(
|
|
170
203
|
self,
|
|
171
|
-
|
|
172
|
-
|
|
204
|
+
documents: List[Dict[str, Any]],
|
|
205
|
+
terms_to_documents: Dict[str, List[str]],
|
|
173
206
|
sample_size: int = 28,
|
|
174
207
|
seed: int = 123,
|
|
175
208
|
max_chars_per_text: int = 1200,
|
|
176
209
|
) -> str:
|
|
177
210
|
"""
|
|
178
|
-
|
|
211
|
+
Build and cache the few-shot block for term extraction.
|
|
212
|
+
|
|
213
|
+
Strategy:
|
|
214
|
+
- Create strata by associated terms (doc -> associated term list).
|
|
215
|
+
- Sample proportionally across strata.
|
|
216
|
+
- Deduplicate by document id and top up from remaining docs if needed.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
documents: Documents with keys: {"id","title","text"}.
|
|
220
|
+
terms_to_documents: Mapping term -> list of doc IDs.
|
|
221
|
+
sample_size: Desired number of examples in the block.
|
|
222
|
+
seed: RNG seed (local to this call).
|
|
223
|
+
max_chars_per_text: Text truncation limit per example.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
The formatted few-shot example block string.
|
|
179
227
|
"""
|
|
180
|
-
random.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
num_sample_docs = min(sample_size, num_total_docs)
|
|
191
|
-
|
|
192
|
-
# Load the map of term -> [list of document IDs]
|
|
193
|
-
with open(terms_path, "r", encoding="utf-8") as file_handle:
|
|
194
|
-
term_to_doc_map = json.load(file_handle)
|
|
195
|
-
|
|
196
|
-
# Invert map: document ID -> [list of terms]
|
|
197
|
-
doc_id_to_terms_map = defaultdict(list)
|
|
198
|
-
for term, doc_ids in term_to_doc_map.items():
|
|
199
|
-
for doc_id in doc_ids:
|
|
200
|
-
doc_id_to_terms_map[doc_id].append(term)
|
|
201
|
-
|
|
202
|
-
# Define strata (groups of documents associated with specific terms)
|
|
203
|
-
strata_map = defaultdict(list)
|
|
204
|
-
for doc in corpus_documents:
|
|
205
|
-
doc_id = doc.get("id", "")
|
|
206
|
-
associated_terms = doc_id_to_terms_map.get(doc_id, ["no_term"])
|
|
228
|
+
rng = random.Random(seed)
|
|
229
|
+
|
|
230
|
+
document_to_terms = self._invert_terms_to_docs_mapping(terms_to_documents)
|
|
231
|
+
total_documents = len(documents)
|
|
232
|
+
target_sample_count = min(int(sample_size), total_documents)
|
|
233
|
+
|
|
234
|
+
strata: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list)
|
|
235
|
+
for document in documents:
|
|
236
|
+
document_id = str(document.get("id", ""))
|
|
237
|
+
associated_terms = document_to_terms.get(document_id, ["no_term"])
|
|
207
238
|
for term in associated_terms:
|
|
208
|
-
|
|
239
|
+
strata[str(term)].append(document)
|
|
209
240
|
|
|
210
|
-
# Perform proportional sampling across strata
|
|
211
241
|
sampled_documents: List[Dict[str, Any]] = []
|
|
212
|
-
for
|
|
213
|
-
|
|
214
|
-
if num_stratum_docs == 0:
|
|
242
|
+
for docs_in_stratum in strata.values():
|
|
243
|
+
if not docs_in_stratum:
|
|
215
244
|
continue
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
245
|
+
proportion = len(docs_in_stratum) / max(1, total_documents)
|
|
246
|
+
stratum_quota = int(target_sample_count * proportion)
|
|
247
|
+
if stratum_quota > 0:
|
|
248
|
+
sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum))))
|
|
249
|
+
|
|
250
|
+
sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")}
|
|
251
|
+
final_documents = list(sampled_by_id.values())
|
|
252
|
+
|
|
253
|
+
if len(final_documents) > target_sample_count:
|
|
254
|
+
final_documents = rng.sample(final_documents, target_sample_count)
|
|
255
|
+
elif len(final_documents) < target_sample_count:
|
|
256
|
+
remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id]
|
|
257
|
+
additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents))
|
|
258
|
+
if additional_needed > 0:
|
|
259
|
+
final_documents.extend(rng.sample(remaining_documents, additional_needed))
|
|
260
|
+
|
|
261
|
+
lines: List[str] = []
|
|
262
|
+
for document in final_documents:
|
|
263
|
+
document_id = str(document.get("id", ""))
|
|
264
|
+
title = str(document.get("title", ""))
|
|
265
|
+
text = self._truncate_text(str(document.get("text", "")), max_chars_per_text)
|
|
266
|
+
associated_terms = document_to_terms.get(document_id, [])
|
|
267
|
+
|
|
268
|
+
lines.append(
|
|
269
|
+
"Document ID: {doc_id}\n"
|
|
270
|
+
"Title: {title}\n"
|
|
271
|
+
"Text: {text}\n"
|
|
272
|
+
"Associated Terms: {terms}\n"
|
|
273
|
+
"----------------------------------------".format(
|
|
274
|
+
doc_id=document_id,
|
|
275
|
+
title=title,
|
|
276
|
+
text=text,
|
|
277
|
+
terms=associated_terms,
|
|
226
278
|
)
|
|
227
|
-
|
|
228
|
-
# Deduplicate sampled documents by ID and adjust count to exactly 'sample_size'
|
|
229
|
-
unique_docs_by_id = {}
|
|
230
|
-
for doc in sampled_documents:
|
|
231
|
-
unique_docs_by_id[doc.get("id", "")] = doc
|
|
232
|
-
|
|
233
|
-
final_sample_docs = list(unique_docs_by_id.values())
|
|
234
|
-
|
|
235
|
-
if len(final_sample_docs) > num_sample_docs:
|
|
236
|
-
final_sample_docs = random.sample(final_sample_docs, num_sample_docs)
|
|
237
|
-
elif len(final_sample_docs) < num_sample_docs:
|
|
238
|
-
remaining_docs = [
|
|
239
|
-
d for d in corpus_documents if d.get("id", "") not in unique_docs_by_id
|
|
240
|
-
]
|
|
241
|
-
needed_count = min(
|
|
242
|
-
num_sample_docs - len(final_sample_docs), len(remaining_docs)
|
|
243
|
-
)
|
|
244
|
-
final_sample_docs.extend(random.sample(remaining_docs, needed_count))
|
|
245
|
-
|
|
246
|
-
# Format the few-shot exemplar text block
|
|
247
|
-
prompt_lines: List[str] = []
|
|
248
|
-
for doc in final_sample_docs:
|
|
249
|
-
doc_id = doc.get("id", "")
|
|
250
|
-
title = doc.get("title", "")
|
|
251
|
-
text = doc.get("text", "")
|
|
252
|
-
|
|
253
|
-
# Truncate text if it exceeds the maximum character limit
|
|
254
|
-
if max_chars_per_text and len(text) > max_chars_per_text:
|
|
255
|
-
text = text[:max_chars_per_text] + "…"
|
|
256
|
-
|
|
257
|
-
associated_terms = doc_id_to_terms_map.get(doc_id, [])
|
|
258
|
-
prompt_lines.append(
|
|
259
|
-
f"Document ID: {doc_id}\nTitle: {title}\nText: {text}\nAssociated Terms: {associated_terms}\n----------------------------------------"
|
|
260
279
|
)
|
|
261
280
|
|
|
262
|
-
|
|
263
|
-
self.
|
|
264
|
-
return prompt_block
|
|
281
|
+
self.few_shot_terms_block = "\n".join(lines)
|
|
282
|
+
return self.few_shot_terms_block
|
|
265
283
|
|
|
266
|
-
|
|
267
|
-
def build_types_fewshot_block(
|
|
284
|
+
def build_few_shot_types_block(
|
|
268
285
|
self,
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
286
|
+
documents: List[Dict[str, Any]],
|
|
287
|
+
terms_to_documents: Dict[str, List[str]],
|
|
288
|
+
terms_to_types: Optional[Dict[str, List[str]]] = None,
|
|
289
|
+
sample_size: int = 28,
|
|
290
|
+
seed: int = 123,
|
|
274
291
|
max_chars_per_text: int = 800,
|
|
275
292
|
) -> str:
|
|
276
293
|
"""
|
|
277
|
-
|
|
278
|
-
|
|
294
|
+
Build and cache the few-shot block for type (class) extraction.
|
|
295
|
+
|
|
296
|
+
Prefers doc->types derived from `terms_to_types`; if absent, falls back to treating
|
|
297
|
+
associated terms as "types" for stratification (behavior-preserving fallback).
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
documents: Documents with keys: {"id","title","text"}.
|
|
301
|
+
terms_to_documents: Mapping term -> list of doc IDs.
|
|
302
|
+
terms_to_types: Optional mapping term -> list of types.
|
|
303
|
+
sample_size: Desired number of examples in the block.
|
|
304
|
+
seed: RNG seed (local to this call).
|
|
305
|
+
max_chars_per_text: Text truncation limit per example.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
The formatted few-shot example block string.
|
|
279
309
|
"""
|
|
280
|
-
|
|
281
|
-
docs_by_id = {}
|
|
282
|
-
with open(docs_jsonl, "r", encoding="utf-8") as file_handle:
|
|
283
|
-
for line in file_handle:
|
|
284
|
-
line_stripped = line.strip()
|
|
285
|
-
if line_stripped:
|
|
286
|
-
try:
|
|
287
|
-
doc = json.loads(line_stripped)
|
|
288
|
-
doc_id = doc.get("id", "")
|
|
289
|
-
if doc_id:
|
|
290
|
-
docs_by_id[doc_id] = doc
|
|
291
|
-
except json.JSONDecodeError:
|
|
292
|
-
continue
|
|
293
|
-
|
|
294
|
-
# Load term -> [doc_id,...] map
|
|
295
|
-
with open(terms2doc_json, "r", encoding="utf-8") as file_handle:
|
|
296
|
-
term_to_doc_map = json.load(file_handle)
|
|
297
|
-
|
|
298
|
-
flags = 0 if case_sensitive else re.IGNORECASE
|
|
299
|
-
prompt_lines: List[str] = []
|
|
300
|
-
|
|
301
|
-
# Iterate over terms (which act as types in this context)
|
|
302
|
-
for term, doc_ids in term_to_doc_map.items():
|
|
303
|
-
escaped_term = re.escape(term)
|
|
304
|
-
# Create regex pattern for matching the term in the text
|
|
305
|
-
pattern = rf"\b{escaped_term}\b" if full_word else escaped_term
|
|
306
|
-
term_regex = re.compile(pattern, flags=flags)
|
|
307
|
-
|
|
308
|
-
picked_count = 0
|
|
309
|
-
for doc_id in doc_ids:
|
|
310
|
-
doc = docs_by_id.get(doc_id)
|
|
311
|
-
if not doc:
|
|
312
|
-
continue
|
|
313
|
-
|
|
314
|
-
title = doc.get("title", "")
|
|
315
|
-
text = doc.get("text", "")
|
|
316
|
-
|
|
317
|
-
# Check if the term/type is actually present in the document text/title
|
|
318
|
-
if term_regex.search(f"{title} {text}"):
|
|
319
|
-
text_content = text
|
|
320
|
-
|
|
321
|
-
# Truncate text if necessary
|
|
322
|
-
if max_chars_per_text and len(text_content) > max_chars_per_text:
|
|
323
|
-
text_content = text_content[:max_chars_per_text] + "…"
|
|
324
|
-
|
|
325
|
-
# Escape single quotes in the term for Python list formatting in the prompt
|
|
326
|
-
term_for_prompt = term.replace("'", "\\'")
|
|
327
|
-
|
|
328
|
-
prompt_lines.append(
|
|
329
|
-
f"Document ID: {doc_id}\nTitle: {title}\nText: {text_content}\nAssociated Types: ['{term_for_prompt}']\n----------------------------------------"
|
|
330
|
-
)
|
|
331
|
-
picked_count += 1
|
|
332
|
-
|
|
333
|
-
if picked_count >= sample_per_term:
|
|
334
|
-
break # Move to the next term
|
|
335
|
-
|
|
336
|
-
prompt_block = "\n".join(prompt_lines)
|
|
337
|
-
self.fewshot_types_block = prompt_block
|
|
338
|
-
return prompt_block
|
|
310
|
+
rng = random.Random(seed)
|
|
339
311
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
312
|
+
documents_by_id = {str(d.get("id", "")): d for d in documents if d.get("id", "")}
|
|
313
|
+
|
|
314
|
+
document_to_types = self._derive_document_to_types(terms_to_documents, terms_to_types)
|
|
315
|
+
if not document_to_types:
|
|
316
|
+
document_to_types = self._invert_terms_to_docs_mapping(terms_to_documents)
|
|
317
|
+
|
|
318
|
+
type_to_documents: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list)
|
|
319
|
+
for document_id, candidate_types in document_to_types.items():
|
|
320
|
+
document = documents_by_id.get(document_id)
|
|
321
|
+
if not document:
|
|
322
|
+
continue
|
|
323
|
+
for candidate_type in candidate_types:
|
|
324
|
+
type_to_documents[str(candidate_type)].append(document)
|
|
325
|
+
|
|
326
|
+
total_documents = len(documents)
|
|
327
|
+
target_sample_count = min(int(sample_size), total_documents)
|
|
328
|
+
|
|
329
|
+
sampled_documents: List[Dict[str, Any]] = []
|
|
330
|
+
for docs_in_stratum in type_to_documents.values():
|
|
331
|
+
if not docs_in_stratum:
|
|
332
|
+
continue
|
|
333
|
+
proportion = len(docs_in_stratum) / max(1, total_documents)
|
|
334
|
+
stratum_quota = int(target_sample_count * proportion)
|
|
335
|
+
if stratum_quota > 0:
|
|
336
|
+
sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum))))
|
|
337
|
+
|
|
338
|
+
sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")}
|
|
339
|
+
final_documents = list(sampled_by_id.values())
|
|
340
|
+
|
|
341
|
+
if len(final_documents) > target_sample_count:
|
|
342
|
+
final_documents = rng.sample(final_documents, target_sample_count)
|
|
343
|
+
elif len(final_documents) < target_sample_count:
|
|
344
|
+
remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id]
|
|
345
|
+
additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents))
|
|
346
|
+
if additional_needed > 0:
|
|
347
|
+
final_documents.extend(rng.sample(remaining_documents, additional_needed))
|
|
348
|
+
|
|
349
|
+
lines: List[str] = []
|
|
350
|
+
for document in final_documents:
|
|
351
|
+
document_id = str(document.get("id", ""))
|
|
352
|
+
title = str(document.get("title", ""))
|
|
353
|
+
text = self._truncate_text(str(document.get("text", "")), max_chars_per_text)
|
|
354
|
+
|
|
355
|
+
associated_types = document_to_types.get(document_id, [])
|
|
356
|
+
associated_types_escaped = [t.replace("'", "\\'") for t in associated_types]
|
|
357
|
+
|
|
358
|
+
lines.append(
|
|
359
|
+
"Document ID: {doc_id}\n"
|
|
360
|
+
"Title: {title}\n"
|
|
361
|
+
"Text: {text}\n"
|
|
362
|
+
"Associated Types: {types}\n"
|
|
363
|
+
"----------------------------------------".format(
|
|
364
|
+
doc_id=document_id,
|
|
365
|
+
title=title,
|
|
366
|
+
text=text,
|
|
367
|
+
types=associated_types_escaped,
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
self.few_shot_types_block = "\n".join(lines)
|
|
372
|
+
return self.few_shot_types_block
|
|
373
|
+
|
|
374
|
+
def _format_term_prompt(self, example_block: str, title: str, text: str) -> str:
|
|
347
375
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
376
|
+
Format a prompt for term extraction.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
example_block: Few-shot examples block.
|
|
380
|
+
title: Document title.
|
|
381
|
+
text: Document text.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
Prompt string.
|
|
350
385
|
"""
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
386
|
+
return (
|
|
387
|
+
f"{example_block}\n"
|
|
388
|
+
"[var]\n"
|
|
389
|
+
f"Title: {title}\n"
|
|
390
|
+
f"Text: {text}\n"
|
|
391
|
+
"[var]\n"
|
|
392
|
+
"Extract all relevant terms that could form the basis of an ontology from the above document.\n"
|
|
393
|
+
"Return ONLY a Python list like ['term1', 'term2', ...] and nothing else.\n"
|
|
394
|
+
"If no terms are found, return [].\n"
|
|
354
395
|
)
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
396
|
+
|
|
397
|
+
def _format_type_prompt(self, example_block: str, title: str, text: str) -> str:
|
|
398
|
+
"""
|
|
399
|
+
Format a prompt for type (class) extraction.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
example_block: Few-shot examples block.
|
|
403
|
+
title: Document title.
|
|
404
|
+
text: Document text.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Prompt string.
|
|
408
|
+
"""
|
|
409
|
+
return (
|
|
410
|
+
f"{example_block}\n"
|
|
411
|
+
"[var]\n"
|
|
412
|
+
f"Title: {title}\n"
|
|
413
|
+
f"Text: {text}\n"
|
|
414
|
+
"[var]\n"
|
|
415
|
+
"Extract all relevant TYPES mentioned in the above document that could serve as ontology classes.\n"
|
|
416
|
+
"Only consider content inside the [var] ... [var] block.\n"
|
|
417
|
+
"Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return [].\n"
|
|
358
418
|
)
|
|
359
419
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
return f"""{example_block}
|
|
378
|
-
[var]
|
|
379
|
-
Title: {title}
|
|
380
|
-
Text: {text}
|
|
381
|
-
[var]
|
|
382
|
-
Extract all relevant TYPES mentioned in the above document that could serve as ontology classes.
|
|
383
|
-
Only consider content inside the [var] ... [var] block.
|
|
384
|
-
Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return [].
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
def _parse_list_like(self, raw_string: str) -> List[str]:
|
|
388
|
-
"""Try to extract a Python list of strings from model output robustly."""
|
|
389
|
-
processed_string = raw_string.strip()
|
|
390
|
-
if processed_string in ("[]", ""):
|
|
420
|
+
def _parse_python_list_of_strings(self, raw_text: str) -> List[str]:
|
|
421
|
+
"""
|
|
422
|
+
Parse an LLM response intended to be a Python list of strings.
|
|
423
|
+
|
|
424
|
+
This parser is intentionally tolerant:
|
|
425
|
+
1) Try literal_eval on the full string
|
|
426
|
+
2) Else extract the first [...] block and literal_eval it
|
|
427
|
+
3) Else fallback to extracting quoted strings
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
raw_text: Model output.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
List of strings (possibly empty).
|
|
434
|
+
"""
|
|
435
|
+
stripped = (raw_text or "").strip()
|
|
436
|
+
if stripped in ("", "[]"):
|
|
391
437
|
return []
|
|
392
438
|
|
|
393
|
-
# 1. Try direct evaluation
|
|
394
439
|
try:
|
|
395
|
-
|
|
396
|
-
if isinstance(
|
|
397
|
-
|
|
398
|
-
return [item for item in parsed_value if isinstance(item, str)]
|
|
440
|
+
parsed = ast.literal_eval(stripped)
|
|
441
|
+
if isinstance(parsed, list):
|
|
442
|
+
return [item for item in parsed if isinstance(item, str)]
|
|
399
443
|
except Exception:
|
|
400
444
|
pass
|
|
401
445
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
if bracket_match:
|
|
446
|
+
match = re.search(r"\[[\s\S]*?\]", stripped)
|
|
447
|
+
if match:
|
|
405
448
|
try:
|
|
406
|
-
|
|
407
|
-
if isinstance(
|
|
408
|
-
return [item for item in
|
|
449
|
+
parsed = ast.literal_eval(match.group(0))
|
|
450
|
+
if isinstance(parsed, list):
|
|
451
|
+
return [item for item in parsed if isinstance(item, str)]
|
|
409
452
|
except Exception:
|
|
410
453
|
pass
|
|
411
454
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
quoted_matches = re.findall(r"'([^']+)'|\"([^\"]+)\"", processed_string)
|
|
415
|
-
flattened_list = [a_match or b_match for a_match, b_match in quoted_matches]
|
|
416
|
-
return flattened_list
|
|
417
|
-
|
|
418
|
-
def _call_model_one(self, prompt: str, max_new_tokens: int = 120) -> str:
|
|
419
|
-
"""Calls the underlying LocalAutoLLM for a single prompt. Returns the raw tail output."""
|
|
420
|
-
# self.model is an instance of LocalAutoLLM
|
|
421
|
-
model_output = self.model.generate(
|
|
422
|
-
[prompt], max_new_tokens=max_new_tokens, temperature=0.0, top_p=1.0
|
|
423
|
-
)
|
|
424
|
-
return model_output[0] if model_output else ""
|
|
455
|
+
quoted = re.findall(r"'([^']+)'|\"([^\"]+)\"", stripped)
|
|
456
|
+
return [a or b for a, b in quoted]
|
|
425
457
|
|
|
426
|
-
def
|
|
427
|
-
self,
|
|
428
|
-
docs_test_jsonl: str,
|
|
429
|
-
out_jsonl: str,
|
|
430
|
-
max_lines: int = -1,
|
|
431
|
-
max_new_tokens: int = 120,
|
|
432
|
-
) -> int:
|
|
458
|
+
def _generate_completion(self, prompt_text: str) -> str:
|
|
433
459
|
"""
|
|
434
|
-
|
|
435
|
-
|
|
460
|
+
Generate a completion for a single prompt (deterministic decoding).
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
prompt_text: Full prompt to send to the model.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
The generated completion text (prompt stripped where possible).
|
|
436
467
|
"""
|
|
437
|
-
if
|
|
438
|
-
raise RuntimeError("
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
)
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
)
|
|
470
|
-
num_written_terms += 1
|
|
471
|
-
|
|
472
|
-
# Lightweight memory management for long runs
|
|
473
|
-
if line_index % 50 == 0:
|
|
474
|
-
gc.collect()
|
|
475
|
-
if torch.cuda.is_available():
|
|
476
|
-
torch.cuda.empty_cache()
|
|
477
|
-
|
|
478
|
-
return num_written_terms
|
|
479
|
-
|
|
480
|
-
def predict_types(
|
|
468
|
+
if self.model is None or self.tokenizer is None:
|
|
469
|
+
raise RuntimeError("Model/tokenizer not loaded. Call .load() first.")
|
|
470
|
+
|
|
471
|
+
encoded = self.tokenizer([prompt_text], return_tensors="pt", padding=True, truncation=True)
|
|
472
|
+
input_ids = encoded["input_ids"]
|
|
473
|
+
attention_mask = encoded["attention_mask"]
|
|
474
|
+
|
|
475
|
+
model_device = next(self.model.parameters()).device
|
|
476
|
+
input_ids = input_ids.to(model_device)
|
|
477
|
+
attention_mask = attention_mask.to(model_device)
|
|
478
|
+
|
|
479
|
+
with torch.no_grad():
|
|
480
|
+
output_ids = self.model.generate(
|
|
481
|
+
input_ids=input_ids,
|
|
482
|
+
attention_mask=attention_mask,
|
|
483
|
+
max_new_tokens=self.max_new_tokens,
|
|
484
|
+
do_sample=False,
|
|
485
|
+
temperature=0.0,
|
|
486
|
+
top_p=1.0,
|
|
487
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
488
|
+
)[0]
|
|
489
|
+
|
|
490
|
+
decoded_full = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
491
|
+
decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
492
|
+
|
|
493
|
+
if decoded_full.startswith(decoded_prompt):
|
|
494
|
+
return decoded_full[len(decoded_prompt) :].strip()
|
|
495
|
+
|
|
496
|
+
prompt_token_count = int(attention_mask[0].sum().item())
|
|
497
|
+
return self.tokenizer.decode(output_ids[prompt_token_count:], skip_special_tokens=True).strip()
|
|
498
|
+
|
|
499
|
+
def fit(
|
|
481
500
|
self,
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
) ->
|
|
487
|
-
"""
|
|
488
|
-
Runs Type Extraction on the test documents and saves results to a JSONL file.
|
|
489
|
-
Returns: The count of individual types written.
|
|
490
|
-
"""
|
|
491
|
-
if not self.fewshot_types_block:
|
|
492
|
-
raise RuntimeError("Few-shot block for types is empty. Call fit() first.")
|
|
493
|
-
|
|
494
|
-
num_written_types = 0
|
|
495
|
-
with (
|
|
496
|
-
open(docs_test_jsonl, "r", encoding="utf-8") as file_in,
|
|
497
|
-
open(out_jsonl, "w", encoding="utf-8") as file_out,
|
|
498
|
-
):
|
|
499
|
-
for line_index, line in enumerate(file_in, start=1):
|
|
500
|
-
if 0 < max_lines < line_index:
|
|
501
|
-
break
|
|
502
|
-
|
|
503
|
-
try:
|
|
504
|
-
document = json.loads(line.strip())
|
|
505
|
-
except Exception:
|
|
506
|
-
continue # Skip malformed JSON lines
|
|
507
|
-
|
|
508
|
-
doc_id = document.get("id", "unknown")
|
|
509
|
-
title = document.get("title", "")
|
|
510
|
-
text = document.get("text", "")
|
|
511
|
-
|
|
512
|
-
# Construct and call model using the dedicated type prompt block
|
|
513
|
-
prompt = self._build_type_prompt(self.fewshot_types_block, title, text)
|
|
514
|
-
raw_output = self._call_model_one(prompt, max_new_tokens=max_new_tokens)
|
|
515
|
-
predicted_types = self._parse_list_like(raw_output)
|
|
516
|
-
|
|
517
|
-
# Write extracted types
|
|
518
|
-
for term_or_type in predicted_types:
|
|
519
|
-
if isinstance(term_or_type, str) and term_or_type.strip():
|
|
520
|
-
file_out.write(
|
|
521
|
-
json.dumps({"doc_id": doc_id, "type": term_or_type.strip()})
|
|
522
|
-
+ "\n"
|
|
523
|
-
)
|
|
524
|
-
num_written_types += 1
|
|
525
|
-
|
|
526
|
-
if line_index % 50 == 0:
|
|
527
|
-
gc.collect()
|
|
528
|
-
if torch.cuda.is_available():
|
|
529
|
-
torch.cuda.empty_cache()
|
|
530
|
-
|
|
531
|
-
return num_written_types
|
|
532
|
-
|
|
533
|
-
# --- Evaluation utilities (unchanged from prior definition, added docstrings) ---
|
|
534
|
-
def load_gold_pairs(self, terms2doc_path: str) -> Set[Tuple[str, str]]:
|
|
535
|
-
"""Convert terms2docs JSON into a set of unique (doc_id, term) pairs, lowercased."""
|
|
536
|
-
gold_pairs = set()
|
|
537
|
-
with open(terms2doc_path, "r", encoding="utf-8") as file_handle:
|
|
538
|
-
term_to_doc_map = json.load(file_handle)
|
|
539
|
-
|
|
540
|
-
for term, doc_ids in term_to_doc_map.items():
|
|
541
|
-
clean_term = term.strip().lower()
|
|
542
|
-
for doc_id in doc_ids:
|
|
543
|
-
gold_pairs.add((doc_id, clean_term))
|
|
544
|
-
return gold_pairs
|
|
545
|
-
|
|
546
|
-
def load_predicted_pairs(
|
|
547
|
-
self, predicted_jsonl_path: str, key: str = "term"
|
|
548
|
-
) -> Set[Tuple[str, str]]:
|
|
549
|
-
"""Load predicted (doc_id, term/type) pairs from a JSONL file, lowercased."""
|
|
550
|
-
predicted_pairs = set()
|
|
551
|
-
with open(predicted_jsonl_path, "r", encoding="utf-8") as file_handle:
|
|
552
|
-
for line in file_handle:
|
|
553
|
-
try:
|
|
554
|
-
entry = json.loads(line.strip())
|
|
555
|
-
except Exception:
|
|
556
|
-
continue
|
|
557
|
-
doc_id = entry.get("doc_id")
|
|
558
|
-
value = entry.get(key)
|
|
559
|
-
if doc_id and value:
|
|
560
|
-
predicted_pairs.add((doc_id, value.strip().lower()))
|
|
561
|
-
return predicted_pairs
|
|
562
|
-
|
|
563
|
-
def evaluate_extraction_f1(
|
|
564
|
-
self, terms2doc_path: str, predicted_jsonl: str, key: str = "term"
|
|
565
|
-
) -> float:
|
|
501
|
+
train_data: Any,
|
|
502
|
+
task: str = "text2onto",
|
|
503
|
+
ontologizer: bool = False,
|
|
504
|
+
**kwargs: Any,
|
|
505
|
+
) -> None:
|
|
566
506
|
"""
|
|
567
|
-
|
|
507
|
+
Build and cache few-shot blocks from the training split.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
train_data: A split bundle dict. Must contain "documents" and "terms2docs".
|
|
511
|
+
task: Must be "text2onto".
|
|
512
|
+
ontologizer: Unused here (kept for signature compatibility).
|
|
513
|
+
**kwargs:
|
|
514
|
+
sample_size: Few-shot sample size per block (default 28).
|
|
515
|
+
seed: RNG seed (default 123).
|
|
568
516
|
"""
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
predicted_set = self.load_predicted_pairs(predicted_jsonl, key=key)
|
|
517
|
+
if task != "text2onto":
|
|
518
|
+
raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).")
|
|
572
519
|
|
|
573
|
-
|
|
574
|
-
|
|
520
|
+
if not self._is_loaded:
|
|
521
|
+
self.load(model_id=self._default_model_id)
|
|
575
522
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
523
|
+
documents: List[Dict[str, Any]] = train_data.get("documents", []) or []
|
|
524
|
+
terms_to_documents: Dict[str, List[str]] = train_data.get("terms2docs", {}) or {}
|
|
525
|
+
terms_to_types: Optional[Dict[str, List[str]]] = train_data.get("terms2types", None)
|
|
579
526
|
|
|
580
|
-
|
|
581
|
-
|
|
527
|
+
sample_size = int(kwargs.get("sample_size", 28))
|
|
528
|
+
seed = int(kwargs.get("seed", 123))
|
|
582
529
|
|
|
583
|
-
|
|
584
|
-
|
|
530
|
+
self.build_few_shot_terms_block(
|
|
531
|
+
documents=documents,
|
|
532
|
+
terms_to_documents=terms_to_documents,
|
|
533
|
+
sample_size=sample_size,
|
|
534
|
+
seed=seed,
|
|
535
|
+
)
|
|
536
|
+
self.build_few_shot_types_block(
|
|
537
|
+
documents=documents,
|
|
538
|
+
terms_to_documents=terms_to_documents,
|
|
539
|
+
terms_to_types=terms_to_types,
|
|
540
|
+
sample_size=sample_size,
|
|
541
|
+
seed=seed,
|
|
585
542
|
)
|
|
586
543
|
|
|
587
|
-
|
|
588
|
-
|
|
544
|
+
def predict(
|
|
545
|
+
self,
|
|
546
|
+
test_data: Any,
|
|
547
|
+
task: str = "text2onto",
|
|
548
|
+
ontologizer: bool = False,
|
|
549
|
+
**kwargs: Any,
|
|
550
|
+
) -> Dict[str, Any]:
|
|
551
|
+
"""
|
|
552
|
+
Run term/type extraction over test documents.
|
|
589
553
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
print(f" 🎯 True Positives: {num_true_positives}")
|
|
554
|
+
Args:
|
|
555
|
+
test_data: A split bundle dict. Must contain "documents".
|
|
556
|
+
task: Must be "text2onto".
|
|
557
|
+
ontologizer: Unused here (kept for signature compatibility).
|
|
558
|
+
**kwargs:
|
|
559
|
+
max_docs: If > 0, limit number of docs processed.
|
|
597
560
|
|
|
598
|
-
|
|
561
|
+
Returns:
|
|
562
|
+
Prediction payload dict: {"terms": [...], "types": [...]}.
|
|
563
|
+
"""
|
|
564
|
+
if task != "text2onto":
|
|
565
|
+
raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).")
|
|
566
|
+
|
|
567
|
+
if not self.few_shot_terms_block or not self.few_shot_types_block:
|
|
568
|
+
raise RuntimeError("Few-shot blocks are empty. Pipeline should call fit() before predict().")
|
|
569
|
+
|
|
570
|
+
max_docs = int(kwargs.get("max_docs", -1))
|
|
571
|
+
documents: List[Dict[str, Any]] = test_data.get("documents", []) or []
|
|
572
|
+
if max_docs > 0:
|
|
573
|
+
documents = documents[:max_docs]
|
|
574
|
+
|
|
575
|
+
term_predictions: List[Dict[str, str]] = []
|
|
576
|
+
type_predictions: List[Dict[str, str]] = []
|
|
577
|
+
|
|
578
|
+
for doc_index, document in enumerate(documents, start=1):
|
|
579
|
+
document_id = str(document.get("id", "unknown"))
|
|
580
|
+
title = str(document.get("title", ""))
|
|
581
|
+
text = str(document.get("text", ""))
|
|
582
|
+
|
|
583
|
+
term_prompt = self._format_term_prompt(self.few_shot_terms_block, title, text)
|
|
584
|
+
extracted_terms = self._parse_python_list_of_strings(self._generate_completion(term_prompt))
|
|
585
|
+
for term in extracted_terms:
|
|
586
|
+
normalized_term = (term or "").strip()
|
|
587
|
+
if normalized_term:
|
|
588
|
+
term_predictions.append({"doc_id": document_id, "term": normalized_term})
|
|
589
|
+
|
|
590
|
+
type_prompt = self._format_type_prompt(self.few_shot_types_block, title, text)
|
|
591
|
+
extracted_types = self._parse_python_list_of_strings(self._generate_completion(type_prompt))
|
|
592
|
+
for extracted_type in extracted_types:
|
|
593
|
+
normalized_type = (extracted_type or "").strip()
|
|
594
|
+
if normalized_type:
|
|
595
|
+
type_predictions.append({"doc_id": document_id, "type": normalized_type})
|
|
596
|
+
|
|
597
|
+
if doc_index % 50 == 0:
|
|
598
|
+
gc.collect()
|
|
599
|
+
if torch.cuda.is_available():
|
|
600
|
+
torch.cuda.empty_cache()
|
|
601
|
+
|
|
602
|
+
# IMPORTANT: return only the prediction payload; LearnerPipeline wraps it.
|
|
603
|
+
return {"terms": term_predictions, "types": type_predictions}
|