camel-ai 0.2.33__py3-none-any.whl → 0.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -0,0 +1,261 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import asyncio
16
+ from datetime import datetime
17
+ from typing import List
18
+
19
+ from pydantic import ValidationError
20
+
21
+ from camel.agents import ChatAgent
22
+ from camel.logger import get_logger
23
+ from camel.models.base_model import BaseModelBackend
24
+ from camel.verifiers import BaseVerifier
25
+ from camel.verifiers.models import VerifierInput
26
+
27
+ from .base_generator import BaseGenerator
28
+ from .models import DataPoint
29
+ from .static_dataset import StaticDataset
30
+
31
+ logger = get_logger(__name__)
32
+
33
+ SYSTEM_PROMPT = """**You are an advanced data generation assistant.**
34
+ Your goal is to generate high-quality synthetic data points based on
35
+ provided examples. Your output must be well-structured,
36
+ logically sound, and formatted correctly.
37
+
38
+ **Instructions:**
39
+ 1. **Follow the Structure**
40
+ Each data point must include:
41
+ - **Question**: A clear, well-formed query.
42
+ - **Rationale**: A step-by-step, executable reasoning process ending
43
+ with `print(final_answer)`.
44
+ - **Final Answer**: The correct, concise result.
45
+
46
+ 2. **Ensure Logical Consistency**
47
+ - The `rationale` must be code that runs correctly.
48
+ - The `final_answer` should match the printed output.
49
+
50
+ 3. **Output Format (Strict)**
51
+ ```
52
+ Question: [Generated question]
53
+ Rationale: [Code that solves the question, ending in a print statement,
54
+ outputting the answer.]
55
+ Final Answer: [The Final Answer]
56
+
57
+ **Now, generate a new data point based on the given examples.**
58
+ """
59
+
60
+
61
+ class FewShotGenerator(BaseGenerator):
62
+ r"""A generator for creating synthetic datapoints using few-shot learning.
63
+
64
+ This class leverages a seed dataset, an agent, and a verifier to generate
65
+ new synthetic datapoints on demand through few-shot prompting.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ seed_dataset: StaticDataset,
71
+ verifier: BaseVerifier,
72
+ model: BaseModelBackend,
73
+ seed: int = 42,
74
+ **kwargs,
75
+ ):
76
+ r"""Initialize the few-shot generator.
77
+
78
+ Args:
79
+ seed_dataset (StaticDataset): Validated static dataset to
80
+ use for examples.
81
+ verifier (BaseVerifier): Verifier to validate generated content.
82
+ model (BaseModelBackend): The underlying LLM that the generating
83
+ agent will be initiated with.
84
+ seed (int): Random seed for reproducibility. (default: :obj:`42`)
85
+ **kwargs: Additional generator parameters.
86
+ """
87
+ super().__init__(seed=seed, **kwargs)
88
+ self.seed_dataset = seed_dataset
89
+ try:
90
+ self._validate_seed_dataset()
91
+ except Exception:
92
+ raise RuntimeError("Seed Data does not follow Datapoint format")
93
+ self.verifier = verifier
94
+ self.agent = ChatAgent(system_message=SYSTEM_PROMPT, model=model)
95
+
96
+ # TODO: Validate that seed dataset contains rationale
97
+ def _validate_seed_dataset(self) -> None:
98
+ pass
99
+
100
+ def _construct_prompt(self, examples: List[DataPoint]) -> str:
101
+ r"""Construct a prompt for generating new datapoints
102
+ using a fixed sample of examples from the seed dataset.
103
+
104
+ Args:
105
+ examples (List[DataPoint]): Examples to include in the prompt.
106
+
107
+ Returns:
108
+ str: Formatted prompt with examples.
109
+ """
110
+ prompt = (
111
+ "Generate a new datapoint similar to the following examples:\n\n"
112
+ )
113
+ for i, example in enumerate(examples, 1):
114
+ prompt += f"Example {i}:\n"
115
+ prompt += f"Question: {example.question}\n"
116
+ if example.rationale is not None:
117
+ prompt += f"Rationale: {example.rationale}\n"
118
+ else:
119
+ prompt += "Rationale: None\n"
120
+ prompt += f"Final Answer: {example.final_answer}\n\n"
121
+ prompt += "New datapoint:"
122
+ return prompt
123
+
124
+ async def generate_new(
125
+ self,
126
+ n: int,
127
+ max_retries: int = 10,
128
+ num_examples: int = 3,
129
+ **kwargs,
130
+ ) -> List[DataPoint]:
131
+ r"""Generates and validates `n` new datapoints through
132
+ few-shot prompting, with a retry limit.
133
+
134
+ Steps:
135
+ 1. Samples examples from the seed dataset.
136
+ 2. Constructs a prompt using the selected examples.
137
+ 3. Uses an agent to generate a new datapoint,
138
+ consisting of a question and code to solve the question.
139
+ 4. Executes code using a verifier to get pseudo ground truth.
140
+ 5. Stores valid datapoints in memory.
141
+
142
+ Args:
143
+ n (int): Number of valid datapoints to generate.
144
+ max_retries (int): Maximum number of retries before stopping.
145
+ (default: :obj:`10`)
146
+ num_examples (int): Number of examples to sample from the
147
+ seed dataset for few shot prompting.
148
+ (default: :obj:`3`)
149
+ **kwargs: Additional generation parameters.
150
+
151
+ Returns:
152
+ List[DataPoint]: A list of newly generated valid datapoints.
153
+
154
+ Raises:
155
+ TypeError: If the agent's output is not a dictionary (or does not
156
+ match the expected format).
157
+ KeyError: If required keys are missing from the response.
158
+ AttributeError: If the verifier response lacks attributes.
159
+ ValidationError: If a datapoint fails schema validation.
160
+ RuntimeError: If retries are exhausted before `n` valid datapoints
161
+ are generated.
162
+
163
+ Notes:
164
+ - Retries on validation failures until `n` valid datapoints exist
165
+ or `max_retries` is reached, whichever comes first.
166
+ - If retries are exhausted before reaching `n`, a `RuntimeError`
167
+ is raised.
168
+ - Metadata includes a timestamp for tracking datapoint creation.
169
+ """
170
+ valid_data_points: List[DataPoint] = []
171
+ retries = 0
172
+
173
+ while len(valid_data_points) < n and retries < max_retries:
174
+ try:
175
+ examples = [
176
+ self.seed_dataset.sample() for _ in range(num_examples)
177
+ ]
178
+ prompt = self._construct_prompt(examples)
179
+
180
+ try:
181
+ agent_output = (
182
+ self.agent.step(prompt, response_format=DataPoint)
183
+ .msgs[0]
184
+ .parsed
185
+ )
186
+
187
+ assert isinstance(agent_output, DataPoint)
188
+
189
+ self.agent.reset()
190
+
191
+ except (TypeError, KeyError) as e:
192
+ logger.warning(
193
+ f"Agent output issue: {e}, retrying... "
194
+ f"({retries + 1}/{max_retries})"
195
+ )
196
+ retries += 1
197
+ continue
198
+
199
+ rationale = agent_output.rationale
200
+
201
+ if not isinstance(rationale, str):
202
+ raise TypeError(f"Rationale {rationale} is not a string.")
203
+
204
+ try:
205
+ verifier_response = await self.verifier.verify(
206
+ VerifierInput(
207
+ llm_response=rationale,
208
+ ground_truth=None,
209
+ )
210
+ )
211
+ if not verifier_response or not verifier_response.result:
212
+ raise ValueError(
213
+ "Verifier unsuccessful, response: "
214
+ f"{verifier_response}"
215
+ )
216
+ except (ValueError, AttributeError) as e:
217
+ logger.warning(
218
+ f"Verifier issue: {e}, "
219
+ f"retrying... ({retries + 1}/{max_retries})"
220
+ )
221
+ retries += 1
222
+ continue
223
+
224
+ try:
225
+ new_datapoint = DataPoint(
226
+ question=agent_output.question,
227
+ rationale=rationale,
228
+ final_answer=verifier_response.result,
229
+ metadata={
230
+ "synthetic": str(True),
231
+ "created": datetime.now().isoformat(),
232
+ "generator": "few_shot",
233
+ },
234
+ )
235
+ except ValidationError as e:
236
+ logger.warning(
237
+ f"Datapoint validation failed: {e}, "
238
+ f"retrying... ({retries + 1}/{max_retries})"
239
+ )
240
+ retries += 1
241
+ continue
242
+
243
+ valid_data_points.append(new_datapoint)
244
+
245
+ except Exception as e:
246
+ logger.warning(
247
+ f"Unexpected error: {e}, retrying..."
248
+ f" ({retries + 1}/{max_retries})"
249
+ )
250
+ retries += 1
251
+
252
+ if len(valid_data_points) < n:
253
+ raise RuntimeError(
254
+ f"Failed to generate {n} valid datapoints "
255
+ f"after {max_retries} retries."
256
+ )
257
+
258
+ # Thread-safe way to extend the data list
259
+ async with asyncio.Lock():
260
+ self._data.extend(valid_data_points)
261
+ return valid_data_points
@@ -60,7 +60,7 @@ class StaticDataset(Dataset):
60
60
  Input data, which can be one of the following:
61
61
  - A Hugging Face Dataset (:obj:`HFDataset`).
62
62
  - A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
63
- - A :obj:`Path` object representing a JSON file.
63
+ - A :obj:`Path` object representing a JSON or JSONL file.
64
64
  - A list of dictionaries with :obj:`DataPoint`-compatible
65
65
  fields.
66
66
  seed (int): Random seed for reproducibility.
@@ -112,6 +112,7 @@ class StaticDataset(Dataset):
112
112
 
113
113
  Raises:
114
114
  TypeError: If the input data type is unsupported.
115
+ ValueError: If the Path has an unsupported file extension.
115
116
  """
116
117
 
117
118
  if isinstance(data, HFDataset):
@@ -119,7 +120,16 @@ class StaticDataset(Dataset):
119
120
  elif isinstance(data, Dataset):
120
121
  raw_data = self._init_from_pytorch_dataset(data)
121
122
  elif isinstance(data, Path):
122
- raw_data = self._init_from_json_path(data)
123
+ if data.suffix == ".jsonl":
124
+ raw_data = self._init_from_jsonl_path(data)
125
+ elif data.suffix == ".json":
126
+ raw_data = self._init_from_json_path(data)
127
+ else:
128
+ raise ValueError(
129
+ f"Unsupported file extension: {data.suffix}."
130
+ " Please enter a .json or .jsonl object."
131
+ )
132
+
123
133
  elif isinstance(data, list):
124
134
  raw_data = self._init_from_list(data)
125
135
  else:
@@ -322,6 +332,48 @@ class StaticDataset(Dataset):
322
332
  )
323
333
  return loaded_data
324
334
 
335
+ def _init_from_jsonl_path(self, data: Path) -> List[Dict[str, Any]]:
336
+ r"""Load and parse a dataset from a JSONL file.
337
+
338
+ Args:
339
+ data (Path): Path to the JSONL file.
340
+
341
+ Returns:
342
+ List[Dict[str, Any]]: A list of datapoint dictionaries.
343
+
344
+ Raises:
345
+ FileNotFoundError: If the specified JSONL file does not exist.
346
+ ValueError: If a line in the file contains invalid JSON or
347
+ is not a dictionary.
348
+ """
349
+ if not data.exists():
350
+ raise FileNotFoundError(f"JSONL file not found: {data}")
351
+
352
+ raw_data = []
353
+ logger.debug(f"Loading JSONL from {data}")
354
+ with data.open('r', encoding='utf-8') as f:
355
+ for line_number, line in enumerate(f, start=1):
356
+ line = line.strip()
357
+ if not line:
358
+ continue # Skip blank lines if any exist.
359
+ try:
360
+ record = json.loads(line)
361
+ except json.JSONDecodeError as e:
362
+ raise ValueError(
363
+ f"Invalid JSON on line {line_number} in file "
364
+ f"{data}: {e}"
365
+ )
366
+ raw_data.append(record)
367
+ logger.info(f"Successfully loaded {len(raw_data)} items from {data}")
368
+
369
+ for i, item in enumerate(raw_data):
370
+ if not isinstance(item, dict):
371
+ raise ValueError(
372
+ f"Expected a dictionary at record {i+1} (line {i+1}), "
373
+ f"got {type(item).__name__}"
374
+ )
375
+ return raw_data
376
+
325
377
  def _init_from_list(
326
378
  self, data: List[Dict[str, Any]]
327
379
  ) -> List[Dict[str, Any]]:
@@ -33,6 +33,8 @@ class ChatHistoryMemory(AgentMemory):
33
33
  window_size (int, optional): The number of recent chat messages to
34
34
  retrieve. If not provided, the entire chat history will be
35
35
  retrieved. (default: :obj:`None`)
36
+ agent_id (str, optional): The ID of the agent associated with the chat
37
+ history.
36
38
  """
37
39
 
38
40
  def __init__(
@@ -40,6 +42,7 @@ class ChatHistoryMemory(AgentMemory):
40
42
  context_creator: BaseContextCreator,
41
43
  storage: Optional[BaseKeyValueStorage] = None,
42
44
  window_size: Optional[int] = None,
45
+ agent_id: Optional[str] = None,
43
46
  ) -> None:
44
47
  if window_size is not None and not isinstance(window_size, int):
45
48
  raise TypeError("`window_size` must be an integer or None.")
@@ -48,6 +51,15 @@ class ChatHistoryMemory(AgentMemory):
48
51
  self._context_creator = context_creator
49
52
  self._window_size = window_size
50
53
  self._chat_history_block = ChatHistoryBlock(storage=storage)
54
+ self._agent_id = agent_id
55
+
56
+ @property
57
+ def agent_id(self) -> Optional[str]:
58
+ return self._agent_id
59
+
60
+ @agent_id.setter
61
+ def agent_id(self, val: Optional[str]) -> None:
62
+ self._agent_id = val
51
63
 
52
64
  def retrieve(self) -> List[ContextRecord]:
53
65
  records = self._chat_history_block.retrieve(self._window_size)
@@ -63,6 +75,10 @@ class ChatHistoryMemory(AgentMemory):
63
75
  return records
64
76
 
65
77
  def write_records(self, records: List[MemoryRecord]) -> None:
78
+ for record in records:
79
+ # assign the agent_id to the record
80
+ if record.agent_id == "" and self.agent_id is not None:
81
+ record.agent_id = self.agent_id
66
82
  self._chat_history_block.write_records(records)
67
83
 
68
84
  def get_context_creator(self) -> BaseContextCreator:
@@ -84,6 +100,8 @@ class VectorDBMemory(AgentMemory):
84
100
  (default: :obj:`None`)
85
101
  retrieve_limit (int, optional): The maximum number of messages
86
102
  to be added into the context. (default: :obj:`3`)
103
+ agent_id (str, optional): The ID of the agent associated with
104
+ the messages stored in the vector database.
87
105
  """
88
106
 
89
107
  def __init__(
@@ -91,13 +109,23 @@ class VectorDBMemory(AgentMemory):
91
109
  context_creator: BaseContextCreator,
92
110
  storage: Optional[BaseVectorStorage] = None,
93
111
  retrieve_limit: int = 3,
112
+ agent_id: Optional[str] = None,
94
113
  ) -> None:
95
114
  self._context_creator = context_creator
96
115
  self._retrieve_limit = retrieve_limit
97
116
  self._vectordb_block = VectorDBBlock(storage=storage)
117
+ self._agent_id = agent_id
98
118
 
99
119
  self._current_topic: str = ""
100
120
 
121
+ @property
122
+ def agent_id(self) -> Optional[str]:
123
+ return self._agent_id
124
+
125
+ @agent_id.setter
126
+ def agent_id(self, val: Optional[str]) -> None:
127
+ self._agent_id = val
128
+
101
129
  def retrieve(self) -> List[ContextRecord]:
102
130
  return self._vectordb_block.retrieve(
103
131
  self._current_topic,
@@ -109,6 +137,11 @@ class VectorDBMemory(AgentMemory):
109
137
  for record in records:
110
138
  if record.role_at_backend == OpenAIBackendRole.USER:
111
139
  self._current_topic = record.message.content
140
+
141
+ # assign the agent_id to the record
142
+ if record.agent_id == "" and self.agent_id is not None:
143
+ record.agent_id = self.agent_id
144
+
112
145
  self._vectordb_block.write_records(records)
113
146
 
114
147
  def get_context_creator(self) -> BaseContextCreator:
@@ -133,6 +166,8 @@ class LongtermAgentMemory(AgentMemory):
133
166
  (default: :obj:`None`)
134
167
  retrieve_limit (int, optional): The maximum number of messages
135
168
  to be added into the context. (default: :obj:`3`)
169
+ agent_id (str, optional): The ID of the agent associated with the chat
170
+ history and the messages stored in the vector database.
136
171
  """
137
172
 
138
173
  def __init__(
@@ -141,12 +176,22 @@ class LongtermAgentMemory(AgentMemory):
141
176
  chat_history_block: Optional[ChatHistoryBlock] = None,
142
177
  vector_db_block: Optional[VectorDBBlock] = None,
143
178
  retrieve_limit: int = 3,
179
+ agent_id: Optional[str] = None,
144
180
  ) -> None:
145
181
  self.chat_history_block = chat_history_block or ChatHistoryBlock()
146
182
  self.vector_db_block = vector_db_block or VectorDBBlock()
147
183
  self.retrieve_limit = retrieve_limit
148
184
  self._context_creator = context_creator
149
185
  self._current_topic: str = ""
186
+ self._agent_id = agent_id
187
+
188
+ @property
189
+ def agent_id(self) -> Optional[str]:
190
+ return self._agent_id
191
+
192
+ @agent_id.setter
193
+ def agent_id(self, val: Optional[str]) -> None:
194
+ self._agent_id = val
150
195
 
151
196
  def get_context_creator(self) -> BaseContextCreator:
152
197
  r"""Returns the context creator used by the memory.
@@ -166,7 +211,8 @@ class LongtermAgentMemory(AgentMemory):
166
211
  """
167
212
  chat_history = self.chat_history_block.retrieve()
168
213
  vector_db_retrieve = self.vector_db_block.retrieve(
169
- self._current_topic, self.retrieve_limit
214
+ self._current_topic,
215
+ self.retrieve_limit,
170
216
  )
171
217
  return chat_history[:1] + vector_db_retrieve + chat_history[1:]
172
218
 
camel/memories/base.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import List, Tuple
16
+ from typing import List, Optional, Tuple
17
17
 
18
18
  from camel.memories.records import ContextRecord, MemoryRecord
19
19
  from camel.messages import OpenAIMessage
@@ -112,6 +112,16 @@ class AgentMemory(MemoryBlock, ABC):
112
112
  the memory records stored within the AgentMemory.
113
113
  """
114
114
 
115
+ @property
116
+ @abstractmethod
117
+ def agent_id(self) -> Optional[str]:
118
+ pass
119
+
120
+ @agent_id.setter
121
+ @abstractmethod
122
+ def agent_id(self, val: Optional[str]) -> None:
123
+ pass
124
+
115
125
  @abstractmethod
116
126
  def retrieve(self) -> List[ContextRecord]:
117
127
  r"""Get a record list from the memory for creating model context.
@@ -138,3 +148,15 @@ class AgentMemory(MemoryBlock, ABC):
138
148
  context in OpenAIMessage format and the total token count.
139
149
  """
140
150
  return self.get_context_creator().create_context(self.retrieve())
151
+
152
+ def __repr__(self) -> str:
153
+ r"""Returns a string representation of the AgentMemory.
154
+
155
+ Returns:
156
+ str: A string in the format 'ClassName(agent_id=<id>)'
157
+ if agent_id exists, otherwise just 'ClassName()'.
158
+ """
159
+ agent_id = getattr(self, '_agent_id', None)
160
+ if agent_id:
161
+ return f"{self.__class__.__name__}(agent_id='{agent_id}')"
162
+ return f"{self.__class__.__name__}()"
camel/memories/records.py CHANGED
@@ -39,6 +39,8 @@ class MemoryRecord(BaseModel):
39
39
  key-value pairs that provide more information. If not given, it
40
40
  will be an empty `Dict`.
41
41
  timestamp (float, optional): The timestamp when the record was created.
42
+ agent_id (str): The identifier of the agent associated with this
43
+ memory.
42
44
  """
43
45
 
44
46
  model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -50,6 +52,7 @@ class MemoryRecord(BaseModel):
50
52
  timestamp: float = Field(
51
53
  default_factory=lambda: datetime.now(timezone.utc).timestamp()
52
54
  )
55
+ agent_id: str = Field(default="")
53
56
 
54
57
  _MESSAGE_TYPES: ClassVar[dict] = {
55
58
  "BaseMessage": BaseMessage,
@@ -73,6 +76,7 @@ class MemoryRecord(BaseModel):
73
76
  role_at_backend=record_dict["role_at_backend"],
74
77
  extra_info=record_dict["extra_info"],
75
78
  timestamp=record_dict["timestamp"],
79
+ agent_id=record_dict["agent_id"],
76
80
  )
77
81
 
78
82
  def to_dict(self) -> Dict[str, Any]:
@@ -88,6 +92,7 @@ class MemoryRecord(BaseModel):
88
92
  "role_at_backend": self.role_at_backend,
89
93
  "extra_info": self.extra_info,
90
94
  "timestamp": self.timestamp,
95
+ "agent_id": self.agent_id,
91
96
  }
92
97
 
93
98
  def to_openai_message(self) -> OpenAIMessage:
@@ -56,10 +56,8 @@ class OpenAICompatibleModel(BaseModelBackend):
56
56
  url: Optional[str] = None,
57
57
  token_counter: Optional[BaseTokenCounter] = None,
58
58
  ) -> None:
59
- self.api_key = api_key or os.environ.get(
60
- "OPENAI_COMPATIBILITY_API_KEY"
61
- )
62
- self.url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
59
+ api_key = api_key or os.environ.get("OPENAI_COMPATIBILITY_API_KEY")
60
+ url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
63
61
  super().__init__(
64
62
  model_type, model_config_dict, api_key, url, token_counter
65
63
  )
@@ -324,7 +324,10 @@ def _kill_process_tree(
324
324
 
325
325
  # Sometime processes cannot be killed with SIGKILL
326
326
  # so we send an additional signal to kill them.
327
- itself.send_signal(signal.SIGQUIT)
327
+ if hasattr(signal, "SIGQUIT"):
328
+ itself.send_signal(signal.SIGQUIT)
329
+ else:
330
+ itself.send_signal(signal.SIGTERM)
328
331
  except psutil.NoSuchProcess:
329
332
  pass
330
333
 
@@ -44,6 +44,31 @@ class StubTokenCounter(BaseTokenCounter):
44
44
  """
45
45
  return 10
46
46
 
47
+ def encode(self, text: str) -> List[int]:
48
+ r"""Encode text into token IDs for STUB models.
49
+
50
+ Args:
51
+ text (str): The text to encode.
52
+
53
+ Returns:
54
+ List[int]: List of token IDs.
55
+ """
56
+ # For stub models, just return a list of 0s with length proportional
57
+ # to text length
58
+ return [0] * (len(text) // 4 + 1) # Simple approximation
59
+
60
+ def decode(self, token_ids: List[int]) -> str:
61
+ r"""Decode token IDs back to text for STUB models.
62
+
63
+ Args:
64
+ token_ids (List[int]): List of token IDs to decode.
65
+
66
+ Returns:
67
+ str: Decoded text.
68
+ """
69
+ # For stub models, return a placeholder string
70
+ return "[Stub decoded text]"
71
+
47
72
 
48
73
  class StubModel(BaseModelBackend):
49
74
  r"""A dummy model used for unit tests."""
@@ -27,6 +27,7 @@ from camel.storages import (
27
27
  VectorRecord,
28
28
  )
29
29
  from camel.utils import Constants
30
+ from camel.utils.chunker import BaseChunker, UnstructuredIOChunker
30
31
 
31
32
  if TYPE_CHECKING:
32
33
  from unstructured.documents.elements import Element
@@ -78,6 +79,7 @@ class VectorRetriever(BaseRetriever):
78
79
  should_chunk: bool = True,
79
80
  extra_info: Optional[dict] = None,
80
81
  metadata_filename: Optional[str] = None,
82
+ chunker: Optional[BaseChunker] = None,
81
83
  **kwargs: Any,
82
84
  ) -> None:
83
85
  r"""Processes content from local file path, remote URL, string
@@ -101,6 +103,12 @@ class VectorRetriever(BaseRetriever):
101
103
  used for storing metadata. Defaults to None.
102
104
  **kwargs (Any): Additional keyword arguments for content parsing.
103
105
  """
106
+ if chunker is None:
107
+ chunker = UnstructuredIOChunker(
108
+ chunk_type=chunk_type,
109
+ max_characters=max_characters,
110
+ metadata_filename=metadata_filename,
111
+ )
104
112
  from unstructured.documents.elements import Element
105
113
 
106
114
  if isinstance(content, Element):
@@ -140,13 +148,7 @@ class VectorRetriever(BaseRetriever):
140
148
  else:
141
149
  # Chunk the content if required
142
150
  chunks = (
143
- self.uio.chunk_elements(
144
- chunk_type=chunk_type,
145
- elements=elements,
146
- max_characters=max_characters,
147
- )
148
- if should_chunk
149
- else elements
151
+ chunker.chunk(content=elements) if should_chunk else (elements)
150
152
  )
151
153
 
152
154
  # Process chunks in batches and store embeddings
@@ -157,6 +159,7 @@ class VectorRetriever(BaseRetriever):
157
159
  )
158
160
 
159
161
  records = []
162
+ offset = 0
160
163
  # Prepare the payload for each vector record, includes the
161
164
  # content path, chunk metadata, and chunk text
162
165
  for vector, chunk in zip(batch_vectors, batch_chunks):
@@ -178,6 +181,7 @@ class VectorRetriever(BaseRetriever):
178
181
  chunk_metadata["metadata"].pop("orig_elements", "")
179
182
  chunk_metadata["extra_info"] = extra_info or {}
180
183
  chunk_text = {"text": str(chunk)}
184
+ chunk_metadata["metadata"]["piece_num"] = i + offset + 1
181
185
  combined_dict = {
182
186
  **content_path_info,
183
187
  **chunk_metadata,
@@ -187,6 +191,7 @@ class VectorRetriever(BaseRetriever):
187
191
  records.append(
188
192
  VectorRecord(vector=vector, payload=combined_dict)
189
193
  )
194
+ offset += 1
190
195
 
191
196
  self.storage.add(records=records)
192
197
 
@@ -14,7 +14,7 @@
14
14
 
15
15
  from .base import BaseKeyValueStorage
16
16
  from .in_memory import InMemoryKeyValueStorage
17
- from .json import JsonStorage
17
+ from .json import CamelJSONEncoder, JsonStorage
18
18
  from .redis import RedisStorage
19
19
 
20
20
  __all__ = [
@@ -22,4 +22,5 @@ __all__ = [
22
22
  'InMemoryKeyValueStorage',
23
23
  'JsonStorage',
24
24
  'RedisStorage',
25
+ 'CamelJSONEncoder',
25
26
  ]