praisonaiagents 0.0.96__tar.gz → 0.0.97__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/PKG-INFO +5 -1
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/__init__.py +2 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agent/agent.py +46 -37
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agents/agents.py +99 -6
- praisonaiagents-0.0.97/praisonaiagents/approval.py +263 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/knowledge/knowledge.py +27 -4
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/llm/llm.py +4 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/main.py +15 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/memory/memory.py +67 -6
- praisonaiagents-0.0.97/praisonaiagents/session.py +291 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/file_tools.py +5 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/python_tools.py +2 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/shell_tools.py +3 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents.egg-info/PKG-INFO +5 -1
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents.egg-info/SOURCES.txt +3 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents.egg-info/requires.txt +5 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/pyproject.toml +8 -1
- praisonaiagents-0.0.97/tests/test-graph-memory.py +135 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/README.md +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agent/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agent/image_agent.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agents/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/agents/autoagents.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/knowledge/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/knowledge/chunking.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/llm/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/mcp/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/mcp/mcp.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/mcp/mcp_sse.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/process/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/process/process.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/task/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/task/task.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/__init__.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/arxiv_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/calculator_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/csv_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/duckdb_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/duckduckgo_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/excel_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/json_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/newspaper_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/pandas_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/spider_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/test.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/train/data/generatecot.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/wikipedia_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/xml_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/yaml_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents/tools/yfinance_tools.py +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents.egg-info/dependency_links.txt +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/praisonaiagents.egg-info/top_level.txt +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/setup.cfg +0 -0
- {praisonaiagents-0.0.96 → praisonaiagents-0.0.97}/tests/test.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: praisonaiagents
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.97
|
4
4
|
Summary: Praison AI agents for completing complex tasks with Self Reflection Agents
|
5
5
|
Author: Mervin Praison
|
6
6
|
Requires-Python: >=3.10
|
@@ -19,6 +19,9 @@ Requires-Dist: mem0ai>=0.1.0; extra == "knowledge"
|
|
19
19
|
Requires-Dist: chromadb>=1.0.0; extra == "knowledge"
|
20
20
|
Requires-Dist: markitdown[all]>=0.1.0; extra == "knowledge"
|
21
21
|
Requires-Dist: chonkie>=1.0.2; extra == "knowledge"
|
22
|
+
Provides-Extra: graph
|
23
|
+
Requires-Dist: mem0ai[graph]>=0.1.0; extra == "graph"
|
24
|
+
Requires-Dist: chromadb>=1.0.0; extra == "graph"
|
22
25
|
Provides-Extra: llm
|
23
26
|
Requires-Dist: litellm>=1.50.0; extra == "llm"
|
24
27
|
Requires-Dist: pydantic>=2.4.2; extra == "llm"
|
@@ -28,6 +31,7 @@ Requires-Dist: uvicorn>=0.34.0; extra == "api"
|
|
28
31
|
Provides-Extra: all
|
29
32
|
Requires-Dist: praisonaiagents[memory]; extra == "all"
|
30
33
|
Requires-Dist: praisonaiagents[knowledge]; extra == "all"
|
34
|
+
Requires-Dist: praisonaiagents[graph]; extra == "all"
|
31
35
|
Requires-Dist: praisonaiagents[llm]; extra == "all"
|
32
36
|
Requires-Dist: praisonaiagents[mcp]; extra == "all"
|
33
37
|
Requires-Dist: praisonaiagents[api]; extra == "all"
|
@@ -11,6 +11,7 @@ from .agents.autoagents import AutoAgents
|
|
11
11
|
from .knowledge.knowledge import Knowledge
|
12
12
|
from .knowledge.chunking import Chunking
|
13
13
|
from .mcp.mcp import MCP
|
14
|
+
from .session import Session
|
14
15
|
from .main import (
|
15
16
|
TaskOutput,
|
16
17
|
ReflectionOutput,
|
@@ -40,6 +41,7 @@ __all__ = [
|
|
40
41
|
'TaskOutput',
|
41
42
|
'ReflectionOutput',
|
42
43
|
'AutoAgents',
|
44
|
+
'Session',
|
43
45
|
'display_interaction',
|
44
46
|
'display_self_reflection',
|
45
47
|
'display_instruction',
|
@@ -16,7 +16,8 @@ from ..main import (
|
|
16
16
|
display_self_reflection,
|
17
17
|
ReflectionOutput,
|
18
18
|
client,
|
19
|
-
adisplay_instruction
|
19
|
+
adisplay_instruction,
|
20
|
+
approval_callback
|
20
21
|
)
|
21
22
|
import inspect
|
22
23
|
import uuid
|
@@ -570,6 +571,35 @@ Your Goal: {self.goal}
|
|
570
571
|
"""
|
571
572
|
logging.debug(f"{self.name} executing tool {function_name} with arguments: {arguments}")
|
572
573
|
|
574
|
+
# Check if approval is required for this tool
|
575
|
+
from ..approval import is_approval_required, console_approval_callback, get_risk_level, mark_approved, ApprovalDecision
|
576
|
+
if is_approval_required(function_name):
|
577
|
+
risk_level = get_risk_level(function_name)
|
578
|
+
logging.info(f"Tool {function_name} requires approval (risk level: {risk_level})")
|
579
|
+
|
580
|
+
# Use global approval callback or default console callback
|
581
|
+
callback = approval_callback or console_approval_callback
|
582
|
+
|
583
|
+
try:
|
584
|
+
decision = callback(function_name, arguments, risk_level)
|
585
|
+
if not decision.approved:
|
586
|
+
error_msg = f"Tool execution denied: {decision.reason}"
|
587
|
+
logging.warning(error_msg)
|
588
|
+
return {"error": error_msg, "approval_denied": True}
|
589
|
+
|
590
|
+
# Mark as approved in context to prevent double approval in decorator
|
591
|
+
mark_approved(function_name)
|
592
|
+
|
593
|
+
# Use modified arguments if provided
|
594
|
+
if decision.modified_args:
|
595
|
+
arguments = decision.modified_args
|
596
|
+
logging.info(f"Using modified arguments: {arguments}")
|
597
|
+
|
598
|
+
except Exception as e:
|
599
|
+
error_msg = f"Error during approval process: {str(e)}"
|
600
|
+
logging.error(error_msg)
|
601
|
+
return {"error": error_msg, "approval_error": True}
|
602
|
+
|
573
603
|
# Special handling for MCP tools
|
574
604
|
# Check if tools is an MCP instance with the requested function name
|
575
605
|
from ..mcp.mcp import MCP
|
@@ -982,43 +1012,7 @@ Your Goal: {self.goal}
|
|
982
1012
|
if not response:
|
983
1013
|
return None
|
984
1014
|
|
985
|
-
tool_calls = getattr(response.choices[0].message, 'tool_calls', None)
|
986
1015
|
response_text = response.choices[0].message.content.strip()
|
987
|
-
if tool_calls: ## TODO: Most likely this tool call is already called in _chat_completion, so maybe we can remove this.
|
988
|
-
messages.append({
|
989
|
-
"role": "assistant",
|
990
|
-
"content": response_text,
|
991
|
-
"tool_calls": tool_calls
|
992
|
-
})
|
993
|
-
|
994
|
-
for tool_call in tool_calls:
|
995
|
-
function_name = tool_call.function.name
|
996
|
-
arguments = json.loads(tool_call.function.arguments)
|
997
|
-
|
998
|
-
if self.verbose:
|
999
|
-
display_tool_call(f"Agent {self.name} is calling function '{function_name}' with arguments: {arguments}", console=self.console)
|
1000
|
-
|
1001
|
-
tool_result = self.execute_tool(function_name, arguments)
|
1002
|
-
|
1003
|
-
if tool_result:
|
1004
|
-
if self.verbose:
|
1005
|
-
display_tool_call(f"Function '{function_name}' returned: {tool_result}", console=self.console)
|
1006
|
-
messages.append({
|
1007
|
-
"role": "tool",
|
1008
|
-
"tool_call_id": tool_call.id,
|
1009
|
-
"content": json.dumps(tool_result)
|
1010
|
-
})
|
1011
|
-
else:
|
1012
|
-
messages.append({
|
1013
|
-
"role": "tool",
|
1014
|
-
"tool_call_id": tool_call.id,
|
1015
|
-
"content": "Function returned an empty output"
|
1016
|
-
})
|
1017
|
-
|
1018
|
-
response = self._chat_completion(messages, temperature=temperature, stream=stream)
|
1019
|
-
if not response:
|
1020
|
-
return None
|
1021
|
-
response_text = response.choices[0].message.content.strip()
|
1022
1016
|
|
1023
1017
|
# Handle output_json or output_pydantic if specified
|
1024
1018
|
if output_json or output_pydantic:
|
@@ -1418,6 +1412,21 @@ Your Goal: {self.goal}
|
|
1418
1412
|
"""Async version of execute_tool"""
|
1419
1413
|
try:
|
1420
1414
|
logging.info(f"Executing async tool: {function_name} with arguments: {arguments}")
|
1415
|
+
|
1416
|
+
# Check if approval is required for this tool
|
1417
|
+
from ..approval import is_approval_required, request_approval
|
1418
|
+
if is_approval_required(function_name):
|
1419
|
+
decision = await request_approval(function_name, arguments)
|
1420
|
+
if not decision.approved:
|
1421
|
+
error_msg = f"Tool execution denied: {decision.reason}"
|
1422
|
+
logging.warning(error_msg)
|
1423
|
+
return {"error": error_msg, "approval_denied": True}
|
1424
|
+
|
1425
|
+
# Use modified arguments if provided
|
1426
|
+
if decision.modified_args:
|
1427
|
+
arguments = decision.modified_args
|
1428
|
+
logging.info(f"Using modified arguments: {arguments}")
|
1429
|
+
|
1421
1430
|
# Try to find the function in the agent's tools list first
|
1422
1431
|
func = None
|
1423
1432
|
for tool in self.tools:
|
@@ -63,7 +63,6 @@ def process_task_context(context_item, verbose=0, user_id=None):
|
|
63
63
|
"""
|
64
64
|
Process a single context item for task execution.
|
65
65
|
This helper function avoids code duplication between async and sync execution methods.
|
66
|
-
|
67
66
|
Args:
|
68
67
|
context_item: The context item to process (can be string, list, task object, or dict)
|
69
68
|
verbose: Verbosity level for logging
|
@@ -203,7 +202,6 @@ class PraisonAIAgents:
|
|
203
202
|
mem_cfg = memory_config
|
204
203
|
if not mem_cfg:
|
205
204
|
mem_cfg = next((t.config.get('memory_config') for t in tasks if hasattr(t, 'config') and t.config), None)
|
206
|
-
|
207
205
|
# Set default memory config if none provided
|
208
206
|
if not mem_cfg:
|
209
207
|
mem_cfg = {
|
@@ -215,7 +213,6 @@ class PraisonAIAgents:
|
|
215
213
|
},
|
216
214
|
"rag_db_path": "./.praison/chroma_db"
|
217
215
|
}
|
218
|
-
|
219
216
|
# Add embedder config if provided
|
220
217
|
if embedder:
|
221
218
|
if isinstance(embedder, dict):
|
@@ -231,17 +228,14 @@ class PraisonAIAgents:
|
|
231
228
|
self.shared_memory = Memory(config=mem_cfg, verbose=verbose)
|
232
229
|
if verbose >= 5:
|
233
230
|
logger.info("Initialized shared memory for PraisonAIAgents")
|
234
|
-
|
235
231
|
# Distribute memory to tasks
|
236
232
|
for task in tasks:
|
237
233
|
if not task.memory:
|
238
234
|
task.memory = self.shared_memory
|
239
235
|
if verbose >= 5:
|
240
236
|
logger.info(f"Assigned shared memory to task {task.id}")
|
241
|
-
|
242
237
|
except Exception as e:
|
243
238
|
logger.error(f"Failed to initialize shared memory: {e}")
|
244
|
-
|
245
239
|
# Update tasks with shared memory
|
246
240
|
if self.shared_memory:
|
247
241
|
for task in tasks:
|
@@ -898,6 +892,105 @@ Context:
|
|
898
892
|
def clear_state(self) -> None:
|
899
893
|
"""Clear all state values"""
|
900
894
|
self._state.clear()
|
895
|
+
|
896
|
+
# Convenience methods for enhanced state management
|
897
|
+
def has_state(self, key: str) -> bool:
|
898
|
+
"""Check if a state key exists"""
|
899
|
+
return key in self._state
|
900
|
+
|
901
|
+
def get_all_state(self) -> Dict[str, Any]:
|
902
|
+
"""Get a copy of the entire state dictionary"""
|
903
|
+
return self._state.copy()
|
904
|
+
|
905
|
+
def delete_state(self, key: str) -> bool:
|
906
|
+
"""Delete a state key if it exists. Returns True if deleted, False if key didn't exist."""
|
907
|
+
if key in self._state:
|
908
|
+
del self._state[key]
|
909
|
+
return True
|
910
|
+
return False
|
911
|
+
|
912
|
+
def increment_state(self, key: str, amount: float = 1, default: float = 0) -> float:
|
913
|
+
"""Increment a numeric state value. Creates the key with default if it doesn't exist."""
|
914
|
+
current = self._state.get(key, default)
|
915
|
+
if not isinstance(current, (int, float)):
|
916
|
+
raise TypeError(f"Cannot increment non-numeric value at key '{key}': {type(current).__name__}")
|
917
|
+
new_value = current + amount
|
918
|
+
self._state[key] = new_value
|
919
|
+
return new_value
|
920
|
+
|
921
|
+
def append_to_state(self, key: str, value: Any, max_length: Optional[int] = None) -> List[Any]:
|
922
|
+
"""Append a value to a list state. Creates the list if it doesn't exist.
|
923
|
+
|
924
|
+
Args:
|
925
|
+
key: State key
|
926
|
+
value: Value to append
|
927
|
+
max_length: Optional maximum length for the list
|
928
|
+
|
929
|
+
Returns:
|
930
|
+
The updated list
|
931
|
+
|
932
|
+
Raises:
|
933
|
+
TypeError: If the existing value is not a list and convert_to_list=False
|
934
|
+
"""
|
935
|
+
if key not in self._state:
|
936
|
+
self._state[key] = []
|
937
|
+
elif not isinstance(self._state[key], list):
|
938
|
+
# Be explicit about type conversion for better user experience
|
939
|
+
current_value = self._state[key]
|
940
|
+
self._state[key] = [current_value]
|
941
|
+
|
942
|
+
self._state[key].append(value)
|
943
|
+
|
944
|
+
# Trim list if max_length is specified
|
945
|
+
if max_length and len(self._state[key]) > max_length:
|
946
|
+
self._state[key] = self._state[key][-max_length:]
|
947
|
+
|
948
|
+
return self._state[key]
|
949
|
+
|
950
|
+
def save_session_state(self, session_id: str, include_memory: bool = True) -> None:
|
951
|
+
"""Save current state to memory for session persistence"""
|
952
|
+
if self.shared_memory and include_memory:
|
953
|
+
state_data = {
|
954
|
+
"session_id": session_id,
|
955
|
+
"user_id": self.user_id,
|
956
|
+
"run_id": self.run_id,
|
957
|
+
"state": self._state,
|
958
|
+
"agents": [agent.name for agent in self.agents],
|
959
|
+
"process": self.process
|
960
|
+
}
|
961
|
+
self.shared_memory.store_short_term(
|
962
|
+
text=f"Session state for {session_id}",
|
963
|
+
metadata={
|
964
|
+
"type": "session_state",
|
965
|
+
"session_id": session_id,
|
966
|
+
"user_id": self.user_id,
|
967
|
+
"state_data": state_data
|
968
|
+
}
|
969
|
+
)
|
970
|
+
|
971
|
+
def restore_session_state(self, session_id: str) -> bool:
|
972
|
+
"""Restore state from memory for session persistence. Returns True if restored."""
|
973
|
+
if not self.shared_memory:
|
974
|
+
return False
|
975
|
+
|
976
|
+
# Use metadata-based search for better SQLite compatibility
|
977
|
+
results = self.shared_memory.search_short_term(
|
978
|
+
query=f"type:session_state",
|
979
|
+
limit=10 # Get more results to filter by session_id
|
980
|
+
)
|
981
|
+
|
982
|
+
# Filter results by session_id in metadata
|
983
|
+
for result in results:
|
984
|
+
metadata = result.get("metadata", {})
|
985
|
+
if (metadata.get("type") == "session_state" and
|
986
|
+
metadata.get("session_id") == session_id):
|
987
|
+
state_data = metadata.get("state_data", {})
|
988
|
+
if "state" in state_data:
|
989
|
+
# Merge with existing state instead of replacing
|
990
|
+
self._state.update(state_data["state"])
|
991
|
+
return True
|
992
|
+
|
993
|
+
return False
|
901
994
|
|
902
995
|
def launch(self, path: str = '/agents', port: int = 8000, host: str = '0.0.0.0', debug: bool = False, protocol: str = "http"):
|
903
996
|
"""
|
@@ -0,0 +1,263 @@
|
|
1
|
+
"""
|
2
|
+
Human Approval Framework for PraisonAI Agents
|
3
|
+
|
4
|
+
This module provides a minimal human-in-the-loop approval system for dangerous tool operations.
|
5
|
+
It extends the existing callback system to require human approval before executing high-risk tools.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import asyncio
|
10
|
+
from typing import Dict, Set, Optional, Callable, Any, Literal
|
11
|
+
from functools import wraps
|
12
|
+
from contextvars import ContextVar
|
13
|
+
from rich.console import Console
|
14
|
+
from rich.panel import Panel
|
15
|
+
from rich.text import Text
|
16
|
+
from rich.prompt import Confirm
|
17
|
+
|
18
|
+
# Global registries for approval requirements
|
19
|
+
APPROVAL_REQUIRED_TOOLS: Set[str] = set()
|
20
|
+
TOOL_RISK_LEVELS: Dict[str, str] = {}
|
21
|
+
|
22
|
+
# Risk levels
|
23
|
+
RiskLevel = Literal["critical", "high", "medium", "low"]
|
24
|
+
|
25
|
+
# Global approval callback
|
26
|
+
approval_callback: Optional[Callable] = None
|
27
|
+
|
28
|
+
# Context variable to track if we're in an approved execution context
|
29
|
+
_approved_context: ContextVar[Set[str]] = ContextVar('approved_context', default=set())
|
30
|
+
|
31
|
+
class ApprovalDecision:
|
32
|
+
"""Result of an approval request"""
|
33
|
+
def __init__(self, approved: bool, modified_args: Optional[Dict[str, Any]] = None, reason: str = ""):
|
34
|
+
self.approved = approved
|
35
|
+
self.modified_args = modified_args or {}
|
36
|
+
self.reason = reason
|
37
|
+
|
38
|
+
def set_approval_callback(callback_fn: Callable):
|
39
|
+
"""Set a custom approval callback function.
|
40
|
+
|
41
|
+
The callback should accept (function_name, arguments, risk_level) and return ApprovalDecision.
|
42
|
+
"""
|
43
|
+
global approval_callback
|
44
|
+
approval_callback = callback_fn
|
45
|
+
|
46
|
+
def mark_approved(tool_name: str):
|
47
|
+
"""Mark a tool as approved in the current context."""
|
48
|
+
approved = _approved_context.get(set())
|
49
|
+
approved.add(tool_name)
|
50
|
+
_approved_context.set(approved)
|
51
|
+
|
52
|
+
def is_already_approved(tool_name: str) -> bool:
|
53
|
+
"""Check if a tool is already approved in the current context."""
|
54
|
+
approved = _approved_context.get(set())
|
55
|
+
return tool_name in approved
|
56
|
+
|
57
|
+
def clear_approval_context():
|
58
|
+
"""Clear the approval context."""
|
59
|
+
_approved_context.set(set())
|
60
|
+
|
61
|
+
def require_approval(risk_level: RiskLevel = "high"):
|
62
|
+
"""Decorator to mark a tool as requiring human approval.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
risk_level: The risk level of the tool ("critical", "high", "medium", "low")
|
66
|
+
"""
|
67
|
+
def decorator(func):
|
68
|
+
tool_name = getattr(func, '__name__', str(func))
|
69
|
+
APPROVAL_REQUIRED_TOOLS.add(tool_name)
|
70
|
+
TOOL_RISK_LEVELS[tool_name] = risk_level
|
71
|
+
|
72
|
+
@wraps(func)
|
73
|
+
def wrapper(*args, **kwargs):
|
74
|
+
# Skip approval if already approved in current context
|
75
|
+
if is_already_approved(tool_name):
|
76
|
+
return func(*args, **kwargs)
|
77
|
+
|
78
|
+
# Request approval before executing the function
|
79
|
+
try:
|
80
|
+
# Try to check if we're in an async context
|
81
|
+
try:
|
82
|
+
asyncio.get_running_loop()
|
83
|
+
# We're in an async context, but this is a sync function
|
84
|
+
# Fall back to sync approval to avoid loop conflicts
|
85
|
+
raise RuntimeError("Use sync fallback in async context")
|
86
|
+
except RuntimeError:
|
87
|
+
# Either no running loop or we want sync fallback
|
88
|
+
# Use asyncio.run for clean async execution
|
89
|
+
decision = asyncio.run(request_approval(tool_name, kwargs))
|
90
|
+
except Exception as e:
|
91
|
+
# Fallback to sync approval if async fails
|
92
|
+
logging.warning(f"Async approval failed, using sync fallback: {e}")
|
93
|
+
callback = approval_callback or console_approval_callback
|
94
|
+
decision = callback(tool_name, kwargs, risk_level)
|
95
|
+
|
96
|
+
if not decision.approved:
|
97
|
+
raise PermissionError(f"Execution of {tool_name} denied: {decision.reason}")
|
98
|
+
|
99
|
+
# Mark as approved and merge modified args
|
100
|
+
mark_approved(tool_name)
|
101
|
+
kwargs.update(decision.modified_args)
|
102
|
+
return func(*args, **kwargs)
|
103
|
+
|
104
|
+
@wraps(func)
|
105
|
+
async def async_wrapper(*args, **kwargs):
|
106
|
+
# Skip approval if already approved in current context
|
107
|
+
if is_already_approved(tool_name):
|
108
|
+
return await func(*args, **kwargs)
|
109
|
+
|
110
|
+
# Request approval before executing the function
|
111
|
+
decision = await request_approval(tool_name, kwargs)
|
112
|
+
if not decision.approved:
|
113
|
+
raise PermissionError(f"Execution of {tool_name} denied: {decision.reason}")
|
114
|
+
|
115
|
+
# Mark as approved and merge modified args
|
116
|
+
mark_approved(tool_name)
|
117
|
+
kwargs.update(decision.modified_args)
|
118
|
+
return await func(*args, **kwargs)
|
119
|
+
|
120
|
+
# Return the appropriate wrapper based on function type
|
121
|
+
if asyncio.iscoroutinefunction(func):
|
122
|
+
return async_wrapper
|
123
|
+
else:
|
124
|
+
return wrapper
|
125
|
+
|
126
|
+
return decorator
|
127
|
+
|
128
|
+
def console_approval_callback(function_name: str, arguments: Dict[str, Any], risk_level: str) -> ApprovalDecision:
|
129
|
+
"""Default console-based approval callback.
|
130
|
+
|
131
|
+
Displays tool information and prompts user for approval via console.
|
132
|
+
"""
|
133
|
+
console = Console()
|
134
|
+
|
135
|
+
# Create risk level styling
|
136
|
+
risk_colors = {
|
137
|
+
"critical": "bold red",
|
138
|
+
"high": "red",
|
139
|
+
"medium": "yellow",
|
140
|
+
"low": "blue"
|
141
|
+
}
|
142
|
+
risk_color = risk_colors.get(risk_level, "white")
|
143
|
+
|
144
|
+
# Display tool information
|
145
|
+
tool_info = f"[bold]Function:[/] {function_name}\n"
|
146
|
+
tool_info += f"[bold]Risk Level:[/] [{risk_color}]{risk_level.upper()}[/{risk_color}]\n"
|
147
|
+
tool_info += f"[bold]Arguments:[/]\n"
|
148
|
+
|
149
|
+
for key, value in arguments.items():
|
150
|
+
# Truncate long values for display
|
151
|
+
str_value = str(value)
|
152
|
+
if len(str_value) > 100:
|
153
|
+
str_value = str_value[:97] + "..."
|
154
|
+
tool_info += f" {key}: {str_value}\n"
|
155
|
+
|
156
|
+
console.print(Panel(
|
157
|
+
tool_info.strip(),
|
158
|
+
title="🔒 Tool Approval Required",
|
159
|
+
border_style=risk_color,
|
160
|
+
title_align="left"
|
161
|
+
))
|
162
|
+
|
163
|
+
# Get user approval
|
164
|
+
try:
|
165
|
+
approved = Confirm.ask(
|
166
|
+
f"[{risk_color}]Do you want to execute this {risk_level} risk tool?[/{risk_color}]",
|
167
|
+
default=False
|
168
|
+
)
|
169
|
+
|
170
|
+
if approved:
|
171
|
+
console.print("[green]✅ Tool execution approved[/green]")
|
172
|
+
return ApprovalDecision(approved=True, reason="User approved")
|
173
|
+
else:
|
174
|
+
console.print("[red]❌ Tool execution denied[/red]")
|
175
|
+
return ApprovalDecision(approved=False, reason="User denied")
|
176
|
+
|
177
|
+
except KeyboardInterrupt:
|
178
|
+
console.print("\n[red]❌ Tool execution cancelled by user[/red]")
|
179
|
+
return ApprovalDecision(approved=False, reason="User cancelled")
|
180
|
+
except Exception as e:
|
181
|
+
console.print(f"[red]Error during approval: {e}[/red]")
|
182
|
+
return ApprovalDecision(approved=False, reason=f"Approval error: {e}")
|
183
|
+
|
184
|
+
async def request_approval(function_name: str, arguments: Dict[str, Any]) -> ApprovalDecision:
|
185
|
+
"""Request approval for a tool execution.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
function_name: Name of the function to execute
|
189
|
+
arguments: Arguments to pass to the function
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
ApprovalDecision with approval status and any modifications
|
193
|
+
"""
|
194
|
+
# Check if approval is required
|
195
|
+
if function_name not in APPROVAL_REQUIRED_TOOLS:
|
196
|
+
return ApprovalDecision(approved=True, reason="No approval required")
|
197
|
+
|
198
|
+
risk_level = TOOL_RISK_LEVELS.get(function_name, "medium")
|
199
|
+
|
200
|
+
# Use custom callback if set, otherwise use console callback
|
201
|
+
callback = approval_callback or console_approval_callback
|
202
|
+
|
203
|
+
try:
|
204
|
+
# Handle async callbacks
|
205
|
+
if asyncio.iscoroutinefunction(callback):
|
206
|
+
decision = await callback(function_name, arguments, risk_level)
|
207
|
+
else:
|
208
|
+
# Run sync callback in executor to avoid blocking
|
209
|
+
loop = asyncio.get_event_loop()
|
210
|
+
decision = await loop.run_in_executor(None, callback, function_name, arguments, risk_level)
|
211
|
+
|
212
|
+
return decision
|
213
|
+
|
214
|
+
except Exception as e:
|
215
|
+
logging.error(f"Error in approval callback: {e}")
|
216
|
+
return ApprovalDecision(approved=False, reason=f"Approval callback error: {e}")
|
217
|
+
|
218
|
+
# Default dangerous tools - can be configured at runtime
|
219
|
+
DEFAULT_DANGEROUS_TOOLS = {
|
220
|
+
# Critical risk tools
|
221
|
+
"execute_command": "critical",
|
222
|
+
"kill_process": "critical",
|
223
|
+
"execute_code": "critical",
|
224
|
+
|
225
|
+
# High risk tools
|
226
|
+
"write_file": "high",
|
227
|
+
"delete_file": "high",
|
228
|
+
"move_file": "high",
|
229
|
+
"copy_file": "high",
|
230
|
+
"execute_query": "high",
|
231
|
+
|
232
|
+
# Medium risk tools
|
233
|
+
"evaluate": "medium",
|
234
|
+
"crawl": "medium",
|
235
|
+
"scrape_page": "medium",
|
236
|
+
}
|
237
|
+
|
238
|
+
def configure_default_approvals():
|
239
|
+
"""Configure default dangerous tools to require approval."""
|
240
|
+
for tool_name, risk_level in DEFAULT_DANGEROUS_TOOLS.items():
|
241
|
+
APPROVAL_REQUIRED_TOOLS.add(tool_name)
|
242
|
+
TOOL_RISK_LEVELS[tool_name] = risk_level
|
243
|
+
|
244
|
+
def add_approval_requirement(tool_name: str, risk_level: RiskLevel = "high"):
|
245
|
+
"""Dynamically add approval requirement for a tool."""
|
246
|
+
APPROVAL_REQUIRED_TOOLS.add(tool_name)
|
247
|
+
TOOL_RISK_LEVELS[tool_name] = risk_level
|
248
|
+
|
249
|
+
def remove_approval_requirement(tool_name: str):
|
250
|
+
"""Remove approval requirement for a tool."""
|
251
|
+
APPROVAL_REQUIRED_TOOLS.discard(tool_name)
|
252
|
+
TOOL_RISK_LEVELS.pop(tool_name, None)
|
253
|
+
|
254
|
+
def is_approval_required(tool_name: str) -> bool:
|
255
|
+
"""Check if a tool requires approval."""
|
256
|
+
return tool_name in APPROVAL_REQUIRED_TOOLS
|
257
|
+
|
258
|
+
def get_risk_level(tool_name: str) -> Optional[str]:
|
259
|
+
"""Get the risk level of a tool."""
|
260
|
+
return TOOL_RISK_LEVELS.get(tool_name)
|
261
|
+
|
262
|
+
# Initialize with defaults
|
263
|
+
configure_default_approvals()
|
@@ -17,9 +17,13 @@ class CustomMemory:
|
|
17
17
|
}).from_config(config)
|
18
18
|
|
19
19
|
@staticmethod
|
20
|
-
def _add_to_vector_store(self, messages, metadata, filters, infer):
|
20
|
+
def _add_to_vector_store(self, messages, metadata=None, filters=None, infer=None):
|
21
21
|
# Custom implementation that doesn't use LLM
|
22
|
-
|
22
|
+
# Handle different message formats for backward compatibility
|
23
|
+
if isinstance(messages, list):
|
24
|
+
parsed_messages = "\n".join([msg.get("content", str(msg)) if isinstance(msg, dict) else str(msg) for msg in messages])
|
25
|
+
else:
|
26
|
+
parsed_messages = str(messages)
|
23
27
|
|
24
28
|
# Create a simple fact without using LLM
|
25
29
|
new_retrieved_facts = [parsed_messages]
|
@@ -34,7 +38,7 @@ class CustomMemory:
|
|
34
38
|
memory_id = self._create_memory(
|
35
39
|
data=parsed_messages,
|
36
40
|
existing_embeddings=new_message_embeddings,
|
37
|
-
metadata=metadata
|
41
|
+
metadata=metadata or {}
|
38
42
|
)
|
39
43
|
|
40
44
|
return [{
|
@@ -137,6 +141,10 @@ class Knowledge:
|
|
137
141
|
# Merge reranker config if provided
|
138
142
|
if "reranker" in self._config:
|
139
143
|
base_config["reranker"].update(self._config["reranker"])
|
144
|
+
|
145
|
+
# Merge graph_store config if provided (for graph memory support)
|
146
|
+
if "graph_store" in self._config:
|
147
|
+
base_config["graph_store"] = self._config["graph_store"]
|
140
148
|
return base_config
|
141
149
|
|
142
150
|
@cached_property
|
@@ -184,7 +192,22 @@ class Knowledge:
|
|
184
192
|
if not content:
|
185
193
|
return []
|
186
194
|
|
187
|
-
|
195
|
+
# Try new API format first, fall back to old format for backward compatibility
|
196
|
+
try:
|
197
|
+
# Convert content to messages format for mem0 API compatibility
|
198
|
+
if isinstance(content, str):
|
199
|
+
messages = [{"role": "user", "content": content}]
|
200
|
+
else:
|
201
|
+
messages = content if isinstance(content, list) else [{"role": "user", "content": str(content)}]
|
202
|
+
|
203
|
+
result = self.memory.add(messages=messages, user_id=user_id, agent_id=agent_id, run_id=run_id, metadata=metadata)
|
204
|
+
except TypeError as e:
|
205
|
+
# Fallback to old API format if messages parameter is not supported
|
206
|
+
if "unexpected keyword argument" in str(e) or "positional argument" in str(e):
|
207
|
+
self._log(f"Falling back to legacy API format due to: {e}")
|
208
|
+
result = self.memory.add(content, user_id=user_id, agent_id=agent_id, run_id=run_id, metadata=metadata)
|
209
|
+
else:
|
210
|
+
raise
|
188
211
|
self._log(f"Store operation result: {result}")
|
189
212
|
return result
|
190
213
|
except Exception as e:
|
@@ -1512,6 +1512,10 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1512
1512
|
if self.stop_phrases:
|
1513
1513
|
params["stop"] = self.stop_phrases
|
1514
1514
|
|
1515
|
+
# Add extra settings for provider-specific parameters (e.g., num_ctx for Ollama)
|
1516
|
+
if self.extra_settings:
|
1517
|
+
params.update(self.extra_settings)
|
1518
|
+
|
1515
1519
|
# Override with any provided parameters
|
1516
1520
|
params.update(override_params)
|
1517
1521
|
|
@@ -43,12 +43,18 @@ error_logs = []
|
|
43
43
|
sync_display_callbacks = {}
|
44
44
|
async_display_callbacks = {}
|
45
45
|
|
46
|
+
# Global approval callback registry
|
47
|
+
approval_callback = None
|
48
|
+
|
46
49
|
# At the top of the file, add display_callbacks to __all__
|
47
50
|
__all__ = [
|
48
51
|
'error_logs',
|
49
52
|
'register_display_callback',
|
53
|
+
'register_approval_callback',
|
50
54
|
'sync_display_callbacks',
|
51
55
|
'async_display_callbacks',
|
56
|
+
'execute_callback',
|
57
|
+
'approval_callback',
|
52
58
|
# ... other exports
|
53
59
|
]
|
54
60
|
|
@@ -65,6 +71,15 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F
|
|
65
71
|
else:
|
66
72
|
sync_display_callbacks[display_type] = callback_fn
|
67
73
|
|
74
|
+
def register_approval_callback(callback_fn):
|
75
|
+
"""Register a global approval callback function for dangerous tool operations.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision
|
79
|
+
"""
|
80
|
+
global approval_callback
|
81
|
+
approval_callback = callback_fn
|
82
|
+
|
68
83
|
async def execute_callback(display_type: str, **kwargs):
|
69
84
|
"""Execute both sync and async callbacks for a given display type.
|
70
85
|
|