agno 2.1.1__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. agno/agent/agent.py +12 -0
  2. agno/db/base.py +8 -4
  3. agno/db/dynamo/dynamo.py +69 -17
  4. agno/db/firestore/firestore.py +68 -29
  5. agno/db/gcs_json/gcs_json_db.py +68 -17
  6. agno/db/in_memory/in_memory_db.py +83 -14
  7. agno/db/json/json_db.py +79 -15
  8. agno/db/mongo/mongo.py +27 -8
  9. agno/db/mysql/mysql.py +17 -3
  10. agno/db/postgres/postgres.py +21 -3
  11. agno/db/redis/redis.py +38 -11
  12. agno/db/singlestore/singlestore.py +14 -3
  13. agno/db/sqlite/sqlite.py +34 -46
  14. agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
  15. agno/knowledge/reader/pdf_reader.py +28 -52
  16. agno/knowledge/reader/reader_factory.py +12 -0
  17. agno/memory/manager.py +12 -4
  18. agno/models/anthropic/claude.py +4 -1
  19. agno/models/aws/bedrock.py +52 -112
  20. agno/os/app.py +24 -30
  21. agno/os/interfaces/a2a/__init__.py +3 -0
  22. agno/os/interfaces/a2a/a2a.py +42 -0
  23. agno/os/interfaces/a2a/router.py +252 -0
  24. agno/os/interfaces/a2a/utils.py +924 -0
  25. agno/os/interfaces/agui/router.py +12 -0
  26. agno/os/router.py +38 -8
  27. agno/os/routers/memory/memory.py +5 -3
  28. agno/os/routers/memory/schemas.py +1 -0
  29. agno/os/utils.py +36 -10
  30. agno/team/team.py +12 -0
  31. agno/tools/mcp.py +46 -1
  32. agno/utils/merge_dict.py +22 -1
  33. agno/utils/streamlit.py +1 -1
  34. agno/workflow/parallel.py +90 -14
  35. agno/workflow/step.py +30 -27
  36. agno/workflow/workflow.py +5 -3
  37. {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/METADATA +16 -14
  38. {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/RECORD +41 -36
  39. {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/WHEEL +0 -0
  40. {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/licenses/LICENSE +0 -0
  41. {agno-2.1.1.dist-info → agno-2.1.2.dist-info}/top_level.txt +0 -0
@@ -34,12 +34,18 @@ async def run_agent(agent: Agent, run_input: RunAgentInput) -> AsyncIterator[Bas
34
34
  messages = convert_agui_messages_to_agno_messages(run_input.messages or [])
35
35
  yield RunStartedEvent(type=EventType.RUN_STARTED, thread_id=run_input.thread_id, run_id=run_id)
36
36
 
37
+ # Look for user_id in run_input.forwarded_props
38
+ user_id = None
39
+ if run_input.forwarded_props and isinstance(run_input.forwarded_props, dict):
40
+ user_id = run_input.forwarded_props.get("user_id")
41
+
37
42
  # Request streaming response from agent
38
43
  response_stream = agent.arun(
39
44
  input=messages,
40
45
  session_id=run_input.thread_id,
41
46
  stream=True,
42
47
  stream_intermediate_steps=True,
48
+ user_id=user_id,
43
49
  )
44
50
 
45
51
  # Stream the response content in AG-UI format
@@ -64,12 +70,18 @@ async def run_team(team: Team, input: RunAgentInput) -> AsyncIterator[BaseEvent]
64
70
  messages = convert_agui_messages_to_agno_messages(input.messages or [])
65
71
  yield RunStartedEvent(type=EventType.RUN_STARTED, thread_id=input.thread_id, run_id=run_id)
66
72
 
73
+ # Look for user_id in input.forwarded_props
74
+ user_id = None
75
+ if input.forwarded_props and isinstance(input.forwarded_props, dict):
76
+ user_id = input.forwarded_props.get("user_id")
77
+
67
78
  # Request streaming response from team
68
79
  response_stream = team.arun(
69
80
  input=messages,
70
81
  session_id=input.thread_id,
71
82
  stream=True,
72
83
  stream_intermediate_steps=True,
84
+ user_id=user_id,
73
85
  )
74
86
 
75
87
  # Stream the response content in AG-UI format
agno/os/router.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import json
2
+ from itertools import chain
2
3
  from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
3
4
  from uuid import uuid4
4
5
 
@@ -643,7 +644,7 @@ def get_base_router(
643
644
  os_id=os.id or "Unnamed OS",
644
645
  description=os.description,
645
646
  available_models=os.config.available_models if os.config else [],
646
- databases=[db.id for db in os.dbs.values()],
647
+ databases=list({db.id for db in chain(os.dbs.values(), os.knowledge_dbs.values())}),
647
648
  chat=os.config.chat if os.config else None,
648
649
  session=os._get_session_config(),
649
650
  memory=os._get_memory_config(),
@@ -784,19 +785,39 @@ def get_base_router(
784
785
 
785
786
  if files:
786
787
  for file in files:
787
- if file.content_type in ["image/png", "image/jpeg", "image/jpg", "image/webp"]:
788
+ if file.content_type in [
789
+ "image/png",
790
+ "image/jpeg",
791
+ "image/jpg",
792
+ "image/gif",
793
+ "image/webp",
794
+ "image/bmp",
795
+ "image/tiff",
796
+ "image/tif",
797
+ "image/avif",
798
+ ]:
788
799
  try:
789
800
  base64_image = process_image(file)
790
801
  base64_images.append(base64_image)
791
802
  except Exception as e:
792
803
  log_error(f"Error processing image {file.filename}: {e}")
793
804
  continue
794
- elif file.content_type in ["audio/wav", "audio/mp3", "audio/mpeg"]:
805
+ elif file.content_type in [
806
+ "audio/wav",
807
+ "audio/wave",
808
+ "audio/mp3",
809
+ "audio/mpeg",
810
+ "audio/ogg",
811
+ "audio/mp4",
812
+ "audio/m4a",
813
+ "audio/aac",
814
+ "audio/flac",
815
+ ]:
795
816
  try:
796
- base64_audio = process_audio(file)
797
- base64_audios.append(base64_audio)
817
+ audio = process_audio(file)
818
+ base64_audios.append(audio)
798
819
  except Exception as e:
799
- log_error(f"Error processing audio {file.filename}: {e}")
820
+ log_error(f"Error processing audio {file.filename} with content type {file.content_type}: {e}")
800
821
  continue
801
822
  elif file.content_type in [
802
823
  "video/x-flv",
@@ -819,10 +840,19 @@ def get_base_router(
819
840
  continue
820
841
  elif file.content_type in [
821
842
  "application/pdf",
822
- "text/csv",
843
+ "application/json",
844
+ "application/x-javascript",
823
845
  "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
846
+ "text/javascript",
847
+ "application/x-python",
848
+ "text/x-python",
824
849
  "text/plain",
825
- "application/json",
850
+ "text/html",
851
+ "text/css",
852
+ "text/md",
853
+ "text/csv",
854
+ "text/xml",
855
+ "text/rtf",
826
856
  ]:
827
857
  # Process document files
828
858
  try:
@@ -120,10 +120,11 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
120
120
  )
121
121
  async def delete_memory(
122
122
  memory_id: str = Path(description="Memory ID to delete"),
123
+ user_id: Optional[str] = Query(default=None, description="User ID to delete memory for"),
123
124
  db_id: Optional[str] = Query(default=None, description="Database ID to use for deletion"),
124
125
  ) -> None:
125
126
  db = get_db(dbs, db_id)
126
- db.delete_user_memory(memory_id=memory_id)
127
+ db.delete_user_memory(memory_id=memory_id, user_id=user_id)
127
128
 
128
129
  @router.delete(
129
130
  "/memories",
@@ -145,7 +146,7 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
145
146
  db_id: Optional[str] = Query(default=None, description="Database ID to use for deletion"),
146
147
  ) -> None:
147
148
  db = get_db(dbs, db_id)
148
- db.delete_user_memories(memory_ids=request.memory_ids)
149
+ db.delete_user_memories(memory_ids=request.memory_ids, user_id=request.user_id)
149
150
 
150
151
  @router.get(
151
152
  "/memories",
@@ -249,10 +250,11 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
249
250
  )
250
251
  async def get_memory(
251
252
  memory_id: str = Path(description="Memory ID to retrieve"),
253
+ user_id: Optional[str] = Query(default=None, description="User ID to query memory for"),
252
254
  db_id: Optional[str] = Query(default=None, description="Database ID to query memory from"),
253
255
  ) -> UserMemorySchema:
254
256
  db = get_db(dbs, db_id)
255
- user_memory = db.get_user_memory(memory_id=memory_id, deserialize=False)
257
+ user_memory = db.get_user_memory(memory_id=memory_id, user_id=user_id, deserialize=False)
256
258
  if not user_memory:
257
259
  raise HTTPException(status_code=404, detail=f"Memory with ID {memory_id} not found")
258
260
 
@@ -6,6 +6,7 @@ from pydantic import BaseModel
6
6
 
7
7
  class DeleteMemoriesRequest(BaseModel):
8
8
  memory_ids: List[str]
9
+ user_id: Optional[str] = None
9
10
 
10
11
 
11
12
  class UserMemorySchema(BaseModel):
agno/os/utils.py CHANGED
@@ -93,18 +93,40 @@ def get_session_name(session: Dict[str, Any]) -> str:
93
93
 
94
94
  # For teams, identify the first Team run and avoid using the first member's run
95
95
  if session.get("session_type") == "team":
96
- run = runs[0] if not runs[0].get("agent_id") else runs[1]
96
+ run = None
97
+ for r in runs:
98
+ # If agent_id is not present, it's a team run
99
+ if not r.get("agent_id"):
100
+ run = r
101
+ break
102
+ # Fallback to first run if no team run found
103
+ if run is None and runs:
104
+ run = runs[0]
97
105
 
98
- # For workflows, pass along the first step_executor_run
99
106
  elif session.get("session_type") == "workflow":
100
107
  try:
101
- run = session["runs"][0]["step_executor_runs"][0]
108
+ workflow_run = runs[0]
109
+ workflow_input = workflow_run.get("input")
110
+ if isinstance(workflow_input, str):
111
+ return workflow_input
112
+ elif isinstance(workflow_input, dict):
113
+ try:
114
+ import json
115
+ return json.dumps(workflow_input)
116
+ except (TypeError, ValueError):
117
+ pass
118
+
119
+ workflow_name = session.get("workflow_data", {}).get("name")
120
+ return f"New {workflow_name} Session" if workflow_name else ""
102
121
  except (KeyError, IndexError, TypeError):
103
122
  return ""
104
123
 
105
124
  # For agents, use the first run
106
125
  else:
107
- run = runs[0]
126
+ run = runs[0] if runs else None
127
+
128
+ if run is None:
129
+ return ""
108
130
 
109
131
  if not isinstance(run, dict):
110
132
  run = run.to_dict()
@@ -150,13 +172,17 @@ def process_document(file: UploadFile) -> Optional[FileMedia]:
150
172
  return None
151
173
 
152
174
 
153
- def extract_format(file: UploadFile):
154
- _format = None
175
+ def extract_format(file: UploadFile) -> Optional[str]:
176
+ """Extract the File format from file name or content_type."""
177
+ # Get the format from the filename
155
178
  if file.filename and "." in file.filename:
156
- _format = file.filename.split(".")[-1].lower()
157
- elif file.content_type:
158
- _format = file.content_type.split("/")[-1]
159
- return _format
179
+ return file.filename.split(".")[-1].lower()
180
+
181
+ # Fallback to the file content_type
182
+ if file.content_type:
183
+ return file.content_type.strip().split("/")[-1]
184
+
185
+ return None
160
186
 
161
187
 
162
188
  def format_tools(agent_tools: List[Union[Dict[str, Any], Toolkit, Function, Callable]]):
agno/team/team.py CHANGED
@@ -4035,6 +4035,12 @@ class Team:
4035
4035
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
4036
4036
  break
4037
4037
 
4038
+ if isinstance(reasoning_agent_response.content, str):
4039
+ log_warning(
4040
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
4041
+ )
4042
+ break
4043
+
4038
4044
  if reasoning_agent_response.content.reasoning_steps is None:
4039
4045
  log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
4040
4046
  break
@@ -4261,6 +4267,12 @@ class Team:
4261
4267
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
4262
4268
  break
4263
4269
 
4270
+ if isinstance(reasoning_agent_response.content, str):
4271
+ log_warning(
4272
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
4273
+ )
4274
+ break
4275
+
4264
4276
  if reasoning_agent_response.content.reasoning_steps is None:
4265
4277
  log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
4266
4278
  break
agno/tools/mcp.py CHANGED
@@ -22,6 +22,8 @@ except (ImportError, ModuleNotFoundError):
22
22
 
23
23
  def _prepare_command(command: str) -> list[str]:
24
24
  """Sanitize a command and split it into parts before using it to run a MCP server."""
25
+ import os
26
+ import shutil
25
27
  from shlex import split
26
28
 
27
29
  # Block dangerous characters
@@ -55,10 +57,53 @@ def _prepare_command(command: str) -> list[str]:
55
57
  }
56
58
 
57
59
  executable = parts[0].split("/")[-1]
60
+
61
+ # Check if it's a relative path starting with ./ or ../
62
+ if executable.startswith("./") or executable.startswith("../"):
63
+ # Allow relative paths to binaries
64
+ return parts
65
+
66
+ # Check if it's an absolute path to a binary
67
+ if executable.startswith("/") and os.path.isfile(executable):
68
+ # Allow absolute paths to existing files
69
+ return parts
70
+
71
+ # Check if it's a binary in current directory without ./
72
+ if "/" not in executable and os.path.isfile(executable):
73
+ # Allow binaries in current directory
74
+ return parts
75
+
76
+ # Check if it's a binary in PATH
77
+ if shutil.which(executable):
78
+ return parts
79
+
58
80
  if executable not in ALLOWED_COMMANDS:
59
81
  raise ValueError(f"MCP command needs to use one of the following executables: {ALLOWED_COMMANDS}")
60
82
 
61
- return parts
83
+ first_part = parts[0]
84
+ executable = first_part.split("/")[-1]
85
+
86
+ # Allow known commands
87
+ if executable in ALLOWED_COMMANDS:
88
+ return parts
89
+
90
+ # Allow relative paths to custom binaries
91
+ if first_part.startswith(("./", "../")):
92
+ return parts
93
+
94
+ # Allow absolute paths to existing files
95
+ if first_part.startswith("/") and os.path.isfile(first_part):
96
+ return parts
97
+
98
+ # Allow binaries in current directory without ./
99
+ if "/" not in first_part and os.path.isfile(first_part):
100
+ return parts
101
+
102
+ # Allow binaries in PATH
103
+ if shutil.which(first_part):
104
+ return parts
105
+
106
+ raise ValueError(f"MCP command needs to use one of the following executables: {ALLOWED_COMMANDS}")
62
107
 
63
108
 
64
109
  @dataclass
agno/utils/merge_dict.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict
1
+ from typing import Any, Dict, List
2
2
 
3
3
 
4
4
  def merge_dictionaries(a: Dict[str, Any], b: Dict[str, Any]) -> None:
@@ -18,3 +18,24 @@ def merge_dictionaries(a: Dict[str, Any], b: Dict[str, Any]) -> None:
18
18
  merge_dictionaries(a[key], b[key])
19
19
  else:
20
20
  a[key] = b[key]
21
+
22
+
23
+ def merge_parallel_session_states(original_state: Dict[str, Any], modified_states: List[Dict[str, Any]]) -> None:
24
+ """
25
+ Smart merge for parallel session states that only applies actual changes.
26
+ This prevents parallel steps from overwriting each other's changes.
27
+ """
28
+ if not original_state or not modified_states:
29
+ return
30
+
31
+ # Collect all actual changes (keys where value differs from original)
32
+ all_changes = {}
33
+ for modified_state in modified_states:
34
+ if modified_state:
35
+ for key, value in modified_state.items():
36
+ if key not in original_state or original_state[key] != value:
37
+ all_changes[key] = value
38
+
39
+ # Apply all collected changes to the original state
40
+ for key, value in all_changes.items():
41
+ original_state[key] = value
agno/utils/streamlit.py CHANGED
@@ -452,7 +452,7 @@ MODELS = [
452
452
  "gpt-4o",
453
453
  "o3-mini",
454
454
  "gpt-5",
455
- "claude-4-sonnet",
455
+ "claude-sonnet-4-5-20250929",
456
456
  "gemini-2.5-pro",
457
457
  ]
458
458
 
agno/workflow/parallel.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from copy import deepcopy
3
4
  from dataclasses import dataclass
4
5
  from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Union
5
6
  from uuid import uuid4
@@ -13,6 +14,7 @@ from agno.run.workflow import (
13
14
  WorkflowRunOutput,
14
15
  WorkflowRunOutputEvent,
15
16
  )
17
+ from agno.utils.merge_dict import merge_parallel_session_states
16
18
  from agno.utils.log import log_debug, logger
17
19
  from agno.workflow.condition import Condition
18
20
  from agno.workflow.step import Step
@@ -205,9 +207,20 @@ class Parallel:
205
207
 
206
208
  self._prepare_steps()
207
209
 
210
+ # Create individual session_state copies for each step to prevent race conditions
211
+ session_state_copies = []
212
+ for _ in range(len(self.steps)):
213
+ if session_state is not None:
214
+ session_state_copies.append(deepcopy(session_state))
215
+ else:
216
+ session_state_copies.append({})
217
+
208
218
  def execute_step_with_index(step_with_index):
209
219
  """Execute a single step and preserve its original index"""
210
220
  idx, step = step_with_index
221
+ # Use the individual session_state copy for this step
222
+ step_session_state = session_state_copies[idx]
223
+
211
224
  try:
212
225
  step_result = step.execute(
213
226
  step_input,
@@ -215,9 +228,9 @@ class Parallel:
215
228
  user_id=user_id,
216
229
  workflow_run_response=workflow_run_response,
217
230
  store_executor_outputs=store_executor_outputs,
218
- session_state=session_state,
231
+ session_state=step_session_state,
219
232
  ) # type: ignore[union-attr]
220
- return idx, step_result
233
+ return idx, step_result, step_session_state
221
234
  except Exception as exc:
222
235
  parallel_step_name = getattr(step, "name", f"step_{idx}")
223
236
  logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
@@ -229,6 +242,7 @@ class Parallel:
229
242
  success=False,
230
243
  error=str(exc),
231
244
  ),
245
+ step_session_state,
232
246
  )
233
247
 
234
248
  # Use index to preserve order
@@ -241,12 +255,14 @@ class Parallel:
241
255
  for indexed_step in indexed_steps
242
256
  }
243
257
 
244
- # Collect results
258
+ # Collect results and modified session_state copies
245
259
  results_with_indices = []
260
+ modified_session_states = []
246
261
  for future in as_completed(future_to_index):
247
262
  try:
248
- index, result = future.result()
263
+ index, result, modified_session_state = future.result()
249
264
  results_with_indices.append((index, result))
265
+ modified_session_states.append(modified_session_state)
250
266
  step_name = getattr(self.steps[index], "name", f"step_{index}")
251
267
  log_debug(f"Parallel step {step_name} completed")
252
268
  except Exception as e:
@@ -265,6 +281,9 @@ class Parallel:
265
281
  )
266
282
  )
267
283
 
284
+ if session_state is not None:
285
+ merge_parallel_session_states(session_state, modified_session_states)
286
+
268
287
  # Sort by original index to preserve order
269
288
  results_with_indices.sort(key=lambda x: x[0])
270
289
  results = [result for _, result in results_with_indices]
@@ -304,6 +323,14 @@ class Parallel:
304
323
 
305
324
  self._prepare_steps()
306
325
 
326
+ # Create individual session_state copies for each step to prevent race conditions
327
+ session_state_copies = []
328
+ for _ in range(len(self.steps)):
329
+ if session_state is not None:
330
+ session_state_copies.append(deepcopy(session_state))
331
+ else:
332
+ session_state_copies.append({})
333
+
307
334
  if stream_intermediate_steps and workflow_run_response:
308
335
  # Yield parallel step started event
309
336
  yield ParallelExecutionStartedEvent(
@@ -321,6 +348,9 @@ class Parallel:
321
348
  def execute_step_stream_with_index(step_with_index):
322
349
  """Execute a single step with streaming and preserve its original index"""
323
350
  idx, step = step_with_index
351
+ # Use the individual session_state copy for this step
352
+ step_session_state = session_state_copies[idx]
353
+
324
354
  try:
325
355
  step_events = []
326
356
 
@@ -342,11 +372,11 @@ class Parallel:
342
372
  workflow_run_response=workflow_run_response,
343
373
  step_index=sub_step_index,
344
374
  store_executor_outputs=store_executor_outputs,
345
- session_state=session_state,
375
+ session_state=step_session_state,
346
376
  parent_step_id=parallel_step_id,
347
377
  ):
348
378
  step_events.append(event)
349
- return idx, step_events
379
+ return idx, step_events, step_session_state
350
380
  except Exception as exc:
351
381
  parallel_step_name = getattr(step, "name", f"step_{idx}")
352
382
  logger.error(f"Parallel step {parallel_step_name} streaming failed: {exc}")
@@ -360,12 +390,14 @@ class Parallel:
360
390
  error=str(exc),
361
391
  )
362
392
  ],
393
+ step_session_state,
363
394
  )
364
395
 
365
396
  # Use index to preserve order
366
397
  indexed_steps = list(enumerate(self.steps))
367
398
  all_events_with_indices = []
368
399
  step_results = []
400
+ modified_session_states = []
369
401
 
370
402
  with ThreadPoolExecutor(max_workers=len(self.steps)) as executor:
371
403
  # Submit all tasks with their original indices
@@ -374,11 +406,12 @@ class Parallel:
374
406
  for indexed_step in indexed_steps
375
407
  }
376
408
 
377
- # Collect results as they complete
409
+ # Collect results and modified session_state copies
378
410
  for future in as_completed(future_to_index):
379
411
  try:
380
- index, events = future.result()
412
+ index, events, modified_session_state = future.result()
381
413
  all_events_with_indices.append((index, events))
414
+ modified_session_states.append(modified_session_state)
382
415
 
383
416
  # Extract StepOutput from events for the final result
384
417
  step_outputs = [event for event in events if isinstance(event, StepOutput)]
@@ -400,6 +433,10 @@ class Parallel:
400
433
  all_events_with_indices.append((index, [error_event]))
401
434
  step_results.append(error_event)
402
435
 
436
+ # Merge all session_state changes back into the original session_state
437
+ if session_state is not None:
438
+ merge_parallel_session_states(session_state, modified_session_states)
439
+
403
440
  # Sort events by original index to preserve order
404
441
  all_events_with_indices.sort(key=lambda x: x[0])
405
442
 
@@ -456,9 +493,20 @@ class Parallel:
456
493
 
457
494
  self._prepare_steps()
458
495
 
496
+ # Create individual session_state copies for each step to prevent race conditions
497
+ session_state_copies = []
498
+ for _ in range(len(self.steps)):
499
+ if session_state is not None:
500
+ session_state_copies.append(deepcopy(session_state))
501
+ else:
502
+ session_state_copies.append({})
503
+
459
504
  async def execute_step_async_with_index(step_with_index):
460
505
  """Execute a single step asynchronously and preserve its original index"""
461
506
  idx, step = step_with_index
507
+ # Use the individual session_state copy for this step
508
+ step_session_state = session_state_copies[idx]
509
+
462
510
  try:
463
511
  inner_step_result = await step.aexecute(
464
512
  step_input,
@@ -466,9 +514,9 @@ class Parallel:
466
514
  user_id=user_id,
467
515
  workflow_run_response=workflow_run_response,
468
516
  store_executor_outputs=store_executor_outputs,
469
- session_state=session_state,
517
+ session_state=step_session_state,
470
518
  ) # type: ignore[union-attr]
471
- return idx, inner_step_result
519
+ return idx, inner_step_result, step_session_state
472
520
  except Exception as exc:
473
521
  parallel_step_name = getattr(step, "name", f"step_{idx}")
474
522
  logger.error(f"Parallel step {parallel_step_name} failed: {exc}")
@@ -480,6 +528,7 @@ class Parallel:
480
528
  success=False,
481
529
  error=str(exc),
482
530
  ),
531
+ step_session_state,
483
532
  )
484
533
 
485
534
  # Use index to preserve order
@@ -493,6 +542,7 @@ class Parallel:
493
542
 
494
543
  # Process results and handle exceptions, preserving order
495
544
  processed_results_with_indices = []
545
+ modified_session_states = []
496
546
  for i, result in enumerate(results_with_indices):
497
547
  if isinstance(result, Exception):
498
548
  step_name = getattr(self.steps[i], "name", f"step_{i}")
@@ -508,12 +558,19 @@ class Parallel:
508
558
  ),
509
559
  )
510
560
  )
561
+ # Still collect the session state copy for failed steps
562
+ modified_session_states.append(session_state_copies[i])
511
563
  else:
512
- index, step_result = result # type: ignore[misc]
564
+ index, step_result, modified_session_state = result # type: ignore[misc]
513
565
  processed_results_with_indices.append((index, step_result))
566
+ modified_session_states.append(modified_session_state)
514
567
  step_name = getattr(self.steps[index], "name", f"step_{index}")
515
568
  log_debug(f"Parallel step {step_name} completed")
516
569
 
570
+ # Smart merge all session_state changes back into the original session_state
571
+ if session_state is not None:
572
+ merge_parallel_session_states(session_state, modified_session_states)
573
+
517
574
  # Sort by original index to preserve order
518
575
  processed_results_with_indices.sort(key=lambda x: x[0])
519
576
  results = [result for _, result in processed_results_with_indices]
@@ -553,6 +610,14 @@ class Parallel:
553
610
 
554
611
  self._prepare_steps()
555
612
 
613
+ # Create individual session_state copies for each step to prevent race conditions
614
+ session_state_copies = []
615
+ for _ in range(len(self.steps)):
616
+ if session_state is not None:
617
+ session_state_copies.append(deepcopy(session_state))
618
+ else:
619
+ session_state_copies.append({})
620
+
556
621
  if stream_intermediate_steps and workflow_run_response:
557
622
  # Yield parallel step started event
558
623
  yield ParallelExecutionStartedEvent(
@@ -570,6 +635,9 @@ class Parallel:
570
635
  async def execute_step_stream_async_with_index(step_with_index):
571
636
  """Execute a single step with async streaming and preserve its original index"""
572
637
  idx, step = step_with_index
638
+ # Use the individual session_state copy for this step
639
+ step_session_state = session_state_copies[idx]
640
+
573
641
  try:
574
642
  step_events = []
575
643
 
@@ -591,11 +659,11 @@ class Parallel:
591
659
  workflow_run_response=workflow_run_response,
592
660
  step_index=sub_step_index,
593
661
  store_executor_outputs=store_executor_outputs,
594
- session_state=session_state,
662
+ session_state=step_session_state,
595
663
  parent_step_id=parallel_step_id,
596
664
  ): # type: ignore[union-attr]
597
665
  step_events.append(event)
598
- return idx, step_events
666
+ return idx, step_events, step_session_state
599
667
  except Exception as e:
600
668
  parallel_step_name = getattr(step, "name", f"step_{idx}")
601
669
  logger.error(f"Parallel step {parallel_step_name} async streaming failed: {e}")
@@ -609,12 +677,14 @@ class Parallel:
609
677
  error=str(e),
610
678
  )
611
679
  ],
680
+ step_session_state,
612
681
  )
613
682
 
614
683
  # Use index to preserve order
615
684
  indexed_steps = list(enumerate(self.steps))
616
685
  all_events_with_indices = []
617
686
  step_results = []
687
+ modified_session_states = []
618
688
 
619
689
  # Create tasks for all steps with their indices
620
690
  tasks = [execute_step_stream_async_with_index(indexed_step) for indexed_step in indexed_steps]
@@ -635,9 +705,11 @@ class Parallel:
635
705
  )
636
706
  all_events_with_indices.append((i, [error_event]))
637
707
  step_results.append(error_event)
708
+ modified_session_states.append(session_state_copies[i])
638
709
  else:
639
- index, events = result # type: ignore[misc]
710
+ index, events, modified_session_state = result # type: ignore[misc]
640
711
  all_events_with_indices.append((index, events))
712
+ modified_session_states.append(modified_session_state)
641
713
 
642
714
  # Extract StepOutput from events for the final result
643
715
  step_outputs = [event for event in events if isinstance(event, StepOutput)]
@@ -647,6 +719,10 @@ class Parallel:
647
719
  step_name = getattr(self.steps[index], "name", f"step_{index}")
648
720
  log_debug(f"Parallel step {step_name} async streaming completed")
649
721
 
722
+ # Merge all session_state changes back into the original session_state
723
+ if session_state is not None:
724
+ merge_parallel_session_states(session_state, modified_session_states)
725
+
650
726
  # Sort events by original index to preserve order
651
727
  all_events_with_indices.sort(key=lambda x: x[0])
652
728