camel-ai 0.1.6.5__py3-none-any.whl → 0.1.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -0,0 +1,291 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import OpenAI, Stream
18
+
19
+ from camel.configs import SAMBA_API_PARAMS
20
+ from camel.messages import OpenAIMessage
21
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
22
+ from camel.utils import (
23
+ BaseTokenCounter,
24
+ OpenAITokenCounter,
25
+ api_keys_required,
26
+ )
27
+
28
+
29
+ class SambaModel:
30
+ r"""SambaNova service interface."""
31
+
32
+ def __init__(
33
+ self,
34
+ model_type: ModelType,
35
+ model_config_dict: Dict[str, Any],
36
+ api_key: Optional[str] = None,
37
+ url: Optional[str] = None,
38
+ token_counter: Optional[BaseTokenCounter] = None,
39
+ ) -> None:
40
+ r"""Constructor for SambaNova backend.
41
+
42
+ Args:
43
+ model_type (ModelType): Model for which a SambaNova backend is
44
+ created.
45
+ model_config_dict (Dict[str, Any]): A dictionary that will
46
+ be fed into API request.
47
+ api_key (Optional[str]): The API key for authenticating with the
48
+ SambaNova service. (default: :obj:`None`)
49
+ url (Optional[str]): The url to the SambaNova service. (default:
50
+ :obj:`"https://fast-api.snova.ai/v1/chat/completions"`)
51
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
52
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
53
+ GPT_4O_MINI)` will be used.
54
+ """
55
+ self.model_type = model_type
56
+ self._api_key = api_key or os.environ.get("SAMBA_API_KEY")
57
+ self._url = url or os.environ.get("SAMBA_API_BASE_URL")
58
+ self._token_counter = token_counter
59
+ self.model_config_dict = model_config_dict
60
+ self.check_model_config()
61
+
62
+ @property
63
+ def token_counter(self) -> BaseTokenCounter:
64
+ r"""Initialize the token counter for the model backend.
65
+
66
+ Returns:
67
+ BaseTokenCounter: The token counter following the model's
68
+ tokenization style.
69
+ """
70
+ if not self._token_counter:
71
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
72
+ return self._token_counter
73
+
74
+ def check_model_config(self):
75
+ r"""Check whether the model configuration contains any
76
+ unexpected arguments to SambaNova API.
77
+
78
+ Raises:
79
+ ValueError: If the model configuration dictionary contains any
80
+ unexpected arguments to SambaNova API.
81
+ """
82
+ for param in self.model_config_dict:
83
+ if param not in SAMBA_API_PARAMS:
84
+ raise ValueError(
85
+ f"Unexpected argument `{param}` is "
86
+ "input into SambaNova model backend."
87
+ )
88
+
89
+ @api_keys_required("SAMBA_API_KEY")
90
+ def run( # type: ignore[misc]
91
+ self, messages: List[OpenAIMessage]
92
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
93
+ r"""Runs SambaNova's FastAPI service.
94
+
95
+ Args:
96
+ messages (List[OpenAIMessage]): Message list with the chat history
97
+ in OpenAI API format.
98
+
99
+ Returns:
100
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
101
+ `ChatCompletion` in the non-stream mode, or
102
+ `Stream[ChatCompletionChunk]` in the stream mode.
103
+ """
104
+
105
+ if self.model_config_dict.get("stream") is True:
106
+ return self._run_streaming(messages)
107
+ else:
108
+ return self._run_non_streaming(messages)
109
+
110
+ def _run_streaming( # type: ignore[misc]
111
+ self, messages: List[OpenAIMessage]
112
+ ) -> Stream[ChatCompletionChunk]:
113
+ r"""Handles streaming inference with SambaNova FastAPI.
114
+
115
+ Args:
116
+ messages (List[OpenAIMessage]): A list of messages representing the
117
+ chat history in OpenAI API format.
118
+
119
+ Returns:
120
+ Stream[ChatCompletionChunk]: A generator yielding
121
+ `ChatCompletionChunk` objects as they are received from the
122
+ API.
123
+
124
+ Raises:
125
+ RuntimeError: If the HTTP request fails.
126
+ """
127
+
128
+ import httpx
129
+
130
+ headers = {
131
+ "Authorization": f"Basic {self._api_key}",
132
+ "Content-Type": "application/json",
133
+ }
134
+
135
+ data = {
136
+ "messages": messages,
137
+ "max_tokens": self.token_limit,
138
+ "stop": self.model_config_dict.get("stop"),
139
+ "model": self.model_type.value,
140
+ "stream": True,
141
+ "stream_options": self.model_config_dict.get("stream_options"),
142
+ }
143
+
144
+ try:
145
+ with httpx.stream(
146
+ "POST",
147
+ self._url or "https://fast-api.snova.ai/v1/chat/completions",
148
+ headers=headers,
149
+ json=data,
150
+ ) as api_response:
151
+ stream = Stream[ChatCompletionChunk](
152
+ cast_to=ChatCompletionChunk,
153
+ response=api_response,
154
+ client=OpenAI(),
155
+ )
156
+ for chunk in stream:
157
+ yield chunk
158
+ except httpx.HTTPError as e:
159
+ raise RuntimeError(f"HTTP request failed: {e!s}")
160
+
161
+ def _run_non_streaming(
162
+ self, messages: List[OpenAIMessage]
163
+ ) -> ChatCompletion:
164
+ r"""Handles non-streaming inference with SambaNova FastAPI.
165
+
166
+ Args:
167
+ messages (List[OpenAIMessage]): A list of messages representing the
168
+ message in OpenAI API format.
169
+
170
+ Returns:
171
+ ChatCompletion: A `ChatCompletion` object containing the complete
172
+ response from the API.
173
+
174
+ Raises:
175
+ RuntimeError: If the HTTP request fails.
176
+ ValueError: If the JSON response cannot be decoded or is missing
177
+ expected data.
178
+ """
179
+
180
+ import json
181
+
182
+ import httpx
183
+
184
+ headers = {
185
+ "Authorization": f"Basic {self._api_key}",
186
+ "Content-Type": "application/json",
187
+ }
188
+
189
+ data = {
190
+ "messages": messages,
191
+ "max_tokens": self.token_limit,
192
+ "stop": self.model_config_dict.get("stop"),
193
+ "model": self.model_type.value,
194
+ "stream": True,
195
+ "stream_options": self.model_config_dict.get("stream_options"),
196
+ }
197
+
198
+ try:
199
+ with httpx.stream(
200
+ "POST",
201
+ self._url or "https://fast-api.snova.ai/v1/chat/completions",
202
+ headers=headers,
203
+ json=data,
204
+ ) as api_response:
205
+ samba_response = []
206
+ for chunk in api_response.iter_text():
207
+ if chunk.startswith('data: '):
208
+ chunk = chunk[6:]
209
+ if '[DONE]' in chunk:
210
+ break
211
+ json_data = json.loads(chunk)
212
+ samba_response.append(json_data)
213
+ return self._to_openai_response(samba_response)
214
+ except httpx.HTTPError as e:
215
+ raise RuntimeError(f"HTTP request failed: {e!s}")
216
+ except json.JSONDecodeError as e:
217
+ raise ValueError(f"Failed to decode JSON response: {e!s}")
218
+
219
+ def _to_openai_response(
220
+ self, samba_response: List[Dict[str, Any]]
221
+ ) -> ChatCompletion:
222
+ r"""Converts SambaNova response chunks into an OpenAI-compatible
223
+ response.
224
+
225
+ Args:
226
+ samba_response (List[Dict[str, Any]]): A list of dictionaries
227
+ representing partial responses from the SambaNova API.
228
+
229
+ Returns:
230
+ ChatCompletion: A `ChatCompletion` object constructed from the
231
+ aggregated response data.
232
+
233
+ Raises:
234
+ ValueError: If the response data is invalid or incomplete.
235
+ """
236
+ # Step 1: Combine the content from each chunk
237
+ full_content = ""
238
+ for chunk in samba_response:
239
+ if chunk['choices']:
240
+ for choice in chunk['choices']:
241
+ delta_content = choice['delta'].get('content', '')
242
+ full_content += delta_content
243
+
244
+ # Step 2: Create the ChatCompletion object
245
+ # Extract relevant information from the first chunk
246
+ first_chunk = samba_response[0]
247
+
248
+ choices = [
249
+ dict(
250
+ index=0, # type: ignore[index]
251
+ message={
252
+ "role": 'assistant',
253
+ "content": full_content.strip(),
254
+ },
255
+ finish_reason=samba_response[-1]['choices'][0]['finish_reason']
256
+ or None,
257
+ )
258
+ ]
259
+
260
+ obj = ChatCompletion.construct(
261
+ id=first_chunk['id'],
262
+ choices=choices,
263
+ created=first_chunk['created'],
264
+ model=first_chunk['model'],
265
+ object="chat.completion",
266
+ usage=None,
267
+ )
268
+
269
+ return obj
270
+
271
+ @property
272
+ def token_limit(self) -> int:
273
+ r"""Returns the maximum token limit for a given model.
274
+
275
+ Returns:
276
+ int: The maximum token limit for the given model.
277
+ """
278
+ return (
279
+ self.model_config_dict.get("max_tokens")
280
+ or self.model_type.token_limit
281
+ )
282
+
283
+ @property
284
+ def stream(self) -> bool:
285
+ r"""Returns whether the model is in stream mode, which sends partial
286
+ results each time.
287
+
288
+ Returns:
289
+ bool: Whether the model is in stream mode.
290
+ """
291
+ return self.model_config_dict.get('stream', False)
@@ -0,0 +1,148 @@
1
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2
+ # Licensed under the Apache License, Version 2.0 (the “License”);
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an “AS IS” BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+
15
+ import os
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ from openai import OpenAI, Stream
19
+
20
+ from camel.configs import TOGETHERAI_API_PARAMS
21
+ from camel.messages import OpenAIMessage
22
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
23
+ from camel.utils import (
24
+ BaseTokenCounter,
25
+ OpenAITokenCounter,
26
+ api_keys_required,
27
+ )
28
+
29
+
30
+ class TogetherAIModel:
31
+ r"""Constructor for Together AI backend with OpenAI compatibility.
32
+ TODO: Add function calling support
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model_type: str,
38
+ model_config_dict: Dict[str, Any],
39
+ api_key: Optional[str] = None,
40
+ url: Optional[str] = None,
41
+ token_counter: Optional[BaseTokenCounter] = None,
42
+ ) -> None:
43
+ r"""Constructor for TogetherAI backend.
44
+
45
+ Args:
46
+ model_type (str): Model for which a backend is created, supported
47
+ model can be found here: https://docs.together.ai/docs/chat-models
48
+ model_config_dict (Dict[str, Any]): A dictionary that will
49
+ be fed into openai.ChatCompletion.create().
50
+ api_key (Optional[str]): The API key for authenticating with the
51
+ Together service. (default: :obj:`None`)
52
+ url (Optional[str]): The url to the Together AI service. (default:
53
+ :obj:`"https://api.together.xyz/v1"`)
54
+ token_counter (Optional[BaseTokenCounter]): Token counter to use
55
+ for the model. If not provided, `OpenAITokenCounter(ModelType.
56
+ GPT_4O_MINI)` will be used.
57
+ """
58
+ self.model_type = model_type
59
+ self.model_config_dict = model_config_dict
60
+ self._token_counter = token_counter
61
+ self._api_key = api_key or os.environ.get("TOGETHER_API_KEY")
62
+ self._url = url or os.environ.get("TOGETHER_API_BASE_URL")
63
+
64
+ self._client = OpenAI(
65
+ timeout=60,
66
+ max_retries=3,
67
+ api_key=self._api_key,
68
+ base_url=self._url or "https://api.together.xyz/v1",
69
+ )
70
+
71
+ @api_keys_required("TOGETHER_API_KEY")
72
+ def run(
73
+ self,
74
+ messages: List[OpenAIMessage],
75
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
76
+ r"""Runs inference of OpenAI chat completion.
77
+
78
+ Args:
79
+ messages (List[OpenAIMessage]): Message list with the chat history
80
+ in OpenAI API format.
81
+
82
+ Returns:
83
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
84
+ `ChatCompletion` in the non-stream mode, or
85
+ `Stream[ChatCompletionChunk]` in the stream mode.
86
+ """
87
+ # Use OpenAI cilent as interface call Together AI
88
+ # Reference: https://docs.together.ai/docs/openai-api-compatibility
89
+ response = self._client.chat.completions.create(
90
+ messages=messages,
91
+ model=self.model_type,
92
+ **self.model_config_dict,
93
+ )
94
+ return response
95
+
96
+ @property
97
+ def token_counter(self) -> BaseTokenCounter:
98
+ r"""Initialize the token counter for the model backend.
99
+
100
+ Returns:
101
+ OpenAITokenCounter: The token counter following the model's
102
+ tokenization style.
103
+ """
104
+
105
+ if not self._token_counter:
106
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
107
+ return self._token_counter
108
+
109
+ def check_model_config(self):
110
+ r"""Check whether the model configuration contains any
111
+ unexpected arguments to TogetherAI API.
112
+
113
+ Raises:
114
+ ValueError: If the model configuration dictionary contains any
115
+ unexpected arguments to TogetherAI API.
116
+ """
117
+ for param in self.model_config_dict:
118
+ if param not in TOGETHERAI_API_PARAMS:
119
+ raise ValueError(
120
+ f"Unexpected argument `{param}` is "
121
+ "input into TogetherAI model backend."
122
+ )
123
+
124
+ @property
125
+ def stream(self) -> bool:
126
+ r"""Returns whether the model is in stream mode, which sends partial
127
+ results each time.
128
+
129
+ Returns:
130
+ bool: Whether the model is in stream mode.
131
+ """
132
+ return self.model_config_dict.get('stream', False)
133
+
134
+ @property
135
+ def token_limit(self) -> int:
136
+ r"""Returns the maximum token limit for the given model.
137
+
138
+ Returns:
139
+ int: The maximum token limit for the given model.
140
+ """
141
+ max_tokens = self.model_config_dict.get("max_tokens")
142
+ if isinstance(max_tokens, int):
143
+ return max_tokens
144
+ print(
145
+ "Must set `max_tokens` as an integer in `model_config_dict` when"
146
+ " setting up the model. Using 4096 as default value."
147
+ )
148
+ return 4096
@@ -11,6 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import os
14
15
  from typing import Any, Dict, List, Optional, Union
15
16
 
16
17
  from openai import OpenAI, Stream
@@ -42,20 +43,21 @@ class VLLMModel:
42
43
  model_config_dict (Dict[str, Any]): A dictionary that will
43
44
  be fed into openai.ChatCompletion.create().
44
45
  url (Optional[str]): The url to the model service. (default:
45
- :obj:`None`)
46
+ :obj:`"http://localhost:8000/v1"`)
46
47
  api_key (Optional[str]): The API key for authenticating with the
47
48
  model service.
48
49
  token_counter (Optional[BaseTokenCounter]): Token counter to use
49
50
  for the model. If not provided, `OpenAITokenCounter(ModelType.
50
- GPT_3_5_TURBO)` will be used.
51
+ GPT_4O_MINI)` will be used.
51
52
  """
52
53
  self.model_type = model_type
53
54
  self.model_config_dict = model_config_dict
55
+ self._url = url or os.environ.get("VLLM_BASE_URL")
54
56
  # Use OpenAI cilent as interface call vLLM
55
57
  self._client = OpenAI(
56
58
  timeout=60,
57
59
  max_retries=3,
58
- base_url=url,
60
+ base_url=self._url or "http://localhost:8000/v1",
59
61
  api_key=api_key,
60
62
  )
61
63
  self._token_counter = token_counter
@@ -70,7 +72,7 @@ class VLLMModel:
70
72
  tokenization style.
71
73
  """
72
74
  if not self._token_counter:
73
- self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
75
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
74
76
  return self._token_counter
75
77
 
76
78
  def check_model_config(self):
@@ -113,7 +115,7 @@ class VLLMModel:
113
115
 
114
116
  @property
115
117
  def token_limit(self) -> int:
116
- """Returns the maximum token limit for the given model.
118
+ r"""Returns the maximum token limit for the given model.
117
119
 
118
120
  Returns:
119
121
  int: The maximum token limit for the given model.
@@ -52,7 +52,7 @@ class ZhipuAIModel(BaseModelBackend):
52
52
  :obj:`None`)
53
53
  token_counter (Optional[BaseTokenCounter]): Token counter to use
54
54
  for the model. If not provided, `OpenAITokenCounter(ModelType.
55
- GPT_3_5_TURBO)` will be used.
55
+ GPT_4O_MINI)` will be used.
56
56
  """
57
57
  super().__init__(
58
58
  model_type, model_config_dict, api_key, url, token_counter
@@ -105,7 +105,7 @@ class ZhipuAIModel(BaseModelBackend):
105
105
  """
106
106
 
107
107
  if not self._token_counter:
108
- self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
108
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
109
109
  return self._token_counter
110
110
 
111
111
  def check_model_config(self):
@@ -14,9 +14,7 @@
14
14
  import datetime
15
15
  import os
16
16
  import re
17
- from pathlib import Path
18
17
  from typing import Collection, List, Optional, Sequence, Tuple, Union
19
- from urllib.parse import urlparse
20
18
 
21
19
  from camel.embeddings import BaseEmbedding, OpenAIEmbedding
22
20
  from camel.retrievers.vector_retriever import VectorRetriever
@@ -28,6 +26,11 @@ from camel.storages import (
28
26
  )
29
27
  from camel.types import StorageType
30
28
 
29
+ try:
30
+ from unstructured.documents.elements import Element
31
+ except ImportError:
32
+ Element = None
33
+
31
34
  DEFAULT_TOP_K_RESULTS = 1
32
35
  DEFAULT_SIMILARITY_THRESHOLD = 0.75
33
36
 
@@ -97,41 +100,22 @@ class AutoRetriever:
97
100
  f"Unsupported vector storage type: {self.storage_type}"
98
101
  )
99
102
 
100
- def _collection_name_generator(self, content: str) -> str:
103
+ def _collection_name_generator(self, content: Union[str, Element]) -> str:
101
104
  r"""Generates a valid collection name from a given file path or URL.
102
105
 
103
106
  Args:
104
- contents (str): Local file path, remote URL or string content.
107
+ content (Union[str, Element]): Local file path, remote URL,
108
+ string content or Element object.
105
109
 
106
110
  Returns:
107
111
  str: A sanitized, valid collection name suitable for use.
108
112
  """
109
- # Check if the content is URL
110
- parsed_url = urlparse(content)
111
- is_url = all([parsed_url.scheme, parsed_url.netloc])
112
-
113
- # Convert given path into a collection name, ensuring it only
114
- # contains numbers, letters, and underscores
115
- if is_url:
116
- # For URLs, remove https://, replace /, and any characters not
117
- # allowed by Milvus with _
118
- collection_name = re.sub(
119
- r'[^0-9a-zA-Z]+',
120
- '_',
121
- content.replace("https://", ""),
122
- )
123
- elif os.path.exists(content):
124
- # For file paths, get the stem and replace spaces with _, also
125
- # ensuring only allowed characters are present
126
- collection_name = re.sub(r'[^0-9a-zA-Z]+', '_', Path(content).stem)
127
- else:
128
- # the content is string input
129
- collection_name = content[:10]
130
113
 
131
- # Ensure the collection name does not start or end with underscore
132
- collection_name = collection_name.strip("_")
133
- # Limit the maximum length of the collection name to 30 characters
134
- collection_name = collection_name[:30]
114
+ if isinstance(content, Element):
115
+ content = content.metadata.file_directory
116
+
117
+ collection_name = re.sub(r'[^a-zA-Z0-9]', '', content)[:20]
118
+
135
119
  return collection_name
136
120
 
137
121
  def _get_file_modified_date_from_file(
@@ -193,7 +177,7 @@ class AutoRetriever:
193
177
  def run_vector_retriever(
194
178
  self,
195
179
  query: str,
196
- contents: Union[str, List[str]],
180
+ contents: Union[str, List[str], Element, List[Element]],
197
181
  top_k: int = DEFAULT_TOP_K_RESULTS,
198
182
  similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
199
183
  return_detailed_info: bool = False,
@@ -203,8 +187,8 @@ class AutoRetriever:
203
187
 
204
188
  Args:
205
189
  query (str): Query string for information retriever.
206
- contents (Union[str, List[str]]): Local file paths, remote URLs or
207
- string contents.
190
+ contents (Union[str, List[str], Element, List[Element]]): Local
191
+ file paths, remote URLs, string contents or Element objects.
208
192
  top_k (int, optional): The number of top results to return during
209
193
  retrieve. Must be a positive integer. Defaults to
210
194
  `DEFAULT_TOP_K_RESULTS`.
@@ -230,7 +214,9 @@ class AutoRetriever:
230
214
  if not contents:
231
215
  raise ValueError("content cannot be empty.")
232
216
 
233
- contents = [contents] if isinstance(contents, str) else contents
217
+ contents = (
218
+ [contents] if isinstance(contents, (str, Element)) else contents
219
+ )
234
220
 
235
221
  all_retrieved_info = []
236
222
  for content in contents:
@@ -246,6 +232,7 @@ class AutoRetriever:
246
232
  file_is_modified = False # initialize with a default value
247
233
  if (
248
234
  vector_storage_instance.status().vector_count != 0
235
+ and isinstance(content, str)
249
236
  and os.path.exists(content)
250
237
  ):
251
238
  # Get original modified date from file
@@ -13,7 +13,7 @@
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
14
  import os
15
15
  import warnings
16
- from typing import Any, Dict, List, Optional
16
+ from typing import Any, Dict, List, Optional, Union
17
17
  from urllib.parse import urlparse
18
18
 
19
19
  from camel.embeddings import BaseEmbedding, OpenAIEmbedding
@@ -26,6 +26,11 @@ from camel.storages import (
26
26
  VectorRecord,
27
27
  )
28
28
 
29
+ try:
30
+ from unstructured.documents.elements import Element
31
+ except ImportError:
32
+ Element = None
33
+
29
34
  DEFAULT_TOP_K_RESULTS = 1
30
35
  DEFAULT_SIMILARITY_THRESHOLD = 0.75
31
36
 
@@ -69,7 +74,7 @@ class VectorRetriever(BaseRetriever):
69
74
 
70
75
  def process(
71
76
  self,
72
- content: str,
77
+ content: Union[str, Element],
73
78
  chunk_type: str = "chunk_by_title",
74
79
  **kwargs: Any,
75
80
  ) -> None:
@@ -78,18 +83,22 @@ class VectorRetriever(BaseRetriever):
78
83
  vector storage.
79
84
 
80
85
  Args:
81
- contents (str): Local file path, remote URL or string content.
86
+ content (Union[str, Element]): Local file path, remote URL,
87
+ string content or Element object.
82
88
  chunk_type (str): Type of chunking going to apply. Defaults to
83
89
  "chunk_by_title".
84
90
  **kwargs (Any): Additional keyword arguments for content parsing.
85
91
  """
86
- # Check if the content is URL
87
- parsed_url = urlparse(content)
88
- is_url = all([parsed_url.scheme, parsed_url.netloc])
89
- if is_url or os.path.exists(content):
90
- elements = self.uio.parse_file_or_url(content, **kwargs)
92
+ if isinstance(content, Element):
93
+ elements = [content]
91
94
  else:
92
- elements = [self.uio.create_element_from_text(text=content)]
95
+ # Check if the content is URL
96
+ parsed_url = urlparse(content)
97
+ is_url = all([parsed_url.scheme, parsed_url.netloc])
98
+ if is_url or os.path.exists(content):
99
+ elements = self.uio.parse_file_or_url(content, **kwargs) or []
100
+ else:
101
+ elements = [self.uio.create_element_from_text(text=content)]
93
102
  if elements:
94
103
  chunks = self.uio.chunk_elements(
95
104
  chunk_type=chunk_type, elements=elements
@@ -110,7 +119,12 @@ class VectorRetriever(BaseRetriever):
110
119
  # Prepare the payload for each vector record, includes the content
111
120
  # path, chunk metadata, and chunk text
112
121
  for vector, chunk in zip(batch_vectors, batch_chunks):
113
- content_path_info = {"content path": content}
122
+ if isinstance(content, str):
123
+ content_path_info = {"content path": content}
124
+ elif isinstance(content, Element):
125
+ content_path_info = {
126
+ "content path": content.metadata.file_directory
127
+ }
114
128
  chunk_metadata = {"metadata": chunk.metadata.to_dict()}
115
129
  chunk_text = {"text": str(chunk)}
116
130
  combined_dict = {