OntoLearner 1.4.4__py3-none-any.whl → 1.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ontolearner/VERSION CHANGED
@@ -1 +1 @@
1
- 1.4.4
1
+ 1.4.6
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC
16
- from typing import Any, List, Optional
16
+ from typing import Any, List, Optional, Dict
17
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
18
  import torch
19
19
  import torch.nn.functional as F
@@ -147,7 +147,7 @@ class AutoLearner(ABC):
147
147
  def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
148
148
  pass
149
149
 
150
- def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any:
150
+ def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
151
151
  formatted_data = []
152
152
  if task == "term-typing":
153
153
  for typing in data.term_typings:
@@ -173,7 +173,7 @@ class AutoLearner(ABC):
173
173
  formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
174
174
  return formatted_data
175
175
 
176
- def tasks_ground_truth_former(self, data: Any, task: str) -> Any:
176
+ def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
177
177
  formatted_data = []
178
178
  if task == "term-typing":
179
179
  for typing in data.term_typings:
@@ -350,7 +350,7 @@ class AutoRetriever(ABC):
350
350
  self.documents = inputs
351
351
  self.embeddings = self.embedding_model.encode(inputs, convert_to_tensor=True)
352
352
 
353
- def retrieve(self, query: List[str], top_k: int = 5) -> List[List[str]]:
353
+ def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1) -> List[List[str]]:
354
354
  """
355
355
  Retrieve the top-k most similar examples for each query in a list of queries.
356
356
 
@@ -363,33 +363,37 @@ class AutoRetriever(ABC):
363
363
  """
364
364
  if self.embeddings is None:
365
365
  raise RuntimeError("Retriever model must index documents before prediction.")
366
-
367
- # Encode all queries at once
368
366
  query_embeddings = self.embedding_model.encode(query, convert_to_tensor=True) # shape: [num_queries, dim]
369
-
370
367
  if query_embeddings.shape[-1] != self.embeddings.shape[-1]:
371
368
  raise ValueError(
372
369
  f"Embedding dimension mismatch: query embedding dim={query_embeddings.shape[-1]}, "
373
370
  f"document embedding dim={self.embeddings.shape[-1]}"
374
371
  )
375
-
376
- # Normalize embeddings for cosine similarity
377
- query_norm = F.normalize(query_embeddings, p=2, dim=1)
378
372
  doc_norm = F.normalize(self.embeddings, p=2, dim=1)
373
+ if batch_size == -1:
374
+ results = self._retrieve(query_embeddings=query_embeddings, doc_norm=doc_norm, top_k=top_k)
375
+ else:
376
+ results = self._batch_retrieve(query_embeddings=query_embeddings, doc_norm=doc_norm, top_k=top_k, batch_size=batch_size)
377
+ return results
379
378
 
380
- # Compute cosine similarity: [num_queries, num_docs]
381
- similarity_matrix = torch.matmul(query_norm, doc_norm.T)
382
-
383
- # Get top-k indices for each query
384
- top_k = min(top_k, len(self.documents))
385
- topk_similarities, topk_indices = torch.topk(similarity_matrix, k=top_k, dim=1)
386
379
 
387
- # Retrieve documents for each query
380
+ def _retrieve(self, query_embeddings, doc_norm, top_k: int = 5) -> List[List[str]]:
381
+ query_norm = F.normalize(query_embeddings, p=2, dim=1)
382
+ similarity_matrix = torch.matmul(query_norm, doc_norm.T)
383
+ current_top_k = min(top_k, len(self.documents))
384
+ topk_similarities, topk_indices = torch.topk(similarity_matrix, k=current_top_k, dim=1)
388
385
  results = [[self.documents[i] for i in indices] for indices in topk_indices]
389
-
390
386
  return results
391
387
 
392
388
 
389
+ def _batch_retrieve(self, query_embeddings, doc_norm, top_k: int = 5, batch_size: int = 1024) -> List[List[str]]:
390
+ results = []
391
+ for i in range(0, query_embeddings.size(0), batch_size):
392
+ batch_queries = query_embeddings[i:i + batch_size]
393
+ batch_results = self._retrieve(batch_queries, doc_norm, top_k=top_k)
394
+ results.extend(batch_results)
395
+ return results
396
+
393
397
  class AutoPrompt(ABC):
394
398
  """
395
399
  Abstract base class for prompt formatting components.
@@ -11,13 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from typing import Dict
14
+ from typing import List, Dict, Tuple, Set
16
15
 
17
16
  SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
18
17
 
19
- def text2onto_metrics(y_true, y_pred, similarity_threshold: float = 0.8) -> Dict:
20
- def jaccard_similarity(a, b):
18
+ def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]:
19
+ def jaccard_similarity(a: str, b: str) -> float:
21
20
  set_a = set(a.lower().split())
22
21
  set_b = set(b.lower().split())
23
22
  if not set_a and not set_b:
@@ -46,10 +45,13 @@ def text2onto_metrics(y_true, y_pred, similarity_threshold: float = 0.8) -> Dict
46
45
  return {
47
46
  "f1_score": f1_score,
48
47
  "precision": precision,
49
- "recall": recall
48
+ "recall": recall,
49
+ "total_correct": total_correct,
50
+ "total_predicted": total_predicted,
51
+ "total_ground_truth": total_ground_truth
50
52
  }
51
53
 
52
- def term_typing_metrics(y_true, y_pred) -> Dict:
54
+ def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
53
55
  """
54
56
  Compute precision, recall, and F1-score for term typing
55
57
  using (term, type) pair-level matching instead of ID-based lookups.
@@ -77,13 +79,17 @@ def term_typing_metrics(y_true, y_pred) -> Dict:
77
79
  precision = total_correct / total_predicted if total_predicted > 0 else 0.0
78
80
  recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0.0
79
81
  f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
82
+
80
83
  return {
81
84
  "f1_score": f1_score,
82
85
  "precision": precision,
83
- "recall": recall
86
+ "recall": recall,
87
+ "total_correct": total_correct,
88
+ "total_predicted": total_predicted,
89
+ "total_ground_truth": total_ground_truth
84
90
  }
85
91
 
86
- def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
92
+ def taxonomy_discovery_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
87
93
  total_predicted = len(y_pred)
88
94
  total_ground_truth = len(y_true)
89
95
  # Convert ground truth and predictions to sets of tuples for easy comparison
@@ -102,18 +108,22 @@ def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
102
108
  return {
103
109
  "f1_score": f1_score,
104
110
  "precision": precision,
105
- "recall": recall
111
+ "recall": recall,
112
+ "total_correct": total_correct,
113
+ "total_predicted": total_predicted,
114
+ "total_ground_truth": total_ground_truth
106
115
  }
107
116
 
108
- def non_taxonomic_re_metrics(y_true, y_pred) -> Dict:
109
- def normalize_triple(item):
117
+
118
+ def non_taxonomic_re_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
119
+ def normalize_triple(item: Dict[str, str]) -> Tuple[str, str, str]:
110
120
  return (
111
121
  item["head"].strip().lower(),
112
122
  item["relation"].strip().lower(),
113
123
  item["tail"].strip().lower()
114
124
  )
115
125
 
116
- def expand_symmetric(triples):
126
+ def expand_symmetric(triples: Set[Tuple[str, str, str]]) -> Set[Tuple[str, str, str]]:
117
127
  expanded = set()
118
128
  for h, r, t in triples:
119
129
  expanded.add((h, r, t))
@@ -136,5 +146,8 @@ def non_taxonomic_re_metrics(y_true, y_pred) -> Dict:
136
146
  return {
137
147
  "f1_score": f1_score,
138
148
  "precision": precision,
139
- "recall": recall
149
+ "recall": recall,
150
+ "total_correct": total_correct,
151
+ "total_predicted": total_predicted,
152
+ "total_ground_truth": total_ground_truth
140
153
  }
@@ -17,12 +17,12 @@ from typing import Any, Optional
17
17
  import warnings
18
18
 
19
19
  class AutoRetrieverLearner(AutoLearner):
20
- def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5):
20
+ def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5, batch_size: int = -1):
21
21
  super().__init__()
22
22
  self.retriever = base_retriever
23
23
  self.top_k = top_k
24
24
  self._is_term_typing_fit = False
25
- self._is_taxonomy_discovery_fit = False
25
+ self._batch_size = batch_size
26
26
 
27
27
  def load(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2"):
28
28
  self.retriever.load(model_id=model_id)
@@ -35,7 +35,7 @@ class AutoRetrieverLearner(AutoLearner):
35
35
 
36
36
  def _retriever_predict(self, data:Any, top_k: int) -> Any:
37
37
  if isinstance(data, list):
38
- return self.retriever.retrieve(query=data, top_k=top_k)
38
+ return self.retriever.retrieve(query=data, top_k=top_k, batch_size=self._batch_size)
39
39
  if isinstance(data, str):
40
40
  return self.retriever.retrieve(query=[data], top_k=top_k)
41
41
  raise TypeError(f"Unsupported data type {type(data)}. You should pass a List[str] or a str.")
@@ -63,9 +63,9 @@ class AutoRetrieverLearner(AutoLearner):
63
63
  if test:
64
64
  self._retriever_fit(data=data)
65
65
  candidates_lst = self._retriever_predict(data=data, top_k=self.top_k + 1)
66
- taxonomic_pairs = [{"parent": query, "child": candidate}
66
+ taxonomic_pairs = [{"parent": candidate, "child": query}
67
67
  for query, candidates in zip(data, candidates_lst)
68
- for candidate in candidates if candidate != query]
68
+ for candidate in candidates if candidate.lower() != query.lower()]
69
69
  return taxonomic_pairs
70
70
  else:
71
71
  warnings.warn("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.")
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: OntoLearner
3
- Version: 1.4.4
3
+ Version: 1.4.6
4
4
  Summary: OntoLearner: A Modular Python Library for Ontology Learning with LLMs.
5
5
  License: MIT
6
+ License-File: LICENSE
6
7
  Author: Hamed Babaei Giglou
7
8
  Author-email: hamedbabaeigiglou@gmail.com
8
9
  Requires-Python: >=3.10,<3.14.0
@@ -1,9 +1,9 @@
1
- ontolearner/VERSION,sha256=0D2LJotBcTKKwkWM6MqC1pof3jhGnU5PquT4-o2KXjU,6
1
+ ontolearner/VERSION,sha256=Vj1cgMVOd-L4RU7NXBQ4j5qzr3ftxztvRzatpkUnlSw,6
2
2
  ontolearner/__init__.py,sha256=E4yukFv2PV4uyztTPDWljCySY9AVDcDDzabuvxfabYE,1889
3
3
  ontolearner/_learner.py,sha256=2CRQvpsz8akIOdxTs2-KLJ-MssULrjpK-QDD3QXUJXI,5297
4
4
  ontolearner/_ontology.py,sha256=W1mp195SImqLKwaj4ueEaBWuLJg2jUdx1JT20Ds3fmQ,6950
5
5
  ontolearner/base/__init__.py,sha256=5pf-ltxzGp32xhEcPdbtm11wXJrYJMUeWG-mbcAYD8Q,705
6
- ontolearner/base/learner.py,sha256=DVWp7OHlhTYU3Es7Q6CWCOeL7Y5LTbjWilTri_DNExs,17897
6
+ ontolearner/base/learner.py,sha256=J9-Oi2P_UA5Jdbh8muBN0VgH8HGi1uyhEi2LZmCv_rk,18543
7
7
  ontolearner/base/ontology.py,sha256=JbMJ1-WUyHWQiNJL-DeaqcriUimLdqN3_ESROgqOPTQ,24772
8
8
  ontolearner/base/text2onto.py,sha256=iUXYZoqnwgebQuQzM-XSGTVRfHLlhjUK_z5XUvhRICc,5388
9
9
  ontolearner/data_structure/__init__.py,sha256=1HiKvk8FKjhYeI92RHnJXxyQbUJBi3JFytjQjthsY_s,599
@@ -11,13 +11,13 @@ ontolearner/data_structure/data.py,sha256=jUUDfqsOZcEqIR83SRboiKibPdA_JquI1uOEiQ
11
11
  ontolearner/data_structure/metric.py,sha256=4QKkZ5L1YK6hDTU-N5Z9I9Ha99DVHmGfYxK7N2qdhfc,7589
12
12
  ontolearner/evaluation/__init__.py,sha256=4BZr3BUXjQDTj4Aqlqy4THa80lZPsMuh1EBTCyi9Wig,842
13
13
  ontolearner/evaluation/evaluate.py,sha256=NYCVcmPqpyIxYZrMAim37gL-erdh698RD3t3eNTTgZc,1163
14
- ontolearner/evaluation/metrics.py,sha256=jk-80kQZfWldYV9Lzhq3lZvWE8YT5ywqtzhIfmTm664,5378
14
+ ontolearner/evaluation/metrics.py,sha256=3Aw6ycJ3_Q6xfj4tMBJP6QcexUei0G16H0ZQWt87aRU,6286
15
15
  ontolearner/learner/__init__.py,sha256=ZS816XCPb2K7azTlK2032A6ozZNoijlPLDOwcgu3-8g,745
16
16
  ontolearner/learner/label_mapper.py,sha256=-XW8MHafm4ix3e9u-RRwDePJ71D804DNuKzdf1zudtk,3789
17
17
  ontolearner/learner/llm.py,sha256=bwCoeR7z3YgYrkKyjDM-MRHZAuDzpUt8f-A0bDUbtGM,7151
18
18
  ontolearner/learner/prompt.py,sha256=0ckH7xphIDKczPe7G-rwiOxFGZ7RsLnpPlNW92b-31U,1574
19
19
  ontolearner/learner/rag.py,sha256=eysB2RvcWkVo53s8-kSbZtJv904YVTmdtxplM4ukUKM,4283
20
- ontolearner/learner/retriever.py,sha256=FIsvutDXvrr9N6AMu35TNJHdiQGbmRQ4TTGfRRdHdYo,4931
20
+ ontolearner/learner/retriever.py,sha256=GDXr6l0m_prxnctxQzBpm75xL4jW2Q4b91iyePFcDAs,4988
21
21
  ontolearner/ontology/__init__.py,sha256=F9Ta1qCX9mOxIK5CPRypEoglQNkpJ6SJpqziz73xKQE,1328
22
22
  ontolearner/ontology/agriculture.py,sha256=ZaXHNEFjbtsMH8M7HQ8ypnfJS4TUQy_as16fwv-kOKA,5903
23
23
  ontolearner/ontology/arts_humanities.py,sha256=K4ceDJL6PfIfSJZ86uQUkUXOVoiERG6ItgvVE2lhLKk,3996
@@ -53,7 +53,7 @@ ontolearner/tools/visualizer.py,sha256=cwijl4yYaS1SCLM5wbvRTEcbQj9Bjo4fHzZR6q6o8
53
53
  ontolearner/utils/__init__.py,sha256=pSEyU3dlPMADBqygqaaid44RdWf0Lo3Fvz-K_rQ7_Bw,733
54
54
  ontolearner/utils/io.py,sha256=3DqGK2p7c0onKi0Xxs16WB08uHfHUId3bW0dDKwyS0g,2110
55
55
  ontolearner/utils/train_test_split.py,sha256=Zlm42eT6QGWwlySyomCPIiTGmGqeN_h4z4xBY2EAOR8,11530
56
- ontolearner-1.4.4.dist-info/LICENSE,sha256=krXMLuMKgzX-UgaufgfJdm9ojIloZot7ZdvJUnNxl4I,1067
57
- ontolearner-1.4.4.dist-info/METADATA,sha256=Uf9twY6zgfxNZbgmtCqIhYJ1nnzzwe-TzC5_bztEV_U,13999
58
- ontolearner-1.4.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
59
- ontolearner-1.4.4.dist-info/RECORD,,
56
+ ontolearner-1.4.6.dist-info/METADATA,sha256=S756f3Kes6TKDwU59ft3TU3GSkffQh6sjptJswr5orw,14021
57
+ ontolearner-1.4.6.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
58
+ ontolearner-1.4.6.dist-info/licenses/LICENSE,sha256=krXMLuMKgzX-UgaufgfJdm9ojIloZot7ZdvJUnNxl4I,1067
59
+ ontolearner-1.4.6.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.3
2
+ Generator: poetry-core 2.2.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any