camel-ai 0.2.77__py3-none-any.whl → 0.2.79a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -62,6 +62,7 @@ class BaseGenerator(abc.ABC, IterableDataset):
62
62
  self._buffer = buffer
63
63
  self._data: List[DataPoint] = []
64
64
  self._batch_to_save: List[DataPoint] = []
65
+ self._iter_position: int = 0
65
66
 
66
67
  if data_path:
67
68
  file_path = Path(data_path)
@@ -103,9 +104,9 @@ class BaseGenerator(abc.ABC, IterableDataset):
103
104
  r"""Async iterator that yields datapoints dynamically.
104
105
 
105
106
  If a `data_path` was provided during initialization, those datapoints
106
- are yielded first. When self._data is empty, 20 new datapoints
107
- are generated. Every 100 yields, the batch is appended to the
108
- JSONL file or discarded if `cache` is None.
107
+ are yielded first. When self._iter_position reaches the end of _data,
108
+ new datapoints are generated. Every 100 yields, the batch is appended
109
+ to the JSONL file or discarded if `cache` is None.
109
110
 
110
111
  Yields:
111
112
  DataPoint: A single datapoint.
@@ -113,9 +114,10 @@ class BaseGenerator(abc.ABC, IterableDataset):
113
114
 
114
115
  async def generator():
115
116
  while True:
116
- if not self._data:
117
+ if self._iter_position >= len(self._data):
117
118
  await self.generate_new(self._buffer)
118
- datapoint = self._data.pop(0)
119
+ datapoint = self._data[self._iter_position]
120
+ self._iter_position += 1
119
121
  yield datapoint
120
122
  self._batch_to_save.append(datapoint)
121
123
  if len(self._batch_to_save) == 100:
@@ -132,9 +134,9 @@ class BaseGenerator(abc.ABC, IterableDataset):
132
134
  r"""Synchronous iterator for PyTorch IterableDataset compatibility.
133
135
 
134
136
  If a `data_path` was provided during initialization, those datapoints
135
- are yielded first. When self._data is empty, 20 new datapoints
136
- are generated. Every 100 yields, the batch is appended to the
137
- JSONL file or discarded if `cache` is None.
137
+ are yielded first. When self._iter_position reaches the end of _data,
138
+ new datapoints are generated. Every 100 yields, the batch is appended
139
+ to the JSONL file or discarded if `cache` is None.
138
140
 
139
141
  Yields:
140
142
  DataPoint: A single datapoint.
@@ -150,9 +152,10 @@ class BaseGenerator(abc.ABC, IterableDataset):
150
152
  raise
151
153
 
152
154
  while True:
153
- if not self._data:
155
+ if self._iter_position >= len(self._data):
154
156
  asyncio.run(self.generate_new(self._buffer))
155
- datapoint = self._data.pop(0)
157
+ datapoint = self._data[self._iter_position]
158
+ self._iter_position += 1
156
159
  yield datapoint
157
160
  self._batch_to_save.append(datapoint)
158
161
  if len(self._batch_to_save) == 100:
@@ -248,6 +251,7 @@ class BaseGenerator(abc.ABC, IterableDataset):
248
251
 
249
252
  self.save_to_jsonl(file_path)
250
253
  self._data = []
254
+ self._iter_position = 0
251
255
  logger.info(f"Data flushed to {file_path} and cleared from the memory")
252
256
 
253
257
  def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
@@ -290,3 +294,28 @@ class BaseGenerator(abc.ABC, IterableDataset):
290
294
  f"Successfully loaded {len(raw_data)} items from {file_path}"
291
295
  )
292
296
  return raw_data
297
+
298
+ def __getitem__(self, index: int) -> DataPoint:
299
+ r"""Get a datapoint by index without removing the datapoint from _data.
300
+
301
+ Args:
302
+ index (int): Index of the datapoint to retrieve.
303
+
304
+ Returns:
305
+ DataPoint: The datapoint at the specified index.
306
+
307
+ Raises:
308
+ IndexError: If the index is out of range.
309
+ """
310
+ if index < 0 or index >= len(self._data):
311
+ raise IndexError(f"Index {index} is out of range")
312
+
313
+ return self._data[index]
314
+
315
+ def __len__(self) -> int:
316
+ r"""Get the number of datapoints in the dataset.
317
+
318
+ Returns:
319
+ int: The number of datapoints.
320
+ """
321
+ return len(self._data)
@@ -218,9 +218,34 @@ class SingleStepEnv:
218
218
  return observations[0] if batch_size == 1 else observations
219
219
 
220
220
  elif isinstance(self.dataset, BaseGenerator):
221
- self._states = [
222
- await self.dataset.async_sample() for _ in range(batch_size)
223
- ]
221
+ # Generate more data if needed
222
+ if batch_size > len(self.dataset):
223
+ new_datapoints_needed = batch_size - len(self.dataset)
224
+ await self.dataset.generate_new(n=new_datapoints_needed)
225
+
226
+ # Verify that enough data was generated
227
+ if len(self.dataset) < batch_size:
228
+ raise RuntimeError(
229
+ f"Failed to generate enough datapoints. "
230
+ f"Requested {batch_size}, but only "
231
+ f"{len(self.dataset)} available after generation."
232
+ )
233
+
234
+ # Choose sampling strategy based on whether seed is provided
235
+ if seed is not None:
236
+ # Deterministic random sampling when seed is provided
237
+ random_indices = rng.sample(
238
+ range(len(self.dataset)), batch_size
239
+ )
240
+ self._states = [self.dataset[ind] for ind in random_indices]
241
+ else:
242
+ # Sequential sampling when no seed (backward compatible)
243
+ # Use async_sample to maintain sequential behavior
244
+ self._states = [
245
+ await self.dataset.async_sample()
246
+ for _ in range(batch_size)
247
+ ]
248
+
224
249
  self.current_batch_size = batch_size
225
250
  self._states_done = [False] * batch_size
226
251
 
@@ -18,7 +18,7 @@ from .agent_memories import (
18
18
  VectorDBMemory,
19
19
  )
20
20
  from .base import AgentMemory, BaseContextCreator, MemoryBlock
21
- from .blocks.chat_history_block import ChatHistoryBlock, EmptyMemoryWarning
21
+ from .blocks.chat_history_block import ChatHistoryBlock
22
22
  from .blocks.vectordb_block import VectorDBBlock
23
23
  from .context_creators.score_based import ScoreBasedContextCreator
24
24
  from .records import ContextRecord, MemoryRecord
@@ -35,5 +35,4 @@ __all__ = [
35
35
  'ChatHistoryBlock',
36
36
  'VectorDBBlock',
37
37
  'LongtermAgentMemory',
38
- 'EmptyMemoryWarning',
39
38
  ]
@@ -11,7 +11,6 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
- import warnings
15
14
  from typing import List, Optional
16
15
 
17
16
  from camel.memories.base import MemoryBlock
@@ -21,17 +20,6 @@ from camel.storages.key_value_storages.in_memory import InMemoryKeyValueStorage
21
20
  from camel.types import OpenAIBackendRole
22
21
 
23
22
 
24
- class EmptyMemoryWarning(UserWarning):
25
- """Warning raised when attempting to access an empty memory.
26
-
27
- This warning is raised when operations are performed on memory
28
- that contains no records. It can be safely caught and suppressed
29
- in contexts where empty memory is expected.
30
- """
31
-
32
- pass
33
-
34
-
35
23
  class ChatHistoryBlock(MemoryBlock):
36
24
  r"""An implementation of the :obj:`MemoryBlock` abstract base class for
37
25
  maintaining a record of chat histories.
@@ -81,11 +69,8 @@ class ChatHistoryBlock(MemoryBlock):
81
69
  """
82
70
  record_dicts = self.storage.load()
83
71
  if len(record_dicts) == 0:
84
- warnings.warn(
85
- "The `ChatHistoryMemory` is empty.",
86
- EmptyMemoryWarning,
87
- stacklevel=1,
88
- )
72
+ # Empty memory is a valid state (e.g., during initialization).
73
+ # Users can check if memory is empty by checking the returned list.
89
74
  return list()
90
75
 
91
76
  if window_size is not None and window_size >= 0:
@@ -13,17 +13,11 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import os
16
- from typing import Any, Dict, List, Optional, Type, Union
17
-
18
- from openai import AsyncStream
19
- from pydantic import BaseModel
16
+ from typing import Any, Dict, Optional, Union
20
17
 
21
18
  from camel.configs import BedrockConfig
22
- from camel.messages import OpenAIMessage
23
19
  from camel.models.openai_compatible_model import OpenAICompatibleModel
24
20
  from camel.types import (
25
- ChatCompletion,
26
- ChatCompletionChunk,
27
21
  ModelType,
28
22
  )
29
23
  from camel.utils import BaseTokenCounter, api_keys_required
@@ -93,13 +87,3 @@ class AWSBedrockModel(OpenAICompatibleModel):
93
87
  max_retries=max_retries,
94
88
  **kwargs,
95
89
  )
96
-
97
- async def _arun(
98
- self,
99
- messages: List[OpenAIMessage],
100
- response_format: Optional[Type[BaseModel]] = None,
101
- tools: Optional[List[Dict[str, Any]]] = None,
102
- ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
103
- raise NotImplementedError(
104
- "AWS Bedrock does not support async inference."
105
- )
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
+ import copy
15
16
  import os
16
17
  from typing import Any, Dict, List, Optional, Type, Union
17
18
 
@@ -19,6 +20,7 @@ from openai import AsyncStream
19
20
  from pydantic import BaseModel
20
21
 
21
22
  from camel.configs import MoonshotConfig
23
+ from camel.logger import get_logger
22
24
  from camel.messages import OpenAIMessage
23
25
  from camel.models._utils import try_modify_message_with_format
24
26
  from camel.models.openai_compatible_model import OpenAICompatibleModel
@@ -34,6 +36,8 @@ from camel.utils import (
34
36
  update_langfuse_trace,
35
37
  )
36
38
 
39
+ logger = get_logger(__name__)
40
+
37
41
  if os.environ.get("LANGFUSE_ENABLED", "False").lower() == "true":
38
42
  try:
39
43
  from langfuse.decorators import observe
@@ -84,7 +88,7 @@ class MoonshotModel(OpenAICompatibleModel):
84
88
  model_type: Union[ModelType, str],
85
89
  model_config_dict: Optional[Dict[str, Any]] = None,
86
90
  api_key: Optional[str] = None,
87
- url: Optional[str] = "https://api.moonshot.ai/v1",
91
+ url: Optional[str] = None,
88
92
  token_counter: Optional[BaseTokenCounter] = None,
89
93
  timeout: Optional[float] = None,
90
94
  max_retries: int = 3,
@@ -93,7 +97,12 @@ class MoonshotModel(OpenAICompatibleModel):
93
97
  if model_config_dict is None:
94
98
  model_config_dict = MoonshotConfig().as_dict()
95
99
  api_key = api_key or os.environ.get("MOONSHOT_API_KEY")
96
- url = url or os.environ.get("MOONSHOT_API_BASE_URL")
100
+ # Preserve default URL if not provided
101
+ if url is None:
102
+ url = (
103
+ os.environ.get("MOONSHOT_API_BASE_URL")
104
+ or "https://api.moonshot.ai/v1"
105
+ )
97
106
  timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
98
107
  super().__init__(
99
108
  model_type=model_type,
@@ -125,12 +134,12 @@ class MoonshotModel(OpenAICompatibleModel):
125
134
  Returns:
126
135
  Dict[str, Any]: The prepared request configuration.
127
136
  """
128
- import copy
129
-
130
137
  request_config = copy.deepcopy(self.model_config_dict)
131
138
 
132
139
  if tools:
133
- request_config["tools"] = tools
140
+ # Clean tools to remove null types (Moonshot API incompatibility)
141
+ cleaned_tools = self._clean_tool_schemas(tools)
142
+ request_config["tools"] = cleaned_tools
134
143
  elif response_format:
135
144
  # Use the same approach as DeepSeek for structured output
136
145
  try_modify_message_with_format(messages[-1], response_format)
@@ -138,6 +147,94 @@ class MoonshotModel(OpenAICompatibleModel):
138
147
 
139
148
  return request_config
140
149
 
150
+ def _clean_tool_schemas(
151
+ self, tools: List[Dict[str, Any]]
152
+ ) -> List[Dict[str, Any]]:
153
+ r"""Clean tool schemas to remove null types for Moonshot compatibility.
154
+
155
+ Moonshot API doesn't accept {"type": "null"} in anyOf schemas.
156
+ This method removes null type definitions from parameters.
157
+
158
+ Args:
159
+ tools (List[Dict[str, Any]]): Original tool schemas.
160
+
161
+ Returns:
162
+ List[Dict[str, Any]]: Cleaned tool schemas.
163
+ """
164
+
165
+ def remove_null_from_schema(schema: Any) -> Any:
166
+ """Recursively remove null types from schema."""
167
+ if isinstance(schema, dict):
168
+ # Create a copy to avoid modifying the original
169
+ result = {}
170
+
171
+ for key, value in schema.items():
172
+ if key == 'type' and isinstance(value, list):
173
+ # Handle type arrays like ["string", "null"]
174
+ filtered_types = [t for t in value if t != 'null']
175
+ if len(filtered_types) == 1:
176
+ # Single type remains, convert to string
177
+ result[key] = filtered_types[0]
178
+ elif len(filtered_types) > 1:
179
+ # Multiple types remain, keep as array
180
+ result[key] = filtered_types
181
+ else:
182
+ # All were null, use string as fallback
183
+ logger.warning(
184
+ "All types in tool schema type array "
185
+ "were null, falling back to 'string' "
186
+ "type for Moonshot API compatibility. "
187
+ "Original tool schema may need review."
188
+ )
189
+ result[key] = 'string'
190
+ elif key == 'anyOf':
191
+ # Handle anyOf with null types
192
+ filtered = [
193
+ item
194
+ for item in value
195
+ if not (
196
+ isinstance(item, dict)
197
+ and item.get('type') == 'null'
198
+ )
199
+ ]
200
+ if len(filtered) == 1:
201
+ # If only one type remains, flatten it
202
+ return remove_null_from_schema(filtered[0])
203
+ elif len(filtered) > 1:
204
+ result[key] = [
205
+ remove_null_from_schema(item)
206
+ for item in filtered
207
+ ]
208
+ else:
209
+ # All were null, return string type as fallback
210
+ logger.warning(
211
+ "All types in tool schema anyOf were null, "
212
+ "falling back to 'string' type for "
213
+ "Moonshot API compatibility. Original "
214
+ "tool schema may need review."
215
+ )
216
+ return {"type": "string"}
217
+ else:
218
+ # Recursively process other values
219
+ result[key] = remove_null_from_schema(value)
220
+
221
+ return result
222
+ elif isinstance(schema, list):
223
+ return [remove_null_from_schema(item) for item in schema]
224
+ else:
225
+ return schema
226
+
227
+ cleaned_tools = copy.deepcopy(tools)
228
+ for tool in cleaned_tools:
229
+ if 'function' in tool and 'parameters' in tool['function']:
230
+ params = tool['function']['parameters']
231
+ if 'properties' in params:
232
+ params['properties'] = remove_null_from_schema(
233
+ params['properties']
234
+ )
235
+
236
+ return cleaned_tools
237
+
141
238
  @observe()
142
239
  async def _arun(
143
240
  self,
@@ -0,0 +1,122 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ from datetime import datetime, timezone
17
+ from typing import Any, Dict, List, Literal, Optional, Union
18
+
19
+ from pydantic import BaseModel, ConfigDict, Field
20
+
21
+
22
+ class WorkforceEventBase(BaseModel):
23
+ model_config = ConfigDict(frozen=True, extra='forbid')
24
+ event_type: Literal[
25
+ "task_decomposed",
26
+ "task_created",
27
+ "task_assigned",
28
+ "task_started",
29
+ "task_completed",
30
+ "task_failed",
31
+ "worker_created",
32
+ "worker_deleted",
33
+ "queue_status",
34
+ "all_tasks_completed",
35
+ ]
36
+ metadata: Optional[Dict[str, Any]] = None
37
+ timestamp: datetime = Field(
38
+ default_factory=lambda: datetime.now(timezone.utc)
39
+ )
40
+
41
+
42
+ class WorkerCreatedEvent(WorkforceEventBase):
43
+ event_type: Literal["worker_created"] = "worker_created"
44
+ worker_id: str
45
+ worker_type: str
46
+ role: str
47
+
48
+
49
+ class WorkerDeletedEvent(WorkforceEventBase):
50
+ event_type: Literal["worker_deleted"] = "worker_deleted"
51
+ worker_id: str
52
+ reason: Optional[str] = None
53
+
54
+
55
+ class TaskDecomposedEvent(WorkforceEventBase):
56
+ event_type: Literal["task_decomposed"] = "task_decomposed"
57
+ parent_task_id: str
58
+ subtask_ids: List[str]
59
+
60
+
61
+ class TaskCreatedEvent(WorkforceEventBase):
62
+ event_type: Literal["task_created"] = "task_created"
63
+ task_id: str
64
+ description: str
65
+ parent_task_id: Optional[str] = None
66
+ task_type: Optional[str] = None
67
+
68
+
69
+ class TaskAssignedEvent(WorkforceEventBase):
70
+ event_type: Literal["task_assigned"] = "task_assigned"
71
+ task_id: str
72
+ worker_id: str
73
+ queue_time_seconds: Optional[float] = None
74
+ dependencies: Optional[List[str]] = None
75
+
76
+
77
+ class TaskStartedEvent(WorkforceEventBase):
78
+ event_type: Literal["task_started"] = "task_started"
79
+ task_id: str
80
+ worker_id: str
81
+
82
+
83
+ class TaskCompletedEvent(WorkforceEventBase):
84
+ event_type: Literal["task_completed"] = "task_completed"
85
+ task_id: str
86
+ worker_id: str
87
+ result_summary: Optional[str] = None
88
+ processing_time_seconds: Optional[float] = None
89
+ token_usage: Optional[Dict[str, int]] = None
90
+
91
+
92
+ class TaskFailedEvent(WorkforceEventBase):
93
+ event_type: Literal["task_failed"] = "task_failed"
94
+ task_id: str
95
+ error_message: str
96
+ worker_id: Optional[str] = None
97
+
98
+
99
+ class AllTasksCompletedEvent(WorkforceEventBase):
100
+ event_type: Literal["all_tasks_completed"] = "all_tasks_completed"
101
+
102
+
103
+ class QueueStatusEvent(WorkforceEventBase):
104
+ event_type: Literal["queue_status"] = "queue_status"
105
+ queue_name: str
106
+ length: int
107
+ pending_task_ids: Optional[List[str]] = None
108
+ metadata: Optional[Dict[str, Any]] = None
109
+
110
+
111
+ WorkforceEvent = Union[
112
+ TaskDecomposedEvent,
113
+ TaskCreatedEvent,
114
+ TaskAssignedEvent,
115
+ TaskStartedEvent,
116
+ TaskCompletedEvent,
117
+ TaskFailedEvent,
118
+ WorkerCreatedEvent,
119
+ WorkerDeletedEvent,
120
+ AllTasksCompletedEvent,
121
+ QueueStatusEvent,
122
+ ]