OntoLearner 1.4.10__py3-none-any.whl → 1.4.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +38 -17
- ontolearner/evaluation/metrics.py +72 -32
- ontolearner/learner/__init__.py +1 -1
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/prompt.py +40 -5
- ontolearner/learner/rag/__init__.py +14 -0
- ontolearner/learner/{rag.py → rag/rag.py} +7 -2
- ontolearner/learner/retriever/__init__.py +1 -1
- ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
- ontolearner/learner/retriever/learner.py +3 -4
- ontolearner/learner/text2onto/__init__.py +1 -1
- ontolearner/learner/text2onto/alexbek.py +484 -1105
- ontolearner/learner/text2onto/sbunlp.py +498 -493
- ontolearner/text2onto/splitter.py +69 -6
- {ontolearner-1.4.10.dist-info → ontolearner-1.4.11.dist-info}/METADATA +2 -2
- {ontolearner-1.4.10.dist-info → ontolearner-1.4.11.dist-info}/RECORD +19 -18
- {ontolearner-1.4.10.dist-info → ontolearner-1.4.11.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.10.dist-info → ontolearner-1.4.11.dist-info}/licenses/LICENSE +0 -0
ontolearner/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
1.4.
|
|
1
|
+
1.4.11
|
ontolearner/base/learner.py
CHANGED
|
@@ -18,6 +18,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
18
18
|
import torch
|
|
19
19
|
import torch.nn.functional as F
|
|
20
20
|
from sentence_transformers import SentenceTransformer
|
|
21
|
+
from collections import defaultdict
|
|
21
22
|
|
|
22
23
|
class AutoLearner(ABC):
|
|
23
24
|
"""
|
|
@@ -70,6 +71,7 @@ class AutoLearner(ABC):
|
|
|
70
71
|
- "term-typing": Predict semantic types for terms
|
|
71
72
|
- "taxonomy-discovery": Identify hierarchical relationships
|
|
72
73
|
- "non-taxonomy-discovery": Identify non-hierarchical relationships
|
|
74
|
+
- "text2onto" : Extract ontology terms and their semantic types from documents
|
|
73
75
|
|
|
74
76
|
Raises:
|
|
75
77
|
NotImplementedError: If not implemented by concrete class.
|
|
@@ -81,6 +83,8 @@ class AutoLearner(ABC):
|
|
|
81
83
|
self._taxonomy_discovery(train_data, test=False)
|
|
82
84
|
elif task == 'non-taxonomic-re':
|
|
83
85
|
self._non_taxonomic_re(train_data, test=False)
|
|
86
|
+
elif task == 'text2onto':
|
|
87
|
+
self._text2onto(train_data, test=False)
|
|
84
88
|
else:
|
|
85
89
|
raise ValueError(f"{task} is not a valid task.")
|
|
86
90
|
|
|
@@ -103,6 +107,7 @@ class AutoLearner(ABC):
|
|
|
103
107
|
- term-typing: List of predicted types for each term
|
|
104
108
|
- taxonomy-discovery: Boolean predictions for relationships
|
|
105
109
|
- non-taxonomy-discovery: Predicted relation types
|
|
110
|
+
- text2onto : Extract ontology terms and their semantic types from documents
|
|
106
111
|
|
|
107
112
|
Raises:
|
|
108
113
|
NotImplementedError: If not implemented by concrete class.
|
|
@@ -115,6 +120,8 @@ class AutoLearner(ABC):
|
|
|
115
120
|
return self._taxonomy_discovery(eval_data, test=True)
|
|
116
121
|
elif task == 'non-taxonomic-re':
|
|
117
122
|
return self._non_taxonomic_re(eval_data, test=True)
|
|
123
|
+
elif task == 'text2onto':
|
|
124
|
+
return self._text2onto(eval_data, test=True)
|
|
118
125
|
else:
|
|
119
126
|
raise ValueError(f"{task} is not a valid task.")
|
|
120
127
|
|
|
@@ -147,6 +154,9 @@ class AutoLearner(ABC):
|
|
|
147
154
|
def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
148
155
|
pass
|
|
149
156
|
|
|
157
|
+
def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
158
|
+
pass
|
|
159
|
+
|
|
150
160
|
def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
|
|
151
161
|
formatted_data = []
|
|
152
162
|
if task == "term-typing":
|
|
@@ -171,6 +181,7 @@ class AutoLearner(ABC):
|
|
|
171
181
|
non_taxonomic_types = list(set(non_taxonomic_types))
|
|
172
182
|
non_taxonomic_res = list(set(non_taxonomic_res))
|
|
173
183
|
formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
|
|
184
|
+
|
|
174
185
|
return formatted_data
|
|
175
186
|
|
|
176
187
|
def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
|
|
@@ -186,6 +197,26 @@ class AutoLearner(ABC):
|
|
|
186
197
|
formatted_data.append({"head": non_taxonomic_triplets.head,
|
|
187
198
|
"tail": non_taxonomic_triplets.tail,
|
|
188
199
|
"relation": non_taxonomic_triplets.relation})
|
|
200
|
+
if task == "text2onto":
|
|
201
|
+
terms2docs = data.get("terms2docs", {}) or {}
|
|
202
|
+
terms2types = data.get("terms2types", {}) or {}
|
|
203
|
+
|
|
204
|
+
# gold doc→terms
|
|
205
|
+
gold_terms = []
|
|
206
|
+
for term, doc_ids in terms2docs.items():
|
|
207
|
+
for doc_id in doc_ids or []:
|
|
208
|
+
gold_terms.append({"doc_id": doc_id, "term": term})
|
|
209
|
+
|
|
210
|
+
# gold doc→types derived via doc→terms + term→types
|
|
211
|
+
doc2types = defaultdict(set)
|
|
212
|
+
for term, doc_ids in terms2docs.items():
|
|
213
|
+
for doc_id in doc_ids or []:
|
|
214
|
+
for ty in (terms2types.get(term, []) or []):
|
|
215
|
+
if isinstance(ty, str) and ty.strip():
|
|
216
|
+
doc2types[doc_id].add(ty.strip())
|
|
217
|
+
gold_types = [{"doc_id": doc_id, "type": ty} for doc_id, tys in doc2types.items() for ty in tys]
|
|
218
|
+
return {"terms": gold_terms, "types": gold_types}
|
|
219
|
+
|
|
189
220
|
return formatted_data
|
|
190
221
|
|
|
191
222
|
class AutoLLM(ABC):
|
|
@@ -201,7 +232,7 @@ class AutoLLM(ABC):
|
|
|
201
232
|
tokenizer: The tokenizer associated with the model.
|
|
202
233
|
"""
|
|
203
234
|
|
|
204
|
-
def __init__(self, label_mapper: Any, device: str='cpu', token: str="") -> None:
|
|
235
|
+
def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 256) -> None:
|
|
205
236
|
"""
|
|
206
237
|
Initialize the LLM component.
|
|
207
238
|
|
|
@@ -213,6 +244,7 @@ class AutoLLM(ABC):
|
|
|
213
244
|
self.device=device
|
|
214
245
|
self.model: Optional[Any] = None
|
|
215
246
|
self.tokenizer: Optional[Any] = None
|
|
247
|
+
self.max_length = max_length
|
|
216
248
|
|
|
217
249
|
|
|
218
250
|
def load(self, model_id: str) -> None:
|
|
@@ -236,10 +268,8 @@ class AutoLLM(ABC):
|
|
|
236
268
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
|
|
237
269
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
238
270
|
if self.device == "cpu":
|
|
239
|
-
# device_map = "cpu"
|
|
240
271
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
241
272
|
model_id,
|
|
242
|
-
# device_map=device_map,
|
|
243
273
|
torch_dtype=torch.bfloat16,
|
|
244
274
|
token=self.token
|
|
245
275
|
)
|
|
@@ -248,8 +278,8 @@ class AutoLLM(ABC):
|
|
|
248
278
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
249
279
|
model_id,
|
|
250
280
|
device_map=device_map,
|
|
251
|
-
|
|
252
|
-
|
|
281
|
+
token=self.token,
|
|
282
|
+
trust_remote_code=True,
|
|
253
283
|
)
|
|
254
284
|
self.label_mapper.fit()
|
|
255
285
|
|
|
@@ -276,29 +306,20 @@ class AutoLLM(ABC):
|
|
|
276
306
|
List of generated text responses, one for each input prompt.
|
|
277
307
|
Responses include the original input plus generated continuation.
|
|
278
308
|
"""
|
|
279
|
-
# Tokenize inputs and move to device
|
|
280
309
|
encoded_inputs = self.tokenizer(inputs,
|
|
281
310
|
return_tensors="pt",
|
|
282
|
-
|
|
311
|
+
max_length=self.max_length,
|
|
283
312
|
truncation=True).to(self.model.device)
|
|
284
313
|
input_ids = encoded_inputs["input_ids"]
|
|
285
314
|
input_length = input_ids.shape[1]
|
|
286
|
-
|
|
287
|
-
# Generate output
|
|
288
315
|
outputs = self.model.generate(
|
|
289
316
|
**encoded_inputs,
|
|
290
317
|
max_new_tokens=max_new_tokens,
|
|
291
|
-
pad_token_id=self.tokenizer.eos_token_id
|
|
318
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
319
|
+
eos_token_id=self.tokenizer.eos_token_id
|
|
292
320
|
)
|
|
293
|
-
|
|
294
|
-
# Extract only the newly generated tokens (excluding prompt)
|
|
295
321
|
generated_tokens = outputs[:, input_length:]
|
|
296
|
-
|
|
297
|
-
# Decode only the generated part
|
|
298
322
|
decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
|
|
299
|
-
print(decoded_outputs)
|
|
300
|
-
print(self.label_mapper.predict(decoded_outputs))
|
|
301
|
-
# Map the decoded text to labels
|
|
302
323
|
return self.label_mapper.predict(decoded_outputs)
|
|
303
324
|
|
|
304
325
|
class AutoRetriever(ABC):
|
|
@@ -11,44 +11,84 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import List, Dict, Tuple, Set
|
|
14
|
+
from typing import List, Dict, Tuple, Set, Any, Union
|
|
15
15
|
|
|
16
16
|
SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
|
|
17
17
|
|
|
18
|
-
def text2onto_metrics(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
18
|
+
def text2onto_metrics(
|
|
19
|
+
y_true: Dict[str, Any],
|
|
20
|
+
y_pred: Dict[str, Any],
|
|
21
|
+
similarity_threshold: float = 0.8
|
|
22
|
+
) -> Dict[str, Any]:
|
|
23
|
+
"""
|
|
24
|
+
Expects:
|
|
25
|
+
y_true = {"terms": [{"doc_id": str, "term": str}, ...],
|
|
26
|
+
"types": [{"doc_id": str, "type": str}, ...]}
|
|
27
|
+
y_pred = same shape
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
{"terms": {...}, "types": {...}}
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def jaccard_similarity(text_a: str, text_b: str) -> float:
|
|
34
|
+
tokens_a = set(text_a.lower().split())
|
|
35
|
+
tokens_b = set(text_b.lower().split())
|
|
36
|
+
if not tokens_a and not tokens_b:
|
|
23
37
|
return 1.0
|
|
24
|
-
return len(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
|
|
39
|
+
|
|
40
|
+
def pairs_to_strings(rows: List[Dict[str, str]], value_key: str) -> List[str]:
|
|
41
|
+
paired_strings: List[str] = []
|
|
42
|
+
for row in rows or []:
|
|
43
|
+
doc_id = (row.get("doc_id") or "").strip()
|
|
44
|
+
value = (row.get(value_key) or "").strip()
|
|
45
|
+
if doc_id and value:
|
|
46
|
+
# keep doc association + allow token Jaccard
|
|
47
|
+
paired_strings.append(f"{doc_id} {value}")
|
|
48
|
+
return paired_strings
|
|
49
|
+
|
|
50
|
+
def score_list(ground_truth_items: List[str], predicted_items: List[str]) -> Dict[str, Union[float, int]]:
|
|
51
|
+
matched_ground_truth_indices: Set[int] = set()
|
|
52
|
+
matched_predicted_indices: Set[int] = set()
|
|
53
|
+
|
|
54
|
+
for predicted_index, predicted_item in enumerate(predicted_items):
|
|
55
|
+
for ground_truth_index, ground_truth_item in enumerate(ground_truth_items):
|
|
56
|
+
if ground_truth_index in matched_ground_truth_indices:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
if jaccard_similarity(predicted_item, ground_truth_item) >= similarity_threshold:
|
|
60
|
+
matched_predicted_indices.add(predicted_index)
|
|
61
|
+
matched_ground_truth_indices.add(ground_truth_index)
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
total_correct = len(matched_predicted_indices)
|
|
65
|
+
total_predicted = len(predicted_items)
|
|
66
|
+
total_ground_truth = len(ground_truth_items)
|
|
67
|
+
|
|
68
|
+
precision = total_correct / total_predicted if total_predicted else 0.0
|
|
69
|
+
recall = total_correct / total_ground_truth if total_ground_truth else 0.0
|
|
70
|
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
"f1_score": f1,
|
|
74
|
+
"precision": precision,
|
|
75
|
+
"recall": recall,
|
|
76
|
+
"total_correct": total_correct,
|
|
77
|
+
"total_predicted": total_predicted,
|
|
78
|
+
"total_ground_truth": total_ground_truth,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
ground_truth_terms = pairs_to_strings(y_true.get("terms", []), "term")
|
|
82
|
+
predicted_terms = pairs_to_strings(y_pred.get("terms", []), "term")
|
|
83
|
+
ground_truth_types = pairs_to_strings(y_true.get("types", []), "type")
|
|
84
|
+
predicted_types = pairs_to_strings(y_pred.get("types", []), "type")
|
|
85
|
+
|
|
86
|
+
terms_metrics = score_list(ground_truth_terms, predicted_terms)
|
|
87
|
+
types_metrics = score_list(ground_truth_types, predicted_types)
|
|
41
88
|
|
|
42
|
-
precision = total_correct / total_predicted if total_predicted > 0 else 0
|
|
43
|
-
recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0
|
|
44
|
-
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
45
89
|
return {
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"recall": recall,
|
|
49
|
-
"total_correct": total_correct,
|
|
50
|
-
"total_predicted": total_predicted,
|
|
51
|
-
"total_ground_truth": total_ground_truth
|
|
90
|
+
"terms": terms_metrics,
|
|
91
|
+
"types": types_metrics,
|
|
52
92
|
}
|
|
53
93
|
|
|
54
94
|
def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
|
ontolearner/learner/__init__.py
CHANGED
|
@@ -14,6 +14,6 @@
|
|
|
14
14
|
|
|
15
15
|
from .llm import AutoLLMLearner, FalconLLM, MistralLLM
|
|
16
16
|
from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
|
|
17
|
-
from .rag import AutoRAGLearner
|
|
17
|
+
from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
|
|
18
18
|
from .prompt import StandardizedPrompting
|
|
19
19
|
from .label_mapper import LabelMapper
|
|
@@ -31,7 +31,7 @@ class LabelMapper:
|
|
|
31
31
|
ngram_range: Tuple=(1, 1),
|
|
32
32
|
label_dict: Dict[str, List[str]]=None,
|
|
33
33
|
analyzer: str = 'word',
|
|
34
|
-
iterator_no: int =
|
|
34
|
+
iterator_no: int = 1000):
|
|
35
35
|
"""
|
|
36
36
|
Initializes the TFIDFLabelMapper with a specified classifier and TF-IDF configuration.
|
|
37
37
|
|
ontolearner/learner/prompt.py
CHANGED
|
@@ -17,15 +17,50 @@ from ..base import AutoPrompt
|
|
|
17
17
|
class StandardizedPrompting(AutoPrompt):
|
|
18
18
|
def __init__(self, task: str = None):
|
|
19
19
|
if task == "term-typing":
|
|
20
|
-
prompt_template = """
|
|
20
|
+
prompt_template = """You are performing term typing.
|
|
21
|
+
|
|
22
|
+
Determine whether the given term is a clear and unambiguous instance of the specified high-level type.
|
|
23
|
+
|
|
24
|
+
Rules:
|
|
25
|
+
- Answer "yes" only if the term commonly and directly belongs to the type.
|
|
26
|
+
- Answer "no" if the term does not belong to the type, is ambiguous, or only weakly related.
|
|
27
|
+
- Use the most common meaning of the term.
|
|
28
|
+
- Do not explain your answer.
|
|
29
|
+
|
|
21
30
|
Term: {term}
|
|
22
31
|
Type: {type}
|
|
23
|
-
Answer:
|
|
32
|
+
Answer (yes or no):"""
|
|
24
33
|
elif task == "taxonomy-discovery":
|
|
25
|
-
prompt_template =
|
|
26
|
-
|
|
34
|
+
prompt_template = """You are identifying taxonomic (is-a) relationships.
|
|
35
|
+
|
|
36
|
+
Question:
|
|
37
|
+
Is "{parent}" a superclass (direct or indirect) of "{child}" in a standard conceptual or ontological hierarchy?
|
|
38
|
+
|
|
39
|
+
Rules:
|
|
40
|
+
- A superclass means: "{child}" is a type or instance of "{parent}".
|
|
41
|
+
- Answer "yes" only if the relationship is a true is-a relationship.
|
|
42
|
+
- Answer "no" for part-of, related-to, or associative relationships.
|
|
43
|
+
- Use general world knowledge.
|
|
44
|
+
- Do not explain.
|
|
45
|
+
|
|
46
|
+
Parent: {parent}
|
|
47
|
+
Child: {child}
|
|
48
|
+
Answer (yes or no):"""
|
|
27
49
|
elif task == "non-taxonomic-re":
|
|
28
|
-
prompt_template = """
|
|
50
|
+
prompt_template = """You are identifying non-taxonomic conceptual relationships.
|
|
51
|
+
|
|
52
|
+
Given two conceptual types, determine whether the specified relation typically holds between them.
|
|
53
|
+
|
|
54
|
+
Rules:
|
|
55
|
+
- Answer "yes" only if the relation commonly and meaningfully applies.
|
|
56
|
+
- Answer "no" if the relation is rare, indirect, or context-dependent.
|
|
57
|
+
- Do not infer relations that require specific situations.
|
|
58
|
+
- Do not explain.
|
|
59
|
+
|
|
60
|
+
Head type: {head}
|
|
61
|
+
Tail type: {tail}
|
|
62
|
+
Relation: {relation}
|
|
63
|
+
Answer (yes or no):"""
|
|
29
64
|
else:
|
|
30
65
|
raise ValueError("Unknown task! Current tasks are: 'term-typing', 'taxonomy-discovery', 'non-taxonomic-re'")
|
|
31
66
|
super().__init__(prompt_template)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
|
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import warnings
|
|
16
16
|
from typing import Any
|
|
17
|
-
from
|
|
18
|
-
|
|
17
|
+
from ...base import AutoLearner
|
|
19
18
|
|
|
20
19
|
class AutoRAGLearner(AutoLearner):
|
|
21
20
|
def __init__(self,
|
|
@@ -87,3 +86,9 @@ class AutoRAGLearner(AutoLearner):
|
|
|
87
86
|
return self.llm._non_taxonomic_re_predict(dataset=dataset)
|
|
88
87
|
else:
|
|
89
88
|
warnings.warn("No requirement for fiting the non-taxonomic-re model, the predict module will use the input data to do the fit as well.")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LLMAugmentedRAGLearner(AutoRAGLearner):
|
|
92
|
+
|
|
93
|
+
def set_augmenter(self, augmenter):
|
|
94
|
+
self.retriever.set_augmenter(augmenter=augmenter)
|
|
@@ -16,4 +16,4 @@ from .crossencoder import CrossEncoderRetriever
|
|
|
16
16
|
from .embedding import GloveRetriever, Word2VecRetriever
|
|
17
17
|
from .ngram import NgramRetriever
|
|
18
18
|
from .learner import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
|
|
19
|
-
from .
|
|
19
|
+
from .augmented_retriever import LLMAugmenterGenerator, LLMAugmenter, LLMAugmentedRetriever
|
|
@@ -17,6 +17,8 @@ from typing import Any, List, Dict
|
|
|
17
17
|
from openai import OpenAI
|
|
18
18
|
import time
|
|
19
19
|
from tqdm import tqdm
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
20
22
|
|
|
21
23
|
from ...base import AutoRetriever
|
|
22
24
|
from ...utils import load_json
|
|
@@ -125,7 +127,6 @@ class LLMAugmenterGenerator(ABC):
|
|
|
125
127
|
except Exception:
|
|
126
128
|
print("sleep for 5 seconds")
|
|
127
129
|
time.sleep(5)
|
|
128
|
-
|
|
129
130
|
return inference
|
|
130
131
|
|
|
131
132
|
def tasks_data_former(self, data: Any, task: str) -> List[str] | Dict[str, List[str]]:
|
|
@@ -298,21 +299,12 @@ class LLMAugmentedRetriever(AutoRetriever):
|
|
|
298
299
|
Attributes:
|
|
299
300
|
augmenter: An augmenter instance that provides transform() and top_n_candidate.
|
|
300
301
|
"""
|
|
301
|
-
|
|
302
|
-
def __init__(self) -> None:
|
|
303
|
-
"""
|
|
304
|
-
Initialize the augmented retriever with no augmenter attached.
|
|
305
|
-
"""
|
|
302
|
+
def __init__(self, threshold: float = 0.0, cutoff_rate: float = 100.0) -> None:
|
|
306
303
|
super().__init__()
|
|
307
|
-
self.
|
|
304
|
+
self.threshold = threshold
|
|
305
|
+
self.cutoff_rate = cutoff_rate
|
|
308
306
|
|
|
309
307
|
def set_augmenter(self, augmenter):
|
|
310
|
-
"""
|
|
311
|
-
Attach an augmenter instance.
|
|
312
|
-
|
|
313
|
-
Args:
|
|
314
|
-
augmenter: An object providing `transform(query, task)` and `top_n_candidate`.
|
|
315
|
-
"""
|
|
316
308
|
self.augmenter = augmenter
|
|
317
309
|
|
|
318
310
|
def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1, task: str = None) -> List[List[str]]:
|
|
@@ -328,29 +320,46 @@ class LLMAugmentedRetriever(AutoRetriever):
|
|
|
328
320
|
Returns:
|
|
329
321
|
list[list[str]]: A list of document lists, one per input query.
|
|
330
322
|
"""
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
323
|
+
if task != 'taxonomy-discovery':
|
|
324
|
+
return super().retrieve(query=query, top_k=top_k, batch_size=batch_size)
|
|
325
|
+
return self.augmented_retrieve(query, top_k=top_k, batch_size=batch_size, task=task)
|
|
326
|
+
|
|
327
|
+
def augmented_retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1, task: str = None):
|
|
328
|
+
if self.embeddings is None:
|
|
329
|
+
raise RuntimeError("Retriever model must index documents before prediction.")
|
|
330
|
+
|
|
331
|
+
augmented_queries, index_map = [], []
|
|
332
|
+
for qu_idx, qu in enumerate(query):
|
|
333
|
+
augmented = self.augmenter.transform(qu, task=task)
|
|
334
|
+
for aug in augmented:
|
|
335
|
+
augmented_queries.append(aug)
|
|
336
|
+
index_map.append(qu_idx)
|
|
337
|
+
|
|
338
|
+
doc_norm = F.normalize(self.embeddings, p=2, dim=1)
|
|
339
|
+
results = [dict() for _ in range(len(query))]
|
|
340
|
+
|
|
341
|
+
if batch_size == -1:
|
|
342
|
+
batch_size = len(augmented_queries)
|
|
343
|
+
|
|
344
|
+
for start in range(0, len(augmented_queries), batch_size):
|
|
345
|
+
batch_aug = augmented_queries[start:start + batch_size]
|
|
346
|
+
batch_embeddings = self.embedding_model.encode(batch_aug, convert_to_tensor=True)
|
|
347
|
+
batch_norm = F.normalize(batch_embeddings, p=2, dim=1)
|
|
348
|
+
similarity_matrix = torch.matmul(batch_norm, doc_norm.T)
|
|
349
|
+
current_top_k = min(top_k, len(self.documents))
|
|
350
|
+
topk_similarities, topk_indices = torch.topk(similarity_matrix, k=current_top_k, dim=1)
|
|
351
|
+
|
|
352
|
+
for i, (doc_indices, sim_scores) in enumerate(zip(topk_indices, topk_similarities)):
|
|
353
|
+
original_query_idx = index_map[start + i]
|
|
354
|
+
|
|
355
|
+
for doc_idx, score in zip(doc_indices.tolist(), sim_scores.tolist()):
|
|
356
|
+
if score >= self.threshold:
|
|
357
|
+
doc = self.documents[doc_idx]
|
|
358
|
+
prev = results[original_query_idx].get(doc, 0.0)
|
|
359
|
+
results[original_query_idx][doc] = prev + score
|
|
360
|
+
|
|
361
|
+
final_results = []
|
|
362
|
+
for doc_score_map in results:
|
|
363
|
+
sorted_docs = sorted(doc_score_map.items(), key=lambda x: x[1], reverse=True)
|
|
364
|
+
final_results.append([doc for doc, _ in sorted_docs])
|
|
365
|
+
return final_results
|
|
@@ -122,7 +122,6 @@ class AutoRetrieverLearner(AutoLearner):
|
|
|
122
122
|
warnings.warn("No requirement for fiting the non-taxonomic RE model, the predict module will use the input data to do the fit as well..")
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
|
|
126
125
|
class LLMAugmentedRetrieverLearner(AutoRetrieverLearner):
|
|
127
126
|
|
|
128
127
|
def set_augmenter(self, augmenter):
|
|
@@ -160,9 +159,9 @@ class LLMAugmentedRetrieverLearner(AutoRetrieverLearner):
|
|
|
160
159
|
taxonomic_pairs = [{"parent": candidate, "child": query}
|
|
161
160
|
for query, candidates in zip(data, candidates_lst)
|
|
162
161
|
for candidate in candidates if candidate.lower() != query.lower()]
|
|
163
|
-
taxonomic_pairs += [{"parent": query, "child": candidate}
|
|
164
|
-
|
|
165
|
-
|
|
162
|
+
# taxonomic_pairs += [{"parent": query, "child": candidate}
|
|
163
|
+
# for query, candidates in zip(data, candidates_lst)
|
|
164
|
+
# for candidate in candidates if candidate.lower() != query.lower()]
|
|
166
165
|
unique_taxonomic_pairs, seen = [], set()
|
|
167
166
|
for pair in taxonomic_pairs:
|
|
168
167
|
key = (pair["parent"].lower(), pair["child"].lower()) # Directional key (parent, child)
|