camel-ai 0.2.34__py3-none-any.whl → 0.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -13,10 +13,14 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import abc
16
+ import asyncio
16
17
  import json
17
18
  import random
18
19
  from pathlib import Path
19
- from typing import List, Union
20
+ from typing import Any, Dict, List, Union
21
+
22
+ from pydantic import ValidationError
23
+ from torch.utils.data import IterableDataset
20
24
 
21
25
  from camel.logger import get_logger
22
26
 
@@ -25,23 +29,47 @@ from .models import DataPoint
25
29
  logger = get_logger(__name__)
26
30
 
27
31
 
28
- class BaseGenerator(abc.ABC):
32
+ class BaseGenerator(abc.ABC, IterableDataset):
29
33
  r"""Abstract base class for data generators.
30
34
 
31
35
  This class defines the interface for generating synthetic datapoints.
32
36
  Concrete implementations should provide specific generation strategies.
33
37
  """
34
38
 
35
- def __init__(self, seed: int = 42, **kwargs):
39
+ def __init__(
40
+ self,
41
+ seed: int = 42,
42
+ cache: Union[str, Path, None] = None,
43
+ data_path: Union[str, Path, None] = None,
44
+ **kwargs,
45
+ ):
36
46
  r"""Initialize the base generator.
37
47
 
38
48
  Args:
39
49
  seed (int): Random seed for reproducibility. (default: :obj:`42`)
50
+ cache (Union[str, Path, None]): Optional path to save generated
51
+ datapoints during iteration. If None is provided, datapoints
52
+ will be discarded every 100 generations.
53
+ data_path (Union[str, Path, None]): Optional path to a JSONL file
54
+ to initialize the dataset from.
40
55
  **kwargs: Additional generator parameters.
41
56
  """
42
57
  self._rng = random.Random(seed)
58
+ self.cache = Path(cache) if cache else None
43
59
 
44
60
  self._data: List[DataPoint] = []
61
+ self._batch_to_save: List[DataPoint] = []
62
+
63
+ if data_path:
64
+ file_path = Path(data_path)
65
+ raw_data = self._init_from_jsonl(file_path)
66
+ try:
67
+ data_points = [DataPoint(**item) for item in raw_data]
68
+ self._data.extend(data_points)
69
+ except ValidationError as e:
70
+ raise ValueError(
71
+ f"Failed to create DataPoint from JSONL data: {e}"
72
+ )
45
73
 
46
74
  @abc.abstractmethod
47
75
  async def generate_new(self, n: int, **kwargs) -> List[DataPoint]:
@@ -56,34 +84,112 @@ class BaseGenerator(abc.ABC):
56
84
  """
57
85
  pass
58
86
 
59
- def __len__(self) -> int:
60
- r"""Return the size of the generated dataset."""
61
- return len(self._data)
87
+ def __aiter__(self):
88
+ r"""Async iterator that yields datapoints dynamically.
62
89
 
63
- def __getitem__(self, idx: int) -> DataPoint:
64
- r"""Retrieve a datapoint by index.
90
+ If a `data_path` was provided during initialization, those datapoints
91
+ are yielded first. When self._data is empty, 20 new datapoints
92
+ are generated. Every 100 yields, the batch is appended to the
93
+ JSONL file or discarded if `cache` is None.
65
94
 
66
- Args:
67
- idx (int): Index of the datapoint.
95
+ Yields:
96
+ DataPoint: A single datapoint.
97
+ """
68
98
 
69
- Returns:
70
- DataPoint: The datapoint corresponding to the given index.
99
+ async def generator():
100
+ while True:
101
+ if not self._data:
102
+ new_datapoints = await self.generate_new(20)
103
+ self._data.extend(new_datapoints)
104
+ datapoint = self._data.pop(0)
105
+ yield datapoint
106
+ self._batch_to_save.append(datapoint)
107
+ if len(self._batch_to_save) == 100:
108
+ if self.cache:
109
+ with self.cache.open("a", encoding="utf-8") as f:
110
+ for dp in self._batch_to_save:
111
+ json.dump(dp.to_dict(), f, ensure_ascii=False)
112
+ f.write("\n")
113
+ self._batch_to_save = []
71
114
 
72
- Raises:
73
- IndexError: If idx is out of bounds.
115
+ return generator()
116
+
117
+ def __iter__(self):
118
+ r"""Synchronous iterator for PyTorch IterableDataset compatibility.
119
+
120
+ If a `data_path` was provided during initialization, those datapoints
121
+ are yielded first. When self._data is empty, 20 new datapoints
122
+ are generated. Every 100 yields, the batch is appended to the
123
+ JSONL file or discarded if `cache` is None.
124
+
125
+ Yields:
126
+ DataPoint: A single datapoint.
74
127
  """
75
- if idx < 0 or idx >= len(self._data):
76
- raise IndexError(
77
- f"Index {idx} out of bounds for dataset of "
78
- f"size {len(self._data)}"
79
- )
80
- return self._data[idx]
128
+ try:
129
+ if asyncio.get_event_loop().is_running():
130
+ raise RuntimeError(
131
+ "Cannot use synchronous iteration (__iter__) in an async "
132
+ "context; use 'async for' with __aiter__ instead"
133
+ )
134
+ except RuntimeError as e:
135
+ if "no running event loop" not in str(e):
136
+ raise
137
+
138
+ while True:
139
+ if not self._data:
140
+ new_datapoints = asyncio.run(self.generate_new(20))
141
+ self._data.extend(new_datapoints)
142
+ datapoint = self._data.pop(0)
143
+ yield datapoint
144
+ self._batch_to_save.append(datapoint)
145
+ if len(self._batch_to_save) == 100:
146
+ if self.cache:
147
+ with self.cache.open("a", encoding="utf-8") as f:
148
+ for dp in self._batch_to_save:
149
+ json.dump(dp.to_dict(), f, ensure_ascii=False)
150
+ f.write("\n")
151
+ self._batch_to_save = []
81
152
 
82
153
  def sample(self) -> DataPoint:
83
- if len(self._data) == 0:
84
- raise RuntimeError("Dataset is empty, cannot sample.")
85
- idx = self._rng.randint(0, len(self._data) - 1)
86
- return self[idx]
154
+ r"""Returns the next datapoint from the current dataset
155
+ synchronously.
156
+
157
+ Raises:
158
+ RuntimeError: If called in an async context.
159
+
160
+ Returns:
161
+ DataPoint: The next DataPoint.
162
+
163
+ Note:
164
+ This method is intended for synchronous contexts.
165
+ Use 'async_sample' in asynchronous contexts to
166
+ avoid blocking or runtime errors.
167
+ """
168
+ try:
169
+ if asyncio.get_event_loop().is_running():
170
+ raise RuntimeError(
171
+ "Cannot use synchronous sampling (sample) "
172
+ "in an async context; use async_sample instead"
173
+ )
174
+ except RuntimeError as e:
175
+ if "no running event loop" not in str(e):
176
+ raise
177
+
178
+ return next(iter(self))
179
+
180
+ async def async_sample(self) -> DataPoint:
181
+ r"""Returns the next datapoint from the current dataset asynchronously.
182
+
183
+ Returns:
184
+ DataPoint: The next datapoint.
185
+
186
+ Note:
187
+ This method is intended for asynchronous contexts. Use 'sample'
188
+ in synchronous contexts.
189
+ """
190
+
191
+ async_iter = self.__aiter__()
192
+ return await async_iter.__anext__()
87
193
 
88
194
  def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
89
195
  r"""Saves the generated datapoints to a JSONL (JSON Lines) file.
@@ -99,7 +205,7 @@ class BaseGenerator(abc.ABC):
99
205
 
100
206
  Notes:
101
207
  - Uses `self._data`, which contains the generated datapoints.
102
- - Overwrites the file if it already exists.
208
+ - Appends to the file if it already exists.
103
209
  - Ensures compatibility with large datasets by using JSONL format.
104
210
  """
105
211
  if not self._data:
@@ -108,11 +214,66 @@ class BaseGenerator(abc.ABC):
108
214
  file_path = Path(file_path)
109
215
 
110
216
  try:
111
- with file_path.open("w", encoding="utf-8") as f:
217
+ with file_path.open("a", encoding="utf-8") as f:
112
218
  for datapoint in self._data:
113
219
  json.dump(datapoint.to_dict(), f, ensure_ascii=False)
114
- f.write("\n") # Ensure each entry is on a new line
220
+ f.write("\n")
115
221
  logger.info(f"Dataset saved successfully to {file_path}")
116
222
  except IOError as e:
117
223
  logger.error(f"Error writing to file {file_path}: {e}")
118
224
  raise
225
+
226
+ def flush(self, file_path: Union[str, Path]) -> None:
227
+ r"""Flush the current data to a JSONL file and clear the data.
228
+
229
+ Args:
230
+ file_path (Union[str, Path]): Path to save the JSONL file.
231
+
232
+ Notes:
233
+ - Uses `save_to_jsonl` to save `self._data`.
234
+ """
235
+
236
+ self.save_to_jsonl(file_path)
237
+ self._data = []
238
+ logger.info(f"Data flushed to {file_path} and cleared from the memory")
239
+
240
+ def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
241
+ r"""Load and parse a dataset from a JSONL file.
242
+
243
+ Args:
244
+ file_path (Path): Path to the JSONL file.
245
+
246
+ Returns:
247
+ List[Dict[str, Any]]: A list of datapoint dictionaries.
248
+
249
+ Raises:
250
+ FileNotFoundError: If the specified JSONL file does not exist.
251
+ ValueError: If a line contains invalid JSON or is not a dictionary.
252
+ """
253
+ if not file_path.exists():
254
+ raise FileNotFoundError(f"JSONL file not found: {file_path}")
255
+
256
+ raw_data = []
257
+ logger.debug(f"Loading JSONL from {file_path}")
258
+ with file_path.open('r', encoding='utf-8') as f:
259
+ for line_number, line in enumerate(f, start=1):
260
+ line = line.strip()
261
+ if not line:
262
+ continue # Skip blank lines
263
+ try:
264
+ record = json.loads(line)
265
+ except json.JSONDecodeError as e:
266
+ raise ValueError(
267
+ f"Invalid JSON on line {line_number} "
268
+ f"in file {file_path}: {e}"
269
+ )
270
+ if not isinstance(record, dict):
271
+ raise ValueError(
272
+ f"Expected a dictionary at line {line_number}, "
273
+ f"got {type(record).__name__}"
274
+ )
275
+ raw_data.append(record)
276
+ logger.info(
277
+ f"Successfully loaded {len(raw_data)} items from {file_path}"
278
+ )
279
+ return raw_data
@@ -33,6 +33,8 @@ class ChatHistoryMemory(AgentMemory):
33
33
  window_size (int, optional): The number of recent chat messages to
34
34
  retrieve. If not provided, the entire chat history will be
35
35
  retrieved. (default: :obj:`None`)
36
+ agent_id (str, optional): The ID of the agent associated with the chat
37
+ history.
36
38
  """
37
39
 
38
40
  def __init__(
@@ -40,6 +42,7 @@ class ChatHistoryMemory(AgentMemory):
40
42
  context_creator: BaseContextCreator,
41
43
  storage: Optional[BaseKeyValueStorage] = None,
42
44
  window_size: Optional[int] = None,
45
+ agent_id: Optional[str] = None,
43
46
  ) -> None:
44
47
  if window_size is not None and not isinstance(window_size, int):
45
48
  raise TypeError("`window_size` must be an integer or None.")
@@ -48,6 +51,15 @@ class ChatHistoryMemory(AgentMemory):
48
51
  self._context_creator = context_creator
49
52
  self._window_size = window_size
50
53
  self._chat_history_block = ChatHistoryBlock(storage=storage)
54
+ self._agent_id = agent_id
55
+
56
+ @property
57
+ def agent_id(self) -> Optional[str]:
58
+ return self._agent_id
59
+
60
+ @agent_id.setter
61
+ def agent_id(self, val: Optional[str]) -> None:
62
+ self._agent_id = val
51
63
 
52
64
  def retrieve(self) -> List[ContextRecord]:
53
65
  records = self._chat_history_block.retrieve(self._window_size)
@@ -63,6 +75,10 @@ class ChatHistoryMemory(AgentMemory):
63
75
  return records
64
76
 
65
77
  def write_records(self, records: List[MemoryRecord]) -> None:
78
+ for record in records:
79
+ # assign the agent_id to the record
80
+ if record.agent_id == "" and self.agent_id is not None:
81
+ record.agent_id = self.agent_id
66
82
  self._chat_history_block.write_records(records)
67
83
 
68
84
  def get_context_creator(self) -> BaseContextCreator:
@@ -84,6 +100,8 @@ class VectorDBMemory(AgentMemory):
84
100
  (default: :obj:`None`)
85
101
  retrieve_limit (int, optional): The maximum number of messages
86
102
  to be added into the context. (default: :obj:`3`)
103
+ agent_id (str, optional): The ID of the agent associated with
104
+ the messages stored in the vector database.
87
105
  """
88
106
 
89
107
  def __init__(
@@ -91,13 +109,23 @@ class VectorDBMemory(AgentMemory):
91
109
  context_creator: BaseContextCreator,
92
110
  storage: Optional[BaseVectorStorage] = None,
93
111
  retrieve_limit: int = 3,
112
+ agent_id: Optional[str] = None,
94
113
  ) -> None:
95
114
  self._context_creator = context_creator
96
115
  self._retrieve_limit = retrieve_limit
97
116
  self._vectordb_block = VectorDBBlock(storage=storage)
117
+ self._agent_id = agent_id
98
118
 
99
119
  self._current_topic: str = ""
100
120
 
121
+ @property
122
+ def agent_id(self) -> Optional[str]:
123
+ return self._agent_id
124
+
125
+ @agent_id.setter
126
+ def agent_id(self, val: Optional[str]) -> None:
127
+ self._agent_id = val
128
+
101
129
  def retrieve(self) -> List[ContextRecord]:
102
130
  return self._vectordb_block.retrieve(
103
131
  self._current_topic,
@@ -109,6 +137,11 @@ class VectorDBMemory(AgentMemory):
109
137
  for record in records:
110
138
  if record.role_at_backend == OpenAIBackendRole.USER:
111
139
  self._current_topic = record.message.content
140
+
141
+ # assign the agent_id to the record
142
+ if record.agent_id == "" and self.agent_id is not None:
143
+ record.agent_id = self.agent_id
144
+
112
145
  self._vectordb_block.write_records(records)
113
146
 
114
147
  def get_context_creator(self) -> BaseContextCreator:
@@ -133,6 +166,8 @@ class LongtermAgentMemory(AgentMemory):
133
166
  (default: :obj:`None`)
134
167
  retrieve_limit (int, optional): The maximum number of messages
135
168
  to be added into the context. (default: :obj:`3`)
169
+ agent_id (str, optional): The ID of the agent associated with the chat
170
+ history and the messages stored in the vector database.
136
171
  """
137
172
 
138
173
  def __init__(
@@ -141,12 +176,22 @@ class LongtermAgentMemory(AgentMemory):
141
176
  chat_history_block: Optional[ChatHistoryBlock] = None,
142
177
  vector_db_block: Optional[VectorDBBlock] = None,
143
178
  retrieve_limit: int = 3,
179
+ agent_id: Optional[str] = None,
144
180
  ) -> None:
145
181
  self.chat_history_block = chat_history_block or ChatHistoryBlock()
146
182
  self.vector_db_block = vector_db_block or VectorDBBlock()
147
183
  self.retrieve_limit = retrieve_limit
148
184
  self._context_creator = context_creator
149
185
  self._current_topic: str = ""
186
+ self._agent_id = agent_id
187
+
188
+ @property
189
+ def agent_id(self) -> Optional[str]:
190
+ return self._agent_id
191
+
192
+ @agent_id.setter
193
+ def agent_id(self, val: Optional[str]) -> None:
194
+ self._agent_id = val
150
195
 
151
196
  def get_context_creator(self) -> BaseContextCreator:
152
197
  r"""Returns the context creator used by the memory.
@@ -166,7 +211,8 @@ class LongtermAgentMemory(AgentMemory):
166
211
  """
167
212
  chat_history = self.chat_history_block.retrieve()
168
213
  vector_db_retrieve = self.vector_db_block.retrieve(
169
- self._current_topic, self.retrieve_limit
214
+ self._current_topic,
215
+ self.retrieve_limit,
170
216
  )
171
217
  return chat_history[:1] + vector_db_retrieve + chat_history[1:]
172
218
 
camel/memories/base.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import List, Tuple
16
+ from typing import List, Optional, Tuple
17
17
 
18
18
  from camel.memories.records import ContextRecord, MemoryRecord
19
19
  from camel.messages import OpenAIMessage
@@ -112,6 +112,16 @@ class AgentMemory(MemoryBlock, ABC):
112
112
  the memory records stored within the AgentMemory.
113
113
  """
114
114
 
115
+ @property
116
+ @abstractmethod
117
+ def agent_id(self) -> Optional[str]:
118
+ pass
119
+
120
+ @agent_id.setter
121
+ @abstractmethod
122
+ def agent_id(self, val: Optional[str]) -> None:
123
+ pass
124
+
115
125
  @abstractmethod
116
126
  def retrieve(self) -> List[ContextRecord]:
117
127
  r"""Get a record list from the memory for creating model context.
@@ -138,3 +148,15 @@ class AgentMemory(MemoryBlock, ABC):
138
148
  context in OpenAIMessage format and the total token count.
139
149
  """
140
150
  return self.get_context_creator().create_context(self.retrieve())
151
+
152
+ def __repr__(self) -> str:
153
+ r"""Returns a string representation of the AgentMemory.
154
+
155
+ Returns:
156
+ str: A string in the format 'ClassName(agent_id=<id>)'
157
+ if agent_id exists, otherwise just 'ClassName()'.
158
+ """
159
+ agent_id = getattr(self, '_agent_id', None)
160
+ if agent_id:
161
+ return f"{self.__class__.__name__}(agent_id='{agent_id}')"
162
+ return f"{self.__class__.__name__}()"
camel/memories/records.py CHANGED
@@ -39,6 +39,8 @@ class MemoryRecord(BaseModel):
39
39
  key-value pairs that provide more information. If not given, it
40
40
  will be an empty `Dict`.
41
41
  timestamp (float, optional): The timestamp when the record was created.
42
+ agent_id (str): The identifier of the agent associated with this
43
+ memory.
42
44
  """
43
45
 
44
46
  model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -50,6 +52,7 @@ class MemoryRecord(BaseModel):
50
52
  timestamp: float = Field(
51
53
  default_factory=lambda: datetime.now(timezone.utc).timestamp()
52
54
  )
55
+ agent_id: str = Field(default="")
53
56
 
54
57
  _MESSAGE_TYPES: ClassVar[dict] = {
55
58
  "BaseMessage": BaseMessage,
@@ -73,6 +76,7 @@ class MemoryRecord(BaseModel):
73
76
  role_at_backend=record_dict["role_at_backend"],
74
77
  extra_info=record_dict["extra_info"],
75
78
  timestamp=record_dict["timestamp"],
79
+ agent_id=record_dict["agent_id"],
76
80
  )
77
81
 
78
82
  def to_dict(self) -> Dict[str, Any]:
@@ -88,6 +92,7 @@ class MemoryRecord(BaseModel):
88
92
  "role_at_backend": self.role_at_backend,
89
93
  "extra_info": self.extra_info,
90
94
  "timestamp": self.timestamp,
95
+ "agent_id": self.agent_id,
91
96
  }
92
97
 
93
98
  def to_openai_message(self) -> OpenAIMessage:
@@ -44,6 +44,31 @@ class StubTokenCounter(BaseTokenCounter):
44
44
  """
45
45
  return 10
46
46
 
47
+ def encode(self, text: str) -> List[int]:
48
+ r"""Encode text into token IDs for STUB models.
49
+
50
+ Args:
51
+ text (str): The text to encode.
52
+
53
+ Returns:
54
+ List[int]: List of token IDs.
55
+ """
56
+ # For stub models, just return a list of 0s with length proportional
57
+ # to text length
58
+ return [0] * (len(text) // 4 + 1) # Simple approximation
59
+
60
+ def decode(self, token_ids: List[int]) -> str:
61
+ r"""Decode token IDs back to text for STUB models.
62
+
63
+ Args:
64
+ token_ids (List[int]): List of token IDs to decode.
65
+
66
+ Returns:
67
+ str: Decoded text.
68
+ """
69
+ # For stub models, return a placeholder string
70
+ return "[Stub decoded text]"
71
+
47
72
 
48
73
  class StubModel(BaseModelBackend):
49
74
  r"""A dummy model used for unit tests."""
@@ -27,6 +27,7 @@ from camel.storages import (
27
27
  VectorRecord,
28
28
  )
29
29
  from camel.utils import Constants
30
+ from camel.utils.chunker import BaseChunker, UnstructuredIOChunker
30
31
 
31
32
  if TYPE_CHECKING:
32
33
  from unstructured.documents.elements import Element
@@ -78,6 +79,7 @@ class VectorRetriever(BaseRetriever):
78
79
  should_chunk: bool = True,
79
80
  extra_info: Optional[dict] = None,
80
81
  metadata_filename: Optional[str] = None,
82
+ chunker: Optional[BaseChunker] = None,
81
83
  **kwargs: Any,
82
84
  ) -> None:
83
85
  r"""Processes content from local file path, remote URL, string
@@ -101,6 +103,12 @@ class VectorRetriever(BaseRetriever):
101
103
  used for storing metadata. Defaults to None.
102
104
  **kwargs (Any): Additional keyword arguments for content parsing.
103
105
  """
106
+ if chunker is None:
107
+ chunker = UnstructuredIOChunker(
108
+ chunk_type=chunk_type,
109
+ max_characters=max_characters,
110
+ metadata_filename=metadata_filename,
111
+ )
104
112
  from unstructured.documents.elements import Element
105
113
 
106
114
  if isinstance(content, Element):
@@ -140,13 +148,7 @@ class VectorRetriever(BaseRetriever):
140
148
  else:
141
149
  # Chunk the content if required
142
150
  chunks = (
143
- self.uio.chunk_elements(
144
- chunk_type=chunk_type,
145
- elements=elements,
146
- max_characters=max_characters,
147
- )
148
- if should_chunk
149
- else elements
151
+ chunker.chunk(content=elements) if should_chunk else (elements)
150
152
  )
151
153
 
152
154
  # Process chunks in batches and store embeddings
@@ -157,6 +159,7 @@ class VectorRetriever(BaseRetriever):
157
159
  )
158
160
 
159
161
  records = []
162
+ offset = 0
160
163
  # Prepare the payload for each vector record, includes the
161
164
  # content path, chunk metadata, and chunk text
162
165
  for vector, chunk in zip(batch_vectors, batch_chunks):
@@ -178,6 +181,7 @@ class VectorRetriever(BaseRetriever):
178
181
  chunk_metadata["metadata"].pop("orig_elements", "")
179
182
  chunk_metadata["extra_info"] = extra_info or {}
180
183
  chunk_text = {"text": str(chunk)}
184
+ chunk_metadata["metadata"]["piece_num"] = i + offset + 1
181
185
  combined_dict = {
182
186
  **content_path_info,
183
187
  **chunk_metadata,
@@ -187,6 +191,7 @@ class VectorRetriever(BaseRetriever):
187
191
  records.append(
188
192
  VectorRecord(vector=vector, payload=combined_dict)
189
193
  )
194
+ offset += 1
190
195
 
191
196
  self.storage.add(records=records)
192
197
 
@@ -14,7 +14,7 @@
14
14
 
15
15
  from .base import BaseKeyValueStorage
16
16
  from .in_memory import InMemoryKeyValueStorage
17
- from .json import JsonStorage
17
+ from .json import CamelJSONEncoder, JsonStorage
18
18
  from .redis import RedisStorage
19
19
 
20
20
  __all__ = [
@@ -22,4 +22,5 @@ __all__ = [
22
22
  'InMemoryKeyValueStorage',
23
23
  'JsonStorage',
24
24
  'RedisStorage',
25
+ 'CamelJSONEncoder',
25
26
  ]
@@ -26,7 +26,7 @@ from camel.types import (
26
26
  )
27
27
 
28
28
 
29
- class _CamelJSONEncoder(json.JSONEncoder):
29
+ class CamelJSONEncoder(json.JSONEncoder):
30
30
  r"""A custom JSON encoder for serializing specifically enumerated types.
31
31
  Ensures enumerated types can be stored in and retrieved from JSON format.
32
32
  """
@@ -62,7 +62,7 @@ class JsonStorage(BaseKeyValueStorage):
62
62
  def _json_object_hook(self, d) -> Any:
63
63
  if "__enum__" in d:
64
64
  name, member = d["__enum__"].split(".")
65
- return getattr(_CamelJSONEncoder.CAMEL_ENUMS[name], member)
65
+ return getattr(CamelJSONEncoder.CAMEL_ENUMS[name], member)
66
66
  else:
67
67
  return d
68
68
 
@@ -75,11 +75,7 @@ class JsonStorage(BaseKeyValueStorage):
75
75
  """
76
76
  with self.json_path.open("a") as f:
77
77
  f.writelines(
78
- [
79
- json.dumps(r, cls=_CamelJSONEncoder, ensure_ascii=False)
80
- + "\n"
81
- for r in records
82
- ]
78
+ [json.dumps(r, cls=CamelJSONEncoder) + "\n" for r in records]
83
79
  )
84
80
 
85
81
  def load(self) -> List[Dict[str, Any]]:
@@ -87,7 +87,11 @@ class VectorDBQueryResult(BaseModel):
87
87
  ) -> "VectorDBQueryResult":
88
88
  r"""A class method to construct a `VectorDBQueryResult` instance."""
89
89
  return cls(
90
- record=VectorRecord(vector=vector, id=id, payload=payload),
90
+ record=VectorRecord(
91
+ vector=vector,
92
+ id=id,
93
+ payload=payload,
94
+ ),
91
95
  similarity=similarity,
92
96
  )
93
97
 
@@ -50,6 +50,7 @@ from .semantic_scholar_toolkit import SemanticScholarToolkit
50
50
  from .zapier_toolkit import ZapierToolkit
51
51
  from .sympy_toolkit import SymPyToolkit
52
52
  from .mineru_toolkit import MinerUToolkit
53
+ from .memory_toolkit import MemoryToolkit
53
54
  from .audio_analysis_toolkit import AudioAnalysisToolkit
54
55
  from .excel_toolkit import ExcelToolkit
55
56
  from .video_analysis_toolkit import VideoAnalysisToolkit
@@ -60,7 +61,6 @@ from .file_write_toolkit import FileWriteToolkit
60
61
  from .terminal_toolkit import TerminalToolkit
61
62
  from .pubmed_toolkit import PubMedToolkit
62
63
 
63
-
64
64
  __all__ = [
65
65
  'BaseToolkit',
66
66
  'FunctionTool',
@@ -97,6 +97,7 @@ __all__ = [
97
97
  'ZapierToolkit',
98
98
  'SymPyToolkit',
99
99
  'MinerUToolkit',
100
+ 'MemoryToolkit',
100
101
  'MCPToolkit',
101
102
  'MCPToolkitManager',
102
103
  'AudioAnalysisToolkit',