agno 2.1.1__py3-none-any.whl → 2.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +12 -0
- agno/db/base.py +8 -4
- agno/db/dynamo/dynamo.py +69 -17
- agno/db/firestore/firestore.py +68 -29
- agno/db/gcs_json/gcs_json_db.py +68 -17
- agno/db/in_memory/in_memory_db.py +83 -14
- agno/db/json/json_db.py +79 -15
- agno/db/mongo/mongo.py +27 -8
- agno/db/mysql/mysql.py +17 -3
- agno/db/postgres/postgres.py +21 -3
- agno/db/redis/redis.py +38 -11
- agno/db/singlestore/singlestore.py +14 -3
- agno/db/sqlite/sqlite.py +34 -46
- agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
- agno/knowledge/reader/pdf_reader.py +28 -52
- agno/knowledge/reader/reader_factory.py +12 -0
- agno/memory/manager.py +12 -4
- agno/models/anthropic/claude.py +4 -1
- agno/models/aws/bedrock.py +52 -112
- agno/os/app.py +24 -30
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +252 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/router.py +12 -0
- agno/os/router.py +38 -8
- agno/os/routers/memory/memory.py +5 -3
- agno/os/routers/memory/schemas.py +1 -0
- agno/os/utils.py +36 -10
- agno/team/team.py +12 -0
- agno/tools/mcp.py +46 -1
- agno/utils/merge_dict.py +22 -1
- agno/utils/streamlit.py +1 -1
- agno/workflow/parallel.py +90 -14
- agno/workflow/step.py +30 -27
- agno/workflow/workflow.py +5 -3
- {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/METADATA +16 -14
- {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/RECORD +41 -36
- {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/WHEEL +0 -0
- {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/top_level.txt +0 -0
|
@@ -34,12 +34,18 @@ async def run_agent(agent: Agent, run_input: RunAgentInput) -> AsyncIterator[Bas
|
|
|
34
34
|
messages = convert_agui_messages_to_agno_messages(run_input.messages or [])
|
|
35
35
|
yield RunStartedEvent(type=EventType.RUN_STARTED, thread_id=run_input.thread_id, run_id=run_id)
|
|
36
36
|
|
|
37
|
+
# Look for user_id in run_input.forwarded_props
|
|
38
|
+
user_id = None
|
|
39
|
+
if run_input.forwarded_props and isinstance(run_input.forwarded_props, dict):
|
|
40
|
+
user_id = run_input.forwarded_props.get("user_id")
|
|
41
|
+
|
|
37
42
|
# Request streaming response from agent
|
|
38
43
|
response_stream = agent.arun(
|
|
39
44
|
input=messages,
|
|
40
45
|
session_id=run_input.thread_id,
|
|
41
46
|
stream=True,
|
|
42
47
|
stream_intermediate_steps=True,
|
|
48
|
+
user_id=user_id,
|
|
43
49
|
)
|
|
44
50
|
|
|
45
51
|
# Stream the response content in AG-UI format
|
|
@@ -64,12 +70,18 @@ async def run_team(team: Team, input: RunAgentInput) -> AsyncIterator[BaseEvent]
|
|
|
64
70
|
messages = convert_agui_messages_to_agno_messages(input.messages or [])
|
|
65
71
|
yield RunStartedEvent(type=EventType.RUN_STARTED, thread_id=input.thread_id, run_id=run_id)
|
|
66
72
|
|
|
73
|
+
# Look for user_id in input.forwarded_props
|
|
74
|
+
user_id = None
|
|
75
|
+
if input.forwarded_props and isinstance(input.forwarded_props, dict):
|
|
76
|
+
user_id = input.forwarded_props.get("user_id")
|
|
77
|
+
|
|
67
78
|
# Request streaming response from team
|
|
68
79
|
response_stream = team.arun(
|
|
69
80
|
input=messages,
|
|
70
81
|
session_id=input.thread_id,
|
|
71
82
|
stream=True,
|
|
72
83
|
stream_intermediate_steps=True,
|
|
84
|
+
user_id=user_id,
|
|
73
85
|
)
|
|
74
86
|
|
|
75
87
|
# Stream the response content in AG-UI format
|
agno/os/router.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from itertools import chain
|
|
2
3
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
|
|
3
4
|
from uuid import uuid4
|
|
4
5
|
|
|
@@ -643,7 +644,7 @@ def get_base_router(
|
|
|
643
644
|
os_id=os.id or "Unnamed OS",
|
|
644
645
|
description=os.description,
|
|
645
646
|
available_models=os.config.available_models if os.config else [],
|
|
646
|
-
databases=
|
|
647
|
+
databases=list({db.id for db in chain(os.dbs.values(), os.knowledge_dbs.values())}),
|
|
647
648
|
chat=os.config.chat if os.config else None,
|
|
648
649
|
session=os._get_session_config(),
|
|
649
650
|
memory=os._get_memory_config(),
|
|
@@ -784,19 +785,39 @@ def get_base_router(
|
|
|
784
785
|
|
|
785
786
|
if files:
|
|
786
787
|
for file in files:
|
|
787
|
-
if file.content_type in [
|
|
788
|
+
if file.content_type in [
|
|
789
|
+
"image/png",
|
|
790
|
+
"image/jpeg",
|
|
791
|
+
"image/jpg",
|
|
792
|
+
"image/gif",
|
|
793
|
+
"image/webp",
|
|
794
|
+
"image/bmp",
|
|
795
|
+
"image/tiff",
|
|
796
|
+
"image/tif",
|
|
797
|
+
"image/avif",
|
|
798
|
+
]:
|
|
788
799
|
try:
|
|
789
800
|
base64_image = process_image(file)
|
|
790
801
|
base64_images.append(base64_image)
|
|
791
802
|
except Exception as e:
|
|
792
803
|
log_error(f"Error processing image {file.filename}: {e}")
|
|
793
804
|
continue
|
|
794
|
-
elif file.content_type in [
|
|
805
|
+
elif file.content_type in [
|
|
806
|
+
"audio/wav",
|
|
807
|
+
"audio/wave",
|
|
808
|
+
"audio/mp3",
|
|
809
|
+
"audio/mpeg",
|
|
810
|
+
"audio/ogg",
|
|
811
|
+
"audio/mp4",
|
|
812
|
+
"audio/m4a",
|
|
813
|
+
"audio/aac",
|
|
814
|
+
"audio/flac",
|
|
815
|
+
]:
|
|
795
816
|
try:
|
|
796
|
-
|
|
797
|
-
base64_audios.append(
|
|
817
|
+
audio = process_audio(file)
|
|
818
|
+
base64_audios.append(audio)
|
|
798
819
|
except Exception as e:
|
|
799
|
-
log_error(f"Error processing audio {file.filename}: {e}")
|
|
820
|
+
log_error(f"Error processing audio {file.filename} with content type {file.content_type}: {e}")
|
|
800
821
|
continue
|
|
801
822
|
elif file.content_type in [
|
|
802
823
|
"video/x-flv",
|
|
@@ -819,10 +840,19 @@ def get_base_router(
|
|
|
819
840
|
continue
|
|
820
841
|
elif file.content_type in [
|
|
821
842
|
"application/pdf",
|
|
822
|
-
"
|
|
843
|
+
"application/json",
|
|
844
|
+
"application/x-javascript",
|
|
823
845
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
846
|
+
"text/javascript",
|
|
847
|
+
"application/x-python",
|
|
848
|
+
"text/x-python",
|
|
824
849
|
"text/plain",
|
|
825
|
-
"
|
|
850
|
+
"text/html",
|
|
851
|
+
"text/css",
|
|
852
|
+
"text/md",
|
|
853
|
+
"text/csv",
|
|
854
|
+
"text/xml",
|
|
855
|
+
"text/rtf",
|
|
826
856
|
]:
|
|
827
857
|
# Process document files
|
|
828
858
|
try:
|
agno/os/routers/memory/memory.py
CHANGED
|
@@ -120,10 +120,11 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
|
|
|
120
120
|
)
|
|
121
121
|
async def delete_memory(
|
|
122
122
|
memory_id: str = Path(description="Memory ID to delete"),
|
|
123
|
+
user_id: Optional[str] = Query(default=None, description="User ID to delete memory for"),
|
|
123
124
|
db_id: Optional[str] = Query(default=None, description="Database ID to use for deletion"),
|
|
124
125
|
) -> None:
|
|
125
126
|
db = get_db(dbs, db_id)
|
|
126
|
-
db.delete_user_memory(memory_id=memory_id)
|
|
127
|
+
db.delete_user_memory(memory_id=memory_id, user_id=user_id)
|
|
127
128
|
|
|
128
129
|
@router.delete(
|
|
129
130
|
"/memories",
|
|
@@ -145,7 +146,7 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
|
|
|
145
146
|
db_id: Optional[str] = Query(default=None, description="Database ID to use for deletion"),
|
|
146
147
|
) -> None:
|
|
147
148
|
db = get_db(dbs, db_id)
|
|
148
|
-
db.delete_user_memories(memory_ids=request.memory_ids)
|
|
149
|
+
db.delete_user_memories(memory_ids=request.memory_ids, user_id=request.user_id)
|
|
149
150
|
|
|
150
151
|
@router.get(
|
|
151
152
|
"/memories",
|
|
@@ -249,10 +250,11 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
|
|
|
249
250
|
)
|
|
250
251
|
async def get_memory(
|
|
251
252
|
memory_id: str = Path(description="Memory ID to retrieve"),
|
|
253
|
+
user_id: Optional[str] = Query(default=None, description="User ID to query memory for"),
|
|
252
254
|
db_id: Optional[str] = Query(default=None, description="Database ID to query memory from"),
|
|
253
255
|
) -> UserMemorySchema:
|
|
254
256
|
db = get_db(dbs, db_id)
|
|
255
|
-
user_memory = db.get_user_memory(memory_id=memory_id, deserialize=False)
|
|
257
|
+
user_memory = db.get_user_memory(memory_id=memory_id, user_id=user_id, deserialize=False)
|
|
256
258
|
if not user_memory:
|
|
257
259
|
raise HTTPException(status_code=404, detail=f"Memory with ID {memory_id} not found")
|
|
258
260
|
|
agno/os/utils.py
CHANGED
|
@@ -93,18 +93,40 @@ def get_session_name(session: Dict[str, Any]) -> str:
|
|
|
93
93
|
|
|
94
94
|
# For teams, identify the first Team run and avoid using the first member's run
|
|
95
95
|
if session.get("session_type") == "team":
|
|
96
|
-
run =
|
|
96
|
+
run = None
|
|
97
|
+
for r in runs:
|
|
98
|
+
# If agent_id is not present, it's a team run
|
|
99
|
+
if not r.get("agent_id"):
|
|
100
|
+
run = r
|
|
101
|
+
break
|
|
102
|
+
# Fallback to first run if no team run found
|
|
103
|
+
if run is None and runs:
|
|
104
|
+
run = runs[0]
|
|
97
105
|
|
|
98
|
-
# For workflows, pass along the first step_executor_run
|
|
99
106
|
elif session.get("session_type") == "workflow":
|
|
100
107
|
try:
|
|
101
|
-
|
|
108
|
+
workflow_run = runs[0]
|
|
109
|
+
workflow_input = workflow_run.get("input")
|
|
110
|
+
if isinstance(workflow_input, str):
|
|
111
|
+
return workflow_input
|
|
112
|
+
elif isinstance(workflow_input, dict):
|
|
113
|
+
try:
|
|
114
|
+
import json
|
|
115
|
+
return json.dumps(workflow_input)
|
|
116
|
+
except (TypeError, ValueError):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
workflow_name = session.get("workflow_data", {}).get("name")
|
|
120
|
+
return f"New {workflow_name} Session" if workflow_name else ""
|
|
102
121
|
except (KeyError, IndexError, TypeError):
|
|
103
122
|
return ""
|
|
104
123
|
|
|
105
124
|
# For agents, use the first run
|
|
106
125
|
else:
|
|
107
|
-
run = runs[0]
|
|
126
|
+
run = runs[0] if runs else None
|
|
127
|
+
|
|
128
|
+
if run is None:
|
|
129
|
+
return ""
|
|
108
130
|
|
|
109
131
|
if not isinstance(run, dict):
|
|
110
132
|
run = run.to_dict()
|
|
@@ -150,13 +172,17 @@ def process_document(file: UploadFile) -> Optional[FileMedia]:
|
|
|
150
172
|
return None
|
|
151
173
|
|
|
152
174
|
|
|
153
|
-
def extract_format(file: UploadFile):
|
|
154
|
-
|
|
175
|
+
def extract_format(file: UploadFile) -> Optional[str]:
|
|
176
|
+
"""Extract the File format from file name or content_type."""
|
|
177
|
+
# Get the format from the filename
|
|
155
178
|
if file.filename and "." in file.filename:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
179
|
+
return file.filename.split(".")[-1].lower()
|
|
180
|
+
|
|
181
|
+
# Fallback to the file content_type
|
|
182
|
+
if file.content_type:
|
|
183
|
+
return file.content_type.strip().split("/")[-1]
|
|
184
|
+
|
|
185
|
+
return None
|
|
160
186
|
|
|
161
187
|
|
|
162
188
|
def format_tools(agent_tools: List[Union[Dict[str, Any], Toolkit, Function, Callable]]):
|
agno/team/team.py
CHANGED
|
@@ -4035,6 +4035,12 @@ class Team:
|
|
|
4035
4035
|
log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
|
|
4036
4036
|
break
|
|
4037
4037
|
|
|
4038
|
+
if isinstance(reasoning_agent_response.content, str):
|
|
4039
|
+
log_warning(
|
|
4040
|
+
"Reasoning error. Content is a string, not structured output. Continuing regular session..."
|
|
4041
|
+
)
|
|
4042
|
+
break
|
|
4043
|
+
|
|
4038
4044
|
if reasoning_agent_response.content.reasoning_steps is None:
|
|
4039
4045
|
log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
|
|
4040
4046
|
break
|
|
@@ -4261,6 +4267,12 @@ class Team:
|
|
|
4261
4267
|
log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
|
|
4262
4268
|
break
|
|
4263
4269
|
|
|
4270
|
+
if isinstance(reasoning_agent_response.content, str):
|
|
4271
|
+
log_warning(
|
|
4272
|
+
"Reasoning error. Content is a string, not structured output. Continuing regular session..."
|
|
4273
|
+
)
|
|
4274
|
+
break
|
|
4275
|
+
|
|
4264
4276
|
if reasoning_agent_response.content.reasoning_steps is None:
|
|
4265
4277
|
log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
|
|
4266
4278
|
break
|
agno/tools/mcp.py
CHANGED
|
@@ -22,6 +22,8 @@ except (ImportError, ModuleNotFoundError):
|
|
|
22
22
|
|
|
23
23
|
def _prepare_command(command: str) -> list[str]:
|
|
24
24
|
"""Sanitize a command and split it into parts before using it to run a MCP server."""
|
|
25
|
+
import os
|
|
26
|
+
import shutil
|
|
25
27
|
from shlex import split
|
|
26
28
|
|
|
27
29
|
# Block dangerous characters
|
|
@@ -55,10 +57,53 @@ def _prepare_command(command: str) -> list[str]:
|
|
|
55
57
|
}
|
|
56
58
|
|
|
57
59
|
executable = parts[0].split("/")[-1]
|
|
60
|
+
|
|
61
|
+
# Check if it's a relative path starting with ./ or ../
|
|
62
|
+
if executable.startswith("./") or executable.startswith("../"):
|
|
63
|
+
# Allow relative paths to binaries
|
|
64
|
+
return parts
|
|
65
|
+
|
|
66
|
+
# Check if it's an absolute path to a binary
|
|
67
|
+
if executable.startswith("/") and os.path.isfile(executable):
|
|
68
|
+
# Allow absolute paths to existing files
|
|
69
|
+
return parts
|
|
70
|
+
|
|
71
|
+
# Check if it's a binary in current directory without ./
|
|
72
|
+
if "/" not in executable and os.path.isfile(executable):
|
|
73
|
+
# Allow binaries in current directory
|
|
74
|
+
return parts
|
|
75
|
+
|
|
76
|
+
# Check if it's a binary in PATH
|
|
77
|
+
if shutil.which(executable):
|
|
78
|
+
return parts
|
|
79
|
+
|
|
58
80
|
if executable not in ALLOWED_COMMANDS:
|
|
59
81
|
raise ValueError(f"MCP command needs to use one of the following executables: {ALLOWED_COMMANDS}")
|
|
60
82
|
|
|
61
|
-
|
|
83
|
+
first_part = parts[0]
|
|
84
|
+
executable = first_part.split("/")[-1]
|
|
85
|
+
|
|
86
|
+
# Allow known commands
|
|
87
|
+
if executable in ALLOWED_COMMANDS:
|
|
88
|
+
return parts
|
|
89
|
+
|
|
90
|
+
# Allow relative paths to custom binaries
|
|
91
|
+
if first_part.startswith(("./", "../")):
|
|
92
|
+
return parts
|
|
93
|
+
|
|
94
|
+
# Allow absolute paths to existing files
|
|
95
|
+
if first_part.startswith("/") and os.path.isfile(first_part):
|
|
96
|
+
return parts
|
|
97
|
+
|
|
98
|
+
# Allow binaries in current directory without ./
|
|
99
|
+
if "/" not in first_part and os.path.isfile(first_part):
|
|
100
|
+
return parts
|
|
101
|
+
|
|
102
|
+
# Allow binaries in PATH
|
|
103
|
+
if shutil.which(first_part):
|
|
104
|
+
return parts
|
|
105
|
+
|
|
106
|
+
raise ValueError(f"MCP command needs to use one of the following executables: {ALLOWED_COMMANDS}")
|
|
62
107
|
|
|
63
108
|
|
|
64
109
|
@dataclass
|
agno/utils/merge_dict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def merge_dictionaries(a: Dict[str, Any], b: Dict[str, Any]) -> None:
|
|
@@ -18,3 +18,24 @@ def merge_dictionaries(a: Dict[str, Any], b: Dict[str, Any]) -> None:
|
|
|
18
18
|
merge_dictionaries(a[key], b[key])
|
|
19
19
|
else:
|
|
20
20
|
a[key] = b[key]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def merge_parallel_session_states(original_state: Dict[str, Any], modified_states: List[Dict[str, Any]]) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Smart merge for parallel session states that only applies actual changes.
|
|
26
|
+
This prevents parallel steps from overwriting each other's changes.
|
|
27
|
+
"""
|
|
28
|
+
if not original_state or not modified_states:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
# Collect all actual changes (keys where value differs from original)
|
|
32
|
+
all_changes = {}
|
|
33
|
+
for modified_state in modified_states:
|
|
34
|
+
if modified_state:
|
|
35
|
+
for key, value in modified_state.items():
|
|
36
|
+
if key not in original_state or original_state[key] != value:
|
|
37
|
+
all_changes[key] = value
|
|
38
|
+
|
|
39
|
+
# Apply all collected changes to the original state
|
|
40
|
+
for key, value in all_changes.items():
|
|
41
|
+
original_state[key] = value
|
agno/utils/streamlit.py
CHANGED
agno/workflow/parallel.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
3
|
+
from copy import deepcopy
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Union
|
|
5
6
|
from uuid import uuid4
|
|
@@ -13,6 +14,7 @@ from agno.run.workflow import (
|
|
|
13
14
|
WorkflowRunOutput,
|
|
14
15
|
WorkflowRunOutputEvent,
|
|
15
16
|
)
|
|
17
|
+
from agno.utils.merge_dict import merge_parallel_session_states
|
|
16
18
|
from agno.utils.log import log_debug, logger
|
|
17
19
|
from agno.workflow.condition import Condition
|
|
18
20
|
from agno.workflow.step import Step
|
|
@@ -205,9 +207,20 @@ class Parallel:
|
|
|
205
207
|
|
|
206
208
|
self._prepare_steps()
|
|
207
209
|
|
|
210
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
211
|
+
session_state_copies = []
|
|
212
|
+
for _ in range(len(self.steps)):
|
|
213
|
+
if session_state is not None:
|
|
214
|
+
session_state_copies.append(deepcopy(session_state))
|
|
215
|
+
else:
|
|
216
|
+
session_state_copies.append({})
|
|
217
|
+
|
|
208
218
|
def execute_step_with_index(step_with_index):
|
|
209
219
|
"""Execute a single step and preserve its original index"""
|
|
210
220
|
idx, step = step_with_index
|
|
221
|
+
# Use the individual session_state copy for this step
|
|
222
|
+
step_session_state = session_state_copies[idx]
|
|
223
|
+
|
|
211
224
|
try:
|
|
212
225
|
step_result = step.execute(
|
|
213
226
|
step_input,
|
|
@@ -215,9 +228,9 @@ class Parallel:
|
|
|
215
228
|
user_id=user_id,
|
|
216
229
|
workflow_run_response=workflow_run_response,
|
|
217
230
|
store_executor_outputs=store_executor_outputs,
|
|
218
|
-
session_state=
|
|
231
|
+
session_state=step_session_state,
|
|
219
232
|
) # type: ignore[union-attr]
|
|
220
|
-
return idx, step_result
|
|
233
|
+
return idx, step_result, step_session_state
|
|
221
234
|
except Exception as exc:
|
|
222
235
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
223
236
|
logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
|
|
@@ -229,6 +242,7 @@ class Parallel:
|
|
|
229
242
|
success=False,
|
|
230
243
|
error=str(exc),
|
|
231
244
|
),
|
|
245
|
+
step_session_state,
|
|
232
246
|
)
|
|
233
247
|
|
|
234
248
|
# Use index to preserve order
|
|
@@ -241,12 +255,14 @@ class Parallel:
|
|
|
241
255
|
for indexed_step in indexed_steps
|
|
242
256
|
}
|
|
243
257
|
|
|
244
|
-
# Collect results
|
|
258
|
+
# Collect results and modified session_state copies
|
|
245
259
|
results_with_indices = []
|
|
260
|
+
modified_session_states = []
|
|
246
261
|
for future in as_completed(future_to_index):
|
|
247
262
|
try:
|
|
248
|
-
index, result = future.result()
|
|
263
|
+
index, result, modified_session_state = future.result()
|
|
249
264
|
results_with_indices.append((index, result))
|
|
265
|
+
modified_session_states.append(modified_session_state)
|
|
250
266
|
step_name = getattr(self.steps[index], "name", f"step_{index}")
|
|
251
267
|
log_debug(f"Parallel step {step_name} completed")
|
|
252
268
|
except Exception as e:
|
|
@@ -265,6 +281,9 @@ class Parallel:
|
|
|
265
281
|
)
|
|
266
282
|
)
|
|
267
283
|
|
|
284
|
+
if session_state is not None:
|
|
285
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
286
|
+
|
|
268
287
|
# Sort by original index to preserve order
|
|
269
288
|
results_with_indices.sort(key=lambda x: x[0])
|
|
270
289
|
results = [result for _, result in results_with_indices]
|
|
@@ -304,6 +323,14 @@ class Parallel:
|
|
|
304
323
|
|
|
305
324
|
self._prepare_steps()
|
|
306
325
|
|
|
326
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
327
|
+
session_state_copies = []
|
|
328
|
+
for _ in range(len(self.steps)):
|
|
329
|
+
if session_state is not None:
|
|
330
|
+
session_state_copies.append(deepcopy(session_state))
|
|
331
|
+
else:
|
|
332
|
+
session_state_copies.append({})
|
|
333
|
+
|
|
307
334
|
if stream_intermediate_steps and workflow_run_response:
|
|
308
335
|
# Yield parallel step started event
|
|
309
336
|
yield ParallelExecutionStartedEvent(
|
|
@@ -321,6 +348,9 @@ class Parallel:
|
|
|
321
348
|
def execute_step_stream_with_index(step_with_index):
|
|
322
349
|
"""Execute a single step with streaming and preserve its original index"""
|
|
323
350
|
idx, step = step_with_index
|
|
351
|
+
# Use the individual session_state copy for this step
|
|
352
|
+
step_session_state = session_state_copies[idx]
|
|
353
|
+
|
|
324
354
|
try:
|
|
325
355
|
step_events = []
|
|
326
356
|
|
|
@@ -342,11 +372,11 @@ class Parallel:
|
|
|
342
372
|
workflow_run_response=workflow_run_response,
|
|
343
373
|
step_index=sub_step_index,
|
|
344
374
|
store_executor_outputs=store_executor_outputs,
|
|
345
|
-
session_state=
|
|
375
|
+
session_state=step_session_state,
|
|
346
376
|
parent_step_id=parallel_step_id,
|
|
347
377
|
):
|
|
348
378
|
step_events.append(event)
|
|
349
|
-
return idx, step_events
|
|
379
|
+
return idx, step_events, step_session_state
|
|
350
380
|
except Exception as exc:
|
|
351
381
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
352
382
|
logger.error(f"Parallel step {parallel_step_name} streaming failed: {exc}")
|
|
@@ -360,12 +390,14 @@ class Parallel:
|
|
|
360
390
|
error=str(exc),
|
|
361
391
|
)
|
|
362
392
|
],
|
|
393
|
+
step_session_state,
|
|
363
394
|
)
|
|
364
395
|
|
|
365
396
|
# Use index to preserve order
|
|
366
397
|
indexed_steps = list(enumerate(self.steps))
|
|
367
398
|
all_events_with_indices = []
|
|
368
399
|
step_results = []
|
|
400
|
+
modified_session_states = []
|
|
369
401
|
|
|
370
402
|
with ThreadPoolExecutor(max_workers=len(self.steps)) as executor:
|
|
371
403
|
# Submit all tasks with their original indices
|
|
@@ -374,11 +406,12 @@ class Parallel:
|
|
|
374
406
|
for indexed_step in indexed_steps
|
|
375
407
|
}
|
|
376
408
|
|
|
377
|
-
# Collect results
|
|
409
|
+
# Collect results and modified session_state copies
|
|
378
410
|
for future in as_completed(future_to_index):
|
|
379
411
|
try:
|
|
380
|
-
index, events = future.result()
|
|
412
|
+
index, events, modified_session_state = future.result()
|
|
381
413
|
all_events_with_indices.append((index, events))
|
|
414
|
+
modified_session_states.append(modified_session_state)
|
|
382
415
|
|
|
383
416
|
# Extract StepOutput from events for the final result
|
|
384
417
|
step_outputs = [event for event in events if isinstance(event, StepOutput)]
|
|
@@ -400,6 +433,10 @@ class Parallel:
|
|
|
400
433
|
all_events_with_indices.append((index, [error_event]))
|
|
401
434
|
step_results.append(error_event)
|
|
402
435
|
|
|
436
|
+
# Merge all session_state changes back into the original session_state
|
|
437
|
+
if session_state is not None:
|
|
438
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
439
|
+
|
|
403
440
|
# Sort events by original index to preserve order
|
|
404
441
|
all_events_with_indices.sort(key=lambda x: x[0])
|
|
405
442
|
|
|
@@ -456,9 +493,20 @@ class Parallel:
|
|
|
456
493
|
|
|
457
494
|
self._prepare_steps()
|
|
458
495
|
|
|
496
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
497
|
+
session_state_copies = []
|
|
498
|
+
for _ in range(len(self.steps)):
|
|
499
|
+
if session_state is not None:
|
|
500
|
+
session_state_copies.append(deepcopy(session_state))
|
|
501
|
+
else:
|
|
502
|
+
session_state_copies.append({})
|
|
503
|
+
|
|
459
504
|
async def execute_step_async_with_index(step_with_index):
|
|
460
505
|
"""Execute a single step asynchronously and preserve its original index"""
|
|
461
506
|
idx, step = step_with_index
|
|
507
|
+
# Use the individual session_state copy for this step
|
|
508
|
+
step_session_state = session_state_copies[idx]
|
|
509
|
+
|
|
462
510
|
try:
|
|
463
511
|
inner_step_result = await step.aexecute(
|
|
464
512
|
step_input,
|
|
@@ -466,9 +514,9 @@ class Parallel:
|
|
|
466
514
|
user_id=user_id,
|
|
467
515
|
workflow_run_response=workflow_run_response,
|
|
468
516
|
store_executor_outputs=store_executor_outputs,
|
|
469
|
-
session_state=
|
|
517
|
+
session_state=step_session_state,
|
|
470
518
|
) # type: ignore[union-attr]
|
|
471
|
-
return idx, inner_step_result
|
|
519
|
+
return idx, inner_step_result, step_session_state
|
|
472
520
|
except Exception as exc:
|
|
473
521
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
474
522
|
logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
|
|
@@ -480,6 +528,7 @@ class Parallel:
|
|
|
480
528
|
success=False,
|
|
481
529
|
error=str(exc),
|
|
482
530
|
),
|
|
531
|
+
step_session_state,
|
|
483
532
|
)
|
|
484
533
|
|
|
485
534
|
# Use index to preserve order
|
|
@@ -493,6 +542,7 @@ class Parallel:
|
|
|
493
542
|
|
|
494
543
|
# Process results and handle exceptions, preserving order
|
|
495
544
|
processed_results_with_indices = []
|
|
545
|
+
modified_session_states = []
|
|
496
546
|
for i, result in enumerate(results_with_indices):
|
|
497
547
|
if isinstance(result, Exception):
|
|
498
548
|
step_name = getattr(self.steps[i], "name", f"step_{i}")
|
|
@@ -508,12 +558,19 @@ class Parallel:
|
|
|
508
558
|
),
|
|
509
559
|
)
|
|
510
560
|
)
|
|
561
|
+
# Still collect the session state copy for failed steps
|
|
562
|
+
modified_session_states.append(session_state_copies[i])
|
|
511
563
|
else:
|
|
512
|
-
index, step_result = result # type: ignore[misc]
|
|
564
|
+
index, step_result, modified_session_state = result # type: ignore[misc]
|
|
513
565
|
processed_results_with_indices.append((index, step_result))
|
|
566
|
+
modified_session_states.append(modified_session_state)
|
|
514
567
|
step_name = getattr(self.steps[index], "name", f"step_{index}")
|
|
515
568
|
log_debug(f"Parallel step {step_name} completed")
|
|
516
569
|
|
|
570
|
+
# Smart merge all session_state changes back into the original session_state
|
|
571
|
+
if session_state is not None:
|
|
572
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
573
|
+
|
|
517
574
|
# Sort by original index to preserve order
|
|
518
575
|
processed_results_with_indices.sort(key=lambda x: x[0])
|
|
519
576
|
results = [result for _, result in processed_results_with_indices]
|
|
@@ -553,6 +610,14 @@ class Parallel:
|
|
|
553
610
|
|
|
554
611
|
self._prepare_steps()
|
|
555
612
|
|
|
613
|
+
# Create individual session_state copies for each step to prevent race conditions
|
|
614
|
+
session_state_copies = []
|
|
615
|
+
for _ in range(len(self.steps)):
|
|
616
|
+
if session_state is not None:
|
|
617
|
+
session_state_copies.append(deepcopy(session_state))
|
|
618
|
+
else:
|
|
619
|
+
session_state_copies.append({})
|
|
620
|
+
|
|
556
621
|
if stream_intermediate_steps and workflow_run_response:
|
|
557
622
|
# Yield parallel step started event
|
|
558
623
|
yield ParallelExecutionStartedEvent(
|
|
@@ -570,6 +635,9 @@ class Parallel:
|
|
|
570
635
|
async def execute_step_stream_async_with_index(step_with_index):
|
|
571
636
|
"""Execute a single step with async streaming and preserve its original index"""
|
|
572
637
|
idx, step = step_with_index
|
|
638
|
+
# Use the individual session_state copy for this step
|
|
639
|
+
step_session_state = session_state_copies[idx]
|
|
640
|
+
|
|
573
641
|
try:
|
|
574
642
|
step_events = []
|
|
575
643
|
|
|
@@ -591,11 +659,11 @@ class Parallel:
|
|
|
591
659
|
workflow_run_response=workflow_run_response,
|
|
592
660
|
step_index=sub_step_index,
|
|
593
661
|
store_executor_outputs=store_executor_outputs,
|
|
594
|
-
session_state=
|
|
662
|
+
session_state=step_session_state,
|
|
595
663
|
parent_step_id=parallel_step_id,
|
|
596
664
|
): # type: ignore[union-attr]
|
|
597
665
|
step_events.append(event)
|
|
598
|
-
return idx, step_events
|
|
666
|
+
return idx, step_events, step_session_state
|
|
599
667
|
except Exception as e:
|
|
600
668
|
parallel_step_name = getattr(step, "name", f"step_{idx}")
|
|
601
669
|
logger.error(f"Parallel step {parallel_step_name} async streaming failed: {e}")
|
|
@@ -609,12 +677,14 @@ class Parallel:
|
|
|
609
677
|
error=str(e),
|
|
610
678
|
)
|
|
611
679
|
],
|
|
680
|
+
step_session_state,
|
|
612
681
|
)
|
|
613
682
|
|
|
614
683
|
# Use index to preserve order
|
|
615
684
|
indexed_steps = list(enumerate(self.steps))
|
|
616
685
|
all_events_with_indices = []
|
|
617
686
|
step_results = []
|
|
687
|
+
modified_session_states = []
|
|
618
688
|
|
|
619
689
|
# Create tasks for all steps with their indices
|
|
620
690
|
tasks = [execute_step_stream_async_with_index(indexed_step) for indexed_step in indexed_steps]
|
|
@@ -635,9 +705,11 @@ class Parallel:
|
|
|
635
705
|
)
|
|
636
706
|
all_events_with_indices.append((i, [error_event]))
|
|
637
707
|
step_results.append(error_event)
|
|
708
|
+
modified_session_states.append(session_state_copies[i])
|
|
638
709
|
else:
|
|
639
|
-
index, events = result # type: ignore[misc]
|
|
710
|
+
index, events, modified_session_state = result # type: ignore[misc]
|
|
640
711
|
all_events_with_indices.append((index, events))
|
|
712
|
+
modified_session_states.append(modified_session_state)
|
|
641
713
|
|
|
642
714
|
# Extract StepOutput from events for the final result
|
|
643
715
|
step_outputs = [event for event in events if isinstance(event, StepOutput)]
|
|
@@ -647,6 +719,10 @@ class Parallel:
|
|
|
647
719
|
step_name = getattr(self.steps[index], "name", f"step_{index}")
|
|
648
720
|
log_debug(f"Parallel step {step_name} async streaming completed")
|
|
649
721
|
|
|
722
|
+
# Merge all session_state changes back into the original session_state
|
|
723
|
+
if session_state is not None:
|
|
724
|
+
merge_parallel_session_states(session_state, modified_session_states)
|
|
725
|
+
|
|
650
726
|
# Sort events by original index to preserve order
|
|
651
727
|
all_events_with_indices.sort(key=lambda x: x[0])
|
|
652
728
|
|