OntoLearner 1.4.4__py3-none-any.whl → 1.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ontolearner/VERSION CHANGED
@@ -1 +1 @@
1
- 1.4.4
1
+ 1.4.5
@@ -350,7 +350,7 @@ class AutoRetriever(ABC):
350
350
  self.documents = inputs
351
351
  self.embeddings = self.embedding_model.encode(inputs, convert_to_tensor=True)
352
352
 
353
- def retrieve(self, query: List[str], top_k: int = 5) -> List[List[str]]:
353
+ def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1) -> List[List[str]]:
354
354
  """
355
355
  Retrieve the top-k most similar examples for each query in a list of queries.
356
356
 
@@ -363,33 +363,37 @@ class AutoRetriever(ABC):
363
363
  """
364
364
  if self.embeddings is None:
365
365
  raise RuntimeError("Retriever model must index documents before prediction.")
366
-
367
- # Encode all queries at once
368
366
  query_embeddings = self.embedding_model.encode(query, convert_to_tensor=True) # shape: [num_queries, dim]
369
-
370
367
  if query_embeddings.shape[-1] != self.embeddings.shape[-1]:
371
368
  raise ValueError(
372
369
  f"Embedding dimension mismatch: query embedding dim={query_embeddings.shape[-1]}, "
373
370
  f"document embedding dim={self.embeddings.shape[-1]}"
374
371
  )
375
-
376
- # Normalize embeddings for cosine similarity
377
- query_norm = F.normalize(query_embeddings, p=2, dim=1)
378
372
  doc_norm = F.normalize(self.embeddings, p=2, dim=1)
373
+ if batch_size == -1:
374
+ results = self._retrieve(query_embeddings=query_embeddings, doc_norm=doc_norm, top_k=top_k)
375
+ else:
376
+ results = self._batch_retrieve(query_embeddings=query_embeddings, doc_norm=doc_norm, top_k=top_k, batch_size=batch_size)
377
+ return results
379
378
 
380
- # Compute cosine similarity: [num_queries, num_docs]
381
- similarity_matrix = torch.matmul(query_norm, doc_norm.T)
382
-
383
- # Get top-k indices for each query
384
- top_k = min(top_k, len(self.documents))
385
- topk_similarities, topk_indices = torch.topk(similarity_matrix, k=top_k, dim=1)
386
379
 
387
- # Retrieve documents for each query
380
+ def _retrieve(self, query_embeddings, doc_norm, top_k: int = 5) -> List[List[str]]:
381
+ query_norm = F.normalize(query_embeddings, p=2, dim=1)
382
+ similarity_matrix = torch.matmul(query_norm, doc_norm.T)
383
+ current_top_k = min(top_k, len(self.documents))
384
+ topk_similarities, topk_indices = torch.topk(similarity_matrix, k=current_top_k, dim=1)
388
385
  results = [[self.documents[i] for i in indices] for indices in topk_indices]
389
-
390
386
  return results
391
387
 
392
388
 
389
+ def _batch_retrieve(self, query_embeddings, doc_norm, top_k: int = 5, batch_size: int = 1024) -> List[List[str]]:
390
+ results = []
391
+ for i in range(0, query_embeddings.size(0), batch_size):
392
+ batch_queries = query_embeddings[i:i + batch_size]
393
+ batch_results = self._retrieve(batch_queries, doc_norm, top_k=top_k)
394
+ results.extend(batch_results)
395
+ return results
396
+
393
397
  class AutoPrompt(ABC):
394
398
  """
395
399
  Abstract base class for prompt formatting components.
@@ -17,12 +17,13 @@ from typing import Any, Optional
17
17
  import warnings
18
18
 
19
19
  class AutoRetrieverLearner(AutoLearner):
20
- def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5):
20
+ def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5, batch_size: int = -1):
21
21
  super().__init__()
22
22
  self.retriever = base_retriever
23
23
  self.top_k = top_k
24
24
  self._is_term_typing_fit = False
25
25
  self._is_taxonomy_discovery_fit = False
26
+ self._batch_size = batch_size
26
27
 
27
28
  def load(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2"):
28
29
  self.retriever.load(model_id=model_id)
@@ -35,7 +36,7 @@ class AutoRetrieverLearner(AutoLearner):
35
36
 
36
37
  def _retriever_predict(self, data:Any, top_k: int) -> Any:
37
38
  if isinstance(data, list):
38
- return self.retriever.retrieve(query=data, top_k=top_k)
39
+ return self.retriever.retrieve(query=data, top_k=top_k, batch_size=self._batch_size)
39
40
  if isinstance(data, str):
40
41
  return self.retriever.retrieve(query=[data], top_k=top_k)
41
42
  raise TypeError(f"Unsupported data type {type(data)}. You should pass a List[str] or a str.")
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: OntoLearner
3
- Version: 1.4.4
3
+ Version: 1.4.5
4
4
  Summary: OntoLearner: A Modular Python Library for Ontology Learning with LLMs.
5
5
  License: MIT
6
+ License-File: LICENSE
6
7
  Author: Hamed Babaei Giglou
7
8
  Author-email: hamedbabaeigiglou@gmail.com
8
9
  Requires-Python: >=3.10,<3.14.0
@@ -1,9 +1,9 @@
1
- ontolearner/VERSION,sha256=0D2LJotBcTKKwkWM6MqC1pof3jhGnU5PquT4-o2KXjU,6
1
+ ontolearner/VERSION,sha256=9DoCksZ0VCotZ2oPT1iKf_vhy6wlCQ3zRLEv8SefeCA,6
2
2
  ontolearner/__init__.py,sha256=E4yukFv2PV4uyztTPDWljCySY9AVDcDDzabuvxfabYE,1889
3
3
  ontolearner/_learner.py,sha256=2CRQvpsz8akIOdxTs2-KLJ-MssULrjpK-QDD3QXUJXI,5297
4
4
  ontolearner/_ontology.py,sha256=W1mp195SImqLKwaj4ueEaBWuLJg2jUdx1JT20Ds3fmQ,6950
5
5
  ontolearner/base/__init__.py,sha256=5pf-ltxzGp32xhEcPdbtm11wXJrYJMUeWG-mbcAYD8Q,705
6
- ontolearner/base/learner.py,sha256=DVWp7OHlhTYU3Es7Q6CWCOeL7Y5LTbjWilTri_DNExs,17897
6
+ ontolearner/base/learner.py,sha256=6iQlrfffK6RtEZKA2NvXExQEA7hVvlyDg4zWM0fkaCQ,18497
7
7
  ontolearner/base/ontology.py,sha256=JbMJ1-WUyHWQiNJL-DeaqcriUimLdqN3_ESROgqOPTQ,24772
8
8
  ontolearner/base/text2onto.py,sha256=iUXYZoqnwgebQuQzM-XSGTVRfHLlhjUK_z5XUvhRICc,5388
9
9
  ontolearner/data_structure/__init__.py,sha256=1HiKvk8FKjhYeI92RHnJXxyQbUJBi3JFytjQjthsY_s,599
@@ -17,7 +17,7 @@ ontolearner/learner/label_mapper.py,sha256=-XW8MHafm4ix3e9u-RRwDePJ71D804DNuKzdf
17
17
  ontolearner/learner/llm.py,sha256=bwCoeR7z3YgYrkKyjDM-MRHZAuDzpUt8f-A0bDUbtGM,7151
18
18
  ontolearner/learner/prompt.py,sha256=0ckH7xphIDKczPe7G-rwiOxFGZ7RsLnpPlNW92b-31U,1574
19
19
  ontolearner/learner/rag.py,sha256=eysB2RvcWkVo53s8-kSbZtJv904YVTmdtxplM4ukUKM,4283
20
- ontolearner/learner/retriever.py,sha256=FIsvutDXvrr9N6AMu35TNJHdiQGbmRQ4TTGfRRdHdYo,4931
20
+ ontolearner/learner/retriever.py,sha256=8X4fVD-OJCMLPH5Tl3nUOuXVhPLApzLAJcgIY5p9gEI,5020
21
21
  ontolearner/ontology/__init__.py,sha256=F9Ta1qCX9mOxIK5CPRypEoglQNkpJ6SJpqziz73xKQE,1328
22
22
  ontolearner/ontology/agriculture.py,sha256=ZaXHNEFjbtsMH8M7HQ8ypnfJS4TUQy_as16fwv-kOKA,5903
23
23
  ontolearner/ontology/arts_humanities.py,sha256=K4ceDJL6PfIfSJZ86uQUkUXOVoiERG6ItgvVE2lhLKk,3996
@@ -53,7 +53,7 @@ ontolearner/tools/visualizer.py,sha256=cwijl4yYaS1SCLM5wbvRTEcbQj9Bjo4fHzZR6q6o8
53
53
  ontolearner/utils/__init__.py,sha256=pSEyU3dlPMADBqygqaaid44RdWf0Lo3Fvz-K_rQ7_Bw,733
54
54
  ontolearner/utils/io.py,sha256=3DqGK2p7c0onKi0Xxs16WB08uHfHUId3bW0dDKwyS0g,2110
55
55
  ontolearner/utils/train_test_split.py,sha256=Zlm42eT6QGWwlySyomCPIiTGmGqeN_h4z4xBY2EAOR8,11530
56
- ontolearner-1.4.4.dist-info/LICENSE,sha256=krXMLuMKgzX-UgaufgfJdm9ojIloZot7ZdvJUnNxl4I,1067
57
- ontolearner-1.4.4.dist-info/METADATA,sha256=Uf9twY6zgfxNZbgmtCqIhYJ1nnzzwe-TzC5_bztEV_U,13999
58
- ontolearner-1.4.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
59
- ontolearner-1.4.4.dist-info/RECORD,,
56
+ ontolearner-1.4.5.dist-info/METADATA,sha256=89wwW4YjejZ9_TdDjn5JwcCQLK-Khb_DvDrBdRGFeJ8,14021
57
+ ontolearner-1.4.5.dist-info/WHEEL,sha256=M5asmiAlL6HEcOq52Yi5mmk9KmTVjY2RDPtO4p9DMrc,88
58
+ ontolearner-1.4.5.dist-info/licenses/LICENSE,sha256=krXMLuMKgzX-UgaufgfJdm9ojIloZot7ZdvJUnNxl4I,1067
59
+ ontolearner-1.4.5.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.3
2
+ Generator: poetry-core 2.2.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any