agno 2.0.0rc2__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +6009 -2874
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +385 -6
- agno/db/dynamo/dynamo.py +388 -81
- agno/db/dynamo/schemas.py +47 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +435 -64
- agno/db/firestore/schemas.py +11 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +384 -42
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +351 -66
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +339 -48
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +510 -37
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2036 -0
- agno/db/mongo/mongo.py +653 -76
- agno/db/mongo/schemas.py +13 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/mysql.py +687 -25
- agno/db/mysql/schemas.py +61 -37
- agno/db/mysql/utils.py +60 -2
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2001 -0
- agno/db/postgres/postgres.py +676 -57
- agno/db/postgres/schemas.py +43 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +344 -38
- agno/db/redis/schemas.py +18 -0
- agno/db/redis/utils.py +60 -2
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/memory.py +13 -0
- agno/db/singlestore/schemas.py +26 -1
- agno/db/singlestore/singlestore.py +687 -53
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2371 -0
- agno/db/sqlite/schemas.py +24 -0
- agno/db/sqlite/sqlite.py +774 -85
- agno/db/sqlite/utils.py +168 -5
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1361 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +50 -22
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +68 -1
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/discord/client.py +1 -0
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +1 -1
- agno/knowledge/chunking/semantic.py +40 -8
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +13 -0
- agno/knowledge/embedder/openai.py +37 -65
- agno/knowledge/embedder/sentence_transformer.py +8 -4
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +595 -187
- agno/knowledge/reader/base.py +9 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
- agno/knowledge/reader/json_reader.py +6 -5
- agno/knowledge/reader/markdown_reader.py +13 -13
- agno/knowledge/reader/pdf_reader.py +43 -68
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +51 -6
- agno/knowledge/reader/s3_reader.py +3 -15
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +13 -13
- agno/knowledge/reader/web_search_reader.py +2 -43
- agno/knowledge/reader/website_reader.py +43 -25
- agno/knowledge/reranker/__init__.py +3 -0
- agno/knowledge/types.py +9 -0
- agno/knowledge/utils.py +20 -0
- agno/media.py +339 -266
- agno/memory/manager.py +336 -82
- agno/models/aimlapi/aimlapi.py +2 -2
- agno/models/anthropic/claude.py +183 -37
- agno/models/aws/bedrock.py +52 -112
- agno/models/aws/claude.py +33 -1
- agno/models/azure/ai_foundry.py +33 -15
- agno/models/azure/openai_chat.py +25 -8
- agno/models/base.py +1011 -566
- agno/models/cerebras/cerebras.py +19 -13
- agno/models/cerebras/cerebras_openai.py +8 -5
- agno/models/cohere/chat.py +27 -1
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/deepinfra/deepinfra.py +2 -2
- agno/models/deepseek/deepseek.py +2 -2
- agno/models/fireworks/fireworks.py +2 -2
- agno/models/google/gemini.py +110 -37
- agno/models/groq/groq.py +28 -11
- agno/models/huggingface/huggingface.py +2 -1
- agno/models/internlm/internlm.py +2 -2
- agno/models/langdb/langdb.py +4 -4
- agno/models/litellm/chat.py +18 -1
- agno/models/litellm/litellm_openai.py +2 -2
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/message.py +143 -4
- agno/models/meta/llama.py +27 -10
- agno/models/meta/llama_openai.py +5 -17
- agno/models/nebius/nebius.py +6 -6
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/nvidia.py +2 -2
- agno/models/ollama/chat.py +60 -6
- agno/models/openai/chat.py +102 -43
- agno/models/openai/responses.py +103 -106
- agno/models/openrouter/openrouter.py +41 -3
- agno/models/perplexity/perplexity.py +4 -5
- agno/models/portkey/portkey.py +3 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +81 -5
- agno/models/sambanova/sambanova.py +2 -2
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/together.py +2 -2
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +2 -2
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +96 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +3 -2
- agno/os/app.py +543 -175
- agno/os/auth.py +24 -14
- agno/os/config.py +1 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +23 -7
- agno/os/interfaces/agui/router.py +27 -3
- agno/os/interfaces/agui/utils.py +242 -142
- agno/os/interfaces/base.py +6 -2
- agno/os/interfaces/slack/router.py +81 -23
- agno/os/interfaces/slack/slack.py +29 -14
- agno/os/interfaces/whatsapp/router.py +11 -4
- agno/os/interfaces/whatsapp/whatsapp.py +14 -7
- agno/os/mcp.py +111 -54
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +556 -139
- agno/os/routers/evals/evals.py +71 -34
- agno/os/routers/evals/schemas.py +31 -31
- agno/os/routers/evals/utils.py +6 -5
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/knowledge.py +185 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +158 -53
- agno/os/routers/memory/schemas.py +20 -16
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +499 -38
- agno/os/schema.py +308 -198
- agno/os/utils.py +401 -41
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +2 -2
- agno/reasoning/deepseek.py +2 -2
- agno/reasoning/default.py +3 -1
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +2 -2
- agno/reasoning/ollama.py +2 -2
- agno/reasoning/openai.py +7 -2
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +266 -112
- agno/run/base.py +53 -24
- agno/run/team.py +252 -111
- agno/run/workflow.py +156 -45
- agno/session/agent.py +105 -89
- agno/session/summary.py +65 -25
- agno/session/team.py +176 -96
- agno/session/workflow.py +406 -40
- agno/team/team.py +3854 -1692
- agno/tools/brightdata.py +3 -3
- agno/tools/cartesia.py +3 -5
- agno/tools/dalle.py +9 -8
- agno/tools/decorator.py +4 -2
- agno/tools/desi_vocal.py +2 -2
- agno/tools/duckduckgo.py +15 -11
- agno/tools/e2b.py +20 -13
- agno/tools/eleven_labs.py +26 -28
- agno/tools/exa.py +21 -16
- agno/tools/fal.py +4 -4
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +350 -0
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +257 -37
- agno/tools/giphy.py +2 -2
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +270 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/knowledge.py +3 -3
- agno/tools/lumalab.py +3 -3
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +11 -17
- agno/tools/memori.py +1 -53
- agno/tools/memory.py +419 -0
- agno/tools/models/azure_openai.py +2 -2
- agno/tools/models/gemini.py +3 -3
- agno/tools/models/groq.py +3 -5
- agno/tools/models/nebius.py +7 -7
- agno/tools/models_labs.py +25 -15
- agno/tools/notion.py +204 -0
- agno/tools/openai.py +4 -9
- agno/tools/opencv.py +3 -3
- agno/tools/parallel.py +314 -0
- agno/tools/replicate.py +7 -7
- agno/tools/scrapegraph.py +58 -31
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/slack.py +18 -3
- agno/tools/spider.py +2 -2
- agno/tools/tavily.py +146 -0
- agno/tools/whatsapp.py +1 -1
- agno/tools/workflow.py +278 -0
- agno/tools/yfinance.py +12 -11
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +27 -0
- agno/utils/common.py +90 -1
- agno/utils/events.py +222 -7
- agno/utils/gemini.py +181 -23
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +111 -0
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +95 -5
- agno/utils/media.py +188 -10
- agno/utils/merge_dict.py +22 -1
- agno/utils/message.py +60 -0
- agno/utils/models/claude.py +40 -11
- agno/utils/models/cohere.py +1 -1
- agno/utils/models/watsonx.py +1 -1
- agno/utils/openai.py +1 -1
- agno/utils/print_response/agent.py +105 -21
- agno/utils/print_response/team.py +103 -38
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/reasoning.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +16 -10
- agno/utils/string.py +41 -0
- agno/utils/team.py +98 -9
- agno/utils/tools.py +1 -1
- agno/vectordb/base.py +23 -4
- agno/vectordb/cassandra/cassandra.py +65 -9
- agno/vectordb/chroma/chromadb.py +182 -38
- agno/vectordb/clickhouse/clickhousedb.py +64 -11
- agno/vectordb/couchbase/couchbase.py +105 -10
- agno/vectordb/lancedb/lance_db.py +183 -135
- agno/vectordb/langchaindb/langchaindb.py +25 -7
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +46 -7
- agno/vectordb/milvus/milvus.py +126 -9
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +112 -7
- agno/vectordb/pgvector/pgvector.py +142 -21
- agno/vectordb/pineconedb/pineconedb.py +80 -8
- agno/vectordb/qdrant/qdrant.py +125 -39
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/singlestore/singlestore.py +111 -25
- agno/vectordb/surrealdb/surrealdb.py +31 -5
- agno/vectordb/upstashdb/upstashdb.py +76 -8
- agno/vectordb/weaviate/weaviate.py +86 -15
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +112 -18
- agno/workflow/loop.py +69 -10
- agno/workflow/parallel.py +266 -118
- agno/workflow/router.py +110 -17
- agno/workflow/step.py +645 -136
- agno/workflow/steps.py +65 -6
- agno/workflow/types.py +71 -33
- agno/workflow/workflow.py +2113 -300
- agno-2.3.0.dist-info/METADATA +618 -0
- agno-2.3.0.dist-info/RECORD +577 -0
- agno-2.3.0.dist-info/licenses/LICENSE +201 -0
- agno/knowledge/reader/url_reader.py +0 -128
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -610
- agno/utils/models/aws_claude.py +0 -170
- agno-2.0.0rc2.dist-info/METADATA +0 -355
- agno-2.0.0rc2.dist-info/RECORD +0 -515
- agno-2.0.0rc2.dist-info/licenses/LICENSE +0 -375
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
agno/workflow/parallel.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import warnings
|
|
2
3
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
|
+
from copy import deepcopy
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Union
|
|
5
7
|
from uuid import uuid4
|
|
6
8
|
|
|
7
9
|
from agno.models.metrics import Metrics
|
|
8
10
|
from agno.run.agent import RunOutputEvent
|
|
11
|
+
from agno.run.base import RunContext
|
|
9
12
|
from agno.run.team import TeamRunOutputEvent
|
|
10
13
|
from agno.run.workflow import (
|
|
11
14
|
ParallelExecutionCompletedEvent,
|
|
@@ -13,7 +16,9 @@ from agno.run.workflow import (
|
|
|
13
16
|
WorkflowRunOutput,
|
|
14
17
|
WorkflowRunOutputEvent,
|
|
15
18
|
)
|
|
19
|
+
from agno.session.workflow import WorkflowSession
|
|
16
20
|
from agno.utils.log import log_debug, logger
|
|
21
|
+
from agno.utils.merge_dict import merge_parallel_session_states
|
|
17
22
|
from agno.workflow.condition import Condition
|
|
18
23
|
from agno.workflow.step import Step
|
|
19
24
|
from agno.workflow.types import StepInput, StepOutput, StepType
|
|
@@ -197,7 +202,11 @@ class Parallel:
|
|
|
197
202
|
user_id: Optional[str] = None,
|
|
198
203
|
workflow_run_response: Optional[WorkflowRunOutput] = None,
|
|
199
204
|
store_executor_outputs: bool = True,
|
|
205
|
+
run_context: Optional[RunContext] = None,
|
|
200
206
|
session_state: Optional[Dict[str, Any]] = None,
|
|
207
|
+
workflow_session: Optional[WorkflowSession] = None,
|
|
208
|
+
add_workflow_history_to_steps: Optional[bool] = False,
|
|
209
|
+
num_history_runs: int = 3,
|
|
201
210
|
) -> StepOutput:
|
|
202
211
|
"""Execute all steps in parallel and return aggregated result"""
|
|
203
212
|
# Use workflow logger for parallel orchestration
|
|
@@ -205,9 +214,24 @@ class Parallel:
|
|
|
205
214
|
|
|
206
215
|
self._prepare_steps()
|
|
207
216
|
|
|
217
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
218
|
+
session_state_copies = []
|
|
219
|
+
for _ in range(len(self.steps)):
|
|
220
|
+
# If using run context, no need to deepcopy the state. We want the direct reference.
|
|
221
|
+
if run_context is not None and run_context.session_state is not None:
|
|
222
|
+
session_state_copies.append(run_context.session_state)
|
|
223
|
+
else:
|
|
224
|
+
if session_state is not None:
|
|
225
|
+
session_state_copies.append(deepcopy(session_state))
|
|
226
|
+
else:
|
|
227
|
+
session_state_copies.append({})
|
|
228
|
+
|
|
208
229
|
def execute_step_with_index(step_with_index):
|
|
209
230
|
"""Execute a single step and preserve its original index"""
|
|
210
231
|
idx, step = step_with_index
|
|
232
|
+
# Use the individual session_state copy for this step
|
|
233
|
+
step_session_state = session_state_copies[idx]
|
|
234
|
+
|
|
211
235
|
try:
|
|
212
236
|
step_result = step.execute(
|
|
213
237
|
step_input,
|
|
@@ -215,9 +239,13 @@ class Parallel:
|
|
|
215
239
|
user_id=user_id,
|
|
216
240
|
workflow_run_response=workflow_run_response,
|
|
217
241
|
store_executor_outputs=store_executor_outputs,
|
|
218
|
-
|
|
242
|
+
workflow_session=workflow_session,
|
|
243
|
+
add_workflow_history_to_steps=add_workflow_history_to_steps,
|
|
244
|
+
num_history_runs=num_history_runs,
|
|
245
|
+
run_context=run_context,
|
|
246
|
+
session_state=step_session_state,
|
|
219
247
|
) # type: ignore[union-attr]
|
|
220
|
-
return idx, step_result
|
|
248
|
+
return idx, step_result, step_session_state
|
|
221
249
|
except Exception as exc:
|
|
222
250
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
223
251
|
logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
|
|
@@ -229,6 +257,7 @@ class Parallel:
|
|
|
229
257
|
success=False,
|
|
230
258
|
error=str(exc),
|
|
231
259
|
),
|
|
260
|
+
step_session_state,
|
|
232
261
|
)
|
|
233
262
|
|
|
234
263
|
# Use index to preserve order
|
|
@@ -241,12 +270,14 @@ class Parallel:
|
|
|
241
270
|
for indexed_step in indexed_steps
|
|
242
271
|
}
|
|
243
272
|
|
|
244
|
-
# Collect results
|
|
273
|
+
# Collect results and modified session_state copies
|
|
245
274
|
results_with_indices = []
|
|
275
|
+
modified_session_states = []
|
|
246
276
|
for future in as_completed(future_to_index):
|
|
247
277
|
try:
|
|
248
|
-
index, result = future.result()
|
|
278
|
+
index, result, modified_session_state = future.result()
|
|
249
279
|
results_with_indices.append((index, result))
|
|
280
|
+
modified_session_states.append(modified_session_state)
|
|
250
281
|
step_name = getattr(self.steps[index], "name", f"step_{index}")
|
|
251
282
|
log_debug(f"Parallel step {step_name} completed")
|
|
252
283
|
except Exception as e:
|
|
@@ -265,6 +296,9 @@ class Parallel:
|
|
|
265
296
|
)
|
|
266
297
|
)
|
|
267
298
|
|
|
299
|
+
if run_context is None and session_state is not None:
|
|
300
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
301
|
+
|
|
268
302
|
# Sort by original index to preserve order
|
|
269
303
|
results_with_indices.sort(key=lambda x: x[0])
|
|
270
304
|
results = [result for _, result in results_with_indices]
|
|
@@ -290,12 +324,18 @@ class Parallel:
|
|
|
290
324
|
step_input: StepInput,
|
|
291
325
|
session_id: Optional[str] = None,
|
|
292
326
|
user_id: Optional[str] = None,
|
|
327
|
+
stream_events: bool = False,
|
|
293
328
|
stream_intermediate_steps: bool = False,
|
|
329
|
+
stream_executor_events: bool = True,
|
|
294
330
|
workflow_run_response: Optional[WorkflowRunOutput] = None,
|
|
295
331
|
step_index: Optional[Union[int, tuple]] = None,
|
|
296
332
|
store_executor_outputs: bool = True,
|
|
333
|
+
run_context: Optional[RunContext] = None,
|
|
297
334
|
session_state: Optional[Dict[str, Any]] = None,
|
|
298
335
|
parent_step_id: Optional[str] = None,
|
|
336
|
+
workflow_session: Optional[WorkflowSession] = None,
|
|
337
|
+
add_workflow_history_to_steps: Optional[bool] = False,
|
|
338
|
+
num_history_runs: int = 3,
|
|
299
339
|
) -> Iterator[Union[WorkflowRunOutputEvent, StepOutput]]:
|
|
300
340
|
"""Execute all steps in parallel with streaming support"""
|
|
301
341
|
log_debug(f"Parallel Start: {self.name} ({len(self.steps)} steps)", center=True, symbol="=")
|
|
@@ -304,7 +344,28 @@ class Parallel:
|
|
|
304
344
|
|
|
305
345
|
self._prepare_steps()
|
|
306
346
|
|
|
307
|
-
|
|
347
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
348
|
+
session_state_copies = []
|
|
349
|
+
for _ in range(len(self.steps)):
|
|
350
|
+
# If using run context, no need to deepcopy the state. We want the direct reference.
|
|
351
|
+
if run_context is not None and run_context.session_state is not None:
|
|
352
|
+
session_state_copies.append(run_context.session_state)
|
|
353
|
+
else:
|
|
354
|
+
if session_state is not None:
|
|
355
|
+
session_state_copies.append(deepcopy(session_state))
|
|
356
|
+
else:
|
|
357
|
+
session_state_copies.append({})
|
|
358
|
+
|
|
359
|
+
# Considering both stream_events and stream_intermediate_steps (deprecated)
|
|
360
|
+
if stream_intermediate_steps is not None:
|
|
361
|
+
warnings.warn(
|
|
362
|
+
"The 'stream_intermediate_steps' parameter is deprecated and will be removed in future versions. Use 'stream_events' instead.",
|
|
363
|
+
DeprecationWarning,
|
|
364
|
+
stacklevel=2,
|
|
365
|
+
)
|
|
366
|
+
stream_events = stream_events or stream_intermediate_steps
|
|
367
|
+
|
|
368
|
+
if stream_events and workflow_run_response:
|
|
308
369
|
# Yield parallel step started event
|
|
309
370
|
yield ParallelExecutionStartedEvent(
|
|
310
371
|
run_id=workflow_run_response.run_id or "",
|
|
@@ -318,11 +379,20 @@ class Parallel:
|
|
|
318
379
|
parent_step_id=parent_step_id,
|
|
319
380
|
)
|
|
320
381
|
|
|
382
|
+
import queue
|
|
383
|
+
|
|
384
|
+
event_queue = queue.Queue() # type: ignore
|
|
385
|
+
step_results = []
|
|
386
|
+
modified_session_states = []
|
|
387
|
+
|
|
321
388
|
def execute_step_stream_with_index(step_with_index):
|
|
322
|
-
"""Execute a single step with streaming and
|
|
389
|
+
"""Execute a single step with streaming and put events in queue immediately"""
|
|
323
390
|
idx, step = step_with_index
|
|
391
|
+
# Use the individual session_state copy for this step
|
|
392
|
+
step_session_state = session_state_copies[idx]
|
|
393
|
+
|
|
324
394
|
try:
|
|
325
|
-
|
|
395
|
+
step_outputs = []
|
|
326
396
|
|
|
327
397
|
# If step_index is None or integer (main step): create (step_index, sub_index)
|
|
328
398
|
# If step_index is tuple (child step): all parallel sub-steps get same index
|
|
@@ -338,77 +408,87 @@ class Parallel:
|
|
|
338
408
|
step_input,
|
|
339
409
|
session_id=session_id,
|
|
340
410
|
user_id=user_id,
|
|
341
|
-
|
|
411
|
+
stream_events=stream_events,
|
|
412
|
+
stream_executor_events=stream_executor_events,
|
|
342
413
|
workflow_run_response=workflow_run_response,
|
|
343
414
|
step_index=sub_step_index,
|
|
344
415
|
store_executor_outputs=store_executor_outputs,
|
|
345
|
-
session_state=
|
|
416
|
+
session_state=step_session_state,
|
|
346
417
|
parent_step_id=parallel_step_id,
|
|
418
|
+
workflow_session=workflow_session,
|
|
419
|
+
add_workflow_history_to_steps=add_workflow_history_to_steps,
|
|
420
|
+
num_history_runs=num_history_runs,
|
|
347
421
|
):
|
|
348
|
-
|
|
349
|
-
|
|
422
|
+
# Put event immediately in queue
|
|
423
|
+
event_queue.put(("event", idx, event))
|
|
424
|
+
if isinstance(event, StepOutput):
|
|
425
|
+
step_outputs.append(event)
|
|
426
|
+
|
|
427
|
+
# Signal completion for this step
|
|
428
|
+
event_queue.put(("complete", idx, step_outputs, step_session_state))
|
|
429
|
+
return idx, step_outputs, step_session_state
|
|
350
430
|
except Exception as exc:
|
|
351
431
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
352
432
|
logger.error(f"Parallel step {parallel_step_name} streaming failed: {exc}")
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
content=f"Step {parallel_step_name} failed: {str(exc)}",
|
|
359
|
-
success=False,
|
|
360
|
-
error=str(exc),
|
|
361
|
-
)
|
|
362
|
-
],
|
|
433
|
+
error_event = StepOutput(
|
|
434
|
+
step_name=parallel_step_name,
|
|
435
|
+
content=f"Step {parallel_step_name} failed: {str(exc)}",
|
|
436
|
+
success=False,
|
|
437
|
+
error=str(exc),
|
|
363
438
|
)
|
|
439
|
+
event_queue.put(("event", idx, error_event))
|
|
440
|
+
event_queue.put(("complete", idx, [error_event], step_session_state))
|
|
441
|
+
return idx, [error_event], step_session_state
|
|
364
442
|
|
|
365
|
-
#
|
|
443
|
+
# Submit all parallel tasks
|
|
366
444
|
indexed_steps = list(enumerate(self.steps))
|
|
367
|
-
all_events_with_indices = []
|
|
368
|
-
step_results = []
|
|
369
445
|
|
|
370
446
|
with ThreadPoolExecutor(max_workers=len(self.steps)) as executor:
|
|
371
|
-
# Submit all tasks
|
|
372
|
-
|
|
373
|
-
executor.submit(execute_step_stream_with_index, indexed_step): indexed_step[0]
|
|
374
|
-
for indexed_step in indexed_steps
|
|
375
|
-
}
|
|
447
|
+
# Submit all tasks
|
|
448
|
+
futures = [executor.submit(execute_step_stream_with_index, indexed_step) for indexed_step in indexed_steps]
|
|
376
449
|
|
|
377
|
-
#
|
|
378
|
-
|
|
450
|
+
# Process events from queue as they arrive
|
|
451
|
+
completed_steps = 0
|
|
452
|
+
total_steps = len(self.steps)
|
|
453
|
+
|
|
454
|
+
while completed_steps < total_steps:
|
|
379
455
|
try:
|
|
380
|
-
|
|
381
|
-
all_events_with_indices.append((index, events))
|
|
456
|
+
message_type, step_idx, *data = event_queue.get(timeout=1.0)
|
|
382
457
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
458
|
+
if message_type == "event":
|
|
459
|
+
event = data[0]
|
|
460
|
+
# Yield events immediately as they arrive (except StepOutputs)
|
|
461
|
+
if not isinstance(event, StepOutput):
|
|
462
|
+
yield event
|
|
387
463
|
|
|
388
|
-
|
|
389
|
-
|
|
464
|
+
elif message_type == "complete":
|
|
465
|
+
step_outputs, step_session_state = data
|
|
466
|
+
step_results.extend(step_outputs)
|
|
467
|
+
modified_session_states.append(step_session_state)
|
|
468
|
+
completed_steps += 1
|
|
469
|
+
|
|
470
|
+
step_name = getattr(self.steps[step_idx], "name", f"step_{step_idx}")
|
|
471
|
+
log_debug(f"Parallel step {step_name} streaming completed")
|
|
472
|
+
|
|
473
|
+
except queue.Empty:
|
|
474
|
+
for i, future in enumerate(futures):
|
|
475
|
+
if future.done() and future.exception():
|
|
476
|
+
logger.error(f"Parallel step {i} failed: {future.exception()}")
|
|
477
|
+
if completed_steps < total_steps:
|
|
478
|
+
completed_steps += 1
|
|
390
479
|
except Exception as e:
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
logger.error(f"Parallel step {step_name} streaming failed: {e}")
|
|
394
|
-
error_event = StepOutput(
|
|
395
|
-
step_name=step_name,
|
|
396
|
-
content=f"Step {step_name} failed: {str(e)}",
|
|
397
|
-
success=False,
|
|
398
|
-
error=str(e),
|
|
399
|
-
)
|
|
400
|
-
all_events_with_indices.append((index, [error_event]))
|
|
401
|
-
step_results.append(error_event)
|
|
480
|
+
logger.error(f"Error processing parallel step events: {e}")
|
|
481
|
+
completed_steps += 1
|
|
402
482
|
|
|
403
|
-
|
|
404
|
-
|
|
483
|
+
for future in futures:
|
|
484
|
+
try:
|
|
485
|
+
future.result()
|
|
486
|
+
except Exception as e:
|
|
487
|
+
logger.error(f"Future completion error: {e}")
|
|
405
488
|
|
|
406
|
-
#
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
# Only yield non-StepOutput events during streaming to avoid duplication
|
|
410
|
-
if not isinstance(event, StepOutput):
|
|
411
|
-
yield event
|
|
489
|
+
# Merge all session_state changes back into the original session_state
|
|
490
|
+
if run_context is None and session_state is not None:
|
|
491
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
412
492
|
|
|
413
493
|
# Flatten step_results - handle steps that return List[StepOutput] (like Condition/Loop)
|
|
414
494
|
flattened_step_results: List[StepOutput] = []
|
|
@@ -426,7 +506,7 @@ class Parallel:
|
|
|
426
506
|
|
|
427
507
|
log_debug(f"Parallel End: {self.name} ({len(self.steps)} steps)", center=True, symbol="=")
|
|
428
508
|
|
|
429
|
-
if
|
|
509
|
+
if stream_events and workflow_run_response:
|
|
430
510
|
# Yield parallel step completed event
|
|
431
511
|
yield ParallelExecutionCompletedEvent(
|
|
432
512
|
run_id=workflow_run_response.run_id or "",
|
|
@@ -436,7 +516,7 @@ class Parallel:
|
|
|
436
516
|
step_name=self.name,
|
|
437
517
|
step_index=step_index,
|
|
438
518
|
parallel_step_count=len(self.steps),
|
|
439
|
-
step_results=
|
|
519
|
+
step_results=flattened_step_results,
|
|
440
520
|
step_id=parallel_step_id,
|
|
441
521
|
parent_step_id=parent_step_id,
|
|
442
522
|
)
|
|
@@ -448,7 +528,11 @@ class Parallel:
|
|
|
448
528
|
user_id: Optional[str] = None,
|
|
449
529
|
workflow_run_response: Optional[WorkflowRunOutput] = None,
|
|
450
530
|
store_executor_outputs: bool = True,
|
|
531
|
+
run_context: Optional[RunContext] = None,
|
|
451
532
|
session_state: Optional[Dict[str, Any]] = None,
|
|
533
|
+
workflow_session: Optional[WorkflowSession] = None,
|
|
534
|
+
add_workflow_history_to_steps: Optional[bool] = False,
|
|
535
|
+
num_history_runs: int = 3,
|
|
452
536
|
) -> StepOutput:
|
|
453
537
|
"""Execute all steps in parallel using asyncio and return aggregated result"""
|
|
454
538
|
# Use workflow logger for async parallel orchestration
|
|
@@ -456,9 +540,24 @@ class Parallel:
|
|
|
456
540
|
|
|
457
541
|
self._prepare_steps()
|
|
458
542
|
|
|
543
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
544
|
+
session_state_copies = []
|
|
545
|
+
for _ in range(len(self.steps)):
|
|
546
|
+
# If using run context, no need to deepcopy the state. We want the direct reference.
|
|
547
|
+
if run_context is not None and run_context.session_state is not None:
|
|
548
|
+
session_state_copies.append(run_context.session_state)
|
|
549
|
+
else:
|
|
550
|
+
if session_state is not None:
|
|
551
|
+
session_state_copies.append(deepcopy(session_state))
|
|
552
|
+
else:
|
|
553
|
+
session_state_copies.append({})
|
|
554
|
+
|
|
459
555
|
async def execute_step_async_with_index(step_with_index):
|
|
460
556
|
"""Execute a single step asynchronously and preserve its original index"""
|
|
461
557
|
idx, step = step_with_index
|
|
558
|
+
# Use the individual session_state copy for this step
|
|
559
|
+
step_session_state = session_state_copies[idx]
|
|
560
|
+
|
|
462
561
|
try:
|
|
463
562
|
inner_step_result = await step.aexecute(
|
|
464
563
|
step_input,
|
|
@@ -466,9 +565,12 @@ class Parallel:
|
|
|
466
565
|
user_id=user_id,
|
|
467
566
|
workflow_run_response=workflow_run_response,
|
|
468
567
|
store_executor_outputs=store_executor_outputs,
|
|
469
|
-
|
|
568
|
+
workflow_session=workflow_session,
|
|
569
|
+
add_workflow_history_to_steps=add_workflow_history_to_steps,
|
|
570
|
+
num_history_runs=num_history_runs,
|
|
571
|
+
session_state=step_session_state,
|
|
470
572
|
) # type: ignore[union-attr]
|
|
471
|
-
return idx, inner_step_result
|
|
573
|
+
return idx, inner_step_result, step_session_state
|
|
472
574
|
except Exception as exc:
|
|
473
575
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
474
576
|
logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
|
|
@@ -480,6 +582,7 @@ class Parallel:
|
|
|
480
582
|
success=False,
|
|
481
583
|
error=str(exc),
|
|
482
584
|
),
|
|
585
|
+
step_session_state,
|
|
483
586
|
)
|
|
484
587
|
|
|
485
588
|
# Use index to preserve order
|
|
@@ -493,6 +596,7 @@ class Parallel:
|
|
|
493
596
|
|
|
494
597
|
# Process results and handle exceptions, preserving order
|
|
495
598
|
processed_results_with_indices = []
|
|
599
|
+
modified_session_states = []
|
|
496
600
|
for i, result in enumerate(results_with_indices):
|
|
497
601
|
if isinstance(result, Exception):
|
|
498
602
|
step_name = getattr(self.steps[i], "name", f"step_{i}")
|
|
@@ -508,12 +612,19 @@ class Parallel:
|
|
|
508
612
|
),
|
|
509
613
|
)
|
|
510
614
|
)
|
|
615
|
+
# Still collect the session state copy for failed steps
|
|
616
|
+
modified_session_states.append(session_state_copies[i])
|
|
511
617
|
else:
|
|
512
|
-
index, step_result = result # type: ignore[misc]
|
|
618
|
+
index, step_result, modified_session_state = result # type: ignore[misc]
|
|
513
619
|
processed_results_with_indices.append((index, step_result))
|
|
620
|
+
modified_session_states.append(modified_session_state)
|
|
514
621
|
step_name = getattr(self.steps[index], "name", f"step_{index}")
|
|
515
622
|
log_debug(f"Parallel step {step_name} completed")
|
|
516
623
|
|
|
624
|
+
# Smart merge all session_state changes back into the original session_state
|
|
625
|
+
if run_context is None and session_state is not None:
|
|
626
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
627
|
+
|
|
517
628
|
# Sort by original index to preserve order
|
|
518
629
|
processed_results_with_indices.sort(key=lambda x: x[0])
|
|
519
630
|
results = [result for _, result in processed_results_with_indices]
|
|
@@ -539,12 +650,18 @@ class Parallel:
|
|
|
539
650
|
step_input: StepInput,
|
|
540
651
|
session_id: Optional[str] = None,
|
|
541
652
|
user_id: Optional[str] = None,
|
|
653
|
+
stream_events: bool = False,
|
|
542
654
|
stream_intermediate_steps: bool = False,
|
|
655
|
+
stream_executor_events: bool = True,
|
|
543
656
|
workflow_run_response: Optional[WorkflowRunOutput] = None,
|
|
544
657
|
step_index: Optional[Union[int, tuple]] = None,
|
|
545
658
|
store_executor_outputs: bool = True,
|
|
659
|
+
run_context: Optional[RunContext] = None,
|
|
546
660
|
session_state: Optional[Dict[str, Any]] = None,
|
|
547
661
|
parent_step_id: Optional[str] = None,
|
|
662
|
+
workflow_session: Optional[WorkflowSession] = None,
|
|
663
|
+
add_workflow_history_to_steps: Optional[bool] = False,
|
|
664
|
+
num_history_runs: int = 3,
|
|
548
665
|
) -> AsyncIterator[Union[WorkflowRunOutputEvent, TeamRunOutputEvent, RunOutputEvent, StepOutput]]:
|
|
549
666
|
"""Execute all steps in parallel with async streaming support"""
|
|
550
667
|
log_debug(f"Parallel Start: {self.name} ({len(self.steps)} steps)", center=True, symbol="=")
|
|
@@ -553,7 +670,28 @@ class Parallel:
|
|
|
553
670
|
|
|
554
671
|
self._prepare_steps()
|
|
555
672
|
|
|
556
|
-
|
|
673
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
674
|
+
session_state_copies = []
|
|
675
|
+
for _ in range(len(self.steps)):
|
|
676
|
+
# If using run context, no need to deepcopy the state. We want the direct reference.
|
|
677
|
+
if run_context is not None and run_context.session_state is not None:
|
|
678
|
+
session_state_copies.append(run_context.session_state)
|
|
679
|
+
else:
|
|
680
|
+
if session_state is not None:
|
|
681
|
+
session_state_copies.append(deepcopy(session_state))
|
|
682
|
+
else:
|
|
683
|
+
session_state_copies.append({})
|
|
684
|
+
|
|
685
|
+
# Considering both stream_events and stream_intermediate_steps (deprecated)
|
|
686
|
+
if stream_intermediate_steps is not None:
|
|
687
|
+
warnings.warn(
|
|
688
|
+
"The 'stream_intermediate_steps' parameter is deprecated and will be removed in future versions. Use 'stream_events' instead.",
|
|
689
|
+
DeprecationWarning,
|
|
690
|
+
stacklevel=2,
|
|
691
|
+
)
|
|
692
|
+
stream_events = stream_events or stream_intermediate_steps
|
|
693
|
+
|
|
694
|
+
if stream_events and workflow_run_response:
|
|
557
695
|
# Yield parallel step started event
|
|
558
696
|
yield ParallelExecutionStartedEvent(
|
|
559
697
|
run_id=workflow_run_response.run_id or "",
|
|
@@ -567,11 +705,20 @@ class Parallel:
|
|
|
567
705
|
parent_step_id=parent_step_id,
|
|
568
706
|
)
|
|
569
707
|
|
|
708
|
+
import asyncio
|
|
709
|
+
|
|
710
|
+
event_queue = asyncio.Queue() # type: ignore
|
|
711
|
+
step_results = []
|
|
712
|
+
modified_session_states = []
|
|
713
|
+
|
|
570
714
|
async def execute_step_stream_async_with_index(step_with_index):
|
|
571
|
-
"""Execute a single step with async streaming and
|
|
715
|
+
"""Execute a single step with async streaming and yield events immediately"""
|
|
572
716
|
idx, step = step_with_index
|
|
717
|
+
# Use the individual session_state copy for this step
|
|
718
|
+
step_session_state = session_state_copies[idx]
|
|
719
|
+
|
|
573
720
|
try:
|
|
574
|
-
|
|
721
|
+
step_outputs = []
|
|
575
722
|
|
|
576
723
|
# If step_index is None or integer (main step): create (step_index, sub_index)
|
|
577
724
|
# If step_index is tuple (child step): all parallel sub-steps get same index
|
|
@@ -587,75 +734,76 @@ class Parallel:
|
|
|
587
734
|
step_input,
|
|
588
735
|
session_id=session_id,
|
|
589
736
|
user_id=user_id,
|
|
590
|
-
|
|
737
|
+
stream_events=stream_events,
|
|
738
|
+
stream_executor_events=stream_executor_events,
|
|
591
739
|
workflow_run_response=workflow_run_response,
|
|
592
740
|
step_index=sub_step_index,
|
|
593
741
|
store_executor_outputs=store_executor_outputs,
|
|
594
|
-
session_state=
|
|
742
|
+
session_state=step_session_state,
|
|
743
|
+
run_context=run_context,
|
|
595
744
|
parent_step_id=parallel_step_id,
|
|
745
|
+
workflow_session=workflow_session,
|
|
746
|
+
add_workflow_history_to_steps=add_workflow_history_to_steps,
|
|
747
|
+
num_history_runs=num_history_runs,
|
|
596
748
|
): # type: ignore[union-attr]
|
|
597
|
-
|
|
598
|
-
|
|
749
|
+
# Yield events immediately to the queue
|
|
750
|
+
await event_queue.put(("event", idx, event))
|
|
751
|
+
if isinstance(event, StepOutput):
|
|
752
|
+
step_outputs.append(event)
|
|
753
|
+
|
|
754
|
+
# Signal completion for this step
|
|
755
|
+
await event_queue.put(("complete", idx, step_outputs, step_session_state))
|
|
756
|
+
return idx, step_outputs, step_session_state
|
|
599
757
|
except Exception as e:
|
|
600
758
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
601
759
|
logger.error(f"Parallel step {parallel_step_name} async streaming failed: {e}")
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
content=f"Step {parallel_step_name} failed: {str(e)}",
|
|
608
|
-
success=False,
|
|
609
|
-
error=str(e),
|
|
610
|
-
)
|
|
611
|
-
],
|
|
760
|
+
error_event = StepOutput(
|
|
761
|
+
step_name=parallel_step_name,
|
|
762
|
+
content=f"Step {parallel_step_name} failed: {str(e)}",
|
|
763
|
+
success=False,
|
|
764
|
+
error=str(e),
|
|
612
765
|
)
|
|
766
|
+
await event_queue.put(("event", idx, error_event))
|
|
767
|
+
await event_queue.put(("complete", idx, [error_event], step_session_state))
|
|
768
|
+
return idx, [error_event], step_session_state
|
|
613
769
|
|
|
614
|
-
#
|
|
770
|
+
# Start all parallel tasks
|
|
615
771
|
indexed_steps = list(enumerate(self.steps))
|
|
616
|
-
|
|
617
|
-
|
|
772
|
+
tasks = [
|
|
773
|
+
asyncio.create_task(execute_step_stream_async_with_index(indexed_step)) for indexed_step in indexed_steps
|
|
774
|
+
]
|
|
618
775
|
|
|
619
|
-
#
|
|
620
|
-
|
|
776
|
+
# Process events as they arrive and track completion
|
|
777
|
+
completed_steps = 0
|
|
778
|
+
total_steps = len(self.steps)
|
|
621
779
|
|
|
622
|
-
|
|
623
|
-
|
|
780
|
+
while completed_steps < total_steps:
|
|
781
|
+
try:
|
|
782
|
+
message_type, step_idx, *data = await event_queue.get()
|
|
624
783
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
logger.error(f"Parallel step {step_name} async streaming failed: {result}")
|
|
630
|
-
error_event = StepOutput(
|
|
631
|
-
step_name=step_name,
|
|
632
|
-
content=f"Step {step_name} failed: {str(result)}",
|
|
633
|
-
success=False,
|
|
634
|
-
error=str(result),
|
|
635
|
-
)
|
|
636
|
-
all_events_with_indices.append((i, [error_event]))
|
|
637
|
-
step_results.append(error_event)
|
|
638
|
-
else:
|
|
639
|
-
index, events = result # type: ignore[misc]
|
|
640
|
-
all_events_with_indices.append((index, events))
|
|
784
|
+
if message_type == "event":
|
|
785
|
+
event = data[0]
|
|
786
|
+
if not isinstance(event, StepOutput):
|
|
787
|
+
yield event
|
|
641
788
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
if step_outputs:
|
|
789
|
+
elif message_type == "complete":
|
|
790
|
+
step_outputs, step_session_state = data
|
|
645
791
|
step_results.extend(step_outputs)
|
|
792
|
+
modified_session_states.append(step_session_state)
|
|
793
|
+
completed_steps += 1
|
|
646
794
|
|
|
647
|
-
|
|
648
|
-
|
|
795
|
+
step_name = getattr(self.steps[step_idx], "name", f"step_{step_idx}")
|
|
796
|
+
log_debug(f"Parallel step {step_name} async streaming completed")
|
|
797
|
+
|
|
798
|
+
except Exception as e:
|
|
799
|
+
logger.error(f"Error processing parallel step events: {e}")
|
|
800
|
+
completed_steps += 1
|
|
649
801
|
|
|
650
|
-
|
|
651
|
-
all_events_with_indices.sort(key=lambda x: x[0])
|
|
802
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
652
803
|
|
|
653
|
-
#
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
# Only yield non-StepOutput events during streaming to avoid duplication
|
|
657
|
-
if not isinstance(event, StepOutput):
|
|
658
|
-
yield event
|
|
804
|
+
# Merge all session_state changes back into the original session_state
|
|
805
|
+
if run_context is None and session_state is not None:
|
|
806
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
659
807
|
|
|
660
808
|
# Flatten step_results - handle steps that return List[StepOutput] (like Condition/Loop)
|
|
661
809
|
flattened_step_results: List[StepOutput] = []
|
|
@@ -673,7 +821,7 @@ class Parallel:
|
|
|
673
821
|
|
|
674
822
|
log_debug(f"Parallel End: {self.name} ({len(self.steps)} steps)", center=True, symbol="=")
|
|
675
823
|
|
|
676
|
-
if
|
|
824
|
+
if stream_events and workflow_run_response:
|
|
677
825
|
# Yield parallel step completed event
|
|
678
826
|
yield ParallelExecutionCompletedEvent(
|
|
679
827
|
run_id=workflow_run_response.run_id or "",
|
|
@@ -683,7 +831,7 @@ class Parallel:
|
|
|
683
831
|
step_name=self.name,
|
|
684
832
|
step_index=step_index,
|
|
685
833
|
parallel_step_count=len(self.steps),
|
|
686
|
-
step_results=
|
|
834
|
+
step_results=flattened_step_results,
|
|
687
835
|
step_id=parallel_step_id,
|
|
688
836
|
parent_step_id=parent_step_id,
|
|
689
837
|
)
|