camel-ai 0.2.33__py3-none-any.whl → 0.2.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/_types.py +1 -1
- camel/agents/_utils.py +4 -4
- camel/agents/chat_agent.py +174 -29
- camel/agents/knowledge_graph_agent.py +5 -0
- camel/configs/openai_config.py +20 -16
- camel/datasets/__init__.py +2 -4
- camel/datasets/base_generator.py +170 -226
- camel/datasets/few_shot_generator.py +261 -0
- camel/datasets/static_dataset.py +54 -2
- camel/memories/agent_memories.py +47 -1
- camel/memories/base.py +23 -1
- camel/memories/records.py +5 -0
- camel/models/openai_compatible_model.py +2 -4
- camel/models/sglang_model.py +4 -1
- camel/models/stub_model.py +25 -0
- camel/retrievers/vector_retriever.py +12 -7
- camel/storages/key_value_storages/__init__.py +2 -1
- camel/storages/key_value_storages/json.py +3 -7
- camel/storages/vectordb_storages/base.py +5 -1
- camel/toolkits/__init__.py +2 -1
- camel/toolkits/file_write_toolkit.py +24 -2
- camel/toolkits/github_toolkit.py +15 -3
- camel/toolkits/memory_toolkit.py +129 -0
- camel/utils/chunker/__init__.py +22 -0
- camel/utils/chunker/base.py +24 -0
- camel/utils/chunker/code_chunker.py +193 -0
- camel/utils/chunker/uio_chunker.py +66 -0
- camel/utils/token_counting.py +133 -0
- {camel_ai-0.2.33.dist-info → camel_ai-0.2.35.dist-info}/METADATA +3 -3
- {camel_ai-0.2.33.dist-info → camel_ai-0.2.35.dist-info}/RECORD +33 -27
- {camel_ai-0.2.33.dist-info → camel_ai-0.2.35.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.33.dist-info → camel_ai-0.2.35.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import List
|
|
18
|
+
|
|
19
|
+
from pydantic import ValidationError
|
|
20
|
+
|
|
21
|
+
from camel.agents import ChatAgent
|
|
22
|
+
from camel.logger import get_logger
|
|
23
|
+
from camel.models.base_model import BaseModelBackend
|
|
24
|
+
from camel.verifiers import BaseVerifier
|
|
25
|
+
from camel.verifiers.models import VerifierInput
|
|
26
|
+
|
|
27
|
+
from .base_generator import BaseGenerator
|
|
28
|
+
from .models import DataPoint
|
|
29
|
+
from .static_dataset import StaticDataset
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
SYSTEM_PROMPT = """**You are an advanced data generation assistant.**
|
|
34
|
+
Your goal is to generate high-quality synthetic data points based on
|
|
35
|
+
provided examples. Your output must be well-structured,
|
|
36
|
+
logically sound, and formatted correctly.
|
|
37
|
+
|
|
38
|
+
**Instructions:**
|
|
39
|
+
1. **Follow the Structure**
|
|
40
|
+
Each data point must include:
|
|
41
|
+
- **Question**: A clear, well-formed query.
|
|
42
|
+
- **Rationale**: A step-by-step, executable reasoning process ending
|
|
43
|
+
with `print(final_answer)`.
|
|
44
|
+
- **Final Answer**: The correct, concise result.
|
|
45
|
+
|
|
46
|
+
2. **Ensure Logical Consistency**
|
|
47
|
+
- The `rationale` must be code that runs correctly.
|
|
48
|
+
- The `final_answer` should match the printed output.
|
|
49
|
+
|
|
50
|
+
3. **Output Format (Strict)**
|
|
51
|
+
```
|
|
52
|
+
Question: [Generated question]
|
|
53
|
+
Rationale: [Code that solves the question, ending in a print statement,
|
|
54
|
+
outputting the answer.]
|
|
55
|
+
Final Answer: [The Final Answer]
|
|
56
|
+
|
|
57
|
+
**Now, generate a new data point based on the given examples.**
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FewShotGenerator(BaseGenerator):
|
|
62
|
+
r"""A generator for creating synthetic datapoints using few-shot learning.
|
|
63
|
+
|
|
64
|
+
This class leverages a seed dataset, an agent, and a verifier to generate
|
|
65
|
+
new synthetic datapoints on demand through few-shot prompting.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
seed_dataset: StaticDataset,
|
|
71
|
+
verifier: BaseVerifier,
|
|
72
|
+
model: BaseModelBackend,
|
|
73
|
+
seed: int = 42,
|
|
74
|
+
**kwargs,
|
|
75
|
+
):
|
|
76
|
+
r"""Initialize the few-shot generator.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
seed_dataset (StaticDataset): Validated static dataset to
|
|
80
|
+
use for examples.
|
|
81
|
+
verifier (BaseVerifier): Verifier to validate generated content.
|
|
82
|
+
model (BaseModelBackend): The underlying LLM that the generating
|
|
83
|
+
agent will be initiated with.
|
|
84
|
+
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
|
85
|
+
**kwargs: Additional generator parameters.
|
|
86
|
+
"""
|
|
87
|
+
super().__init__(seed=seed, **kwargs)
|
|
88
|
+
self.seed_dataset = seed_dataset
|
|
89
|
+
try:
|
|
90
|
+
self._validate_seed_dataset()
|
|
91
|
+
except Exception:
|
|
92
|
+
raise RuntimeError("Seed Data does not follow Datapoint format")
|
|
93
|
+
self.verifier = verifier
|
|
94
|
+
self.agent = ChatAgent(system_message=SYSTEM_PROMPT, model=model)
|
|
95
|
+
|
|
96
|
+
# TODO: Validate that seed dataset contains rationale
|
|
97
|
+
def _validate_seed_dataset(self) -> None:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
def _construct_prompt(self, examples: List[DataPoint]) -> str:
|
|
101
|
+
r"""Construct a prompt for generating new datapoints
|
|
102
|
+
using a fixed sample of examples from the seed dataset.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
examples (List[DataPoint]): Examples to include in the prompt.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
str: Formatted prompt with examples.
|
|
109
|
+
"""
|
|
110
|
+
prompt = (
|
|
111
|
+
"Generate a new datapoint similar to the following examples:\n\n"
|
|
112
|
+
)
|
|
113
|
+
for i, example in enumerate(examples, 1):
|
|
114
|
+
prompt += f"Example {i}:\n"
|
|
115
|
+
prompt += f"Question: {example.question}\n"
|
|
116
|
+
if example.rationale is not None:
|
|
117
|
+
prompt += f"Rationale: {example.rationale}\n"
|
|
118
|
+
else:
|
|
119
|
+
prompt += "Rationale: None\n"
|
|
120
|
+
prompt += f"Final Answer: {example.final_answer}\n\n"
|
|
121
|
+
prompt += "New datapoint:"
|
|
122
|
+
return prompt
|
|
123
|
+
|
|
124
|
+
async def generate_new(
|
|
125
|
+
self,
|
|
126
|
+
n: int,
|
|
127
|
+
max_retries: int = 10,
|
|
128
|
+
num_examples: int = 3,
|
|
129
|
+
**kwargs,
|
|
130
|
+
) -> List[DataPoint]:
|
|
131
|
+
r"""Generates and validates `n` new datapoints through
|
|
132
|
+
few-shot prompting, with a retry limit.
|
|
133
|
+
|
|
134
|
+
Steps:
|
|
135
|
+
1. Samples examples from the seed dataset.
|
|
136
|
+
2. Constructs a prompt using the selected examples.
|
|
137
|
+
3. Uses an agent to generate a new datapoint,
|
|
138
|
+
consisting of a question and code to solve the question.
|
|
139
|
+
4. Executes code using a verifier to get pseudo ground truth.
|
|
140
|
+
5. Stores valid datapoints in memory.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
n (int): Number of valid datapoints to generate.
|
|
144
|
+
max_retries (int): Maximum number of retries before stopping.
|
|
145
|
+
(default: :obj:`10`)
|
|
146
|
+
num_examples (int): Number of examples to sample from the
|
|
147
|
+
seed dataset for few shot prompting.
|
|
148
|
+
(default: :obj:`3`)
|
|
149
|
+
**kwargs: Additional generation parameters.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List[DataPoint]: A list of newly generated valid datapoints.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
TypeError: If the agent's output is not a dictionary (or does not
|
|
156
|
+
match the expected format).
|
|
157
|
+
KeyError: If required keys are missing from the response.
|
|
158
|
+
AttributeError: If the verifier response lacks attributes.
|
|
159
|
+
ValidationError: If a datapoint fails schema validation.
|
|
160
|
+
RuntimeError: If retries are exhausted before `n` valid datapoints
|
|
161
|
+
are generated.
|
|
162
|
+
|
|
163
|
+
Notes:
|
|
164
|
+
- Retries on validation failures until `n` valid datapoints exist
|
|
165
|
+
or `max_retries` is reached, whichever comes first.
|
|
166
|
+
- If retries are exhausted before reaching `n`, a `RuntimeError`
|
|
167
|
+
is raised.
|
|
168
|
+
- Metadata includes a timestamp for tracking datapoint creation.
|
|
169
|
+
"""
|
|
170
|
+
valid_data_points: List[DataPoint] = []
|
|
171
|
+
retries = 0
|
|
172
|
+
|
|
173
|
+
while len(valid_data_points) < n and retries < max_retries:
|
|
174
|
+
try:
|
|
175
|
+
examples = [
|
|
176
|
+
self.seed_dataset.sample() for _ in range(num_examples)
|
|
177
|
+
]
|
|
178
|
+
prompt = self._construct_prompt(examples)
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
agent_output = (
|
|
182
|
+
self.agent.step(prompt, response_format=DataPoint)
|
|
183
|
+
.msgs[0]
|
|
184
|
+
.parsed
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
assert isinstance(agent_output, DataPoint)
|
|
188
|
+
|
|
189
|
+
self.agent.reset()
|
|
190
|
+
|
|
191
|
+
except (TypeError, KeyError) as e:
|
|
192
|
+
logger.warning(
|
|
193
|
+
f"Agent output issue: {e}, retrying... "
|
|
194
|
+
f"({retries + 1}/{max_retries})"
|
|
195
|
+
)
|
|
196
|
+
retries += 1
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
rationale = agent_output.rationale
|
|
200
|
+
|
|
201
|
+
if not isinstance(rationale, str):
|
|
202
|
+
raise TypeError(f"Rationale {rationale} is not a string.")
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
verifier_response = await self.verifier.verify(
|
|
206
|
+
VerifierInput(
|
|
207
|
+
llm_response=rationale,
|
|
208
|
+
ground_truth=None,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
if not verifier_response or not verifier_response.result:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"Verifier unsuccessful, response: "
|
|
214
|
+
f"{verifier_response}"
|
|
215
|
+
)
|
|
216
|
+
except (ValueError, AttributeError) as e:
|
|
217
|
+
logger.warning(
|
|
218
|
+
f"Verifier issue: {e}, "
|
|
219
|
+
f"retrying... ({retries + 1}/{max_retries})"
|
|
220
|
+
)
|
|
221
|
+
retries += 1
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
new_datapoint = DataPoint(
|
|
226
|
+
question=agent_output.question,
|
|
227
|
+
rationale=rationale,
|
|
228
|
+
final_answer=verifier_response.result,
|
|
229
|
+
metadata={
|
|
230
|
+
"synthetic": str(True),
|
|
231
|
+
"created": datetime.now().isoformat(),
|
|
232
|
+
"generator": "few_shot",
|
|
233
|
+
},
|
|
234
|
+
)
|
|
235
|
+
except ValidationError as e:
|
|
236
|
+
logger.warning(
|
|
237
|
+
f"Datapoint validation failed: {e}, "
|
|
238
|
+
f"retrying... ({retries + 1}/{max_retries})"
|
|
239
|
+
)
|
|
240
|
+
retries += 1
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
valid_data_points.append(new_datapoint)
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.warning(
|
|
247
|
+
f"Unexpected error: {e}, retrying..."
|
|
248
|
+
f" ({retries + 1}/{max_retries})"
|
|
249
|
+
)
|
|
250
|
+
retries += 1
|
|
251
|
+
|
|
252
|
+
if len(valid_data_points) < n:
|
|
253
|
+
raise RuntimeError(
|
|
254
|
+
f"Failed to generate {n} valid datapoints "
|
|
255
|
+
f"after {max_retries} retries."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Thread-safe way to extend the data list
|
|
259
|
+
async with asyncio.Lock():
|
|
260
|
+
self._data.extend(valid_data_points)
|
|
261
|
+
return valid_data_points
|
camel/datasets/static_dataset.py
CHANGED
|
@@ -60,7 +60,7 @@ class StaticDataset(Dataset):
|
|
|
60
60
|
Input data, which can be one of the following:
|
|
61
61
|
- A Hugging Face Dataset (:obj:`HFDataset`).
|
|
62
62
|
- A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
|
|
63
|
-
- A :obj:`Path` object representing a JSON file.
|
|
63
|
+
- A :obj:`Path` object representing a JSON or JSONL file.
|
|
64
64
|
- A list of dictionaries with :obj:`DataPoint`-compatible
|
|
65
65
|
fields.
|
|
66
66
|
seed (int): Random seed for reproducibility.
|
|
@@ -112,6 +112,7 @@ class StaticDataset(Dataset):
|
|
|
112
112
|
|
|
113
113
|
Raises:
|
|
114
114
|
TypeError: If the input data type is unsupported.
|
|
115
|
+
ValueError: If the Path has an unsupported file extension.
|
|
115
116
|
"""
|
|
116
117
|
|
|
117
118
|
if isinstance(data, HFDataset):
|
|
@@ -119,7 +120,16 @@ class StaticDataset(Dataset):
|
|
|
119
120
|
elif isinstance(data, Dataset):
|
|
120
121
|
raw_data = self._init_from_pytorch_dataset(data)
|
|
121
122
|
elif isinstance(data, Path):
|
|
122
|
-
|
|
123
|
+
if data.suffix == ".jsonl":
|
|
124
|
+
raw_data = self._init_from_jsonl_path(data)
|
|
125
|
+
elif data.suffix == ".json":
|
|
126
|
+
raw_data = self._init_from_json_path(data)
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Unsupported file extension: {data.suffix}."
|
|
130
|
+
" Please enter a .json or .jsonl object."
|
|
131
|
+
)
|
|
132
|
+
|
|
123
133
|
elif isinstance(data, list):
|
|
124
134
|
raw_data = self._init_from_list(data)
|
|
125
135
|
else:
|
|
@@ -322,6 +332,48 @@ class StaticDataset(Dataset):
|
|
|
322
332
|
)
|
|
323
333
|
return loaded_data
|
|
324
334
|
|
|
335
|
+
def _init_from_jsonl_path(self, data: Path) -> List[Dict[str, Any]]:
|
|
336
|
+
r"""Load and parse a dataset from a JSONL file.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
data (Path): Path to the JSONL file.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
List[Dict[str, Any]]: A list of datapoint dictionaries.
|
|
343
|
+
|
|
344
|
+
Raises:
|
|
345
|
+
FileNotFoundError: If the specified JSONL file does not exist.
|
|
346
|
+
ValueError: If a line in the file contains invalid JSON or
|
|
347
|
+
is not a dictionary.
|
|
348
|
+
"""
|
|
349
|
+
if not data.exists():
|
|
350
|
+
raise FileNotFoundError(f"JSONL file not found: {data}")
|
|
351
|
+
|
|
352
|
+
raw_data = []
|
|
353
|
+
logger.debug(f"Loading JSONL from {data}")
|
|
354
|
+
with data.open('r', encoding='utf-8') as f:
|
|
355
|
+
for line_number, line in enumerate(f, start=1):
|
|
356
|
+
line = line.strip()
|
|
357
|
+
if not line:
|
|
358
|
+
continue # Skip blank lines if any exist.
|
|
359
|
+
try:
|
|
360
|
+
record = json.loads(line)
|
|
361
|
+
except json.JSONDecodeError as e:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Invalid JSON on line {line_number} in file "
|
|
364
|
+
f"{data}: {e}"
|
|
365
|
+
)
|
|
366
|
+
raw_data.append(record)
|
|
367
|
+
logger.info(f"Successfully loaded {len(raw_data)} items from {data}")
|
|
368
|
+
|
|
369
|
+
for i, item in enumerate(raw_data):
|
|
370
|
+
if not isinstance(item, dict):
|
|
371
|
+
raise ValueError(
|
|
372
|
+
f"Expected a dictionary at record {i+1} (line {i+1}), "
|
|
373
|
+
f"got {type(item).__name__}"
|
|
374
|
+
)
|
|
375
|
+
return raw_data
|
|
376
|
+
|
|
325
377
|
def _init_from_list(
|
|
326
378
|
self, data: List[Dict[str, Any]]
|
|
327
379
|
) -> List[Dict[str, Any]]:
|
camel/memories/agent_memories.py
CHANGED
|
@@ -33,6 +33,8 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
33
33
|
window_size (int, optional): The number of recent chat messages to
|
|
34
34
|
retrieve. If not provided, the entire chat history will be
|
|
35
35
|
retrieved. (default: :obj:`None`)
|
|
36
|
+
agent_id (str, optional): The ID of the agent associated with the chat
|
|
37
|
+
history.
|
|
36
38
|
"""
|
|
37
39
|
|
|
38
40
|
def __init__(
|
|
@@ -40,6 +42,7 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
40
42
|
context_creator: BaseContextCreator,
|
|
41
43
|
storage: Optional[BaseKeyValueStorage] = None,
|
|
42
44
|
window_size: Optional[int] = None,
|
|
45
|
+
agent_id: Optional[str] = None,
|
|
43
46
|
) -> None:
|
|
44
47
|
if window_size is not None and not isinstance(window_size, int):
|
|
45
48
|
raise TypeError("`window_size` must be an integer or None.")
|
|
@@ -48,6 +51,15 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
48
51
|
self._context_creator = context_creator
|
|
49
52
|
self._window_size = window_size
|
|
50
53
|
self._chat_history_block = ChatHistoryBlock(storage=storage)
|
|
54
|
+
self._agent_id = agent_id
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def agent_id(self) -> Optional[str]:
|
|
58
|
+
return self._agent_id
|
|
59
|
+
|
|
60
|
+
@agent_id.setter
|
|
61
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
62
|
+
self._agent_id = val
|
|
51
63
|
|
|
52
64
|
def retrieve(self) -> List[ContextRecord]:
|
|
53
65
|
records = self._chat_history_block.retrieve(self._window_size)
|
|
@@ -63,6 +75,10 @@ class ChatHistoryMemory(AgentMemory):
|
|
|
63
75
|
return records
|
|
64
76
|
|
|
65
77
|
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
78
|
+
for record in records:
|
|
79
|
+
# assign the agent_id to the record
|
|
80
|
+
if record.agent_id == "" and self.agent_id is not None:
|
|
81
|
+
record.agent_id = self.agent_id
|
|
66
82
|
self._chat_history_block.write_records(records)
|
|
67
83
|
|
|
68
84
|
def get_context_creator(self) -> BaseContextCreator:
|
|
@@ -84,6 +100,8 @@ class VectorDBMemory(AgentMemory):
|
|
|
84
100
|
(default: :obj:`None`)
|
|
85
101
|
retrieve_limit (int, optional): The maximum number of messages
|
|
86
102
|
to be added into the context. (default: :obj:`3`)
|
|
103
|
+
agent_id (str, optional): The ID of the agent associated with
|
|
104
|
+
the messages stored in the vector database.
|
|
87
105
|
"""
|
|
88
106
|
|
|
89
107
|
def __init__(
|
|
@@ -91,13 +109,23 @@ class VectorDBMemory(AgentMemory):
|
|
|
91
109
|
context_creator: BaseContextCreator,
|
|
92
110
|
storage: Optional[BaseVectorStorage] = None,
|
|
93
111
|
retrieve_limit: int = 3,
|
|
112
|
+
agent_id: Optional[str] = None,
|
|
94
113
|
) -> None:
|
|
95
114
|
self._context_creator = context_creator
|
|
96
115
|
self._retrieve_limit = retrieve_limit
|
|
97
116
|
self._vectordb_block = VectorDBBlock(storage=storage)
|
|
117
|
+
self._agent_id = agent_id
|
|
98
118
|
|
|
99
119
|
self._current_topic: str = ""
|
|
100
120
|
|
|
121
|
+
@property
|
|
122
|
+
def agent_id(self) -> Optional[str]:
|
|
123
|
+
return self._agent_id
|
|
124
|
+
|
|
125
|
+
@agent_id.setter
|
|
126
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
127
|
+
self._agent_id = val
|
|
128
|
+
|
|
101
129
|
def retrieve(self) -> List[ContextRecord]:
|
|
102
130
|
return self._vectordb_block.retrieve(
|
|
103
131
|
self._current_topic,
|
|
@@ -109,6 +137,11 @@ class VectorDBMemory(AgentMemory):
|
|
|
109
137
|
for record in records:
|
|
110
138
|
if record.role_at_backend == OpenAIBackendRole.USER:
|
|
111
139
|
self._current_topic = record.message.content
|
|
140
|
+
|
|
141
|
+
# assign the agent_id to the record
|
|
142
|
+
if record.agent_id == "" and self.agent_id is not None:
|
|
143
|
+
record.agent_id = self.agent_id
|
|
144
|
+
|
|
112
145
|
self._vectordb_block.write_records(records)
|
|
113
146
|
|
|
114
147
|
def get_context_creator(self) -> BaseContextCreator:
|
|
@@ -133,6 +166,8 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
133
166
|
(default: :obj:`None`)
|
|
134
167
|
retrieve_limit (int, optional): The maximum number of messages
|
|
135
168
|
to be added into the context. (default: :obj:`3`)
|
|
169
|
+
agent_id (str, optional): The ID of the agent associated with the chat
|
|
170
|
+
history and the messages stored in the vector database.
|
|
136
171
|
"""
|
|
137
172
|
|
|
138
173
|
def __init__(
|
|
@@ -141,12 +176,22 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
141
176
|
chat_history_block: Optional[ChatHistoryBlock] = None,
|
|
142
177
|
vector_db_block: Optional[VectorDBBlock] = None,
|
|
143
178
|
retrieve_limit: int = 3,
|
|
179
|
+
agent_id: Optional[str] = None,
|
|
144
180
|
) -> None:
|
|
145
181
|
self.chat_history_block = chat_history_block or ChatHistoryBlock()
|
|
146
182
|
self.vector_db_block = vector_db_block or VectorDBBlock()
|
|
147
183
|
self.retrieve_limit = retrieve_limit
|
|
148
184
|
self._context_creator = context_creator
|
|
149
185
|
self._current_topic: str = ""
|
|
186
|
+
self._agent_id = agent_id
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def agent_id(self) -> Optional[str]:
|
|
190
|
+
return self._agent_id
|
|
191
|
+
|
|
192
|
+
@agent_id.setter
|
|
193
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
194
|
+
self._agent_id = val
|
|
150
195
|
|
|
151
196
|
def get_context_creator(self) -> BaseContextCreator:
|
|
152
197
|
r"""Returns the context creator used by the memory.
|
|
@@ -166,7 +211,8 @@ class LongtermAgentMemory(AgentMemory):
|
|
|
166
211
|
"""
|
|
167
212
|
chat_history = self.chat_history_block.retrieve()
|
|
168
213
|
vector_db_retrieve = self.vector_db_block.retrieve(
|
|
169
|
-
self._current_topic,
|
|
214
|
+
self._current_topic,
|
|
215
|
+
self.retrieve_limit,
|
|
170
216
|
)
|
|
171
217
|
return chat_history[:1] + vector_db_retrieve + chat_history[1:]
|
|
172
218
|
|
camel/memories/base.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
-
from typing import List, Tuple
|
|
16
|
+
from typing import List, Optional, Tuple
|
|
17
17
|
|
|
18
18
|
from camel.memories.records import ContextRecord, MemoryRecord
|
|
19
19
|
from camel.messages import OpenAIMessage
|
|
@@ -112,6 +112,16 @@ class AgentMemory(MemoryBlock, ABC):
|
|
|
112
112
|
the memory records stored within the AgentMemory.
|
|
113
113
|
"""
|
|
114
114
|
|
|
115
|
+
@property
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def agent_id(self) -> Optional[str]:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@agent_id.setter
|
|
121
|
+
@abstractmethod
|
|
122
|
+
def agent_id(self, val: Optional[str]) -> None:
|
|
123
|
+
pass
|
|
124
|
+
|
|
115
125
|
@abstractmethod
|
|
116
126
|
def retrieve(self) -> List[ContextRecord]:
|
|
117
127
|
r"""Get a record list from the memory for creating model context.
|
|
@@ -138,3 +148,15 @@ class AgentMemory(MemoryBlock, ABC):
|
|
|
138
148
|
context in OpenAIMessage format and the total token count.
|
|
139
149
|
"""
|
|
140
150
|
return self.get_context_creator().create_context(self.retrieve())
|
|
151
|
+
|
|
152
|
+
def __repr__(self) -> str:
|
|
153
|
+
r"""Returns a string representation of the AgentMemory.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: A string in the format 'ClassName(agent_id=<id>)'
|
|
157
|
+
if agent_id exists, otherwise just 'ClassName()'.
|
|
158
|
+
"""
|
|
159
|
+
agent_id = getattr(self, '_agent_id', None)
|
|
160
|
+
if agent_id:
|
|
161
|
+
return f"{self.__class__.__name__}(agent_id='{agent_id}')"
|
|
162
|
+
return f"{self.__class__.__name__}()"
|
camel/memories/records.py
CHANGED
|
@@ -39,6 +39,8 @@ class MemoryRecord(BaseModel):
|
|
|
39
39
|
key-value pairs that provide more information. If not given, it
|
|
40
40
|
will be an empty `Dict`.
|
|
41
41
|
timestamp (float, optional): The timestamp when the record was created.
|
|
42
|
+
agent_id (str): The identifier of the agent associated with this
|
|
43
|
+
memory.
|
|
42
44
|
"""
|
|
43
45
|
|
|
44
46
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -50,6 +52,7 @@ class MemoryRecord(BaseModel):
|
|
|
50
52
|
timestamp: float = Field(
|
|
51
53
|
default_factory=lambda: datetime.now(timezone.utc).timestamp()
|
|
52
54
|
)
|
|
55
|
+
agent_id: str = Field(default="")
|
|
53
56
|
|
|
54
57
|
_MESSAGE_TYPES: ClassVar[dict] = {
|
|
55
58
|
"BaseMessage": BaseMessage,
|
|
@@ -73,6 +76,7 @@ class MemoryRecord(BaseModel):
|
|
|
73
76
|
role_at_backend=record_dict["role_at_backend"],
|
|
74
77
|
extra_info=record_dict["extra_info"],
|
|
75
78
|
timestamp=record_dict["timestamp"],
|
|
79
|
+
agent_id=record_dict["agent_id"],
|
|
76
80
|
)
|
|
77
81
|
|
|
78
82
|
def to_dict(self) -> Dict[str, Any]:
|
|
@@ -88,6 +92,7 @@ class MemoryRecord(BaseModel):
|
|
|
88
92
|
"role_at_backend": self.role_at_backend,
|
|
89
93
|
"extra_info": self.extra_info,
|
|
90
94
|
"timestamp": self.timestamp,
|
|
95
|
+
"agent_id": self.agent_id,
|
|
91
96
|
}
|
|
92
97
|
|
|
93
98
|
def to_openai_message(self) -> OpenAIMessage:
|
|
@@ -56,10 +56,8 @@ class OpenAICompatibleModel(BaseModelBackend):
|
|
|
56
56
|
url: Optional[str] = None,
|
|
57
57
|
token_counter: Optional[BaseTokenCounter] = None,
|
|
58
58
|
) -> None:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
)
|
|
62
|
-
self.url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
|
|
59
|
+
api_key = api_key or os.environ.get("OPENAI_COMPATIBILITY_API_KEY")
|
|
60
|
+
url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
|
|
63
61
|
super().__init__(
|
|
64
62
|
model_type, model_config_dict, api_key, url, token_counter
|
|
65
63
|
)
|
camel/models/sglang_model.py
CHANGED
|
@@ -324,7 +324,10 @@ def _kill_process_tree(
|
|
|
324
324
|
|
|
325
325
|
# Sometime processes cannot be killed with SIGKILL
|
|
326
326
|
# so we send an additional signal to kill them.
|
|
327
|
-
|
|
327
|
+
if hasattr(signal, "SIGQUIT"):
|
|
328
|
+
itself.send_signal(signal.SIGQUIT)
|
|
329
|
+
else:
|
|
330
|
+
itself.send_signal(signal.SIGTERM)
|
|
328
331
|
except psutil.NoSuchProcess:
|
|
329
332
|
pass
|
|
330
333
|
|
camel/models/stub_model.py
CHANGED
|
@@ -44,6 +44,31 @@ class StubTokenCounter(BaseTokenCounter):
|
|
|
44
44
|
"""
|
|
45
45
|
return 10
|
|
46
46
|
|
|
47
|
+
def encode(self, text: str) -> List[int]:
|
|
48
|
+
r"""Encode text into token IDs for STUB models.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
text (str): The text to encode.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List[int]: List of token IDs.
|
|
55
|
+
"""
|
|
56
|
+
# For stub models, just return a list of 0s with length proportional
|
|
57
|
+
# to text length
|
|
58
|
+
return [0] * (len(text) // 4 + 1) # Simple approximation
|
|
59
|
+
|
|
60
|
+
def decode(self, token_ids: List[int]) -> str:
|
|
61
|
+
r"""Decode token IDs back to text for STUB models.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
token_ids (List[int]): List of token IDs to decode.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
str: Decoded text.
|
|
68
|
+
"""
|
|
69
|
+
# For stub models, return a placeholder string
|
|
70
|
+
return "[Stub decoded text]"
|
|
71
|
+
|
|
47
72
|
|
|
48
73
|
class StubModel(BaseModelBackend):
|
|
49
74
|
r"""A dummy model used for unit tests."""
|
|
@@ -27,6 +27,7 @@ from camel.storages import (
|
|
|
27
27
|
VectorRecord,
|
|
28
28
|
)
|
|
29
29
|
from camel.utils import Constants
|
|
30
|
+
from camel.utils.chunker import BaseChunker, UnstructuredIOChunker
|
|
30
31
|
|
|
31
32
|
if TYPE_CHECKING:
|
|
32
33
|
from unstructured.documents.elements import Element
|
|
@@ -78,6 +79,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
78
79
|
should_chunk: bool = True,
|
|
79
80
|
extra_info: Optional[dict] = None,
|
|
80
81
|
metadata_filename: Optional[str] = None,
|
|
82
|
+
chunker: Optional[BaseChunker] = None,
|
|
81
83
|
**kwargs: Any,
|
|
82
84
|
) -> None:
|
|
83
85
|
r"""Processes content from local file path, remote URL, string
|
|
@@ -101,6 +103,12 @@ class VectorRetriever(BaseRetriever):
|
|
|
101
103
|
used for storing metadata. Defaults to None.
|
|
102
104
|
**kwargs (Any): Additional keyword arguments for content parsing.
|
|
103
105
|
"""
|
|
106
|
+
if chunker is None:
|
|
107
|
+
chunker = UnstructuredIOChunker(
|
|
108
|
+
chunk_type=chunk_type,
|
|
109
|
+
max_characters=max_characters,
|
|
110
|
+
metadata_filename=metadata_filename,
|
|
111
|
+
)
|
|
104
112
|
from unstructured.documents.elements import Element
|
|
105
113
|
|
|
106
114
|
if isinstance(content, Element):
|
|
@@ -140,13 +148,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
140
148
|
else:
|
|
141
149
|
# Chunk the content if required
|
|
142
150
|
chunks = (
|
|
143
|
-
|
|
144
|
-
chunk_type=chunk_type,
|
|
145
|
-
elements=elements,
|
|
146
|
-
max_characters=max_characters,
|
|
147
|
-
)
|
|
148
|
-
if should_chunk
|
|
149
|
-
else elements
|
|
151
|
+
chunker.chunk(content=elements) if should_chunk else (elements)
|
|
150
152
|
)
|
|
151
153
|
|
|
152
154
|
# Process chunks in batches and store embeddings
|
|
@@ -157,6 +159,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
157
159
|
)
|
|
158
160
|
|
|
159
161
|
records = []
|
|
162
|
+
offset = 0
|
|
160
163
|
# Prepare the payload for each vector record, includes the
|
|
161
164
|
# content path, chunk metadata, and chunk text
|
|
162
165
|
for vector, chunk in zip(batch_vectors, batch_chunks):
|
|
@@ -178,6 +181,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
178
181
|
chunk_metadata["metadata"].pop("orig_elements", "")
|
|
179
182
|
chunk_metadata["extra_info"] = extra_info or {}
|
|
180
183
|
chunk_text = {"text": str(chunk)}
|
|
184
|
+
chunk_metadata["metadata"]["piece_num"] = i + offset + 1
|
|
181
185
|
combined_dict = {
|
|
182
186
|
**content_path_info,
|
|
183
187
|
**chunk_metadata,
|
|
@@ -187,6 +191,7 @@ class VectorRetriever(BaseRetriever):
|
|
|
187
191
|
records.append(
|
|
188
192
|
VectorRecord(vector=vector, payload=combined_dict)
|
|
189
193
|
)
|
|
194
|
+
offset += 1
|
|
190
195
|
|
|
191
196
|
self.storage.add(records=records)
|
|
192
197
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from .base import BaseKeyValueStorage
|
|
16
16
|
from .in_memory import InMemoryKeyValueStorage
|
|
17
|
-
from .json import JsonStorage
|
|
17
|
+
from .json import CamelJSONEncoder, JsonStorage
|
|
18
18
|
from .redis import RedisStorage
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
@@ -22,4 +22,5 @@ __all__ = [
|
|
|
22
22
|
'InMemoryKeyValueStorage',
|
|
23
23
|
'JsonStorage',
|
|
24
24
|
'RedisStorage',
|
|
25
|
+
'CamelJSONEncoder',
|
|
25
26
|
]
|