agno 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. agno/agent/agent.py +229 -164
  2. agno/db/dynamo/dynamo.py +8 -0
  3. agno/db/firestore/firestore.py +8 -0
  4. agno/db/gcs_json/gcs_json_db.py +9 -0
  5. agno/db/json/json_db.py +8 -0
  6. agno/db/migrations/v1_to_v2.py +191 -23
  7. agno/db/mongo/mongo.py +68 -0
  8. agno/db/mysql/mysql.py +13 -3
  9. agno/db/mysql/schemas.py +27 -27
  10. agno/db/postgres/postgres.py +19 -11
  11. agno/db/redis/redis.py +6 -0
  12. agno/db/singlestore/schemas.py +1 -1
  13. agno/db/singlestore/singlestore.py +8 -1
  14. agno/db/sqlite/sqlite.py +12 -3
  15. agno/integrations/discord/client.py +1 -0
  16. agno/knowledge/knowledge.py +92 -66
  17. agno/knowledge/reader/reader_factory.py +7 -3
  18. agno/knowledge/reader/web_search_reader.py +12 -6
  19. agno/models/base.py +2 -2
  20. agno/models/message.py +109 -0
  21. agno/models/openai/chat.py +3 -0
  22. agno/models/openai/responses.py +12 -0
  23. agno/models/response.py +5 -0
  24. agno/models/siliconflow/__init__.py +5 -0
  25. agno/models/siliconflow/siliconflow.py +25 -0
  26. agno/os/app.py +164 -41
  27. agno/os/auth.py +24 -14
  28. agno/os/interfaces/agui/utils.py +98 -134
  29. agno/os/router.py +128 -55
  30. agno/os/routers/evals/utils.py +9 -9
  31. agno/os/routers/health.py +25 -0
  32. agno/os/routers/home.py +52 -0
  33. agno/os/routers/knowledge/knowledge.py +11 -11
  34. agno/os/routers/session/session.py +24 -8
  35. agno/os/schema.py +29 -2
  36. agno/os/utils.py +0 -8
  37. agno/run/agent.py +3 -3
  38. agno/run/team.py +3 -3
  39. agno/run/workflow.py +64 -10
  40. agno/session/team.py +1 -0
  41. agno/team/team.py +189 -94
  42. agno/tools/duckduckgo.py +15 -11
  43. agno/tools/googlesearch.py +1 -1
  44. agno/tools/mem0.py +11 -17
  45. agno/tools/memory.py +34 -6
  46. agno/utils/common.py +90 -1
  47. agno/utils/streamlit.py +14 -8
  48. agno/utils/string.py +32 -0
  49. agno/utils/tools.py +1 -1
  50. agno/vectordb/chroma/chromadb.py +8 -2
  51. agno/workflow/step.py +115 -16
  52. agno/workflow/workflow.py +16 -13
  53. {agno-2.0.3.dist-info → agno-2.0.5.dist-info}/METADATA +6 -5
  54. {agno-2.0.3.dist-info → agno-2.0.5.dist-info}/RECORD +57 -54
  55. agno/knowledge/reader/url_reader.py +0 -128
  56. {agno-2.0.3.dist-info → agno-2.0.5.dist-info}/WHEEL +0 -0
  57. {agno-2.0.3.dist-info → agno-2.0.5.dist-info}/licenses/LICENSE +0 -0
  58. {agno-2.0.3.dist-info → agno-2.0.5.dist-info}/top_level.txt +0 -0
agno/tools/mem0.py CHANGED
@@ -2,7 +2,6 @@ import json
2
2
  from os import getenv
3
3
  from typing import Any, Dict, List, Optional, Union
4
4
 
5
- from agno.agent import Agent
6
5
  from agno.tools import Toolkit
7
6
  from agno.utils.log import log_debug, log_error, log_warning
8
7
 
@@ -69,15 +68,13 @@ class Mem0Tools(Toolkit):
69
68
  def _get_user_id(
70
69
  self,
71
70
  method_name: str,
72
- agent: Optional[Agent] = None,
71
+ session_state: Dict[str, Any],
73
72
  ) -> str:
74
73
  """Resolve the user ID"""
75
74
  resolved_user_id = self.user_id
76
- if not resolved_user_id and agent is not None:
75
+ if not resolved_user_id:
77
76
  try:
78
- session_state = getattr(agent, "session_state", None)
79
- if isinstance(session_state, dict):
80
- resolved_user_id = session_state.get("current_user_id")
77
+ resolved_user_id = session_state.get("current_user_id")
81
78
  except Exception:
82
79
  pass
83
80
  if not resolved_user_id:
@@ -88,7 +85,7 @@ class Mem0Tools(Toolkit):
88
85
 
89
86
  def add_memory(
90
87
  self,
91
- agent: Agent,
88
+ session_state,
92
89
  content: Union[str, Dict[str, str]],
93
90
  ) -> str:
94
91
  """Add facts to the user's memory.
@@ -101,7 +98,7 @@ class Mem0Tools(Toolkit):
101
98
  str: JSON-encoded Mem0 response or an error message.
102
99
  """
103
100
 
104
- resolved_user_id = self._get_user_id("add_memory", agent=agent)
101
+ resolved_user_id = self._get_user_id("add_memory", session_state=session_state)
105
102
  if isinstance(resolved_user_id, str) and resolved_user_id.startswith("Error in add_memory:"):
106
103
  return resolved_user_id
107
104
  try:
@@ -116,7 +113,6 @@ class Mem0Tools(Toolkit):
116
113
  messages_list,
117
114
  user_id=resolved_user_id,
118
115
  infer=self.infer,
119
- output_format="v1.1",
120
116
  )
121
117
  return json.dumps(result)
122
118
  except Exception as e:
@@ -125,19 +121,18 @@ class Mem0Tools(Toolkit):
125
121
 
126
122
  def search_memory(
127
123
  self,
128
- agent: Agent,
124
+ session_state: Dict[str, Any],
129
125
  query: str,
130
126
  ) -> str:
131
127
  """Semantic search for *query* across the user's stored memories."""
132
128
 
133
- resolved_user_id = self._get_user_id("search_memory", agent=agent)
129
+ resolved_user_id = self._get_user_id("search_memory", session_state=session_state)
134
130
  if isinstance(resolved_user_id, str) and resolved_user_id.startswith("Error in search_memory:"):
135
131
  return resolved_user_id
136
132
  try:
137
133
  results = self.client.search(
138
134
  query=query,
139
135
  user_id=resolved_user_id,
140
- output_format="v1.1",
141
136
  )
142
137
 
143
138
  if isinstance(results, dict) and "results" in results:
@@ -156,16 +151,15 @@ class Mem0Tools(Toolkit):
156
151
  log_error(f"Error searching memory: {e}")
157
152
  return f"Error searching memory: {e}"
158
153
 
159
- def get_all_memories(self, agent: Agent) -> str:
154
+ def get_all_memories(self, session_state: Dict[str, Any]) -> str:
160
155
  """Return **all** memories for the current user as a JSON string."""
161
156
 
162
- resolved_user_id = self._get_user_id("get_all_memories", agent=agent)
157
+ resolved_user_id = self._get_user_id("get_all_memories", session_state=session_state)
163
158
  if isinstance(resolved_user_id, str) and resolved_user_id.startswith("Error in get_all_memories:"):
164
159
  return resolved_user_id
165
160
  try:
166
161
  results = self.client.get_all(
167
162
  user_id=resolved_user_id,
168
- output_format="v1.1",
169
163
  )
170
164
 
171
165
  if isinstance(results, dict) and "results" in results:
@@ -183,10 +177,10 @@ class Mem0Tools(Toolkit):
183
177
  log_error(f"Error getting all memories: {e}")
184
178
  return f"Error getting all memories: {e}"
185
179
 
186
- def delete_all_memories(self, agent: Agent) -> str:
180
+ def delete_all_memories(self, session_state: Dict[str, Any]) -> str:
187
181
  """Delete *all* memories associated with the current user"""
188
182
 
189
- resolved_user_id = self._get_user_id("delete_all_memories", agent=agent)
183
+ resolved_user_id = self._get_user_id("delete_all_memories", session_state=session_state)
190
184
  if isinstance(resolved_user_id, str) and resolved_user_id.startswith("Error in delete_all_memories:"):
191
185
  error_msg = resolved_user_id
192
186
  log_error(error_msg)
agno/tools/memory.py CHANGED
@@ -95,13 +95,28 @@ class MemoryTools(Toolkit):
95
95
 
96
96
  def get_memories(self, session_state: Dict[str, Any]) -> str:
97
97
  """
98
- Use this tool to get a list of memories from the database.
98
+ Use this tool to get a list of memories for the current user from the database.
99
99
  """
100
100
  try:
101
101
  # Get user info from session state
102
102
  user_id = session_state.get("current_user_id") if session_state else None
103
103
 
104
104
  memories = self.db.get_user_memories(user_id=user_id)
105
+
106
+ # Store the result in session state for analysis
107
+ if session_state is None:
108
+ session_state = {}
109
+ if "memory_operations" not in session_state:
110
+ session_state["memory_operations"] = []
111
+
112
+ operation_result = {
113
+ "operation": "get_memories",
114
+ "success": True,
115
+ "memories": [memory.to_dict() for memory in memories], # type: ignore
116
+ "error": None,
117
+ }
118
+ session_state["memory_operations"].append(operation_result)
119
+
105
120
  return json.dumps([memory.to_dict() for memory in memories], indent=2) # type: ignore
106
121
  except Exception as e:
107
122
  log_error(f"Error getting memories: {e}")
@@ -328,19 +343,23 @@ class MemoryTools(Toolkit):
328
343
  - Purpose: A scratchpad for planning memory operations, brainstorming memory content, and refining your approach. You never reveal your "Think" content to the user.
329
344
  - Usage: Call `think` whenever you need to figure out what memory operations to perform, analyze requirements, or decide on strategy.
330
345
 
331
- 2. **Add Memory**
346
+ 2. **Get Memories**
347
+ - Purpose: Retrieves a list of memories from the database for the current user.
348
+ - Usage: Call `get_memories` when you need to retrieve memories for the current user.
349
+
350
+ 3. **Add Memory**
332
351
  - Purpose: Creates new memories in the database with specified content and metadata.
333
352
  - Usage: Call `add_memory` with memory content and optional topics when you need to store new information.
334
353
 
335
- 3. **Update Memory**
354
+ 4. **Update Memory**
336
355
  - Purpose: Modifies existing memories in the database by memory ID.
337
356
  - Usage: Call `update_memory` with a memory ID and the fields you want to change. Only specify the fields that need updating.
338
357
 
339
- 4. **Delete Memory**
358
+ 5. **Delete Memory**
340
359
  - Purpose: Removes memories from the database by memory ID.
341
360
  - Usage: Call `delete_memory` with a memory ID when a memory is no longer needed or requested to be removed.
342
361
 
343
- 5. **Analyze**
362
+ 6. **Analyze**
344
363
  - Purpose: Evaluate whether the memory operations results are correct and sufficient. If not, go back to "Think" or use memory operations with refined parameters.
345
364
  - Usage: Call `analyze` after performing memory operations to verify:
346
365
  - Success: Did the operation complete successfully?
@@ -387,5 +406,14 @@ class MemoryTools(Toolkit):
387
406
  Delete Memory: memory_id="work_schedule_memory_id"
388
407
  Analyze: Successfully deleted the outdated work schedule memory. The old information won't interfere with future scheduling requests.
389
408
 
390
- Final Answer: I've removed your old work schedule information. Feel free to share your new schedule when you're ready, and I'll store the updated information.\
409
+ Final Answer: I've removed your old work schedule information. Feel free to share your new schedule when you're ready, and I'll store the updated information.
410
+
411
+ #### Example 4: Retrieving Memories
412
+
413
+ User: What have you remembered about me?
414
+ Think: The user wants to retrieve memories about themselves. I should use the get_memories tool to retrieve the memories.
415
+ Get Memories:
416
+ Analyze: Successfully retrieved the memories about the user. The memories are relevant to the user's preferences and activities.
417
+
418
+ Final Answer: I've retrieved the memories about you. You like to hike in the mountains on weekends and travel to new places and experience different cultures. You are planning to travel to Africa in December.\
391
419
  """)
agno/utils/common.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from dataclasses import asdict
2
- from typing import Any, List, Optional, Type
2
+ from typing import Any, List, Optional, Set, Type, Union, get_type_hints
3
3
 
4
4
 
5
5
  def isinstanceany(obj: Any, class_list: List[Type]) -> bool:
@@ -41,3 +41,92 @@ def nested_model_dump(value):
41
41
  elif isinstance(value, list):
42
42
  return [nested_model_dump(item) for item in value]
43
43
  return value
44
+
45
+
46
+ def is_typed_dict(cls: Type[Any]) -> bool:
47
+ """Check if a class is a TypedDict"""
48
+ return (
49
+ hasattr(cls, "__annotations__")
50
+ and hasattr(cls, "__total__")
51
+ and hasattr(cls, "__required_keys__")
52
+ and hasattr(cls, "__optional_keys__")
53
+ )
54
+
55
+
56
+ def check_type_compatibility(value: Any, expected_type: Type) -> bool:
57
+ """Basic type compatibility checking."""
58
+ from typing import get_args, get_origin
59
+
60
+ # Handle None/Optional types
61
+ if value is None:
62
+ return (
63
+ type(None) in get_args(expected_type) if hasattr(expected_type, "__args__") else expected_type is type(None)
64
+ )
65
+
66
+ # Handle Union types (including Optional)
67
+ origin = get_origin(expected_type)
68
+ if origin is Union:
69
+ return any(check_type_compatibility(value, arg) for arg in get_args(expected_type))
70
+
71
+ # Handle List types
72
+ if origin is list or expected_type is list:
73
+ if not isinstance(value, list):
74
+ return False
75
+ if origin is list and get_args(expected_type):
76
+ element_type = get_args(expected_type)[0]
77
+ return all(check_type_compatibility(item, element_type) for item in value)
78
+ return True
79
+
80
+ if expected_type in (str, int, float, bool):
81
+ return isinstance(value, expected_type)
82
+
83
+ if expected_type is Any:
84
+ return True
85
+
86
+ try:
87
+ return isinstance(value, expected_type)
88
+ except TypeError:
89
+ return True
90
+
91
+
92
+ def validate_typed_dict(data: dict, schema_cls) -> dict:
93
+ """Validate input data against a TypedDict schema."""
94
+ if not isinstance(data, dict):
95
+ raise ValueError(f"Expected dict for TypedDict {schema_cls.__name__}, got {type(data)}")
96
+
97
+ # Get type hints from the TypedDict
98
+ try:
99
+ type_hints = get_type_hints(schema_cls)
100
+ except Exception as e:
101
+ raise ValueError(f"Could not get type hints for TypedDict {schema_cls.__name__}: {e}")
102
+
103
+ # Get required and optional keys
104
+ required_keys: Set[str] = getattr(schema_cls, "__required_keys__", set())
105
+ optional_keys: Set[str] = getattr(schema_cls, "__optional_keys__", set())
106
+ all_keys = required_keys | optional_keys
107
+
108
+ # Check for missing required fields
109
+ missing_required = required_keys - set(data.keys())
110
+ if missing_required:
111
+ raise ValueError(f"Missing required fields in TypedDict {schema_cls.__name__}: {missing_required}")
112
+
113
+ # Check for unexpected fields
114
+ unexpected_fields = set(data.keys()) - all_keys
115
+ if unexpected_fields:
116
+ raise ValueError(f"Unexpected fields in TypedDict {schema_cls.__name__}: {unexpected_fields}")
117
+
118
+ # Basic type checking for provided fields
119
+ validated_data = {}
120
+ for field_name, value in data.items():
121
+ if field_name in type_hints:
122
+ expected_type = type_hints[field_name]
123
+
124
+ # Handle simple type checking
125
+ if not check_type_compatibility(value, expected_type):
126
+ raise ValueError(
127
+ f"Field '{field_name}' expected type {expected_type}, got {type(value)} with value {value}"
128
+ )
129
+
130
+ validated_data[field_name] = value
131
+
132
+ return validated_data
agno/utils/streamlit.py CHANGED
@@ -1,14 +1,20 @@
1
1
  from datetime import datetime
2
2
  from typing import Any, Callable, Dict, List, Optional
3
3
 
4
- import streamlit as st
5
-
6
- from agno.agent import Agent
7
- from agno.db.base import SessionType
8
- from agno.models.anthropic import Claude
9
- from agno.models.google import Gemini
10
- from agno.models.openai import OpenAIChat
11
- from agno.utils.log import logger
4
+ try:
5
+ from agno.agent import Agent
6
+ from agno.db.base import SessionType
7
+ from agno.models.anthropic import Claude
8
+ from agno.models.google import Gemini
9
+ from agno.models.openai import OpenAIChat
10
+ from agno.utils.log import logger
11
+ except ImportError:
12
+ raise ImportError("`agno` not installed. Please install using `pip install agno`")
13
+
14
+ try:
15
+ import streamlit as st
16
+ except ImportError:
17
+ raise ImportError("`streamlit` not installed. Please install using `pip install streamlit`")
12
18
 
13
19
 
14
20
  def add_message(role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None:
agno/utils/string.py CHANGED
@@ -1,7 +1,9 @@
1
1
  import hashlib
2
2
  import json
3
3
  import re
4
+ import uuid
4
5
  from typing import Optional, Type
6
+ from uuid import uuid4
5
7
 
6
8
  from pydantic import BaseModel, ValidationError
7
9
 
@@ -188,3 +190,33 @@ def parse_response_model_str(content: str, output_schema: Type[BaseModel]) -> Op
188
190
  logger.warning("All parsing attempts failed.")
189
191
 
190
192
  return structured_output
193
+
194
+
195
+ def generate_id(seed: Optional[str] = None) -> str:
196
+ """
197
+ Generate a deterministic UUID5 based on a seed string.
198
+ If no seed is provided, generate a random UUID4.
199
+
200
+ Args:
201
+ seed (str): The seed string to generate the UUID from.
202
+
203
+ Returns:
204
+ str: A deterministic UUID5 string.
205
+ """
206
+ if seed is None:
207
+ return str(uuid4())
208
+ return str(uuid.uuid5(uuid.NAMESPACE_DNS, seed))
209
+
210
+
211
+ def generate_id_from_name(name: Optional[str] = None) -> str:
212
+ """
213
+ Generate a deterministic ID from a name string.
214
+ If no name is provided, generate a random UUID4.
215
+
216
+ Args:
217
+ name (str): The name string to generate the ID from.
218
+ """
219
+ if name:
220
+ return name.lower().replace(" ", "-").replace("_", "-")
221
+ else:
222
+ return str(uuid4())
agno/utils/tools.py CHANGED
@@ -13,7 +13,7 @@ def get_function_call_for_tool_call(
13
13
  _tool_call_function = tool_call.get("function")
14
14
  if _tool_call_function is not None:
15
15
  _tool_call_function_name = _tool_call_function.get("name")
16
- _tool_call_function_arguments_str = _tool_call_function.get("arguments")
16
+ _tool_call_function_arguments_str = _tool_call_function.get("arguments") or "{}"
17
17
  if _tool_call_function_name is not None:
18
18
  return get_function_call(
19
19
  name=_tool_call_function_name,
@@ -766,9 +766,15 @@ class ChromaDb(VectorDb):
766
766
  updated_metadatas.append(updated_meta)
767
767
 
768
768
  # Update the documents
769
+ # Filter out None values from metadata as ChromaDB doesn't accept them
770
+ cleaned_metadatas = []
771
+ for meta in updated_metadatas:
772
+ cleaned_meta = {k: v for k, v in meta.items() if v is not None}
773
+ cleaned_metadatas.append(cleaned_meta)
774
+
769
775
  # Convert to the expected type for ChromaDB
770
- chroma_metadatas = cast(List[Mapping[str, Union[str, int, float, bool, None]]], updated_metadatas)
771
- collection.update(ids=ids, metadatas=chroma_metadatas)
776
+ chroma_metadatas = cast(List[Mapping[str, Union[str, int, float, bool]]], cleaned_metadatas)
777
+ collection.update(ids=ids, metadatas=chroma_metadatas) # type: ignore
772
778
  logger.debug(f"Updated metadata for {len(ids)} documents with content_id: {content_id}")
773
779
 
774
780
  except TypeError as te:
agno/workflow/step.py CHANGED
@@ -164,6 +164,38 @@ class Step:
164
164
  return response.metrics
165
165
  return None
166
166
 
167
+ def _call_custom_function(
168
+ self,
169
+ func: Callable,
170
+ step_input: StepInput,
171
+ session_state: Optional[Dict[str, Any]] = None,
172
+ ) -> Any:
173
+ """Call custom function with session_state support if the function accepts it"""
174
+ if session_state is not None and self._function_has_session_state_param():
175
+ return func(step_input, session_state)
176
+ else:
177
+ return func(step_input)
178
+
179
+ async def _acall_custom_function(
180
+ self,
181
+ func: Callable,
182
+ step_input: StepInput,
183
+ session_state: Optional[Dict[str, Any]] = None,
184
+ ) -> Any:
185
+ """Call custom async function with session_state support if the function accepts it"""
186
+ import inspect
187
+
188
+ if inspect.isasyncgenfunction(func):
189
+ if session_state is not None and self._function_has_session_state_param():
190
+ return func(step_input, session_state)
191
+ else:
192
+ return func(step_input)
193
+ else:
194
+ if session_state is not None and self._function_has_session_state_param():
195
+ return await func(step_input, session_state)
196
+ else:
197
+ return await func(step_input)
198
+
167
199
  def execute(
168
200
  self,
169
201
  step_input: StepInput,
@@ -191,8 +223,11 @@ class Step:
191
223
  if inspect.isgeneratorfunction(self.active_executor):
192
224
  content = ""
193
225
  final_response = None
226
+ session_state_copy = copy(session_state) if session_state else None
194
227
  try:
195
- for chunk in self.active_executor(step_input): # type: ignore
228
+ for chunk in self._call_custom_function(
229
+ self.active_executor, step_input, session_state_copy
230
+ ): # type: ignore
196
231
  if (
197
232
  hasattr(chunk, "content")
198
233
  and chunk.content is not None
@@ -208,13 +243,22 @@ class Step:
208
243
  if hasattr(e, "value") and isinstance(e.value, StepOutput):
209
244
  final_response = e.value
210
245
 
246
+ # Merge session_state changes back
247
+ if session_state_copy and session_state:
248
+ merge_dictionaries(session_state, session_state_copy)
249
+
211
250
  if final_response is not None:
212
251
  response = final_response
213
252
  else:
214
253
  response = StepOutput(content=content)
215
254
  else:
216
- # Execute function directly with StepInput
217
- result = self.active_executor(step_input) # type: ignore
255
+ # Execute function with signature inspection for session_state support
256
+ session_state_copy = copy(session_state) if session_state else None
257
+ result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
258
+
259
+ # Merge session_state changes back
260
+ if session_state_copy and session_state:
261
+ merge_dictionaries(session_state, session_state_copy)
218
262
 
219
263
  # If function returns StepOutput, use it directly
220
264
  if isinstance(result, StepOutput):
@@ -291,6 +335,19 @@ class Step:
291
335
 
292
336
  return StepOutput(content=f"Step {self.name} failed but skipped", success=False)
293
337
 
338
+ def _function_has_session_state_param(self) -> bool:
339
+ """Check if the custom function has a session_state parameter"""
340
+ if self._executor_type != "function":
341
+ return False
342
+
343
+ try:
344
+ from inspect import signature
345
+
346
+ sig = signature(self.active_executor) # type: ignore
347
+ return "session_state" in sig.parameters
348
+ except Exception:
349
+ return False
350
+
294
351
  def execute_stream(
295
352
  self,
296
353
  step_input: StepInput,
@@ -338,8 +395,10 @@ class Step:
338
395
  if inspect.isgeneratorfunction(self.active_executor):
339
396
  log_debug("Function returned iterable, streaming events")
340
397
  content = ""
398
+ session_state_copy = copy(session_state) if session_state else None
341
399
  try:
342
- for event in self.active_executor(step_input): # type: ignore
400
+ iterator = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
401
+ for event in iterator: # type: ignore
343
402
  if (
344
403
  hasattr(event, "content")
345
404
  and event.content is not None
@@ -353,6 +412,11 @@ class Step:
353
412
  break
354
413
  else:
355
414
  yield event # type: ignore[misc]
415
+
416
+ # Merge session_state changes back
417
+ if session_state_copy and session_state:
418
+ merge_dictionaries(session_state, session_state_copy)
419
+
356
420
  if not final_response:
357
421
  final_response = StepOutput(content=content)
358
422
  except StopIteration as e:
@@ -360,7 +424,13 @@ class Step:
360
424
  final_response = e.value
361
425
 
362
426
  else:
363
- result = self.active_executor(step_input) # type: ignore
427
+ session_state_copy = copy(session_state) if session_state else None
428
+ result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
429
+
430
+ # Merge session_state changes back
431
+ if session_state_copy and session_state:
432
+ merge_dictionaries(session_state, session_state_copy)
433
+
364
434
  if isinstance(result, StepOutput):
365
435
  final_response = result
366
436
  else:
@@ -429,7 +499,7 @@ class Step:
429
499
  if store_executor_outputs and workflow_run_response is not None:
430
500
  self._store_executor_response(workflow_run_response, active_executor_run_response) # type: ignore
431
501
 
432
- final_response = self._process_step_output(active_executor_run_response) # type: ignore
502
+ final_response = active_executor_run_response # type: ignore
433
503
 
434
504
  else:
435
505
  raise ValueError(f"Unsupported executor type: {self._executor_type}")
@@ -443,6 +513,7 @@ class Step:
443
513
  use_workflow_logger()
444
514
 
445
515
  # Yield the step output
516
+ final_response = self._process_step_output(final_response)
446
517
  yield final_response
447
518
 
448
519
  # Emit StepCompletedEvent
@@ -505,9 +576,13 @@ class Step:
505
576
  ):
506
577
  content = ""
507
578
  final_response = None
579
+ session_state_copy = copy(session_state) if session_state else None
508
580
  try:
509
581
  if inspect.isgeneratorfunction(self.active_executor):
510
- for chunk in self.active_executor(step_input): # type: ignore
582
+ iterator = self._call_custom_function(
583
+ self.active_executor, step_input, session_state_copy
584
+ ) # type: ignore
585
+ for chunk in iterator: # type: ignore
511
586
  if (
512
587
  hasattr(chunk, "content")
513
588
  and chunk.content is not None
@@ -520,7 +595,10 @@ class Step:
520
595
  final_response = chunk
521
596
  else:
522
597
  if inspect.isasyncgenfunction(self.active_executor):
523
- async for chunk in self.active_executor(step_input): # type: ignore
598
+ iterator = await self._acall_custom_function(
599
+ self.active_executor, step_input, session_state_copy
600
+ ) # type: ignore
601
+ async for chunk in iterator: # type: ignore
524
602
  if (
525
603
  hasattr(chunk, "content")
526
604
  and chunk.content is not None
@@ -536,15 +614,26 @@ class Step:
536
614
  if hasattr(e, "value") and isinstance(e.value, StepOutput):
537
615
  final_response = e.value
538
616
 
617
+ # Merge session_state changes back
618
+ if session_state_copy and session_state:
619
+ merge_dictionaries(session_state, session_state_copy)
620
+
539
621
  if final_response is not None:
540
622
  response = final_response
541
623
  else:
542
624
  response = StepOutput(content=content)
543
625
  else:
626
+ session_state_copy = copy(session_state) if session_state else None
544
627
  if inspect.iscoroutinefunction(self.active_executor):
545
- result = await self.active_executor(step_input) # type: ignore
628
+ result = await self._acall_custom_function(
629
+ self.active_executor, step_input, session_state_copy
630
+ ) # type: ignore
546
631
  else:
547
- result = self.active_executor(step_input) # type: ignore
632
+ result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
633
+
634
+ # Merge session_state changes back
635
+ if session_state_copy and session_state:
636
+ merge_dictionaries(session_state, session_state_copy)
548
637
 
549
638
  # If function returns StepOutput, use it directly
550
639
  if isinstance(result, StepOutput):
@@ -662,11 +751,16 @@ class Step:
662
751
  log_debug(f"Executing async function executor for step: {self.name}")
663
752
  import inspect
664
753
 
754
+ session_state_copy = copy(session_state) if session_state else None
755
+
665
756
  # Check if the function is an async generator
666
757
  if inspect.isasyncgenfunction(self.active_executor):
667
758
  content = ""
668
759
  # It's an async generator - iterate over it
669
- async for event in self.active_executor(step_input): # type: ignore
760
+ iterator = await self._acall_custom_function(
761
+ self.active_executor, step_input, session_state_copy
762
+ ) # type: ignore
763
+ async for event in iterator: # type: ignore
670
764
  if (
671
765
  hasattr(event, "content")
672
766
  and event.content is not None
@@ -684,7 +778,7 @@ class Step:
684
778
  final_response = StepOutput(content=content)
685
779
  elif inspect.iscoroutinefunction(self.active_executor):
686
780
  # It's a regular async function - await it
687
- result = await self.active_executor(step_input) # type: ignore
781
+ result = await self._acall_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
688
782
  if isinstance(result, StepOutput):
689
783
  final_response = result
690
784
  else:
@@ -692,7 +786,8 @@ class Step:
692
786
  elif inspect.isgeneratorfunction(self.active_executor):
693
787
  content = ""
694
788
  # It's a regular generator function - iterate over it
695
- for event in self.active_executor(step_input): # type: ignore
789
+ iterator = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
790
+ for event in iterator: # type: ignore
696
791
  if (
697
792
  hasattr(event, "content")
698
793
  and event.content is not None
@@ -710,11 +805,15 @@ class Step:
710
805
  final_response = StepOutput(content=content)
711
806
  else:
712
807
  # It's a regular function - call it directly
713
- result = self.active_executor(step_input) # type: ignore
808
+ result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
714
809
  if isinstance(result, StepOutput):
715
810
  final_response = result
716
811
  else:
717
812
  final_response = StepOutput(content=str(result))
813
+
814
+ # Merge session_state changes back
815
+ if session_state_copy and session_state:
816
+ merge_dictionaries(session_state, session_state_copy)
718
817
  else:
719
818
  # For agents and teams, prepare message with context
720
819
  message = self._prepare_message(
@@ -767,7 +866,6 @@ class Step:
767
866
 
768
867
  active_executor_run_response = None
769
868
  async for event in response_stream:
770
- log_debug(f"Received async event from agent: {type(event).__name__}")
771
869
  if isinstance(event, RunOutput) or isinstance(event, TeamRunOutput):
772
870
  active_executor_run_response = event
773
871
  break
@@ -779,7 +877,7 @@ class Step:
779
877
  if store_executor_outputs and workflow_run_response is not None:
780
878
  self._store_executor_response(workflow_run_response, active_executor_run_response) # type: ignore
781
879
 
782
- final_response = self._process_step_output(active_executor_run_response) # type: ignore
880
+ final_response = active_executor_run_response # type: ignore
783
881
  else:
784
882
  raise ValueError(f"Unsupported executor type: {self._executor_type}")
785
883
 
@@ -791,6 +889,7 @@ class Step:
791
889
  use_workflow_logger()
792
890
 
793
891
  # Yield the final response
892
+ final_response = self._process_step_output(final_response)
794
893
  yield final_response
795
894
 
796
895
  if stream_intermediate_steps and workflow_run_response: