camel-ai 0.2.34__py3-none-any.whl → 0.2.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/_types.py +1 -1
- camel/agents/_utils.py +4 -4
- camel/agents/chat_agent.py +174 -29
- camel/configs/openai_config.py +20 -16
- camel/datasets/base_generator.py +188 -27
- camel/memories/agent_memories.py +47 -1
- camel/memories/base.py +23 -1
- camel/memories/records.py +5 -0
- camel/models/stub_model.py +25 -0
- camel/retrievers/vector_retriever.py +12 -7
- camel/storages/key_value_storages/__init__.py +2 -1
- camel/storages/key_value_storages/json.py +3 -7
- camel/storages/vectordb_storages/base.py +5 -1
- camel/toolkits/__init__.py +2 -1
- camel/toolkits/memory_toolkit.py +129 -0
- camel/utils/chunker/__init__.py +22 -0
- camel/utils/chunker/base.py +24 -0
- camel/utils/chunker/code_chunker.py +193 -0
- camel/utils/chunker/uio_chunker.py +66 -0
- camel/utils/token_counting.py +133 -0
- {camel_ai-0.2.34.dist-info → camel_ai-0.2.35.dist-info}/METADATA +1 -1
- {camel_ai-0.2.34.dist-info → camel_ai-0.2.35.dist-info}/RECORD +25 -20
- {camel_ai-0.2.34.dist-info → camel_ai-0.2.35.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.34.dist-info → camel_ai-0.2.35.dist-info}/licenses/LICENSE +0 -0
camel/datasets/base_generator.py
CHANGED
|
@@ -13,10 +13,14 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
15
|
import abc
|
|
16
|
+
import asyncio
|
|
16
17
|
import json
|
|
17
18
|
import random
|
|
18
19
|
from pathlib import Path
|
|
19
|
-
from typing import List, Union
|
|
20
|
+
from typing import Any, Dict, List, Union
|
|
21
|
+
|
|
22
|
+
from pydantic import ValidationError
|
|
23
|
+
from torch.utils.data import IterableDataset
|
|
20
24
|
|
|
21
25
|
from camel.logger import get_logger
|
|
22
26
|
|
|
@@ -25,23 +29,47 @@ from .models import DataPoint
|
|
|
25
29
|
logger = get_logger(__name__)
|
|
26
30
|
|
|
27
31
|
|
|
28
|
-
class BaseGenerator(abc.ABC):
|
|
32
|
+
class BaseGenerator(abc.ABC, IterableDataset):
|
|
29
33
|
r"""Abstract base class for data generators.
|
|
30
34
|
|
|
31
35
|
This class defines the interface for generating synthetic datapoints.
|
|
32
36
|
Concrete implementations should provide specific generation strategies.
|
|
33
37
|
"""
|
|
34
38
|
|
|
35
|
-
def __init__(
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
seed: int = 42,
|
|
42
|
+
cache: Union[str, Path, None] = None,
|
|
43
|
+
data_path: Union[str, Path, None] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
):
|
|
36
46
|
r"""Initialize the base generator.
|
|
37
47
|
|
|
38
48
|
Args:
|
|
39
49
|
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
|
50
|
+
cache (Union[str, Path, None]): Optional path to save generated
|
|
51
|
+
datapoints during iteration. If None is provided, datapoints
|
|
52
|
+
will be discarded every 100 generations.
|
|
53
|
+
data_path (Union[str, Path, None]): Optional path to a JSONL file
|
|
54
|
+
to initialize the dataset from.
|
|
40
55
|
**kwargs: Additional generator parameters.
|
|
41
56
|
"""
|
|
42
57
|
self._rng = random.Random(seed)
|
|
58
|
+
self.cache = Path(cache) if cache else None
|
|
43
59
|
|
|
44
60
|
self._data: List[DataPoint] = []
|
|
61
|
+
self._batch_to_save: List[DataPoint] = []
|
|
62
|
+
|
|
63
|
+
if data_path:
|
|
64
|
+
file_path = Path(data_path)
|
|
65
|
+
raw_data = self._init_from_jsonl(file_path)
|
|
66
|
+
try:
|
|
67
|
+
data_points = [DataPoint(**item) for item in raw_data]
|
|
68
|
+
self._data.extend(data_points)
|
|
69
|
+
except ValidationError as e:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Failed to create DataPoint from JSONL data: {e}"
|
|
72
|
+
)
|
|
45
73
|
|
|
46
74
|
@abc.abstractmethod
|
|
47
75
|
async def generate_new(self, n: int, **kwargs) -> List[DataPoint]:
|
|
@@ -56,34 +84,112 @@ class BaseGenerator(abc.ABC):
|
|
|
56
84
|
"""
|
|
57
85
|
pass
|
|
58
86
|
|
|
59
|
-
def
|
|
60
|
-
r"""
|
|
61
|
-
return len(self._data)
|
|
87
|
+
def __aiter__(self):
|
|
88
|
+
r"""Async iterator that yields datapoints dynamically.
|
|
62
89
|
|
|
63
|
-
|
|
64
|
-
|
|
90
|
+
If a `data_path` was provided during initialization, those datapoints
|
|
91
|
+
are yielded first. When self._data is empty, 20 new datapoints
|
|
92
|
+
are generated. Every 100 yields, the batch is appended to the
|
|
93
|
+
JSONL file or discarded if `cache` is None.
|
|
65
94
|
|
|
66
|
-
|
|
67
|
-
|
|
95
|
+
Yields:
|
|
96
|
+
DataPoint: A single datapoint.
|
|
97
|
+
"""
|
|
68
98
|
|
|
69
|
-
|
|
70
|
-
|
|
99
|
+
async def generator():
|
|
100
|
+
while True:
|
|
101
|
+
if not self._data:
|
|
102
|
+
new_datapoints = await self.generate_new(20)
|
|
103
|
+
self._data.extend(new_datapoints)
|
|
104
|
+
datapoint = self._data.pop(0)
|
|
105
|
+
yield datapoint
|
|
106
|
+
self._batch_to_save.append(datapoint)
|
|
107
|
+
if len(self._batch_to_save) == 100:
|
|
108
|
+
if self.cache:
|
|
109
|
+
with self.cache.open("a", encoding="utf-8") as f:
|
|
110
|
+
for dp in self._batch_to_save:
|
|
111
|
+
json.dump(dp.to_dict(), f, ensure_ascii=False)
|
|
112
|
+
f.write("\n")
|
|
113
|
+
self._batch_to_save = []
|
|
71
114
|
|
|
72
|
-
|
|
73
|
-
|
|
115
|
+
return generator()
|
|
116
|
+
|
|
117
|
+
def __iter__(self):
|
|
118
|
+
r"""Synchronous iterator for PyTorch IterableDataset compatibility.
|
|
119
|
+
|
|
120
|
+
If a `data_path` was provided during initialization, those datapoints
|
|
121
|
+
are yielded first. When self._data is empty, 20 new datapoints
|
|
122
|
+
are generated. Every 100 yields, the batch is appended to the
|
|
123
|
+
JSONL file or discarded if `cache` is None.
|
|
124
|
+
|
|
125
|
+
Yields:
|
|
126
|
+
DataPoint: A single datapoint.
|
|
74
127
|
"""
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
128
|
+
try:
|
|
129
|
+
if asyncio.get_event_loop().is_running():
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
"Cannot use synchronous iteration (__iter__) in an async "
|
|
132
|
+
"context; use 'async for' with __aiter__ instead"
|
|
133
|
+
)
|
|
134
|
+
except RuntimeError as e:
|
|
135
|
+
if "no running event loop" not in str(e):
|
|
136
|
+
raise
|
|
137
|
+
|
|
138
|
+
while True:
|
|
139
|
+
if not self._data:
|
|
140
|
+
new_datapoints = asyncio.run(self.generate_new(20))
|
|
141
|
+
self._data.extend(new_datapoints)
|
|
142
|
+
datapoint = self._data.pop(0)
|
|
143
|
+
yield datapoint
|
|
144
|
+
self._batch_to_save.append(datapoint)
|
|
145
|
+
if len(self._batch_to_save) == 100:
|
|
146
|
+
if self.cache:
|
|
147
|
+
with self.cache.open("a", encoding="utf-8") as f:
|
|
148
|
+
for dp in self._batch_to_save:
|
|
149
|
+
json.dump(dp.to_dict(), f, ensure_ascii=False)
|
|
150
|
+
f.write("\n")
|
|
151
|
+
self._batch_to_save = []
|
|
81
152
|
|
|
82
153
|
def sample(self) -> DataPoint:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
154
|
+
r"""Returns the next datapoint from the current dataset
|
|
155
|
+
synchronously.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
RuntimeError: If called in an async context.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
DataPoint: The next DataPoint.
|
|
162
|
+
|
|
163
|
+
Note:
|
|
164
|
+
This method is intended for synchronous contexts.
|
|
165
|
+
Use 'async_sample' in asynchronous contexts to
|
|
166
|
+
avoid blocking or runtime errors.
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
if asyncio.get_event_loop().is_running():
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
"Cannot use synchronous sampling (sample) "
|
|
172
|
+
"in an async context; use async_sample instead"
|
|
173
|
+
)
|
|
174
|
+
except RuntimeError as e:
|
|
175
|
+
if "no running event loop" not in str(e):
|
|
176
|
+
raise
|
|
177
|
+
|
|
178
|
+
return next(iter(self))
|
|
179
|
+
|
|
180
|
+
async def async_sample(self) -> DataPoint:
|
|
181
|
+
r"""Returns the next datapoint from the current dataset asynchronously.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
DataPoint: The next datapoint.
|
|
185
|
+
|
|
186
|
+
Note:
|
|
187
|
+
This method is intended for asynchronous contexts. Use 'sample'
|
|
188
|
+
in synchronous contexts.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
async_iter = self.__aiter__()
|
|
192
|
+
return await async_iter.__anext__()
|
|
87
193
|
|
|
88
194
|
def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
|
|
89
195
|
r"""Saves the generated datapoints to a JSONL (JSON Lines) file.
|
|
@@ -99,7 +205,7 @@ class BaseGenerator(abc.ABC):
|
|
|
99
205
|
|
|
100
206
|
Notes:
|
|
101
207
|
- Uses `self._data`, which contains the generated datapoints.
|
|
102
|
-
-
|
|
208
|
+
- Appends to the file if it already exists.
|
|
103
209
|
- Ensures compatibility with large datasets by using JSONL format.
|
|
104
210
|
"""
|
|
105
211
|
if not self._data:
|
|
@@ -108,11 +214,66 @@ class BaseGenerator(abc.ABC):
|
|
|
108
214
|
file_path = Path(file_path)
|
|
109
215
|
|
|
110
216
|
try:
|
|
111
|
-
with file_path.open("
|
|
217
|
+
with file_path.open("a", encoding="utf-8") as f:
|
|
112
218
|
for datapoint in self._data:
|
|
113
219
|
json.dump(datapoint.to_dict(), f, ensure_ascii=False)
|
|
114
|
-
f.write("\n")
|
|
220
|
+
f.write("\n")
|
|
115
221
|
logger.info(f"Dataset saved successfully to {file_path}")
|
|
116
222
|
except IOError as e:
|
|
117
223
|
logger.error(f"Error writing to file {file_path}: {e}")
|
|
118
224
|
raise
|
|
225
|
+
|
|
226
|
+
def flush(self, file_path: Union[str, Path]) -> None:
|
|
227
|
+
r"""Flush the current data to a JSONL file and clear the data.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
file_path (Union[str, Path]): Path to save the JSONL file.
|
|
231
|
+
|
|
232
|
+
Notes:
|
|
233
|
+
- Uses `save_to_jsonl` to save `self._data`.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
self.save_to_jsonl(file_path)
|
|
237
|
+
self._data = []
|
|
238
|
+
logger.info(f"Data flushed to {file_path} and cleared from the memory")
|
|
239
|
+
|
|
240
|
+
def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
|
|
241
|
+
r"""Load and parse a dataset from a JSONL file.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
file_path (Path): Path to the JSONL file.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List[Dict[str, Any]]: A list of datapoint dictionaries.
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
FileNotFoundError: If the specified JSONL file does not exist.
|
|
251
|
+
ValueError: If a line contains invalid JSON or is not a dictionary.
|
|
252
|
+
"""
|
|
253
|
+
if not file_path.exists():
|
|
254
|
+
raise FileNotFoundError(f"JSONL file not found: {file_path}")
|
|
255
|
+
|
|
256
|
+
raw_data = []
|
|
257
|
+
logger.debug(f"Loading JSONL from {file_path}")
|
|
258
|
+
with file_path.open('r', encoding='utf-8') as f:
|
|
259
|
+
for line_number, line in enumerate(f, start=1):
|
|
260
|
+
line = line.strip()
|
|
261
|
+
if not line:
|
|
262
|
+
continue # Skip blank lines
|
|
263
|
+
try:
|
|
264
|
+
record = json.loads(line)
|
|
265
|
+
except json.JSONDecodeError as e:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Invalid JSON on line {line_number} "
|
|
268
|
+
f"in file {file_path}: {e}"
|
|
269
|
+
)
|
|
270
|
+
if not isinstance(record, dict):
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Expected a dictionary at line {line_number}, "
|
|
273
|
+
f"got {type(record).__name__}"
|
|
274
|
+
)
|
|
275
|
+
raw_data.append(record)
|
|
276
|
+
logger.info(
|
|
277
|
+
f"Successfully loaded {len(raw_data)} items from {file_path}"
|
|
278
|
+
)
|
|
279
|
+
return raw_data
|
camel/memories/agent_memories.py
CHANGED
|
@@ -33,6 +33,8 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
33
33
|
window_size (int, optional): The number of recent chat messages to
|
|
34
34
|
retrieve. If not provided, the entire chat history will be
|
|
35
35
|
retrieved. (default: :obj:`None`)
|
|
36
|
+
agent_id (str, optional): The ID of the agent associated with the chat
|
|
37
|
+
history.
|
|
36
38
|
"""
|
|
37
39
|
|
|
38
40
|
def __init__(
|
|
@@ -40,6 +42,7 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
40
42
|
context_creator: BaseContextCreator,
|
|
41
43
|
storage: Optional[BaseKeyValueStorage] = None,
|
|
42
44
|
window_size: Optional[int] = None,
|
|
45
|
+
agent_id: Optional[str] = None,
|
|
43
46
|
) -> None:
|
|
44
47
|
if window_size is not None and not isinstance(window_size, int):
|
|
45
48
|
raise TypeError("`window_size` must be an integer or None.")
|
|
@@ -48,6 +51,15 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
48
51
|
self._context_creator = context_creator
|
|
49
52
|
self._window_size = window_size
|
|
50
53
|
self._chat_history_block = ChatHistoryBlock(storage=storage)
|
|
54
|
+
self._agent_id = agent_id
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def agent_id(self) -> Optional[str]:
|
|
58
|
+
return self._agent_id
|
|
59
|
+
|
|
60
|
+
@agent_id.setter
|
|
61
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
62
|
+
self._agent_id = val
|
|
51
63
|
|
|
52
64
|
def retrieve(self) -> List[ContextRecord]:
|
|
53
65
|
records = self._chat_history_block.retrieve(self._window_size)
|
|
@@ -63,6 +75,10 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
63
75
|
return records
|
|
64
76
|
|
|
65
77
|
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
78
|
+
for record in records:
|
|
79
|
+
# assign the agent_id to the record
|
|
80
|
+
if record.agent_id == "" and self.agent_id is not None:
|
|
81
|
+
record.agent_id = self.agent_id
|
|
66
82
|
self._chat_history_block.write_records(records)
|
|
67
83
|
|
|
68
84
|
def get_context_creator(self) -> BaseContextCreator:
|
|
@@ -84,6 +100,8 @@ class VectorDBMemory(AgentMemory):
|
|
|
84
100
|
(default: :obj:`None`)
|
|
85
101
|
retrieve_limit (int, optional): The maximum number of messages
|
|
86
102
|
to be added into the context. (default: :obj:`3`)
|
|
103
|
+
agent_id (str, optional): The ID of the agent associated with
|
|
104
|
+
the messages stored in the vector database.
|
|
87
105
|
"""
|
|
88
106
|
|
|
89
107
|
def __init__(
|
|
@@ -91,13 +109,23 @@ class VectorDBMemory(AgentMemory):
|
|
|
91
109
|
context_creator: BaseContextCreator,
|
|
92
110
|
storage: Optional[BaseVectorStorage] = None,
|
|
93
111
|
retrieve_limit: int = 3,
|
|
112
|
+
agent_id: Optional[str] = None,
|
|
94
113
|
) -> None:
|
|
95
114
|
self._context_creator = context_creator
|
|
96
115
|
self._retrieve_limit = retrieve_limit
|
|
97
116
|
self._vectordb_block = VectorDBBlock(storage=storage)
|
|
117
|
+
self._agent_id = agent_id
|
|
98
118
|
|
|
99
119
|
self._current_topic: str = ""
|
|
100
120
|
|
|
121
|
+
@property
|
|
122
|
+
def agent_id(self) -> Optional[str]:
|
|
123
|
+
return self._agent_id
|
|
124
|
+
|
|
125
|
+
@agent_id.setter
|
|
126
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
127
|
+
self._agent_id = val
|
|
128
|
+
|
|
101
129
|
def retrieve(self) -> List[ContextRecord]:
|
|
102
130
|
return self._vectordb_block.retrieve(
|
|
103
131
|
self._current_topic,
|
|
@@ -109,6 +137,11 @@ class VectorDBMemory(AgentMemory):
|
|
|
109
137
|
for record in records:
|
|
110
138
|
if record.role_at_backend == OpenAIBackendRole.USER:
|
|
111
139
|
self._current_topic = record.message.content
|
|
140
|
+
|
|
141
|
+
# assign the agent_id to the record
|
|
142
|
+
if record.agent_id == "" and self.agent_id is not None:
|
|
143
|
+
record.agent_id = self.agent_id
|
|
144
|
+
|
|
112
145
|
self._vectordb_block.write_records(records)
|
|
113
146
|
|
|
114
147
|
def get_context_creator(self) -> BaseContextCreator:
|
|
@@ -133,6 +166,8 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
133
166
|
(default: :obj:`None`)
|
|
134
167
|
retrieve_limit (int, optional): The maximum number of messages
|
|
135
168
|
to be added into the context. (default: :obj:`3`)
|
|
169
|
+
agent_id (str, optional): The ID of the agent associated with the chat
|
|
170
|
+
history and the messages stored in the vector database.
|
|
136
171
|
"""
|
|
137
172
|
|
|
138
173
|
def __init__(
|
|
@@ -141,12 +176,22 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
141
176
|
chat_history_block: Optional[ChatHistoryBlock] = None,
|
|
142
177
|
vector_db_block: Optional[VectorDBBlock] = None,
|
|
143
178
|
retrieve_limit: int = 3,
|
|
179
|
+
agent_id: Optional[str] = None,
|
|
144
180
|
) -> None:
|
|
145
181
|
self.chat_history_block = chat_history_block or ChatHistoryBlock()
|
|
146
182
|
self.vector_db_block = vector_db_block or VectorDBBlock()
|
|
147
183
|
self.retrieve_limit = retrieve_limit
|
|
148
184
|
self._context_creator = context_creator
|
|
149
185
|
self._current_topic: str = ""
|
|
186
|
+
self._agent_id = agent_id
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def agent_id(self) -> Optional[str]:
|
|
190
|
+
return self._agent_id
|
|
191
|
+
|
|
192
|
+
@agent_id.setter
|
|
193
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
194
|
+
self._agent_id = val
|
|
150
195
|
|
|
151
196
|
def get_context_creator(self) -> BaseContextCreator:
|
|
152
197
|
r"""Returns the context creator used by the memory.
|
|
@@ -166,7 +211,8 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
166
211
|
"""
|
|
167
212
|
chat_history = self.chat_history_block.retrieve()
|
|
168
213
|
vector_db_retrieve = self.vector_db_block.retrieve(
|
|
169
|
-
self._current_topic,
|
|
214
|
+
self._current_topic,
|
|
215
|
+
self.retrieve_limit,
|
|
170
216
|
)
|
|
171
217
|
return chat_history[:1] + vector_db_retrieve + chat_history[1:]
|
|
172
218
|
|
camel/memories/base.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
-
from typing import List, Tuple
|
|
16
|
+
from typing import List, Optional, Tuple
|
|
17
17
|
|
|
18
18
|
from camel.memories.records import ContextRecord, MemoryRecord
|
|
19
19
|
from camel.messages import OpenAIMessage
|
|
@@ -112,6 +112,16 @@ class AgentMemory(MemoryBlock, ABC):
|
|
|
112
112
|
the memory records stored within the AgentMemory.
|
|
113
113
|
"""
|
|
114
114
|
|
|
115
|
+
@property
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def agent_id(self) -> Optional[str]:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@agent_id.setter
|
|
121
|
+
@abstractmethod
|
|
122
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
123
|
+
pass
|
|
124
|
+
|
|
115
125
|
@abstractmethod
|
|
116
126
|
def retrieve(self) -> List[ContextRecord]:
|
|
117
127
|
r"""Get a record list from the memory for creating model context.
|
|
@@ -138,3 +148,15 @@ class AgentMemory(MemoryBlock, ABC):
|
|
|
138
148
|
context in OpenAIMessage format and the total token count.
|
|
139
149
|
"""
|
|
140
150
|
return self.get_context_creator().create_context(self.retrieve())
|
|
151
|
+
|
|
152
|
+
def __repr__(self) -> str:
|
|
153
|
+
r"""Returns a string representation of the AgentMemory.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: A string in the format 'ClassName(agent_id=<id>)'
|
|
157
|
+
if agent_id exists, otherwise just 'ClassName()'.
|
|
158
|
+
"""
|
|
159
|
+
agent_id = getattr(self, '_agent_id', None)
|
|
160
|
+
if agent_id:
|
|
161
|
+
return f"{self.__class__.__name__}(agent_id='{agent_id}')"
|
|
162
|
+
return f"{self.__class__.__name__}()"
|
camel/memories/records.py
CHANGED
|
@@ -39,6 +39,8 @@ class MemoryRecord(BaseModel):
|
|
|
39
39
|
key-value pairs that provide more information. If not given, it
|
|
40
40
|
will be an empty `Dict`.
|
|
41
41
|
timestamp (float, optional): The timestamp when the record was created.
|
|
42
|
+
agent_id (str): The identifier of the agent associated with this
|
|
43
|
+
memory.
|
|
42
44
|
"""
|
|
43
45
|
|
|
44
46
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -50,6 +52,7 @@ class MemoryRecord(BaseModel):
|
|
|
50
52
|
timestamp: float = Field(
|
|
51
53
|
default_factory=lambda: datetime.now(timezone.utc).timestamp()
|
|
52
54
|
)
|
|
55
|
+
agent_id: str = Field(default="")
|
|
53
56
|
|
|
54
57
|
_MESSAGE_TYPES: ClassVar[dict] = {
|
|
55
58
|
"BaseMessage": BaseMessage,
|
|
@@ -73,6 +76,7 @@ class MemoryRecord(BaseModel):
|
|
|
73
76
|
role_at_backend=record_dict["role_at_backend"],
|
|
74
77
|
extra_info=record_dict["extra_info"],
|
|
75
78
|
timestamp=record_dict["timestamp"],
|
|
79
|
+
agent_id=record_dict["agent_id"],
|
|
76
80
|
)
|
|
77
81
|
|
|
78
82
|
def to_dict(self) -> Dict[str, Any]:
|
|
@@ -88,6 +92,7 @@ class MemoryRecord(BaseModel):
|
|
|
88
92
|
"role_at_backend": self.role_at_backend,
|
|
89
93
|
"extra_info": self.extra_info,
|
|
90
94
|
"timestamp": self.timestamp,
|
|
95
|
+
"agent_id": self.agent_id,
|
|
91
96
|
}
|
|
92
97
|
|
|
93
98
|
def to_openai_message(self) -> OpenAIMessage:
|
camel/models/stub_model.py
CHANGED
|
@@ -44,6 +44,31 @@ class StubTokenCounter(BaseTokenCounter):
|
|
|
44
44
|
"""
|
|
45
45
|
return 10
|
|
46
46
|
|
|
47
|
+
def encode(self, text: str) -> List[int]:
|
|
48
|
+
r"""Encode text into token IDs for STUB models.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
text (str): The text to encode.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List[int]: List of token IDs.
|
|
55
|
+
"""
|
|
56
|
+
# For stub models, just return a list of 0s with length proportional
|
|
57
|
+
# to text length
|
|
58
|
+
return [0] * (len(text) // 4 + 1) # Simple approximation
|
|
59
|
+
|
|
60
|
+
def decode(self, token_ids: List[int]) -> str:
|
|
61
|
+
r"""Decode token IDs back to text for STUB models.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
token_ids (List[int]): List of token IDs to decode.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
str: Decoded text.
|
|
68
|
+
"""
|
|
69
|
+
# For stub models, return a placeholder string
|
|
70
|
+
return "[Stub decoded text]"
|
|
71
|
+
|
|
47
72
|
|
|
48
73
|
class StubModel(BaseModelBackend):
|
|
49
74
|
r"""A dummy model used for unit tests."""
|
|
@@ -27,6 +27,7 @@ from camel.storages import (
|
|
|
27
27
|
VectorRecord,
|
|
28
28
|
)
|
|
29
29
|
from camel.utils import Constants
|
|
30
|
+
from camel.utils.chunker import BaseChunker, UnstructuredIOChunker
|
|
30
31
|
|
|
31
32
|
if TYPE_CHECKING:
|
|
32
33
|
from unstructured.documents.elements import Element
|
|
@@ -78,6 +79,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
78
79
|
should_chunk: bool = True,
|
|
79
80
|
extra_info: Optional[dict] = None,
|
|
80
81
|
metadata_filename: Optional[str] = None,
|
|
82
|
+
chunker: Optional[BaseChunker] = None,
|
|
81
83
|
**kwargs: Any,
|
|
82
84
|
) -> None:
|
|
83
85
|
r"""Processes content from local file path, remote URL, string
|
|
@@ -101,6 +103,12 @@ class VectorRetriever(BaseRetriever):
|
|
|
101
103
|
used for storing metadata. Defaults to None.
|
|
102
104
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
103
105
|
"""
|
|
106
|
+
if chunker is None:
|
|
107
|
+
chunker = UnstructuredIOChunker(
|
|
108
|
+
chunk_type=chunk_type,
|
|
109
|
+
max_characters=max_characters,
|
|
110
|
+
metadata_filename=metadata_filename,
|
|
111
|
+
)
|
|
104
112
|
from unstructured.documents.elements import Element
|
|
105
113
|
|
|
106
114
|
if isinstance(content, Element):
|
|
@@ -140,13 +148,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
140
148
|
else:
|
|
141
149
|
# Chunk the content if required
|
|
142
150
|
chunks = (
|
|
143
|
-
|
|
144
|
-
chunk_type=chunk_type,
|
|
145
|
-
elements=elements,
|
|
146
|
-
max_characters=max_characters,
|
|
147
|
-
)
|
|
148
|
-
if should_chunk
|
|
149
|
-
else elements
|
|
151
|
+
chunker.chunk(content=elements) if should_chunk else (elements)
|
|
150
152
|
)
|
|
151
153
|
|
|
152
154
|
# Process chunks in batches and store embeddings
|
|
@@ -157,6 +159,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
157
159
|
)
|
|
158
160
|
|
|
159
161
|
records = []
|
|
162
|
+
offset = 0
|
|
160
163
|
# Prepare the payload for each vector record, includes the
|
|
161
164
|
# content path, chunk metadata, and chunk text
|
|
162
165
|
for vector, chunk in zip(batch_vectors, batch_chunks):
|
|
@@ -178,6 +181,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
178
181
|
chunk_metadata["metadata"].pop("orig_elements", "")
|
|
179
182
|
chunk_metadata["extra_info"] = extra_info or {}
|
|
180
183
|
chunk_text = {"text": str(chunk)}
|
|
184
|
+
chunk_metadata["metadata"]["piece_num"] = i + offset + 1
|
|
181
185
|
combined_dict = {
|
|
182
186
|
**content_path_info,
|
|
183
187
|
**chunk_metadata,
|
|
@@ -187,6 +191,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
187
191
|
records.append(
|
|
188
192
|
VectorRecord(vector=vector, payload=combined_dict)
|
|
189
193
|
)
|
|
194
|
+
offset += 1
|
|
190
195
|
|
|
191
196
|
self.storage.add(records=records)
|
|
192
197
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from .base import BaseKeyValueStorage
|
|
16
16
|
from .in_memory import InMemoryKeyValueStorage
|
|
17
|
-
from .json import JsonStorage
|
|
17
|
+
from .json import CamelJSONEncoder, JsonStorage
|
|
18
18
|
from .redis import RedisStorage
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
@@ -22,4 +22,5 @@ __all__ = [
|
|
|
22
22
|
'InMemoryKeyValueStorage',
|
|
23
23
|
'JsonStorage',
|
|
24
24
|
'RedisStorage',
|
|
25
|
+
'CamelJSONEncoder',
|
|
25
26
|
]
|
|
@@ -26,7 +26,7 @@ from camel.types import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
class
|
|
29
|
+
class CamelJSONEncoder(json.JSONEncoder):
|
|
30
30
|
r"""A custom JSON encoder for serializing specifically enumerated types.
|
|
31
31
|
Ensures enumerated types can be stored in and retrieved from JSON format.
|
|
32
32
|
"""
|
|
@@ -62,7 +62,7 @@ class JsonStorage(BaseKeyValueStorage):
|
|
|
62
62
|
def _json_object_hook(self, d) -> Any:
|
|
63
63
|
if "__enum__" in d:
|
|
64
64
|
name, member = d["__enum__"].split(".")
|
|
65
|
-
return getattr(
|
|
65
|
+
return getattr(CamelJSONEncoder.CAMEL_ENUMS[name], member)
|
|
66
66
|
else:
|
|
67
67
|
return d
|
|
68
68
|
|
|
@@ -75,11 +75,7 @@ class JsonStorage(BaseKeyValueStorage):
|
|
|
75
75
|
"""
|
|
76
76
|
with self.json_path.open("a") as f:
|
|
77
77
|
f.writelines(
|
|
78
|
-
[
|
|
79
|
-
json.dumps(r, cls=_CamelJSONEncoder, ensure_ascii=False)
|
|
80
|
-
+ "\n"
|
|
81
|
-
for r in records
|
|
82
|
-
]
|
|
78
|
+
[json.dumps(r, cls=CamelJSONEncoder) + "\n" for r in records]
|
|
83
79
|
)
|
|
84
80
|
|
|
85
81
|
def load(self) -> List[Dict[str, Any]]:
|
|
@@ -87,7 +87,11 @@ class VectorDBQueryResult(BaseModel):
|
|
|
87
87
|
) -> "VectorDBQueryResult":
|
|
88
88
|
r"""A class method to construct a `VectorDBQueryResult` instance."""
|
|
89
89
|
return cls(
|
|
90
|
-
record=VectorRecord(
|
|
90
|
+
record=VectorRecord(
|
|
91
|
+
vector=vector,
|
|
92
|
+
id=id,
|
|
93
|
+
payload=payload,
|
|
94
|
+
),
|
|
91
95
|
similarity=similarity,
|
|
92
96
|
)
|
|
93
97
|
|
camel/toolkits/__init__.py
CHANGED
|
@@ -50,6 +50,7 @@ from .semantic_scholar_toolkit import SemanticScholarToolkit
|
|
|
50
50
|
from .zapier_toolkit import ZapierToolkit
|
|
51
51
|
from .sympy_toolkit import SymPyToolkit
|
|
52
52
|
from .mineru_toolkit import MinerUToolkit
|
|
53
|
+
from .memory_toolkit import MemoryToolkit
|
|
53
54
|
from .audio_analysis_toolkit import AudioAnalysisToolkit
|
|
54
55
|
from .excel_toolkit import ExcelToolkit
|
|
55
56
|
from .video_analysis_toolkit import VideoAnalysisToolkit
|
|
@@ -60,7 +61,6 @@ from .file_write_toolkit import FileWriteToolkit
|
|
|
60
61
|
from .terminal_toolkit import TerminalToolkit
|
|
61
62
|
from .pubmed_toolkit import PubMedToolkit
|
|
62
63
|
|
|
63
|
-
|
|
64
64
|
__all__ = [
|
|
65
65
|
'BaseToolkit',
|
|
66
66
|
'FunctionTool',
|
|
@@ -97,6 +97,7 @@ __all__ = [
|
|
|
97
97
|
'ZapierToolkit',
|
|
98
98
|
'SymPyToolkit',
|
|
99
99
|
'MinerUToolkit',
|
|
100
|
+
'MemoryToolkit',
|
|
100
101
|
'MCPToolkit',
|
|
101
102
|
'MCPToolkitManager',
|
|
102
103
|
'AudioAnalysisToolkit',
|