OntoLearner 1.4.10__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +41 -18
- ontolearner/evaluation/metrics.py +72 -32
- ontolearner/learner/__init__.py +3 -2
- ontolearner/learner/label_mapper.py +5 -4
- ontolearner/learner/llm.py +257 -0
- ontolearner/learner/prompt.py +40 -5
- ontolearner/learner/rag/__init__.py +14 -0
- ontolearner/learner/{rag.py → rag/rag.py} +7 -2
- ontolearner/learner/retriever/__init__.py +1 -1
- ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
- ontolearner/learner/retriever/learner.py +3 -4
- ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
- ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
- ontolearner/learner/text2onto/__init__.py +1 -1
- ontolearner/learner/text2onto/alexbek.py +484 -1105
- ontolearner/learner/text2onto/sbunlp.py +498 -493
- ontolearner/ontology/biology.py +2 -3
- ontolearner/ontology/chemistry.py +16 -18
- ontolearner/ontology/ecology_environment.py +2 -3
- ontolearner/ontology/general.py +4 -6
- ontolearner/ontology/material_science_engineering.py +64 -45
- ontolearner/ontology/medicine.py +2 -3
- ontolearner/ontology/scholarly_knowledge.py +6 -9
- ontolearner/processor.py +3 -3
- ontolearner/text2onto/splitter.py +69 -6
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/METADATA +2 -2
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/RECORD +30 -29
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,1208 +12,587 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any, Dict, List, Optional, Tuple, Iterable
|
|
16
15
|
import json
|
|
17
|
-
from json.decoder import JSONDecodeError
|
|
18
|
-
import os
|
|
19
|
-
import random
|
|
20
16
|
import re
|
|
17
|
+
from typing import Any, Dict, List, Optional
|
|
18
|
+
from collections import defaultdict
|
|
21
19
|
|
|
22
20
|
import torch
|
|
23
21
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
24
22
|
|
|
25
|
-
from ...base import AutoLearner,
|
|
23
|
+
from ...base import AutoLearner, AutoRetriever
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
from outlines.models import Transformers as OutlinesTFModel
|
|
29
|
-
from outlines.generate import json as outlines_generate_json
|
|
30
|
-
from pydantic import BaseModel
|
|
31
|
-
|
|
32
|
-
class _PredictedTypesSchema(BaseModel):
|
|
33
|
-
"""Schema used when generating structured JSON { "types": [...] }."""
|
|
34
|
-
|
|
35
|
-
types: List[str]
|
|
36
|
-
|
|
37
|
-
OUTLINES_AVAILABLE: bool = True
|
|
38
|
-
except Exception:
|
|
39
|
-
# If outlines is unavailable, we will fall back to greedy decoding + regex parsing.
|
|
40
|
-
OUTLINES_AVAILABLE = False
|
|
41
|
-
_PredictedTypesSchema = None
|
|
42
|
-
OutlinesTFModel = None
|
|
43
|
-
outlines_generate_json = None
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class LocalAutoLLM(AutoLLM):
|
|
25
|
+
class AlexbekRAGFewShotLearner(AutoLearner):
|
|
47
26
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
27
|
+
What it does (2-stage):
|
|
28
|
+
1) doc -> terms
|
|
29
|
+
- retrieve top-k similar TRAIN documents (each has gold OL terms)
|
|
30
|
+
- build a few-shot chat prompt: (doc -> {"terms":[...]}) examples + target doc
|
|
31
|
+
- generate JSON {"terms":[...]} and parse it
|
|
32
|
+
|
|
33
|
+
2) term -> types
|
|
34
|
+
- retrieve top-k similar TRAIN terms (each has gold types)
|
|
35
|
+
- build a few-shot chat prompt: (term -> {"types":[...]}) examples + target term
|
|
36
|
+
- generate JSON {"types":[...]} and parse it
|
|
37
|
+
|
|
38
|
+
Training behavior (fit):
|
|
39
|
+
- builds two retrieval indices:
|
|
40
|
+
* doc_retriever index over JSON strings of train docs (with "OL" field = gold terms)
|
|
41
|
+
* term_retriever index over JSON strings of train term->types examples
|
|
42
|
+
|
|
43
|
+
Prediction behavior (predict):
|
|
44
|
+
- returns a dict compatible with OntoLearner evaluation_report:
|
|
45
|
+
{
|
|
46
|
+
"terms": [{"doc_id": "...", "term": "..."}, ...],
|
|
47
|
+
"types": [{"doc_id": "...", "type": "..."}, ...],
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
Expected data format for task="text2onto":
|
|
51
|
+
data = {
|
|
52
|
+
"documents": [ {"id"/"doc_id": str, "title": str, "text": str, ...}, ... ],
|
|
53
|
+
"terms2docs": { term(str): [doc_id(str), ...], ... }
|
|
54
|
+
"terms2types": { term(str): [type(str), ...], ... }
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
IMPORTANT:
|
|
58
|
+
- LearnerPipeline calls learner.load(model_id=llm_id). We accept that and override llm_model_id.
|
|
59
|
+
- We override tasks_data_former() so AutoLearner.fit/predict does NOT rewrite text2onto dicts.
|
|
60
|
+
- Device placement: we put the model exactly on the device string the user provides
|
|
61
|
+
("cpu", "cuda", "cuda:0", "cuda:1", ...). No device_map="auto".
|
|
53
62
|
"""
|
|
54
63
|
|
|
55
|
-
|
|
64
|
+
TERM2TYPES_SYSTEM_PROMPT = (
|
|
65
|
+
"You are an expert in ontology and semantic type classification. Your task is to predict "
|
|
66
|
+
"the semantic types for given terms based on their context and similar examples.\n\n"
|
|
67
|
+
"Given a term, you should predict its semantic types from the domain-specific ontology. "
|
|
68
|
+
"Use the provided examples to understand the patterns and relationships between terms and their types.\n\n"
|
|
69
|
+
"Output your response as a JSON object with the following structure:\n"
|
|
70
|
+
'{\n "types": ["type1", "type2", ...]\n}\n\n'
|
|
71
|
+
"The types should be relevant semantic categories that best describe the given term."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
DOC2TERMS_SYSTEM_PROMPT = (
|
|
75
|
+
"You are an expert in ontology term extraction.\n\n"
|
|
76
|
+
"TASK: Extract specific, relevant ontology terms from scientific documents.\n\n"
|
|
77
|
+
"INSTRUCTIONS:\n"
|
|
78
|
+
"- The following conversation contains few-shot examples showing correct term extraction patterns\n"
|
|
79
|
+
"- Study these examples carefully to understand the extraction style and approach\n"
|
|
80
|
+
"- Follow the EXACT same pattern and style demonstrated in the examples\n"
|
|
81
|
+
"- Extract only terms that actually appear in the document text\n"
|
|
82
|
+
"- Focus on domain-specific terminology, concepts, and technical terms\n\n"
|
|
83
|
+
"- The first three user-assistant conversation pairs serve as few-shot examples\n"
|
|
84
|
+
"- Each example shows: user provides a document, assistant extracts relevant terms\n"
|
|
85
|
+
"- Pay attention to the extraction patterns and term selection criteria in these examples\n\n"
|
|
86
|
+
"DO:\n"
|
|
87
|
+
"- Extract terms that are EXPLICITLY mentioned in the LAST document\n"
|
|
88
|
+
"- Follow the SAME extraction pattern as shown in examples\n"
|
|
89
|
+
"- Return unique terms without duplicates\n"
|
|
90
|
+
"- Use the same JSON format as demonstrated\n\n"
|
|
91
|
+
"DON'T:\n"
|
|
92
|
+
"- Hallucinate or invent terms not present in last the document\n"
|
|
93
|
+
"- Repeat the same term multiple times\n"
|
|
94
|
+
"- Deviate from the extraction style shown in examples\n\n"
|
|
95
|
+
"OUTPUT FORMAT: Return a JSON object with a single field 'terms' containing a list of extracted terms."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
llm_model_id: str,
|
|
101
|
+
retriever_model_id: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
102
|
+
device: str = "cpu",
|
|
103
|
+
top_k: int = 3,
|
|
104
|
+
max_new_tokens: int = 256,
|
|
105
|
+
max_input_length: int = 2048,
|
|
106
|
+
use_tfidf: bool = False,
|
|
107
|
+
seed: int = 42,
|
|
108
|
+
restrict_to_known_types: bool = True,
|
|
109
|
+
hf_token: str = "",
|
|
110
|
+
local_files_only: bool = False,
|
|
111
|
+
**kwargs: Any,
|
|
112
|
+
):
|
|
56
113
|
"""
|
|
57
|
-
Initialize the local LLM holder.
|
|
58
|
-
|
|
59
114
|
Parameters
|
|
60
115
|
----------
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
116
|
+
llm_model_id:
|
|
117
|
+
HuggingFace model id OR local path to a downloaded model directory.
|
|
118
|
+
retriever_model_id:
|
|
119
|
+
SentenceTransformer model id OR local path to a downloaded SBERT directory.
|
|
120
|
+
device:
|
|
121
|
+
Exact device string to place model on ("cpu", "cuda", "cuda:0", ...).
|
|
122
|
+
top_k:
|
|
123
|
+
Number of retrieved examples for few-shot prompting in each stage.
|
|
124
|
+
max_new_tokens:
|
|
125
|
+
Max tokens to generate for each prompt.
|
|
126
|
+
max_input_length:
|
|
127
|
+
Max prompt length before truncation.
|
|
128
|
+
use_tfidf:
|
|
129
|
+
If docs include TF-IDF suggestions (key "TF-IDF" or "tfidf_terms"), include them in prompts.
|
|
130
|
+
seed:
|
|
131
|
+
Seed for reproducibility.
|
|
132
|
+
restrict_to_known_types:
|
|
133
|
+
If True, append allowed type label list (from training) to system prompt in term->types stage.
|
|
134
|
+
This helps exact-match evaluation by discouraging invented labels.
|
|
135
|
+
hf_token:
|
|
136
|
+
HuggingFace token for gated models (optional).
|
|
137
|
+
local_files_only:
|
|
138
|
+
If True, Transformers will not try to reach the internet (requires local cache / local path).
|
|
139
|
+
"""
|
|
140
|
+
super().__init__(**kwargs)
|
|
141
|
+
|
|
142
|
+
self.llm_model_id: str = llm_model_id
|
|
143
|
+
self.retriever_model_id: str = retriever_model_id
|
|
144
|
+
self.device: str = device
|
|
145
|
+
self.top_k: int = int(top_k)
|
|
146
|
+
self.max_new_tokens: int = int(max_new_tokens)
|
|
147
|
+
self.max_input_length: int = int(max_input_length)
|
|
148
|
+
self.use_tfidf: bool = bool(use_tfidf)
|
|
149
|
+
self.seed: int = int(seed)
|
|
150
|
+
self.restrict_to_known_types: bool = bool(restrict_to_known_types)
|
|
151
|
+
self.hf_token: str = hf_token or ""
|
|
152
|
+
self.local_files_only: bool = bool(local_files_only)
|
|
153
|
+
|
|
67
154
|
self.model: Optional[AutoModelForCausalLM] = None
|
|
68
155
|
self.tokenizer: Optional[AutoTokenizer] = None
|
|
156
|
+
self._loaded: bool = False
|
|
69
157
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
generation defaults.
|
|
74
|
-
|
|
75
|
-
Parameters
|
|
76
|
-
----------
|
|
77
|
-
model_id : str
|
|
78
|
-
Model identifier resolvable by HF `from_pretrained`.
|
|
79
|
-
load_in_4bit : bool
|
|
80
|
-
If True and bitsandbytes is available, load using 4-bit quantization.
|
|
81
|
-
"""
|
|
82
|
-
# Tokenizer
|
|
83
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
84
|
-
model_id, padding_side="left", token=self.token
|
|
85
|
-
)
|
|
86
|
-
if self.tokenizer.pad_token is None:
|
|
87
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
88
|
-
|
|
89
|
-
# Model (optionally quantized)
|
|
90
|
-
if load_in_4bit:
|
|
91
|
-
from transformers import BitsAndBytesConfig
|
|
158
|
+
# Internal retrievers (always used in method-1, even in "llm-only" pipeline mode)
|
|
159
|
+
self.doc_retriever = AutoRetriever()
|
|
160
|
+
self.term_retriever = AutoRetriever()
|
|
92
161
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
bnb_4bit_use_double_quant=True,
|
|
97
|
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
98
|
-
)
|
|
99
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
100
|
-
model_id,
|
|
101
|
-
device_map="auto",
|
|
102
|
-
quantization_config=quantization_config,
|
|
103
|
-
token=self.token,
|
|
104
|
-
)
|
|
105
|
-
else:
|
|
106
|
-
device_map = (
|
|
107
|
-
"auto" if (self.device != "cpu" and torch.cuda.is_available()) else None
|
|
108
|
-
)
|
|
109
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
110
|
-
model_id,
|
|
111
|
-
device_map=device_map,
|
|
112
|
-
torch_dtype=torch.bfloat16
|
|
113
|
-
if torch.cuda.is_available()
|
|
114
|
-
else torch.float32,
|
|
115
|
-
token=self.token,
|
|
116
|
-
)
|
|
162
|
+
# Indexed corpora as JSON strings
|
|
163
|
+
self._doc_examples_json: List[str] = []
|
|
164
|
+
self._term_examples_json: List[str] = []
|
|
117
165
|
|
|
118
|
-
#
|
|
119
|
-
|
|
120
|
-
generation_cfg.do_sample = False
|
|
121
|
-
generation_cfg.temperature = None
|
|
122
|
-
generation_cfg.top_k = None
|
|
123
|
-
generation_cfg.top_p = None
|
|
124
|
-
generation_cfg.num_beams = 1
|
|
166
|
+
# Cached allowed type labels (for optional restriction)
|
|
167
|
+
self._allowed_types: List[str] = []
|
|
125
168
|
|
|
126
|
-
def
|
|
169
|
+
def tasks_data_former(self, data: Any, task: str, test: bool = False):
|
|
127
170
|
"""
|
|
128
|
-
|
|
171
|
+
Override base formatter: for task='text2onto' return data unchanged.
|
|
172
|
+
"""
|
|
173
|
+
if task == "text2onto":
|
|
174
|
+
return data
|
|
175
|
+
return super().tasks_data_former(data=data, task=task, test=test)
|
|
129
176
|
|
|
130
|
-
|
|
131
|
-
----------
|
|
132
|
-
prompts : List[str]
|
|
133
|
-
Prompts to generate for (batched).
|
|
134
|
-
max_new_tokens : int
|
|
135
|
-
Maximum number of new tokens per continuation.
|
|
136
|
-
|
|
137
|
-
Returns
|
|
138
|
-
-------
|
|
139
|
-
List[str]
|
|
140
|
-
Decoded new-token texts (no special tokens, stripped).
|
|
177
|
+
def load(self, **kwargs: Any):
|
|
141
178
|
"""
|
|
142
|
-
|
|
143
|
-
raise RuntimeError(
|
|
144
|
-
"Call .load(model_id) on LocalAutoLLM before generate()."
|
|
145
|
-
)
|
|
179
|
+
Called by LearnerPipeline as: learner.load(model_id=llm_id)
|
|
146
180
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
181
|
+
We accept overrides via kwargs:
|
|
182
|
+
- model_id / llm_model_id
|
|
183
|
+
- device, top_k, max_new_tokens, max_input_length, use_tfidf, seed, restrict_to_known_types
|
|
184
|
+
- hf_token, local_files_only
|
|
185
|
+
"""
|
|
186
|
+
model_id = kwargs.get("model_id") or kwargs.get("llm_model_id")
|
|
187
|
+
if model_id:
|
|
188
|
+
self.llm_model_id = str(model_id)
|
|
154
189
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
190
|
+
for k in [
|
|
191
|
+
"device",
|
|
192
|
+
"top_k",
|
|
193
|
+
"max_new_tokens",
|
|
194
|
+
"max_input_length",
|
|
195
|
+
"use_tfidf",
|
|
196
|
+
"seed",
|
|
197
|
+
"restrict_to_known_types",
|
|
198
|
+
"hf_token",
|
|
199
|
+
"local_files_only",
|
|
200
|
+
"retriever_model_id",
|
|
201
|
+
]:
|
|
202
|
+
if k in kwargs:
|
|
203
|
+
setattr(self, k, kwargs[k])
|
|
163
204
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
return [
|
|
167
|
-
self.tokenizer.decode(row, skip_special_tokens=True).strip()
|
|
168
|
-
for row in continuation_token_ids
|
|
169
|
-
]
|
|
205
|
+
if self._loaded:
|
|
206
|
+
return
|
|
170
207
|
|
|
208
|
+
torch.manual_seed(self.seed)
|
|
209
|
+
if torch.cuda.is_available():
|
|
210
|
+
torch.cuda.manual_seed_all(self.seed)
|
|
171
211
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
Public API (A1 + convenience):
|
|
177
|
-
- fit(train_docs_jsonl, terms2doc_json, sample_size=24, seed=42)
|
|
178
|
-
- predict_terms(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
|
|
179
|
-
- predict_types(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int
|
|
180
|
-
- evaluate_extraction_f1(gold_item2docs_json, preds_jsonl, key="term"|"type") -> float
|
|
181
|
-
|
|
182
|
-
Option A (A2, term→types) bridge:
|
|
183
|
-
- predict_types_from_terms_option_a(...)
|
|
184
|
-
Reads your A1 results (docs→terms), predicts types for each term, and
|
|
185
|
-
writes two files: terms2types_pred.json + types2docs_pred.json
|
|
186
|
-
"""
|
|
212
|
+
dev = str(self.device).strip()
|
|
213
|
+
if dev.startswith("cuda") and not torch.cuda.is_available():
|
|
214
|
+
raise RuntimeError(f"Device was set to '{dev}', but CUDA is not available.")
|
|
187
215
|
|
|
188
|
-
|
|
189
|
-
"""
|
|
190
|
-
Initialize learner state and canned prompts.
|
|
216
|
+
dtype = torch.bfloat16 if dev.startswith("cuda") else torch.float32
|
|
191
217
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
# Few-shot exemplars for A1 (Docs→Terms) and for Docs→Types:
|
|
204
|
-
# Each exemplar is a tuple: (title, text, gold_list)
|
|
205
|
-
self._fewshot_terms_docs: List[Tuple[str, str, List[str]]] = []
|
|
206
|
-
self._fewshot_types_docs: List[Tuple[str, str, List[str]]] = []
|
|
207
|
-
|
|
208
|
-
# System prompts
|
|
209
|
-
self._system_prompt_terms = (
|
|
210
|
-
"You are an expert in ontology term extraction.\n"
|
|
211
|
-
"Extract only terms that explicitly appear in the document.\n"
|
|
212
|
-
'Answer strictly as JSON: {"terms": ["..."]}\n'
|
|
213
|
-
)
|
|
214
|
-
self._system_prompt_types = (
|
|
215
|
-
"You are an expert in ontology type classification.\n"
|
|
216
|
-
"List ontology *types* that characterize the document’s terminology.\n"
|
|
217
|
-
'Answer strictly as JSON: {"types": ["..."]}\n'
|
|
218
|
-
)
|
|
218
|
+
tok_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only}
|
|
219
|
+
if self.hf_token:
|
|
220
|
+
tok_kwargs["token"] = self.hf_token
|
|
221
|
+
try:
|
|
222
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs)
|
|
223
|
+
except TypeError:
|
|
224
|
+
tok_kwargs.pop("token", None)
|
|
225
|
+
if self.hf_token:
|
|
226
|
+
tok_kwargs["use_auth_token"] = self.hf_token
|
|
227
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs)
|
|
219
228
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
self._json_array_regex = re.compile(r"\[[^\]]*\]", re.S)
|
|
229
|
+
if self.tokenizer.pad_token is None:
|
|
230
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
223
231
|
|
|
224
|
-
# Term→Types (Option A) specific prompt
|
|
225
|
-
self._system_prompt_term_to_types = (
|
|
226
|
-
"You are an expert in ontology and semantic type classification.\n"
|
|
227
|
-
"Given a term, predict its semantic types from the domain-specific ontology.\n"
|
|
228
|
-
'Answer strictly as JSON:\n{"types": ["type1", "type2", "..."]}'
|
|
229
|
-
)
|
|
230
232
|
|
|
231
|
-
|
|
232
|
-
self
|
|
233
|
-
|
|
234
|
-
train_docs_jsonl: str,
|
|
235
|
-
terms2doc_json: str,
|
|
236
|
-
sample_size: int = 24,
|
|
237
|
-
seed: int = 42,
|
|
238
|
-
) -> None:
|
|
239
|
-
"""
|
|
240
|
-
Build internal few-shot exemplars from a labeled training split.
|
|
233
|
+
model_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only}
|
|
234
|
+
if self.hf_token:
|
|
235
|
+
model_kwargs["token"] = self.hf_token
|
|
241
236
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
JSON mapping item -> [doc_id,...]; "item" can be a term or type.
|
|
248
|
-
sample_size : int
|
|
249
|
-
Number of exemplar documents to keep for few-shot prompting.
|
|
250
|
-
seed : int
|
|
251
|
-
RNG seed for reproducible sampling.
|
|
252
|
-
"""
|
|
253
|
-
rng = random.Random(seed)
|
|
254
|
-
|
|
255
|
-
# Load documents and map doc_id -> row
|
|
256
|
-
document_map = self._load_documents_jsonl(train_docs_jsonl)
|
|
257
|
-
if not document_map:
|
|
258
|
-
raise FileNotFoundError(f"No documents found in: {train_docs_jsonl}")
|
|
259
|
-
|
|
260
|
-
# Load item -> [doc_ids]
|
|
261
|
-
item_to_docs_map = self._load_json(terms2doc_json)
|
|
262
|
-
if not isinstance(item_to_docs_map, dict):
|
|
263
|
-
raise ValueError(
|
|
264
|
-
f"{terms2doc_json} must be a JSON dict mapping item -> [doc_ids]"
|
|
237
|
+
try:
|
|
238
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
239
|
+
self.llm_model_id,
|
|
240
|
+
dtype=dtype,
|
|
241
|
+
**model_kwargs,
|
|
265
242
|
)
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
exemplar_candidates: List[Tuple[str, str, List[str]]] = []
|
|
275
|
-
for doc_id, labeled_items in doc_id_to_items_map.items():
|
|
276
|
-
doc_row = document_map.get(doc_id)
|
|
277
|
-
if not doc_row:
|
|
278
|
-
continue
|
|
279
|
-
doc_title = str(doc_row.get("title", "")) # be defensive (may be None)
|
|
280
|
-
doc_text = self._to_text(
|
|
281
|
-
doc_row.get("text", "")
|
|
282
|
-
) # string-ify list if needed
|
|
283
|
-
if not doc_text:
|
|
284
|
-
continue
|
|
285
|
-
gold_items = self._unique_preserve(
|
|
286
|
-
[s for s in labeled_items if isinstance(s, str)]
|
|
243
|
+
except TypeError:
|
|
244
|
+
model_kwargs.pop("token", None)
|
|
245
|
+
if self.hf_token:
|
|
246
|
+
model_kwargs["use_auth_token"] = self.hf_token
|
|
247
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
248
|
+
self.llm_model_id,
|
|
249
|
+
torch_dtype=dtype,
|
|
250
|
+
**model_kwargs,
|
|
287
251
|
)
|
|
288
|
-
if gold_items:
|
|
289
|
-
exemplar_candidates.append((doc_title, doc_text, gold_items))
|
|
290
252
|
|
|
291
|
-
|
|
292
|
-
raise RuntimeError(
|
|
293
|
-
"No candidate docs with items found to build few-shot exemplars."
|
|
294
|
-
)
|
|
253
|
+
self.model = self.model.to(dev)
|
|
295
254
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
)
|
|
299
|
-
# Reuse exemplars for both docs→terms and docs→types prompting
|
|
300
|
-
self._fewshot_terms_docs = chosen_exemplars
|
|
301
|
-
self._fewshot_types_docs = chosen_exemplars
|
|
255
|
+
self.doc_retriever.load(self.retriever_model_id)
|
|
256
|
+
self.term_retriever.load(self.retriever_model_id)
|
|
302
257
|
|
|
303
|
-
|
|
304
|
-
self,
|
|
305
|
-
*,
|
|
306
|
-
docs_test_jsonl: str,
|
|
307
|
-
out_jsonl: str,
|
|
308
|
-
max_new_tokens: int = 128,
|
|
309
|
-
few_shot_k: int = 6,
|
|
310
|
-
) -> int:
|
|
311
|
-
"""
|
|
312
|
-
Extract terms that explicitly appear in each document.
|
|
258
|
+
self._loaded = True
|
|
313
259
|
|
|
314
|
-
Writes one JSON object per line:
|
|
315
|
-
{"id": "<doc_id>", "terms": ["...", "...", ...]}
|
|
316
260
|
|
|
317
|
-
|
|
318
|
-
----------
|
|
319
|
-
docs_test_jsonl : str
|
|
320
|
-
Path to test/dev documents in JSONL or tolerant JSON/JSONL.
|
|
321
|
-
out_jsonl : str
|
|
322
|
-
Output JSONL path where predictions are written (one line per doc).
|
|
323
|
-
max_new_tokens : int
|
|
324
|
-
Max generation length.
|
|
325
|
-
few_shot_k : int
|
|
326
|
-
Number of few-shot exemplars to prepend per prompt.
|
|
327
|
-
|
|
328
|
-
Returns
|
|
329
|
-
-------
|
|
330
|
-
int
|
|
331
|
-
Number of lines written (i.e., number of processed documents).
|
|
261
|
+
def _format_doc(self, title: str, text: str, tfidf: Optional[List[str]] = None) -> str:
|
|
332
262
|
"""
|
|
333
|
-
|
|
334
|
-
raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
|
|
335
|
-
|
|
336
|
-
test_documents = self._load_documents_jsonl(docs_test_jsonl)
|
|
337
|
-
prompts: List[str] = []
|
|
338
|
-
document_order: List[str] = []
|
|
339
|
-
|
|
340
|
-
for document_id, document_row in test_documents.items():
|
|
341
|
-
title = str(document_row.get("title", ""))
|
|
342
|
-
text = self._to_text(document_row.get("text", ""))
|
|
343
|
-
|
|
344
|
-
fewshot_block = self._format_fewshot_block(
|
|
345
|
-
self._system_prompt_terms,
|
|
346
|
-
self._fewshot_terms_docs,
|
|
347
|
-
key="terms",
|
|
348
|
-
k=few_shot_k,
|
|
349
|
-
)
|
|
350
|
-
user_block = self._format_user_block(title, text)
|
|
351
|
-
|
|
352
|
-
prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
|
|
353
|
-
document_order.append(document_id)
|
|
354
|
-
|
|
355
|
-
generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
|
|
356
|
-
parsed_term_lists = [
|
|
357
|
-
self._parse_json_list(generated, key="terms") for generated in generations
|
|
358
|
-
]
|
|
359
|
-
|
|
360
|
-
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
|
361
|
-
lines_written = 0
|
|
362
|
-
with open(out_jsonl, "w", encoding="utf-8") as fp_out:
|
|
363
|
-
for document_id, term_list in zip(document_order, parsed_term_lists):
|
|
364
|
-
payload = {"id": document_id, "terms": self._unique_preserve(term_list)}
|
|
365
|
-
fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
366
|
-
lines_written += 1
|
|
367
|
-
return lines_written
|
|
368
|
-
|
|
369
|
-
def predict_types(
|
|
370
|
-
self,
|
|
371
|
-
*,
|
|
372
|
-
docs_test_jsonl: str,
|
|
373
|
-
out_jsonl: str,
|
|
374
|
-
max_new_tokens: int = 128,
|
|
375
|
-
few_shot_k: int = 6,
|
|
376
|
-
) -> int:
|
|
263
|
+
Format doc as the retriever query and as the user prompt content.
|
|
377
264
|
"""
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
265
|
+
s = f"Title: {title}\n\nText:\n{text}"
|
|
266
|
+
if tfidf:
|
|
267
|
+
s += f"\n\nTF-IDF based suggestions: {tfidf}"
|
|
268
|
+
return s
|
|
382
269
|
|
|
383
|
-
|
|
384
|
-
----------
|
|
385
|
-
docs_test_jsonl : str
|
|
386
|
-
Path to test/dev documents in JSONL or tolerant JSON/JSONL.
|
|
387
|
-
out_jsonl : str
|
|
388
|
-
Output JSONL path where predictions are written (one line per doc).
|
|
389
|
-
max_new_tokens : int
|
|
390
|
-
Max generation length.
|
|
391
|
-
few_shot_k : int
|
|
392
|
-
Number of few-shot exemplars to prepend per prompt.
|
|
393
|
-
|
|
394
|
-
Returns
|
|
395
|
-
-------
|
|
396
|
-
int
|
|
397
|
-
Number of lines written (i.e., number of processed documents).
|
|
270
|
+
def _apply_chat_template(self, conversation: List[Dict[str, str]]) -> str:
|
|
398
271
|
"""
|
|
399
|
-
|
|
400
|
-
raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)")
|
|
401
|
-
|
|
402
|
-
test_documents = self._load_documents_jsonl(docs_test_jsonl)
|
|
403
|
-
prompts: List[str] = []
|
|
404
|
-
document_order: List[str] = []
|
|
405
|
-
|
|
406
|
-
for document_id, document_row in test_documents.items():
|
|
407
|
-
title = str(document_row.get("title", ""))
|
|
408
|
-
text = self._to_text(document_row.get("text", ""))
|
|
409
|
-
|
|
410
|
-
fewshot_block = self._format_fewshot_block(
|
|
411
|
-
self._system_prompt_types,
|
|
412
|
-
self._fewshot_types_docs,
|
|
413
|
-
key="types",
|
|
414
|
-
k=few_shot_k,
|
|
415
|
-
)
|
|
416
|
-
user_block = self._format_user_block(title, text)
|
|
417
|
-
|
|
418
|
-
prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:")
|
|
419
|
-
document_order.append(document_id)
|
|
420
|
-
|
|
421
|
-
generations = self.model.generate(prompts, max_new_tokens=max_new_tokens)
|
|
422
|
-
parsed_type_lists = [
|
|
423
|
-
self._parse_json_list(generated, key="types") for generated in generations
|
|
424
|
-
]
|
|
425
|
-
|
|
426
|
-
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
|
427
|
-
lines_written = 0
|
|
428
|
-
with open(out_jsonl, "w", encoding="utf-8") as fp_out:
|
|
429
|
-
for document_id, type_list in zip(document_order, parsed_type_lists):
|
|
430
|
-
payload = {"id": document_id, "types": self._unique_preserve(type_list)}
|
|
431
|
-
fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
432
|
-
lines_written += 1
|
|
433
|
-
return lines_written
|
|
434
|
-
|
|
435
|
-
def evaluate_extraction_f1(
|
|
436
|
-
self,
|
|
437
|
-
gold_item2docs_json: str,
|
|
438
|
-
preds_jsonl: str,
|
|
439
|
-
*,
|
|
440
|
-
key: str = "term",
|
|
441
|
-
) -> float:
|
|
272
|
+
Convert conversation into a single prompt string using the tokenizer's chat template if available.
|
|
442
273
|
"""
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
JSON mapping item -> [doc_ids].
|
|
449
|
-
preds_jsonl : str
|
|
450
|
-
JSONL lines like {"id": "...", "terms":[...]} or {"id":"...","types":[...]}.
|
|
451
|
-
key : str
|
|
452
|
-
"term" or "type" depending on what you are evaluating.
|
|
453
|
-
|
|
454
|
-
Returns
|
|
455
|
-
-------
|
|
456
|
-
float
|
|
457
|
-
Micro-averaged F1 score.
|
|
458
|
-
"""
|
|
459
|
-
item_to_doc_ids: Dict[str, List[str]] = self._load_json(gold_item2docs_json)
|
|
460
|
-
|
|
461
|
-
# Build gold: doc_id -> set(items)
|
|
462
|
-
gold_doc_to_items: Dict[str, set] = {}
|
|
463
|
-
for item_label, doc_id_list in item_to_doc_ids.items():
|
|
464
|
-
for document_id in doc_id_list:
|
|
465
|
-
gold_doc_to_items.setdefault(document_id, set()).add(
|
|
466
|
-
self._norm(item_label)
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
# Build predictions: doc_id -> set(items)
|
|
470
|
-
pred_doc_to_items: Dict[str, set] = {}
|
|
471
|
-
with open(preds_jsonl, "r", encoding="utf-8") as fp_in:
|
|
472
|
-
for line in fp_in:
|
|
473
|
-
row = json.loads(line.strip())
|
|
474
|
-
document_id = str(row.get("id", ""))
|
|
475
|
-
items_list = row.get("terms" if key == "term" else "types", [])
|
|
476
|
-
pred_doc_to_items[document_id] = {
|
|
477
|
-
self._norm(x) for x in items_list if isinstance(x, str)
|
|
478
|
-
}
|
|
274
|
+
assert self.tokenizer is not None
|
|
275
|
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
276
|
+
return self.tokenizer.apply_chat_template(
|
|
277
|
+
conversation, add_generation_prompt=True, tokenize=False
|
|
278
|
+
)
|
|
479
279
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
280
|
+
parts = []
|
|
281
|
+
for t in conversation:
|
|
282
|
+
parts.append(f"{t['role'].upper()}:\n{t['content']}\n")
|
|
283
|
+
parts.append("ASSISTANT:\n")
|
|
284
|
+
return "\n".join(parts)
|
|
285
|
+
|
|
286
|
+
def _extract_first_json_obj(self, text: str) -> Optional[dict]:
|
|
287
|
+
"""
|
|
288
|
+
Extract the first valid JSON object from generated text by scanning balanced {...}.
|
|
289
|
+
"""
|
|
290
|
+
starts = [i for i, ch in enumerate(text) if ch == "{"]
|
|
291
|
+
|
|
292
|
+
for s in starts:
|
|
293
|
+
depth = 0
|
|
294
|
+
for e in range(s, len(text)):
|
|
295
|
+
if text[e] == "{":
|
|
296
|
+
depth += 1
|
|
297
|
+
elif text[e] == "}":
|
|
298
|
+
depth -= 1
|
|
299
|
+
if depth == 0:
|
|
300
|
+
candidate = text[s : e + 1].strip().replace("\n", " ")
|
|
301
|
+
try:
|
|
302
|
+
return json.loads(candidate)
|
|
303
|
+
except Exception:
|
|
304
|
+
try:
|
|
305
|
+
candidate2 = re.sub(r"'", '"', candidate)
|
|
306
|
+
return json.loads(candidate2)
|
|
307
|
+
except Exception:
|
|
308
|
+
pass
|
|
309
|
+
break
|
|
310
|
+
return None
|
|
311
|
+
|
|
312
|
+
def _dedup_clean(self, items: List[str]) -> List[str]:
|
|
313
|
+
"""
|
|
314
|
+
Normalize and deduplicate strings (case-insensitive).
|
|
315
|
+
"""
|
|
316
|
+
out: List[str] = []
|
|
317
|
+
seen = set()
|
|
318
|
+
for x in items or []:
|
|
319
|
+
if not isinstance(x, str):
|
|
320
|
+
continue
|
|
321
|
+
x2 = re.sub(r"\s+", " ", x.strip())
|
|
322
|
+
if not x2:
|
|
323
|
+
continue
|
|
324
|
+
k = x2.lower()
|
|
325
|
+
if k in seen:
|
|
326
|
+
continue
|
|
327
|
+
seen.add(k)
|
|
328
|
+
out.append(x2)
|
|
329
|
+
return out
|
|
506
330
|
|
|
507
|
-
def
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
doc_terms_jsonl: Optional[str] = None, # formerly a1_results_jsonl
|
|
511
|
-
doc_terms_list: Optional[List[Dict]] = None, # formerly a1_results_list
|
|
512
|
-
few_shot_jsonl: Optional[
|
|
513
|
-
str
|
|
514
|
-
] = None, # JSONL lines: {"term":"...", "types":[...]}
|
|
515
|
-
rag_terms_json: Optional[
|
|
516
|
-
str
|
|
517
|
-
] = None, # JSON list; items may contain "term" and "RAG":[...]
|
|
518
|
-
random_few_shot: Optional[int] = 3,
|
|
519
|
-
model_id: str = "Qwen/Qwen2.5-1.5B-Instruct",
|
|
520
|
-
use_structured_output: bool = True,
|
|
521
|
-
seed: int = 42,
|
|
522
|
-
out_terms2types: str = "terms2types_pred.json",
|
|
523
|
-
out_types2docs: str = "types2docs_pred.json",
|
|
524
|
-
) -> Dict[str, Any]:
|
|
331
|
+
def _doc_id(self, d: Dict[str, Any]) -> str:
|
|
332
|
+
"""
|
|
333
|
+
Extract doc_id from common keys: doc_id, id, docid.
|
|
525
334
|
"""
|
|
526
|
-
|
|
335
|
+
return str(d.get("doc_id") or d.get("id") or d.get("docid") or "")
|
|
527
336
|
|
|
528
|
-
|
|
529
|
-
----------
|
|
530
|
-
doc_terms_jsonl : Optional[str]
|
|
531
|
-
Path to JSONL with lines like {"id": "...", "terms": [...]} or a JSON with {"results":[...]}.
|
|
532
|
-
doc_terms_list : Optional[List[Dict]]
|
|
533
|
-
In-memory results like [{"id":"...","extracted_terms":[...]}] or {"id":"...","terms":[...]}.
|
|
534
|
-
few_shot_jsonl : Optional[str]
|
|
535
|
-
Global few-shot exemplars: one JSON object per line with {"term": "...", "types":[...]}.
|
|
536
|
-
rag_terms_json : Optional[str]
|
|
537
|
-
Optional per-term RAG exemplars: a JSON list of {"term": "...", "RAG":[{"term": "...", "types":[...]}]}.
|
|
538
|
-
random_few_shot : Optional[int]
|
|
539
|
-
If provided, randomly select up to this many few-shot examples for each prediction.
|
|
540
|
-
model_id : str
|
|
541
|
-
HF model id used specifically for term→types predictions.
|
|
542
|
-
use_structured_output : bool
|
|
543
|
-
If True and outlines is available, enforce structured {"types":[...]} output.
|
|
544
|
-
seed : int
|
|
545
|
-
Random seed for reproducibility.
|
|
546
|
-
out_terms2types : str
|
|
547
|
-
Output JSON path for list of {"term": "...", "predicted_types":[...]}.
|
|
548
|
-
out_types2docs : str
|
|
549
|
-
Output JSON path for dict {"TYPE":[doc_ids,...], ...}.
|
|
550
|
-
|
|
551
|
-
Returns
|
|
552
|
-
-------
|
|
553
|
-
Dict[str, Any]
|
|
554
|
-
Summary with predictions and counts.
|
|
337
|
+
def _extract_documents(self, data: Any) -> List[Dict[str, Any]]:
|
|
555
338
|
"""
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
339
|
+
Accept list-of-docs OR dict with 'documents'/'docs'.
|
|
340
|
+
"""
|
|
341
|
+
if isinstance(data, list):
|
|
342
|
+
return data
|
|
343
|
+
if isinstance(data, dict):
|
|
344
|
+
if isinstance(data.get("documents"), list):
|
|
345
|
+
return data["documents"]
|
|
346
|
+
if isinstance(data.get("docs"), list):
|
|
347
|
+
return data["docs"]
|
|
348
|
+
raise ValueError("Expected dict with 'documents' (or 'docs'), or a list of docs.")
|
|
559
349
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
in_memory_results=doc_terms_list,
|
|
564
|
-
)
|
|
565
|
-
if not doc_term_extractions:
|
|
566
|
-
raise ValueError(
|
|
567
|
-
"No document→terms results provided (doc_terms_jsonl/doc_terms_list)."
|
|
568
|
-
)
|
|
350
|
+
def _normalize_terms2docs(self, raw_terms2docs: Any, docs: List[Dict[str, Any]]) -> Dict[str, List[str]]:
|
|
351
|
+
"""
|
|
352
|
+
Normalize mapping to: term -> [doc_id, ...].
|
|
569
353
|
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
global_few_shot_examples: List[Dict] = []
|
|
576
|
-
if few_shot_jsonl and os.path.exists(few_shot_jsonl):
|
|
577
|
-
with open(few_shot_jsonl, "r", encoding="utf-8") as few_shot_file:
|
|
578
|
-
for raw_line in few_shot_file:
|
|
579
|
-
raw_line = raw_line.strip()
|
|
580
|
-
if not raw_line:
|
|
581
|
-
continue
|
|
582
|
-
try:
|
|
583
|
-
json_obj = json.loads(raw_line)
|
|
584
|
-
except Exception:
|
|
585
|
-
continue
|
|
586
|
-
if (
|
|
587
|
-
isinstance(json_obj, dict)
|
|
588
|
-
and "term" in json_obj
|
|
589
|
-
and "types" in json_obj
|
|
590
|
-
):
|
|
591
|
-
global_few_shot_examples.append(json_obj)
|
|
592
|
-
|
|
593
|
-
# Optional per-term RAG examples: {normalized_term -> [examples]}
|
|
594
|
-
rag_examples_lookup: Dict[str, List[Dict]] = {}
|
|
595
|
-
if rag_terms_json and os.path.exists(rag_terms_json):
|
|
596
|
-
try:
|
|
597
|
-
rag_payload = self._load_json(rag_terms_json)
|
|
598
|
-
if isinstance(rag_payload, list):
|
|
599
|
-
for rag_item in rag_payload:
|
|
600
|
-
if isinstance(rag_item, dict):
|
|
601
|
-
normalized_term = self._normalize_term(
|
|
602
|
-
rag_item.get("term", "")
|
|
603
|
-
)
|
|
604
|
-
rag_examples_lookup[normalized_term] = rag_item.get(
|
|
605
|
-
"RAG", []
|
|
606
|
-
)
|
|
607
|
-
except Exception:
|
|
608
|
-
pass
|
|
354
|
+
If caller accidentally provides inverted mapping: doc_id -> [term, ...],
|
|
355
|
+
we detect it (keys mostly match doc_ids) and invert it.
|
|
356
|
+
"""
|
|
357
|
+
if not isinstance(raw_terms2docs, dict) or not raw_terms2docs:
|
|
358
|
+
return {}
|
|
609
359
|
|
|
610
|
-
|
|
611
|
-
|
|
360
|
+
doc_ids = {self._doc_id(d) for d in docs}
|
|
361
|
+
doc_ids.discard("")
|
|
612
362
|
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
for
|
|
616
|
-
normalized_term = self._normalize_term(term_text)
|
|
363
|
+
keys = list(raw_terms2docs.keys())
|
|
364
|
+
sample = keys[:200]
|
|
365
|
+
hits = sum(1 for k in sample if str(k) in doc_ids)
|
|
617
366
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
367
|
+
if sample and hits >= int(0.6 * len(sample)):
|
|
368
|
+
term2docs: Dict[str, List[str]] = defaultdict(list)
|
|
369
|
+
for did, terms in raw_terms2docs.items():
|
|
370
|
+
did = str(did)
|
|
371
|
+
if did not in doc_ids:
|
|
372
|
+
continue
|
|
373
|
+
for t in (terms or []):
|
|
374
|
+
if isinstance(t, str) and t.strip():
|
|
375
|
+
term2docs[t.strip()].append(did)
|
|
376
|
+
return {t: self._dedup_clean(ds) for t, ds in term2docs.items()}
|
|
377
|
+
|
|
378
|
+
norm: Dict[str, List[str]] = {}
|
|
379
|
+
for term, doc_list in raw_terms2docs.items():
|
|
380
|
+
if not isinstance(term, str) or not term.strip():
|
|
381
|
+
continue
|
|
382
|
+
docs_norm = [str(d) for d in (doc_list or []) if str(d)]
|
|
383
|
+
if docs_norm:
|
|
384
|
+
norm[term.strip()] = self._dedup_clean(docs_norm)
|
|
385
|
+
return norm
|
|
623
386
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
typing_prompt_string = self._apply_chat_template_safe_types(
|
|
631
|
-
typing_tokenizer, conversation_messages
|
|
632
|
-
)
|
|
387
|
+
def _generate(self, prompt: str) -> str:
|
|
388
|
+
"""
|
|
389
|
+
Deterministic single-prompt generation (no sampling).
|
|
390
|
+
Returns decoded completion only.
|
|
391
|
+
"""
|
|
392
|
+
assert self.model is not None and self.tokenizer is not None
|
|
633
393
|
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
and _PredictedTypesSchema is not None
|
|
642
|
-
):
|
|
643
|
-
try:
|
|
644
|
-
outlines_model = OutlinesTFModel(typing_model, typing_tokenizer) # type: ignore
|
|
645
|
-
generator = outlines_generate_json(
|
|
646
|
-
outlines_model, _PredictedTypesSchema
|
|
647
|
-
) # type: ignore
|
|
648
|
-
structured = generator(typing_prompt_string, max_tokens=512)
|
|
649
|
-
predicted_types = [
|
|
650
|
-
label for label in structured.types if isinstance(label, str)
|
|
651
|
-
]
|
|
652
|
-
raw_generation_text = json.dumps(
|
|
653
|
-
{"types": predicted_types}, ensure_ascii=False
|
|
654
|
-
)
|
|
655
|
-
except Exception:
|
|
656
|
-
# Fall back to greedy decoding
|
|
657
|
-
use_structured_output = False
|
|
658
|
-
|
|
659
|
-
# Greedy decode fallback
|
|
660
|
-
if (
|
|
661
|
-
not use_structured_output
|
|
662
|
-
or not OUTLINES_AVAILABLE
|
|
663
|
-
or _PredictedTypesSchema is None
|
|
664
|
-
):
|
|
665
|
-
tokenized_prompt = typing_tokenizer(
|
|
666
|
-
typing_prompt_string,
|
|
667
|
-
return_tensors="pt",
|
|
668
|
-
truncation=True,
|
|
669
|
-
max_length=2048,
|
|
670
|
-
)
|
|
671
|
-
if torch.cuda.is_available():
|
|
672
|
-
tokenized_prompt = {
|
|
673
|
-
name: tensor.cuda() for name, tensor in tokenized_prompt.items()
|
|
674
|
-
}
|
|
675
|
-
with torch.no_grad():
|
|
676
|
-
output_ids = typing_model.generate(
|
|
677
|
-
**tokenized_prompt,
|
|
678
|
-
max_new_tokens=256,
|
|
679
|
-
do_sample=False,
|
|
680
|
-
num_beams=1,
|
|
681
|
-
pad_token_id=typing_tokenizer.eos_token_id,
|
|
682
|
-
)
|
|
683
|
-
new_token_span = output_ids[0][tokenized_prompt["input_ids"].shape[1] :]
|
|
684
|
-
raw_generation_text = typing_tokenizer.decode(
|
|
685
|
-
new_token_span, skip_special_tokens=True
|
|
686
|
-
)
|
|
687
|
-
predicted_types = self._extract_types_from_text(raw_generation_text)
|
|
688
|
-
|
|
689
|
-
term_to_predicted_types_list.append(
|
|
690
|
-
{
|
|
691
|
-
"term": term_text,
|
|
692
|
-
"predicted_types": sorted(set(predicted_types)),
|
|
693
|
-
}
|
|
694
|
-
)
|
|
394
|
+
enc = self.tokenizer(
|
|
395
|
+
prompt,
|
|
396
|
+
return_tensors="pt",
|
|
397
|
+
truncation=True,
|
|
398
|
+
max_length=self.max_input_length,
|
|
399
|
+
)
|
|
400
|
+
enc = {k: v.to(self.model.device) for k, v in enc.items()}
|
|
695
401
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
doc_ids_for_term
|
|
704
|
-
)
|
|
705
|
-
|
|
706
|
-
types_to_doc_ids: Dict[str, List[str]] = {
|
|
707
|
-
type_label: sorted(doc_id_set)
|
|
708
|
-
for type_label, doc_id_set in types_to_doc_id_set.items()
|
|
709
|
-
}
|
|
710
|
-
|
|
711
|
-
# 8) Save outputs
|
|
712
|
-
os.makedirs(os.path.dirname(out_terms2types) or ".", exist_ok=True)
|
|
713
|
-
with open(out_terms2types, "w", encoding="utf-8") as fp_terms2types:
|
|
714
|
-
json.dump(
|
|
715
|
-
term_to_predicted_types_list,
|
|
716
|
-
fp_terms2types,
|
|
717
|
-
ensure_ascii=False,
|
|
718
|
-
indent=2,
|
|
402
|
+
with torch.no_grad():
|
|
403
|
+
out = self.model.generate(
|
|
404
|
+
**enc,
|
|
405
|
+
max_new_tokens=self.max_new_tokens,
|
|
406
|
+
do_sample=False,
|
|
407
|
+
num_beams=1,
|
|
408
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
719
409
|
)
|
|
720
410
|
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
json.dump(types_to_doc_ids, fp_types2docs, ensure_ascii=False, indent=2)
|
|
411
|
+
gen_tokens = out[0][enc["input_ids"].shape[1] :]
|
|
412
|
+
return self.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
|
724
413
|
|
|
725
|
-
|
|
726
|
-
del typing_model, typing_tokenizer
|
|
727
|
-
if torch.cuda.is_available():
|
|
728
|
-
torch.cuda.empty_cache()
|
|
729
|
-
|
|
730
|
-
return {
|
|
731
|
-
"terms2types_pred": term_to_predicted_types_list,
|
|
732
|
-
"types2docs_pred": types_to_doc_ids,
|
|
733
|
-
"unique_terms": len(unique_terms),
|
|
734
|
-
"types_count": len(types_to_doc_ids),
|
|
735
|
-
}
|
|
736
|
-
|
|
737
|
-
def _load_json(self, path: str) -> Dict[str, Any]:
|
|
738
|
-
"""Load a JSON file from disk and return its parsed object."""
|
|
739
|
-
with open(path, "r", encoding="utf-8") as file_obj:
|
|
740
|
-
return json.load(file_obj)
|
|
741
|
-
|
|
742
|
-
def _iter_json_objects(self, blob: str) -> Iterable[Dict[str, Any]]:
|
|
414
|
+
def _retrieve_doc_fewshot(self, doc: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
743
415
|
"""
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
Supports cases where multiple JSON objects are concatenated back-to-back
|
|
747
|
-
in a single line. It skips stray commas/whitespace between objects.
|
|
748
|
-
|
|
749
|
-
Parameters
|
|
750
|
-
----------
|
|
751
|
-
blob : str
|
|
752
|
-
A string that may contain one or more JSON objects.
|
|
753
|
-
|
|
754
|
-
Yields
|
|
755
|
-
------
|
|
756
|
-
Dict[str, Any]
|
|
757
|
-
Each parsed JSON object.
|
|
416
|
+
Retrieve top-k doc examples (JSON dicts) for few-shot doc->terms prompting.
|
|
758
417
|
"""
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
cursor_index += 1
|
|
765
|
-
if cursor_index >= text_length:
|
|
766
|
-
break
|
|
418
|
+
q = self._format_doc(doc.get("title", ""), doc.get("text", ""))
|
|
419
|
+
hits = self.doc_retriever.retrieve([q], top_k=self.top_k)[0]
|
|
420
|
+
|
|
421
|
+
out: List[Dict[str, Any]] = []
|
|
422
|
+
for h in hits:
|
|
767
423
|
try:
|
|
768
|
-
|
|
769
|
-
except
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
def _load_documents_jsonl(self, path: str) -> Dict[str, Dict[str, Any]]:
|
|
424
|
+
out.append(json.loads(h))
|
|
425
|
+
except Exception:
|
|
426
|
+
continue
|
|
427
|
+
return out
|
|
428
|
+
|
|
429
|
+
def _retrieve_term_fewshot(self, term: str) -> List[Dict[str, Any]]:
|
|
776
430
|
"""
|
|
777
|
-
|
|
778
|
-
• True JSONL (one object per line)
|
|
779
|
-
• Lines with multiple concatenated JSON objects
|
|
780
|
-
• Whole file as a JSON array
|
|
781
|
-
|
|
782
|
-
Returns
|
|
783
|
-
-------
|
|
784
|
-
Dict[str, Dict[str, Any]]
|
|
785
|
-
Mapping doc_id -> full document row.
|
|
431
|
+
Retrieve top-k term examples (JSON dicts) for few-shot term->types prompting.
|
|
786
432
|
"""
|
|
787
|
-
|
|
433
|
+
hits = self.term_retriever.retrieve([term], top_k=self.top_k)[0]
|
|
788
434
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
# Case A: whole-file JSON array
|
|
793
|
-
if content.startswith("["):
|
|
435
|
+
out: List[Dict[str, Any]] = []
|
|
436
|
+
for h in hits:
|
|
794
437
|
try:
|
|
795
|
-
|
|
796
|
-
if isinstance(json_array, list):
|
|
797
|
-
for record in json_array:
|
|
798
|
-
if not isinstance(record, dict):
|
|
799
|
-
continue
|
|
800
|
-
document_id = str(
|
|
801
|
-
record.get("id")
|
|
802
|
-
or record.get("doc_id")
|
|
803
|
-
or (record.get("doc") or {}).get("id")
|
|
804
|
-
or ""
|
|
805
|
-
)
|
|
806
|
-
if document_id:
|
|
807
|
-
documents_by_id[document_id] = record
|
|
808
|
-
return documents_by_id
|
|
438
|
+
out.append(json.loads(h))
|
|
809
439
|
except Exception:
|
|
810
|
-
# Fall back to line-wise handling if array parsing fails
|
|
811
|
-
pass
|
|
812
|
-
|
|
813
|
-
# Case B: treat as JSONL-ish; parse *all* objects per line
|
|
814
|
-
for raw_line in content.splitlines():
|
|
815
|
-
line = raw_line.strip()
|
|
816
|
-
if not line:
|
|
817
440
|
continue
|
|
818
|
-
|
|
819
|
-
if not isinstance(record, dict):
|
|
820
|
-
continue
|
|
821
|
-
document_id = str(
|
|
822
|
-
record.get("id")
|
|
823
|
-
or record.get("doc_id")
|
|
824
|
-
or (record.get("doc") or {}).get("id")
|
|
825
|
-
or ""
|
|
826
|
-
)
|
|
827
|
-
if document_id:
|
|
828
|
-
documents_by_id[document_id] = record
|
|
829
|
-
|
|
830
|
-
return documents_by_id
|
|
831
|
-
|
|
832
|
-
def _to_text(self, text_field: Any) -> str:
|
|
833
|
-
"""
|
|
834
|
-
Convert a 'text' field into a single string (handles list-of-strings).
|
|
835
|
-
|
|
836
|
-
Parameters
|
|
837
|
-
----------
|
|
838
|
-
text_field : Any
|
|
839
|
-
The value found under "text" in the dataset row.
|
|
441
|
+
return out
|
|
840
442
|
|
|
841
|
-
|
|
842
|
-
-------
|
|
843
|
-
str
|
|
844
|
-
A single-string representation of the text.
|
|
443
|
+
def _doc_to_terms(self, doc: Dict[str, Any]) -> List[str]:
|
|
845
444
|
"""
|
|
846
|
-
|
|
847
|
-
return text_field
|
|
848
|
-
if isinstance(text_field, list):
|
|
849
|
-
return " ".join(str(part) for part in text_field)
|
|
850
|
-
return str(text_field) if text_field is not None else ""
|
|
851
|
-
|
|
852
|
-
def _unique_preserve(self, values: List[str]) -> List[str]:
|
|
445
|
+
Predict terms for a document using few-shot prompting + doc retrieval.
|
|
853
446
|
"""
|
|
854
|
-
|
|
447
|
+
fewshot = self._retrieve_doc_fewshot(doc)
|
|
855
448
|
|
|
856
|
-
|
|
857
|
-
----------
|
|
858
|
-
values : List[str]
|
|
859
|
-
Sequence possibly containing duplicates.
|
|
449
|
+
convo: List[Dict[str, str]] = [{"role": "system", "content": self.DOC2TERMS_SYSTEM_PROMPT}]
|
|
860
450
|
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
451
|
+
for ex in fewshot:
|
|
452
|
+
ex_tfidf = ex.get("TF-IDF") or ex.get("tfidf_terms") or []
|
|
453
|
+
convo += [
|
|
454
|
+
{
|
|
455
|
+
"role": "user",
|
|
456
|
+
"content": self._format_doc(
|
|
457
|
+
ex.get("title", ""),
|
|
458
|
+
ex.get("text", ""),
|
|
459
|
+
ex_tfidf if self.use_tfidf else None,
|
|
460
|
+
),
|
|
461
|
+
},
|
|
462
|
+
{
|
|
463
|
+
"role": "assistant",
|
|
464
|
+
"content": json.dumps({"terms": ex.get("OL", [])}, ensure_ascii=False),
|
|
465
|
+
},
|
|
466
|
+
]
|
|
877
467
|
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
468
|
+
tfidf = doc.get("TF-IDF") or doc.get("tfidf_terms") or []
|
|
469
|
+
convo.append(
|
|
470
|
+
{
|
|
471
|
+
"role": "user",
|
|
472
|
+
"content": self._format_doc(
|
|
473
|
+
doc.get("title", ""),
|
|
474
|
+
doc.get("text", ""),
|
|
475
|
+
tfidf if self.use_tfidf else None,
|
|
476
|
+
),
|
|
477
|
+
}
|
|
478
|
+
)
|
|
882
479
|
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
"""
|
|
888
|
-
return " ".join(text.lower().split())
|
|
480
|
+
prompt = self._apply_chat_template(convo)
|
|
481
|
+
gen = self._generate(prompt)
|
|
482
|
+
parsed = self._extract_first_json_obj(gen) or {}
|
|
483
|
+
return self._dedup_clean(parsed.get("terms", []))
|
|
889
484
|
|
|
890
|
-
def
|
|
485
|
+
def _term_to_types(self, term: str) -> List[str]:
|
|
486
|
+
"""
|
|
487
|
+
Predict types for a term using few-shot prompting + term retrieval.
|
|
891
488
|
"""
|
|
892
|
-
|
|
489
|
+
fewshot = self._retrieve_term_fewshot(term)
|
|
893
490
|
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
491
|
+
system = self.TERM2TYPES_SYSTEM_PROMPT
|
|
492
|
+
if self.restrict_to_known_types and self._allowed_types:
|
|
493
|
+
allowed_block = "\n".join(f"- {t}" for t in self._allowed_types)
|
|
494
|
+
system = (
|
|
495
|
+
system
|
|
496
|
+
+ "\n\nIMPORTANT CONSTRAINT:\n"
|
|
497
|
+
+ "Choose ONLY from the following valid ontology types (do not invent new labels):\n"
|
|
498
|
+
+ allowed_block
|
|
499
|
+
)
|
|
898
500
|
|
|
899
|
-
|
|
900
|
-
-------
|
|
901
|
-
str
|
|
902
|
-
Lowercased, trimmed and single-spaced term.
|
|
903
|
-
"""
|
|
904
|
-
return " ".join(str(term).strip().split()).lower()
|
|
501
|
+
convo: List[Dict[str, str]] = [{"role": "system", "content": system}]
|
|
905
502
|
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
"""
|
|
915
|
-
Render a few-shot block like:
|
|
503
|
+
for ex in fewshot:
|
|
504
|
+
convo += [
|
|
505
|
+
{"role": "user", "content": f"Term: {ex.get('term','')}"},
|
|
506
|
+
{
|
|
507
|
+
"role": "assistant",
|
|
508
|
+
"content": json.dumps({"types": ex.get("types", [])}, ensure_ascii=False),
|
|
509
|
+
},
|
|
510
|
+
]
|
|
916
511
|
|
|
917
|
-
|
|
512
|
+
convo.append({"role": "user", "content": f"Term: {term}"})
|
|
918
513
|
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
Assistant:
|
|
924
|
-
{"terms": [...]} or {"types": [...]}
|
|
514
|
+
prompt = self._apply_chat_template(convo)
|
|
515
|
+
gen = self._generate(prompt)
|
|
516
|
+
parsed = self._extract_first_json_obj(gen) or {}
|
|
517
|
+
return self._dedup_clean(parsed.get("types", []))
|
|
925
518
|
|
|
926
|
-
|
|
927
|
-
----------
|
|
928
|
-
system_prompt : str
|
|
929
|
-
Instructional system text to prepend.
|
|
930
|
-
fewshot_examples : List[Tuple[str, str, List[str]]]
|
|
931
|
-
Examples as (title, text, labels_list).
|
|
932
|
-
key : str
|
|
933
|
-
Either "terms" or "types" depending on the task.
|
|
934
|
-
k : int
|
|
935
|
-
Number of examples to include.
|
|
936
|
-
|
|
937
|
-
Returns
|
|
938
|
-
-------
|
|
939
|
-
str
|
|
940
|
-
Formatted few-shot block text.
|
|
519
|
+
def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
941
520
|
"""
|
|
942
|
-
|
|
943
|
-
for example_title, example_text, gold_list in fewshot_examples[:k]:
|
|
944
|
-
lines.append("### Example")
|
|
945
|
-
lines.append(f"User:\nTitle: {example_title}\n{example_text}")
|
|
946
|
-
lines.append(
|
|
947
|
-
f'Assistant:\n{{"{key}": '
|
|
948
|
-
+ json.dumps(gold_list, ensure_ascii=False)
|
|
949
|
-
+ "}"
|
|
950
|
-
)
|
|
951
|
-
return "\n".join(lines)
|
|
521
|
+
Train or predict for task="text2onto".
|
|
952
522
|
|
|
953
|
-
|
|
523
|
+
Returns:
|
|
524
|
+
- training: None
|
|
525
|
+
- prediction: {"terms": [...], "types": [...]}
|
|
954
526
|
"""
|
|
955
|
-
|
|
527
|
+
if not self._loaded:
|
|
528
|
+
self.load(model_id=self.llm_model_id, device=self.device)
|
|
956
529
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
title : str
|
|
960
|
-
Document title.
|
|
961
|
-
text : str
|
|
962
|
-
Document text (single string).
|
|
963
|
-
|
|
964
|
-
Returns
|
|
965
|
-
-------
|
|
966
|
-
str
|
|
967
|
-
Formatted user block.
|
|
968
|
-
"""
|
|
969
|
-
return f"### Task\nUser:\nTitle: {title}\n{text}"
|
|
530
|
+
if not isinstance(data, dict):
|
|
531
|
+
raise ValueError("text2onto expects a dict with documents + mappings.")
|
|
970
532
|
|
|
971
|
-
|
|
972
|
-
"""
|
|
973
|
-
Extract a list from model output, trying:
|
|
974
|
-
1) JSON object with the key ({"terms":[...]} or {"types":[...]}).
|
|
975
|
-
2) Any top-level JSON array.
|
|
976
|
-
3) Fallback: comma-split.
|
|
533
|
+
docs = self._extract_documents(data)
|
|
977
534
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
generated_text : str
|
|
981
|
-
Raw generation text to parse.
|
|
982
|
-
key : str
|
|
983
|
-
"terms" or "types".
|
|
984
|
-
|
|
985
|
-
Returns
|
|
986
|
-
-------
|
|
987
|
-
List[str]
|
|
988
|
-
Parsed strings (best-effort).
|
|
989
|
-
"""
|
|
990
|
-
# 1) Try a JSON object and read key
|
|
991
|
-
try:
|
|
992
|
-
object_match = self._json_object_regex.search(generated_text)
|
|
993
|
-
if object_match:
|
|
994
|
-
json_obj = json.loads(object_match.group(0))
|
|
995
|
-
json_array = json_obj.get(key)
|
|
996
|
-
if isinstance(json_array, list):
|
|
997
|
-
return [value for value in json_array if isinstance(value, str)]
|
|
998
|
-
except Exception:
|
|
999
|
-
pass
|
|
1000
|
-
|
|
1001
|
-
# 2) Any JSON array
|
|
1002
|
-
try:
|
|
1003
|
-
array_match = self._json_array_regex.search(generated_text)
|
|
1004
|
-
if array_match:
|
|
1005
|
-
json_array = json.loads(array_match.group(0))
|
|
1006
|
-
if isinstance(json_array, list):
|
|
1007
|
-
return [value for value in json_array if isinstance(value, str)]
|
|
1008
|
-
except Exception:
|
|
1009
|
-
pass
|
|
1010
|
-
|
|
1011
|
-
# 3) Fallback: comma-split (last resort)
|
|
1012
|
-
if "," in generated_text:
|
|
1013
|
-
return [
|
|
1014
|
-
part.strip().strip('"').strip("'")
|
|
1015
|
-
for part in generated_text.split(",")
|
|
1016
|
-
if part.strip()
|
|
1017
|
-
]
|
|
1018
|
-
return []
|
|
535
|
+
raw_terms2docs = data.get("terms2docs") or data.get("term2docs") or {}
|
|
536
|
+
terms2types = data.get("terms2types") or data.get("term2types") or {}
|
|
1019
537
|
|
|
1020
|
-
|
|
1021
|
-
self, tokenizer: AutoTokenizer, messages: List[Dict[str, str]]
|
|
1022
|
-
) -> str:
|
|
1023
|
-
"""
|
|
1024
|
-
Safely build a prompt string for chat models. Uses the model's chat template
|
|
1025
|
-
when available; otherwise falls back to a simple concatenation.
|
|
1026
|
-
"""
|
|
1027
|
-
try:
|
|
1028
|
-
return tokenizer.apply_chat_template(
|
|
1029
|
-
messages, add_generation_prompt=True, tokenize=False
|
|
1030
|
-
)
|
|
1031
|
-
except Exception:
|
|
1032
|
-
system_text = next(
|
|
1033
|
-
(m["content"] for m in messages if m.get("role") == "system"), ""
|
|
1034
|
-
)
|
|
1035
|
-
last_user_text = next(
|
|
1036
|
-
(m["content"] for m in reversed(messages) if m.get("role") == "user"),
|
|
1037
|
-
"",
|
|
1038
|
-
)
|
|
1039
|
-
return f"{system_text}\n\nUser:\n{last_user_text}\n\nAssistant:"
|
|
538
|
+
terms2docs = self._normalize_terms2docs(raw_terms2docs, docs)
|
|
1040
539
|
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
term: str,
|
|
1044
|
-
few_shot_examples: Optional[List[Dict]] = None,
|
|
1045
|
-
random_k: Optional[int] = None,
|
|
1046
|
-
) -> List[Dict[str, str]]:
|
|
1047
|
-
"""
|
|
1048
|
-
Create a chat-style conversation for a single term→types query,
|
|
1049
|
-
optionally prepending few-shot examples.
|
|
1050
|
-
"""
|
|
1051
|
-
messages: List[Dict[str, str]] = [
|
|
1052
|
-
{"role": "system", "content": self._system_prompt_term_to_types}
|
|
1053
|
-
]
|
|
1054
|
-
examples = list(few_shot_examples or [])
|
|
1055
|
-
if random_k and len(examples) > random_k:
|
|
1056
|
-
import random as _rnd
|
|
1057
|
-
|
|
1058
|
-
examples = _rnd.sample(examples, random_k)
|
|
1059
|
-
for exemplar in examples:
|
|
1060
|
-
example_term = exemplar.get("term", "")
|
|
1061
|
-
example_types = exemplar.get("types", [])
|
|
1062
|
-
messages.append({"role": "user", "content": f"Term: {example_term}"})
|
|
1063
|
-
messages.append(
|
|
540
|
+
if not test:
|
|
541
|
+
self._allowed_types = sorted(
|
|
1064
542
|
{
|
|
1065
|
-
|
|
1066
|
-
|
|
543
|
+
ty.strip()
|
|
544
|
+
for tys in (terms2types or {}).values()
|
|
545
|
+
for ty in (tys or [])
|
|
546
|
+
if isinstance(ty, str) and ty.strip()
|
|
1067
547
|
}
|
|
1068
548
|
)
|
|
1069
|
-
messages.append({"role": "user", "content": f"Term: {term}"})
|
|
1070
|
-
return messages
|
|
1071
549
|
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
|
1093
|
-
"""
|
|
1094
|
-
Load a *separate* small chat model for Term→Types (keeps LocalAutoLLM untouched).
|
|
1095
|
-
"""
|
|
1096
|
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
1097
|
-
if tokenizer.pad_token is None:
|
|
1098
|
-
tokenizer.pad_token = tokenizer.eos_token
|
|
1099
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
1100
|
-
model_id,
|
|
1101
|
-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
|
1102
|
-
device_map="auto" if torch.cuda.is_available() else None,
|
|
1103
|
-
)
|
|
1104
|
-
return model, tokenizer
|
|
550
|
+
# build doc->terms from term->docs
|
|
551
|
+
doc2terms: Dict[str, List[str]] = defaultdict(list)
|
|
552
|
+
for term, doc_ids in (terms2docs or {}).items():
|
|
553
|
+
for did in (doc_ids or []):
|
|
554
|
+
doc2terms[str(did)].append(term)
|
|
555
|
+
|
|
556
|
+
# doc few-shot corpus: doc + gold OL terms
|
|
557
|
+
doc_examples: List[Dict[str, Any]] = []
|
|
558
|
+
for d in docs:
|
|
559
|
+
did = self._doc_id(d)
|
|
560
|
+
ex = dict(d)
|
|
561
|
+
ex["doc_id"] = did
|
|
562
|
+
ex["OL"] = self._dedup_clean(doc2terms.get(did, []))
|
|
563
|
+
doc_examples.append(ex)
|
|
564
|
+
|
|
565
|
+
# term few-shot corpus: term + gold types
|
|
566
|
+
term_examples = [
|
|
567
|
+
{"term": t, "types": self._dedup_clean(tys)}
|
|
568
|
+
for t, tys in (terms2types or {}).items()
|
|
569
|
+
]
|
|
1105
570
|
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
results_json_path: Optional[str] = None,
|
|
1110
|
-
in_memory_results: Optional[List[Dict]] = None,
|
|
1111
|
-
) -> List[Dict]:
|
|
1112
|
-
"""
|
|
1113
|
-
Normalize document→terms outputs to a list of:
|
|
1114
|
-
{"id": "<doc_id>", "extracted_terms": ["...", ...]}
|
|
1115
|
-
|
|
1116
|
-
Accepts either:
|
|
1117
|
-
- in_memory_results (list of dicts)
|
|
1118
|
-
- results_json_path pointing to:
|
|
1119
|
-
• a JSONL file with lines: {"id": "...", "terms": [...]}
|
|
1120
|
-
• OR a JSON file with {"results":[{"id":..., "extracted_terms": [...]}, ...]}
|
|
1121
|
-
• OR a JSON list of dicts
|
|
1122
|
-
"""
|
|
1123
|
-
normalized_records: List[Dict] = []
|
|
571
|
+
# store as JSON strings so retrievers return parseable strings
|
|
572
|
+
self._doc_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in doc_examples]
|
|
573
|
+
self._term_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in term_examples]
|
|
1124
574
|
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
if not document_id:
|
|
1130
|
-
return None
|
|
1131
|
-
terms = source_row.get("extracted_terms")
|
|
1132
|
-
if terms is None:
|
|
1133
|
-
terms = source_row.get("terms")
|
|
1134
|
-
if (
|
|
1135
|
-
terms is None
|
|
1136
|
-
and "payload" in source_row
|
|
1137
|
-
and isinstance(source_row["payload"], dict)
|
|
1138
|
-
):
|
|
1139
|
-
terms = source_row["payload"].get("terms")
|
|
1140
|
-
if not isinstance(terms, list):
|
|
1141
|
-
terms = []
|
|
1142
|
-
return {
|
|
1143
|
-
"id": document_id,
|
|
1144
|
-
"extracted_terms": [t for t in terms if isinstance(t, str)],
|
|
1145
|
-
}
|
|
575
|
+
# index retrievers
|
|
576
|
+
self.doc_retriever.index(self._doc_examples_json)
|
|
577
|
+
self.term_retriever.index(self._term_examples_json)
|
|
578
|
+
return None
|
|
1146
579
|
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
if isinstance(json_obj, dict):
|
|
1167
|
-
coerced_record = _coerce_to_record(json_obj)
|
|
1168
|
-
if coerced_record:
|
|
1169
|
-
normalized_records.append(coerced_record)
|
|
1170
|
-
else:
|
|
1171
|
-
payload_obj = self._load_json(results_json_path)
|
|
1172
|
-
if isinstance(payload_obj, dict) and "results" in payload_obj:
|
|
1173
|
-
for source_row in payload_obj["results"]:
|
|
1174
|
-
coerced_record = _coerce_to_record(source_row)
|
|
1175
|
-
if coerced_record:
|
|
1176
|
-
normalized_records.append(coerced_record)
|
|
1177
|
-
elif isinstance(payload_obj, list):
|
|
1178
|
-
for source_row in payload_obj:
|
|
1179
|
-
if isinstance(source_row, dict):
|
|
1180
|
-
coerced_record = _coerce_to_record(source_row)
|
|
1181
|
-
if coerced_record:
|
|
1182
|
-
normalized_records.append(coerced_record)
|
|
1183
|
-
|
|
1184
|
-
return normalized_records
|
|
1185
|
-
|
|
1186
|
-
def _collect_unique_terms_from_extractions(
|
|
1187
|
-
self, doc_term_extractions: List[Dict]
|
|
1188
|
-
) -> List[str]:
|
|
1189
|
-
"""
|
|
1190
|
-
Collect unique terms (original casing) from normalized document→terms results.
|
|
1191
|
-
"""
|
|
1192
|
-
seen_normalized_terms: set = set()
|
|
1193
|
-
ordered_unique_terms: List[str] = []
|
|
1194
|
-
for record in doc_term_extractions:
|
|
1195
|
-
for term_text in record.get("extracted_terms", []):
|
|
1196
|
-
normalized = self._normalize_term(term_text)
|
|
1197
|
-
if normalized and normalized not in seen_normalized_terms:
|
|
1198
|
-
seen_normalized_terms.add(normalized)
|
|
1199
|
-
ordered_unique_terms.append(term_text.strip())
|
|
1200
|
-
return ordered_unique_terms
|
|
1201
|
-
|
|
1202
|
-
def _build_term_to_doc_ids(
|
|
1203
|
-
self, doc_term_extractions: List[Dict]
|
|
1204
|
-
) -> Dict[str, List[str]]:
|
|
1205
|
-
"""
|
|
1206
|
-
Build lookup: normalized_term -> sorted unique list of doc_ids.
|
|
1207
|
-
"""
|
|
1208
|
-
term_to_doc_set: Dict[str, set] = {}
|
|
1209
|
-
for record in doc_term_extractions:
|
|
1210
|
-
document_id = str(record.get("id", ""))
|
|
1211
|
-
for term_text in record.get("extracted_terms", []):
|
|
1212
|
-
normalized = self._normalize_term(term_text)
|
|
1213
|
-
if not normalized or not document_id:
|
|
1214
|
-
continue
|
|
1215
|
-
term_to_doc_set.setdefault(normalized, set()).add(document_id)
|
|
1216
|
-
return {
|
|
1217
|
-
normalized_term: sorted(doc_ids)
|
|
1218
|
-
for normalized_term, doc_ids in term_to_doc_set.items()
|
|
1219
|
-
}
|
|
580
|
+
doc2terms_pred: Dict[str, List[str]] = {}
|
|
581
|
+
for d in docs:
|
|
582
|
+
did = self._doc_id(d)
|
|
583
|
+
doc2terms_pred[did] = self._doc_to_terms(d)
|
|
584
|
+
|
|
585
|
+
unique_terms = sorted({t for ts in doc2terms_pred.values() for t in ts})
|
|
586
|
+
term2types_pred: Dict[str, List[str]] = {t: self._term_to_types(t) for t in unique_terms}
|
|
587
|
+
|
|
588
|
+
doc2types_pred: Dict[str, List[str]] = {}
|
|
589
|
+
for did, terms in doc2terms_pred.items():
|
|
590
|
+
tys: List[str] = []
|
|
591
|
+
for t in terms:
|
|
592
|
+
tys.extend(term2types_pred.get(t, []))
|
|
593
|
+
doc2types_pred[did] = self._dedup_clean(tys)
|
|
594
|
+
|
|
595
|
+
pred_terms = [{"doc_id": did, "term": t} for did, ts in doc2terms_pred.items() for t in ts]
|
|
596
|
+
pred_types = [{"doc_id": did, "type": ty} for did, tys in doc2types_pred.items() for ty in tys]
|
|
597
|
+
|
|
598
|
+
return {"terms": pred_terms, "types": pred_types}
|