camel-ai 0.2.66__py3-none-any.whl → 0.2.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (68) hide show
  1. camel/__init__.py +1 -1
  2. camel/configs/__init__.py +3 -0
  3. camel/configs/qianfan_config.py +85 -0
  4. camel/environments/__init__.py +12 -0
  5. camel/environments/rlcards_env.py +860 -0
  6. camel/interpreters/docker/Dockerfile +2 -5
  7. camel/loaders/firecrawl_reader.py +4 -4
  8. camel/memories/blocks/vectordb_block.py +8 -1
  9. camel/memories/context_creators/score_based.py +123 -19
  10. camel/models/__init__.py +2 -0
  11. camel/models/aiml_model.py +8 -0
  12. camel/models/anthropic_model.py +122 -2
  13. camel/models/aws_bedrock_model.py +8 -0
  14. camel/models/azure_openai_model.py +14 -5
  15. camel/models/base_model.py +4 -0
  16. camel/models/cohere_model.py +9 -2
  17. camel/models/crynux_model.py +8 -0
  18. camel/models/deepseek_model.py +8 -0
  19. camel/models/gemini_model.py +8 -0
  20. camel/models/groq_model.py +8 -0
  21. camel/models/internlm_model.py +8 -0
  22. camel/models/litellm_model.py +5 -0
  23. camel/models/lmstudio_model.py +14 -1
  24. camel/models/mistral_model.py +15 -1
  25. camel/models/model_factory.py +6 -0
  26. camel/models/modelscope_model.py +8 -0
  27. camel/models/moonshot_model.py +8 -0
  28. camel/models/nemotron_model.py +17 -2
  29. camel/models/netmind_model.py +8 -0
  30. camel/models/novita_model.py +8 -0
  31. camel/models/nvidia_model.py +8 -0
  32. camel/models/ollama_model.py +8 -0
  33. camel/models/openai_compatible_model.py +23 -5
  34. camel/models/openai_model.py +21 -4
  35. camel/models/openrouter_model.py +8 -0
  36. camel/models/ppio_model.py +8 -0
  37. camel/models/qianfan_model.py +104 -0
  38. camel/models/qwen_model.py +8 -0
  39. camel/models/reka_model.py +18 -3
  40. camel/models/samba_model.py +17 -3
  41. camel/models/sglang_model.py +20 -5
  42. camel/models/siliconflow_model.py +8 -0
  43. camel/models/stub_model.py +8 -1
  44. camel/models/togetherai_model.py +8 -0
  45. camel/models/vllm_model.py +7 -0
  46. camel/models/volcano_model.py +14 -1
  47. camel/models/watsonx_model.py +4 -1
  48. camel/models/yi_model.py +8 -0
  49. camel/models/zhipuai_model.py +8 -0
  50. camel/societies/workforce/prompts.py +71 -22
  51. camel/societies/workforce/role_playing_worker.py +3 -8
  52. camel/societies/workforce/single_agent_worker.py +37 -9
  53. camel/societies/workforce/task_channel.py +25 -20
  54. camel/societies/workforce/utils.py +104 -14
  55. camel/societies/workforce/worker.py +98 -16
  56. camel/societies/workforce/workforce.py +1289 -101
  57. camel/societies/workforce/workforce_logger.py +613 -0
  58. camel/tasks/task.py +16 -5
  59. camel/toolkits/__init__.py +2 -0
  60. camel/toolkits/code_execution.py +1 -1
  61. camel/toolkits/playwright_mcp_toolkit.py +2 -1
  62. camel/toolkits/pptx_toolkit.py +4 -4
  63. camel/types/enums.py +32 -0
  64. camel/types/unified_model_type.py +5 -0
  65. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/METADATA +4 -3
  66. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/RECORD +68 -64
  67. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/WHEEL +0 -0
  68. {camel_ai-0.2.66.dist-info → camel_ai-0.2.68.dist-info}/licenses/LICENSE +0 -0
@@ -28,7 +28,6 @@ from camel.societies.workforce.prompts import (
28
28
  from camel.societies.workforce.utils import TaskResult
29
29
  from camel.societies.workforce.worker import Worker
30
30
  from camel.tasks.task import Task, TaskState, validate_task_content
31
- from camel.utils import print_text_animated
32
31
 
33
32
 
34
33
  class RolePlayingWorker(Worker):
@@ -141,24 +140,20 @@ class RolePlayingWorker(Worker):
141
140
  )
142
141
  break
143
142
 
144
- print_text_animated(
143
+ print(
145
144
  f"{Fore.BLUE}AI User:\n\n{user_response.msg.content}"
146
145
  f"{Fore.RESET}\n",
147
- delay=0.005,
148
146
  )
149
147
  chat_history.append(f"AI User: {user_response.msg.content}")
150
148
 
151
- print_text_animated(
152
- f"{Fore.GREEN}AI Assistant:{Fore.RESET}", delay=0.005
153
- )
149
+ print(f"{Fore.GREEN}AI Assistant:{Fore.RESET}")
154
150
 
155
151
  for func_record in assistant_response.info['tool_calls']:
156
152
  print(func_record)
157
153
 
158
- print_text_animated(
154
+ print(
159
155
  f"\n{Fore.GREEN}{assistant_response.msg.content}"
160
156
  f"{Fore.RESET}\n",
161
- delay=0.005,
162
157
  )
163
158
  chat_history.append(
164
159
  f"AI Assistant: {assistant_response.msg.content}"
@@ -24,7 +24,6 @@ from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
24
24
  from camel.societies.workforce.utils import TaskResult
25
25
  from camel.societies.workforce.worker import Worker
26
26
  from camel.tasks.task import Task, TaskState, validate_task_content
27
- from camel.utils import print_text_animated
28
27
 
29
28
 
30
29
  class SingleAgentWorker(Worker):
@@ -33,15 +32,22 @@ class SingleAgentWorker(Worker):
33
32
  Args:
34
33
  description (str): Description of the node.
35
34
  worker (ChatAgent): Worker of the node. A single agent.
35
+ max_concurrent_tasks (int): Maximum number of tasks this worker can
36
+ process concurrently. (default: :obj:`10`)
36
37
  """
37
38
 
38
39
  def __init__(
39
40
  self,
40
41
  description: str,
41
42
  worker: ChatAgent,
43
+ max_concurrent_tasks: int = 10,
42
44
  ) -> None:
43
45
  node_id = worker.agent_id
44
- super().__init__(description, node_id=node_id)
46
+ super().__init__(
47
+ description,
48
+ node_id=node_id,
49
+ max_concurrent_tasks=max_concurrent_tasks,
50
+ )
45
51
  self.worker = worker
46
52
 
47
53
  def reset(self) -> Any:
@@ -52,11 +58,12 @@ class SingleAgentWorker(Worker):
52
58
  async def _process_task(
53
59
  self, task: Task, dependencies: List[Task]
54
60
  ) -> TaskState:
55
- r"""Processes a task with its dependencies.
61
+ r"""Processes a task with its dependencies using a cloned agent.
56
62
 
57
63
  This method asynchronously processes a given task, considering its
58
- dependencies, by sending a generated prompt to a worker. It updates
59
- the task's result based on the agent's response.
64
+ dependencies, by sending a generated prompt to a cloned worker agent.
65
+ Using a cloned agent ensures that concurrent tasks don't interfere
66
+ with each other's state.
60
67
 
61
68
  Args:
62
69
  task (Task): The task to process, which includes necessary details
@@ -67,6 +74,10 @@ class SingleAgentWorker(Worker):
67
74
  TaskState: `TaskState.DONE` if processed successfully, otherwise
68
75
  `TaskState.FAILED`.
69
76
  """
77
+ # Clone the agent for this specific task to avoid state conflicts
78
+ # when processing multiple tasks concurrently
79
+ worker_agent = self.worker.clone(with_memory=False)
80
+
70
81
  dependency_tasks_info = self._get_dep_tasks_info(dependencies)
71
82
  prompt = PROCESS_TASK_PROMPT.format(
72
83
  content=task.content,
@@ -74,7 +85,7 @@ class SingleAgentWorker(Worker):
74
85
  additional_info=task.additional_info,
75
86
  )
76
87
  try:
77
- response = await self.worker.astep(
88
+ response = await worker_agent.astep(
78
89
  prompt, response_format=TaskResult
79
90
  )
80
91
  except Exception as e:
@@ -84,6 +95,13 @@ class SingleAgentWorker(Worker):
84
95
  )
85
96
  return TaskState.FAILED
86
97
 
98
+ # Get actual token usage from the cloned agent that processed this task
99
+ try:
100
+ _, total_token_count = worker_agent.memory.get_context()
101
+ except Exception:
102
+ # Fallback if memory context unavailable
103
+ total_token_count = 0
104
+
87
105
  # Populate additional_info with worker attempt details
88
106
  if task.additional_info is None:
89
107
  task.additional_info = {}
@@ -91,14 +109,20 @@ class SingleAgentWorker(Worker):
91
109
  # Create worker attempt details with descriptive keys
92
110
  worker_attempt_details = {
93
111
  "agent_id": getattr(
112
+ worker_agent, "agent_id", worker_agent.role_name
113
+ ),
114
+ "original_worker_id": getattr(
94
115
  self.worker, "agent_id", self.worker.role_name
95
116
  ),
96
117
  "timestamp": str(datetime.datetime.now()),
97
118
  "description": f"Attempt by "
98
- f"{getattr(self.worker, 'agent_id', self.worker.role_name)} "
119
+ f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
120
+ f"(cloned from "
121
+ f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
99
122
  f"to process task {task.content}",
100
123
  "response_content": response.msg.content,
101
124
  "tool_calls": response.info["tool_calls"],
125
+ "total_token_count": total_token_count,
102
126
  }
103
127
 
104
128
  # Store the worker attempt in additional_info
@@ -106,15 +130,19 @@ class SingleAgentWorker(Worker):
106
130
  task.additional_info["worker_attempts"] = []
107
131
  task.additional_info["worker_attempts"].append(worker_attempt_details)
108
132
 
133
+ # Store the actual token usage for this specific task
134
+ task.additional_info["token_usage"] = {
135
+ "total_tokens": total_token_count
136
+ }
137
+
109
138
  print(f"======\n{Fore.GREEN}Reply from {self}:{Fore.RESET}")
110
139
 
111
140
  result_dict = json.loads(response.msg.content)
112
141
  task_result = TaskResult(**result_dict)
113
142
 
114
143
  color = Fore.RED if task_result.failed else Fore.GREEN
115
- print_text_animated(
144
+ print(
116
145
  f"\n{color}{task_result.content}{Fore.RESET}\n======",
117
- delay=0.005,
118
146
  )
119
147
 
120
148
  if task_result.failed:
@@ -23,6 +23,8 @@ class PacketStatus(Enum):
23
23
  states:
24
24
 
25
25
  - ``SENT``: The packet has been sent to a worker.
26
+ - ``PROCESSING``: The packet has been claimed by a worker and is being
27
+ processed.
26
28
  - ``RETURNED``: The packet has been returned by the worker, meaning that
27
29
  the status of the task inside has been updated.
28
30
  - ``ARCHIVED``: The packet has been archived, meaning that the content of
@@ -31,6 +33,7 @@ class PacketStatus(Enum):
31
33
  """
32
34
 
33
35
  SENT = "SENT"
36
+ PROCESSING = "PROCESSING"
34
37
  RETURNED = "RETURNED"
35
38
  ARCHIVED = "ARCHIVED"
36
39
 
@@ -79,7 +82,6 @@ class TaskChannel:
79
82
  r"""An internal class used by Workforce to manage tasks."""
80
83
 
81
84
  def __init__(self) -> None:
82
- self._task_id_list: List[str] = []
83
85
  self._condition = asyncio.Condition()
84
86
  self._task_dict: Dict[str, Packet] = {}
85
87
 
@@ -89,8 +91,7 @@ class TaskChannel:
89
91
  """
90
92
  async with self._condition:
91
93
  while True:
92
- for task_id in self._task_id_list:
93
- packet = self._task_dict[task_id]
94
+ for packet in self._task_dict.values():
94
95
  if packet.publisher_id != publisher_id:
95
96
  continue
96
97
  if packet.status != PacketStatus.RETURNED:
@@ -99,17 +100,20 @@ class TaskChannel:
99
100
  await self._condition.wait()
100
101
 
101
102
  async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
102
- r"""Get a task from the channel that has been assigned to the
103
- assignee.
103
+ r"""Atomically get and claim a task from the channel that has been
104
+ assigned to the assignee. This prevents race conditions where multiple
105
+ concurrent calls might retrieve the same task.
104
106
  """
105
107
  async with self._condition:
106
108
  while True:
107
- for task_id in self._task_id_list:
108
- packet = self._task_dict[task_id]
109
+ for packet in self._task_dict.values():
109
110
  if (
110
111
  packet.status == PacketStatus.SENT
111
112
  and packet.assignee_id == assignee_id
112
113
  ):
114
+ # Atomically claim the task by changing its status
115
+ packet.status = PacketStatus.PROCESSING
116
+ self._condition.notify_all()
113
117
  return packet.task
114
118
  await self._condition.wait()
115
119
 
@@ -119,7 +123,6 @@ class TaskChannel:
119
123
  r"""Send a task to the channel with specified publisher and assignee,
120
124
  along with the dependency of the task."""
121
125
  async with self._condition:
122
- self._task_id_list.append(task.id)
123
126
  packet = Packet(task, publisher_id, assignee_id)
124
127
  self._task_dict[packet.task.id] = packet
125
128
  self._condition.notify_all()
@@ -130,7 +133,6 @@ class TaskChannel:
130
133
  r"""Post a dependency to the channel. A dependency is a task that is
131
134
  archived, and will be referenced by other tasks."""
132
135
  async with self._condition:
133
- self._task_id_list.append(dependency.id)
134
136
  packet = Packet(
135
137
  dependency, publisher_id, status=PacketStatus.ARCHIVED
136
138
  )
@@ -141,30 +143,32 @@ class TaskChannel:
141
143
  r"""Return a task to the sender, indicating that the task has been
142
144
  processed by the worker."""
143
145
  async with self._condition:
144
- packet = self._task_dict[task_id]
145
- packet.status = PacketStatus.RETURNED
146
+ if task_id in self._task_dict:
147
+ packet = self._task_dict[task_id]
148
+ packet.status = PacketStatus.RETURNED
146
149
  self._condition.notify_all()
147
150
 
148
151
  async def archive_task(self, task_id: str) -> None:
149
152
  r"""Archive a task in channel, making it to become a dependency."""
150
153
  async with self._condition:
151
- packet = self._task_dict[task_id]
152
- packet.status = PacketStatus.ARCHIVED
154
+ if task_id in self._task_dict:
155
+ packet = self._task_dict[task_id]
156
+ packet.status = PacketStatus.ARCHIVED
153
157
  self._condition.notify_all()
154
158
 
155
159
  async def remove_task(self, task_id: str) -> None:
156
160
  r"""Remove a task from the channel."""
157
161
  async with self._condition:
158
- self._task_id_list.remove(task_id)
159
- self._task_dict.pop(task_id)
162
+ # Check if task ID exists before removing
163
+ if task_id in self._task_dict:
164
+ del self._task_dict[task_id]
160
165
  self._condition.notify_all()
161
166
 
162
167
  async def get_dependency_ids(self) -> List[str]:
163
168
  r"""Get the IDs of all dependencies in the channel."""
164
169
  async with self._condition:
165
170
  dependency_ids = []
166
- for task_id in self._task_id_list:
167
- packet = self._task_dict[task_id]
171
+ for task_id, packet in self._task_dict.items():
168
172
  if packet.status == PacketStatus.ARCHIVED:
169
173
  dependency_ids.append(task_id)
170
174
  return dependency_ids
@@ -172,11 +176,12 @@ class TaskChannel:
172
176
  async def get_task_by_id(self, task_id: str) -> Task:
173
177
  r"""Get a task from the channel by its ID."""
174
178
  async with self._condition:
175
- if task_id not in self._task_id_list:
179
+ packet = self._task_dict.get(task_id)
180
+ if packet is None:
176
181
  raise ValueError(f"Task {task_id} not found.")
177
- return self._task_dict[task_id].task
182
+ return packet.task
178
183
 
179
184
  async def get_channel_debug_info(self) -> str:
180
185
  r"""Get the debug information of the channel."""
181
186
  async with self._condition:
182
- return str(self._task_dict) + '\n' + str(self._task_id_list)
187
+ return str(self._task_dict)
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  from functools import wraps
15
- from typing import Callable
15
+ from typing import Callable, List
16
16
 
17
17
  from pydantic import BaseModel, Field
18
18
 
@@ -41,32 +41,122 @@ class TaskResult(BaseModel):
41
41
  )
42
42
 
43
43
 
44
- class TaskAssignResult(BaseModel):
45
- r"""The result of task assignment."""
44
+ class TaskAssignment(BaseModel):
45
+ r"""An individual task assignment within a batch."""
46
46
 
47
+ task_id: str = Field(description="The ID of the task to be assigned.")
47
48
  assignee_id: str = Field(
48
- description="The ID of the workforce that is assigned to the task."
49
+ description="The ID of the worker/workforce to assign the task to."
50
+ )
51
+ dependencies: List[str] = Field(
52
+ default_factory=list,
53
+ description="List of task IDs that must complete before this task.",
54
+ )
55
+
56
+
57
+ class TaskAssignResult(BaseModel):
58
+ r"""The result of task assignment for both single and batch assignments."""
59
+
60
+ assignments: List[TaskAssignment] = Field(
61
+ description="List of task assignments."
49
62
  )
50
63
 
51
64
 
52
- def check_if_running(running: bool) -> Callable:
53
- r"""Check if the workforce is (not) running, specified the boolean value.
54
- If the workforce is not in the expected status, raise an exception.
65
+ def check_if_running(
66
+ running: bool,
67
+ max_retries: int = 3,
68
+ retry_delay: float = 1.0,
69
+ handle_exceptions: bool = False,
70
+ ) -> Callable:
71
+ r"""Check if the workforce is (not) running, specified by the boolean
72
+ value. Provides fault tolerance through automatic retries and exception
73
+ handling.
74
+
75
+ Args:
76
+ running (bool): Expected running state (True or False).
77
+ max_retries (int, optional): Maximum number of retry attempts if the
78
+ operation fails. Set to 0 to disable retries. (default: :obj:`3`)
79
+ retry_delay (float, optional): Delay in seconds between retry attempts.
80
+ (default: :obj:`1.0`)
81
+ handle_exceptions (bool, optional): If True, catch and log exceptions
82
+ instead of propagating them. (default: :obj:`False`)
55
83
 
56
84
  Raises:
57
- RuntimeError: If the workforce is not in the expected status.
85
+ RuntimeError: If the workforce is not in the expected status and
86
+ retries are exhausted or disabled.
87
+ Exception: Any exception raised by the decorated function if
88
+ handle_exceptions is False and retries are exhausted.
58
89
  """
90
+ import logging
91
+ import time
92
+
93
+ logger = logging.getLogger(__name__)
59
94
 
60
95
  def decorator(func):
61
96
  @wraps(func)
62
97
  def wrapper(self, *args, **kwargs):
63
- if self._running != running:
64
- status = "not running" if running else "running"
65
- raise RuntimeError(
66
- f"The workforce is {status}. Cannot perform the "
67
- f"operation {func.__name__}."
98
+ retries = 0
99
+ last_exception = None
100
+
101
+ while retries <= max_retries:
102
+ try:
103
+ # Check running state
104
+ if self._running != running:
105
+ status = "not running" if running else "running"
106
+ error_msg = (
107
+ f"The workforce is {status}. Cannot perform the "
108
+ f"operation {func.__name__}."
109
+ )
110
+
111
+ # If we have retries left, wait and try again
112
+ if retries < max_retries:
113
+ logger.warning(
114
+ f"{error_msg} Retrying in {retry_delay}s... "
115
+ f"(Attempt {retries+1}/{max_retries})"
116
+ )
117
+ time.sleep(retry_delay)
118
+ retries += 1
119
+ continue
120
+ else:
121
+ raise RuntimeError(error_msg)
122
+
123
+ return func(self, *args, **kwargs)
124
+
125
+ except Exception as e:
126
+ last_exception = e
127
+
128
+ if isinstance(e, RuntimeError) and "workforce is" in str(
129
+ e
130
+ ):
131
+ raise
132
+
133
+ if retries < max_retries:
134
+ logger.warning(
135
+ f"Exception in {func.__name__}: {e}. "
136
+ f"Retrying in {retry_delay}s... "
137
+ f"(Attempt {retries+1}/{max_retries})"
138
+ )
139
+ time.sleep(retry_delay)
140
+ retries += 1
141
+ else:
142
+ if handle_exceptions:
143
+ logger.error(
144
+ f"Failed to execute {func.__name__} after "
145
+ f"{max_retries} retries: {e}"
146
+ )
147
+ return None
148
+ else:
149
+ # Re-raise the exception
150
+ raise
151
+
152
+ # This should not be reached, but just in case
153
+ if handle_exceptions:
154
+ logger.error(
155
+ f"Unexpected failure in {func.__name__}: {last_exception}"
68
156
  )
69
- return func(self, *args, **kwargs)
157
+ return None
158
+ else:
159
+ raise last_exception
70
160
 
71
161
  return wrapper
72
162
 
@@ -13,9 +13,10 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  from __future__ import annotations
15
15
 
16
+ import asyncio
16
17
  import logging
17
18
  from abc import ABC, abstractmethod
18
- from typing import List, Optional
19
+ from typing import List, Optional, Set
19
20
 
20
21
  from colorama import Fore
21
22
 
@@ -35,14 +36,19 @@ class Worker(BaseNode, ABC):
35
36
  description (str): Description of the node.
36
37
  node_id (Optional[str]): ID of the node. If not provided, it will
37
38
  be generated automatically. (default: :obj:`None`)
39
+ max_concurrent_tasks (int): Maximum number of tasks this worker can
40
+ process concurrently. (default: :obj:`10`)
38
41
  """
39
42
 
40
43
  def __init__(
41
44
  self,
42
45
  description: str,
43
46
  node_id: Optional[str] = None,
47
+ max_concurrent_tasks: int = 10,
44
48
  ) -> None:
45
49
  super().__init__(description, node_id=node_id)
50
+ self.max_concurrent_tasks = max_concurrent_tasks
51
+ self._active_task_ids: Set[str] = set()
46
52
 
47
53
  def __repr__(self):
48
54
  return f"Worker node {self.node_id} ({self.description})"
@@ -60,7 +66,7 @@ class Worker(BaseNode, ABC):
60
66
  pass
61
67
 
62
68
  async def _get_assigned_task(self) -> Task:
63
- r"""Get the task assigned to this node from the channel."""
69
+ r"""Get a task assigned to this node from the channel."""
64
70
  return await self._channel.get_assigned_task_by_assignee(self.node_id)
65
71
 
66
72
  @staticmethod
@@ -77,20 +83,10 @@ class Worker(BaseNode, ABC):
77
83
  def set_channel(self, channel: TaskChannel):
78
84
  self._channel = channel
79
85
 
80
- @check_if_running(False)
81
- async def _listen_to_channel(self):
82
- """Continuously listen to the channel, process the task that are
83
- assigned to this node, and update the result and status of the task.
84
-
85
- This method should be run in an event loop, as it will run
86
- indefinitely.
87
- """
88
- self._running = True
89
- logger.info(f"{self} started.")
90
-
91
- while True:
92
- # Get the earliest task assigned to this node
93
- task = await self._get_assigned_task()
86
+ async def _process_single_task(self, task: Task) -> None:
87
+ r"""Process a single task and handle its completion/failure."""
88
+ try:
89
+ self._active_task_ids.add(task.id)
94
90
  print(
95
91
  f"{Fore.YELLOW}{self} get task {task.id}: {task.content}"
96
92
  f"{Fore.RESET}"
@@ -109,6 +105,92 @@ class Worker(BaseNode, ABC):
109
105
  task.set_state(task_state)
110
106
 
111
107
  await self._channel.return_task(task.id)
108
+ except Exception as e:
109
+ logger.error(f"Error processing task {task.id}: {e}")
110
+ task.set_state(TaskState.FAILED)
111
+ await self._channel.return_task(task.id)
112
+ finally:
113
+ self._active_task_ids.discard(task.id)
114
+
115
+ @check_if_running(False)
116
+ async def _listen_to_channel(self):
117
+ r"""Continuously listen to the channel, process tasks that are
118
+ assigned to this node concurrently up to max_concurrent_tasks limit.
119
+
120
+ This method supports parallel task execution when multiple tasks
121
+ are assigned to the same worker.
122
+ """
123
+ self._running = True
124
+ logger.info(
125
+ f"{self} started with max {self.max_concurrent_tasks} "
126
+ f"concurrent tasks."
127
+ )
128
+
129
+ # Keep track of running task coroutines
130
+ running_tasks: Set[asyncio.Task] = set()
131
+
132
+ while self._running:
133
+ try:
134
+ # Clean up completed tasks
135
+ completed_tasks = [t for t in running_tasks if t.done()]
136
+ for completed_task in completed_tasks:
137
+ running_tasks.remove(completed_task)
138
+ # Check for exceptions in completed tasks
139
+ try:
140
+ await completed_task
141
+ except Exception as e:
142
+ logger.error(f"Task processing failed: {e}")
143
+
144
+ # Check if we can accept more tasks
145
+ if len(running_tasks) < self.max_concurrent_tasks:
146
+ try:
147
+ # Try to get a new task (with short timeout to avoid
148
+ # blocking)
149
+ task = await asyncio.wait_for(
150
+ self._get_assigned_task(), timeout=1.0
151
+ )
152
+
153
+ # Create and start processing task
154
+ task_coroutine = asyncio.create_task(
155
+ self._process_single_task(task)
156
+ )
157
+ running_tasks.add(task_coroutine)
158
+
159
+ except asyncio.TimeoutError:
160
+ # No tasks available, continue loop
161
+ if not running_tasks:
162
+ # No tasks running and none available, short sleep
163
+ await asyncio.sleep(0.1)
164
+ continue
165
+ else:
166
+ # At max capacity, wait for at least one task to complete
167
+ if running_tasks:
168
+ done, running_tasks = await asyncio.wait(
169
+ running_tasks, return_when=asyncio.FIRST_COMPLETED
170
+ )
171
+ # Process completed tasks
172
+ for completed_task in done:
173
+ try:
174
+ await completed_task
175
+ except Exception as e:
176
+ logger.error(f"Task processing failed: {e}")
177
+
178
+ except Exception as e:
179
+ logger.error(
180
+ f"Error in worker {self.node_id} listen loop: {e}"
181
+ )
182
+ await asyncio.sleep(0.1)
183
+ continue
184
+
185
+ # Wait for all remaining tasks to complete when stopping
186
+ if running_tasks:
187
+ logger.info(
188
+ f"{self} stopping, waiting for {len(running_tasks)} "
189
+ f"tasks to complete..."
190
+ )
191
+ await asyncio.gather(*running_tasks, return_exceptions=True)
192
+
193
+ logger.info(f"{self} stopped.")
112
194
 
113
195
  @check_if_running(False)
114
196
  async def start(self):