OntoLearner 1.4.9__py3-none-any.whl → 1.4.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ontolearner/VERSION CHANGED
@@ -1 +1 @@
1
- 1.4.9
1
+ 1.4.11
@@ -18,6 +18,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
  from sentence_transformers import SentenceTransformer
21
+ from collections import defaultdict
21
22
 
22
23
  class AutoLearner(ABC):
23
24
  """
@@ -70,6 +71,7 @@ class AutoLearner(ABC):
70
71
  - "term-typing": Predict semantic types for terms
71
72
  - "taxonomy-discovery": Identify hierarchical relationships
72
73
  - "non-taxonomy-discovery": Identify non-hierarchical relationships
74
+ - "text2onto" : Extract ontology terms and their semantic types from documents
73
75
 
74
76
  Raises:
75
77
  NotImplementedError: If not implemented by concrete class.
@@ -81,6 +83,8 @@ class AutoLearner(ABC):
81
83
  self._taxonomy_discovery(train_data, test=False)
82
84
  elif task == 'non-taxonomic-re':
83
85
  self._non_taxonomic_re(train_data, test=False)
86
+ elif task == 'text2onto':
87
+ self._text2onto(train_data, test=False)
84
88
  else:
85
89
  raise ValueError(f"{task} is not a valid task.")
86
90
 
@@ -103,6 +107,7 @@ class AutoLearner(ABC):
103
107
  - term-typing: List of predicted types for each term
104
108
  - taxonomy-discovery: Boolean predictions for relationships
105
109
  - non-taxonomy-discovery: Predicted relation types
110
+ - text2onto : Extract ontology terms and their semantic types from documents
106
111
 
107
112
  Raises:
108
113
  NotImplementedError: If not implemented by concrete class.
@@ -115,6 +120,8 @@ class AutoLearner(ABC):
115
120
  return self._taxonomy_discovery(eval_data, test=True)
116
121
  elif task == 'non-taxonomic-re':
117
122
  return self._non_taxonomic_re(eval_data, test=True)
123
+ elif task == 'text2onto':
124
+ return self._text2onto(eval_data, test=True)
118
125
  else:
119
126
  raise ValueError(f"{task} is not a valid task.")
120
127
 
@@ -147,6 +154,9 @@ class AutoLearner(ABC):
147
154
  def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
148
155
  pass
149
156
 
157
+ def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
158
+ pass
159
+
150
160
  def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
151
161
  formatted_data = []
152
162
  if task == "term-typing":
@@ -171,6 +181,7 @@ class AutoLearner(ABC):
171
181
  non_taxonomic_types = list(set(non_taxonomic_types))
172
182
  non_taxonomic_res = list(set(non_taxonomic_res))
173
183
  formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
184
+
174
185
  return formatted_data
175
186
 
176
187
  def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
@@ -186,6 +197,26 @@ class AutoLearner(ABC):
186
197
  formatted_data.append({"head": non_taxonomic_triplets.head,
187
198
  "tail": non_taxonomic_triplets.tail,
188
199
  "relation": non_taxonomic_triplets.relation})
200
+ if task == "text2onto":
201
+ terms2docs = data.get("terms2docs", {}) or {}
202
+ terms2types = data.get("terms2types", {}) or {}
203
+
204
+ # gold doc→terms
205
+ gold_terms = []
206
+ for term, doc_ids in terms2docs.items():
207
+ for doc_id in doc_ids or []:
208
+ gold_terms.append({"doc_id": doc_id, "term": term})
209
+
210
+ # gold doc→types derived via doc→terms + term→types
211
+ doc2types = defaultdict(set)
212
+ for term, doc_ids in terms2docs.items():
213
+ for doc_id in doc_ids or []:
214
+ for ty in (terms2types.get(term, []) or []):
215
+ if isinstance(ty, str) and ty.strip():
216
+ doc2types[doc_id].add(ty.strip())
217
+ gold_types = [{"doc_id": doc_id, "type": ty} for doc_id, tys in doc2types.items() for ty in tys]
218
+ return {"terms": gold_terms, "types": gold_types}
219
+
189
220
  return formatted_data
190
221
 
191
222
  class AutoLLM(ABC):
@@ -201,7 +232,7 @@ class AutoLLM(ABC):
201
232
  tokenizer: The tokenizer associated with the model.
202
233
  """
203
234
 
204
- def __init__(self, label_mapper: Any, device: str='cpu', token: str="") -> None:
235
+ def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 256) -> None:
205
236
  """
206
237
  Initialize the LLM component.
207
238
 
@@ -213,6 +244,7 @@ class AutoLLM(ABC):
213
244
  self.device=device
214
245
  self.model: Optional[Any] = None
215
246
  self.tokenizer: Optional[Any] = None
247
+ self.max_length = max_length
216
248
 
217
249
 
218
250
  def load(self, model_id: str) -> None:
@@ -236,10 +268,8 @@ class AutoLLM(ABC):
236
268
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
237
269
  self.tokenizer.pad_token = self.tokenizer.eos_token
238
270
  if self.device == "cpu":
239
- # device_map = "cpu"
240
271
  self.model = AutoModelForCausalLM.from_pretrained(
241
272
  model_id,
242
- # device_map=device_map,
243
273
  torch_dtype=torch.bfloat16,
244
274
  token=self.token
245
275
  )
@@ -248,8 +278,8 @@ class AutoLLM(ABC):
248
278
  self.model = AutoModelForCausalLM.from_pretrained(
249
279
  model_id,
250
280
  device_map=device_map,
251
- torch_dtype=torch.bfloat16,
252
- token=self.token
281
+ token=self.token,
282
+ trust_remote_code=True,
253
283
  )
254
284
  self.label_mapper.fit()
255
285
 
@@ -276,29 +306,20 @@ class AutoLLM(ABC):
276
306
  List of generated text responses, one for each input prompt.
277
307
  Responses include the original input plus generated continuation.
278
308
  """
279
- # Tokenize inputs and move to device
280
309
  encoded_inputs = self.tokenizer(inputs,
281
310
  return_tensors="pt",
282
- padding=True,
311
+ max_length=self.max_length,
283
312
  truncation=True).to(self.model.device)
284
313
  input_ids = encoded_inputs["input_ids"]
285
314
  input_length = input_ids.shape[1]
286
-
287
- # Generate output
288
315
  outputs = self.model.generate(
289
316
  **encoded_inputs,
290
317
  max_new_tokens=max_new_tokens,
291
- pad_token_id=self.tokenizer.eos_token_id
318
+ pad_token_id=self.tokenizer.eos_token_id,
319
+ eos_token_id=self.tokenizer.eos_token_id
292
320
  )
293
-
294
- # Extract only the newly generated tokens (excluding prompt)
295
321
  generated_tokens = outputs[:, input_length:]
296
-
297
- # Decode only the generated part
298
322
  decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
299
- print(decoded_outputs)
300
- print(self.label_mapper.predict(decoded_outputs))
301
- # Map the decoded text to labels
302
323
  return self.label_mapper.predict(decoded_outputs)
303
324
 
304
325
  class AutoRetriever(ABC):
@@ -372,7 +372,7 @@ class BaseOntology(ABC):
372
372
  # Save updated metrics
373
373
  df.to_excel(metrics_file_path, index=False)
374
374
 
375
- def is_valid_label(label: str) -> Any:
375
+ def is_valid_label(self, label: str) -> Any:
376
376
  invalids = ['root', 'thing']
377
377
  if label.lower() in invalids:
378
378
  return None
@@ -522,7 +522,7 @@ class BaseOntology(ABC):
522
522
  return True
523
523
  return False
524
524
 
525
- def _is_anonymous_id(label: str) -> bool:
525
+ def _is_anonymous_id(self, label: str) -> bool:
526
526
  """Check if a label represents an anonymous class identifier."""
527
527
  if not label:
528
528
  return True
@@ -11,44 +11,84 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import List, Dict, Tuple, Set
14
+ from typing import List, Dict, Tuple, Set, Any, Union
15
15
 
16
16
  SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
17
17
 
18
- def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]:
19
- def jaccard_similarity(a: str, b: str) -> float:
20
- set_a = set(a.lower().split())
21
- set_b = set(b.lower().split())
22
- if not set_a and not set_b:
18
+ def text2onto_metrics(
19
+ y_true: Dict[str, Any],
20
+ y_pred: Dict[str, Any],
21
+ similarity_threshold: float = 0.8
22
+ ) -> Dict[str, Any]:
23
+ """
24
+ Expects:
25
+ y_true = {"terms": [{"doc_id": str, "term": str}, ...],
26
+ "types": [{"doc_id": str, "type": str}, ...]}
27
+ y_pred = same shape
28
+
29
+ Returns:
30
+ {"terms": {...}, "types": {...}}
31
+ """
32
+
33
+ def jaccard_similarity(text_a: str, text_b: str) -> float:
34
+ tokens_a = set(text_a.lower().split())
35
+ tokens_b = set(text_b.lower().split())
36
+ if not tokens_a and not tokens_b:
23
37
  return 1.0
24
- return len(set_a & set_b) / len(set_a | set_b)
25
-
26
- matched_gt_indices = set()
27
- matched_pred_indices = set()
28
- for i, pred_label in enumerate(y_pred):
29
- for j, gt_label in enumerate(y_true):
30
- if j in matched_gt_indices:
31
- continue
32
- sim = jaccard_similarity(pred_label, gt_label)
33
- if sim >= similarity_threshold:
34
- matched_pred_indices.add(i)
35
- matched_gt_indices.add(j)
36
- break # each gt matched once
37
-
38
- total_correct = len(matched_pred_indices)
39
- total_predicted = len(y_pred)
40
- total_ground_truth = len(y_true)
38
+ return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
39
+
40
+ def pairs_to_strings(rows: List[Dict[str, str]], value_key: str) -> List[str]:
41
+ paired_strings: List[str] = []
42
+ for row in rows or []:
43
+ doc_id = (row.get("doc_id") or "").strip()
44
+ value = (row.get(value_key) or "").strip()
45
+ if doc_id and value:
46
+ # keep doc association + allow token Jaccard
47
+ paired_strings.append(f"{doc_id} {value}")
48
+ return paired_strings
49
+
50
+ def score_list(ground_truth_items: List[str], predicted_items: List[str]) -> Dict[str, Union[float, int]]:
51
+ matched_ground_truth_indices: Set[int] = set()
52
+ matched_predicted_indices: Set[int] = set()
53
+
54
+ for predicted_index, predicted_item in enumerate(predicted_items):
55
+ for ground_truth_index, ground_truth_item in enumerate(ground_truth_items):
56
+ if ground_truth_index in matched_ground_truth_indices:
57
+ continue
58
+
59
+ if jaccard_similarity(predicted_item, ground_truth_item) >= similarity_threshold:
60
+ matched_predicted_indices.add(predicted_index)
61
+ matched_ground_truth_indices.add(ground_truth_index)
62
+ break
63
+
64
+ total_correct = len(matched_predicted_indices)
65
+ total_predicted = len(predicted_items)
66
+ total_ground_truth = len(ground_truth_items)
67
+
68
+ precision = total_correct / total_predicted if total_predicted else 0.0
69
+ recall = total_correct / total_ground_truth if total_ground_truth else 0.0
70
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
71
+
72
+ return {
73
+ "f1_score": f1,
74
+ "precision": precision,
75
+ "recall": recall,
76
+ "total_correct": total_correct,
77
+ "total_predicted": total_predicted,
78
+ "total_ground_truth": total_ground_truth,
79
+ }
80
+
81
+ ground_truth_terms = pairs_to_strings(y_true.get("terms", []), "term")
82
+ predicted_terms = pairs_to_strings(y_pred.get("terms", []), "term")
83
+ ground_truth_types = pairs_to_strings(y_true.get("types", []), "type")
84
+ predicted_types = pairs_to_strings(y_pred.get("types", []), "type")
85
+
86
+ terms_metrics = score_list(ground_truth_terms, predicted_terms)
87
+ types_metrics = score_list(ground_truth_types, predicted_types)
41
88
 
42
- precision = total_correct / total_predicted if total_predicted > 0 else 0
43
- recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0
44
- f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
45
89
  return {
46
- "f1_score": f1_score,
47
- "precision": precision,
48
- "recall": recall,
49
- "total_correct": total_correct,
50
- "total_predicted": total_predicted,
51
- "total_ground_truth": total_ground_truth
90
+ "terms": terms_metrics,
91
+ "types": types_metrics,
52
92
  }
53
93
 
54
94
  def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
@@ -14,6 +14,6 @@
14
14
 
15
15
  from .llm import AutoLLMLearner, FalconLLM, MistralLLM
16
16
  from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
17
- from .rag import AutoRAGLearner
17
+ from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
18
18
  from .prompt import StandardizedPrompting
19
19
  from .label_mapper import LabelMapper
@@ -31,7 +31,7 @@ class LabelMapper:
31
31
  ngram_range: Tuple=(1, 1),
32
32
  label_dict: Dict[str, List[str]]=None,
33
33
  analyzer: str = 'word',
34
- iterator_no: int = 100):
34
+ iterator_no: int = 1000):
35
35
  """
36
36
  Initializes the TFIDFLabelMapper with a specified classifier and TF-IDF configuration.
37
37
 
@@ -17,15 +17,50 @@ from ..base import AutoPrompt
17
17
  class StandardizedPrompting(AutoPrompt):
18
18
  def __init__(self, task: str = None):
19
19
  if task == "term-typing":
20
- prompt_template = """Determine whether the given term can be categorized as an instance of the specified high-level type. Answer with `yes` if it is otherwise answer with `no`. Do not explain.
20
+ prompt_template = """You are performing term typing.
21
+
22
+ Determine whether the given term is a clear and unambiguous instance of the specified high-level type.
23
+
24
+ Rules:
25
+ - Answer "yes" only if the term commonly and directly belongs to the type.
26
+ - Answer "no" if the term does not belong to the type, is ambiguous, or only weakly related.
27
+ - Use the most common meaning of the term.
28
+ - Do not explain your answer.
29
+
21
30
  Term: {term}
22
31
  Type: {type}
23
- Answer: """
32
+ Answer (yes or no):"""
24
33
  elif task == "taxonomy-discovery":
25
- prompt_template = """Is {parent} a direct or indirect superclass (or parent concept) of {child} in a conceptual hierarchy? Answer with yes or no.
26
- Answer: """
34
+ prompt_template = """You are identifying taxonomic (is-a) relationships.
35
+
36
+ Question:
37
+ Is "{parent}" a superclass (direct or indirect) of "{child}" in a standard conceptual or ontological hierarchy?
38
+
39
+ Rules:
40
+ - A superclass means: "{child}" is a type or instance of "{parent}".
41
+ - Answer "yes" only if the relationship is a true is-a relationship.
42
+ - Answer "no" for part-of, related-to, or associative relationships.
43
+ - Use general world knowledge.
44
+ - Do not explain.
45
+
46
+ Parent: {parent}
47
+ Child: {child}
48
+ Answer (yes or no):"""
27
49
  elif task == "non-taxonomic-re":
28
- prompt_template = """Given the conceptual types `{head}` and `{tail}`, does a `{relation}` relation exist between them? Respond with "yes" if it does, otherwise respond with "no"."""
50
+ prompt_template = """You are identifying non-taxonomic conceptual relationships.
51
+
52
+ Given two conceptual types, determine whether the specified relation typically holds between them.
53
+
54
+ Rules:
55
+ - Answer "yes" only if the relation commonly and meaningfully applies.
56
+ - Answer "no" if the relation is rare, indirect, or context-dependent.
57
+ - Do not infer relations that require specific situations.
58
+ - Do not explain.
59
+
60
+ Head type: {head}
61
+ Tail type: {tail}
62
+ Relation: {relation}
63
+ Answer (yes or no):"""
29
64
  else:
30
65
  raise ValueError("Unknown task! Current tasks are: 'term-typing', 'taxonomy-discovery', 'non-taxonomic-re'")
31
66
  super().__init__(prompt_template)
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
@@ -14,8 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from typing import Any
17
- from ..base import AutoLearner
18
-
17
+ from ...base import AutoLearner
19
18
 
20
19
  class AutoRAGLearner(AutoLearner):
21
20
  def __init__(self,
@@ -87,3 +86,9 @@ class AutoRAGLearner(AutoLearner):
87
86
  return self.llm._non_taxonomic_re_predict(dataset=dataset)
88
87
  else:
89
88
  warnings.warn("No requirement for fiting the non-taxonomic-re model, the predict module will use the input data to do the fit as well.")
89
+
90
+
91
+ class LLMAugmentedRAGLearner(AutoRAGLearner):
92
+
93
+ def set_augmenter(self, augmenter):
94
+ self.retriever.set_augmenter(augmenter=augmenter)
@@ -16,4 +16,4 @@ from .crossencoder import CrossEncoderRetriever
16
16
  from .embedding import GloveRetriever, Word2VecRetriever
17
17
  from .ngram import NgramRetriever
18
18
  from .learner import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
19
- from .llm_retriever import LLMAugmenterGenerator, LLMAugmenter, LLMAugmentedRetriever
19
+ from .augmented_retriever import LLMAugmenterGenerator, LLMAugmenter, LLMAugmentedRetriever
@@ -17,6 +17,8 @@ from typing import Any, List, Dict
17
17
  from openai import OpenAI
18
18
  import time
19
19
  from tqdm import tqdm
20
+ import torch
21
+ import torch.nn.functional as F
20
22
 
21
23
  from ...base import AutoRetriever
22
24
  from ...utils import load_json
@@ -125,7 +127,6 @@ class LLMAugmenterGenerator(ABC):
125
127
  except Exception:
126
128
  print("sleep for 5 seconds")
127
129
  time.sleep(5)
128
-
129
130
  return inference
130
131
 
131
132
  def tasks_data_former(self, data: Any, task: str) -> List[str] | Dict[str, List[str]]:
@@ -298,21 +299,12 @@ class LLMAugmentedRetriever(AutoRetriever):
298
299
  Attributes:
299
300
  augmenter: An augmenter instance that provides transform() and top_n_candidate.
300
301
  """
301
-
302
- def __init__(self) -> None:
303
- """
304
- Initialize the augmented retriever with no augmenter attached.
305
- """
302
+ def __init__(self, threshold: float = 0.0, cutoff_rate: float = 100.0) -> None:
306
303
  super().__init__()
307
- self.augmenter = None
304
+ self.threshold = threshold
305
+ self.cutoff_rate = cutoff_rate
308
306
 
309
307
  def set_augmenter(self, augmenter):
310
- """
311
- Attach an augmenter instance.
312
-
313
- Args:
314
- augmenter: An object providing `transform(query, task)` and `top_n_candidate`.
315
- """
316
308
  self.augmenter = augmenter
317
309
 
318
310
  def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1, task: str = None) -> List[List[str]]:
@@ -328,29 +320,46 @@ class LLMAugmentedRetriever(AutoRetriever):
328
320
  Returns:
329
321
  list[list[str]]: A list of document lists, one per input query.
330
322
  """
331
- parent_retrieve = super(LLMAugmentedRetriever, self).retrieve
332
-
333
- if task == 'taxonomy-discovery':
334
- query_sets = []
335
- for idx in range(self.augmenter.top_n_candidate):
336
- query_set = []
337
- for qu in query:
338
- query_set.append(self.augmenter.transform(qu, task=task)[idx])
339
- query_sets.append(query_set)
340
-
341
- retrieves = [
342
- parent_retrieve(query=query_set, top_k=top_k, batch_size=batch_size)
343
- for query_set in query_sets
344
- ]
345
-
346
- results = []
347
- for qu_idx, qu in enumerate(query):
348
- qu_result = []
349
- for top_idx in range(self.augmenter.top_n_candidate):
350
- qu_result += retrieves[top_idx][qu_idx]
351
- results.append(list(set(qu_result)))
352
-
353
- return results
354
-
355
- else:
356
- return parent_retrieve(query=query, top_k=top_k, batch_size=batch_size)
323
+ if task != 'taxonomy-discovery':
324
+ return super().retrieve(query=query, top_k=top_k, batch_size=batch_size)
325
+ return self.augmented_retrieve(query, top_k=top_k, batch_size=batch_size, task=task)
326
+
327
+ def augmented_retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1, task: str = None):
328
+ if self.embeddings is None:
329
+ raise RuntimeError("Retriever model must index documents before prediction.")
330
+
331
+ augmented_queries, index_map = [], []
332
+ for qu_idx, qu in enumerate(query):
333
+ augmented = self.augmenter.transform(qu, task=task)
334
+ for aug in augmented:
335
+ augmented_queries.append(aug)
336
+ index_map.append(qu_idx)
337
+
338
+ doc_norm = F.normalize(self.embeddings, p=2, dim=1)
339
+ results = [dict() for _ in range(len(query))]
340
+
341
+ if batch_size == -1:
342
+ batch_size = len(augmented_queries)
343
+
344
+ for start in range(0, len(augmented_queries), batch_size):
345
+ batch_aug = augmented_queries[start:start + batch_size]
346
+ batch_embeddings = self.embedding_model.encode(batch_aug, convert_to_tensor=True)
347
+ batch_norm = F.normalize(batch_embeddings, p=2, dim=1)
348
+ similarity_matrix = torch.matmul(batch_norm, doc_norm.T)
349
+ current_top_k = min(top_k, len(self.documents))
350
+ topk_similarities, topk_indices = torch.topk(similarity_matrix, k=current_top_k, dim=1)
351
+
352
+ for i, (doc_indices, sim_scores) in enumerate(zip(topk_indices, topk_similarities)):
353
+ original_query_idx = index_map[start + i]
354
+
355
+ for doc_idx, score in zip(doc_indices.tolist(), sim_scores.tolist()):
356
+ if score >= self.threshold:
357
+ doc = self.documents[doc_idx]
358
+ prev = results[original_query_idx].get(doc, 0.0)
359
+ results[original_query_idx][doc] = prev + score
360
+
361
+ final_results = []
362
+ for doc_score_map in results:
363
+ sorted_docs = sorted(doc_score_map.items(), key=lambda x: x[1], reverse=True)
364
+ final_results.append([doc for doc, _ in sorted_docs])
365
+ return final_results
@@ -122,7 +122,6 @@ class AutoRetrieverLearner(AutoLearner):
122
122
  warnings.warn("No requirement for fiting the non-taxonomic RE model, the predict module will use the input data to do the fit as well..")
123
123
 
124
124
 
125
-
126
125
  class LLMAugmentedRetrieverLearner(AutoRetrieverLearner):
127
126
 
128
127
  def set_augmenter(self, augmenter):
@@ -160,9 +159,9 @@ class LLMAugmentedRetrieverLearner(AutoRetrieverLearner):
160
159
  taxonomic_pairs = [{"parent": candidate, "child": query}
161
160
  for query, candidates in zip(data, candidates_lst)
162
161
  for candidate in candidates if candidate.lower() != query.lower()]
163
- taxonomic_pairs += [{"parent": query, "child": candidate}
164
- for query, candidates in zip(data, candidates_lst)
165
- for candidate in candidates if candidate.lower() != query.lower()]
162
+ # taxonomic_pairs += [{"parent": query, "child": candidate}
163
+ # for query, candidates in zip(data, candidates_lst)
164
+ # for candidate in candidates if candidate.lower() != query.lower()]
166
165
  unique_taxonomic_pairs, seen = [], set()
167
166
  for pair in taxonomic_pairs:
168
167
  key = (pair["parent"].lower(), pair["child"].lower()) # Directional key (parent, child)
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .alexbek import AlexbekFewShotLearner
15
+ from .alexbek import AlexbekRAGFewShotLearner
16
16
  from .sbunlp import SBUNLPFewShotLearner