camel-ai 0.2.66__py3-none-any.whl → 0.2.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/configs/__init__.py +3 -0
- camel/configs/qianfan_config.py +85 -0
- camel/environments/__init__.py +12 -0
- camel/environments/rlcards_env.py +860 -0
- camel/interpreters/docker/Dockerfile +2 -5
- camel/loaders/firecrawl_reader.py +4 -4
- camel/memories/blocks/vectordb_block.py +8 -1
- camel/memories/context_creators/score_based.py +123 -19
- camel/models/__init__.py +2 -0
- camel/models/aiml_model.py +8 -0
- camel/models/anthropic_model.py +122 -2
- camel/models/aws_bedrock_model.py +8 -0
- camel/models/azure_openai_model.py +14 -5
- camel/models/base_model.py +4 -0
- camel/models/cohere_model.py +9 -2
- camel/models/crynux_model.py +8 -0
- camel/models/deepseek_model.py +8 -0
- camel/models/gemini_model.py +8 -0
- camel/models/groq_model.py +8 -0
- camel/models/internlm_model.py +8 -0
- camel/models/litellm_model.py +5 -0
- camel/models/lmstudio_model.py +14 -1
- camel/models/mistral_model.py +15 -1
- camel/models/model_factory.py +6 -0
- camel/models/modelscope_model.py +8 -0
- camel/models/moonshot_model.py +8 -0
- camel/models/nemotron_model.py +17 -2
- camel/models/netmind_model.py +8 -0
- camel/models/novita_model.py +8 -0
- camel/models/nvidia_model.py +8 -0
- camel/models/ollama_model.py +8 -0
- camel/models/openai_compatible_model.py +23 -5
- camel/models/openai_model.py +21 -4
- camel/models/openrouter_model.py +8 -0
- camel/models/ppio_model.py +8 -0
- camel/models/qianfan_model.py +104 -0
- camel/models/qwen_model.py +8 -0
- camel/models/reka_model.py +18 -3
- camel/models/samba_model.py +17 -3
- camel/models/sglang_model.py +20 -5
- camel/models/siliconflow_model.py +8 -0
- camel/models/stub_model.py +8 -1
- camel/models/togetherai_model.py +8 -0
- camel/models/vllm_model.py +7 -0
- camel/models/volcano_model.py +14 -1
- camel/models/watsonx_model.py +4 -1
- camel/models/yi_model.py +8 -0
- camel/models/zhipuai_model.py +8 -0
- camel/societies/workforce/prompts.py +71 -22
- camel/societies/workforce/role_playing_worker.py +3 -8
- camel/societies/workforce/single_agent_worker.py +37 -9
- camel/societies/workforce/task_channel.py +25 -20
- camel/societies/workforce/utils.py +104 -14
- camel/societies/workforce/worker.py +98 -16
- camel/societies/workforce/workforce.py +1289 -101
- camel/societies/workforce/workforce_logger.py +613 -0
- camel/tasks/task.py +16 -5
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/code_execution.py +1 -1
- camel/toolkits/playwright_mcp_toolkit.py +2 -1
- camel/toolkits/pptx_toolkit.py +4 -4
- camel/types/enums.py +32 -0
- camel/types/unified_model_type.py +5 -0
- {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/METADATA +4 -3
- {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/RECORD +68 -64
- {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/licenses/LICENSE +0 -0
|
@@ -28,7 +28,6 @@ from camel.societies.workforce.prompts import (
|
|
|
28
28
|
from camel.societies.workforce.utils import TaskResult
|
|
29
29
|
from camel.societies.workforce.worker import Worker
|
|
30
30
|
from camel.tasks.task import Task, TaskState, validate_task_content
|
|
31
|
-
from camel.utils import print_text_animated
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
class RolePlayingWorker(Worker):
|
|
@@ -141,24 +140,20 @@ class RolePlayingWorker(Worker):
|
|
|
141
140
|
)
|
|
142
141
|
break
|
|
143
142
|
|
|
144
|
-
|
|
143
|
+
print(
|
|
145
144
|
f"{Fore.BLUE}AI User:\n\n{user_response.msg.content}"
|
|
146
145
|
f"{Fore.RESET}\n",
|
|
147
|
-
delay=0.005,
|
|
148
146
|
)
|
|
149
147
|
chat_history.append(f"AI User: {user_response.msg.content}")
|
|
150
148
|
|
|
151
|
-
|
|
152
|
-
f"{Fore.GREEN}AI Assistant:{Fore.RESET}", delay=0.005
|
|
153
|
-
)
|
|
149
|
+
print(f"{Fore.GREEN}AI Assistant:{Fore.RESET}")
|
|
154
150
|
|
|
155
151
|
for func_record in assistant_response.info['tool_calls']:
|
|
156
152
|
print(func_record)
|
|
157
153
|
|
|
158
|
-
|
|
154
|
+
print(
|
|
159
155
|
f"\n{Fore.GREEN}{assistant_response.msg.content}"
|
|
160
156
|
f"{Fore.RESET}\n",
|
|
161
|
-
delay=0.005,
|
|
162
157
|
)
|
|
163
158
|
chat_history.append(
|
|
164
159
|
f"AI Assistant: {assistant_response.msg.content}"
|
|
@@ -24,7 +24,6 @@ from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
|
|
|
24
24
|
from camel.societies.workforce.utils import TaskResult
|
|
25
25
|
from camel.societies.workforce.worker import Worker
|
|
26
26
|
from camel.tasks.task import Task, TaskState, validate_task_content
|
|
27
|
-
from camel.utils import print_text_animated
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
class SingleAgentWorker(Worker):
|
|
@@ -33,15 +32,22 @@ class SingleAgentWorker(Worker):
|
|
|
33
32
|
Args:
|
|
34
33
|
description (str): Description of the node.
|
|
35
34
|
worker (ChatAgent): Worker of the node. A single agent.
|
|
35
|
+
max_concurrent_tasks (int): Maximum number of tasks this worker can
|
|
36
|
+
process concurrently. (default: :obj:`10`)
|
|
36
37
|
"""
|
|
37
38
|
|
|
38
39
|
def __init__(
|
|
39
40
|
self,
|
|
40
41
|
description: str,
|
|
41
42
|
worker: ChatAgent,
|
|
43
|
+
max_concurrent_tasks: int = 10,
|
|
42
44
|
) -> None:
|
|
43
45
|
node_id = worker.agent_id
|
|
44
|
-
super().__init__(
|
|
46
|
+
super().__init__(
|
|
47
|
+
description,
|
|
48
|
+
node_id=node_id,
|
|
49
|
+
max_concurrent_tasks=max_concurrent_tasks,
|
|
50
|
+
)
|
|
45
51
|
self.worker = worker
|
|
46
52
|
|
|
47
53
|
def reset(self) -> Any:
|
|
@@ -52,11 +58,12 @@ class SingleAgentWorker(Worker):
|
|
|
52
58
|
async def _process_task(
|
|
53
59
|
self, task: Task, dependencies: List[Task]
|
|
54
60
|
) -> TaskState:
|
|
55
|
-
r"""Processes a task with its dependencies.
|
|
61
|
+
r"""Processes a task with its dependencies using a cloned agent.
|
|
56
62
|
|
|
57
63
|
This method asynchronously processes a given task, considering its
|
|
58
|
-
dependencies, by sending a generated prompt to a worker.
|
|
59
|
-
|
|
64
|
+
dependencies, by sending a generated prompt to a cloned worker agent.
|
|
65
|
+
Using a cloned agent ensures that concurrent tasks don't interfere
|
|
66
|
+
with each other's state.
|
|
60
67
|
|
|
61
68
|
Args:
|
|
62
69
|
task (Task): The task to process, which includes necessary details
|
|
@@ -67,6 +74,10 @@ class SingleAgentWorker(Worker):
|
|
|
67
74
|
TaskState: `TaskState.DONE` if processed successfully, otherwise
|
|
68
75
|
`TaskState.FAILED`.
|
|
69
76
|
"""
|
|
77
|
+
# Clone the agent for this specific task to avoid state conflicts
|
|
78
|
+
# when processing multiple tasks concurrently
|
|
79
|
+
worker_agent = self.worker.clone(with_memory=False)
|
|
80
|
+
|
|
70
81
|
dependency_tasks_info = self._get_dep_tasks_info(dependencies)
|
|
71
82
|
prompt = PROCESS_TASK_PROMPT.format(
|
|
72
83
|
content=task.content,
|
|
@@ -74,7 +85,7 @@ class SingleAgentWorker(Worker):
|
|
|
74
85
|
additional_info=task.additional_info,
|
|
75
86
|
)
|
|
76
87
|
try:
|
|
77
|
-
response = await
|
|
88
|
+
response = await worker_agent.astep(
|
|
78
89
|
prompt, response_format=TaskResult
|
|
79
90
|
)
|
|
80
91
|
except Exception as e:
|
|
@@ -84,6 +95,13 @@ class SingleAgentWorker(Worker):
|
|
|
84
95
|
)
|
|
85
96
|
return TaskState.FAILED
|
|
86
97
|
|
|
98
|
+
# Get actual token usage from the cloned agent that processed this task
|
|
99
|
+
try:
|
|
100
|
+
_, total_token_count = worker_agent.memory.get_context()
|
|
101
|
+
except Exception:
|
|
102
|
+
# Fallback if memory context unavailable
|
|
103
|
+
total_token_count = 0
|
|
104
|
+
|
|
87
105
|
# Populate additional_info with worker attempt details
|
|
88
106
|
if task.additional_info is None:
|
|
89
107
|
task.additional_info = {}
|
|
@@ -91,14 +109,20 @@ class SingleAgentWorker(Worker):
|
|
|
91
109
|
# Create worker attempt details with descriptive keys
|
|
92
110
|
worker_attempt_details = {
|
|
93
111
|
"agent_id": getattr(
|
|
112
|
+
worker_agent, "agent_id", worker_agent.role_name
|
|
113
|
+
),
|
|
114
|
+
"original_worker_id": getattr(
|
|
94
115
|
self.worker, "agent_id", self.worker.role_name
|
|
95
116
|
),
|
|
96
117
|
"timestamp": str(datetime.datetime.now()),
|
|
97
118
|
"description": f"Attempt by "
|
|
98
|
-
f"{getattr(
|
|
119
|
+
f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
|
|
120
|
+
f"(cloned from "
|
|
121
|
+
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
|
|
99
122
|
f"to process task {task.content}",
|
|
100
123
|
"response_content": response.msg.content,
|
|
101
124
|
"tool_calls": response.info["tool_calls"],
|
|
125
|
+
"total_token_count": total_token_count,
|
|
102
126
|
}
|
|
103
127
|
|
|
104
128
|
# Store the worker attempt in additional_info
|
|
@@ -106,15 +130,19 @@ class SingleAgentWorker(Worker):
|
|
|
106
130
|
task.additional_info["worker_attempts"] = []
|
|
107
131
|
task.additional_info["worker_attempts"].append(worker_attempt_details)
|
|
108
132
|
|
|
133
|
+
# Store the actual token usage for this specific task
|
|
134
|
+
task.additional_info["token_usage"] = {
|
|
135
|
+
"total_tokens": total_token_count
|
|
136
|
+
}
|
|
137
|
+
|
|
109
138
|
print(f"======\n{Fore.GREEN}Reply from {self}:{Fore.RESET}")
|
|
110
139
|
|
|
111
140
|
result_dict = json.loads(response.msg.content)
|
|
112
141
|
task_result = TaskResult(**result_dict)
|
|
113
142
|
|
|
114
143
|
color = Fore.RED if task_result.failed else Fore.GREEN
|
|
115
|
-
|
|
144
|
+
print(
|
|
116
145
|
f"\n{color}{task_result.content}{Fore.RESET}\n======",
|
|
117
|
-
delay=0.005,
|
|
118
146
|
)
|
|
119
147
|
|
|
120
148
|
if task_result.failed:
|
|
@@ -23,6 +23,8 @@ class PacketStatus(Enum):
|
|
|
23
23
|
states:
|
|
24
24
|
|
|
25
25
|
- ``SENT``: The packet has been sent to a worker.
|
|
26
|
+
- ``PROCESSING``: The packet has been claimed by a worker and is being
|
|
27
|
+
processed.
|
|
26
28
|
- ``RETURNED``: The packet has been returned by the worker, meaning that
|
|
27
29
|
the status of the task inside has been updated.
|
|
28
30
|
- ``ARCHIVED``: The packet has been archived, meaning that the content of
|
|
@@ -31,6 +33,7 @@ class PacketStatus(Enum):
|
|
|
31
33
|
"""
|
|
32
34
|
|
|
33
35
|
SENT = "SENT"
|
|
36
|
+
PROCESSING = "PROCESSING"
|
|
34
37
|
RETURNED = "RETURNED"
|
|
35
38
|
ARCHIVED = "ARCHIVED"
|
|
36
39
|
|
|
@@ -79,7 +82,6 @@ class TaskChannel:
|
|
|
79
82
|
r"""An internal class used by Workforce to manage tasks."""
|
|
80
83
|
|
|
81
84
|
def __init__(self) -> None:
|
|
82
|
-
self._task_id_list: List[str] = []
|
|
83
85
|
self._condition = asyncio.Condition()
|
|
84
86
|
self._task_dict: Dict[str, Packet] = {}
|
|
85
87
|
|
|
@@ -89,8 +91,7 @@ class TaskChannel:
|
|
|
89
91
|
"""
|
|
90
92
|
async with self._condition:
|
|
91
93
|
while True:
|
|
92
|
-
for
|
|
93
|
-
packet = self._task_dict[task_id]
|
|
94
|
+
for packet in self._task_dict.values():
|
|
94
95
|
if packet.publisher_id != publisher_id:
|
|
95
96
|
continue
|
|
96
97
|
if packet.status != PacketStatus.RETURNED:
|
|
@@ -99,17 +100,20 @@ class TaskChannel:
|
|
|
99
100
|
await self._condition.wait()
|
|
100
101
|
|
|
101
102
|
async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
|
|
102
|
-
r"""
|
|
103
|
-
assignee.
|
|
103
|
+
r"""Atomically get and claim a task from the channel that has been
|
|
104
|
+
assigned to the assignee. This prevents race conditions where multiple
|
|
105
|
+
concurrent calls might retrieve the same task.
|
|
104
106
|
"""
|
|
105
107
|
async with self._condition:
|
|
106
108
|
while True:
|
|
107
|
-
for
|
|
108
|
-
packet = self._task_dict[task_id]
|
|
109
|
+
for packet in self._task_dict.values():
|
|
109
110
|
if (
|
|
110
111
|
packet.status == PacketStatus.SENT
|
|
111
112
|
and packet.assignee_id == assignee_id
|
|
112
113
|
):
|
|
114
|
+
# Atomically claim the task by changing its status
|
|
115
|
+
packet.status = PacketStatus.PROCESSING
|
|
116
|
+
self._condition.notify_all()
|
|
113
117
|
return packet.task
|
|
114
118
|
await self._condition.wait()
|
|
115
119
|
|
|
@@ -119,7 +123,6 @@ class TaskChannel:
|
|
|
119
123
|
r"""Send a task to the channel with specified publisher and assignee,
|
|
120
124
|
along with the dependency of the task."""
|
|
121
125
|
async with self._condition:
|
|
122
|
-
self._task_id_list.append(task.id)
|
|
123
126
|
packet = Packet(task, publisher_id, assignee_id)
|
|
124
127
|
self._task_dict[packet.task.id] = packet
|
|
125
128
|
self._condition.notify_all()
|
|
@@ -130,7 +133,6 @@ class TaskChannel:
|
|
|
130
133
|
r"""Post a dependency to the channel. A dependency is a task that is
|
|
131
134
|
archived, and will be referenced by other tasks."""
|
|
132
135
|
async with self._condition:
|
|
133
|
-
self._task_id_list.append(dependency.id)
|
|
134
136
|
packet = Packet(
|
|
135
137
|
dependency, publisher_id, status=PacketStatus.ARCHIVED
|
|
136
138
|
)
|
|
@@ -141,30 +143,32 @@ class TaskChannel:
|
|
|
141
143
|
r"""Return a task to the sender, indicating that the task has been
|
|
142
144
|
processed by the worker."""
|
|
143
145
|
async with self._condition:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
+
if task_id in self._task_dict:
|
|
147
|
+
packet = self._task_dict[task_id]
|
|
148
|
+
packet.status = PacketStatus.RETURNED
|
|
146
149
|
self._condition.notify_all()
|
|
147
150
|
|
|
148
151
|
async def archive_task(self, task_id: str) -> None:
|
|
149
152
|
r"""Archive a task in channel, making it to become a dependency."""
|
|
150
153
|
async with self._condition:
|
|
151
|
-
|
|
152
|
-
|
|
154
|
+
if task_id in self._task_dict:
|
|
155
|
+
packet = self._task_dict[task_id]
|
|
156
|
+
packet.status = PacketStatus.ARCHIVED
|
|
153
157
|
self._condition.notify_all()
|
|
154
158
|
|
|
155
159
|
async def remove_task(self, task_id: str) -> None:
|
|
156
160
|
r"""Remove a task from the channel."""
|
|
157
161
|
async with self._condition:
|
|
158
|
-
|
|
159
|
-
self._task_dict
|
|
162
|
+
# Check if task ID exists before removing
|
|
163
|
+
if task_id in self._task_dict:
|
|
164
|
+
del self._task_dict[task_id]
|
|
160
165
|
self._condition.notify_all()
|
|
161
166
|
|
|
162
167
|
async def get_dependency_ids(self) -> List[str]:
|
|
163
168
|
r"""Get the IDs of all dependencies in the channel."""
|
|
164
169
|
async with self._condition:
|
|
165
170
|
dependency_ids = []
|
|
166
|
-
for task_id in self.
|
|
167
|
-
packet = self._task_dict[task_id]
|
|
171
|
+
for task_id, packet in self._task_dict.items():
|
|
168
172
|
if packet.status == PacketStatus.ARCHIVED:
|
|
169
173
|
dependency_ids.append(task_id)
|
|
170
174
|
return dependency_ids
|
|
@@ -172,11 +176,12 @@ class TaskChannel:
|
|
|
172
176
|
async def get_task_by_id(self, task_id: str) -> Task:
|
|
173
177
|
r"""Get a task from the channel by its ID."""
|
|
174
178
|
async with self._condition:
|
|
175
|
-
|
|
179
|
+
packet = self._task_dict.get(task_id)
|
|
180
|
+
if packet is None:
|
|
176
181
|
raise ValueError(f"Task {task_id} not found.")
|
|
177
|
-
return
|
|
182
|
+
return packet.task
|
|
178
183
|
|
|
179
184
|
async def get_channel_debug_info(self) -> str:
|
|
180
185
|
r"""Get the debug information of the channel."""
|
|
181
186
|
async with self._condition:
|
|
182
|
-
return str(self._task_dict)
|
|
187
|
+
return str(self._task_dict)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
from functools import wraps
|
|
15
|
-
from typing import Callable
|
|
15
|
+
from typing import Callable, List
|
|
16
16
|
|
|
17
17
|
from pydantic import BaseModel, Field
|
|
18
18
|
|
|
@@ -41,32 +41,122 @@ class TaskResult(BaseModel):
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
class
|
|
45
|
-
r"""
|
|
44
|
+
class TaskAssignment(BaseModel):
|
|
45
|
+
r"""An individual task assignment within a batch."""
|
|
46
46
|
|
|
47
|
+
task_id: str = Field(description="The ID of the task to be assigned.")
|
|
47
48
|
assignee_id: str = Field(
|
|
48
|
-
description="The ID of the workforce
|
|
49
|
+
description="The ID of the worker/workforce to assign the task to."
|
|
50
|
+
)
|
|
51
|
+
dependencies: List[str] = Field(
|
|
52
|
+
default_factory=list,
|
|
53
|
+
description="List of task IDs that must complete before this task.",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TaskAssignResult(BaseModel):
|
|
58
|
+
r"""The result of task assignment for both single and batch assignments."""
|
|
59
|
+
|
|
60
|
+
assignments: List[TaskAssignment] = Field(
|
|
61
|
+
description="List of task assignments."
|
|
49
62
|
)
|
|
50
63
|
|
|
51
64
|
|
|
52
|
-
def check_if_running(
|
|
53
|
-
|
|
54
|
-
|
|
65
|
+
def check_if_running(
|
|
66
|
+
running: bool,
|
|
67
|
+
max_retries: int = 3,
|
|
68
|
+
retry_delay: float = 1.0,
|
|
69
|
+
handle_exceptions: bool = False,
|
|
70
|
+
) -> Callable:
|
|
71
|
+
r"""Check if the workforce is (not) running, specified by the boolean
|
|
72
|
+
value. Provides fault tolerance through automatic retries and exception
|
|
73
|
+
handling.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
running (bool): Expected running state (True or False).
|
|
77
|
+
max_retries (int, optional): Maximum number of retry attempts if the
|
|
78
|
+
operation fails. Set to 0 to disable retries. (default: :obj:`3`)
|
|
79
|
+
retry_delay (float, optional): Delay in seconds between retry attempts.
|
|
80
|
+
(default: :obj:`1.0`)
|
|
81
|
+
handle_exceptions (bool, optional): If True, catch and log exceptions
|
|
82
|
+
instead of propagating them. (default: :obj:`False`)
|
|
55
83
|
|
|
56
84
|
Raises:
|
|
57
|
-
RuntimeError: If the workforce is not in the expected status
|
|
85
|
+
RuntimeError: If the workforce is not in the expected status and
|
|
86
|
+
retries are exhausted or disabled.
|
|
87
|
+
Exception: Any exception raised by the decorated function if
|
|
88
|
+
handle_exceptions is False and retries are exhausted.
|
|
58
89
|
"""
|
|
90
|
+
import logging
|
|
91
|
+
import time
|
|
92
|
+
|
|
93
|
+
logger = logging.getLogger(__name__)
|
|
59
94
|
|
|
60
95
|
def decorator(func):
|
|
61
96
|
@wraps(func)
|
|
62
97
|
def wrapper(self, *args, **kwargs):
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
98
|
+
retries = 0
|
|
99
|
+
last_exception = None
|
|
100
|
+
|
|
101
|
+
while retries <= max_retries:
|
|
102
|
+
try:
|
|
103
|
+
# Check running state
|
|
104
|
+
if self._running != running:
|
|
105
|
+
status = "not running" if running else "running"
|
|
106
|
+
error_msg = (
|
|
107
|
+
f"The workforce is {status}. Cannot perform the "
|
|
108
|
+
f"operation {func.__name__}."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# If we have retries left, wait and try again
|
|
112
|
+
if retries < max_retries:
|
|
113
|
+
logger.warning(
|
|
114
|
+
f"{error_msg} Retrying in {retry_delay}s... "
|
|
115
|
+
f"(Attempt {retries+1}/{max_retries})"
|
|
116
|
+
)
|
|
117
|
+
time.sleep(retry_delay)
|
|
118
|
+
retries += 1
|
|
119
|
+
continue
|
|
120
|
+
else:
|
|
121
|
+
raise RuntimeError(error_msg)
|
|
122
|
+
|
|
123
|
+
return func(self, *args, **kwargs)
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
last_exception = e
|
|
127
|
+
|
|
128
|
+
if isinstance(e, RuntimeError) and "workforce is" in str(
|
|
129
|
+
e
|
|
130
|
+
):
|
|
131
|
+
raise
|
|
132
|
+
|
|
133
|
+
if retries < max_retries:
|
|
134
|
+
logger.warning(
|
|
135
|
+
f"Exception in {func.__name__}: {e}. "
|
|
136
|
+
f"Retrying in {retry_delay}s... "
|
|
137
|
+
f"(Attempt {retries+1}/{max_retries})"
|
|
138
|
+
)
|
|
139
|
+
time.sleep(retry_delay)
|
|
140
|
+
retries += 1
|
|
141
|
+
else:
|
|
142
|
+
if handle_exceptions:
|
|
143
|
+
logger.error(
|
|
144
|
+
f"Failed to execute {func.__name__} after "
|
|
145
|
+
f"{max_retries} retries: {e}"
|
|
146
|
+
)
|
|
147
|
+
return None
|
|
148
|
+
else:
|
|
149
|
+
# Re-raise the exception
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
# This should not be reached, but just in case
|
|
153
|
+
if handle_exceptions:
|
|
154
|
+
logger.error(
|
|
155
|
+
f"Unexpected failure in {func.__name__}: {last_exception}"
|
|
68
156
|
)
|
|
69
|
-
|
|
157
|
+
return None
|
|
158
|
+
else:
|
|
159
|
+
raise last_exception
|
|
70
160
|
|
|
71
161
|
return wrapper
|
|
72
162
|
|
|
@@ -13,9 +13,10 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import logging
|
|
17
18
|
from abc import ABC, abstractmethod
|
|
18
|
-
from typing import List, Optional
|
|
19
|
+
from typing import List, Optional, Set
|
|
19
20
|
|
|
20
21
|
from colorama import Fore
|
|
21
22
|
|
|
@@ -35,14 +36,19 @@ class Worker(BaseNode, ABC):
|
|
|
35
36
|
description (str): Description of the node.
|
|
36
37
|
node_id (Optional[str]): ID of the node. If not provided, it will
|
|
37
38
|
be generated automatically. (default: :obj:`None`)
|
|
39
|
+
max_concurrent_tasks (int): Maximum number of tasks this worker can
|
|
40
|
+
process concurrently. (default: :obj:`10`)
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
43
|
def __init__(
|
|
41
44
|
self,
|
|
42
45
|
description: str,
|
|
43
46
|
node_id: Optional[str] = None,
|
|
47
|
+
max_concurrent_tasks: int = 10,
|
|
44
48
|
) -> None:
|
|
45
49
|
super().__init__(description, node_id=node_id)
|
|
50
|
+
self.max_concurrent_tasks = max_concurrent_tasks
|
|
51
|
+
self._active_task_ids: Set[str] = set()
|
|
46
52
|
|
|
47
53
|
def __repr__(self):
|
|
48
54
|
return f"Worker node {self.node_id} ({self.description})"
|
|
@@ -60,7 +66,7 @@ class Worker(BaseNode, ABC):
|
|
|
60
66
|
pass
|
|
61
67
|
|
|
62
68
|
async def _get_assigned_task(self) -> Task:
|
|
63
|
-
r"""Get
|
|
69
|
+
r"""Get a task assigned to this node from the channel."""
|
|
64
70
|
return await self._channel.get_assigned_task_by_assignee(self.node_id)
|
|
65
71
|
|
|
66
72
|
@staticmethod
|
|
@@ -77,20 +83,10 @@ class Worker(BaseNode, ABC):
|
|
|
77
83
|
def set_channel(self, channel: TaskChannel):
|
|
78
84
|
self._channel = channel
|
|
79
85
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
This method should be run in an event loop, as it will run
|
|
86
|
-
indefinitely.
|
|
87
|
-
"""
|
|
88
|
-
self._running = True
|
|
89
|
-
logger.info(f"{self} started.")
|
|
90
|
-
|
|
91
|
-
while True:
|
|
92
|
-
# Get the earliest task assigned to this node
|
|
93
|
-
task = await self._get_assigned_task()
|
|
86
|
+
async def _process_single_task(self, task: Task) -> None:
|
|
87
|
+
r"""Process a single task and handle its completion/failure."""
|
|
88
|
+
try:
|
|
89
|
+
self._active_task_ids.add(task.id)
|
|
94
90
|
print(
|
|
95
91
|
f"{Fore.YELLOW}{self} get task {task.id}: {task.content}"
|
|
96
92
|
f"{Fore.RESET}"
|
|
@@ -109,6 +105,92 @@ class Worker(BaseNode, ABC):
|
|
|
109
105
|
task.set_state(task_state)
|
|
110
106
|
|
|
111
107
|
await self._channel.return_task(task.id)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"Error processing task {task.id}: {e}")
|
|
110
|
+
task.set_state(TaskState.FAILED)
|
|
111
|
+
await self._channel.return_task(task.id)
|
|
112
|
+
finally:
|
|
113
|
+
self._active_task_ids.discard(task.id)
|
|
114
|
+
|
|
115
|
+
@check_if_running(False)
|
|
116
|
+
async def _listen_to_channel(self):
|
|
117
|
+
r"""Continuously listen to the channel, process tasks that are
|
|
118
|
+
assigned to this node concurrently up to max_concurrent_tasks limit.
|
|
119
|
+
|
|
120
|
+
This method supports parallel task execution when multiple tasks
|
|
121
|
+
are assigned to the same worker.
|
|
122
|
+
"""
|
|
123
|
+
self._running = True
|
|
124
|
+
logger.info(
|
|
125
|
+
f"{self} started with max {self.max_concurrent_tasks} "
|
|
126
|
+
f"concurrent tasks."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Keep track of running task coroutines
|
|
130
|
+
running_tasks: Set[asyncio.Task] = set()
|
|
131
|
+
|
|
132
|
+
while self._running:
|
|
133
|
+
try:
|
|
134
|
+
# Clean up completed tasks
|
|
135
|
+
completed_tasks = [t for t in running_tasks if t.done()]
|
|
136
|
+
for completed_task in completed_tasks:
|
|
137
|
+
running_tasks.remove(completed_task)
|
|
138
|
+
# Check for exceptions in completed tasks
|
|
139
|
+
try:
|
|
140
|
+
await completed_task
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Task processing failed: {e}")
|
|
143
|
+
|
|
144
|
+
# Check if we can accept more tasks
|
|
145
|
+
if len(running_tasks) < self.max_concurrent_tasks:
|
|
146
|
+
try:
|
|
147
|
+
# Try to get a new task (with short timeout to avoid
|
|
148
|
+
# blocking)
|
|
149
|
+
task = await asyncio.wait_for(
|
|
150
|
+
self._get_assigned_task(), timeout=1.0
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Create and start processing task
|
|
154
|
+
task_coroutine = asyncio.create_task(
|
|
155
|
+
self._process_single_task(task)
|
|
156
|
+
)
|
|
157
|
+
running_tasks.add(task_coroutine)
|
|
158
|
+
|
|
159
|
+
except asyncio.TimeoutError:
|
|
160
|
+
# No tasks available, continue loop
|
|
161
|
+
if not running_tasks:
|
|
162
|
+
# No tasks running and none available, short sleep
|
|
163
|
+
await asyncio.sleep(0.1)
|
|
164
|
+
continue
|
|
165
|
+
else:
|
|
166
|
+
# At max capacity, wait for at least one task to complete
|
|
167
|
+
if running_tasks:
|
|
168
|
+
done, running_tasks = await asyncio.wait(
|
|
169
|
+
running_tasks, return_when=asyncio.FIRST_COMPLETED
|
|
170
|
+
)
|
|
171
|
+
# Process completed tasks
|
|
172
|
+
for completed_task in done:
|
|
173
|
+
try:
|
|
174
|
+
await completed_task
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Task processing failed: {e}")
|
|
177
|
+
|
|
178
|
+
except Exception as e:
|
|
179
|
+
logger.error(
|
|
180
|
+
f"Error in worker {self.node_id} listen loop: {e}"
|
|
181
|
+
)
|
|
182
|
+
await asyncio.sleep(0.1)
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
# Wait for all remaining tasks to complete when stopping
|
|
186
|
+
if running_tasks:
|
|
187
|
+
logger.info(
|
|
188
|
+
f"{self} stopping, waiting for {len(running_tasks)} "
|
|
189
|
+
f"tasks to complete..."
|
|
190
|
+
)
|
|
191
|
+
await asyncio.gather(*running_tasks, return_exceptions=True)
|
|
192
|
+
|
|
193
|
+
logger.info(f"{self} stopped.")
|
|
112
194
|
|
|
113
195
|
@check_if_running(False)
|
|
114
196
|
async def start(self):
|