camel-ai 0.2.67__py3-none-any.whl → 0.2.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -55,11 +55,8 @@ RUN curl -fsSL https://install.python-poetry.org | python3.10 - && \
55
55
  # Upgrade pip and install base Python packages
56
56
  RUN python3.10 -m pip install --upgrade pip setuptools wheel
57
57
 
58
- # Install uv
59
- RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
60
- mv /root/.local/bin/uv /usr/local/bin/uv && \
61
- mv /root/.local/bin/uvx /usr/local/bin/uvx && \
62
- chmod +x /usr/local/bin/uv /usr/local/bin/uvx
58
+ # Install uv using pip instead of the shell script
59
+ RUN pip install uv
63
60
 
64
61
  # Setup working directory
65
62
  WORKDIR /workspace
@@ -98,8 +98,8 @@ class Firecrawl:
98
98
  def scrape(
99
99
  self,
100
100
  url: str,
101
- params: Optional[Dict[str, Any]] = None,
102
- ) -> Dict:
101
+ params: Optional[Dict[str, str]] = None,
102
+ ) -> Dict[str, str]:
103
103
  r"""To scrape a single URL. This function supports advanced scraping
104
104
  by setting different parameters and returns the full scraped data as a
105
105
  dictionary.
@@ -108,11 +108,11 @@ class Firecrawl:
108
108
 
109
109
  Args:
110
110
  url (str): The URL to read.
111
- params (Optional[Dict[str, Any]]): Additional parameters for the
111
+ params (Optional[Dict[str, str]]): Additional parameters for the
112
112
  scrape request.
113
113
 
114
114
  Returns:
115
- Dict: The scraped data.
115
+ Dict[str, str]: The scraped data.
116
116
 
117
117
  Raises:
118
118
  RuntimeError: If the scrape process fails.
@@ -89,13 +89,20 @@ class VectorDBBlock(MemoryBlock):
89
89
  records (List[MemoryRecord]): Memory records to be added to the
90
90
  memory.
91
91
  """
92
+ # Filter out records with empty message content
93
+ valid_records = [
94
+ record
95
+ for record in records
96
+ if record.message.content and record.message.content.strip()
97
+ ]
98
+
92
99
  v_records = [
93
100
  VectorRecord(
94
101
  vector=self.embedding.embed(record.message.content),
95
102
  payload=record.to_dict(),
96
103
  id=str(record.uuid),
97
104
  )
98
- for record in records
105
+ for record in valid_records
99
106
  ]
100
107
  self.storage.add(v_records)
101
108
 
@@ -11,7 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
- from typing import List, Optional, Tuple
14
+ from typing import Dict, List, Optional, Tuple
15
15
 
16
16
  from pydantic import BaseModel
17
17
 
@@ -74,6 +74,8 @@ class ScoreBasedContextCreator(BaseContextCreator):
74
74
  3. Final output maintains chronological order and in history memory,
75
75
  the score of each message decreases according to keep_rate. The
76
76
  newer the message, the higher the score.
77
+ 4. Tool calls and their responses are kept together to maintain
78
+ API compatibility
77
79
 
78
80
  Args:
79
81
  records (List[ContextRecord]): List of context records with scores
@@ -126,12 +128,17 @@ class ScoreBasedContextCreator(BaseContextCreator):
126
128
  )
127
129
 
128
130
  # ======================
129
- # 3. Token Calculation
131
+ # 3. Tool Call Relationship Mapping
132
+ # ======================
133
+ tool_call_groups = self._group_tool_calls_and_responses(regular_units)
134
+
135
+ # ======================
136
+ # 4. Token Calculation
130
137
  # ======================
131
138
  total_tokens = system_tokens + sum(u.num_tokens for u in regular_units)
132
139
 
133
140
  # ======================
134
- # 4. Early Return if Within Limit
141
+ # 5. Early Return if Within Limit
135
142
  # ======================
136
143
  if total_tokens <= self.token_limit:
137
144
  sorted_units = sorted(
@@ -140,7 +147,7 @@ class ScoreBasedContextCreator(BaseContextCreator):
140
147
  return self._assemble_output(sorted_units, system_unit)
141
148
 
142
149
  # ======================
143
- # 5. Truncation Logic
150
+ # 6. Truncation Logic with Tool Call Awareness
144
151
  # ======================
145
152
  logger.warning(
146
153
  f"Context truncation required "
@@ -148,24 +155,12 @@ class ScoreBasedContextCreator(BaseContextCreator):
148
155
  f"pruning low-score messages."
149
156
  )
150
157
 
151
- # Sort for truncation: high scores first, older messages first at same
152
- # score
153
- sorted_for_truncation = sorted(
154
- regular_units, key=self._truncation_sort_key
158
+ remaining_units = self._truncate_with_tool_call_awareness(
159
+ regular_units, tool_call_groups, system_tokens
155
160
  )
156
161
 
157
- # Reverse to process from lowest score (end of sorted list)
158
- remaining_units = []
159
- current_total = system_tokens
160
-
161
- for unit in sorted_for_truncation:
162
- potential_total = current_total + unit.num_tokens
163
- if potential_total <= self.token_limit:
164
- remaining_units.append(unit)
165
- current_total = potential_total
166
-
167
162
  # ======================
168
- # 6. Output Assembly
163
+ # 7. Output Assembly
169
164
  # ======================
170
165
 
171
166
  # In case system message is the only message in memory when sorted
@@ -180,6 +175,115 @@ class ScoreBasedContextCreator(BaseContextCreator):
180
175
  final_units = sorted(remaining_units, key=self._conversation_sort_key)
181
176
  return self._assemble_output(final_units, system_unit)
182
177
 
178
+ def _group_tool_calls_and_responses(
179
+ self, units: List[_ContextUnit]
180
+ ) -> Dict[str, List[_ContextUnit]]:
181
+ r"""Groups tool calls with their corresponding responses.
182
+
183
+ Args:
184
+ units (List[_ContextUnit]): List of context units to analyze
185
+
186
+ Returns:
187
+ Dict[str, List[_ContextUnit]]: Mapping from tool_call_id to list of
188
+ related units (tool call + responses)
189
+ """
190
+ tool_call_groups: Dict[str, List[_ContextUnit]] = {}
191
+
192
+ for unit in units:
193
+ message = unit.record.memory_record.message
194
+ backend_role = unit.record.memory_record.role_at_backend
195
+
196
+ # Check if this is a tool call message
197
+ if hasattr(message, 'func_name') and hasattr(
198
+ message, 'tool_call_id'
199
+ ):
200
+ tool_call_id = getattr(message, 'tool_call_id', None)
201
+ if tool_call_id:
202
+ if tool_call_id not in tool_call_groups:
203
+ tool_call_groups[tool_call_id] = []
204
+ tool_call_groups[tool_call_id].append(unit)
205
+
206
+ # Check if this is a tool response message
207
+ elif backend_role == OpenAIBackendRole.FUNCTION:
208
+ tool_call_id = None
209
+ if hasattr(message, 'tool_call_id'):
210
+ tool_call_id = getattr(message, 'tool_call_id', None)
211
+ elif hasattr(message, 'result') and hasattr(
212
+ message, 'tool_call_id'
213
+ ):
214
+ tool_call_id = getattr(message, 'tool_call_id', None)
215
+
216
+ if tool_call_id:
217
+ if tool_call_id not in tool_call_groups:
218
+ tool_call_groups[tool_call_id] = []
219
+ tool_call_groups[tool_call_id].append(unit)
220
+
221
+ return tool_call_groups
222
+
223
+ def _truncate_with_tool_call_awareness(
224
+ self,
225
+ regular_units: List[_ContextUnit],
226
+ tool_call_groups: Dict[str, List[_ContextUnit]],
227
+ system_tokens: int,
228
+ ) -> List[_ContextUnit]:
229
+ r"""Truncates messages while preserving tool call-response pairs.
230
+
231
+ Args:
232
+ regular_units (List[_ContextUnit]): All regular message units
233
+ tool_call_groups (Dict[str, List[_ContextUnit]]): Grouped tool
234
+ calls
235
+ system_tokens (int): Tokens used by system message
236
+
237
+ Returns:
238
+ List[_ContextUnit]: Units that fit within token limit
239
+ """
240
+ # Create sets for quick lookup of tool call related units
241
+ tool_call_unit_ids = set()
242
+ for group in tool_call_groups.values():
243
+ for unit in group:
244
+ tool_call_unit_ids.add(unit.record.memory_record.uuid)
245
+
246
+ # Separate tool call groups and standalone units
247
+ standalone_units = [
248
+ u
249
+ for u in regular_units
250
+ if u.record.memory_record.uuid not in tool_call_unit_ids
251
+ ]
252
+
253
+ # Sort standalone units for truncation (high scores first)
254
+ standalone_units.sort(key=self._truncation_sort_key)
255
+
256
+ # Sort tool call groups by their best (highest) score
257
+ sorted_tool_groups = []
258
+ for _tool_call_id, group in tool_call_groups.items():
259
+ # Use the highest score in the group as the group's score
260
+ best_score = max(unit.record.score for unit in group)
261
+ latest_timestamp = max(unit.record.timestamp for unit in group)
262
+ group_tokens = sum(unit.num_tokens for unit in group)
263
+ sorted_tool_groups.append(
264
+ ((-best_score, -latest_timestamp), group, group_tokens)
265
+ )
266
+
267
+ sorted_tool_groups.sort(key=lambda x: x[0])
268
+
269
+ # Greedy selection to fit within token limit
270
+ remaining_units = []
271
+ current_tokens = system_tokens
272
+
273
+ # First, try to include complete tool call groups
274
+ for _, group, group_tokens in sorted_tool_groups:
275
+ if current_tokens + group_tokens <= self.token_limit:
276
+ remaining_units.extend(group)
277
+ current_tokens += group_tokens
278
+
279
+ # Then, include standalone units
280
+ for unit in standalone_units:
281
+ if current_tokens + unit.num_tokens <= self.token_limit:
282
+ remaining_units.append(unit)
283
+ current_tokens += unit.num_tokens
284
+
285
+ return remaining_units
286
+
183
287
  def _extract_system_message(
184
288
  self, records: List[ContextRecord]
185
289
  ) -> Tuple[Optional[_ContextUnit], List[_ContextUnit]]:
@@ -12,11 +12,14 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
- from typing import Any, Dict, Optional, Union
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import AsyncStream, Stream
16
18
 
17
19
  from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
20
+ from camel.messages import OpenAIMessage
18
21
  from camel.models.openai_compatible_model import OpenAICompatibleModel
19
- from camel.types import ModelType
22
+ from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
20
23
  from camel.utils import (
21
24
  AnthropicTokenCounter,
22
25
  BaseTokenCounter,
@@ -25,6 +28,47 @@ from camel.utils import (
25
28
  )
26
29
 
27
30
 
31
+ def strip_trailing_whitespace_from_messages(
32
+ messages: List[OpenAIMessage],
33
+ ) -> List[OpenAIMessage]:
34
+ r"""Strip trailing whitespace from all message contents in a list of
35
+ messages. This is necessary because the Anthropic API doesn't allow
36
+ trailing whitespace in message content.
37
+
38
+ Args:
39
+ messages (List[OpenAIMessage]): List of messages to process
40
+
41
+ Returns:
42
+ List[OpenAIMessage]: The processed messages with trailing whitespace
43
+ removed
44
+ """
45
+ if not messages:
46
+ return messages
47
+
48
+ # Create a deep copy to avoid modifying the original messages
49
+ processed_messages = [dict(msg) for msg in messages]
50
+
51
+ # Process each message
52
+ for msg in processed_messages:
53
+ if "content" in msg and msg["content"] is not None:
54
+ if isinstance(msg["content"], str):
55
+ msg["content"] = msg["content"].rstrip()
56
+ elif isinstance(msg["content"], list):
57
+ # Handle content that's a list of content parts (e.g., for
58
+ # multimodal content)
59
+ for i, part in enumerate(msg["content"]):
60
+ if (
61
+ isinstance(part, dict)
62
+ and "text" in part
63
+ and isinstance(part["text"], str)
64
+ ):
65
+ part["text"] = part["text"].rstrip()
66
+ elif isinstance(part, str):
67
+ msg["content"][i] = part.rstrip()
68
+
69
+ return processed_messages # type: ignore[return-value]
70
+
71
+
28
72
  class AnthropicModel(OpenAICompatibleModel):
29
73
  r"""Anthropic API in a unified OpenAICompatibleModel interface.
30
74
 
@@ -89,6 +133,9 @@ class AnthropicModel(OpenAICompatibleModel):
89
133
  **kwargs,
90
134
  )
91
135
 
136
+ # Monkey patch the AnthropicTokenCounter to handle trailing whitespace
137
+ self._patch_anthropic_token_counter()
138
+
92
139
  @property
93
140
  def token_counter(self) -> BaseTokenCounter:
94
141
  r"""Initialize the token counter for the model backend.
@@ -115,3 +162,68 @@ class AnthropicModel(OpenAICompatibleModel):
115
162
  f"Unexpected argument `{param}` is "
116
163
  "input into Anthropic model backend."
117
164
  )
165
+
166
+ def _request_chat_completion(
167
+ self,
168
+ messages: List[OpenAIMessage],
169
+ tools: Optional[List[Dict[str, Any]]] = None,
170
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
171
+ # Strip trailing whitespace from all message contents to prevent
172
+ # Anthropic API errors
173
+ processed_messages = strip_trailing_whitespace_from_messages(messages)
174
+
175
+ # Call the parent class method
176
+ return super()._request_chat_completion(processed_messages, tools)
177
+
178
+ async def _arequest_chat_completion(
179
+ self,
180
+ messages: List[OpenAIMessage],
181
+ tools: Optional[List[Dict[str, Any]]] = None,
182
+ ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
183
+ # Strip trailing whitespace from all message contents to prevent
184
+ # Anthropic API errors
185
+ processed_messages = strip_trailing_whitespace_from_messages(messages)
186
+
187
+ # Call the parent class method
188
+ return await super()._arequest_chat_completion(
189
+ processed_messages, tools
190
+ )
191
+
192
+ def _patch_anthropic_token_counter(self):
193
+ r"""Monkey patch the AnthropicTokenCounter class to handle trailing
194
+ whitespace.
195
+
196
+ This patches the count_tokens_from_messages method to strip trailing
197
+ whitespace from message content before sending to the Anthropic API.
198
+ """
199
+ import functools
200
+
201
+ from anthropic.types import MessageParam
202
+
203
+ from camel.utils import AnthropicTokenCounter
204
+
205
+ original_count_tokens = (
206
+ AnthropicTokenCounter.count_tokens_from_messages
207
+ )
208
+
209
+ @functools.wraps(original_count_tokens)
210
+ def patched_count_tokens(self, messages):
211
+ # Process messages to remove trailing whitespace
212
+ processed_messages = strip_trailing_whitespace_from_messages(
213
+ messages
214
+ )
215
+
216
+ # Use the processed messages with the original method
217
+ return self.client.messages.count_tokens(
218
+ messages=[
219
+ MessageParam(
220
+ content=str(msg["content"]),
221
+ role="user" if msg["role"] == "user" else "assistant",
222
+ )
223
+ for msg in processed_messages
224
+ ],
225
+ model=self.model,
226
+ ).input_tokens
227
+
228
+ # Apply the monkey patch
229
+ AnthropicTokenCounter.count_tokens_from_messages = patched_count_tokens
@@ -88,16 +88,17 @@ Do not include any other text, explanations, justifications, or conversational f
88
88
 
89
89
  PROCESS_TASK_PROMPT = TextPrompt(
90
90
  """You need to process one given task.
91
- Here are results of some prerequisite tasks that you can refer to:
91
+
92
+ Please keep in mind the task you are going to process, the content of the task that you need to do is:
92
93
 
93
94
  ==============================
94
- {dependency_tasks_info}
95
+ {content}
95
96
  ==============================
96
97
 
97
- The content of the task that you need to do is:
98
+ Here are results of some prerequisite tasks that you can refer to:
98
99
 
99
100
  ==============================
100
- {content}
101
+ {dependency_tasks_info}
101
102
  ==============================
102
103
 
103
104
  Here are some additional information about the task:
@@ -182,11 +183,43 @@ Now you should summarize the scenario and return the result of the task.
182
183
  """
183
184
  )
184
185
 
185
- WF_TASK_DECOMPOSE_PROMPT = r"""You need to decompose the given task into subtasks according to the workers available in the group, following these important principles:
186
+ WF_TASK_DECOMPOSE_PROMPT = r"""You need to decompose the given task into subtasks according to the workers available in the group, following these important principles to maximize efficiency and parallelism:
187
+
188
+ 1. **Strategic Grouping for Sequential Work**:
189
+ * If a series of steps must be done in order *and* can be handled by the same worker type, group them into a single subtask to maintain flow and minimize handoffs.
190
+
191
+ 2. **Aggressive Parallelization**:
192
+ * **Across Different Worker Specializations**: If distinct phases of the overall task require different types of workers (e.g., research by a 'SearchAgent', then content creation by a 'DocumentAgent'), define these as separate subtasks.
193
+ * **Within a Single Phase (Data/Task Parallelism)**: If a phase involves repetitive operations on multiple items (e.g., processing 10 documents, fetching 5 web pages, analyzing 3 datasets):
194
+ * Decompose this into parallel subtasks, one for each item or a small batch of items.
195
+ * This applies even if the same type of worker handles these parallel subtasks. The goal is to leverage multiple available workers or allow concurrent processing.
196
+
197
+ 3. **Subtask Design for Efficiency**:
198
+ * **Actionable and Well-Defined**: Each subtask should have a clear, achievable goal.
199
+ * **Balanced Granularity**: Make subtasks large enough to be meaningful but small enough to enable parallelism and quick feedback. Avoid overly large subtasks that hide parallel opportunities.
200
+ * **Consider Dependencies**: While you list tasks sequentially, think about the true dependencies. The workforce manager will handle execution based on these implied dependencies and worker availability.
201
+
202
+ These principles aim to reduce overall completion time by maximizing concurrent work and effectively utilizing all available worker capabilities.
203
+
204
+ **EXAMPLE FORMAT ONLY** (DO NOT use this example content for actual task decomposition):
205
+
206
+ If given a hypothetical task requiring research, analysis, and reporting with multiple items to process, you should decompose it to maximize parallelism:
207
+
208
+ * Poor decomposition (monolithic):
209
+ `<tasks><task>Do all research, analysis, and write final report.</task></tasks>`
210
+
211
+ * Better decomposition (parallel structure):
212
+ ```
213
+ <tasks>
214
+ <task>Subtask 1 (ResearchAgent): Gather initial data and resources.</task>
215
+ <task>Subtask 2.1 (AnalysisAgent): Analyze Item A from Subtask 1 results.</task>
216
+ <task>Subtask 2.2 (AnalysisAgent): Analyze Item B from Subtask 1 results.</task>
217
+ <task>Subtask 2.N (AnalysisAgent): Analyze Item N from Subtask 1 results.</task>
218
+ <task>Subtask 3 (ReportAgent): Compile all analyses into final report.</task>
219
+ </tasks>
220
+ ```
186
221
 
187
- 1. Keep tasks that are sequential and require the same type of worker together in one subtask
188
- 2. Only decompose tasks that can be handled in parallel and require different types of workers
189
- 3. This ensures efficient execution by minimizing context switching between workers
222
+ **END OF FORMAT EXAMPLE** - Now apply this structure to your actual task below.
190
223
 
191
224
  The content of the task is:
192
225
 
@@ -207,7 +240,7 @@ Following are the available workers, given in the format <ID>: <description>.
207
240
  {child_nodes_info}
208
241
  ==============================
209
242
 
210
- You must return the subtasks in the format of a numbered list within <tasks> tags, as shown below:
243
+ You must return the subtasks as a list of individual subtasks within <tasks> tags. If your decomposition, following the principles and detailed example above (e.g., for summarizing multiple papers), results in several parallelizable actions, EACH of those actions must be represented as a separate <task> entry. For instance, the general format is:
211
244
 
212
245
  <tasks>
213
246
  <task>Subtask 1</task>
@@ -32,15 +32,22 @@ class SingleAgentWorker(Worker):
32
32
  Args:
33
33
  description (str): Description of the node.
34
34
  worker (ChatAgent): Worker of the node. A single agent.
35
+ max_concurrent_tasks (int): Maximum number of tasks this worker can
36
+ process concurrently. (default: :obj:`10`)
35
37
  """
36
38
 
37
39
  def __init__(
38
40
  self,
39
41
  description: str,
40
42
  worker: ChatAgent,
43
+ max_concurrent_tasks: int = 10,
41
44
  ) -> None:
42
45
  node_id = worker.agent_id
43
- super().__init__(description, node_id=node_id)
46
+ super().__init__(
47
+ description,
48
+ node_id=node_id,
49
+ max_concurrent_tasks=max_concurrent_tasks,
50
+ )
44
51
  self.worker = worker
45
52
 
46
53
  def reset(self) -> Any:
@@ -51,11 +58,12 @@ class SingleAgentWorker(Worker):
51
58
  async def _process_task(
52
59
  self, task: Task, dependencies: List[Task]
53
60
  ) -> TaskState:
54
- r"""Processes a task with its dependencies.
61
+ r"""Processes a task with its dependencies using a cloned agent.
55
62
 
56
63
  This method asynchronously processes a given task, considering its
57
- dependencies, by sending a generated prompt to a worker. It updates
58
- the task's result based on the agent's response.
64
+ dependencies, by sending a generated prompt to a cloned worker agent.
65
+ Using a cloned agent ensures that concurrent tasks don't interfere
66
+ with each other's state.
59
67
 
60
68
  Args:
61
69
  task (Task): The task to process, which includes necessary details
@@ -66,6 +74,10 @@ class SingleAgentWorker(Worker):
66
74
  TaskState: `TaskState.DONE` if processed successfully, otherwise
67
75
  `TaskState.FAILED`.
68
76
  """
77
+ # Clone the agent for this specific task to avoid state conflicts
78
+ # when processing multiple tasks concurrently
79
+ worker_agent = self.worker.clone(with_memory=False)
80
+
69
81
  dependency_tasks_info = self._get_dep_tasks_info(dependencies)
70
82
  prompt = PROCESS_TASK_PROMPT.format(
71
83
  content=task.content,
@@ -73,7 +85,7 @@ class SingleAgentWorker(Worker):
73
85
  additional_info=task.additional_info,
74
86
  )
75
87
  try:
76
- response = await self.worker.astep(
88
+ response = await worker_agent.astep(
77
89
  prompt, response_format=TaskResult
78
90
  )
79
91
  except Exception as e:
@@ -83,6 +95,13 @@ class SingleAgentWorker(Worker):
83
95
  )
84
96
  return TaskState.FAILED
85
97
 
98
+ # Get actual token usage from the cloned agent that processed this task
99
+ try:
100
+ _, total_token_count = worker_agent.memory.get_context()
101
+ except Exception:
102
+ # Fallback if memory context unavailable
103
+ total_token_count = 0
104
+
86
105
  # Populate additional_info with worker attempt details
87
106
  if task.additional_info is None:
88
107
  task.additional_info = {}
@@ -90,14 +109,20 @@ class SingleAgentWorker(Worker):
90
109
  # Create worker attempt details with descriptive keys
91
110
  worker_attempt_details = {
92
111
  "agent_id": getattr(
112
+ worker_agent, "agent_id", worker_agent.role_name
113
+ ),
114
+ "original_worker_id": getattr(
93
115
  self.worker, "agent_id", self.worker.role_name
94
116
  ),
95
117
  "timestamp": str(datetime.datetime.now()),
96
118
  "description": f"Attempt by "
97
- f"{getattr(self.worker, 'agent_id', self.worker.role_name)} "
119
+ f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
120
+ f"(cloned from "
121
+ f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
98
122
  f"to process task {task.content}",
99
123
  "response_content": response.msg.content,
100
124
  "tool_calls": response.info["tool_calls"],
125
+ "total_token_count": total_token_count,
101
126
  }
102
127
 
103
128
  # Store the worker attempt in additional_info
@@ -105,6 +130,11 @@ class SingleAgentWorker(Worker):
105
130
  task.additional_info["worker_attempts"] = []
106
131
  task.additional_info["worker_attempts"].append(worker_attempt_details)
107
132
 
133
+ # Store the actual token usage for this specific task
134
+ task.additional_info["token_usage"] = {
135
+ "total_tokens": total_token_count
136
+ }
137
+
108
138
  print(f"======\n{Fore.GREEN}Reply from {self}:{Fore.RESET}")
109
139
 
110
140
  result_dict = json.loads(response.msg.content)
@@ -23,6 +23,8 @@ class PacketStatus(Enum):
23
23
  states:
24
24
 
25
25
  - ``SENT``: The packet has been sent to a worker.
26
+ - ``PROCESSING``: The packet has been claimed by a worker and is being
27
+ processed.
26
28
  - ``RETURNED``: The packet has been returned by the worker, meaning that
27
29
  the status of the task inside has been updated.
28
30
  - ``ARCHIVED``: The packet has been archived, meaning that the content of
@@ -31,6 +33,7 @@ class PacketStatus(Enum):
31
33
  """
32
34
 
33
35
  SENT = "SENT"
36
+ PROCESSING = "PROCESSING"
34
37
  RETURNED = "RETURNED"
35
38
  ARCHIVED = "ARCHIVED"
36
39
 
@@ -97,8 +100,9 @@ class TaskChannel:
97
100
  await self._condition.wait()
98
101
 
99
102
  async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
100
- r"""Get a task from the channel that has been assigned to the
101
- assignee.
103
+ r"""Atomically get and claim a task from the channel that has been
104
+ assigned to the assignee. This prevents race conditions where multiple
105
+ concurrent calls might retrieve the same task.
102
106
  """
103
107
  async with self._condition:
104
108
  while True:
@@ -107,6 +111,9 @@ class TaskChannel:
107
111
  packet.status == PacketStatus.SENT
108
112
  and packet.assignee_id == assignee_id
109
113
  ):
114
+ # Atomically claim the task by changing its status
115
+ packet.status = PacketStatus.PROCESSING
116
+ self._condition.notify_all()
110
117
  return packet.task
111
118
  await self._condition.wait()
112
119