agno 2.0.3__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +162 -86
- agno/db/dynamo/dynamo.py +8 -0
- agno/db/firestore/firestore.py +8 -1
- agno/db/gcs_json/gcs_json_db.py +9 -0
- agno/db/json/json_db.py +8 -0
- agno/db/mongo/mongo.py +10 -1
- agno/db/mysql/mysql.py +10 -0
- agno/db/postgres/postgres.py +16 -8
- agno/db/redis/redis.py +6 -0
- agno/db/singlestore/schemas.py +1 -1
- agno/db/singlestore/singlestore.py +8 -1
- agno/db/sqlite/sqlite.py +9 -1
- agno/db/utils.py +14 -0
- agno/knowledge/knowledge.py +91 -65
- agno/models/base.py +2 -2
- agno/models/openai/chat.py +3 -0
- agno/models/openai/responses.py +6 -0
- agno/models/response.py +5 -0
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/os/app.py +4 -1
- agno/os/auth.py +24 -14
- agno/os/router.py +128 -55
- agno/os/routers/evals/utils.py +9 -9
- agno/os/routers/health.py +26 -0
- agno/os/routers/knowledge/knowledge.py +11 -11
- agno/os/routers/session/session.py +24 -8
- agno/os/schema.py +8 -2
- agno/run/workflow.py +64 -10
- agno/session/team.py +1 -0
- agno/team/team.py +192 -92
- agno/tools/mem0.py +11 -17
- agno/tools/memory.py +34 -6
- agno/utils/common.py +90 -1
- agno/utils/streamlit.py +14 -8
- agno/vectordb/chroma/chromadb.py +8 -2
- agno/workflow/step.py +111 -13
- agno/workflow/workflow.py +16 -13
- {agno-2.0.3.dist-info → agno-2.0.4.dist-info}/METADATA +1 -1
- {agno-2.0.3.dist-info → agno-2.0.4.dist-info}/RECORD +43 -40
- {agno-2.0.3.dist-info → agno-2.0.4.dist-info}/WHEEL +0 -0
- {agno-2.0.3.dist-info → agno-2.0.4.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.3.dist-info → agno-2.0.4.dist-info}/top_level.txt +0 -0
agno/tools/memory.py
CHANGED
|
@@ -95,13 +95,28 @@ class MemoryTools(Toolkit):
|
|
|
95
95
|
|
|
96
96
|
def get_memories(self, session_state: Dict[str, Any]) -> str:
|
|
97
97
|
"""
|
|
98
|
-
Use this tool to get a list of memories from the database.
|
|
98
|
+
Use this tool to get a list of memories for the current user from the database.
|
|
99
99
|
"""
|
|
100
100
|
try:
|
|
101
101
|
# Get user info from session state
|
|
102
102
|
user_id = session_state.get("current_user_id") if session_state else None
|
|
103
103
|
|
|
104
104
|
memories = self.db.get_user_memories(user_id=user_id)
|
|
105
|
+
|
|
106
|
+
# Store the result in session state for analysis
|
|
107
|
+
if session_state is None:
|
|
108
|
+
session_state = {}
|
|
109
|
+
if "memory_operations" not in session_state:
|
|
110
|
+
session_state["memory_operations"] = []
|
|
111
|
+
|
|
112
|
+
operation_result = {
|
|
113
|
+
"operation": "get_memories",
|
|
114
|
+
"success": True,
|
|
115
|
+
"memories": [memory.to_dict() for memory in memories], # type: ignore
|
|
116
|
+
"error": None,
|
|
117
|
+
}
|
|
118
|
+
session_state["memory_operations"].append(operation_result)
|
|
119
|
+
|
|
105
120
|
return json.dumps([memory.to_dict() for memory in memories], indent=2) # type: ignore
|
|
106
121
|
except Exception as e:
|
|
107
122
|
log_error(f"Error getting memories: {e}")
|
|
@@ -328,19 +343,23 @@ class MemoryTools(Toolkit):
|
|
|
328
343
|
- Purpose: A scratchpad for planning memory operations, brainstorming memory content, and refining your approach. You never reveal your "Think" content to the user.
|
|
329
344
|
- Usage: Call `think` whenever you need to figure out what memory operations to perform, analyze requirements, or decide on strategy.
|
|
330
345
|
|
|
331
|
-
2. **
|
|
346
|
+
2. **Get Memories**
|
|
347
|
+
- Purpose: Retrieves a list of memories from the database for the current user.
|
|
348
|
+
- Usage: Call `get_memories` when you need to retrieve memories for the current user.
|
|
349
|
+
|
|
350
|
+
3. **Add Memory**
|
|
332
351
|
- Purpose: Creates new memories in the database with specified content and metadata.
|
|
333
352
|
- Usage: Call `add_memory` with memory content and optional topics when you need to store new information.
|
|
334
353
|
|
|
335
|
-
|
|
354
|
+
4. **Update Memory**
|
|
336
355
|
- Purpose: Modifies existing memories in the database by memory ID.
|
|
337
356
|
- Usage: Call `update_memory` with a memory ID and the fields you want to change. Only specify the fields that need updating.
|
|
338
357
|
|
|
339
|
-
|
|
358
|
+
5. **Delete Memory**
|
|
340
359
|
- Purpose: Removes memories from the database by memory ID.
|
|
341
360
|
- Usage: Call `delete_memory` with a memory ID when a memory is no longer needed or requested to be removed.
|
|
342
361
|
|
|
343
|
-
|
|
362
|
+
6. **Analyze**
|
|
344
363
|
- Purpose: Evaluate whether the memory operations results are correct and sufficient. If not, go back to "Think" or use memory operations with refined parameters.
|
|
345
364
|
- Usage: Call `analyze` after performing memory operations to verify:
|
|
346
365
|
- Success: Did the operation complete successfully?
|
|
@@ -387,5 +406,14 @@ class MemoryTools(Toolkit):
|
|
|
387
406
|
Delete Memory: memory_id="work_schedule_memory_id"
|
|
388
407
|
Analyze: Successfully deleted the outdated work schedule memory. The old information won't interfere with future scheduling requests.
|
|
389
408
|
|
|
390
|
-
Final Answer: I've removed your old work schedule information. Feel free to share your new schedule when you're ready, and I'll store the updated information
|
|
409
|
+
Final Answer: I've removed your old work schedule information. Feel free to share your new schedule when you're ready, and I'll store the updated information.
|
|
410
|
+
|
|
411
|
+
#### Example 4: Retrieving Memories
|
|
412
|
+
|
|
413
|
+
User: What have you remembered about me?
|
|
414
|
+
Think: The user wants to retrieve memories about themselves. I should use the get_memories tool to retrieve the memories.
|
|
415
|
+
Get Memories:
|
|
416
|
+
Analyze: Successfully retrieved the memories about the user. The memories are relevant to the user's preferences and activities.
|
|
417
|
+
|
|
418
|
+
Final Answer: I've retrieved the memories about you. You like to hike in the mountains on weekends and travel to new places and experience different cultures. You are planning to travel to Africa in December.\
|
|
391
419
|
""")
|
agno/utils/common.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import asdict
|
|
2
|
-
from typing import Any, List, Optional, Type
|
|
2
|
+
from typing import Any, List, Optional, Set, Type, Union, get_type_hints
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def isinstanceany(obj: Any, class_list: List[Type]) -> bool:
|
|
@@ -41,3 +41,92 @@ def nested_model_dump(value):
|
|
|
41
41
|
elif isinstance(value, list):
|
|
42
42
|
return [nested_model_dump(item) for item in value]
|
|
43
43
|
return value
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_typed_dict(cls: Type[Any]) -> bool:
|
|
47
|
+
"""Check if a class is a TypedDict"""
|
|
48
|
+
return (
|
|
49
|
+
hasattr(cls, "__annotations__")
|
|
50
|
+
and hasattr(cls, "__total__")
|
|
51
|
+
and hasattr(cls, "__required_keys__")
|
|
52
|
+
and hasattr(cls, "__optional_keys__")
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def check_type_compatibility(value: Any, expected_type: Type) -> bool:
|
|
57
|
+
"""Basic type compatibility checking."""
|
|
58
|
+
from typing import get_args, get_origin
|
|
59
|
+
|
|
60
|
+
# Handle None/Optional types
|
|
61
|
+
if value is None:
|
|
62
|
+
return (
|
|
63
|
+
type(None) in get_args(expected_type) if hasattr(expected_type, "__args__") else expected_type is type(None)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Handle Union types (including Optional)
|
|
67
|
+
origin = get_origin(expected_type)
|
|
68
|
+
if origin is Union:
|
|
69
|
+
return any(check_type_compatibility(value, arg) for arg in get_args(expected_type))
|
|
70
|
+
|
|
71
|
+
# Handle List types
|
|
72
|
+
if origin is list or expected_type is list:
|
|
73
|
+
if not isinstance(value, list):
|
|
74
|
+
return False
|
|
75
|
+
if origin is list and get_args(expected_type):
|
|
76
|
+
element_type = get_args(expected_type)[0]
|
|
77
|
+
return all(check_type_compatibility(item, element_type) for item in value)
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
if expected_type in (str, int, float, bool):
|
|
81
|
+
return isinstance(value, expected_type)
|
|
82
|
+
|
|
83
|
+
if expected_type is Any:
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
return isinstance(value, expected_type)
|
|
88
|
+
except TypeError:
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def validate_typed_dict(data: dict, schema_cls) -> dict:
|
|
93
|
+
"""Validate input data against a TypedDict schema."""
|
|
94
|
+
if not isinstance(data, dict):
|
|
95
|
+
raise ValueError(f"Expected dict for TypedDict {schema_cls.__name__}, got {type(data)}")
|
|
96
|
+
|
|
97
|
+
# Get type hints from the TypedDict
|
|
98
|
+
try:
|
|
99
|
+
type_hints = get_type_hints(schema_cls)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise ValueError(f"Could not get type hints for TypedDict {schema_cls.__name__}: {e}")
|
|
102
|
+
|
|
103
|
+
# Get required and optional keys
|
|
104
|
+
required_keys: Set[str] = getattr(schema_cls, "__required_keys__", set())
|
|
105
|
+
optional_keys: Set[str] = getattr(schema_cls, "__optional_keys__", set())
|
|
106
|
+
all_keys = required_keys | optional_keys
|
|
107
|
+
|
|
108
|
+
# Check for missing required fields
|
|
109
|
+
missing_required = required_keys - set(data.keys())
|
|
110
|
+
if missing_required:
|
|
111
|
+
raise ValueError(f"Missing required fields in TypedDict {schema_cls.__name__}: {missing_required}")
|
|
112
|
+
|
|
113
|
+
# Check for unexpected fields
|
|
114
|
+
unexpected_fields = set(data.keys()) - all_keys
|
|
115
|
+
if unexpected_fields:
|
|
116
|
+
raise ValueError(f"Unexpected fields in TypedDict {schema_cls.__name__}: {unexpected_fields}")
|
|
117
|
+
|
|
118
|
+
# Basic type checking for provided fields
|
|
119
|
+
validated_data = {}
|
|
120
|
+
for field_name, value in data.items():
|
|
121
|
+
if field_name in type_hints:
|
|
122
|
+
expected_type = type_hints[field_name]
|
|
123
|
+
|
|
124
|
+
# Handle simple type checking
|
|
125
|
+
if not check_type_compatibility(value, expected_type):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Field '{field_name}' expected type {expected_type}, got {type(value)} with value {value}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
validated_data[field_name] = value
|
|
131
|
+
|
|
132
|
+
return validated_data
|
agno/utils/streamlit.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from typing import Any, Callable, Dict, List, Optional
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from agno.
|
|
7
|
-
from agno.
|
|
8
|
-
from agno.models.
|
|
9
|
-
from agno.models.
|
|
10
|
-
from agno.
|
|
11
|
-
|
|
4
|
+
try:
|
|
5
|
+
from agno.agent import Agent
|
|
6
|
+
from agno.db.base import SessionType
|
|
7
|
+
from agno.models.anthropic import Claude
|
|
8
|
+
from agno.models.google import Gemini
|
|
9
|
+
from agno.models.openai import OpenAIChat
|
|
10
|
+
from agno.utils.log import logger
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError("`agno` not installed. Please install using `pip install agno`")
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import streamlit as st
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError("`streamlit` not installed. Please install using `pip install streamlit`")
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
def add_message(role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None:
|
agno/vectordb/chroma/chromadb.py
CHANGED
|
@@ -766,9 +766,15 @@ class ChromaDb(VectorDb):
|
|
|
766
766
|
updated_metadatas.append(updated_meta)
|
|
767
767
|
|
|
768
768
|
# Update the documents
|
|
769
|
+
# Filter out None values from metadata as ChromaDB doesn't accept them
|
|
770
|
+
cleaned_metadatas = []
|
|
771
|
+
for meta in updated_metadatas:
|
|
772
|
+
cleaned_meta = {k: v for k, v in meta.items() if v is not None}
|
|
773
|
+
cleaned_metadatas.append(cleaned_meta)
|
|
774
|
+
|
|
769
775
|
# Convert to the expected type for ChromaDB
|
|
770
|
-
chroma_metadatas = cast(List[Mapping[str, Union[str, int, float, bool
|
|
771
|
-
collection.update(ids=ids, metadatas=chroma_metadatas)
|
|
776
|
+
chroma_metadatas = cast(List[Mapping[str, Union[str, int, float, bool]]], cleaned_metadatas)
|
|
777
|
+
collection.update(ids=ids, metadatas=chroma_metadatas) # type: ignore
|
|
772
778
|
logger.debug(f"Updated metadata for {len(ids)} documents with content_id: {content_id}")
|
|
773
779
|
|
|
774
780
|
except TypeError as te:
|
agno/workflow/step.py
CHANGED
|
@@ -164,6 +164,38 @@ class Step:
|
|
|
164
164
|
return response.metrics
|
|
165
165
|
return None
|
|
166
166
|
|
|
167
|
+
def _call_custom_function(
|
|
168
|
+
self,
|
|
169
|
+
func: Callable,
|
|
170
|
+
step_input: StepInput,
|
|
171
|
+
session_state: Optional[Dict[str, Any]] = None,
|
|
172
|
+
) -> Any:
|
|
173
|
+
"""Call custom function with session_state support if the function accepts it"""
|
|
174
|
+
if session_state is not None and self._function_has_session_state_param():
|
|
175
|
+
return func(step_input, session_state)
|
|
176
|
+
else:
|
|
177
|
+
return func(step_input)
|
|
178
|
+
|
|
179
|
+
async def _acall_custom_function(
|
|
180
|
+
self,
|
|
181
|
+
func: Callable,
|
|
182
|
+
step_input: StepInput,
|
|
183
|
+
session_state: Optional[Dict[str, Any]] = None,
|
|
184
|
+
) -> Any:
|
|
185
|
+
"""Call custom async function with session_state support if the function accepts it"""
|
|
186
|
+
import inspect
|
|
187
|
+
|
|
188
|
+
if inspect.isasyncgenfunction(func):
|
|
189
|
+
if session_state is not None and self._function_has_session_state_param():
|
|
190
|
+
return func(step_input, session_state)
|
|
191
|
+
else:
|
|
192
|
+
return func(step_input)
|
|
193
|
+
else:
|
|
194
|
+
if session_state is not None and self._function_has_session_state_param():
|
|
195
|
+
return await func(step_input, session_state)
|
|
196
|
+
else:
|
|
197
|
+
return await func(step_input)
|
|
198
|
+
|
|
167
199
|
def execute(
|
|
168
200
|
self,
|
|
169
201
|
step_input: StepInput,
|
|
@@ -191,8 +223,11 @@ class Step:
|
|
|
191
223
|
if inspect.isgeneratorfunction(self.active_executor):
|
|
192
224
|
content = ""
|
|
193
225
|
final_response = None
|
|
226
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
194
227
|
try:
|
|
195
|
-
for chunk in self.
|
|
228
|
+
for chunk in self._call_custom_function(
|
|
229
|
+
self.active_executor, step_input, session_state_copy
|
|
230
|
+
): # type: ignore
|
|
196
231
|
if (
|
|
197
232
|
hasattr(chunk, "content")
|
|
198
233
|
and chunk.content is not None
|
|
@@ -208,13 +243,22 @@ class Step:
|
|
|
208
243
|
if hasattr(e, "value") and isinstance(e.value, StepOutput):
|
|
209
244
|
final_response = e.value
|
|
210
245
|
|
|
246
|
+
# Merge session_state changes back
|
|
247
|
+
if session_state_copy and session_state:
|
|
248
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
249
|
+
|
|
211
250
|
if final_response is not None:
|
|
212
251
|
response = final_response
|
|
213
252
|
else:
|
|
214
253
|
response = StepOutput(content=content)
|
|
215
254
|
else:
|
|
216
|
-
# Execute function
|
|
217
|
-
|
|
255
|
+
# Execute function with signature inspection for session_state support
|
|
256
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
257
|
+
result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
258
|
+
|
|
259
|
+
# Merge session_state changes back
|
|
260
|
+
if session_state_copy and session_state:
|
|
261
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
218
262
|
|
|
219
263
|
# If function returns StepOutput, use it directly
|
|
220
264
|
if isinstance(result, StepOutput):
|
|
@@ -291,6 +335,19 @@ class Step:
|
|
|
291
335
|
|
|
292
336
|
return StepOutput(content=f"Step {self.name} failed but skipped", success=False)
|
|
293
337
|
|
|
338
|
+
def _function_has_session_state_param(self) -> bool:
|
|
339
|
+
"""Check if the custom function has a session_state parameter"""
|
|
340
|
+
if self._executor_type != "function":
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
from inspect import signature
|
|
345
|
+
|
|
346
|
+
sig = signature(self.active_executor) # type: ignore
|
|
347
|
+
return "session_state" in sig.parameters
|
|
348
|
+
except Exception:
|
|
349
|
+
return False
|
|
350
|
+
|
|
294
351
|
def execute_stream(
|
|
295
352
|
self,
|
|
296
353
|
step_input: StepInput,
|
|
@@ -338,8 +395,10 @@ class Step:
|
|
|
338
395
|
if inspect.isgeneratorfunction(self.active_executor):
|
|
339
396
|
log_debug("Function returned iterable, streaming events")
|
|
340
397
|
content = ""
|
|
398
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
341
399
|
try:
|
|
342
|
-
|
|
400
|
+
iterator = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
401
|
+
for event in iterator: # type: ignore
|
|
343
402
|
if (
|
|
344
403
|
hasattr(event, "content")
|
|
345
404
|
and event.content is not None
|
|
@@ -353,6 +412,11 @@ class Step:
|
|
|
353
412
|
break
|
|
354
413
|
else:
|
|
355
414
|
yield event # type: ignore[misc]
|
|
415
|
+
|
|
416
|
+
# Merge session_state changes back
|
|
417
|
+
if session_state_copy and session_state:
|
|
418
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
419
|
+
|
|
356
420
|
if not final_response:
|
|
357
421
|
final_response = StepOutput(content=content)
|
|
358
422
|
except StopIteration as e:
|
|
@@ -360,7 +424,13 @@ class Step:
|
|
|
360
424
|
final_response = e.value
|
|
361
425
|
|
|
362
426
|
else:
|
|
363
|
-
|
|
427
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
428
|
+
result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
429
|
+
|
|
430
|
+
# Merge session_state changes back
|
|
431
|
+
if session_state_copy and session_state:
|
|
432
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
433
|
+
|
|
364
434
|
if isinstance(result, StepOutput):
|
|
365
435
|
final_response = result
|
|
366
436
|
else:
|
|
@@ -505,9 +575,13 @@ class Step:
|
|
|
505
575
|
):
|
|
506
576
|
content = ""
|
|
507
577
|
final_response = None
|
|
578
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
508
579
|
try:
|
|
509
580
|
if inspect.isgeneratorfunction(self.active_executor):
|
|
510
|
-
|
|
581
|
+
iterator = self._call_custom_function(
|
|
582
|
+
self.active_executor, step_input, session_state_copy
|
|
583
|
+
) # type: ignore
|
|
584
|
+
for chunk in iterator: # type: ignore
|
|
511
585
|
if (
|
|
512
586
|
hasattr(chunk, "content")
|
|
513
587
|
and chunk.content is not None
|
|
@@ -520,7 +594,10 @@ class Step:
|
|
|
520
594
|
final_response = chunk
|
|
521
595
|
else:
|
|
522
596
|
if inspect.isasyncgenfunction(self.active_executor):
|
|
523
|
-
|
|
597
|
+
iterator = await self._acall_custom_function(
|
|
598
|
+
self.active_executor, step_input, session_state_copy
|
|
599
|
+
) # type: ignore
|
|
600
|
+
async for chunk in iterator: # type: ignore
|
|
524
601
|
if (
|
|
525
602
|
hasattr(chunk, "content")
|
|
526
603
|
and chunk.content is not None
|
|
@@ -536,15 +613,26 @@ class Step:
|
|
|
536
613
|
if hasattr(e, "value") and isinstance(e.value, StepOutput):
|
|
537
614
|
final_response = e.value
|
|
538
615
|
|
|
616
|
+
# Merge session_state changes back
|
|
617
|
+
if session_state_copy and session_state:
|
|
618
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
619
|
+
|
|
539
620
|
if final_response is not None:
|
|
540
621
|
response = final_response
|
|
541
622
|
else:
|
|
542
623
|
response = StepOutput(content=content)
|
|
543
624
|
else:
|
|
625
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
544
626
|
if inspect.iscoroutinefunction(self.active_executor):
|
|
545
|
-
result = await self.
|
|
627
|
+
result = await self._acall_custom_function(
|
|
628
|
+
self.active_executor, step_input, session_state_copy
|
|
629
|
+
) # type: ignore
|
|
546
630
|
else:
|
|
547
|
-
result = self.active_executor
|
|
631
|
+
result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
632
|
+
|
|
633
|
+
# Merge session_state changes back
|
|
634
|
+
if session_state_copy and session_state:
|
|
635
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
548
636
|
|
|
549
637
|
# If function returns StepOutput, use it directly
|
|
550
638
|
if isinstance(result, StepOutput):
|
|
@@ -662,11 +750,16 @@ class Step:
|
|
|
662
750
|
log_debug(f"Executing async function executor for step: {self.name}")
|
|
663
751
|
import inspect
|
|
664
752
|
|
|
753
|
+
session_state_copy = copy(session_state) if session_state else None
|
|
754
|
+
|
|
665
755
|
# Check if the function is an async generator
|
|
666
756
|
if inspect.isasyncgenfunction(self.active_executor):
|
|
667
757
|
content = ""
|
|
668
758
|
# It's an async generator - iterate over it
|
|
669
|
-
|
|
759
|
+
iterator = await self._acall_custom_function(
|
|
760
|
+
self.active_executor, step_input, session_state_copy
|
|
761
|
+
) # type: ignore
|
|
762
|
+
async for event in iterator: # type: ignore
|
|
670
763
|
if (
|
|
671
764
|
hasattr(event, "content")
|
|
672
765
|
and event.content is not None
|
|
@@ -684,7 +777,7 @@ class Step:
|
|
|
684
777
|
final_response = StepOutput(content=content)
|
|
685
778
|
elif inspect.iscoroutinefunction(self.active_executor):
|
|
686
779
|
# It's a regular async function - await it
|
|
687
|
-
result = await self.active_executor
|
|
780
|
+
result = await self._acall_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
688
781
|
if isinstance(result, StepOutput):
|
|
689
782
|
final_response = result
|
|
690
783
|
else:
|
|
@@ -692,7 +785,8 @@ class Step:
|
|
|
692
785
|
elif inspect.isgeneratorfunction(self.active_executor):
|
|
693
786
|
content = ""
|
|
694
787
|
# It's a regular generator function - iterate over it
|
|
695
|
-
|
|
788
|
+
iterator = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
789
|
+
for event in iterator: # type: ignore
|
|
696
790
|
if (
|
|
697
791
|
hasattr(event, "content")
|
|
698
792
|
and event.content is not None
|
|
@@ -710,11 +804,15 @@ class Step:
|
|
|
710
804
|
final_response = StepOutput(content=content)
|
|
711
805
|
else:
|
|
712
806
|
# It's a regular function - call it directly
|
|
713
|
-
result = self.active_executor
|
|
807
|
+
result = self._call_custom_function(self.active_executor, step_input, session_state_copy) # type: ignore
|
|
714
808
|
if isinstance(result, StepOutput):
|
|
715
809
|
final_response = result
|
|
716
810
|
else:
|
|
717
811
|
final_response = StepOutput(content=str(result))
|
|
812
|
+
|
|
813
|
+
# Merge session_state changes back
|
|
814
|
+
if session_state_copy and session_state:
|
|
815
|
+
merge_dictionaries(session_state, session_state_copy)
|
|
718
816
|
else:
|
|
719
817
|
# For agents and teams, prepare message with context
|
|
720
818
|
message = self._prepare_message(
|
agno/workflow/workflow.py
CHANGED
|
@@ -51,6 +51,7 @@ from agno.run.workflow import (
|
|
|
51
51
|
)
|
|
52
52
|
from agno.session.workflow import WorkflowSession
|
|
53
53
|
from agno.team.team import Team
|
|
54
|
+
from agno.utils.common import is_typed_dict, validate_typed_dict
|
|
54
55
|
from agno.utils.log import (
|
|
55
56
|
log_debug,
|
|
56
57
|
log_warning,
|
|
@@ -217,14 +218,18 @@ class Workflow:
|
|
|
217
218
|
|
|
218
219
|
def _validate_input(
|
|
219
220
|
self, input: Optional[Union[str, Dict[str, Any], List[Any], BaseModel, List[Message]]]
|
|
220
|
-
) -> Optional[BaseModel]:
|
|
221
|
+
) -> Optional[Union[str, List, Dict, Message, BaseModel]]:
|
|
221
222
|
"""Parse and validate input against input_schema if provided"""
|
|
222
223
|
if self.input_schema is None:
|
|
223
|
-
return
|
|
224
|
+
return input # Return input unchanged if no schema is set
|
|
224
225
|
|
|
225
226
|
if input is None:
|
|
226
227
|
raise ValueError("Input required when input_schema is set")
|
|
227
228
|
|
|
229
|
+
# Handle Message objects - extract content
|
|
230
|
+
if isinstance(input, Message):
|
|
231
|
+
input = input.content # type: ignore
|
|
232
|
+
|
|
228
233
|
# If input is a string, convert it to a dict
|
|
229
234
|
if isinstance(input, str):
|
|
230
235
|
import json
|
|
@@ -238,8 +243,6 @@ class Workflow:
|
|
|
238
243
|
if isinstance(input, BaseModel):
|
|
239
244
|
if isinstance(input, self.input_schema):
|
|
240
245
|
try:
|
|
241
|
-
# Re-validate to catch any field validation errors
|
|
242
|
-
input.model_validate(input.model_dump())
|
|
243
246
|
return input
|
|
244
247
|
except Exception as e:
|
|
245
248
|
raise ValueError(f"BaseModel validation failed: {str(e)}")
|
|
@@ -250,8 +253,13 @@ class Workflow:
|
|
|
250
253
|
# Case 2: Message is a dict
|
|
251
254
|
elif isinstance(input, dict):
|
|
252
255
|
try:
|
|
253
|
-
|
|
254
|
-
|
|
256
|
+
# Check if the schema is a TypedDict
|
|
257
|
+
if is_typed_dict(self.input_schema):
|
|
258
|
+
validated_dict = validate_typed_dict(input, self.input_schema)
|
|
259
|
+
return validated_dict
|
|
260
|
+
else:
|
|
261
|
+
validated_model = self.input_schema(**input)
|
|
262
|
+
return validated_model
|
|
255
263
|
except Exception as e:
|
|
256
264
|
raise ValueError(f"Failed to parse dict into {self.input_schema.__name__}: {str(e)}")
|
|
257
265
|
|
|
@@ -1924,10 +1932,7 @@ class Workflow:
|
|
|
1924
1932
|
) -> Union[WorkflowRunOutput, Iterator[WorkflowRunOutputEvent]]:
|
|
1925
1933
|
"""Execute the workflow synchronously with optional streaming"""
|
|
1926
1934
|
|
|
1927
|
-
|
|
1928
|
-
if validated_input is not None:
|
|
1929
|
-
input = validated_input
|
|
1930
|
-
|
|
1935
|
+
input = self._validate_input(input)
|
|
1931
1936
|
if background:
|
|
1932
1937
|
raise RuntimeError("Background execution is not supported for sync run()")
|
|
1933
1938
|
|
|
@@ -2059,9 +2064,7 @@ class Workflow:
|
|
|
2059
2064
|
) -> Union[WorkflowRunOutput, AsyncIterator[WorkflowRunOutputEvent]]:
|
|
2060
2065
|
"""Execute the workflow synchronously with optional streaming"""
|
|
2061
2066
|
|
|
2062
|
-
|
|
2063
|
-
if validated_input is not None:
|
|
2064
|
-
input = validated_input
|
|
2067
|
+
input = self._validate_input(input)
|
|
2065
2068
|
|
|
2066
2069
|
websocket_handler = None
|
|
2067
2070
|
if websocket:
|