OntoLearner 1.4.7__tar.gz → 1.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {ontolearner-1.4.7 → ontolearner-1.4.9}/PKG-INFO +16 -12
  2. {ontolearner-1.4.7 → ontolearner-1.4.9}/README.md +11 -11
  3. ontolearner-1.4.9/ontolearner/VERSION +1 -0
  4. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/base/learner.py +15 -12
  5. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/learner/__init__.py +1 -1
  6. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/learner/label_mapper.py +1 -1
  7. ontolearner-1.4.9/ontolearner/learner/retriever/__init__.py +19 -0
  8. ontolearner-1.4.9/ontolearner/learner/retriever/crossencoder.py +129 -0
  9. ontolearner-1.4.9/ontolearner/learner/retriever/embedding.py +229 -0
  10. ontolearner-1.4.9/ontolearner/learner/retriever/learner.py +217 -0
  11. ontolearner-1.4.9/ontolearner/learner/retriever/llm_retriever.py +356 -0
  12. ontolearner-1.4.9/ontolearner/learner/retriever/ngram.py +123 -0
  13. ontolearner-1.4.9/ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
  14. ontolearner-1.4.9/ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
  15. ontolearner-1.4.9/ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
  16. ontolearner-1.4.9/ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
  17. ontolearner-1.4.9/ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
  18. ontolearner-1.4.9/ontolearner/learner/term_typing/__init__.py +17 -0
  19. ontolearner-1.4.9/ontolearner/learner/term_typing/alexbek.py +1262 -0
  20. ontolearner-1.4.9/ontolearner/learner/term_typing/rwthdbis.py +379 -0
  21. ontolearner-1.4.9/ontolearner/learner/term_typing/sbunlp.py +478 -0
  22. ontolearner-1.4.9/ontolearner/learner/text2onto/__init__.py +16 -0
  23. ontolearner-1.4.9/ontolearner/learner/text2onto/alexbek.py +1219 -0
  24. ontolearner-1.4.9/ontolearner/learner/text2onto/sbunlp.py +598 -0
  25. {ontolearner-1.4.7 → ontolearner-1.4.9}/pyproject.toml +5 -1
  26. ontolearner-1.4.7/ontolearner/VERSION +0 -1
  27. ontolearner-1.4.7/ontolearner/learner/retriever.py +0 -101
  28. {ontolearner-1.4.7 → ontolearner-1.4.9}/LICENSE +0 -0
  29. {ontolearner-1.4.7 → ontolearner-1.4.9}/images/logo.png +0 -0
  30. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/__init__.py +0 -0
  31. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/_learner.py +0 -0
  32. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/_ontology.py +0 -0
  33. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/base/__init__.py +0 -0
  34. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/base/ontology.py +0 -0
  35. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/base/text2onto.py +0 -0
  36. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/data_structure/__init__.py +0 -0
  37. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/data_structure/data.py +0 -0
  38. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/data_structure/metric.py +0 -0
  39. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/evaluation/__init__.py +0 -0
  40. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/evaluation/evaluate.py +0 -0
  41. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/evaluation/metrics.py +0 -0
  42. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/learner/llm.py +0 -0
  43. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/learner/prompt.py +0 -0
  44. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/learner/rag.py +0 -0
  45. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/__init__.py +0 -0
  46. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/agriculture.py +0 -0
  47. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/arts_humanities.py +0 -0
  48. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/biology.py +0 -0
  49. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/chemistry.py +0 -0
  50. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/ecology_environment.py +0 -0
  51. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/education.py +0 -0
  52. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/events.py +0 -0
  53. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/finance.py +0 -0
  54. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/food_beverage.py +0 -0
  55. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/general.py +0 -0
  56. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/geography.py +0 -0
  57. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/industry.py +0 -0
  58. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/law.py +0 -0
  59. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/library_cultural_heritage.py +0 -0
  60. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/material_science_engineering.py +0 -0
  61. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/medicine.py +0 -0
  62. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/news_media.py +0 -0
  63. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/scholarly_knowledge.py +0 -0
  64. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/social_sciences.py +0 -0
  65. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/units_measurements.py +0 -0
  66. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/upper_ontologies.py +0 -0
  67. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/ontology/web.py +0 -0
  68. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/processor.py +0 -0
  69. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/text2onto/__init__.py +0 -0
  70. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/text2onto/batchifier.py +0 -0
  71. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/text2onto/general.py +0 -0
  72. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/text2onto/splitter.py +0 -0
  73. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/text2onto/synthesizer.py +0 -0
  74. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/tools/__init__.py +0 -0
  75. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/tools/analyzer.py +0 -0
  76. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/tools/visualizer.py +0 -0
  77. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/utils/__init__.py +0 -0
  78. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/utils/io.py +0 -0
  79. {ontolearner-1.4.7 → ontolearner-1.4.9}/ontolearner/utils/train_test_split.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: OntoLearner
3
- Version: 1.4.7
3
+ Version: 1.4.9
4
4
  Summary: OntoLearner: A Modular Python Library for Ontology Learning with LLMs.
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -13,8 +13,11 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: Levenshtein
16
17
  Requires-Dist: bitsandbytes (>=0.45.1,<0.46.0)
17
18
  Requires-Dist: dspy (>=2.6.14,<3.0.0)
19
+ Requires-Dist: g4f
20
+ Requires-Dist: gensim
18
21
  Requires-Dist: huggingface-hub (>=0.34.4,<0.35.0)
19
22
  Requires-Dist: matplotlib
20
23
  Requires-Dist: mistral-common[sentencepiece] (>=1.8.5,<2.0.0)
@@ -23,6 +26,7 @@ Requires-Dist: numpy
23
26
  Requires-Dist: openpyxl
24
27
  Requires-Dist: pandas
25
28
  Requires-Dist: pathlib (==1.0.1)
29
+ Requires-Dist: protobuf (<5)
26
30
  Requires-Dist: pydantic (==2.11.3)
27
31
  Requires-Dist: python-dotenv
28
32
  Requires-Dist: rdflib (==7.1.1)
@@ -77,16 +81,16 @@ Please refer to [Installation](https://ontolearner.readthedocs.io/installation.h
77
81
 
78
82
  ## 🔗 Essential Resources
79
83
 
80
- | Resource | Info |
81
- |:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
82
- | **[📚 OntoLearner Documentation](https://ontolearner.readthedocs.io/)** | OntoLearner's extensive documentation website. |
83
- | **[🤗 Datasets on Hugging Face](https://huggingface.co/collections/SciKnowOrg/ontolearner-benchmarking-6823bcd051300c210b7ef68a)** | Access curated, machine-readable ontologies. |
84
- | **Quick Tour on OntoLearner** [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DuElAyEFzd1vtqTjDEXWcc0zCbiV2Yee?usp=sharing) ``version=1.2.1`` | OntoLearner hands-on Colab tutorials. |
85
- | **[🚀 Quickstart](https://ontolearner.readthedocs.io/quickstart.html)** | Get started quickly with OntoLearner’s main features and workflow. |
86
- | **[🕸️ Learning Tasks](https://ontolearner.readthedocs.io/learning_tasks/learning_tasks.html)** | Explore supported ontology learning tasks like LLMs4OL Paradigm tasks and Text2Onto. | |
87
- | **[🧠 Learner Models](https://ontolearner.readthedocs.io/learners/llm.html)** | Browse and configure various learner models, including LLMs, Retrieval, or RAG approaches. |
88
- | **[📚 Ontologies Documentations](https://ontolearner.readthedocs.io/benchmarking/benchmark.html)** | Review benchmark ontologies and datasets used for evaluation and training. |
89
- | **[🧩 How to work with Ontologizer?](https://ontolearner.readthedocs.io/ontologizer/ontology_modularization.html)** | Learn how to modularize and preprocess ontologies using the Ontologizer module. |
84
+ | Resource | Info |
85
+ |:-----------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------|
86
+ | **[📚 OntoLearner Documentation](https://ontolearner.readthedocs.io/)** | OntoLearner's extensive documentation website. |
87
+ | **[🤗 Datasets on Hugging Face](https://huggingface.co/collections/SciKnowOrg/ontolearner-benchmarking-6823bcd051300c210b7ef68a)** | Access curated, machine-readable ontologies. |
88
+ | **[🚀 Quickstart](https://ontolearner.readthedocs.io/quickstart.html)** | Get started quickly with OntoLearner’s main features and workflow. |
89
+ | **[🕸️ Learning Tasks](https://ontolearner.readthedocs.io/learning_tasks/learning_tasks.html)** | Explore supported ontology learning tasks like LLMs4OL Paradigm tasks and Text2Onto. | |
90
+ | **[🧠 Learner Models](https://ontolearner.readthedocs.io/learners/llm.html)** | Browse and configure various learner models, including LLMs, Retrieval, or RAG approaches. |
91
+ | **[📚 Ontologies Documentations](https://ontolearner.readthedocs.io/benchmarking/benchmark.html)** | Review benchmark ontologies and datasets used for evaluation and training. |
92
+ | **[🧩 How to work with Ontologizer?](https://ontolearner.readthedocs.io/ontologizer/ontology_modularization.html)** | Learn how to modularize and preprocess ontologies using the Ontologizer module. |
93
+ | **[🤗 Ontology Metrics Dashboard](https://huggingface.co/spaces/SciKnowOrg/OntoLearner-Benchmark-Metrics)** | Benchmark ontologies with their metrics and complexity scores. |
90
94
 
91
95
  ## 🚀 Quick Tour
92
96
  Get started with OntoLearner in just a few lines of code. This guide demonstrates how to initialize ontologies, load datasets, and train an LLM-assisted learner for ontology engineering tasks.
@@ -132,7 +136,7 @@ task = 'non-taxonomic-re'
132
136
  ret_learner = AutoRetrieverLearner(top_k=5)
133
137
  ret_learner.load(model_id='sentence-transformers/all-MiniLM-L6-v2')
134
138
 
135
- # 5. Fit the model to training data and do the predict
139
+ # 5. Fit the model to training data and then predict over the test data
136
140
  ret_learner.fit(train_data, task=task)
137
141
  predicts = ret_learner.predict(test_data, task=task)
138
142
 
@@ -39,16 +39,16 @@ Please refer to [Installation](https://ontolearner.readthedocs.io/installation.h
39
39
 
40
40
  ## 🔗 Essential Resources
41
41
 
42
- | Resource | Info |
43
- |:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
44
- | **[📚 OntoLearner Documentation](https://ontolearner.readthedocs.io/)** | OntoLearner's extensive documentation website. |
45
- | **[🤗 Datasets on Hugging Face](https://huggingface.co/collections/SciKnowOrg/ontolearner-benchmarking-6823bcd051300c210b7ef68a)** | Access curated, machine-readable ontologies. |
46
- | **Quick Tour on OntoLearner** [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DuElAyEFzd1vtqTjDEXWcc0zCbiV2Yee?usp=sharing) ``version=1.2.1`` | OntoLearner hands-on Colab tutorials. |
47
- | **[🚀 Quickstart](https://ontolearner.readthedocs.io/quickstart.html)** | Get started quickly with OntoLearner’s main features and workflow. |
48
- | **[🕸️ Learning Tasks](https://ontolearner.readthedocs.io/learning_tasks/learning_tasks.html)** | Explore supported ontology learning tasks like LLMs4OL Paradigm tasks and Text2Onto. | |
49
- | **[🧠 Learner Models](https://ontolearner.readthedocs.io/learners/llm.html)** | Browse and configure various learner models, including LLMs, Retrieval, or RAG approaches. |
50
- | **[📚 Ontologies Documentations](https://ontolearner.readthedocs.io/benchmarking/benchmark.html)** | Review benchmark ontologies and datasets used for evaluation and training. |
51
- | **[🧩 How to work with Ontologizer?](https://ontolearner.readthedocs.io/ontologizer/ontology_modularization.html)** | Learn how to modularize and preprocess ontologies using the Ontologizer module. |
42
+ | Resource | Info |
43
+ |:-----------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------|
44
+ | **[📚 OntoLearner Documentation](https://ontolearner.readthedocs.io/)** | OntoLearner's extensive documentation website. |
45
+ | **[🤗 Datasets on Hugging Face](https://huggingface.co/collections/SciKnowOrg/ontolearner-benchmarking-6823bcd051300c210b7ef68a)** | Access curated, machine-readable ontologies. |
46
+ | **[🚀 Quickstart](https://ontolearner.readthedocs.io/quickstart.html)** | Get started quickly with OntoLearner’s main features and workflow. |
47
+ | **[🕸️ Learning Tasks](https://ontolearner.readthedocs.io/learning_tasks/learning_tasks.html)** | Explore supported ontology learning tasks like LLMs4OL Paradigm tasks and Text2Onto. | |
48
+ | **[🧠 Learner Models](https://ontolearner.readthedocs.io/learners/llm.html)** | Browse and configure various learner models, including LLMs, Retrieval, or RAG approaches. |
49
+ | **[📚 Ontologies Documentations](https://ontolearner.readthedocs.io/benchmarking/benchmark.html)** | Review benchmark ontologies and datasets used for evaluation and training. |
50
+ | **[🧩 How to work with Ontologizer?](https://ontolearner.readthedocs.io/ontologizer/ontology_modularization.html)** | Learn how to modularize and preprocess ontologies using the Ontologizer module. |
51
+ | **[🤗 Ontology Metrics Dashboard](https://huggingface.co/spaces/SciKnowOrg/OntoLearner-Benchmark-Metrics)** | Benchmark ontologies with their metrics and complexity scores. |
52
52
 
53
53
  ## 🚀 Quick Tour
54
54
  Get started with OntoLearner in just a few lines of code. This guide demonstrates how to initialize ontologies, load datasets, and train an LLM-assisted learner for ontology engineering tasks.
@@ -94,7 +94,7 @@ task = 'non-taxonomic-re'
94
94
  ret_learner = AutoRetrieverLearner(top_k=5)
95
95
  ret_learner.load(model_id='sentence-transformers/all-MiniLM-L6-v2')
96
96
 
97
- # 5. Fit the model to training data and do the predict
97
+ # 5. Fit the model to training data and then predict over the test data
98
98
  ret_learner.fit(train_data, task=task)
99
99
  predicts = ret_learner.predict(test_data, task=task)
100
100
 
@@ -0,0 +1 @@
1
+ 1.4.9
@@ -236,15 +236,21 @@ class AutoLLM(ABC):
236
236
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
237
237
  self.tokenizer.pad_token = self.tokenizer.eos_token
238
238
  if self.device == "cpu":
239
- device_map = "cpu"
239
+ # device_map = "cpu"
240
+ self.model = AutoModelForCausalLM.from_pretrained(
241
+ model_id,
242
+ # device_map=device_map,
243
+ torch_dtype=torch.bfloat16,
244
+ token=self.token
245
+ )
240
246
  else:
241
247
  device_map = "balanced"
242
- self.model = AutoModelForCausalLM.from_pretrained(
243
- model_id,
244
- device_map=device_map,
245
- torch_dtype=torch.bfloat16,
246
- token=self.token
247
- )
248
+ self.model = AutoModelForCausalLM.from_pretrained(
249
+ model_id,
250
+ device_map=device_map,
251
+ torch_dtype=torch.bfloat16,
252
+ token=self.token
253
+ )
248
254
  self.label_mapper.fit()
249
255
 
250
256
  def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
@@ -290,7 +296,8 @@ class AutoLLM(ABC):
290
296
 
291
297
  # Decode only the generated part
292
298
  decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
293
-
299
+ print(decoded_outputs)
300
+ print(self.label_mapper.predict(decoded_outputs))
294
301
  # Map the decoded text to labels
295
302
  return self.label_mapper.predict(decoded_outputs)
296
303
 
@@ -301,9 +308,6 @@ class AutoRetriever(ABC):
301
308
  This class defines the interface for retrieval components used in ontology learning.
302
309
  Retrievers are responsible for finding semantically similar examples from training
303
310
  data to provide context for language models or to make direct predictions.
304
-
305
- Attributes:
306
- model: The loaded retrieval/embedding model instance.
307
311
  """
308
312
 
309
313
  def __init__(self) -> None:
@@ -313,7 +317,6 @@ class AutoRetriever(ABC):
313
317
  Sets up the basic structure with a model attribute that will be
314
318
  populated when load() is called.
315
319
  """
316
- self.model: Optional[Any] = None
317
320
  self.embedding_model = None
318
321
  self.documents = []
319
322
  self.embeddings = None
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .llm import AutoLLMLearner, FalconLLM, MistralLLM
16
- from .retriever import AutoRetrieverLearner
16
+ from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
17
17
  from .rag import AutoRAGLearner
18
18
  from .prompt import StandardizedPrompting
19
19
  from .label_mapper import LabelMapper
@@ -85,6 +85,6 @@ class LabelMapper:
85
85
  Returns:
86
86
  List[str]: Predicted labels.
87
87
  """
88
- predictions = list(self.model.predict(X))
88
+ predictions = self.model.predict(X).tolist()
89
89
  self.validate_predicts(predictions)
90
90
  return predictions
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .crossencoder import CrossEncoderRetriever
16
+ from .embedding import GloveRetriever, Word2VecRetriever
17
+ from .ngram import NgramRetriever
18
+ from .learner import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
19
+ from .llm_retriever import LLMAugmenterGenerator, LLMAugmenter, LLMAugmentedRetriever
@@ -0,0 +1,129 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from typing import List
16
+ from sentence_transformers import CrossEncoder, SentenceTransformer, util
17
+ from tqdm import tqdm
18
+ import numpy as np
19
+
20
+ from ...base import AutoRetriever
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class CrossEncoderRetriever(AutoRetriever):
26
+ """
27
+ A hybrid dense retriever that combines a BiEncoder for fast candidate
28
+ retrieval and a CrossEncoder for accurate reranking.
29
+
30
+ This retriever follows a two-stage retrieval process:
31
+
32
+ 1. **BiEncoder retrieval**:
33
+ Encodes all documents and queries into embeddings.
34
+ Computes approximate nearest neighbors to obtain a set of top-k candidates.
35
+
36
+ 2. **CrossEncoder reranking**:
37
+ Evaluates each (query, document) pair for semantic relevance.
38
+ Reranks the initial candidates and outputs the final top results.
39
+
40
+ This provides an efficient and accurate alternative to pure CrossEncoder
41
+ or pure BiEncoder approaches.
42
+ """
43
+
44
+ def __init__(self, bi_encoder_model_id: str = None) -> None:
45
+ """
46
+ Initialize the retriever.
47
+
48
+ Args:
49
+ bi_encoder_model_id (str, optional):
50
+ Model ID for the BiEncoder used in the first-stage retrieval.
51
+ If not provided, the CrossEncoder model_id passed to `load()`
52
+ will also be used as the BiEncoder.
53
+ """
54
+ super().__init__()
55
+ self.bi_encoder_model_id = bi_encoder_model_id
56
+
57
+ def load(self, model_id: str):
58
+ """
59
+ Load both the BiEncoder and CrossEncoder models.
60
+
61
+ Args:
62
+ model_id (str):
63
+ Model ID for the CrossEncoder (reranking model). If no explicit
64
+ BiEncoder ID was given at initialization, this ID is also used
65
+ for the BiEncoder.
66
+
67
+ Notes:
68
+ - BiEncoder is used for fast vector similarity search.
69
+ - CrossEncoder is used for slow but accurate reranking.
70
+ """
71
+ if not self.bi_encoder_model_id:
72
+ self.bi_encoder_model_id = model_id
73
+ self.bi_encoder = SentenceTransformer(self.bi_encoder_model_id)
74
+ self.cross_encoder = CrossEncoder(model_id)
75
+
76
+ def index(self, inputs: List[str]):
77
+ """
78
+ Pre-encode all documents using the BiEncoder to support efficient
79
+ semantic search.
80
+
81
+ Args:
82
+ inputs (List[str]):
83
+ List of documents to index.
84
+
85
+ Stores:
86
+ - `self.documents`: Raw input documents.
87
+ - `self.document_embeddings`: Tensor of BiEncoder embeddings.
88
+ """
89
+ self.documents = inputs
90
+ self.document_embeddings = self.bi_encoder.encode(inputs, convert_to_tensor=True, show_progress_bar=True)
91
+
92
+ def retrieve(self, query: List[str], top_k: int = 5, rerank_k: int = 100, batch_size: int = 32) -> List[List[str]]:
93
+ """
94
+ Retrieve top-k most relevant documents per query using a two-stage process.
95
+
96
+ Stage 1: Retrieve top `rerank_k` documents using BiEncoder embeddings.
97
+ Stage 2: Rerank those candidates using the CrossEncoder, returning `top_k`.
98
+
99
+ Args:
100
+ query (List[str]):
101
+ List of user query strings.
102
+ top_k (int):
103
+ Number of final documents to return after reranking.
104
+ rerank_k (int):
105
+ Number of candidates to retrieve before reranking.
106
+ batch_size (int):
107
+ Batch size for CrossEncoder inference.
108
+
109
+ Returns:
110
+ List[List[str]]:
111
+ For each query, a list of top-k reranked documents.
112
+ """
113
+ results = []
114
+ # Step 1: Encode queries with the BiEncoder
115
+ query_embeddings = self.bi_encoder.encode(
116
+ query, convert_to_tensor=True, show_progress_bar=True
117
+ )
118
+ # Step 2: Retrieve candidate documents
119
+ hits_batch = util.semantic_search(query_embeddings, self.document_embeddings, top_k=rerank_k)
120
+ # Step 3: Rerank using CrossEncoder
121
+ for i, hits in enumerate(tqdm(hits_batch, desc="Reranking")):
122
+ candidates = [self.documents[hit["corpus_id"]] for hit in hits]
123
+ pairs = [(query[i], doc) for doc in candidates]
124
+ scores = self.cross_encoder.predict(pairs, batch_size=batch_size, show_progress_bar=False)
125
+ reranked_idx = np.argsort(scores)[::-1][:top_k]
126
+ top_docs = [candidates[j] for j in reranked_idx]
127
+ results.append(top_docs)
128
+
129
+ return results
@@ -0,0 +1,229 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+
19
+ from tqdm import tqdm
20
+ from typing import List, Optional
21
+ from sklearn.metrics.pairwise import cosine_similarity
22
+ from gensim.models import KeyedVectors
23
+ from gensim.utils import simple_preprocess
24
+
25
+ from ...base import AutoRetriever
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class Word2VecRetriever(AutoRetriever):
31
+ """
32
+ Retriever that encodes each document by averaging its Word2Vec-style
33
+ word embeddings. Retrieval is performed by cosine similarity between
34
+ averaged document vectors and averaged query vectors.
35
+ """
36
+
37
+ def __init__(self) -> None:
38
+ """
39
+ Initialize an empty Word2VecRetriever. The model must be loaded using
40
+ :meth:`load` before indexing or retrieval.
41
+ """
42
+ super().__init__()
43
+ self.embedding_model: Optional[KeyedVectors] = None
44
+ self.documents: List[str] = []
45
+ self.embeddings: Optional[torch.Tensor] = None
46
+
47
+ def load(self, model_id: str) -> None:
48
+ """
49
+ Load a pre-trained Word2Vec KeyedVectors model.
50
+
51
+ Args:
52
+ model_id (str):
53
+ Path to a Word2Vec `.bin` or `.txt` vector file.
54
+ """
55
+ self.embedding_model = KeyedVectors.load_word2vec_format(model_id, binary=True)
56
+
57
+ def _encode_text(self, text: str) -> np.ndarray:
58
+ """
59
+ Encode text by averaging embeddings for all in-vocabulary words.
60
+
61
+ Args:
62
+ text (str): Input text string.
63
+
64
+ Returns:
65
+ np.ndarray: Averaged embedding vector. If no word is in the vocabulary,
66
+ a zero vector of appropriate dimensionality is returned.
67
+ """
68
+ if self.embedding_model is None:
69
+ raise RuntimeError("Word2Vec model must be loaded before encoding.")
70
+
71
+ words = simple_preprocess(text)
72
+ valid_vectors = [self.embedding_model[word] for word in words if word in self.embedding_model]
73
+
74
+ if not valid_vectors:
75
+ return np.zeros(self.embedding_model.vector_size)
76
+
77
+ return np.mean(valid_vectors, axis=0)
78
+
79
+ def index(self, inputs: List[str]) -> None:
80
+ """
81
+ Encode and index a list of documents.
82
+
83
+ Args:
84
+ inputs (List[str]): Documents to index.
85
+
86
+ Stores:
87
+ - self.documents: The input documents.
88
+ - self.embeddings: L2-normalized document embeddings.
89
+ """
90
+ self.documents = inputs
91
+ embeddings = [self._encode_text(doc) for doc in tqdm(inputs)]
92
+ self.embeddings = F.normalize(torch.tensor(np.stack(embeddings)), p=2, dim=1)
93
+
94
+ def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1) -> List[List[str]]:
95
+ """
96
+ Retrieve the top-k most similar documents for each query.
97
+
98
+ Args:
99
+ query (List[str]): Query texts.
100
+ top_k (int): Number of results to return per query.
101
+ batch_size (int): Batch size for processing queries. -1 means all at once.
102
+
103
+ Returns:
104
+ List[List[str]]: One list per query containing top-k matching documents.
105
+ """
106
+ if self.embeddings is None:
107
+ raise RuntimeError("Documents must be indexed before retrieval.")
108
+
109
+ query_vec = [self._encode_text(q) for q in query]
110
+ query_vec = F.normalize(torch.tensor(np.stack(query_vec)), p=2, dim=1)
111
+
112
+ if batch_size == -1:
113
+ batch_size = len(query)
114
+
115
+ results = []
116
+ for i in tqdm(range(0, len(query), batch_size)):
117
+ q_batch = query_vec[i:i + batch_size]
118
+ sim = cosine_similarity(q_batch, self.embeddings)
119
+
120
+ topk_idx = np.argsort(sim, axis=1)[:, ::-1][:, :top_k]
121
+
122
+ for row in topk_idx:
123
+ results.append([self.documents[j] for j in row])
124
+
125
+ return results
126
+
127
+
128
+ class GloveRetriever(AutoRetriever):
129
+ """
130
+ Retriever that uses GloVe embedding vectors. Each document is encoded
131
+ by averaging the embeddings of all words that exist in the GloVe vocabulary.
132
+ """
133
+
134
+ def __init__(self) -> None:
135
+ """
136
+ Initialize an empty GloveRetriever. Model must be loaded before use.
137
+ """
138
+ super().__init__()
139
+ self.embedding_model: Optional[dict] = None
140
+ self.documents: List[str] = []
141
+ self.embeddings: Optional[torch.Tensor] = None
142
+
143
+ def load(self, model_id: str) -> None:
144
+ """
145
+ Load GloVe embeddings from a text file.
146
+
147
+ Args:
148
+ model_id (str):
149
+ Path to GloVe `.txt` file, e.g. `glove.6B.300d.txt`.
150
+ """
151
+ logger.info(f"Loading GloVe embeddings from {model_id} ...")
152
+ self.embedding_model = {}
153
+
154
+ with open(model_id, "r", encoding="utf8") as f:
155
+ for line in f:
156
+ values = line.split()
157
+ word = values[0]
158
+ vec = [float(v) for v in values[1:]]
159
+ self.embedding_model[word] = vec
160
+
161
+ logger.info(f"Loaded {len(self.embedding_model)} GloVe words.")
162
+
163
+ def _encode_text(self, text: str) -> np.ndarray:
164
+ """
165
+ Encode text by averaging GloVe embeddings.
166
+
167
+ Args:
168
+ text (str): Input text.
169
+
170
+ Returns:
171
+ np.ndarray: Averaged embedding vector. Returns zero vector if no words match.
172
+ """
173
+ if self.embedding_model is None:
174
+ raise RuntimeError("GloVe model must be loaded before encoding.")
175
+
176
+ words = text.lower().split()
177
+ vecs = [self.embedding_model[w] for w in words if w in self.embedding_model]
178
+
179
+ if not vecs:
180
+ dim = len(next(iter(self.embedding_model.values())))
181
+ return np.zeros(dim)
182
+
183
+ return np.mean(vecs, axis=0)
184
+
185
+ def index(self, inputs: List[str]) -> None:
186
+ """
187
+ Index a list of documents by encoding and normalizing them.
188
+
189
+ Args:
190
+ inputs (List[str]): Documents to index.
191
+ """
192
+ if self.embedding_model is None:
193
+ raise RuntimeError("You must load a GloVe model before indexing.")
194
+
195
+ self.documents = inputs
196
+ embeddings = [self._encode_text(doc) for doc in tqdm(inputs)]
197
+ self.embeddings = F.normalize(torch.tensor(np.stack(embeddings)), p=2, dim=1)
198
+
199
+ def retrieve(self, query: List[str], top_k: int = 5, batch_size: int = -1) -> List[List[str]]:
200
+ """
201
+ Retrieve top-k most similar documents.
202
+
203
+ Args:
204
+ query (List[str]): Query texts.
205
+ top_k (int): Number of results per query.
206
+ batch_size (int): Batch size for query computation.
207
+
208
+ Returns:
209
+ List[List[str]]: Each entry is a list of top-k matching documents.
210
+ """
211
+ if self.embeddings is None:
212
+ raise RuntimeError("Documents must be indexed before retrieval.")
213
+
214
+ query_vec = [self._encode_text(q) for q in query]
215
+ query_vec = F.normalize(torch.tensor(np.stack(query_vec)), p=2, dim=1)
216
+
217
+ if batch_size == -1:
218
+ batch_size = len(query)
219
+
220
+ results = []
221
+ for i in tqdm(range(0, len(query), batch_size)):
222
+ q_batch = query_vec[i:i + batch_size]
223
+ sim = cosine_similarity(q_batch, self.embeddings)
224
+ topk_idx = np.argsort(sim, axis=1)[:, ::-1][:, :top_k]
225
+
226
+ for row in topk_idx:
227
+ results.append([self.documents[j] for j in row])
228
+
229
+ return results