camel-ai 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (45) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/__init__.py +2 -0
  3. camel/agents/chat_agent.py +40 -53
  4. camel/agents/knowledge_graph_agent.py +221 -0
  5. camel/configs/__init__.py +29 -0
  6. camel/configs/anthropic_config.py +73 -0
  7. camel/configs/base_config.py +22 -0
  8. camel/configs/openai_config.py +132 -0
  9. camel/embeddings/openai_embedding.py +7 -2
  10. camel/functions/__init__.py +13 -8
  11. camel/functions/open_api_function.py +380 -0
  12. camel/functions/open_api_specs/coursera/__init__.py +13 -0
  13. camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
  14. camel/functions/open_api_specs/klarna/__init__.py +13 -0
  15. camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
  16. camel/functions/open_api_specs/speak/__init__.py +13 -0
  17. camel/functions/open_api_specs/speak/openapi.yaml +151 -0
  18. camel/functions/openai_function.py +3 -1
  19. camel/functions/retrieval_functions.py +61 -0
  20. camel/functions/slack_functions.py +275 -0
  21. camel/models/__init__.py +2 -0
  22. camel/models/anthropic_model.py +16 -2
  23. camel/models/base_model.py +8 -2
  24. camel/models/model_factory.py +7 -3
  25. camel/models/openai_audio_models.py +251 -0
  26. camel/models/openai_model.py +12 -4
  27. camel/models/stub_model.py +5 -1
  28. camel/retrievers/__init__.py +2 -0
  29. camel/retrievers/auto_retriever.py +47 -36
  30. camel/retrievers/base.py +42 -37
  31. camel/retrievers/bm25_retriever.py +10 -19
  32. camel/retrievers/cohere_rerank_retriever.py +108 -0
  33. camel/retrievers/vector_retriever.py +43 -26
  34. camel/storages/vectordb_storages/qdrant.py +3 -1
  35. camel/toolkits/__init__.py +21 -0
  36. camel/toolkits/base.py +22 -0
  37. camel/toolkits/github_toolkit.py +245 -0
  38. camel/types/__init__.py +6 -0
  39. camel/types/enums.py +44 -3
  40. camel/utils/__init__.py +4 -2
  41. camel/utils/commons.py +97 -173
  42. {camel_ai-0.1.3.dist-info → camel_ai-0.1.4.dist-info}/METADATA +9 -3
  43. {camel_ai-0.1.3.dist-info → camel_ai-0.1.4.dist-info}/RECORD +44 -26
  44. camel/configs.py +0 -271
  45. {camel_ai-0.1.3.dist-info → camel_ai-0.1.4.dist-info}/WHEEL +0 -0
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  from abc import ABC, abstractmethod
15
- from typing import Any, Dict, List, Union
15
+ from typing import Any, Dict, List, Optional, Union
16
16
 
17
17
  from openai import Stream
18
18
 
@@ -27,17 +27,23 @@ class BaseModelBackend(ABC):
27
27
  """
28
28
 
29
29
  def __init__(
30
- self, model_type: ModelType, model_config_dict: Dict[str, Any]
30
+ self,
31
+ model_type: ModelType,
32
+ model_config_dict: Dict[str, Any],
33
+ api_key: Optional[str] = None,
31
34
  ) -> None:
32
35
  r"""Constructor for the model backend.
33
36
 
34
37
  Args:
35
38
  model_type (ModelType): Model for which a backend is created.
36
39
  model_config_dict (Dict[str, Any]): A config dictionary.
40
+ api_key (Optional[str]): The API key for authenticating with the
41
+ LLM service.
37
42
  """
38
43
  self.model_type = model_type
39
44
 
40
45
  self.model_config_dict = model_config_dict
46
+ self._api_key = api_key
41
47
  self.check_model_config()
42
48
 
43
49
  @property
@@ -11,7 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
- from typing import Any, Dict
14
+ from typing import Any, Dict, Optional
15
15
 
16
16
  from camel.models.anthropic_model import AnthropicModel
17
17
  from camel.models.base_model import BaseModelBackend
@@ -30,7 +30,9 @@ class ModelFactory:
30
30
 
31
31
  @staticmethod
32
32
  def create(
33
- model_type: ModelType, model_config_dict: Dict
33
+ model_type: ModelType,
34
+ model_config_dict: Dict,
35
+ api_key: Optional[str] = None,
34
36
  ) -> BaseModelBackend:
35
37
  r"""Creates an instance of `BaseModelBackend` of the specified type.
36
38
 
@@ -38,6 +40,8 @@ class ModelFactory:
38
40
  model_type (ModelType): Model for which a backend is created.
39
41
  model_config_dict (Dict): A dictionary that will be fed into
40
42
  the backend constructor.
43
+ api_key (Optional[str]): The API key for authenticating with the
44
+ LLM service.
41
45
 
42
46
  Raises:
43
47
  ValueError: If there is not backend for the model.
@@ -57,5 +61,5 @@ class ModelFactory:
57
61
  else:
58
62
  raise ValueError(f"Unknown model type `{model_type}` is input")
59
63
 
60
- inst = model_class(model_type, model_config_dict)
64
+ inst = model_class(model_type, model_config_dict, api_key)
61
65
  return inst
@@ -0,0 +1,251 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, List, Optional, Union
16
+
17
+ from openai import OpenAI, _legacy_response
18
+
19
+ from camel.types import AudioModelType, VoiceType
20
+
21
+
22
+ class OpenAIAudioModels:
23
+ r"""Provides access to OpenAI's Text-to-Speech (TTS) and Speech_to_Text
24
+ (STT) models."""
25
+
26
+ def __init__(
27
+ self,
28
+ ) -> None:
29
+ r"""Initialize an instance of OpenAI."""
30
+ url = os.environ.get('OPENAI_API_BASE_URL')
31
+ self._client = OpenAI(timeout=120, max_retries=3, base_url=url)
32
+
33
+ def text_to_speech(
34
+ self,
35
+ input: str,
36
+ model_type: AudioModelType = AudioModelType.TTS_1,
37
+ voice: VoiceType = VoiceType.ALLOY,
38
+ storage_path: Optional[str] = None,
39
+ **kwargs: Any,
40
+ ) -> Union[
41
+ List[_legacy_response.HttpxBinaryResponseContent],
42
+ _legacy_response.HttpxBinaryResponseContent,
43
+ ]:
44
+ r"""Convert text to speech using OpenAI's TTS model. This method
45
+ converts the given input text to speech using the specified model and
46
+ voice.
47
+
48
+ Args:
49
+ input (str): The text to be converted to speech.
50
+ model_type (AudioModelType, optional): The TTS model to use.
51
+ Defaults to `AudioModelType.TTS_1`.
52
+ voice (VoiceType, optional): The voice to be used for generating
53
+ speech. Defaults to `VoiceType.ALLOY`.
54
+ storage_path (str, optional): The local path to store the
55
+ generated speech file if provided, defaults to `None`.
56
+ **kwargs (Any): Extra kwargs passed to the TTS API.
57
+
58
+ Returns:
59
+ Union[List[_legacy_response.HttpxBinaryResponseContent],
60
+ _legacy_response.HttpxBinaryResponseContent]: List of response
61
+ content object from OpenAI if input charaters more than 4096,
62
+ single response content if input charaters less than 4096.
63
+
64
+ Raises:
65
+ Exception: If there's an error during the TTS API call.
66
+ """
67
+ try:
68
+ # Model only support at most 4096 characters one time.
69
+ max_chunk_size = 4095
70
+ audio_chunks = []
71
+ chunk_index = 0
72
+ if len(input) > max_chunk_size:
73
+ while input:
74
+ if len(input) <= max_chunk_size:
75
+ chunk = input
76
+ input = ''
77
+ else:
78
+ # Find the nearest period before the chunk size limit
79
+ while input[max_chunk_size - 1] != '.':
80
+ max_chunk_size -= 1
81
+
82
+ chunk = input[:max_chunk_size]
83
+ input = input[max_chunk_size:].lstrip()
84
+
85
+ response = self._client.audio.speech.create(
86
+ model=model_type.value,
87
+ voice=voice.value,
88
+ input=chunk,
89
+ **kwargs,
90
+ )
91
+ if storage_path:
92
+ try:
93
+ # Create a new storage path for each chunk
94
+ file_name, file_extension = os.path.splitext(
95
+ storage_path
96
+ )
97
+ new_storage_path = (
98
+ f"{file_name}_{chunk_index}{file_extension}"
99
+ )
100
+ response.write_to_file(new_storage_path)
101
+ chunk_index += 1
102
+ except Exception as e:
103
+ raise Exception(
104
+ "Error during writing the file"
105
+ ) from e
106
+
107
+ audio_chunks.append(response)
108
+ return audio_chunks
109
+
110
+ else:
111
+ response = self._client.audio.speech.create(
112
+ model=model_type.value,
113
+ voice=voice.value,
114
+ input=input,
115
+ **kwargs,
116
+ )
117
+
118
+ if storage_path:
119
+ try:
120
+ response.write_to_file(storage_path)
121
+ except Exception as e:
122
+ raise Exception("Error during write the file") from e
123
+
124
+ return response
125
+
126
+ except Exception as e:
127
+ raise Exception("Error during TTS API call") from e
128
+
129
+ def _split_audio(
130
+ self, audio_file_path: str, chunk_size_mb: int = 24
131
+ ) -> list:
132
+ r"""Split the audio file into smaller chunks. Since the Whisper API
133
+ only supports files that are less than 25 MB.
134
+
135
+ Args:
136
+ audio_file_path (str): Path to the input audio file.
137
+ chunk_size_mb (int, optional): Size of each chunk in megabytes.
138
+ Defaults to `24`.
139
+
140
+ Returns:
141
+ list: List of paths to the split audio files.
142
+ """
143
+ from pydub import AudioSegment
144
+
145
+ audio = AudioSegment.from_file(audio_file_path)
146
+ audio_format = os.path.splitext(audio_file_path)[1][1:].lower()
147
+
148
+ # Calculate chunk size in bytes
149
+ chunk_size_bytes = chunk_size_mb * 1024 * 1024
150
+
151
+ # Number of chunks needed
152
+ num_chunks = os.path.getsize(audio_file_path) // chunk_size_bytes + 1
153
+
154
+ # Create a directory to store the chunks
155
+ output_dir = os.path.splitext(audio_file_path)[0] + "_chunks"
156
+ os.makedirs(output_dir, exist_ok=True)
157
+
158
+ # Get audio chunk len in milliseconds
159
+ chunk_size_milliseconds = len(audio) // (num_chunks)
160
+
161
+ # Split the audio into chunks
162
+ split_files = []
163
+ for i in range(num_chunks):
164
+ start = i * chunk_size_milliseconds
165
+ end = (i + 1) * chunk_size_milliseconds
166
+ if i + 1 == num_chunks:
167
+ chunk = audio[start:]
168
+ else:
169
+ chunk = audio[start:end]
170
+ # Create new chunk path
171
+ chunk_path = os.path.join(output_dir, f"chunk_{i}.{audio_format}")
172
+ chunk.export(chunk_path, format=audio_format)
173
+ split_files.append(chunk_path)
174
+ return split_files
175
+
176
+ def speech_to_text(
177
+ self,
178
+ audio_file_path: str,
179
+ translate_into_english: bool = False,
180
+ **kwargs: Any,
181
+ ) -> str:
182
+ r"""Convert speech audio to text.
183
+
184
+ Args:
185
+ audio_file_path (str): The audio file path, supporting one of
186
+ these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or
187
+ webm.
188
+ translate_into_english (bool, optional): Whether to translate the
189
+ speech into English. Defaults to `False`.
190
+ **kwargs (Any): Extra keyword arguments passed to the
191
+ Speech-to-Text (STT) API.
192
+
193
+ Returns:
194
+ str: The output text.
195
+
196
+ Raises:
197
+ ValueError: If the audio file format is not supported.
198
+ Exception: If there's an error during the STT API call.
199
+ """
200
+ supported_formats = [
201
+ "flac",
202
+ "mp3",
203
+ "mp4",
204
+ "mpeg",
205
+ "mpga",
206
+ "m4a",
207
+ "ogg",
208
+ "wav",
209
+ "webm",
210
+ ]
211
+ file_format = audio_file_path.split(".")[-1].lower()
212
+
213
+ if file_format not in supported_formats:
214
+ raise ValueError(f"Unsupported audio file format: {file_format}")
215
+ try:
216
+ if os.path.getsize(audio_file_path) > 24 * 1024 * 1024:
217
+ # Split audio into chunks
218
+ audio_chunks = self._split_audio(audio_file_path)
219
+ texts = []
220
+ for chunk_path in audio_chunks:
221
+ audio_data = open(chunk_path, "rb")
222
+ if translate_into_english:
223
+ translation = self._client.audio.translations.create(
224
+ model="whisper-1", file=audio_data, **kwargs
225
+ )
226
+ texts.append(translation.text)
227
+ else:
228
+ transcription = (
229
+ self._client.audio.transcriptions.create(
230
+ model="whisper-1", file=audio_data, **kwargs
231
+ )
232
+ )
233
+ texts.append(transcription.text)
234
+ os.remove(chunk_path) # Delete temporary chunk file
235
+ return " ".join(texts)
236
+ else:
237
+ # Process the entire audio file
238
+ audio_data = open(audio_file_path, "rb")
239
+
240
+ if translate_into_english:
241
+ translation = self._client.audio.translations.create(
242
+ model="whisper-1", file=audio_data, **kwargs
243
+ )
244
+ return translation.text
245
+ else:
246
+ transcription = self._client.audio.transcriptions.create(
247
+ model="whisper-1", file=audio_data, **kwargs
248
+ )
249
+ return transcription.text
250
+ except Exception as e:
251
+ raise Exception("Error during STT API call") from e
@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional, Union
16
16
 
17
17
  from openai import OpenAI, Stream
18
18
 
19
- from camel.configs import OPENAI_API_PARAMS_WITH_FUNCTIONS
19
+ from camel.configs import OPENAI_API_PARAMS
20
20
  from camel.messages import OpenAIMessage
21
21
  from camel.models import BaseModelBackend
22
22
  from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
@@ -27,7 +27,10 @@ class OpenAIModel(BaseModelBackend):
27
27
  r"""OpenAI API in a unified BaseModelBackend interface."""
28
28
 
29
29
  def __init__(
30
- self, model_type: ModelType, model_config_dict: Dict[str, Any]
30
+ self,
31
+ model_type: ModelType,
32
+ model_config_dict: Dict[str, Any],
33
+ api_key: Optional[str] = None,
31
34
  ) -> None:
32
35
  r"""Constructor for OpenAI backend.
33
36
 
@@ -36,10 +39,15 @@ class OpenAIModel(BaseModelBackend):
36
39
  one of GPT_* series.
37
40
  model_config_dict (Dict[str, Any]): A dictionary that will
38
41
  be fed into openai.ChatCompletion.create().
42
+ api_key (Optional[str]): The API key for authenticating with the
43
+ OpenAI service. (default: :obj:`None`)
39
44
  """
40
45
  super().__init__(model_type, model_config_dict)
41
46
  url = os.environ.get('OPENAI_API_BASE_URL', None)
42
- self._client = OpenAI(timeout=60, max_retries=3, base_url=url)
47
+ self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
48
+ self._client = OpenAI(
49
+ timeout=60, max_retries=3, base_url=url, api_key=self._api_key
50
+ )
43
51
  self._token_counter: Optional[BaseTokenCounter] = None
44
52
 
45
53
  @property
@@ -86,7 +94,7 @@ class OpenAIModel(BaseModelBackend):
86
94
  unexpected arguments to OpenAI API.
87
95
  """
88
96
  for param in self.model_config_dict:
89
- if param not in OPENAI_API_PARAMS_WITH_FUNCTIONS:
97
+ if param not in OPENAI_API_PARAMS:
90
98
  raise ValueError(
91
99
  f"Unexpected argument `{param}` is "
92
100
  "input into OpenAI model backend."
@@ -50,11 +50,15 @@ class StubModel(BaseModelBackend):
50
50
  model_type = ModelType.STUB
51
51
 
52
52
  def __init__(
53
- self, model_type: ModelType, model_config_dict: Dict[str, Any]
53
+ self,
54
+ model_type: ModelType,
55
+ model_config_dict: Dict[str, Any],
56
+ api_key: Optional[str] = None,
54
57
  ) -> None:
55
58
  r"""All arguments are unused for the dummy model."""
56
59
  super().__init__(model_type, model_config_dict)
57
60
  self._token_counter: Optional[BaseTokenCounter] = None
61
+ self._api_key = api_key
58
62
 
59
63
  @property
60
64
  def token_counter(self) -> BaseTokenCounter:
@@ -14,6 +14,7 @@
14
14
  from .auto_retriever import AutoRetriever
15
15
  from .base import BaseRetriever
16
16
  from .bm25_retriever import BM25Retriever
17
+ from .cohere_rerank_retriever import CohereRerankRetriever
17
18
  from .vector_retriever import VectorRetriever
18
19
 
19
20
  __all__ = [
@@ -21,4 +22,5 @@ __all__ = [
21
22
  'VectorRetriever',
22
23
  'AutoRetriever',
23
24
  'BM25Retriever',
25
+ 'CohereRerankRetriever',
24
26
  ]
@@ -63,7 +63,8 @@ class AutoRetriever:
63
63
  self,
64
64
  collection_name: Optional[str] = None,
65
65
  ) -> BaseVectorStorage:
66
- r"""Sets up and returns a vector storage instance with specified parameters.
66
+ r"""Sets up and returns a vector storage instance with specified
67
+ parameters.
67
68
 
68
69
  Args:
69
70
  collection_name (Optional[str]): Name of the collection in the
@@ -195,7 +196,8 @@ class AutoRetriever:
195
196
  similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
196
197
  return_detailed_info: bool = False,
197
198
  ) -> str:
198
- r"""Executes the automatic vector retriever process using vector storage.
199
+ r"""Executes the automatic vector retriever process using vector
200
+ storage.
199
201
 
200
202
  Args:
201
203
  query (str): Query string for information retriever.
@@ -233,9 +235,7 @@ class AutoRetriever:
233
235
 
234
236
  vr = VectorRetriever()
235
237
 
236
- retrieved_infos = ""
237
- retrieved_infos_text = ""
238
-
238
+ all_retrieved_info = []
239
239
  for content_input_path in content_input_paths:
240
240
  # Generate a valid collection name
241
241
  collection_name = self._collection_name_generator(
@@ -278,42 +278,53 @@ class AutoRetriever:
278
278
  # Clear the vector storage
279
279
  vector_storage_instance.clear()
280
280
  # Process and store the content to the vector storage
281
- vr.process(content_input_path, vector_storage_instance)
281
+ vr = VectorRetriever(
282
+ storage=vector_storage_instance,
283
+ similarity_threshold=similarity_threshold,
284
+ )
285
+ vr.process(content_input_path)
286
+ else:
287
+ vr = VectorRetriever(
288
+ storage=vector_storage_instance,
289
+ similarity_threshold=similarity_threshold,
290
+ )
282
291
  # Retrieve info by given query from the vector storage
283
- retrieved_info = vr.query(
284
- query, vector_storage_instance, top_k, similarity_threshold
285
- )
286
- # Reorganize the retrieved info with original query
287
- for info in retrieved_info:
288
- retrieved_infos += "\n" + str(info)
289
- retrieved_infos_text += "\n" + str(info['text'])
290
- output = (
291
- "Original Query:"
292
- + "\n"
293
- + "{"
294
- + query
295
- + "}"
296
- + "\n"
297
- + "Retrieved Context:"
298
- + retrieved_infos
299
- )
300
- output_text = (
301
- "Original Query:"
302
- + "\n"
303
- + "{"
304
- + query
305
- + "}"
306
- + "\n"
307
- + "Retrieved Context:"
308
- + retrieved_infos_text
309
- )
310
-
292
+ retrieved_info = vr.query(query, top_k)
293
+ all_retrieved_info.extend(retrieved_info)
311
294
  except Exception as e:
312
295
  raise RuntimeError(
313
296
  f"Error in auto vector retriever processing: {e!s}"
314
297
  ) from e
315
298
 
299
+ # Split records into those with and without a 'similarity_score'
300
+ # Records with 'similarity_score' lower than 'similarity_threshold'
301
+ # will not have a 'similarity_score' in the output content
302
+ with_score = [
303
+ info for info in all_retrieved_info if 'similarity score' in info
304
+ ]
305
+ without_score = [
306
+ info
307
+ for info in all_retrieved_info
308
+ if 'similarity score' not in info
309
+ ]
310
+ # Sort only the list with scores
311
+ with_score_sorted = sorted(
312
+ with_score, key=lambda x: x['similarity score'], reverse=True
313
+ )
314
+ # Merge back the sorted scored items with the non-scored items
315
+ all_retrieved_info_sorted = with_score_sorted + without_score
316
+ # Select the 'top_k' results
317
+ all_retrieved_info = all_retrieved_info_sorted[:top_k]
318
+
319
+ retrieved_infos = "\n".join(str(info) for info in all_retrieved_info)
320
+ retrieved_infos_text = "\n".join(
321
+ info['text'] for info in all_retrieved_info if 'text' in info
322
+ )
323
+
324
+ detailed_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos}"
325
+ text_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos_text}"
326
+
316
327
  if return_detailed_info:
317
- return output
328
+ return detailed_info
318
329
  else:
319
- return output_text
330
+ return text_info
camel/retrievers/base.py CHANGED
@@ -12,53 +12,58 @@
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  from abc import ABC, abstractmethod
15
- from typing import Any, Dict, List
15
+ from typing import Any, Callable
16
16
 
17
17
  DEFAULT_TOP_K_RESULTS = 1
18
18
 
19
19
 
20
- class BaseRetriever(ABC):
21
- r"""Abstract base class for implementing various types of information
22
- retrievers.
20
+ def _query_unimplemented(self, *input: Any) -> None:
21
+ r"""Defines the query behavior performed at every call.
22
+
23
+ Query the results. Subclasses should implement this
24
+ method according to their specific needs.
25
+
26
+ It should be overridden by all subclasses.
27
+
28
+ .. note::
29
+ Although the recipe for forward pass needs to be defined within
30
+ this function, one should call the :class:`BaseRetriever` instance
31
+ afterwards instead of this since the former takes care of running the
32
+ registered hooks while the latter silently ignores them.
23
33
  """
34
+ raise NotImplementedError(
35
+ f"Retriever [{type(self).__name__}] is missing the required \"query\" function"
36
+ )
24
37
 
25
- @abstractmethod
26
- def __init__(self) -> None:
27
- pass
28
38
 
29
- @abstractmethod
30
- def process(
31
- self,
32
- content_input_path: str,
33
- chunk_type: str = "chunk_by_title",
34
- **kwargs: Any,
35
- ) -> None:
36
- r"""Processes content from a file or URL, divides it into chunks by
39
+ def _process_unimplemented(self, *input: Any) -> None:
40
+ r"""Defines the process behavior performed at every call.
41
+
42
+ Processes content from a file or URL, divides it into chunks by
37
43
  using `Unstructured IO`,then stored internally. This method must be
38
44
  called before executing queries with the retriever.
39
45
 
40
- Args:
41
- content_input_path (str): File path or URL of the content to be
42
- processed.
43
- chunk_type (str): Type of chunking going to apply. Defaults to
44
- "chunk_by_title".
45
- **kwargs (Any): Additional keyword arguments for content parsing.
46
- """
47
- pass
46
+ Should be overridden by all subclasses.
48
47
 
49
- @abstractmethod
50
- def query(
51
- self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, **kwargs: Any
52
- ) -> List[Dict[str, Any]]:
53
- r"""Query the results. Subclasses should implement this
54
- method according to their specific needs.
48
+ .. note::
49
+ Although the recipe for forward pass needs to be defined within
50
+ this function, one should call the :class:`BaseRetriever` instance
51
+ afterwards instead of this since the former takes care of running the
52
+ registered hooks while the latter silently ignores them.
53
+ """
54
+ raise NotImplementedError(
55
+ f"Retriever [{type(self).__name__}] is missing the required \"process\" function"
56
+ )
57
+
58
+
59
+ class BaseRetriever(ABC):
60
+ r"""Abstract base class for implementing various types of information
61
+ retrievers.
62
+ """
55
63
 
56
- Args:
57
- query (str): Query string for information retriever.
58
- top_k (int, optional): The number of top results to return during
59
- retriever. Must be a positive integer. Defaults to
60
- `DEFAULT_TOP_K_RESULTS`.
61
- **kwargs (Any): Flexible keyword arguments for additional
62
- parameters, like `similarity_threshold`.
63
- """
64
+ @abstractmethod
65
+ def __init__(self) -> None:
64
66
  pass
67
+
68
+ process: Callable[..., Any] = _process_unimplemented
69
+ query: Callable[..., Any] = _query_unimplemented