OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +15 -12
- ontolearner/learner/__init__.py +1 -1
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/retriever/__init__.py +19 -0
- ontolearner/learner/retriever/crossencoder.py +129 -0
- ontolearner/learner/retriever/embedding.py +229 -0
- ontolearner/learner/retriever/learner.py +217 -0
- ontolearner/learner/retriever/llm_retriever.py +356 -0
- ontolearner/learner/retriever/ngram.py +123 -0
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/METADATA +16 -12
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/RECORD +26 -9
- ontolearner/learner/retriever.py +0 -101
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC
|
|
16
|
+
from typing import Any, List, Dict
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
import time
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from ...base import AutoRetriever
|
|
22
|
+
from ...utils import load_json
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LLMAugmenterGenerator(ABC):
|
|
26
|
+
"""
|
|
27
|
+
A generator class responsible for creating augmented query candidates using LLMs
|
|
28
|
+
such as GPT-4 and GPT-3.5. This class provides augmentation support for
|
|
29
|
+
three ontology-learning tasks:
|
|
30
|
+
|
|
31
|
+
- term-typing
|
|
32
|
+
- taxonomy-discovery
|
|
33
|
+
- non-taxonomic relation extraction
|
|
34
|
+
|
|
35
|
+
For taxonomy discovery, it invokes a function-calling LLM that returns
|
|
36
|
+
candidate parent classes for each query term.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
client (OpenAI): OpenAI API client used for LLM inference.
|
|
40
|
+
model_id (str): The LLM model identifier.
|
|
41
|
+
term_typing_function (list): Function call schema for term typing (currently unused).
|
|
42
|
+
taxonomy_discovery_function (list): Function call schema for taxonomy discovery.
|
|
43
|
+
non_taxonomic_re_function (list): Function call schema for non-taxonomic relation extraction.
|
|
44
|
+
top_n_candidate (int): Number of augmented candidates to generate per query.
|
|
45
|
+
term_typing_prompt (str): Prompt template used for term typing tasks.
|
|
46
|
+
taxonomy_discovery_prompt (str): Prompt template used for taxonomy discovery.
|
|
47
|
+
non_taxonomic_re_prompt (str): Prompt template for non-taxonomic RE.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, model_id: str = 'gpt-4.1-mini', token: str = '', top_n_candidate: int = 5) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Initialize the LLM augmenter generator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model_id (str): Name of the OpenAI model to use.
|
|
56
|
+
token (str): API key for authentication.
|
|
57
|
+
top_n_candidate (int): Number of generated candidate parents per query.
|
|
58
|
+
"""
|
|
59
|
+
self.client = OpenAI(api_key=token)
|
|
60
|
+
|
|
61
|
+
self.model_id = model_id
|
|
62
|
+
|
|
63
|
+
self.term_typing_function = []
|
|
64
|
+
self.taxonomy_discovery_function = [
|
|
65
|
+
{
|
|
66
|
+
"name": "discover_taxonomy_parents",
|
|
67
|
+
"description": "Given a specific type or class (the query), identify potential parent classes that form valid hierarchical (is-a) relationships within a taxonomy.",
|
|
68
|
+
"parameters": {
|
|
69
|
+
"type": "object",
|
|
70
|
+
"properties": {
|
|
71
|
+
"candidate_parents": {
|
|
72
|
+
"type": "array",
|
|
73
|
+
"items": {"type": "string"},
|
|
74
|
+
"description": "A ranked list of candidate parent classes representing higher-level categories."
|
|
75
|
+
}
|
|
76
|
+
},
|
|
77
|
+
"required": ["candidate_parents"]
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
self.non_taxonomic_re_function = []
|
|
83
|
+
self.top_n_candidate = top_n_candidate
|
|
84
|
+
|
|
85
|
+
self.term_typing_prompt = ""
|
|
86
|
+
self.taxonomy_discovery_prompt = (
|
|
87
|
+
"Given a type (or class) {query}, generate a list of the top {top_n_candidate} candidate classes "
|
|
88
|
+
"that can form hierarchical (is-a) relationships, where each of these classes is a parent of {query}."
|
|
89
|
+
)
|
|
90
|
+
self.non_taxonomic_re_prompt = ""
|
|
91
|
+
|
|
92
|
+
def get_config(self) -> Dict[str, Any]:
|
|
93
|
+
"""
|
|
94
|
+
Get augmenter configuration metadata.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
dict: Dictionary containing the augmentation configuration.
|
|
98
|
+
"""
|
|
99
|
+
return {
|
|
100
|
+
"top_n_candidate": self.top_n_candidate,
|
|
101
|
+
"augmenter_model": self.model_id
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
def generate(self, conversation, function):
|
|
105
|
+
"""
|
|
106
|
+
Call an LLM to produce augmented candidates using function-calling.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
conversation (list): Dialogue messages to send to the LLM.
|
|
110
|
+
function (list): Function schemas supplied to the model.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
list[str]: A list of top-k generated candidates.
|
|
114
|
+
"""
|
|
115
|
+
while True:
|
|
116
|
+
try:
|
|
117
|
+
completion = self.client.chat.completions.create(
|
|
118
|
+
model=self.model_id,
|
|
119
|
+
messages=conversation,
|
|
120
|
+
functions=function
|
|
121
|
+
)
|
|
122
|
+
inference = eval(completion.choices[0].message.function_call.arguments)['candidate_parents'][:self.top_n_candidate]
|
|
123
|
+
assert len(inference) == self.top_n_candidate
|
|
124
|
+
break
|
|
125
|
+
except Exception:
|
|
126
|
+
print("sleep for 5 seconds")
|
|
127
|
+
time.sleep(5)
|
|
128
|
+
|
|
129
|
+
return inference
|
|
130
|
+
|
|
131
|
+
def tasks_data_former(self, data: Any, task: str) -> List[str] | Dict[str, List[str]]:
|
|
132
|
+
"""
|
|
133
|
+
Convert raw dataset input into query lists depending on the ontology-learning task.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
data (Any): Input dataset object.
|
|
137
|
+
task (str): One of {'term-typing', 'taxonomy-discovery', 'non-taxonomic-re'}.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
List[str] or Dict[str, List[str]]: Formatted query inputs.
|
|
141
|
+
"""
|
|
142
|
+
formatted_data = []
|
|
143
|
+
if task == "term-typing":
|
|
144
|
+
for typing in data.term_typings:
|
|
145
|
+
formatted_data.append(typing.term)
|
|
146
|
+
formatted_data = list(set(formatted_data))
|
|
147
|
+
|
|
148
|
+
if task == "taxonomy-discovery":
|
|
149
|
+
for taxonomic_pairs in data.type_taxonomies.taxonomies:
|
|
150
|
+
formatted_data.append(taxonomic_pairs.parent)
|
|
151
|
+
formatted_data.append(taxonomic_pairs.child)
|
|
152
|
+
formatted_data = list(set(formatted_data))
|
|
153
|
+
|
|
154
|
+
if task == "non-taxonomic-re":
|
|
155
|
+
non_taxonomic_types = []
|
|
156
|
+
non_taxonomic_res = []
|
|
157
|
+
for triplet in data.type_non_taxonomic_relations.non_taxonomies:
|
|
158
|
+
non_taxonomic_types.extend([triplet.head, triplet.tail])
|
|
159
|
+
non_taxonomic_res.append(triplet.relation)
|
|
160
|
+
formatted_data = {"types": list(set(non_taxonomic_types)), "relations": list(set(non_taxonomic_res))}
|
|
161
|
+
|
|
162
|
+
return formatted_data
|
|
163
|
+
|
|
164
|
+
def _augment(self, query, conversations, function):
|
|
165
|
+
"""
|
|
166
|
+
Internal helper to generate augmented candidates for a batch of queries.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
query (list[str]): Input query terms.
|
|
170
|
+
conversations (list): LLM conversation blocks for each query.
|
|
171
|
+
function (list): Function-calling schemas.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
dict[str, list[str]]: Mapping from query → list of augmented candidates.
|
|
175
|
+
"""
|
|
176
|
+
results = {}
|
|
177
|
+
for qu, conversation in tqdm(zip(query, conversations)):
|
|
178
|
+
results[qu] = self.generate(conversation=conversation, function=function)
|
|
179
|
+
return results
|
|
180
|
+
|
|
181
|
+
def augment_term_typing(self, query: List[str]) -> List[str]:
|
|
182
|
+
"""
|
|
183
|
+
Augment term-typing queries.
|
|
184
|
+
|
|
185
|
+
Currently a passthrough: no augmentation is performed.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
query (list[str]): Query terms.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
list[str]: Unmodified query terms.
|
|
192
|
+
"""
|
|
193
|
+
return query
|
|
194
|
+
|
|
195
|
+
def augment_non_taxonomic_re(self, query: List[str]) -> List[str]:
|
|
196
|
+
"""
|
|
197
|
+
Augment non-taxonomic relation extraction queries.
|
|
198
|
+
|
|
199
|
+
Currently a passthrough.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
query (list[str]): Query terms.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
list[str]: Unmodified query terms.
|
|
206
|
+
"""
|
|
207
|
+
return query
|
|
208
|
+
|
|
209
|
+
def augment_taxonomy_discovery(self, query: List[str]) -> Dict[str, List[str]]:
|
|
210
|
+
"""
|
|
211
|
+
Generate augmented candidates for taxonomy discovery.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
query (list[str]): List of type/class names to augment.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
dict[str, list[str]]: Mapping of original query → list of candidate parents.
|
|
218
|
+
"""
|
|
219
|
+
conversations = []
|
|
220
|
+
for qu in query:
|
|
221
|
+
prompt = self.taxonomy_discovery_prompt.format(query=qu, top_n_candidate=self.top_n_candidate)
|
|
222
|
+
conversation = [
|
|
223
|
+
{"role": "system", "content": "Discover possible taxonomy parents."},
|
|
224
|
+
{"role": "user", "content": prompt}
|
|
225
|
+
]
|
|
226
|
+
conversations.append(conversation)
|
|
227
|
+
|
|
228
|
+
return self._augment(query=query, conversations=conversations, function=self.taxonomy_discovery_function)
|
|
229
|
+
|
|
230
|
+
def augment(self, data: Any, task: str):
|
|
231
|
+
"""
|
|
232
|
+
Main entry point for all augmentation modes.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
data (Any): Dataset object to format and augment.
|
|
236
|
+
task (str): Task type.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Any: Augmented output suitable for a retriever.
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
ValueError: If an invalid task type is given.
|
|
243
|
+
"""
|
|
244
|
+
data = self.tasks_data_former(data=data, task=task)
|
|
245
|
+
if task == 'term-typing':
|
|
246
|
+
return self.augment_term_typing(data)
|
|
247
|
+
elif task == 'taxonomy-discovery':
|
|
248
|
+
return self.augment_taxonomy_discovery(data)
|
|
249
|
+
elif task == 'non-taxonomic-re':
|
|
250
|
+
return self.augment_non_taxonomic_re(data)
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"{task} is not a valid task.")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class LLMAugmenter:
|
|
256
|
+
"""
|
|
257
|
+
A lightweight augmenter that loads precomputed augmentation data from disk.
|
|
258
|
+
|
|
259
|
+
Attributes:
|
|
260
|
+
augments (dict): Loaded augmentation data.
|
|
261
|
+
top_n_candidate (int): Number of augmentation candidates per query.
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(self, path: str) -> None:
|
|
265
|
+
"""
|
|
266
|
+
Initialize an augmenter that uses offline augmentation data.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
path (str): Path to a JSON file containing saved augmentations.
|
|
270
|
+
"""
|
|
271
|
+
self.augments = load_json(path)
|
|
272
|
+
self.top_n_candidate = self.augments['config']['top_n_candidate']
|
|
273
|
+
|
|
274
|
+
def transform(self, query: str, task: str) -> List[str]:
|
|
275
|
+
"""
|
|
276
|
+
Retrieve the augmented versions of a query term for a specific task.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
query (str): Input query term.
|
|
280
|
+
task (str): Task identifier.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
list[str]: Augmented query candidates.
|
|
284
|
+
"""
|
|
285
|
+
if task == 'taxonomy-discovery':
|
|
286
|
+
return self.augments[task].get(query, [query])
|
|
287
|
+
else:
|
|
288
|
+
return [query]
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class LLMAugmentedRetriever(AutoRetriever):
|
|
292
|
+
"""
|
|
293
|
+
A retriever that enhances queries using LLM-based augmentation before retrieving documents.
|
|
294
|
+
|
|
295
|
+
Supports special augmentation logic for taxonomy discovery where each input query
|
|
296
|
+
is expanded into several augmented variants.
|
|
297
|
+
|
|
298
|
+
Attributes:
|
|
299
|
+
augmenter: An augmenter instance that provides transform() and top_n_candidate.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def __init__(self) -> None:
|
|
303
|
+
"""
|
|
304
|
+
Initialize the augmented retriever with no augmenter attached.
|
|
305
|
+
"""
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.augmenter = None
|
|
308
|
+
|
|
309
|
+
def set_augmenter(self, augmenter):
|
|
310
|
+
"""
|
|
311
|
+
Attach an augmenter instance.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
augmenter: An object providing `transform(query, task)` and `top_n_candidate`.
|
|
315
|
+
"""
|
|
316
|
+
self.augmenter = augmenter
|
|
317
|
+
|
|
318
|
+
def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1, task: str = None) -> List[List[str]]:
|
|
319
|
+
"""
|
|
320
|
+
Retrieve documents for a batch of queries, optionally using query augmentation.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
query (list[str]): List of input query terms.
|
|
324
|
+
top_k (int): Number of documents to retrieve.
|
|
325
|
+
batch_size (int): Batch size for retrieval.
|
|
326
|
+
task (str): Optional task identifier that determines augmentation behavior.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
list[list[str]]: A list of document lists, one per input query.
|
|
330
|
+
"""
|
|
331
|
+
parent_retrieve = super(LLMAugmentedRetriever, self).retrieve
|
|
332
|
+
|
|
333
|
+
if task == 'taxonomy-discovery':
|
|
334
|
+
query_sets = []
|
|
335
|
+
for idx in range(self.augmenter.top_n_candidate):
|
|
336
|
+
query_set = []
|
|
337
|
+
for qu in query:
|
|
338
|
+
query_set.append(self.augmenter.transform(qu, task=task)[idx])
|
|
339
|
+
query_sets.append(query_set)
|
|
340
|
+
|
|
341
|
+
retrieves = [
|
|
342
|
+
parent_retrieve(query=query_set, top_k=top_k, batch_size=batch_size)
|
|
343
|
+
for query_set in query_sets
|
|
344
|
+
]
|
|
345
|
+
|
|
346
|
+
results = []
|
|
347
|
+
for qu_idx, qu in enumerate(query):
|
|
348
|
+
qu_result = []
|
|
349
|
+
for top_idx in range(self.augmenter.top_n_candidate):
|
|
350
|
+
qu_result += retrieves[top_idx][qu_idx]
|
|
351
|
+
results.append(list(set(qu_result)))
|
|
352
|
+
|
|
353
|
+
return results
|
|
354
|
+
|
|
355
|
+
else:
|
|
356
|
+
return parent_retrieve(query=query, top_k=top_k, batch_size=batch_size)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import numpy as np
|
|
16
|
+
from typing import List
|
|
17
|
+
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
|
18
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from ...base import AutoRetriever
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NgramRetriever(AutoRetriever):
|
|
27
|
+
"""
|
|
28
|
+
A retriever based on traditional n-gram vectorization methods such as TF-IDF
|
|
29
|
+
and CountVectorizer.
|
|
30
|
+
|
|
31
|
+
This retriever converts documents and queries into sparse bag-of-ngrams
|
|
32
|
+
vectors and ranks documents using cosine similarity. It is simple,
|
|
33
|
+
interpretable, and suitable for small-scale baselines or non-semantic
|
|
34
|
+
text matching.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, **vectorizer_kwargs) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Initialize the n-gram retriever.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
**vectorizer_kwargs: Additional keyword arguments passed directly
|
|
43
|
+
to the scikit-learn vectorizer (e.g., ngram_range, stop_words).
|
|
44
|
+
"""
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.vectorizer_kwargs = vectorizer_kwargs
|
|
47
|
+
self.vectorizer = None
|
|
48
|
+
self.embeddings = None
|
|
49
|
+
|
|
50
|
+
def load(self, model_id) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Load and initialize the vectorizer based on `model_id`.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model_id (str): Either `"tfidf"` for TF-IDF or `"count"` for
|
|
56
|
+
CountVectorizer.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If the model_id is not one of the supported options.
|
|
60
|
+
"""
|
|
61
|
+
if model_id == "tfidf":
|
|
62
|
+
self.vectorizer = TfidfVectorizer(**self.vectorizer_kwargs)
|
|
63
|
+
elif model_id == "count":
|
|
64
|
+
self.vectorizer = CountVectorizer(**self.vectorizer_kwargs)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(f"Invalid mode '{model_id}'. Choose from ['tfidf', 'count'].")
|
|
67
|
+
|
|
68
|
+
def index(self, inputs: List[str]) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Fit the vectorizer and index (vectorize) the input documents.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
inputs (List[str]): List of text documents to index.
|
|
74
|
+
|
|
75
|
+
Notes:
|
|
76
|
+
This method must be run before calling `retrieve()`. It creates the
|
|
77
|
+
document embedding matrix used for similarity search.
|
|
78
|
+
"""
|
|
79
|
+
if self.vectorizer is None:
|
|
80
|
+
# Default to TF-IDF if the user never called `load()`
|
|
81
|
+
self.load(model_id="tfidf")
|
|
82
|
+
|
|
83
|
+
self.documents = inputs
|
|
84
|
+
logger.info("Fitting vectorizer and transforming documents...")
|
|
85
|
+
self.embeddings = self.vectorizer.fit_transform(inputs)
|
|
86
|
+
logger.info(f"Document embeddings created with shape: {self.embeddings.shape}")
|
|
87
|
+
|
|
88
|
+
def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1) -> List[List[str]]:
|
|
89
|
+
"""
|
|
90
|
+
Retrieve the most similar documents for each query string.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
query (List[str]): A list of query strings.
|
|
94
|
+
top_k (int): Number of most similar documents to return per query.
|
|
95
|
+
batch_size (int): Number of queries to process at once.
|
|
96
|
+
Use `-1` to process all queries in a single batch.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
List[List[str]]: For each query, a list containing the top-k
|
|
100
|
+
matching documents.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
RuntimeError: If retrieval is attempted before indexing.
|
|
104
|
+
"""
|
|
105
|
+
if self.embeddings is None:
|
|
106
|
+
raise RuntimeError("Retriever must index documents before calling `retrieve()`.")
|
|
107
|
+
|
|
108
|
+
logger.info("Vectorizing query text...")
|
|
109
|
+
query_vec = self.vectorizer.transform(query)
|
|
110
|
+
logger.info(f"Query vectors created with shape: {query_vec.shape}")
|
|
111
|
+
|
|
112
|
+
results = []
|
|
113
|
+
if batch_size == -1:
|
|
114
|
+
batch_size = len(query)
|
|
115
|
+
|
|
116
|
+
for i in tqdm(range(0, len(query), batch_size)):
|
|
117
|
+
q_batch = query_vec[i : i + batch_size]
|
|
118
|
+
sim = cosine_similarity(q_batch, self.embeddings)
|
|
119
|
+
topk_idx = np.argsort(sim, axis=1)[:, ::-1][:, :top_k]
|
|
120
|
+
for row_indices in topk_idx:
|
|
121
|
+
results.append([self.documents[j] for j in row_indices])
|
|
122
|
+
|
|
123
|
+
return results
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .alexbek import AlexbekCrossAttnLearner
|
|
16
|
+
from .rwthdbis import RWTHDBISSFTLearner
|
|
17
|
+
from .sbunlp import SBUNLPFewShotLearner
|
|
18
|
+
from .skhnlp import SKHNLPSequentialFTLearner, SKHNLPZSLearner
|