swarms 7.7.5__py3-none-any.whl → 7.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarms/agents/__init__.py +0 -1
- swarms/communication/duckdb_wrap.py +15 -6
- swarms/prompts/max_loop_prompt.py +48 -0
- swarms/prompts/react_base_prompt.py +41 -0
- swarms/structs/agent.py +70 -109
- swarms/structs/concurrent_workflow.py +329 -214
- swarms/structs/conversation.py +123 -6
- swarms/structs/groupchat.py +0 -12
- swarms/structs/multi_model_gpu_manager.py +0 -1
- {swarms-7.7.5.dist-info → swarms-7.7.7.dist-info}/METADATA +1 -1
- {swarms-7.7.5.dist-info → swarms-7.7.7.dist-info}/RECORD +14 -12
- {swarms-7.7.5.dist-info → swarms-7.7.7.dist-info}/LICENSE +0 -0
- {swarms-7.7.5.dist-info → swarms-7.7.7.dist-info}/WHEEL +0 -0
- {swarms-7.7.5.dist-info → swarms-7.7.7.dist-info}/entry_points.txt +0 -0
@@ -1,70 +1,32 @@
|
|
1
1
|
import os
|
2
|
-
import
|
2
|
+
import time
|
3
3
|
from concurrent.futures import ThreadPoolExecutor
|
4
|
-
from
|
4
|
+
from functools import lru_cache
|
5
5
|
from typing import Any, Callable, Dict, List, Optional, Union
|
6
6
|
|
7
|
-
from
|
7
|
+
from tqdm import tqdm
|
8
8
|
|
9
9
|
from swarms.structs.agent import Agent
|
10
10
|
from swarms.structs.base_swarm import BaseSwarm
|
11
|
-
from swarms.utils.file_processing import create_file_in_folder
|
12
|
-
from swarms.utils.loguru_logger import initialize_logger
|
13
11
|
from swarms.structs.conversation import Conversation
|
14
|
-
from swarms.
|
15
|
-
from swarms.
|
12
|
+
from swarms.utils.formatter import formatter
|
13
|
+
from swarms.utils.history_output_formatter import (
|
14
|
+
history_output_formatter,
|
15
|
+
)
|
16
|
+
from swarms.utils.loguru_logger import initialize_logger
|
16
17
|
|
17
18
|
logger = initialize_logger(log_folder="concurrent_workflow")
|
18
19
|
|
19
20
|
|
20
|
-
class AgentOutputSchema(BaseModel):
|
21
|
-
run_id: Optional[str] = Field(
|
22
|
-
..., description="Unique ID for the run"
|
23
|
-
)
|
24
|
-
agent_name: Optional[str] = Field(
|
25
|
-
..., description="Name of the agent"
|
26
|
-
)
|
27
|
-
task: Optional[str] = Field(
|
28
|
-
..., description="Task or query given to the agent"
|
29
|
-
)
|
30
|
-
output: Optional[str] = Field(
|
31
|
-
..., description="Output generated by the agent"
|
32
|
-
)
|
33
|
-
start_time: Optional[datetime] = Field(
|
34
|
-
..., description="Start time of the task"
|
35
|
-
)
|
36
|
-
end_time: Optional[datetime] = Field(
|
37
|
-
..., description="End time of the task"
|
38
|
-
)
|
39
|
-
duration: Optional[float] = Field(
|
40
|
-
...,
|
41
|
-
description="Duration taken to complete the task (in seconds)",
|
42
|
-
)
|
43
|
-
|
44
|
-
|
45
|
-
class MetadataSchema(BaseModel):
|
46
|
-
swarm_id: Optional[str] = Field(
|
47
|
-
generate_swarm_id(), description="Unique ID for the run"
|
48
|
-
)
|
49
|
-
task: Optional[str] = Field(
|
50
|
-
..., description="Task or query given to all agents"
|
51
|
-
)
|
52
|
-
description: Optional[str] = Field(
|
53
|
-
"Concurrent execution of multiple agents",
|
54
|
-
description="Description of the workflow",
|
55
|
-
)
|
56
|
-
agents: Optional[List[AgentOutputSchema]] = Field(
|
57
|
-
..., description="List of agent outputs and metadata"
|
58
|
-
)
|
59
|
-
timestamp: Optional[datetime] = Field(
|
60
|
-
default_factory=datetime.now,
|
61
|
-
description="Timestamp of the workflow execution",
|
62
|
-
)
|
63
|
-
|
64
|
-
|
65
21
|
class ConcurrentWorkflow(BaseSwarm):
|
66
22
|
"""
|
67
23
|
Represents a concurrent workflow that executes multiple agents concurrently in a production-grade manner.
|
24
|
+
Features include:
|
25
|
+
- Interactive model support
|
26
|
+
- Caching for repeated prompts
|
27
|
+
- Optional progress tracking
|
28
|
+
- Enhanced error handling and retries
|
29
|
+
- Input validation
|
68
30
|
|
69
31
|
Args:
|
70
32
|
name (str): The name of the workflow. Defaults to "ConcurrentWorkflow".
|
@@ -72,11 +34,16 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
72
34
|
agents (List[Agent]): The list of agents to be executed concurrently. Defaults to an empty list.
|
73
35
|
metadata_output_path (str): The path to save the metadata output. Defaults to "agent_metadata.json".
|
74
36
|
auto_save (bool): Flag indicating whether to automatically save the metadata. Defaults to False.
|
75
|
-
|
37
|
+
output_type (str): The type of output format. Defaults to "dict".
|
76
38
|
max_loops (int): The maximum number of loops for each agent. Defaults to 1.
|
77
39
|
return_str_on (bool): Flag indicating whether to return the output as a string. Defaults to False.
|
78
|
-
agent_responses (list): The list of agent responses. Defaults to an empty list.
|
79
40
|
auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents. Defaults to False.
|
41
|
+
return_entire_history (bool): Flag indicating whether to return the entire conversation history. Defaults to False.
|
42
|
+
interactive (bool): Flag indicating whether to enable interactive mode. Defaults to False.
|
43
|
+
cache_size (int): The size of the cache. Defaults to 100.
|
44
|
+
max_retries (int): The maximum number of retry attempts. Defaults to 3.
|
45
|
+
retry_delay (float): The delay between retry attempts in seconds. Defaults to 1.0.
|
46
|
+
show_progress (bool): Flag indicating whether to show progress. Defaults to False.
|
80
47
|
|
81
48
|
Raises:
|
82
49
|
ValueError: If the list of agents is empty or if the description is empty.
|
@@ -87,13 +54,18 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
87
54
|
agents (List[Agent]): The list of agents to be executed concurrently.
|
88
55
|
metadata_output_path (str): The path to save the metadata output.
|
89
56
|
auto_save (bool): Flag indicating whether to automatically save the metadata.
|
90
|
-
|
57
|
+
output_type (str): The type of output format.
|
91
58
|
max_loops (int): The maximum number of loops for each agent.
|
92
59
|
return_str_on (bool): Flag indicating whether to return the output as a string.
|
93
|
-
agent_responses (list): The list of agent responses.
|
94
60
|
auto_generate_prompts (bool): Flag indicating whether to auto-generate prompts for agents.
|
95
|
-
|
96
|
-
|
61
|
+
return_entire_history (bool): Flag indicating whether to return the entire conversation history.
|
62
|
+
interactive (bool): Flag indicating whether to enable interactive mode.
|
63
|
+
cache_size (int): The size of the cache.
|
64
|
+
max_retries (int): The maximum number of retry attempts.
|
65
|
+
retry_delay (float): The delay between retry attempts in seconds.
|
66
|
+
show_progress (bool): Flag indicating whether to show progress.
|
67
|
+
_cache (dict): The cache for storing agent outputs.
|
68
|
+
_progress_bar (tqdm): The progress bar for tracking execution.
|
97
69
|
"""
|
98
70
|
|
99
71
|
def __init__(
|
@@ -103,13 +75,16 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
103
75
|
agents: List[Union[Agent, Callable]] = [],
|
104
76
|
metadata_output_path: str = "agent_metadata.json",
|
105
77
|
auto_save: bool = True,
|
106
|
-
|
78
|
+
output_type: str = "dict-all-except-first",
|
107
79
|
max_loops: int = 1,
|
108
80
|
return_str_on: bool = False,
|
109
|
-
agent_responses: list = [],
|
110
81
|
auto_generate_prompts: bool = False,
|
111
|
-
output_type: OutputType = "dict",
|
112
82
|
return_entire_history: bool = False,
|
83
|
+
interactive: bool = False,
|
84
|
+
cache_size: int = 100,
|
85
|
+
max_retries: int = 3,
|
86
|
+
retry_delay: float = 1.0,
|
87
|
+
show_progress: bool = False,
|
113
88
|
*args,
|
114
89
|
**kwargs,
|
115
90
|
):
|
@@ -125,18 +100,22 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
125
100
|
self.agents = agents
|
126
101
|
self.metadata_output_path = metadata_output_path
|
127
102
|
self.auto_save = auto_save
|
128
|
-
self.output_schema = output_schema
|
129
103
|
self.max_loops = max_loops
|
130
104
|
self.return_str_on = return_str_on
|
131
|
-
self.agent_responses = agent_responses
|
132
105
|
self.auto_generate_prompts = auto_generate_prompts
|
133
106
|
self.max_workers = os.cpu_count()
|
134
107
|
self.output_type = output_type
|
135
108
|
self.return_entire_history = return_entire_history
|
136
109
|
self.tasks = [] # Initialize tasks list
|
110
|
+
self.interactive = interactive
|
111
|
+
self.cache_size = cache_size
|
112
|
+
self.max_retries = max_retries
|
113
|
+
self.retry_delay = retry_delay
|
114
|
+
self.show_progress = show_progress
|
115
|
+
self._cache = {}
|
116
|
+
self._progress_bar = None
|
137
117
|
|
138
118
|
self.reliability_check()
|
139
|
-
|
140
119
|
self.conversation = Conversation()
|
141
120
|
|
142
121
|
def disable_agent_prints(self):
|
@@ -145,29 +124,47 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
145
124
|
|
146
125
|
def reliability_check(self):
|
147
126
|
try:
|
148
|
-
|
127
|
+
formatter.print_panel(
|
128
|
+
content=f"\n 🏷️ Name: {self.name}\n 📝 Description: {self.description}\n 🤖 Agents: {len(self.agents)}\n 🔄 Max Loops: {self.max_loops}\n ",
|
129
|
+
title="⚙️ Concurrent Workflow Settings",
|
130
|
+
style="bold blue",
|
131
|
+
)
|
132
|
+
formatter.print_panel(
|
133
|
+
content="🔍 Starting reliability checks",
|
134
|
+
title="🔒 Reliability Checks",
|
135
|
+
style="bold blue",
|
136
|
+
)
|
149
137
|
|
150
138
|
if self.name is None:
|
151
|
-
logger.error("A name is required for the swarm")
|
152
|
-
raise ValueError(
|
139
|
+
logger.error("❌ A name is required for the swarm")
|
140
|
+
raise ValueError(
|
141
|
+
"❌ A name is required for the swarm"
|
142
|
+
)
|
153
143
|
|
154
|
-
if not self.agents:
|
155
|
-
logger.error(
|
144
|
+
if not self.agents or len(self.agents) <= 1:
|
145
|
+
logger.error(
|
146
|
+
"❌ The list of agents must not be empty."
|
147
|
+
)
|
156
148
|
raise ValueError(
|
157
|
-
"The list of agents must not be empty."
|
149
|
+
"❌ The list of agents must not be empty."
|
158
150
|
)
|
159
151
|
|
160
152
|
if not self.description:
|
161
|
-
logger.error("A description is required.")
|
162
|
-
raise ValueError("A description is required.")
|
153
|
+
logger.error("❌ A description is required.")
|
154
|
+
raise ValueError("❌ A description is required.")
|
155
|
+
|
156
|
+
formatter.print_panel(
|
157
|
+
content="✅ Reliability checks completed successfully",
|
158
|
+
title="🎉 Reliability Checks",
|
159
|
+
style="bold green",
|
160
|
+
)
|
163
161
|
|
164
|
-
logger.info("Reliability checks completed successfully")
|
165
162
|
except ValueError as e:
|
166
|
-
logger.error(f"Reliability check failed: {e}")
|
163
|
+
logger.error(f"❌ Reliability check failed: {e}")
|
167
164
|
raise
|
168
165
|
except Exception as e:
|
169
166
|
logger.error(
|
170
|
-
f"An unexpected error occurred during reliability checks: {e}"
|
167
|
+
f"💥 An unexpected error occurred during reliability checks: {e}"
|
171
168
|
)
|
172
169
|
raise
|
173
170
|
|
@@ -184,147 +181,179 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
184
181
|
for agent in self.agents:
|
185
182
|
agent.auto_generate_prompt = True
|
186
183
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
self
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
if self.
|
224
|
-
|
225
|
-
|
184
|
+
@lru_cache(maxsize=100)
|
185
|
+
def _cached_run(self, task: str, agent_id: int) -> Any:
|
186
|
+
"""Cached version of agent execution to avoid redundant computations"""
|
187
|
+
return self.agents[agent_id].run(task=task)
|
188
|
+
|
189
|
+
def enable_progress_bar(self):
|
190
|
+
"""Enable progress bar display"""
|
191
|
+
self.show_progress = True
|
192
|
+
|
193
|
+
def disable_progress_bar(self):
|
194
|
+
"""Disable progress bar display"""
|
195
|
+
if self._progress_bar:
|
196
|
+
self._progress_bar.close()
|
197
|
+
self._progress_bar = None
|
198
|
+
self.show_progress = False
|
199
|
+
|
200
|
+
def _create_progress_bar(self, total: int):
|
201
|
+
"""Create a progress bar for tracking execution"""
|
202
|
+
if self.show_progress:
|
203
|
+
try:
|
204
|
+
self._progress_bar = tqdm(
|
205
|
+
total=total,
|
206
|
+
desc="Processing tasks",
|
207
|
+
unit="task",
|
208
|
+
disable=not self.show_progress,
|
209
|
+
ncols=100,
|
210
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
|
211
|
+
)
|
212
|
+
except Exception as e:
|
213
|
+
logger.warning(f"Failed to create progress bar: {e}")
|
214
|
+
self.show_progress = False
|
215
|
+
self._progress_bar = None
|
216
|
+
return self._progress_bar
|
217
|
+
|
218
|
+
def _update_progress(self, increment: int = 1):
|
219
|
+
"""Update the progress bar"""
|
220
|
+
if self._progress_bar and self.show_progress:
|
221
|
+
try:
|
222
|
+
self._progress_bar.update(increment)
|
223
|
+
except Exception as e:
|
224
|
+
logger.warning(f"Failed to update progress bar: {e}")
|
225
|
+
self.disable_progress_bar()
|
226
|
+
|
227
|
+
def _validate_input(self, task: str) -> bool:
|
228
|
+
"""Validate input task"""
|
229
|
+
if not isinstance(task, str):
|
230
|
+
raise ValueError("Task must be a string")
|
231
|
+
if not task.strip():
|
232
|
+
raise ValueError("Task cannot be empty")
|
233
|
+
return True
|
234
|
+
|
235
|
+
def _handle_interactive(self, task: str) -> str:
|
236
|
+
"""Handle interactive mode for task input"""
|
237
|
+
if self.interactive:
|
238
|
+
from swarms.utils.formatter import formatter
|
239
|
+
|
240
|
+
# Display current task in a panel
|
241
|
+
formatter.print_panel(
|
242
|
+
content=f"Current task: {task}",
|
243
|
+
title="Task Status",
|
244
|
+
style="bold blue",
|
226
245
|
)
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
246
|
+
|
247
|
+
# Get user input with formatted prompt
|
248
|
+
formatter.print_panel(
|
249
|
+
content="Do you want to modify this task? (y/n/q to quit): ",
|
250
|
+
title="User Input",
|
251
|
+
style="bold green",
|
231
252
|
)
|
253
|
+
response = input().lower()
|
254
|
+
|
255
|
+
if response == "q":
|
256
|
+
return None
|
257
|
+
elif response == "y":
|
258
|
+
formatter.print_panel(
|
259
|
+
content="Enter new task: ",
|
260
|
+
title="New Task Input",
|
261
|
+
style="bold yellow",
|
262
|
+
)
|
263
|
+
new_task = input()
|
264
|
+
return new_task
|
265
|
+
return task
|
266
|
+
|
267
|
+
def _run_with_retry(
|
268
|
+
self, agent: Agent, task: str, img: str = None
|
269
|
+
) -> Any:
|
270
|
+
"""Run agent with retry mechanism"""
|
271
|
+
for attempt in range(self.max_retries):
|
272
|
+
try:
|
273
|
+
output = agent.run(task=task, img=img)
|
274
|
+
self.conversation.add(agent.agent_name, output)
|
275
|
+
return output
|
276
|
+
except Exception as e:
|
277
|
+
if attempt == self.max_retries - 1:
|
278
|
+
logger.error(
|
279
|
+
f"Error running agent {agent.agent_name} after {self.max_retries} attempts: {e}"
|
280
|
+
)
|
281
|
+
raise
|
282
|
+
logger.warning(
|
283
|
+
f"Attempt {attempt + 1} failed for agent {agent.agent_name}: {e}"
|
284
|
+
)
|
285
|
+
time.sleep(
|
286
|
+
self.retry_delay * (attempt + 1)
|
287
|
+
) # Exponential backoff
|
232
288
|
|
233
289
|
def _run(
|
234
290
|
self, task: str, img: str = None, *args, **kwargs
|
235
291
|
) -> Union[Dict[str, Any], str]:
|
236
292
|
"""
|
237
|
-
|
293
|
+
Enhanced run method with caching, progress tracking, and better error handling
|
294
|
+
"""
|
238
295
|
|
239
|
-
|
240
|
-
|
241
|
-
|
296
|
+
# Validate and potentially modify task
|
297
|
+
self._validate_input(task)
|
298
|
+
task = self._handle_interactive(task)
|
242
299
|
|
243
|
-
|
244
|
-
|
245
|
-
str: The final metadata as a string if return_str_on is True.
|
300
|
+
# Add task to conversation
|
301
|
+
self.conversation.add("User", task)
|
246
302
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
"""
|
251
|
-
logger.info(
|
252
|
-
f"Running concurrent workflow with {len(self.agents)} agents."
|
253
|
-
)
|
254
|
-
|
255
|
-
self.conversation.add(
|
256
|
-
"User",
|
257
|
-
task,
|
258
|
-
)
|
303
|
+
# Create progress bar if enabled
|
304
|
+
if self.show_progress:
|
305
|
+
self._create_progress_bar(len(self.agents))
|
259
306
|
|
260
307
|
def run_agent(
|
261
308
|
agent: Agent, task: str, img: str = None
|
262
|
-
) ->
|
263
|
-
start_time = datetime.now()
|
309
|
+
) -> Any:
|
264
310
|
try:
|
265
|
-
|
266
|
-
|
267
|
-
self.
|
268
|
-
|
269
|
-
|
270
|
-
|
311
|
+
# Check cache first
|
312
|
+
cache_key = f"{task}_{agent.agent_name}"
|
313
|
+
if cache_key in self._cache:
|
314
|
+
output = self._cache[cache_key]
|
315
|
+
else:
|
316
|
+
output = self._run_with_retry(agent, task, img)
|
317
|
+
# Update cache
|
318
|
+
if len(self._cache) >= self.cache_size:
|
319
|
+
self._cache.pop(next(iter(self._cache)))
|
320
|
+
self._cache[cache_key] = output
|
321
|
+
|
322
|
+
self._update_progress()
|
323
|
+
return output
|
271
324
|
except Exception as e:
|
272
325
|
logger.error(
|
273
326
|
f"Error running agent {agent.agent_name}: {e}"
|
274
327
|
)
|
328
|
+
self._update_progress()
|
275
329
|
raise
|
276
330
|
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
end_time=end_time,
|
287
|
-
duration=duration,
|
288
|
-
)
|
289
|
-
|
290
|
-
logger.info(
|
291
|
-
f"Agent {agent.agent_name} completed task: {task} in {duration:.2f} seconds."
|
292
|
-
)
|
293
|
-
|
294
|
-
return agent_output
|
295
|
-
|
296
|
-
with ThreadPoolExecutor(
|
297
|
-
max_workers=os.cpu_count()
|
298
|
-
) as executor:
|
299
|
-
agent_outputs = list(
|
300
|
-
executor.map(
|
301
|
-
lambda agent: run_agent(agent, task), self.agents
|
331
|
+
try:
|
332
|
+
with ThreadPoolExecutor(
|
333
|
+
max_workers=self.max_workers
|
334
|
+
) as executor:
|
335
|
+
list(
|
336
|
+
executor.map(
|
337
|
+
lambda agent: run_agent(agent, task),
|
338
|
+
self.agents,
|
339
|
+
)
|
302
340
|
)
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
341
|
+
finally:
|
342
|
+
if self._progress_bar and self.show_progress:
|
343
|
+
try:
|
344
|
+
self._progress_bar.close()
|
345
|
+
except Exception as e:
|
346
|
+
logger.warning(
|
347
|
+
f"Failed to close progress bar: {e}"
|
348
|
+
)
|
349
|
+
finally:
|
350
|
+
self._progress_bar = None
|
351
|
+
|
352
|
+
return history_output_formatter(
|
353
|
+
self.conversation,
|
354
|
+
type=self.output_type,
|
310
355
|
)
|
311
356
|
|
312
|
-
self.save_metadata()
|
313
|
-
|
314
|
-
if self.return_str_on:
|
315
|
-
return self.transform_metadata_schema_to_str(
|
316
|
-
self.output_schema
|
317
|
-
)
|
318
|
-
|
319
|
-
elif self.return_entire_history:
|
320
|
-
return self.conversation.return_history_as_string()
|
321
|
-
elif self.output_type == "list":
|
322
|
-
return self.conversation.return_messages_as_list()
|
323
|
-
elif self.output_type == "dict":
|
324
|
-
return self.conversation.return_messages_as_dictionary()
|
325
|
-
else:
|
326
|
-
return self.output_schema.model_dump_json(indent=4)
|
327
|
-
|
328
357
|
def run(
|
329
358
|
self,
|
330
359
|
task: Optional[str] = None,
|
@@ -333,9 +362,11 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
333
362
|
**kwargs,
|
334
363
|
) -> Any:
|
335
364
|
"""
|
336
|
-
Executes the agent's run method on a specified device.
|
365
|
+
Executes the agent's run method on a specified device with optional interactive mode.
|
337
366
|
|
338
|
-
This method attempts to execute the agent's run method on a specified device, either CPU or GPU.
|
367
|
+
This method attempts to execute the agent's run method on a specified device, either CPU or GPU.
|
368
|
+
It supports both standard execution and interactive mode where users can modify tasks and continue
|
369
|
+
the workflow interactively.
|
339
370
|
|
340
371
|
Args:
|
341
372
|
task (Optional[str], optional): The task to be executed. Defaults to None.
|
@@ -359,8 +390,73 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
359
390
|
self.tasks.append(task)
|
360
391
|
|
361
392
|
try:
|
362
|
-
|
363
|
-
|
393
|
+
# Handle interactive mode
|
394
|
+
if self.interactive:
|
395
|
+
current_task = task
|
396
|
+
loop_count = 0
|
397
|
+
|
398
|
+
while loop_count < self.max_loops:
|
399
|
+
if (
|
400
|
+
self.max_loops is not None
|
401
|
+
and loop_count >= self.max_loops
|
402
|
+
):
|
403
|
+
formatter.print_panel(
|
404
|
+
content=f"Maximum number of loops ({self.max_loops}) reached.",
|
405
|
+
title="Session Complete",
|
406
|
+
style="bold red",
|
407
|
+
)
|
408
|
+
break
|
409
|
+
|
410
|
+
if current_task is None:
|
411
|
+
formatter.print_panel(
|
412
|
+
content="Enter your task (or 'q' to quit): ",
|
413
|
+
title="Task Input",
|
414
|
+
style="bold blue",
|
415
|
+
)
|
416
|
+
current_task = input()
|
417
|
+
if current_task.lower() == "q":
|
418
|
+
break
|
419
|
+
|
420
|
+
# Run the workflow with the current task
|
421
|
+
try:
|
422
|
+
outputs = self._run(
|
423
|
+
current_task, img, *args, **kwargs
|
424
|
+
)
|
425
|
+
formatter.print_panel(
|
426
|
+
content=str(outputs),
|
427
|
+
title="Workflow Result",
|
428
|
+
style="bold green",
|
429
|
+
)
|
430
|
+
except Exception as e:
|
431
|
+
formatter.print_panel(
|
432
|
+
content=f"Error: {str(e)}",
|
433
|
+
title="Error",
|
434
|
+
style="bold red",
|
435
|
+
)
|
436
|
+
|
437
|
+
# Ask if user wants to continue
|
438
|
+
formatter.print_panel(
|
439
|
+
content="Do you want to continue with a new task? (y/n): ",
|
440
|
+
title="Continue Session",
|
441
|
+
style="bold yellow",
|
442
|
+
)
|
443
|
+
if input().lower() != "y":
|
444
|
+
break
|
445
|
+
|
446
|
+
current_task = None
|
447
|
+
loop_count += 1
|
448
|
+
|
449
|
+
formatter.print_panel(
|
450
|
+
content="Interactive session ended.",
|
451
|
+
title="Session Complete",
|
452
|
+
style="bold blue",
|
453
|
+
)
|
454
|
+
return outputs
|
455
|
+
else:
|
456
|
+
# Standard non-interactive execution
|
457
|
+
outputs = self._run(task, img, *args, **kwargs)
|
458
|
+
return outputs
|
459
|
+
|
364
460
|
except ValueError as e:
|
365
461
|
logger.error(f"Invalid device specified: {e}")
|
366
462
|
raise e
|
@@ -368,29 +464,48 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
368
464
|
logger.error(f"An error occurred during execution: {e}")
|
369
465
|
raise e
|
370
466
|
|
371
|
-
def run_batched(
|
372
|
-
self, tasks: List[str]
|
373
|
-
) -> List[Union[Dict[str, Any], str]]:
|
467
|
+
def run_batched(self, tasks: List[str]) -> Any:
|
374
468
|
"""
|
375
|
-
|
469
|
+
Enhanced batched execution with progress tracking
|
470
|
+
"""
|
471
|
+
if not tasks:
|
472
|
+
raise ValueError("Tasks list cannot be empty")
|
376
473
|
|
377
|
-
|
378
|
-
tasks (List[str]): A list of tasks or queries to give to all agents.
|
474
|
+
results = []
|
379
475
|
|
380
|
-
|
381
|
-
|
476
|
+
# Create progress bar if enabled
|
477
|
+
if self.show_progress:
|
478
|
+
self._create_progress_bar(len(tasks))
|
479
|
+
|
480
|
+
try:
|
481
|
+
for task in tasks:
|
482
|
+
result = self.run(task)
|
483
|
+
results.append(result)
|
484
|
+
self._update_progress()
|
485
|
+
finally:
|
486
|
+
if self._progress_bar and self.show_progress:
|
487
|
+
try:
|
488
|
+
self._progress_bar.close()
|
489
|
+
except Exception as e:
|
490
|
+
logger.warning(
|
491
|
+
f"Failed to close progress bar: {e}"
|
492
|
+
)
|
493
|
+
finally:
|
494
|
+
self._progress_bar = None
|
382
495
|
|
383
|
-
Example:
|
384
|
-
>>> tasks = ["Task 1", "Task 2"]
|
385
|
-
>>> results = workflow.run_batched(tasks)
|
386
|
-
>>> print(results)
|
387
|
-
"""
|
388
|
-
results = []
|
389
|
-
for task in tasks:
|
390
|
-
result = self.run(task)
|
391
|
-
results.append(result)
|
392
496
|
return results
|
393
497
|
|
498
|
+
def clear_cache(self):
|
499
|
+
"""Clear the task cache"""
|
500
|
+
self._cache.clear()
|
501
|
+
|
502
|
+
def get_cache_stats(self) -> Dict[str, int]:
|
503
|
+
"""Get cache statistics"""
|
504
|
+
return {
|
505
|
+
"cache_size": len(self._cache),
|
506
|
+
"max_cache_size": self.cache_size,
|
507
|
+
}
|
508
|
+
|
394
509
|
|
395
510
|
# if __name__ == "__main__":
|
396
511
|
# # Assuming you've already initialized some agents outside of this class
|