genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Input validation and sanitization for GenXAI."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
from pydantic import BaseModel, Field, validator
|
|
6
|
+
import html
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AgentExecutionRequest(BaseModel):
|
|
10
|
+
"""Validate agent execution request."""
|
|
11
|
+
task: str = Field(..., min_length=1, max_length=10000)
|
|
12
|
+
agent_id: str = Field(..., pattern=r'^[a-zA-Z0-9_-]+$')
|
|
13
|
+
context: Dict[str, Any] = Field(default_factory=dict)
|
|
14
|
+
timeout: int = Field(default=300, ge=1, le=3600)
|
|
15
|
+
|
|
16
|
+
@validator('task')
|
|
17
|
+
def validate_task(cls, v):
|
|
18
|
+
"""Validate task for SQL injection patterns."""
|
|
19
|
+
dangerous_patterns = [
|
|
20
|
+
r'(DROP|DELETE|INSERT|UPDATE|ALTER|CREATE)\s+(TABLE|DATABASE|INDEX)',
|
|
21
|
+
r';\s*(DROP|DELETE|INSERT|UPDATE)',
|
|
22
|
+
r'--\s*$',
|
|
23
|
+
r'/\*.*\*/',
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
for pattern in dangerous_patterns:
|
|
27
|
+
if re.search(pattern, v, re.IGNORECASE):
|
|
28
|
+
raise ValueError("Potential SQL injection detected")
|
|
29
|
+
|
|
30
|
+
return v
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class WorkflowExecutionRequest(BaseModel):
|
|
34
|
+
"""Validate workflow execution request."""
|
|
35
|
+
workflow_id: str = Field(..., pattern=r'^[a-zA-Z0-9_-]+$')
|
|
36
|
+
inputs: Dict[str, Any] = Field(default_factory=dict)
|
|
37
|
+
timeout: int = Field(default=600, ge=1, le=7200)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ToolExecutionRequest(BaseModel):
|
|
41
|
+
"""Validate tool execution request."""
|
|
42
|
+
tool_name: str = Field(..., pattern=r'^[a-zA-Z0-9_-]+$')
|
|
43
|
+
parameters: Dict[str, Any] = Field(default_factory=dict)
|
|
44
|
+
timeout: int = Field(default=60, ge=1, le=600)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def sanitize_sql(query: str) -> str:
|
|
48
|
+
"""Sanitize SQL query.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
query: SQL query string
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Sanitized query
|
|
55
|
+
|
|
56
|
+
Note:
|
|
57
|
+
This is a basic sanitizer. Always use parameterized queries in production.
|
|
58
|
+
"""
|
|
59
|
+
# Remove comments
|
|
60
|
+
query = re.sub(r'--.*$', '', query, flags=re.MULTILINE)
|
|
61
|
+
query = re.sub(r'/\*.*?\*/', '', query, flags=re.DOTALL)
|
|
62
|
+
|
|
63
|
+
# Remove dangerous keywords
|
|
64
|
+
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'CREATE', 'EXEC', 'EXECUTE']
|
|
65
|
+
for keyword in dangerous_keywords:
|
|
66
|
+
query = re.sub(rf'\b{keyword}\b', '', query, flags=re.IGNORECASE)
|
|
67
|
+
|
|
68
|
+
# Escape single quotes
|
|
69
|
+
query = query.replace("'", "''")
|
|
70
|
+
|
|
71
|
+
return query.strip()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def sanitize_html(text: str) -> str:
|
|
75
|
+
"""Sanitize HTML to prevent XSS.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
text: HTML text
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Sanitized text
|
|
82
|
+
"""
|
|
83
|
+
# Escape HTML entities
|
|
84
|
+
text = html.escape(text)
|
|
85
|
+
|
|
86
|
+
# Remove script tags
|
|
87
|
+
text = re.sub(r'<script[^>]*>.*?</script>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
|
88
|
+
|
|
89
|
+
# Remove event handlers
|
|
90
|
+
text = re.sub(r'\s*on\w+\s*=\s*["\']?[^"\']*["\']?', '', text, flags=re.IGNORECASE)
|
|
91
|
+
|
|
92
|
+
# Remove javascript: protocol
|
|
93
|
+
text = re.sub(r'javascript:', '', text, flags=re.IGNORECASE)
|
|
94
|
+
|
|
95
|
+
return text
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def sanitize_command(cmd: str) -> str:
|
|
99
|
+
"""Sanitize shell command.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
cmd: Shell command
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Sanitized command
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If command contains dangerous patterns
|
|
109
|
+
"""
|
|
110
|
+
# Check for dangerous patterns
|
|
111
|
+
dangerous_patterns = [
|
|
112
|
+
r'[;&|`$]', # Command chaining
|
|
113
|
+
r'\$\(', # Command substitution
|
|
114
|
+
r'>\s*/dev/', # Device access
|
|
115
|
+
r'<\s*/dev/',
|
|
116
|
+
r'/etc/passwd', # Sensitive files
|
|
117
|
+
r'/etc/shadow',
|
|
118
|
+
r'rm\s+-rf', # Dangerous commands
|
|
119
|
+
r'dd\s+if=',
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
for pattern in dangerous_patterns:
|
|
123
|
+
if re.search(pattern, cmd):
|
|
124
|
+
raise ValueError(f"Dangerous command pattern detected: {pattern}")
|
|
125
|
+
|
|
126
|
+
# Whitelist allowed commands
|
|
127
|
+
allowed_commands = ['ls', 'cat', 'echo', 'pwd', 'date', 'whoami']
|
|
128
|
+
cmd_name = cmd.split()[0] if cmd.split() else ''
|
|
129
|
+
|
|
130
|
+
if cmd_name not in allowed_commands:
|
|
131
|
+
raise ValueError(f"Command not in whitelist: {cmd_name}")
|
|
132
|
+
|
|
133
|
+
return cmd
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def validate_file_path(path: str) -> str:
|
|
137
|
+
"""Validate file path to prevent directory traversal.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
path: File path
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Validated path
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
ValueError: If path contains dangerous patterns
|
|
147
|
+
"""
|
|
148
|
+
# Check for directory traversal
|
|
149
|
+
if '..' in path:
|
|
150
|
+
raise ValueError("Directory traversal detected")
|
|
151
|
+
|
|
152
|
+
# Check for absolute paths
|
|
153
|
+
if path.startswith('/'):
|
|
154
|
+
raise ValueError("Absolute paths not allowed")
|
|
155
|
+
|
|
156
|
+
# Check for null bytes
|
|
157
|
+
if '\x00' in path:
|
|
158
|
+
raise ValueError("Null byte detected")
|
|
159
|
+
|
|
160
|
+
# Normalize path
|
|
161
|
+
path = path.replace('\\', '/')
|
|
162
|
+
path = re.sub(r'/+', '/', path)
|
|
163
|
+
|
|
164
|
+
return path
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def validate_url(url: str) -> str:
|
|
168
|
+
"""Validate URL to prevent SSRF.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
url: URL string
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Validated URL
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: If URL is dangerous
|
|
178
|
+
"""
|
|
179
|
+
# Check for dangerous protocols
|
|
180
|
+
dangerous_protocols = ['file://', 'ftp://', 'gopher://', 'dict://']
|
|
181
|
+
for protocol in dangerous_protocols:
|
|
182
|
+
if url.lower().startswith(protocol):
|
|
183
|
+
raise ValueError(f"Dangerous protocol: {protocol}")
|
|
184
|
+
|
|
185
|
+
# Check for localhost/internal IPs
|
|
186
|
+
internal_patterns = [
|
|
187
|
+
r'localhost',
|
|
188
|
+
r'127\.0\.0\.1',
|
|
189
|
+
r'0\.0\.0\.0',
|
|
190
|
+
r'10\.\d+\.\d+\.\d+',
|
|
191
|
+
r'172\.(1[6-9]|2[0-9]|3[0-1])\.\d+\.\d+',
|
|
192
|
+
r'192\.168\.\d+\.\d+',
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
for pattern in internal_patterns:
|
|
196
|
+
if re.search(pattern, url, re.IGNORECASE):
|
|
197
|
+
raise ValueError("Internal/localhost URLs not allowed")
|
|
198
|
+
|
|
199
|
+
return url
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def sanitize_json(data: Any) -> Any:
|
|
203
|
+
"""Sanitize JSON data recursively.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
data: JSON data
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Sanitized data
|
|
210
|
+
"""
|
|
211
|
+
if isinstance(data, str):
|
|
212
|
+
return sanitize_html(data)
|
|
213
|
+
elif isinstance(data, dict):
|
|
214
|
+
return {k: sanitize_json(v) for k, v in data.items()}
|
|
215
|
+
elif isinstance(data, list):
|
|
216
|
+
return [sanitize_json(item) for item in data]
|
|
217
|
+
else:
|
|
218
|
+
return data
|
genxai/tools/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""GenXAI Tools module."""
|
|
2
|
+
|
|
3
|
+
from genxai.tools.base import (
|
|
4
|
+
Tool,
|
|
5
|
+
ToolMetadata,
|
|
6
|
+
ToolParameter,
|
|
7
|
+
ToolCategory,
|
|
8
|
+
ToolResult,
|
|
9
|
+
)
|
|
10
|
+
from genxai.tools.registry import ToolRegistry
|
|
11
|
+
from genxai.tools.dynamic import DynamicTool
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"Tool",
|
|
15
|
+
"ToolMetadata",
|
|
16
|
+
"ToolParameter",
|
|
17
|
+
"ToolCategory",
|
|
18
|
+
"ToolResult",
|
|
19
|
+
"ToolRegistry",
|
|
20
|
+
"DynamicTool",
|
|
21
|
+
]
|
genxai/tools/base.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""Base tool classes for GenXAI."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from genxai.observability.metrics import record_tool_execution
|
|
11
|
+
from genxai.observability.tracing import span, record_exception
|
|
12
|
+
from genxai.tools.security.policy import is_tool_allowed
|
|
13
|
+
from genxai.security.rbac import get_current_user, Permission
|
|
14
|
+
from genxai.security.policy_engine import get_policy_engine
|
|
15
|
+
from genxai.security.audit import get_audit_log, AuditEvent
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ToolCategory(str, Enum):
|
|
21
|
+
"""Tool categories for organization."""
|
|
22
|
+
|
|
23
|
+
WEB = "web"
|
|
24
|
+
DATABASE = "database"
|
|
25
|
+
FILE = "file"
|
|
26
|
+
COMPUTATION = "computation"
|
|
27
|
+
COMMUNICATION = "communication"
|
|
28
|
+
AI = "ai"
|
|
29
|
+
# Backwards-compatible alias for unit tests that expect category == "data"
|
|
30
|
+
DATA = "data"
|
|
31
|
+
DATA_PROCESSING = "data_processing"
|
|
32
|
+
SYSTEM = "system"
|
|
33
|
+
CUSTOM = "custom"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ToolParameter(BaseModel):
|
|
37
|
+
"""Tool parameter definition."""
|
|
38
|
+
|
|
39
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
40
|
+
|
|
41
|
+
name: str
|
|
42
|
+
type: str # string, number, boolean, array, object
|
|
43
|
+
description: str
|
|
44
|
+
required: bool = True
|
|
45
|
+
default: Optional[Any] = None
|
|
46
|
+
enum: Optional[List[Any]] = None
|
|
47
|
+
min_value: Optional[float] = None
|
|
48
|
+
max_value: Optional[float] = None
|
|
49
|
+
pattern: Optional[str] = None # Regex pattern for strings
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ToolMetadata(BaseModel):
|
|
54
|
+
"""Tool metadata."""
|
|
55
|
+
|
|
56
|
+
name: str
|
|
57
|
+
description: str
|
|
58
|
+
category: ToolCategory
|
|
59
|
+
tags: List[str] = Field(default_factory=list)
|
|
60
|
+
version: str = "1.0.0"
|
|
61
|
+
author: str = "GenXAI"
|
|
62
|
+
license: str = "MIT"
|
|
63
|
+
documentation_url: Optional[str] = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ToolResult(BaseModel):
|
|
67
|
+
"""Tool execution result."""
|
|
68
|
+
|
|
69
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
70
|
+
|
|
71
|
+
success: bool
|
|
72
|
+
data: Any
|
|
73
|
+
error: Optional[str] = None
|
|
74
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
75
|
+
execution_time: float = 0.0
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class Tool(ABC):
|
|
80
|
+
"""Base class for all tools."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, metadata: ToolMetadata, parameters: List[ToolParameter]):
|
|
83
|
+
"""Initialize tool.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
metadata: Tool metadata
|
|
87
|
+
parameters: Tool parameters
|
|
88
|
+
"""
|
|
89
|
+
self.metadata = metadata
|
|
90
|
+
self.parameters = parameters
|
|
91
|
+
self._execution_count = 0
|
|
92
|
+
self._total_execution_time = 0.0
|
|
93
|
+
self._success_count = 0
|
|
94
|
+
self._failure_count = 0
|
|
95
|
+
|
|
96
|
+
async def execute(self, **kwargs: Any) -> ToolResult:
|
|
97
|
+
"""Execute tool with validation, consistent success semantics, and error handling.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
**kwargs: Tool parameters
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tool execution result
|
|
104
|
+
"""
|
|
105
|
+
start_time = time.time()
|
|
106
|
+
|
|
107
|
+
status = "success"
|
|
108
|
+
error_type: Optional[str] = None
|
|
109
|
+
try:
|
|
110
|
+
with span("genxai.tool.execute", {"tool_name": self.metadata.name}):
|
|
111
|
+
user = get_current_user()
|
|
112
|
+
if user is not None:
|
|
113
|
+
get_policy_engine().check(user, f"tool:{self.metadata.name}", Permission.TOOL_EXECUTE)
|
|
114
|
+
get_audit_log().record(
|
|
115
|
+
AuditEvent(
|
|
116
|
+
action="tool.execute",
|
|
117
|
+
actor_id=user.user_id,
|
|
118
|
+
resource_id=f"tool:{self.metadata.name}",
|
|
119
|
+
status="allowed",
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
allowed, reason = is_tool_allowed(self.metadata.name)
|
|
123
|
+
if not allowed:
|
|
124
|
+
status = "error"
|
|
125
|
+
error_type = "PolicyDenied"
|
|
126
|
+
return ToolResult(
|
|
127
|
+
success=False,
|
|
128
|
+
data=None,
|
|
129
|
+
error=reason or "Tool execution denied by policy",
|
|
130
|
+
execution_time=time.time() - start_time,
|
|
131
|
+
)
|
|
132
|
+
# Validate input
|
|
133
|
+
if not self.validate_input(**kwargs):
|
|
134
|
+
status = "error"
|
|
135
|
+
error_type = "ValidationError"
|
|
136
|
+
return ToolResult(
|
|
137
|
+
success=False,
|
|
138
|
+
data=None,
|
|
139
|
+
error="Invalid input parameters",
|
|
140
|
+
execution_time=time.time() - start_time,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Execute tool logic
|
|
144
|
+
raw_result = await self._execute(**kwargs)
|
|
145
|
+
|
|
146
|
+
# Normalize results:
|
|
147
|
+
# - If tool returns ToolResult, respect it.
|
|
148
|
+
# - If tool returns a dict containing a boolean "success" field, map that
|
|
149
|
+
# to ToolResult.success and propagate "error" if present.
|
|
150
|
+
# - Otherwise treat return value as successful data.
|
|
151
|
+
tool_success: Optional[bool] = None
|
|
152
|
+
tool_error: Optional[str] = None
|
|
153
|
+
result_data: Any = raw_result
|
|
154
|
+
|
|
155
|
+
if isinstance(raw_result, ToolResult):
|
|
156
|
+
# Update metrics based on the returned ToolResult.
|
|
157
|
+
execution_time = time.time() - start_time
|
|
158
|
+
self._execution_count += 1
|
|
159
|
+
self._total_execution_time += execution_time
|
|
160
|
+
if raw_result.success:
|
|
161
|
+
self._success_count += 1
|
|
162
|
+
else:
|
|
163
|
+
self._failure_count += 1
|
|
164
|
+
status = "error"
|
|
165
|
+
error_type = raw_result.error or "ToolError"
|
|
166
|
+
# Ensure metadata/execution_time are populated.
|
|
167
|
+
if not raw_result.metadata:
|
|
168
|
+
raw_result.metadata = {"tool": self.metadata.name, "version": self.metadata.version}
|
|
169
|
+
raw_result.execution_time = execution_time
|
|
170
|
+
record_tool_execution(
|
|
171
|
+
tool_name=self.metadata.name,
|
|
172
|
+
duration=execution_time,
|
|
173
|
+
status="success" if raw_result.success else "error",
|
|
174
|
+
error_type=error_type,
|
|
175
|
+
)
|
|
176
|
+
return raw_result
|
|
177
|
+
|
|
178
|
+
if isinstance(raw_result, dict) and "success" in raw_result and isinstance(raw_result["success"], bool):
|
|
179
|
+
tool_success = raw_result["success"]
|
|
180
|
+
tool_error = raw_result.get("error")
|
|
181
|
+
# Keep the full raw payload as data to aid debugging.
|
|
182
|
+
result_data = raw_result
|
|
183
|
+
|
|
184
|
+
# Update metrics
|
|
185
|
+
execution_time = time.time() - start_time
|
|
186
|
+
self._execution_count += 1
|
|
187
|
+
self._total_execution_time += execution_time
|
|
188
|
+
|
|
189
|
+
# If tool explicitly signaled success/failure, respect it.
|
|
190
|
+
if tool_success is False:
|
|
191
|
+
self._failure_count += 1
|
|
192
|
+
status = "error"
|
|
193
|
+
error_type = tool_error or "ToolError"
|
|
194
|
+
logger.warning(
|
|
195
|
+
f"Tool {self.metadata.name} reported failure in {execution_time:.2f}s: {tool_error}"
|
|
196
|
+
)
|
|
197
|
+
record_tool_execution(
|
|
198
|
+
tool_name=self.metadata.name,
|
|
199
|
+
duration=execution_time,
|
|
200
|
+
status="error",
|
|
201
|
+
error_type=error_type,
|
|
202
|
+
)
|
|
203
|
+
return ToolResult(
|
|
204
|
+
success=False,
|
|
205
|
+
data=result_data,
|
|
206
|
+
error=tool_error or "Tool reported failure",
|
|
207
|
+
execution_time=execution_time,
|
|
208
|
+
metadata={"tool": self.metadata.name, "version": self.metadata.version},
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self._success_count += 1
|
|
212
|
+
logger.info(
|
|
213
|
+
f"Tool {self.metadata.name} executed successfully in {execution_time:.2f}s"
|
|
214
|
+
)
|
|
215
|
+
record_tool_execution(
|
|
216
|
+
tool_name=self.metadata.name,
|
|
217
|
+
duration=execution_time,
|
|
218
|
+
status="success",
|
|
219
|
+
)
|
|
220
|
+
return ToolResult(
|
|
221
|
+
success=True,
|
|
222
|
+
data=result_data,
|
|
223
|
+
execution_time=execution_time,
|
|
224
|
+
metadata={"tool": self.metadata.name, "version": self.metadata.version},
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
execution_time = time.time() - start_time
|
|
229
|
+
self._execution_count += 1
|
|
230
|
+
self._total_execution_time += execution_time
|
|
231
|
+
self._failure_count += 1
|
|
232
|
+
status = "error"
|
|
233
|
+
error_type = type(e).__name__
|
|
234
|
+
|
|
235
|
+
logger.error(f"Tool {self.metadata.name} failed: {str(e)}")
|
|
236
|
+
record_exception(e)
|
|
237
|
+
record_tool_execution(
|
|
238
|
+
tool_name=self.metadata.name,
|
|
239
|
+
duration=execution_time,
|
|
240
|
+
status=status,
|
|
241
|
+
error_type=error_type,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return ToolResult(
|
|
245
|
+
success=False, data=None, error=str(e), execution_time=execution_time
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
@abstractmethod
|
|
249
|
+
async def _execute(self, **kwargs: Any) -> Any:
|
|
250
|
+
"""Implement tool-specific logic.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
**kwargs: Tool parameters
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tool result data
|
|
257
|
+
"""
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
def validate_input(self, **kwargs: Any) -> bool:
|
|
261
|
+
"""Validate input parameters against schema.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
**kwargs: Input parameters
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
True if valid, False otherwise
|
|
268
|
+
"""
|
|
269
|
+
for param in self.parameters:
|
|
270
|
+
# Check required parameters
|
|
271
|
+
if param.required and param.name not in kwargs:
|
|
272
|
+
logger.error(f"Missing required parameter: {param.name}")
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
if param.name in kwargs:
|
|
276
|
+
value = kwargs[param.name]
|
|
277
|
+
|
|
278
|
+
# Type validation
|
|
279
|
+
if param.type == "string" and not isinstance(value, str):
|
|
280
|
+
logger.error(f"Parameter {param.name} must be string")
|
|
281
|
+
return False
|
|
282
|
+
elif param.type == "number" and not isinstance(value, (int, float)):
|
|
283
|
+
logger.error(f"Parameter {param.name} must be number")
|
|
284
|
+
return False
|
|
285
|
+
elif param.type == "boolean" and not isinstance(value, bool):
|
|
286
|
+
logger.error(f"Parameter {param.name} must be boolean")
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
# Range validation
|
|
290
|
+
if param.min_value is not None and value < param.min_value:
|
|
291
|
+
logger.error(
|
|
292
|
+
f"Parameter {param.name} must be >= {param.min_value}"
|
|
293
|
+
)
|
|
294
|
+
return False
|
|
295
|
+
if param.max_value is not None and value > param.max_value:
|
|
296
|
+
logger.error(
|
|
297
|
+
f"Parameter {param.name} must be <= {param.max_value}"
|
|
298
|
+
)
|
|
299
|
+
return False
|
|
300
|
+
|
|
301
|
+
# Enum validation
|
|
302
|
+
if param.enum and value not in param.enum:
|
|
303
|
+
logger.error(
|
|
304
|
+
f"Parameter {param.name} must be one of {param.enum}"
|
|
305
|
+
)
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
return True
|
|
309
|
+
|
|
310
|
+
def get_schema(self) -> Dict[str, Any]:
|
|
311
|
+
"""Generate OpenAPI-style schema.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Tool schema dictionary
|
|
315
|
+
"""
|
|
316
|
+
def _build_param_schema(param: ToolParameter) -> Dict[str, Any]:
|
|
317
|
+
schema: Dict[str, Any] = {
|
|
318
|
+
"type": param.type,
|
|
319
|
+
"description": param.description,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
if param.enum:
|
|
323
|
+
schema["enum"] = param.enum
|
|
324
|
+
if param.default is not None:
|
|
325
|
+
schema["default"] = param.default
|
|
326
|
+
if param.pattern:
|
|
327
|
+
schema["pattern"] = param.pattern
|
|
328
|
+
|
|
329
|
+
if param.type == "number":
|
|
330
|
+
if param.min_value is not None:
|
|
331
|
+
schema["minimum"] = param.min_value
|
|
332
|
+
if param.max_value is not None:
|
|
333
|
+
schema["maximum"] = param.max_value
|
|
334
|
+
|
|
335
|
+
return schema
|
|
336
|
+
|
|
337
|
+
return {
|
|
338
|
+
"name": self.metadata.name,
|
|
339
|
+
"description": self.metadata.description,
|
|
340
|
+
"category": self.metadata.category.value,
|
|
341
|
+
"parameters": {
|
|
342
|
+
"type": "object",
|
|
343
|
+
"properties": {
|
|
344
|
+
param.name: _build_param_schema(param)
|
|
345
|
+
for param in self.parameters
|
|
346
|
+
},
|
|
347
|
+
"required": [p.name for p in self.parameters if p.required],
|
|
348
|
+
},
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
def get_metrics(self) -> Dict[str, Any]:
|
|
352
|
+
"""Get tool execution metrics.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Metrics dictionary
|
|
356
|
+
"""
|
|
357
|
+
return {
|
|
358
|
+
"execution_count": self._execution_count,
|
|
359
|
+
"success_count": self._success_count,
|
|
360
|
+
"failure_count": self._failure_count,
|
|
361
|
+
"success_rate": (
|
|
362
|
+
self._success_count / self._execution_count
|
|
363
|
+
if self._execution_count > 0
|
|
364
|
+
else 0.0
|
|
365
|
+
),
|
|
366
|
+
"total_execution_time": self._total_execution_time,
|
|
367
|
+
"average_execution_time": (
|
|
368
|
+
self._total_execution_time / self._execution_count
|
|
369
|
+
if self._execution_count > 0
|
|
370
|
+
else 0.0
|
|
371
|
+
),
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
def reset_metrics(self) -> None:
|
|
375
|
+
"""Reset tool metrics."""
|
|
376
|
+
self._execution_count = 0
|
|
377
|
+
self._total_execution_time = 0.0
|
|
378
|
+
self._success_count = 0
|
|
379
|
+
self._failure_count = 0
|
|
380
|
+
|
|
381
|
+
def __repr__(self) -> str:
|
|
382
|
+
"""String representation."""
|
|
383
|
+
return f"Tool(name={self.metadata.name}, category={self.metadata.category})"
|