camel-ai 0.2.70__py3-none-any.whl → 0.2.71a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -43,7 +43,12 @@ from camel.societies.workforce.utils import (
43
43
  check_if_running,
44
44
  )
45
45
  from camel.societies.workforce.worker import Worker
46
- from camel.tasks.task import Task, TaskState, validate_task_content
46
+ from camel.tasks.task import (
47
+ Task,
48
+ TaskState,
49
+ is_task_result_insufficient,
50
+ validate_task_content,
51
+ )
47
52
  from camel.toolkits import (
48
53
  CodeExecutionToolkit,
49
54
  SearchToolkit,
@@ -57,6 +62,12 @@ from .workforce_logger import WorkforceLogger
57
62
 
58
63
  logger = get_logger(__name__)
59
64
 
65
+ # Constants for configuration values
66
+ MAX_TASK_RETRIES = 3
67
+ MAX_PENDING_TASKS_LIMIT = 20
68
+ TASK_TIMEOUT_SECONDS = 180.0
69
+ DEFAULT_WORKER_POOL_SIZE = 10
70
+
60
71
 
61
72
  class WorkforceState(Enum):
62
73
  r"""Workforce execution state for human intervention support."""
@@ -216,6 +227,7 @@ class Workforce(BaseNode):
216
227
  self._completed_tasks: List[Task] = []
217
228
  self._loop: Optional[asyncio.AbstractEventLoop] = None
218
229
  self._main_task_future: Optional[asyncio.Future] = None
230
+ self._cleanup_task: Optional[asyncio.Task] = None
219
231
  # Snapshot throttle support
220
232
  self._last_snapshot_time: float = 0.0
221
233
  # Minimum seconds between automatic snapshots
@@ -383,6 +395,40 @@ class Workforce(BaseNode):
383
395
  "better context continuity during task handoffs."
384
396
  )
385
397
 
398
+ # ------------------------------------------------------------------
399
+ # Helper for propagating pause control to externally supplied agents
400
+ # ------------------------------------------------------------------
401
+
402
+ def _attach_pause_event_to_agent(self, agent: ChatAgent) -> None:
403
+ r"""Ensure the given ChatAgent shares this workforce's pause_event.
404
+
405
+ If the agent already has a different pause_event we overwrite it and
406
+ emit a debug log (it is unlikely an agent needs multiple independent
407
+ pause controls once managed by this workforce)."""
408
+ try:
409
+ existing_pause_event = getattr(agent, "pause_event", None)
410
+ if existing_pause_event is not self._pause_event:
411
+ if existing_pause_event is not None:
412
+ logger.debug(
413
+ f"Overriding pause_event for agent {agent.agent_id} "
414
+ f"(had different pause_event: "
415
+ f"{id(existing_pause_event)} "
416
+ f"-> {id(self._pause_event)})"
417
+ )
418
+ agent.pause_event = self._pause_event
419
+ except AttributeError:
420
+ # Should not happen, but guard against unexpected objects
421
+ logger.warning(
422
+ f"Cannot attach pause_event to object {type(agent)} - "
423
+ f"missing pause_event attribute"
424
+ )
425
+
426
+ def _ensure_pause_event_in_kwargs(self, kwargs: Optional[Dict]) -> Dict:
427
+ r"""Insert pause_event into kwargs dict for ChatAgent construction."""
428
+ new_kwargs = dict(kwargs) if kwargs else {}
429
+ new_kwargs.setdefault("pause_event", self._pause_event)
430
+ return new_kwargs
431
+
386
432
  def __repr__(self):
387
433
  return (
388
434
  f"Workforce {self.node_id} ({self.description}) - "
@@ -517,6 +563,15 @@ class Workforce(BaseNode):
517
563
  except Exception as e:
518
564
  logger.warning(f"Error synchronizing shared memory: {e}")
519
565
 
566
+ def _cleanup_task_tracking(self, task_id: str) -> None:
567
+ r"""Clean up tracking data for a task to prevent memory leaks.
568
+
569
+ Args:
570
+ task_id (str): The ID of the task to clean up.
571
+ """
572
+ if task_id in self._task_start_times:
573
+ del self._task_start_times[task_id]
574
+
520
575
  def _decompose_task(self, task: Task) -> List[Task]:
521
576
  r"""Decompose the task into subtasks. This method will also set the
522
577
  relationship between the task and its subtasks.
@@ -1104,7 +1159,7 @@ class Workforce(BaseNode):
1104
1159
  self,
1105
1160
  description: str,
1106
1161
  worker: ChatAgent,
1107
- pool_max_size: int = 10,
1162
+ pool_max_size: int = DEFAULT_WORKER_POOL_SIZE,
1108
1163
  ) -> Workforce:
1109
1164
  r"""Add a worker node to the workforce that uses a single agent.
1110
1165
 
@@ -1117,6 +1172,9 @@ class Workforce(BaseNode):
1117
1172
  Returns:
1118
1173
  Workforce: The workforce node itself.
1119
1174
  """
1175
+ # Ensure the worker agent shares this workforce's pause control
1176
+ self._attach_pause_event_to_agent(worker)
1177
+
1120
1178
  worker_node = SingleAgentWorker(
1121
1179
  description=description,
1122
1180
  worker=worker,
@@ -1163,6 +1221,18 @@ class Workforce(BaseNode):
1163
1221
  Returns:
1164
1222
  Workforce: The workforce node itself.
1165
1223
  """
1224
+ # Ensure provided kwargs carry pause_event so that internally created
1225
+ # ChatAgents (assistant/user/summarizer) inherit it.
1226
+ assistant_agent_kwargs = self._ensure_pause_event_in_kwargs(
1227
+ assistant_agent_kwargs
1228
+ )
1229
+ user_agent_kwargs = self._ensure_pause_event_in_kwargs(
1230
+ user_agent_kwargs
1231
+ )
1232
+ summarize_agent_kwargs = self._ensure_pause_event_in_kwargs(
1233
+ summarize_agent_kwargs
1234
+ )
1235
+
1166
1236
  worker_node = RolePlayingWorker(
1167
1237
  description=description,
1168
1238
  assistant_role_name=assistant_role_name,
@@ -1191,6 +1261,9 @@ class Workforce(BaseNode):
1191
1261
  Returns:
1192
1262
  Workforce: The workforce node itself.
1193
1263
  """
1264
+ # Align child workforce's pause_event with this one for unified
1265
+ # control of worker agents only.
1266
+ workforce._pause_event = self._pause_event
1194
1267
  self._children.append(workforce)
1195
1268
  return self
1196
1269
 
@@ -1224,16 +1297,19 @@ class Workforce(BaseNode):
1224
1297
  # Handle asyncio.Event in a thread-safe way
1225
1298
  if self._loop and not self._loop.is_closed():
1226
1299
  # If we have a loop, use it to set the event safely
1227
- asyncio.run_coroutine_threadsafe(
1228
- self._async_reset(), self._loop
1229
- ).result()
1230
- else:
1231
1300
  try:
1232
- self._reset_task = asyncio.create_task(self._async_reset())
1233
- except RuntimeError:
1234
- asyncio.run(self._async_reset())
1301
+ asyncio.run_coroutine_threadsafe(
1302
+ self._async_reset(), self._loop
1303
+ ).result()
1304
+ except RuntimeError as e:
1305
+ logger.warning(f"Failed to reset via existing loop: {e}")
1306
+ # Fallback to direct event manipulation
1307
+ self._pause_event.set()
1308
+ else:
1309
+ # No active loop, directly set the event
1310
+ self._pause_event.set()
1235
1311
 
1236
- if hasattr(self, 'logger') and self.metrics_logger is not None:
1312
+ if hasattr(self, 'metrics_logger') and self.metrics_logger is not None:
1237
1313
  self.metrics_logger.reset_task_data()
1238
1314
  else:
1239
1315
  self.metrics_logger = WorkforceLogger(workforce_id=self.node_id)
@@ -1516,12 +1592,26 @@ class Workforce(BaseNode):
1516
1592
  # Record the start time when a task is posted
1517
1593
  self._task_start_times[task.id] = time.time()
1518
1594
 
1595
+ task.assigned_worker_id = assignee_id
1596
+
1519
1597
  if self.metrics_logger:
1520
1598
  self.metrics_logger.log_task_started(
1521
1599
  task_id=task.id, worker_id=assignee_id
1522
1600
  )
1523
- self._in_flight_tasks += 1
1524
- await self._channel.post_task(task, self.node_id, assignee_id)
1601
+
1602
+ try:
1603
+ self._in_flight_tasks += 1
1604
+ await self._channel.post_task(task, self.node_id, assignee_id)
1605
+ logger.debug(
1606
+ f"Posted task {task.id} to {assignee_id}. "
1607
+ f"In-flight tasks: {self._in_flight_tasks}"
1608
+ )
1609
+ except Exception as e:
1610
+ # Decrement counter if posting failed
1611
+ self._in_flight_tasks -= 1
1612
+ logger.error(
1613
+ f"Failed to post task {task.id} to {assignee_id}: {e}"
1614
+ )
1525
1615
 
1526
1616
  async def _post_dependency(self, dependency: Task) -> None:
1527
1617
  await self._channel.post_dependency(dependency, self.node_id)
@@ -1580,7 +1670,7 @@ class Workforce(BaseNode):
1580
1670
  new_node = SingleAgentWorker(
1581
1671
  description=new_node_conf.description,
1582
1672
  worker=new_agent,
1583
- pool_max_size=10, # TODO: make this configurable
1673
+ pool_max_size=DEFAULT_WORKER_POOL_SIZE,
1584
1674
  )
1585
1675
  new_node.set_channel(self._channel)
1586
1676
 
@@ -1621,9 +1711,14 @@ class Workforce(BaseNode):
1621
1711
  model_config_dict={"temperature": 0},
1622
1712
  )
1623
1713
 
1624
- return ChatAgent(worker_sys_msg, model=model, tools=function_list) # type: ignore[arg-type]
1714
+ return ChatAgent(
1715
+ worker_sys_msg,
1716
+ model=model,
1717
+ tools=function_list, # type: ignore[arg-type]
1718
+ pause_event=self._pause_event,
1719
+ )
1625
1720
 
1626
- async def _get_returned_task(self) -> Task:
1721
+ async def _get_returned_task(self) -> Optional[Task]:
1627
1722
  r"""Get the task that's published by this node and just get returned
1628
1723
  from the assignee. Includes timeout handling to prevent indefinite
1629
1724
  waiting.
@@ -1632,17 +1727,28 @@ class Workforce(BaseNode):
1632
1727
  # Add timeout to prevent indefinite waiting
1633
1728
  return await asyncio.wait_for(
1634
1729
  self._channel.get_returned_task_by_publisher(self.node_id),
1635
- timeout=180.0, # 3 minute timeout
1730
+ timeout=TASK_TIMEOUT_SECONDS,
1636
1731
  )
1637
- except asyncio.TimeoutError:
1638
- logger.warning(
1639
- f"Timeout waiting for returned task in "
1732
+ except Exception as e:
1733
+ # Decrement in-flight counter to prevent hanging
1734
+ if self._in_flight_tasks > 0:
1735
+ self._in_flight_tasks -= 1
1736
+
1737
+ error_msg = (
1738
+ f"Error getting returned task {e} in "
1640
1739
  f"workforce {self.node_id}. "
1641
- f"This may indicate an issue with async tool execution. "
1642
1740
  f"Current pending tasks: {len(self._pending_tasks)}, "
1643
1741
  f"In-flight tasks: {self._in_flight_tasks}"
1644
1742
  )
1645
- raise
1743
+ logger.warning(error_msg)
1744
+
1745
+ if self._pending_tasks and self._assignees:
1746
+ for task in self._pending_tasks:
1747
+ if task.id in self._assignees:
1748
+ # Mark this real task as failed
1749
+ task.set_state(TaskState.FAILED)
1750
+ return task
1751
+ return None
1646
1752
 
1647
1753
  async def _post_ready_tasks(self) -> None:
1648
1754
  r"""Checks for unassigned tasks, assigns them, and then posts any
@@ -1682,6 +1788,9 @@ class Workforce(BaseNode):
1682
1788
  # Step 2: Iterate through all pending tasks and post those that are
1683
1789
  # ready
1684
1790
  posted_tasks = []
1791
+ # Pre-compute completed task IDs set for O(1) lookups
1792
+ completed_task_ids = {t.id for t in self._completed_tasks}
1793
+
1685
1794
  for task in self._pending_tasks:
1686
1795
  # A task must be assigned to be considered for posting
1687
1796
  if task.id in self._task_dependencies:
@@ -1689,8 +1798,7 @@ class Workforce(BaseNode):
1689
1798
  # Check if all dependencies for this task are in the completed
1690
1799
  # set
1691
1800
  if all(
1692
- dep_id in {t.id for t in self._completed_tasks}
1693
- for dep_id in dependencies
1801
+ dep_id in completed_task_ids for dep_id in dependencies
1694
1802
  ):
1695
1803
  assignee_id = self._assignees[task.id]
1696
1804
  logger.debug(
@@ -1712,17 +1820,67 @@ class Workforce(BaseNode):
1712
1820
  async def _handle_failed_task(self, task: Task) -> bool:
1713
1821
  task.failure_count += 1
1714
1822
 
1823
+ # Determine detailed failure information
1824
+ if is_task_result_insufficient(task):
1825
+ failure_reason = "Worker returned unhelpful "
1826
+ f"response: {task.result[:100] if task.result else ''}..."
1827
+ else:
1828
+ failure_reason = "Task marked as failed despite "
1829
+ f"having result: {(task.result or '')[:100]}..."
1830
+
1831
+ # Add context about the worker and task
1832
+ worker_id = task.assigned_worker_id or "unknown"
1833
+ worker_info = f" (assigned to worker: {worker_id})"
1834
+
1835
+ detailed_error = f"{failure_reason}{worker_info}"
1836
+
1837
+ logger.error(
1838
+ f"Task {task.id} failed (attempt "
1839
+ f"{task.failure_count}/3): {detailed_error}"
1840
+ )
1841
+
1715
1842
  if self.metrics_logger:
1716
- worker_id = self._assignees.get(task.id)
1717
1843
  self.metrics_logger.log_task_failed(
1718
1844
  task_id=task.id,
1719
1845
  worker_id=worker_id,
1720
- error_message=task.result or "Task execution failed",
1846
+ error_message=detailed_error,
1721
1847
  error_type="TaskFailure",
1722
- metadata={'failure_count': task.failure_count},
1848
+ metadata={
1849
+ 'failure_count': task.failure_count,
1850
+ 'task_content': task.content,
1851
+ 'result_length': len(task.result) if task.result else 0,
1852
+ },
1853
+ )
1854
+
1855
+ # Check for immediate halt conditions - return immediately if we
1856
+ # should halt
1857
+ if task.failure_count >= MAX_TASK_RETRIES:
1858
+ logger.error(
1859
+ f"Task {task.id} has exceeded maximum retry attempts "
1860
+ f"({MAX_TASK_RETRIES}). Final failure "
1861
+ f"reason: {detailed_error}. "
1862
+ f"Task content: '{task.content[:100]}...'"
1723
1863
  )
1864
+ self._cleanup_task_tracking(task.id)
1865
+ # Mark task as completed for dependency tracking before halting
1866
+ self._completed_tasks.append(task)
1867
+ if task.id in self._assignees:
1868
+ await self._channel.archive_task(task.id)
1869
+ return True
1724
1870
 
1725
- if task.failure_count > 3:
1871
+ # If too many tasks are failing rapidly, also halt to prevent infinite
1872
+ # loops
1873
+ if len(self._pending_tasks) > MAX_PENDING_TASKS_LIMIT:
1874
+ logger.error(
1875
+ f"Too many pending tasks ({len(self._pending_tasks)} > "
1876
+ f"{MAX_PENDING_TASKS_LIMIT}). Halting to prevent task "
1877
+ f"explosion. Last failed task: {task.id}"
1878
+ )
1879
+ self._cleanup_task_tracking(task.id)
1880
+ # Mark task as completed for dependency tracking before halting
1881
+ self._completed_tasks.append(task)
1882
+ if task.id in self._assignees:
1883
+ await self._channel.archive_task(task.id)
1726
1884
  return True
1727
1885
 
1728
1886
  if task.get_depth() > 3:
@@ -1777,8 +1935,6 @@ class Workforce(BaseNode):
1777
1935
  # Mark task as completed for dependency tracking
1778
1936
  self._completed_tasks.append(task)
1779
1937
 
1780
- # Post next ready tasks
1781
-
1782
1938
  # Sync shared memory after task completion to share knowledge
1783
1939
  if self.share_memory:
1784
1940
  logger.info(
@@ -1792,7 +1948,7 @@ class Workforce(BaseNode):
1792
1948
 
1793
1949
  async def _handle_completed_task(self, task: Task) -> None:
1794
1950
  if self.metrics_logger:
1795
- worker_id = self._assignees.get(task.id, "unknown")
1951
+ worker_id = task.assigned_worker_id or "unknown"
1796
1952
  processing_time_seconds = None
1797
1953
  token_usage = None
1798
1954
 
@@ -1801,7 +1957,7 @@ class Workforce(BaseNode):
1801
1957
  processing_time_seconds = (
1802
1958
  time.time() - self._task_start_times[task.id]
1803
1959
  )
1804
- del self._task_start_times[task.id] # Prevent memory leaks
1960
+ self._cleanup_task_tracking(task.id)
1805
1961
  elif (
1806
1962
  task.additional_info is not None
1807
1963
  and 'processing_time_seconds' in task.additional_info
@@ -1995,8 +2151,19 @@ class Workforce(BaseNode):
1995
2151
  )
1996
2152
  self._last_snapshot_time = time.time()
1997
2153
 
1998
- # Get returned task (this may block until a task is returned)
2154
+ # Get returned task
1999
2155
  returned_task = await self._get_returned_task()
2156
+
2157
+ # If no task was returned, continue
2158
+ if returned_task is None:
2159
+ logger.debug(
2160
+ f"No task returned in workforce {self.node_id}. "
2161
+ f"Pending: {len(self._pending_tasks)}, "
2162
+ f"In-flight: {self._in_flight_tasks}"
2163
+ )
2164
+ await self._post_ready_tasks()
2165
+ continue
2166
+
2000
2167
  self._in_flight_tasks -= 1
2001
2168
 
2002
2169
  # Check for stop request after getting task
@@ -2006,22 +2173,72 @@ class Workforce(BaseNode):
2006
2173
 
2007
2174
  # Process the returned task based on its state
2008
2175
  if returned_task.state == TaskState.DONE:
2009
- print(
2010
- f"{Fore.CYAN}🎯 Task {returned_task.id} completed "
2011
- f"successfully.{Fore.RESET}"
2012
- )
2013
- await self._handle_completed_task(returned_task)
2176
+ # Check if the "completed" task actually failed to provide
2177
+ # useful results
2178
+ if is_task_result_insufficient(returned_task):
2179
+ result_preview = (
2180
+ returned_task.result[:100] + "..."
2181
+ if returned_task.result
2182
+ else "No result"
2183
+ )
2184
+ logger.warning(
2185
+ f"Task {returned_task.id} marked as DONE but "
2186
+ f"result is insufficient. "
2187
+ f"Treating as failed. Result: '{result_preview}'"
2188
+ )
2189
+ returned_task.state = TaskState.FAILED
2190
+ try:
2191
+ halt = await self._handle_failed_task(
2192
+ returned_task
2193
+ )
2194
+ if not halt:
2195
+ continue
2196
+ print(
2197
+ f"{Fore.RED}Task {returned_task.id} has "
2198
+ f"failed for {MAX_TASK_RETRIES} times after "
2199
+ f"insufficient results, halting the "
2200
+ f"workforce. Final error: "
2201
+ f"{returned_task.result or 'Unknown error'}"
2202
+ f"{Fore.RESET}"
2203
+ )
2204
+ await self._graceful_shutdown(returned_task)
2205
+ break
2206
+ except Exception as e:
2207
+ logger.error(
2208
+ f"Error handling insufficient task result "
2209
+ f"{returned_task.id}: {e}",
2210
+ exc_info=True,
2211
+ )
2212
+ continue
2213
+ else:
2214
+ print(
2215
+ f"{Fore.CYAN}🎯 Task {returned_task.id} completed "
2216
+ f"successfully.{Fore.RESET}"
2217
+ )
2218
+ await self._handle_completed_task(returned_task)
2014
2219
  elif returned_task.state == TaskState.FAILED:
2015
- halt = await self._handle_failed_task(returned_task)
2016
- if not halt:
2220
+ try:
2221
+ halt = await self._handle_failed_task(returned_task)
2222
+ if not halt:
2223
+ continue
2224
+ print(
2225
+ f"{Fore.RED}Task {returned_task.id} has failed "
2226
+ f"for {MAX_TASK_RETRIES} times, halting "
2227
+ f"the workforce. Final error: "
2228
+ f"{returned_task.result or 'Unknown error'}"
2229
+ f"{Fore.RESET}"
2230
+ )
2231
+ # Graceful shutdown instead of immediate break
2232
+ await self._graceful_shutdown(returned_task)
2233
+ break
2234
+ except Exception as e:
2235
+ logger.error(
2236
+ f"Error handling failed task "
2237
+ f"{returned_task.id}: {e}",
2238
+ exc_info=True,
2239
+ )
2240
+ # Continue to prevent hanging
2017
2241
  continue
2018
- print(
2019
- f"{Fore.RED}Task {returned_task.id} has failed "
2020
- f"for 3 times, halting the workforce.{Fore.RESET}"
2021
- )
2022
- # Graceful shutdown instead of immediate break
2023
- await self._graceful_shutdown(returned_task)
2024
- break
2025
2242
  elif returned_task.state == TaskState.OPEN:
2026
2243
  # TODO: Add logic for OPEN
2027
2244
  pass
@@ -2031,7 +2248,18 @@ class Workforce(BaseNode):
2031
2248
  )
2032
2249
 
2033
2250
  except Exception as e:
2034
- logger.error(f"Error processing task: {e}")
2251
+ # Decrement in-flight counter to prevent hanging
2252
+ if self._in_flight_tasks > 0:
2253
+ self._in_flight_tasks -= 1
2254
+
2255
+ logger.error(
2256
+ f"Error processing task in workforce {self.node_id}: {e}"
2257
+ f"Workforce state - Pending tasks: "
2258
+ f"{len(self._pending_tasks)}, "
2259
+ f"In-flight tasks: {self._in_flight_tasks}, "
2260
+ f"Completed tasks: {len(self._completed_tasks)}"
2261
+ )
2262
+
2035
2263
  if self._stop_requested:
2036
2264
  break
2037
2265
  # Continue with next iteration unless stop is requested
@@ -2085,11 +2313,38 @@ class Workforce(BaseNode):
2085
2313
  r"""Stop all the child nodes under it. The node itself will be stopped
2086
2314
  by its parent node.
2087
2315
  """
2316
+ # Stop all child nodes first
2088
2317
  for child in self._children:
2089
2318
  if child._running:
2090
2319
  child.stop()
2091
- for child_task in self._child_listening_tasks:
2092
- child_task.cancel()
2320
+
2321
+ # Cancel child listening tasks
2322
+ if self._child_listening_tasks:
2323
+ try:
2324
+ loop = asyncio.get_running_loop()
2325
+ if loop and not loop.is_closed():
2326
+ # Create graceful cleanup task
2327
+ async def cleanup():
2328
+ await asyncio.sleep(0.1) # Brief grace period
2329
+ for task in self._child_listening_tasks:
2330
+ if not task.done():
2331
+ task.cancel()
2332
+ await asyncio.gather(
2333
+ *self._child_listening_tasks,
2334
+ return_exceptions=True,
2335
+ )
2336
+
2337
+ self._cleanup_task = loop.create_task(cleanup())
2338
+ else:
2339
+ # No active loop, cancel immediately
2340
+ for task in self._child_listening_tasks:
2341
+ task.cancel()
2342
+ except (RuntimeError, Exception) as e:
2343
+ # Fallback: cancel immediately
2344
+ logger.debug(f"Exception during task cleanup: {e}")
2345
+ for task in self._child_listening_tasks:
2346
+ task.cancel()
2347
+
2093
2348
  self._running = False
2094
2349
 
2095
2350
  def clone(self, with_memory: bool = False) -> 'Workforce':
@@ -488,7 +488,6 @@ class WorkforceLogger:
488
488
  'worker_utilization': {},
489
489
  'current_pending_tasks': 0,
490
490
  'total_workforce_running_time_seconds': 0.0,
491
- 'avg_task_queue_time_seconds': 0.0,
492
491
  }
493
492
 
494
493
  task_start_times: Dict[str, float] = {}
camel/tasks/task.py CHANGED
@@ -46,19 +46,35 @@ from .task_prompt import (
46
46
  logger = get_logger(__name__)
47
47
 
48
48
 
49
+ class TaskValidationMode(Enum):
50
+ r"""Validation modes for different use cases."""
51
+
52
+ INPUT = "input" # For validating task content before processing
53
+ OUTPUT = "output" # For validating task results after completion
54
+
55
+
49
56
  def validate_task_content(
50
- content: str, task_id: str = "unknown", min_length: int = 10
57
+ content: str,
58
+ task_id: str = "unknown",
59
+ min_length: int = 1,
60
+ mode: TaskValidationMode = TaskValidationMode.INPUT,
61
+ check_failure_patterns: bool = True,
51
62
  ) -> bool:
52
- r"""Validates task result content to avoid silent failures.
53
- It performs basic checks to ensure the content meets minimum
54
- quality standards.
63
+ r"""Unified validation for task content and results to avoid silent
64
+ failures. Performs comprehensive checks to ensure content meets quality
65
+ standards.
55
66
 
56
67
  Args:
57
- content (str): The task result content to validate.
68
+ content (str): The task content or result to validate.
58
69
  task_id (str): Task ID for logging purposes.
59
70
  (default: :obj:`"unknown"`)
60
71
  min_length (int): Minimum content length after stripping whitespace.
61
- (default: :obj:`10`)
72
+ (default: :obj:`1`)
73
+ mode (TaskValidationMode): Validation mode - INPUT for task content,
74
+ OUTPUT for task results. (default: :obj:`TaskValidationMode.INPUT`)
75
+ check_failure_patterns (bool): Whether to check for failure indicators
76
+ in the content. Only effective in OUTPUT mode.
77
+ (default: :obj:`True`)
62
78
 
63
79
  Returns:
64
80
  bool: True if content passes validation, False otherwise.
@@ -85,14 +101,70 @@ def validate_task_content(
85
101
  )
86
102
  return False
87
103
 
104
+ # 4: For OUTPUT mode, check for failure patterns if enabled
105
+ if mode == TaskValidationMode.OUTPUT and check_failure_patterns:
106
+ content_lower = stripped_content.lower()
107
+
108
+ # Check for explicit failure indicators
109
+ failure_indicators = [
110
+ "i cannot complete",
111
+ "i cannot do",
112
+ "task failed",
113
+ "unable to complete",
114
+ "cannot be completed",
115
+ "failed to complete",
116
+ "i cannot",
117
+ "not possible",
118
+ "impossible to",
119
+ "cannot perform",
120
+ ]
121
+
122
+ if any(indicator in content_lower for indicator in failure_indicators):
123
+ logger.warning(
124
+ f"Task {task_id}: Failure indicator detected in result. "
125
+ f"Content preview: '{stripped_content[:100]}...'"
126
+ )
127
+ return False
128
+
129
+ # Check for responses that are just error messages or refusals
130
+ if content_lower.startswith(("error", "failed", "cannot", "unable")):
131
+ logger.warning(
132
+ f"Task {task_id}: Error/refusal pattern detected at start. "
133
+ f"Content preview: '{stripped_content[:100]}...'"
134
+ )
135
+ return False
136
+
88
137
  # All validation checks passed
89
138
  logger.debug(
90
- f"Task {task_id}: Content validation passed "
139
+ f"Task {task_id}: {mode.value} validation passed "
91
140
  f"({len(stripped_content)} chars)"
92
141
  )
93
142
  return True
94
143
 
95
144
 
145
+ def is_task_result_insufficient(task: "Task") -> bool:
146
+ r"""Check if a task result is insufficient and should be treated as failed.
147
+
148
+ This is a convenience wrapper around validate_task_content for backward
149
+ compatibility and semantic clarity when checking task results.
150
+
151
+ Args:
152
+ task (Task): The task to check.
153
+
154
+ Returns:
155
+ bool: True if the result is insufficient, False otherwise.
156
+ """
157
+ if not hasattr(task, 'result') or task.result is None:
158
+ return True
159
+
160
+ return not validate_task_content(
161
+ content=task.result,
162
+ task_id=task.id,
163
+ mode=TaskValidationMode.OUTPUT,
164
+ check_failure_patterns=True,
165
+ )
166
+
167
+
96
168
  def parse_response(
97
169
  response: str, task_id: Optional[str] = None
98
170
  ) -> List["Task"]:
@@ -157,6 +229,8 @@ class Task(BaseModel):
157
229
  (default: :obj:`""`)
158
230
  failure_count (int): The failure count for the task.
159
231
  (default: :obj:`0`)
232
+ assigned_worker_id (Optional[str]): The ID of the worker assigned to
233
+ this task. (default: :obj:`None`)
160
234
  additional_info (Optional[Dict[str, Any]]): Additional information for
161
235
  the task. (default: :obj:`None`)
162
236
  image_list (Optional[List[Image.Image]]): Optional list of PIL Image
@@ -187,6 +261,8 @@ class Task(BaseModel):
187
261
 
188
262
  failure_count: int = 0
189
263
 
264
+ assigned_worker_id: Optional[str] = None
265
+
190
266
  additional_info: Optional[Dict[str, Any]] = None
191
267
 
192
268
  image_list: Optional[List[Image.Image]] = None