camel-ai 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (47) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +107 -22
  3. camel/configs/__init__.py +6 -0
  4. camel/configs/base_config.py +21 -0
  5. camel/configs/gemini_config.py +17 -9
  6. camel/configs/qwen_config.py +91 -0
  7. camel/configs/yi_config.py +58 -0
  8. camel/generators.py +93 -0
  9. camel/interpreters/docker_interpreter.py +5 -0
  10. camel/interpreters/ipython_interpreter.py +2 -1
  11. camel/loaders/__init__.py +2 -0
  12. camel/loaders/apify_reader.py +223 -0
  13. camel/memories/agent_memories.py +24 -1
  14. camel/messages/base.py +38 -0
  15. camel/models/__init__.py +4 -0
  16. camel/models/model_factory.py +6 -0
  17. camel/models/qwen_model.py +139 -0
  18. camel/models/yi_model.py +138 -0
  19. camel/prompts/image_craft.py +8 -0
  20. camel/prompts/video_description_prompt.py +8 -0
  21. camel/retrievers/vector_retriever.py +5 -1
  22. camel/societies/role_playing.py +29 -18
  23. camel/societies/workforce/base.py +7 -1
  24. camel/societies/workforce/task_channel.py +10 -0
  25. camel/societies/workforce/utils.py +6 -0
  26. camel/societies/workforce/worker.py +2 -0
  27. camel/storages/vectordb_storages/qdrant.py +147 -24
  28. camel/tasks/task.py +15 -0
  29. camel/terminators/base.py +4 -0
  30. camel/terminators/response_terminator.py +1 -0
  31. camel/terminators/token_limit_terminator.py +1 -0
  32. camel/toolkits/__init__.py +4 -1
  33. camel/toolkits/base.py +9 -0
  34. camel/toolkits/data_commons_toolkit.py +360 -0
  35. camel/toolkits/function_tool.py +174 -7
  36. camel/toolkits/github_toolkit.py +175 -176
  37. camel/toolkits/google_scholar_toolkit.py +36 -7
  38. camel/toolkits/notion_toolkit.py +279 -0
  39. camel/toolkits/search_toolkit.py +164 -36
  40. camel/types/enums.py +88 -0
  41. camel/types/unified_model_type.py +10 -0
  42. camel/utils/commons.py +2 -1
  43. camel/utils/constants.py +2 -0
  44. {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/METADATA +129 -79
  45. {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/RECORD +47 -40
  46. {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/LICENSE +0 -0
  47. {camel_ai-0.2.6.dist-info → camel_ai-0.2.7.dist-info}/WHEEL +0 -0
@@ -18,6 +18,14 @@ from camel.types import RoleType
18
18
 
19
19
 
20
20
  class ImageCraftPromptTemplateDict(TextPromptDict):
21
+ r"""A dictionary containing :obj:`TextPrompt` used in the `ImageCraft`
22
+ task.
23
+
24
+ Attributes:
25
+ ASSISTANT_PROMPT (TextPrompt): A prompt for the AI assistant to create
26
+ an original image based on the provided descriptive captions.
27
+ """
28
+
21
29
  ASSISTANT_PROMPT = TextPrompt(
22
30
  """You are tasked with creating an original image based on
23
31
  the provided descriptive captions. Use your imagination
@@ -19,6 +19,14 @@ from camel.types import RoleType
19
19
 
20
20
  # flake8: noqa :E501
21
21
  class VideoDescriptionPromptTemplateDict(TextPromptDict):
22
+ r"""A dictionary containing :obj:`TextPrompt` used in the `VideoDescription`
23
+ task.
24
+
25
+ Attributes:
26
+ ASSISTANT_PROMPT (TextPrompt): A prompt for the AI assistant to
27
+ provide a shot description of the content of the current video.
28
+ """
29
+
22
30
  ASSISTANT_PROMPT = TextPrompt(
23
31
  """You are a master of video analysis.
24
32
  Please provide a shot description of the content of the current video."""
@@ -76,6 +76,7 @@ class VectorRetriever(BaseRetriever):
76
76
  max_characters: int = 500,
77
77
  embed_batch: int = 50,
78
78
  should_chunk: bool = True,
79
+ extra_info: Optional[dict] = None,
79
80
  **kwargs: Any,
80
81
  ) -> None:
81
82
  r"""Processes content from local file path, remote URL, string
@@ -93,6 +94,8 @@ class VectorRetriever(BaseRetriever):
93
94
  embed_batch (int): Size of batch for embeddings. Defaults to `50`.
94
95
  should_chunk (bool): If True, divide the content into chunks,
95
96
  otherwise skip chunking. Defaults to True.
97
+ extra_info (Optional[dict]): Extra information to be added
98
+ to the payload. Defaults to None.
96
99
  **kwargs (Any): Additional keyword arguments for content parsing.
97
100
  """
98
101
  from unstructured.documents.elements import Element
@@ -153,12 +156,13 @@ class VectorRetriever(BaseRetriever):
153
156
  chunk_metadata = {"metadata": chunk.metadata.to_dict()}
154
157
  # Remove the 'orig_elements' key if it exists
155
158
  chunk_metadata["metadata"].pop("orig_elements", "")
156
-
159
+ extra_info = extra_info or {}
157
160
  chunk_text = {"text": str(chunk)}
158
161
  combined_dict = {
159
162
  **content_path_info,
160
163
  **chunk_metadata,
161
164
  **chunk_text,
165
+ **extra_info,
162
166
  }
163
167
 
164
168
  records.append(
@@ -23,10 +23,10 @@ from camel.agents import (
23
23
  from camel.generators import SystemMessageGenerator
24
24
  from camel.human import Human
25
25
  from camel.messages import BaseMessage
26
- from camel.models import BaseModelBackend
26
+ from camel.models import BaseModelBackend, ModelFactory
27
27
  from camel.prompts import TextPrompt
28
28
  from camel.responses import ChatAgentResponse
29
- from camel.types import RoleType, TaskType
29
+ from camel.types import ModelPlatformType, ModelType, RoleType, TaskType
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
  logger.setLevel(logging.WARNING)
@@ -55,7 +55,8 @@ class RolePlaying:
55
55
  If not specified, set the criteria to improve task performance.
56
56
  model (BaseModelBackend, optional): The model backend to use for
57
57
  generating responses. If specified, it will override the model in
58
- all agents. (default: :obj:`None`)
58
+ all agents if not specified in agent-specific kwargs. (default:
59
+ :obj:`OpenAIModel` with `GPT_4O_MINI`)
59
60
  task_type (TaskType, optional): The type of task to perform.
60
61
  (default: :obj:`TaskType.AI_SOCIETY`)
61
62
  assistant_agent_kwargs (Dict, optional): Additional arguments to pass
@@ -103,16 +104,21 @@ class RolePlaying:
103
104
  ) -> None:
104
105
  if model is not None:
105
106
  logger.warning(
106
- "The provided model will override the model settings in "
107
- "all agents, including any configurations passed "
108
- "through assistant_agent_kwargs, user_agent_kwargs, and "
109
- "other agent-specific kwargs."
107
+ "Model provided globally is set for all agents if not"
108
+ " already specified in agent_kwargs."
110
109
  )
111
110
 
112
111
  self.with_task_specify = with_task_specify
113
112
  self.with_task_planner = with_task_planner
114
113
  self.with_critic_in_the_loop = with_critic_in_the_loop
115
- self.model = model
114
+ self.model: BaseModelBackend = (
115
+ model
116
+ if model is not None
117
+ else ModelFactory.create(
118
+ model_platform=ModelPlatformType.DEFAULT,
119
+ model_type=ModelType.DEFAULT,
120
+ )
121
+ )
116
122
  self.task_type = task_type
117
123
  self.task_prompt = task_prompt
118
124
 
@@ -204,8 +210,9 @@ class RolePlaying:
204
210
  task_specify_meta_dict.update(extend_task_specify_meta_dict or {})
205
211
  if self.model is not None:
206
212
  if task_specify_agent_kwargs is None:
207
- task_specify_agent_kwargs = {}
208
- task_specify_agent_kwargs.update(dict(model=self.model))
213
+ task_specify_agent_kwargs = {'model': self.model}
214
+ elif 'model' not in task_specify_agent_kwargs:
215
+ task_specify_agent_kwargs.update(dict(model=self.model))
209
216
  task_specify_agent = TaskSpecifyAgent(
210
217
  task_type=self.task_type,
211
218
  output_language=output_language,
@@ -237,8 +244,9 @@ class RolePlaying:
237
244
  if self.with_task_planner:
238
245
  if self.model is not None:
239
246
  if task_planner_agent_kwargs is None:
240
- task_planner_agent_kwargs = {}
241
- task_planner_agent_kwargs.update(dict(model=self.model))
247
+ task_planner_agent_kwargs = {'model': self.model}
248
+ elif 'model' not in task_planner_agent_kwargs:
249
+ task_planner_agent_kwargs.update(dict(model=self.model))
242
250
  task_planner_agent = TaskPlannerAgent(
243
251
  output_language=output_language,
244
252
  **(task_planner_agent_kwargs or {}),
@@ -332,11 +340,13 @@ class RolePlaying:
332
340
  """
333
341
  if self.model is not None:
334
342
  if assistant_agent_kwargs is None:
335
- assistant_agent_kwargs = {}
336
- assistant_agent_kwargs.update(dict(model=self.model))
343
+ assistant_agent_kwargs = {'model': self.model}
344
+ elif 'model' not in assistant_agent_kwargs:
345
+ assistant_agent_kwargs.update(dict(model=self.model))
337
346
  if user_agent_kwargs is None:
338
- user_agent_kwargs = {}
339
- user_agent_kwargs.update(dict(model=self.model))
347
+ user_agent_kwargs = {'model': self.model}
348
+ elif 'model' not in user_agent_kwargs:
349
+ user_agent_kwargs.update(dict(model=self.model))
340
350
 
341
351
  self.assistant_agent = ChatAgent(
342
352
  init_assistant_sys_msg,
@@ -394,8 +404,9 @@ class RolePlaying:
394
404
  )
395
405
  if self.model is not None:
396
406
  if critic_kwargs is None:
397
- critic_kwargs = {}
398
- critic_kwargs.update(dict(model=self.model))
407
+ critic_kwargs = {'model': self.model}
408
+ elif 'model' not in critic_kwargs:
409
+ critic_kwargs.update(dict(model=self.model))
399
410
  self.critic = CriticAgent(
400
411
  self.critic_sys_msg,
401
412
  **(critic_kwargs or {}),
@@ -19,6 +19,12 @@ from camel.societies.workforce.utils import check_if_running
19
19
 
20
20
 
21
21
  class BaseNode(ABC):
22
+ r"""Base class for all nodes in the workforce.
23
+
24
+ Args:
25
+ description (str): Description of the node.
26
+ """
27
+
22
28
  def __init__(self, description: str) -> None:
23
29
  self.node_id = str(id(self))
24
30
  self.description = description
@@ -27,7 +33,7 @@ class BaseNode(ABC):
27
33
 
28
34
  @check_if_running(False)
29
35
  def reset(self, *args: Any, **kwargs: Any) -> Any:
30
- """Resets the node to its initial state."""
36
+ r"""Resets the node to its initial state."""
31
37
  self._channel = TaskChannel()
32
38
  self._running = False
33
39
 
@@ -84,6 +84,9 @@ class TaskChannel:
84
84
  self._task_dict: Dict[str, Packet] = {}
85
85
 
86
86
  async def get_returned_task_by_publisher(self, publisher_id: str) -> Task:
87
+ r"""Get a task from the channel that has been returned by the
88
+ publisher.
89
+ """
87
90
  async with self._condition:
88
91
  while True:
89
92
  for task_id in self._task_id_list:
@@ -96,6 +99,9 @@ class TaskChannel:
96
99
  await self._condition.wait()
97
100
 
98
101
  async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
102
+ r"""Get a task from the channel that has been assigned to the
103
+ assignee.
104
+ """
99
105
  async with self._condition:
100
106
  while True:
101
107
  for task_id in self._task_id_list:
@@ -147,12 +153,14 @@ class TaskChannel:
147
153
  self._condition.notify_all()
148
154
 
149
155
  async def remove_task(self, task_id: str) -> None:
156
+ r"""Remove a task from the channel."""
150
157
  async with self._condition:
151
158
  self._task_id_list.remove(task_id)
152
159
  self._task_dict.pop(task_id)
153
160
  self._condition.notify_all()
154
161
 
155
162
  async def get_dependency_ids(self) -> List[str]:
163
+ r"""Get the IDs of all dependencies in the channel."""
156
164
  async with self._condition:
157
165
  dependency_ids = []
158
166
  for task_id in self._task_id_list:
@@ -162,11 +170,13 @@ class TaskChannel:
162
170
  return dependency_ids
163
171
 
164
172
  async def get_task_by_id(self, task_id: str) -> Task:
173
+ r"""Get a task from the channel by its ID."""
165
174
  async with self._condition:
166
175
  if task_id not in self._task_id_list:
167
176
  raise ValueError(f"Task {task_id} not found.")
168
177
  return self._task_dict[task_id].task
169
178
 
170
179
  async def get_channel_debug_info(self) -> str:
180
+ r"""Get the debug information of the channel."""
171
181
  async with self._condition:
172
182
  return str(self._task_dict) + '\n' + str(self._task_id_list)
@@ -18,6 +18,8 @@ from pydantic import BaseModel, Field
18
18
 
19
19
 
20
20
  class WorkerConf(BaseModel):
21
+ r"""The configuration of a worker."""
22
+
21
23
  role: str = Field(
22
24
  description="The role of the agent working in the work node."
23
25
  )
@@ -31,6 +33,8 @@ class WorkerConf(BaseModel):
31
33
 
32
34
 
33
35
  class TaskResult(BaseModel):
36
+ r"""The result of a task."""
37
+
34
38
  content: str = Field(description="The result of the task.")
35
39
  failed: bool = Field(
36
40
  description="Flag indicating whether the task processing failed."
@@ -38,6 +42,8 @@ class TaskResult(BaseModel):
38
42
 
39
43
 
40
44
  class TaskAssignResult(BaseModel):
45
+ r"""The result of task assignment."""
46
+
41
47
  assignee_id: str = Field(
42
48
  description="The ID of the workforce that is assigned to the task."
43
49
  )
@@ -110,9 +110,11 @@ class Worker(BaseNode, ABC):
110
110
 
111
111
  @check_if_running(False)
112
112
  async def start(self):
113
+ r"""Start the worker."""
113
114
  await self._listen_to_channel()
114
115
 
115
116
  @check_if_running(True)
116
117
  def stop(self):
118
+ r"""Stop the worker."""
117
119
  self._running = False
118
120
  return
@@ -11,8 +11,12 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+ import logging
14
15
  from datetime import datetime
15
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
17
+
18
+ if TYPE_CHECKING:
19
+ from qdrant_client import QdrantClient
16
20
 
17
21
  from camel.storages.vectordb_storages import (
18
22
  BaseVectorStorage,
@@ -25,6 +29,7 @@ from camel.types import VectorDistance
25
29
  from camel.utils import dependencies_required
26
30
 
27
31
  _qdrant_local_client_map: Dict[str, Tuple[Any, int]] = {}
32
+ logger = logging.getLogger(__name__)
28
33
 
29
34
 
30
35
  class QdrantStorage(BaseVectorStorage):
@@ -107,7 +112,13 @@ class QdrantStorage(BaseVectorStorage):
107
112
  hasattr(self, "delete_collection_on_del")
108
113
  and self.delete_collection_on_del
109
114
  ):
110
- self._delete_collection(self.collection_name)
115
+ try:
116
+ self._delete_collection(self.collection_name)
117
+ except RuntimeError as e:
118
+ logger.error(
119
+ f"Failed to delete collection"
120
+ f" '{self.collection_name}': {e}"
121
+ )
111
122
 
112
123
  def _create_client(
113
124
  self,
@@ -244,6 +255,10 @@ class QdrantStorage(BaseVectorStorage):
244
255
  "config": collection_info.config,
245
256
  }
246
257
 
258
+ def close_client(self, **kwargs):
259
+ r"""Closes the client connection to the Qdrant storage."""
260
+ self._client.close(**kwargs)
261
+
247
262
  def add(
248
263
  self,
249
264
  records: List[VectorRecord],
@@ -273,36 +288,124 @@ class QdrantStorage(BaseVectorStorage):
273
288
  f"{op_info}."
274
289
  )
275
290
 
276
- def delete(
277
- self,
278
- ids: List[str],
279
- **kwargs: Any,
291
+ def update_payload(
292
+ self, ids: List[str], payload: Dict[str, Any], **kwargs: Any
280
293
  ) -> None:
281
- r"""Deletes a list of vectors identified by their IDs from the storage.
294
+ r"""Updates the payload of the vectors identified by their IDs.
282
295
 
283
296
  Args:
284
297
  ids (List[str]): List of unique identifiers for the vectors to be
285
- deleted.
298
+ updated.
299
+ payload (Dict[str, Any]): List of payloads to be updated.
286
300
  **kwargs (Any): Additional keyword arguments.
287
301
 
288
302
  Raises:
289
- RuntimeError: If there is an error during the deletion process.
303
+ RuntimeError: If there is an error during the update process.
290
304
  """
291
305
  from qdrant_client.http.models import PointIdsList, UpdateStatus
292
306
 
293
307
  points = cast(List[Union[str, int]], ids)
294
- op_info = self._client.delete(
308
+
309
+ op_info = self._client.set_payload(
295
310
  collection_name=self.collection_name,
296
- points_selector=PointIdsList(points=points),
297
- wait=True,
311
+ payload=payload,
312
+ points=PointIdsList(points=points),
298
313
  **kwargs,
299
314
  )
300
315
  if op_info.status != UpdateStatus.COMPLETED:
301
316
  raise RuntimeError(
302
- "Failed to delete vectors in Qdrant, operation info: "
317
+ "Failed to update payload in Qdrant, operation info: "
303
318
  f"{op_info}"
304
319
  )
305
320
 
321
+ def delete_collection(self) -> None:
322
+ r"""Deletes the entire collection in the Qdrant storage."""
323
+ self._delete_collection(self.collection_name)
324
+
325
+ def delete(
326
+ self,
327
+ ids: Optional[List[str]] = None,
328
+ payload_filter: Optional[Dict[str, Any]] = None,
329
+ **kwargs: Any,
330
+ ) -> None:
331
+ r"""Deletes points from the collection based on either IDs or payload
332
+ filters.
333
+
334
+ Args:
335
+ ids (Optional[List[str]], optional): List of unique identifiers
336
+ for the vectors to be deleted.
337
+ payload_filter (Optional[Dict[str, Any]], optional): A filter for
338
+ the payload to delete points matching specific conditions. If
339
+ `ids` is provided, `payload_filter` will be ignored unless both
340
+ are combined explicitly.
341
+ **kwargs (Any): Additional keyword arguments pass to `QdrantClient.
342
+ delete`.
343
+
344
+ Examples:
345
+ >>> # Delete points with IDs "1", "2", and "3"
346
+ >>> storage.delete(ids=["1", "2", "3"])
347
+ >>> # Delete points with payload filter
348
+ >>> storage.delete(payload_filter={"name": "Alice"})
349
+
350
+ Raises:
351
+ ValueError: If neither `ids` nor `payload_filter` is provided.
352
+ RuntimeError: If there is an error during the deletion process.
353
+
354
+ Notes:
355
+ - If `ids` is provided, the points with these IDs will be deleted
356
+ directly, and the `payload_filter` will be ignored.
357
+ - If `ids` is not provided but `payload_filter` is, then points
358
+ matching the `payload_filter` will be deleted.
359
+ """
360
+ from qdrant_client.http.models import (
361
+ Condition,
362
+ FieldCondition,
363
+ Filter,
364
+ MatchValue,
365
+ PointIdsList,
366
+ UpdateStatus,
367
+ )
368
+
369
+ if not ids and not payload_filter:
370
+ raise ValueError(
371
+ "You must provide either `ids` or `payload_filter` to delete "
372
+ "points."
373
+ )
374
+
375
+ if ids:
376
+ op_info = self._client.delete(
377
+ collection_name=self.collection_name,
378
+ points_selector=PointIdsList(
379
+ points=cast(List[Union[int, str]], ids)
380
+ ),
381
+ **kwargs,
382
+ )
383
+ if op_info.status != UpdateStatus.COMPLETED:
384
+ raise RuntimeError(
385
+ "Failed to delete vectors in Qdrant, operation info: "
386
+ f"{op_info}"
387
+ )
388
+
389
+ if payload_filter:
390
+ filter_conditions = [
391
+ FieldCondition(key=key, match=MatchValue(value=value))
392
+ for key, value in payload_filter.items()
393
+ ]
394
+
395
+ op_info = self._client.delete(
396
+ collection_name=self.collection_name,
397
+ points_selector=Filter(
398
+ must=cast(List[Condition], filter_conditions)
399
+ ),
400
+ **kwargs,
401
+ )
402
+
403
+ if op_info.status != UpdateStatus.COMPLETED:
404
+ raise RuntimeError(
405
+ "Failed to delete vectors in Qdrant, operation info: "
406
+ f"{op_info}"
407
+ )
408
+
306
409
  def status(self) -> VectorDBStatus:
307
410
  status = self._get_collection_info(self.collection_name)
308
411
  return VectorDBStatus(
@@ -313,6 +416,7 @@ class QdrantStorage(BaseVectorStorage):
313
416
  def query(
314
417
  self,
315
418
  query: VectorDBQuery,
419
+ filter_conditions: Optional[Dict[str, Any]] = None,
316
420
  **kwargs: Any,
317
421
  ) -> List[VectorDBQueryResult]:
318
422
  r"""Searches for similar vectors in the storage based on the provided
@@ -321,31 +425,50 @@ class QdrantStorage(BaseVectorStorage):
321
425
  Args:
322
426
  query (VectorDBQuery): The query object containing the search
323
427
  vector and the number of top similar vectors to retrieve.
428
+ filter_conditions (Optional[Dict[str, Any]], optional): A
429
+ dictionary specifying conditions to filter the query results.
324
430
  **kwargs (Any): Additional keyword arguments.
325
431
 
326
432
  Returns:
327
433
  List[VectorDBQueryResult]: A list of vectors retrieved from the
328
434
  storage based on similarity to the query vector.
329
435
  """
330
- # TODO: filter
436
+ from qdrant_client.http.models import (
437
+ Condition,
438
+ FieldCondition,
439
+ Filter,
440
+ MatchValue,
441
+ )
442
+
443
+ # Construct filter if filter_conditions is provided
444
+ search_filter = None
445
+ if filter_conditions:
446
+ must_conditions = [
447
+ FieldCondition(key=key, match=MatchValue(value=value))
448
+ for key, value in filter_conditions.items()
449
+ ]
450
+ search_filter = Filter(must=cast(List[Condition], must_conditions))
451
+
452
+ # Execute the search with optional filter
331
453
  search_result = self._client.search(
332
454
  collection_name=self.collection_name,
333
455
  query_vector=query.query_vector,
334
456
  with_payload=True,
335
457
  with_vectors=True,
336
458
  limit=query.top_k,
459
+ query_filter=search_filter,
337
460
  **kwargs,
338
461
  )
339
- query_results = []
340
- for point in search_result:
341
- query_results.append(
342
- VectorDBQueryResult.create(
343
- similarity=point.score,
344
- id=str(point.id),
345
- payload=point.payload,
346
- vector=point.vector, # type: ignore[arg-type]
347
- )
462
+
463
+ query_results = [
464
+ VectorDBQueryResult.create(
465
+ similarity=point.score,
466
+ id=str(point.id),
467
+ payload=point.payload,
468
+ vector=point.vector, # type: ignore[arg-type]
348
469
  )
470
+ for point in search_result
471
+ ]
349
472
 
350
473
  return query_results
351
474
 
@@ -363,6 +486,6 @@ class QdrantStorage(BaseVectorStorage):
363
486
  pass
364
487
 
365
488
  @property
366
- def client(self) -> Any:
489
+ def client(self) -> "QdrantClient":
367
490
  r"""Provides access to the underlying vector database client."""
368
491
  return self._client
camel/tasks/task.py CHANGED
@@ -130,6 +130,11 @@ class Task(BaseModel):
130
130
  self.set_state(TaskState.DONE)
131
131
 
132
132
  def set_id(self, id: str):
133
+ r"""Set the id of the task.
134
+
135
+ Args:
136
+ id (str): The id of the task.
137
+ """
133
138
  self.id = id
134
139
 
135
140
  def set_state(self, state: TaskState):
@@ -147,10 +152,20 @@ class Task(BaseModel):
147
152
  self.parent.set_state(state)
148
153
 
149
154
  def add_subtask(self, task: "Task"):
155
+ r"""Add a subtask to the current task.
156
+
157
+ Args:
158
+ task (Task): The subtask to be added.
159
+ """
150
160
  task.parent = self
151
161
  self.subtasks.append(task)
152
162
 
153
163
  def remove_subtask(self, id: str):
164
+ r"""Remove a subtask from the current task.
165
+
166
+ Args:
167
+ id (str): The id of the subtask to be removed.
168
+ """
154
169
  self.subtasks = [task for task in self.subtasks if task.id != id]
155
170
 
156
171
  def get_running_task(self) -> Optional["Task"]:
camel/terminators/base.py CHANGED
@@ -18,6 +18,8 @@ from camel.messages import BaseMessage
18
18
 
19
19
 
20
20
  class BaseTerminator(ABC):
21
+ r"""Base class for terminators."""
22
+
21
23
  def __init__(self, *args, **kwargs) -> None:
22
24
  self._terminated: bool = False
23
25
  self._termination_reason: Optional[str] = None
@@ -32,6 +34,8 @@ class BaseTerminator(ABC):
32
34
 
33
35
 
34
36
  class ResponseTerminator(BaseTerminator):
37
+ r"""A terminator that terminates the conversation based on the response."""
38
+
35
39
  @abstractmethod
36
40
  def is_terminated(
37
41
  self, messages: List[BaseMessage]
@@ -122,6 +122,7 @@ class ResponseWordsTerminator(ResponseTerminator):
122
122
  return self._terminated, self._termination_reason
123
123
 
124
124
  def reset(self):
125
+ r"""Reset the terminator."""
125
126
  self._terminated = False
126
127
  self._termination_reason = None
127
128
  self._word_count_dict = defaultdict(int)
@@ -53,5 +53,6 @@ class TokenLimitTerminator(BaseTerminator):
53
53
  return self._terminated, self._termination_reason
54
54
 
55
55
  def reset(self):
56
+ r"""Reset the terminator."""
56
57
  self._terminated = False
57
58
  self._termination_reason = None
@@ -17,10 +17,10 @@ from .function_tool import (
17
17
  OpenAIFunction,
18
18
  get_openai_function_schema,
19
19
  get_openai_tool_schema,
20
+ generate_docstring,
20
21
  )
21
22
  from .open_api_specs.security_config import openapi_security_config
22
23
 
23
-
24
24
  from .math_toolkit import MathToolkit, MATH_FUNCS
25
25
  from .search_toolkit import SearchToolkit, SEARCH_FUNCS
26
26
  from .weather_toolkit import WeatherToolkit, WEATHER_FUNCS
@@ -39,6 +39,7 @@ from .slack_toolkit import SlackToolkit
39
39
  from .twitter_toolkit import TwitterToolkit, TWITTER_FUNCS
40
40
  from .open_api_toolkit import OpenAPIToolkit
41
41
  from .retrieval_toolkit import RetrievalToolkit
42
+ from .notion_toolkit import NotionToolkit
42
43
 
43
44
  __all__ = [
44
45
  'BaseToolkit',
@@ -46,6 +47,7 @@ __all__ = [
46
47
  'OpenAIFunction',
47
48
  'get_openai_function_schema',
48
49
  'get_openai_tool_schema',
50
+ "generate_docstring",
49
51
  'openapi_security_config',
50
52
  'GithubToolkit',
51
53
  'MathToolkit',
@@ -63,6 +65,7 @@ __all__ = [
63
65
  'AskNewsToolkit',
64
66
  'AsyncAskNewsToolkit',
65
67
  'GoogleScholarToolkit',
68
+ 'NotionToolkit',
66
69
  'ArxivToolkit',
67
70
  'MATH_FUNCS',
68
71
  'SEARCH_FUNCS',
camel/toolkits/base.py CHANGED
@@ -19,5 +19,14 @@ from camel.utils import AgentOpsMeta
19
19
 
20
20
 
21
21
  class BaseToolkit(metaclass=AgentOpsMeta):
22
+ r"""Base class for toolkits."""
23
+
22
24
  def get_tools(self) -> List[FunctionTool]:
25
+ r"""Returns a list of FunctionTool objects representing the
26
+ functions in the toolkit.
27
+
28
+ Returns:
29
+ List[FunctionTool]: A list of FunctionTool objects
30
+ representing the functions in the toolkit.
31
+ """
23
32
  raise NotImplementedError("Subclasses must implement this method.")