kailash 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/mcp/server_new.py +6 -6
- kailash/nodes/data/__init__.py +1 -2
- kailash/nodes/data/sql.py +699 -256
- kailash/workflow/cycle_analyzer.py +346 -225
- kailash/workflow/cycle_builder.py +75 -69
- kailash/workflow/cycle_config.py +62 -46
- kailash/workflow/cycle_debugger.py +284 -184
- kailash/workflow/cycle_exceptions.py +111 -97
- kailash/workflow/cycle_profiler.py +272 -202
- kailash/workflow/migration.py +238 -197
- kailash/workflow/templates.py +124 -105
- kailash/workflow/validation.py +356 -298
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/METADATA +4 -1
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/RECORD +18 -18
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/WHEEL +0 -0
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/entry_points.txt +0 -0
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.2.0.dist-info → kailash-0.2.1.dist-info}/top_level.txt +0 -0
kailash/nodes/data/sql.py
CHANGED
@@ -12,30 +12,148 @@ Design Philosophy:
|
|
12
12
|
5. Transaction support
|
13
13
|
"""
|
14
14
|
|
15
|
-
|
15
|
+
import os
|
16
|
+
import threading
|
17
|
+
import time
|
18
|
+
from datetime import datetime
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
20
|
+
|
21
|
+
import yaml
|
22
|
+
from sqlalchemy import create_engine, text
|
23
|
+
from sqlalchemy.exc import SQLAlchemyError
|
24
|
+
from sqlalchemy.pool import QueuePool
|
16
25
|
|
17
26
|
from kailash.nodes.base import Node, NodeParameter, register_node
|
27
|
+
from kailash.sdk_exceptions import NodeExecutionError
|
18
28
|
|
19
29
|
|
20
30
|
@register_node()
|
21
31
|
class SQLDatabaseNode(Node):
|
22
|
-
|
32
|
+
|
33
|
+
class _DatabaseConfigManager:
|
34
|
+
"""Internal manager for database configurations from project settings."""
|
35
|
+
|
36
|
+
def __init__(self, project_config_path: str):
|
37
|
+
"""Initialize with project configuration file path."""
|
38
|
+
self.config_path = project_config_path
|
39
|
+
self.config = self._load_project_config()
|
40
|
+
|
41
|
+
def _load_project_config(self) -> Dict[str, Any]:
|
42
|
+
"""Load project configuration from YAML file."""
|
43
|
+
if not os.path.exists(self.config_path):
|
44
|
+
raise NodeExecutionError(
|
45
|
+
f"Project configuration file not found: {self.config_path}"
|
46
|
+
)
|
47
|
+
|
48
|
+
try:
|
49
|
+
with open(self.config_path, "r") as f:
|
50
|
+
config = yaml.safe_load(f)
|
51
|
+
return config or {}
|
52
|
+
except yaml.YAMLError as e:
|
53
|
+
raise NodeExecutionError(f"Invalid YAML in project configuration: {e}")
|
54
|
+
except Exception as e:
|
55
|
+
raise NodeExecutionError(f"Failed to load project configuration: {e}")
|
56
|
+
|
57
|
+
def get_database_config(
|
58
|
+
self, connection_name: str
|
59
|
+
) -> Tuple[str, Dict[str, Any]]:
|
60
|
+
"""Get database configuration by connection name.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
connection_name: Name of the database connection from project config
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Tuple of (connection_string, db_config)
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
NodeExecutionError: If connection not found in configuration
|
70
|
+
"""
|
71
|
+
databases = self.config.get("databases", {})
|
72
|
+
|
73
|
+
if connection_name in databases:
|
74
|
+
db_config = databases[connection_name].copy()
|
75
|
+
connection_string = db_config.pop("url", None)
|
76
|
+
|
77
|
+
if not connection_string:
|
78
|
+
raise NodeExecutionError(
|
79
|
+
f"No 'url' specified for database connection '{connection_name}'"
|
80
|
+
)
|
81
|
+
|
82
|
+
# Handle environment variable substitution
|
83
|
+
connection_string = self._substitute_env_vars(connection_string)
|
84
|
+
|
85
|
+
return connection_string, db_config
|
86
|
+
|
87
|
+
# Fall back to default configuration
|
88
|
+
if "default" in databases:
|
89
|
+
default_config = databases["default"].copy()
|
90
|
+
connection_string = default_config.pop("url", None)
|
91
|
+
|
92
|
+
if connection_string:
|
93
|
+
connection_string = self._substitute_env_vars(connection_string)
|
94
|
+
return connection_string, default_config
|
95
|
+
|
96
|
+
# Ultimate fallback
|
97
|
+
raise NodeExecutionError(
|
98
|
+
f"Database connection '{connection_name}' not found in project configuration. "
|
99
|
+
f"Available connections: {list(databases.keys())}"
|
100
|
+
)
|
101
|
+
|
102
|
+
def _substitute_env_vars(self, value: str) -> str:
|
103
|
+
"""Substitute environment variables in configuration values."""
|
104
|
+
if (
|
105
|
+
isinstance(value, str)
|
106
|
+
and value.startswith("${")
|
107
|
+
and value.endswith("}")
|
108
|
+
):
|
109
|
+
env_var = value[2:-1]
|
110
|
+
env_value = os.getenv(env_var)
|
111
|
+
if env_value is None:
|
112
|
+
raise NodeExecutionError(
|
113
|
+
f"Environment variable '{env_var}' not found"
|
114
|
+
)
|
115
|
+
return env_value
|
116
|
+
return value
|
117
|
+
|
118
|
+
def validate_config(self) -> None:
|
119
|
+
"""Validate the project configuration."""
|
120
|
+
databases = self.config.get("databases", {})
|
121
|
+
|
122
|
+
if not databases:
|
123
|
+
raise NodeExecutionError(
|
124
|
+
"No databases configured in project configuration"
|
125
|
+
)
|
126
|
+
|
127
|
+
for name, config in databases.items():
|
128
|
+
if not isinstance(config, dict):
|
129
|
+
raise NodeExecutionError(
|
130
|
+
f"Database '{name}' configuration must be a dictionary"
|
131
|
+
)
|
132
|
+
|
133
|
+
if "url" not in config and name != "default":
|
134
|
+
raise NodeExecutionError(
|
135
|
+
f"Database '{name}' missing required 'url' field"
|
136
|
+
)
|
137
|
+
|
138
|
+
"""Executes SQL queries against relational databases with shared connection pools.
|
23
139
|
|
24
140
|
This node provides a unified interface for interacting with various RDBMS
|
25
141
|
systems including PostgreSQL, MySQL, SQLite, and others. It handles
|
26
|
-
connection management, query execution, and result formatting
|
142
|
+
connection management, query execution, and result formatting using
|
143
|
+
shared connection pools for efficient resource utilization.
|
27
144
|
|
28
145
|
Design Features:
|
29
|
-
1.
|
30
|
-
2.
|
146
|
+
1. Shared connection pools across all node instances
|
147
|
+
2. Project-level database configuration
|
31
148
|
3. Parameterized queries to prevent SQL injection
|
32
149
|
4. Flexible result formats (dict, list, raw)
|
33
150
|
5. Transaction support with commit/rollback
|
34
151
|
6. Query timeout handling
|
152
|
+
7. Connection pool monitoring and metrics
|
35
153
|
|
36
154
|
Data Flow:
|
37
|
-
- Input: SQL query, parameters
|
38
|
-
- Processing: Execute query, format results
|
155
|
+
- Input: Connection name (from project config), SQL query, parameters
|
156
|
+
- Processing: Execute query using shared pools, format results
|
39
157
|
- Output: Query results in specified format
|
40
158
|
|
41
159
|
Common Usage Patterns:
|
@@ -45,57 +163,106 @@ class SQLDatabaseNode(Node):
|
|
45
163
|
4. Report generation
|
46
164
|
5. Data validation queries
|
47
165
|
|
48
|
-
Upstream Sources:
|
49
|
-
- User-defined queries
|
50
|
-
- Query builder nodes
|
51
|
-
- Template processors
|
52
|
-
- Previous query results
|
53
|
-
|
54
|
-
Downstream Consumers:
|
55
|
-
- Transform nodes: Process query results
|
56
|
-
- Writer nodes: Export to files
|
57
|
-
- Aggregator nodes: Summarize data
|
58
|
-
- Visualization nodes: Create charts
|
59
|
-
|
60
|
-
Error Handling:
|
61
|
-
- ConnectionError: Database connection issues
|
62
|
-
- QueryError: SQL syntax or execution errors
|
63
|
-
- TimeoutError: Query execution timeout
|
64
|
-
- PermissionError: Access denied
|
65
|
-
|
66
166
|
Example:
|
67
|
-
>>> #
|
68
|
-
>>>
|
69
|
-
|
167
|
+
>>> # Initialize with project configuration
|
168
|
+
>>> SQLDatabaseNode.initialize('kailash_project.yaml')
|
169
|
+
>>>
|
170
|
+
>>> # Create node with database connection configuration
|
171
|
+
>>> sql_node = SQLDatabaseNode(connection='customer_db')
|
172
|
+
>>>
|
173
|
+
>>> # Execute multiple queries with the same node
|
174
|
+
>>> result1 = sql_node.run(
|
70
175
|
... query='SELECT * FROM customers WHERE active = ?',
|
71
|
-
... parameters=[True]
|
72
|
-
...
|
176
|
+
... parameters=[True]
|
177
|
+
... )
|
178
|
+
>>> result2 = sql_node.run(
|
179
|
+
... query='SELECT COUNT(*) as total FROM orders'
|
73
180
|
... )
|
74
|
-
>>>
|
75
|
-
>>> # result['data'] = [
|
181
|
+
>>> # result1['data'] = [
|
76
182
|
>>> # {'id': 1, 'name': 'John', 'active': True},
|
77
183
|
>>> # {'id': 2, 'name': 'Jane', 'active': True}
|
78
184
|
>>> # ]
|
79
185
|
"""
|
80
186
|
|
187
|
+
# Class-level shared resources for connection pooling
|
188
|
+
_shared_pools: Dict[Tuple[str, frozenset], Any] = {}
|
189
|
+
_pool_metrics: Dict[Tuple[str, frozenset], Dict[str, Any]] = {}
|
190
|
+
_pool_lock = threading.Lock()
|
191
|
+
_config_manager: Optional["SQLDatabaseNode._DatabaseConfigManager"] = None
|
192
|
+
|
193
|
+
# NOTE: This method is deprecated in favor of direct configuration in constructor
|
194
|
+
@classmethod
|
195
|
+
def initialize(cls, project_config_path: str) -> None:
|
196
|
+
"""Initialize shared resources with project configuration.
|
197
|
+
|
198
|
+
DEPRECATED: Use direct configuration in constructor instead.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
project_config_path: Path to the project configuration YAML file
|
202
|
+
"""
|
203
|
+
with cls._pool_lock:
|
204
|
+
cls._config_manager = cls._DatabaseConfigManager(project_config_path)
|
205
|
+
cls._config_manager.validate_config()
|
206
|
+
|
207
|
+
def __init__(
|
208
|
+
self,
|
209
|
+
connection_string: str = None,
|
210
|
+
pool_size: int = 5,
|
211
|
+
max_overflow: int = 10,
|
212
|
+
pool_timeout: int = 30,
|
213
|
+
pool_recycle: int = 3600,
|
214
|
+
pool_pre_ping: bool = True,
|
215
|
+
echo: bool = False,
|
216
|
+
connect_args: dict = None,
|
217
|
+
**kwargs,
|
218
|
+
):
|
219
|
+
"""Initialize SQLDatabaseNode with direct database connection configuration.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
connection_string: Database connection URL (e.g., "sqlite:///path/to/db.db")
|
223
|
+
pool_size: Number of connections in the pool (default: 5)
|
224
|
+
max_overflow: Maximum overflow connections (default: 10)
|
225
|
+
pool_timeout: Timeout in seconds to get connection from pool (default: 30)
|
226
|
+
pool_recycle: Time in seconds to recycle connections (default: 3600)
|
227
|
+
pool_pre_ping: Test connections before use (default: True)
|
228
|
+
echo: Enable SQLAlchemy query logging (default: False)
|
229
|
+
connect_args: Additional database-specific connection arguments
|
230
|
+
**kwargs: Additional node configuration parameters
|
231
|
+
"""
|
232
|
+
if not connection_string:
|
233
|
+
raise NodeExecutionError("connection_string parameter is required")
|
234
|
+
|
235
|
+
# Store connection configuration
|
236
|
+
self.connection_string = connection_string
|
237
|
+
self.db_config = {
|
238
|
+
"pool_size": pool_size,
|
239
|
+
"max_overflow": max_overflow,
|
240
|
+
"pool_timeout": pool_timeout,
|
241
|
+
"pool_recycle": pool_recycle,
|
242
|
+
"pool_pre_ping": pool_pre_ping,
|
243
|
+
"echo": echo,
|
244
|
+
}
|
245
|
+
|
246
|
+
if connect_args:
|
247
|
+
self.db_config["connect_args"] = connect_args
|
248
|
+
|
249
|
+
# Add connection_string to kwargs for base class validation
|
250
|
+
kwargs["connection_string"] = connection_string
|
251
|
+
|
252
|
+
# Call parent constructor
|
253
|
+
super().__init__(**kwargs)
|
254
|
+
|
81
255
|
def get_parameters(self) -> Dict[str, NodeParameter]:
|
82
256
|
"""Define input parameters for SQL execution.
|
83
257
|
|
84
|
-
|
85
|
-
|
258
|
+
Configuration parameters (provided to constructor):
|
259
|
+
1. connection_string: Database connection URL
|
260
|
+
2. pool_size, max_overflow, etc.: Connection pool configuration
|
86
261
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
4. result_format: Output structure preference
|
92
|
-
5. timeout: Query execution limit
|
93
|
-
6. transaction_mode: Transaction handling
|
94
|
-
|
95
|
-
Security considerations:
|
96
|
-
- Always use parameterized queries
|
97
|
-
- Connection strings should use environment variables
|
98
|
-
- Validate query permissions
|
262
|
+
Runtime parameters (passed to run() method):
|
263
|
+
3. query: SQL query to execute
|
264
|
+
4. parameters: Query parameters for safety
|
265
|
+
5. result_format: Output format
|
99
266
|
|
100
267
|
Returns:
|
101
268
|
Dictionary of parameter definitions
|
@@ -105,13 +272,13 @@ class SQLDatabaseNode(Node):
|
|
105
272
|
name="connection_string",
|
106
273
|
type=str,
|
107
274
|
required=True,
|
108
|
-
description="Database connection
|
275
|
+
description="Database connection URL (e.g., 'sqlite:///path/to/db.db')",
|
109
276
|
),
|
110
277
|
"query": NodeParameter(
|
111
278
|
name="query",
|
112
279
|
type=str,
|
113
|
-
required=
|
114
|
-
description="SQL query to execute (use ? for
|
280
|
+
required=False, # Not required in constructor, provided at runtime
|
281
|
+
description="SQL query to execute (use ? for SQLite, $1 for PostgreSQL, %s for MySQL)",
|
115
282
|
),
|
116
283
|
"parameters": NodeParameter(
|
117
284
|
name="parameters",
|
@@ -127,55 +294,54 @@ class SQLDatabaseNode(Node):
|
|
127
294
|
default="dict",
|
128
295
|
description="Result format: 'dict', 'list', or 'raw'",
|
129
296
|
),
|
130
|
-
"timeout": NodeParameter(
|
131
|
-
name="timeout",
|
132
|
-
type=int,
|
133
|
-
required=False,
|
134
|
-
default=30,
|
135
|
-
description="Query timeout in seconds",
|
136
|
-
),
|
137
|
-
"transaction_mode": NodeParameter(
|
138
|
-
name="transaction_mode",
|
139
|
-
type=str,
|
140
|
-
required=False,
|
141
|
-
default="auto",
|
142
|
-
description="Transaction mode: 'auto', 'manual', or 'none'",
|
143
|
-
),
|
144
297
|
}
|
145
298
|
|
299
|
+
@staticmethod
|
300
|
+
def _make_hashable(obj):
|
301
|
+
"""Convert nested dictionaries/lists to hashable tuples for cache keys."""
|
302
|
+
if isinstance(obj, dict):
|
303
|
+
return tuple(
|
304
|
+
sorted((k, SQLDatabaseNode._make_hashable(v)) for k, v in obj.items())
|
305
|
+
)
|
306
|
+
elif isinstance(obj, list):
|
307
|
+
return tuple(SQLDatabaseNode._make_hashable(item) for item in obj)
|
308
|
+
else:
|
309
|
+
return obj
|
310
|
+
|
311
|
+
def _get_shared_engine(self):
|
312
|
+
"""Get or create shared engine for database connection."""
|
313
|
+
cache_key = (self.connection_string, self._make_hashable(self.db_config))
|
314
|
+
|
315
|
+
with self._pool_lock:
|
316
|
+
if cache_key not in self._shared_pools:
|
317
|
+
self.logger.info(
|
318
|
+
f"Creating shared pool for {SQLDatabaseNode._mask_connection_password(self.connection_string)}"
|
319
|
+
)
|
320
|
+
|
321
|
+
# Apply configuration with sensible defaults
|
322
|
+
pool_config = {
|
323
|
+
"poolclass": QueuePool,
|
324
|
+
**self.db_config, # Use the stored db_config
|
325
|
+
}
|
326
|
+
|
327
|
+
engine = create_engine(self.connection_string, **pool_config)
|
328
|
+
|
329
|
+
self._shared_pools[cache_key] = engine
|
330
|
+
self._pool_metrics[cache_key] = {
|
331
|
+
"created_at": datetime.now(),
|
332
|
+
"total_queries": 0,
|
333
|
+
}
|
334
|
+
|
335
|
+
return self._shared_pools[cache_key]
|
336
|
+
|
146
337
|
def run(self, **kwargs) -> Dict[str, Any]:
|
147
|
-
"""Execute SQL query
|
148
|
-
|
149
|
-
Performs database query execution with proper connection handling,
|
150
|
-
parameter binding, and result formatting.
|
151
|
-
|
152
|
-
Processing Steps:
|
153
|
-
1. Parse connection string
|
154
|
-
2. Establish database connection
|
155
|
-
3. Prepare parameterized query
|
156
|
-
4. Execute with timeout
|
157
|
-
5. Format results
|
158
|
-
6. Handle transactions
|
159
|
-
7. Close connection
|
160
|
-
|
161
|
-
Connection Management:
|
162
|
-
- Uses connection pooling when available
|
163
|
-
- Automatic retry on connection failure
|
164
|
-
- Proper cleanup on errors
|
165
|
-
|
166
|
-
Result Formatting:
|
167
|
-
- dict: List of dictionaries with column names
|
168
|
-
- list: List of lists (raw rows)
|
169
|
-
- raw: Database cursor object
|
338
|
+
"""Execute SQL query using shared connection pool.
|
170
339
|
|
171
340
|
Args:
|
172
341
|
**kwargs: Validated parameters including:
|
173
|
-
- connection_string: Database URL
|
174
342
|
- query: SQL statement
|
175
|
-
- parameters: Query parameters
|
176
|
-
- result_format: Output format
|
177
|
-
- timeout: Execution timeout
|
178
|
-
- transaction_mode: Transaction handling
|
343
|
+
- parameters: Query parameters (optional)
|
344
|
+
- result_format: Output format (optional)
|
179
345
|
|
180
346
|
Returns:
|
181
347
|
Dictionary containing:
|
@@ -186,195 +352,472 @@ class SQLDatabaseNode(Node):
|
|
186
352
|
|
187
353
|
Raises:
|
188
354
|
NodeExecutionError: Connection or query errors
|
189
|
-
NodeValidationError: Invalid parameters
|
190
|
-
TimeoutError: Query timeout exceeded
|
191
355
|
"""
|
192
|
-
|
193
|
-
query = kwargs
|
194
|
-
|
356
|
+
# Extract validated inputs
|
357
|
+
query = kwargs.get("query")
|
358
|
+
parameters = kwargs.get("parameters", [])
|
195
359
|
result_format = kwargs.get("result_format", "dict")
|
196
|
-
# timeout = kwargs.get("timeout", 30) # TODO: Implement query timeout
|
197
|
-
# transaction_mode = kwargs.get("transaction_mode", "auto") # TODO: Implement transaction handling
|
198
|
-
|
199
|
-
# This is a placeholder implementation
|
200
|
-
# In a real implementation, you would:
|
201
|
-
# 1. Use appropriate database driver (psycopg2, pymysql, sqlite3, etc.)
|
202
|
-
# 2. Implement connection pooling
|
203
|
-
# 3. Handle parameterized queries properly
|
204
|
-
# 4. Implement timeout handling
|
205
|
-
# 5. Format results according to result_format
|
206
|
-
|
207
|
-
self.logger.info(f"Executing SQL query on {connection_string}")
|
208
|
-
|
209
|
-
# Simulate query execution
|
210
|
-
# In real implementation, use actual database connection
|
211
|
-
if "SELECT" in query.upper():
|
212
|
-
# Simulate SELECT query results
|
213
|
-
data = [
|
214
|
-
{"id": 1, "name": "Sample1", "value": 100},
|
215
|
-
{"id": 2, "name": "Sample2", "value": 200},
|
216
|
-
]
|
217
|
-
columns = ["id", "name", "value"]
|
218
|
-
row_count = len(data)
|
219
|
-
else:
|
220
|
-
# Simulate INSERT/UPDATE/DELETE
|
221
|
-
data = []
|
222
|
-
columns = []
|
223
|
-
row_count = 1 # Affected rows
|
224
360
|
|
225
|
-
#
|
226
|
-
if
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
361
|
+
# Validate required parameters
|
362
|
+
if not query:
|
363
|
+
raise NodeExecutionError("query parameter is required")
|
364
|
+
|
365
|
+
# Validate query safety
|
366
|
+
self._validate_query_safety(query)
|
367
|
+
|
368
|
+
# Mask password in connection string for logging
|
369
|
+
masked_connection = SQLDatabaseNode._mask_connection_password(
|
370
|
+
self.connection_string
|
371
|
+
)
|
372
|
+
self.logger.info(f"Executing SQL query on {masked_connection}")
|
373
|
+
self.logger.debug(f"Query: {query}")
|
374
|
+
self.logger.debug(f"Parameters: {parameters}")
|
375
|
+
|
376
|
+
# Get shared engine
|
377
|
+
engine = self._get_shared_engine()
|
378
|
+
|
379
|
+
# Track metrics - use same cache key generation logic
|
380
|
+
cache_key = (self.connection_string, self._make_hashable(self.db_config))
|
381
|
+
with self._pool_lock:
|
382
|
+
self._pool_metrics[cache_key]["total_queries"] += 1
|
383
|
+
|
384
|
+
# Execute query with shared connection pool
|
385
|
+
start_time = time.time()
|
386
|
+
|
387
|
+
try:
|
388
|
+
with engine.connect() as conn:
|
389
|
+
with conn.begin() as trans:
|
390
|
+
try:
|
391
|
+
# Handle parameterized queries
|
392
|
+
# SQLAlchemy 2.0 with text() requires named parameters for positional values
|
393
|
+
if parameters:
|
394
|
+
if isinstance(parameters, dict):
|
395
|
+
# Named parameters - use as-is
|
396
|
+
result = conn.execute(text(query), parameters)
|
397
|
+
elif isinstance(parameters, (list, tuple)):
|
398
|
+
# Convert positional parameters to named parameters
|
399
|
+
named_query, param_dict = (
|
400
|
+
self._convert_to_named_parameters(query, parameters)
|
401
|
+
)
|
402
|
+
result = conn.execute(text(named_query), param_dict)
|
403
|
+
else:
|
404
|
+
# Single parameter
|
405
|
+
named_query, param_dict = (
|
406
|
+
self._convert_to_named_parameters(
|
407
|
+
query, [parameters]
|
408
|
+
)
|
409
|
+
)
|
410
|
+
result = conn.execute(text(named_query), param_dict)
|
411
|
+
else:
|
412
|
+
result = conn.execute(text(query))
|
413
|
+
|
414
|
+
execution_time = time.time() - start_time
|
415
|
+
|
416
|
+
# Process results
|
417
|
+
if result.returns_rows:
|
418
|
+
rows = result.fetchall()
|
419
|
+
columns = list(result.keys()) if result.keys() else []
|
420
|
+
row_count = len(rows)
|
421
|
+
formatted_data = self._format_results(
|
422
|
+
rows, columns, result_format
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
formatted_data = []
|
426
|
+
columns = []
|
427
|
+
row_count = result.rowcount if result.rowcount != -1 else 0
|
428
|
+
|
429
|
+
trans.commit()
|
430
|
+
|
431
|
+
except Exception:
|
432
|
+
trans.rollback()
|
433
|
+
raise
|
434
|
+
|
435
|
+
except SQLAlchemyError as e:
|
436
|
+
execution_time = time.time() - start_time
|
437
|
+
sanitized_error = self._sanitize_error_message(str(e))
|
438
|
+
error_msg = f"Database error: {sanitized_error}"
|
439
|
+
self.logger.error(error_msg)
|
440
|
+
raise NodeExecutionError(error_msg) from e
|
441
|
+
|
442
|
+
except Exception as e:
|
443
|
+
execution_time = time.time() - start_time
|
444
|
+
sanitized_error = self._sanitize_error_message(str(e))
|
445
|
+
error_msg = f"Unexpected error during query execution: {sanitized_error}"
|
446
|
+
self.logger.error(error_msg)
|
447
|
+
raise NodeExecutionError(error_msg) from e
|
448
|
+
|
449
|
+
self.logger.info(
|
450
|
+
f"Query executed successfully in {execution_time:.3f}s, {row_count} rows affected/returned"
|
451
|
+
)
|
232
452
|
|
233
453
|
return {
|
234
454
|
"data": formatted_data,
|
235
455
|
"row_count": row_count,
|
236
456
|
"columns": columns,
|
237
|
-
"execution_time":
|
457
|
+
"execution_time": execution_time,
|
238
458
|
}
|
239
459
|
|
460
|
+
@classmethod
|
461
|
+
def get_pool_status(cls) -> Dict[str, Any]:
|
462
|
+
"""Get status of all shared connection pools."""
|
463
|
+
with cls._pool_lock:
|
464
|
+
status = {}
|
465
|
+
for key, engine in cls._shared_pools.items():
|
466
|
+
pool = engine.pool
|
467
|
+
connection_string = key[0]
|
468
|
+
masked_string = SQLDatabaseNode._mask_connection_password(
|
469
|
+
connection_string
|
470
|
+
)
|
471
|
+
|
472
|
+
status[masked_string] = {
|
473
|
+
"pool_size": pool.size(),
|
474
|
+
"checked_out": pool.checkedout(),
|
475
|
+
"overflow": pool.overflow(),
|
476
|
+
"total_capacity": pool.size() + pool.overflow(),
|
477
|
+
"utilization": (
|
478
|
+
pool.checkedout() / (pool.size() + pool.overflow())
|
479
|
+
if (pool.size() + pool.overflow()) > 0
|
480
|
+
else 0
|
481
|
+
),
|
482
|
+
"metrics": cls._pool_metrics.get(key, {}),
|
483
|
+
}
|
484
|
+
|
485
|
+
return status
|
486
|
+
|
487
|
+
@classmethod
|
488
|
+
def cleanup_pools(cls):
|
489
|
+
"""Clean up all shared connection pools."""
|
490
|
+
with cls._pool_lock:
|
491
|
+
for engine in cls._shared_pools.values():
|
492
|
+
engine.dispose()
|
493
|
+
cls._shared_pools.clear()
|
494
|
+
cls._pool_metrics.clear()
|
495
|
+
|
496
|
+
@staticmethod
|
497
|
+
def _mask_connection_password(connection_string: str) -> str:
|
498
|
+
"""Mask password in connection string for secure logging."""
|
499
|
+
import re
|
500
|
+
|
501
|
+
pattern = r"(://[^:]+:)[^@]+(@)"
|
502
|
+
return re.sub(pattern, r"\1***\2", connection_string)
|
503
|
+
|
504
|
+
def _validate_query_safety(self, query: str) -> None:
|
505
|
+
"""Validate query for potential security issues.
|
240
506
|
|
241
|
-
|
242
|
-
|
243
|
-
"""Builds SQL queries dynamically from components.
|
507
|
+
Args:
|
508
|
+
query: SQL query to validate
|
244
509
|
|
245
|
-
|
246
|
-
|
510
|
+
Raises:
|
511
|
+
NodeExecutionError: If query contains dangerous operations
|
512
|
+
"""
|
513
|
+
if not query:
|
514
|
+
return
|
515
|
+
|
516
|
+
# Convert to uppercase for case-insensitive checks
|
517
|
+
query_upper = query.upper().strip()
|
518
|
+
|
519
|
+
# Check for dangerous SQL operations in dynamic queries
|
520
|
+
dangerous_keywords = [
|
521
|
+
"DROP",
|
522
|
+
"DELETE",
|
523
|
+
"TRUNCATE",
|
524
|
+
"ALTER",
|
525
|
+
"CREATE",
|
526
|
+
"GRANT",
|
527
|
+
"REVOKE",
|
528
|
+
"EXEC",
|
529
|
+
"EXECUTE",
|
530
|
+
"SHUTDOWN",
|
531
|
+
"BACKUP",
|
532
|
+
"RESTORE",
|
533
|
+
]
|
534
|
+
|
535
|
+
# Only flag if these appear as standalone words (not within other words)
|
536
|
+
import re
|
537
|
+
|
538
|
+
for keyword in dangerous_keywords:
|
539
|
+
# Use word boundaries to match standalone keywords
|
540
|
+
pattern = r"\b" + re.escape(keyword) + r"\b"
|
541
|
+
if re.search(pattern, query_upper):
|
542
|
+
self.logger.warning(
|
543
|
+
f"Query contains potentially dangerous keyword: {keyword}"
|
544
|
+
)
|
545
|
+
# Note: In production, you might want to block these entirely
|
546
|
+
# raise NodeExecutionError(f"Query contains forbidden keyword: {keyword}")
|
547
|
+
|
548
|
+
def _sanitize_identifier(self, identifier: str) -> str:
|
549
|
+
"""Sanitize table/column names for dynamic SQL.
|
247
550
|
|
248
|
-
|
249
|
-
|
250
|
-
2. Automatic parameter binding
|
251
|
-
3. SQL injection prevention
|
252
|
-
4. Cross-database SQL generation
|
253
|
-
5. Query validation
|
551
|
+
Args:
|
552
|
+
identifier: Table or column name
|
254
553
|
|
255
|
-
|
256
|
-
|
257
|
-
2. Conditional filtering
|
258
|
-
3. Multi-table joins
|
259
|
-
4. Aggregation queries
|
554
|
+
Returns:
|
555
|
+
Sanitized identifier
|
260
556
|
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
... order_by=['name'],
|
267
|
-
... limit=100
|
268
|
-
... )
|
269
|
-
>>> result = builder.execute()
|
270
|
-
>>> # result['query'] = 'SELECT name, email FROM customers WHERE active = ? AND country = ? ORDER BY name LIMIT 100'
|
271
|
-
>>> # result['parameters'] = [True, 'USA']
|
272
|
-
"""
|
557
|
+
Raises:
|
558
|
+
NodeExecutionError: If identifier contains invalid characters
|
559
|
+
"""
|
560
|
+
if not identifier:
|
561
|
+
return identifier
|
273
562
|
|
274
|
-
|
275
|
-
|
563
|
+
import re
|
564
|
+
|
565
|
+
# Allow only alphanumeric characters, underscores, and dots
|
566
|
+
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_\.]*$", identifier):
|
567
|
+
raise NodeExecutionError(
|
568
|
+
f"Invalid identifier '{identifier}': must contain only letters, numbers, underscores, and dots"
|
569
|
+
)
|
570
|
+
|
571
|
+
# Check for SQL injection attempts
|
572
|
+
dangerous_patterns = [
|
573
|
+
r'[\'"`;]', # Quotes and semicolons
|
574
|
+
r"--", # SQL comments
|
575
|
+
r"/\*", # Block comment start
|
576
|
+
r"\*/", # Block comment end
|
577
|
+
]
|
578
|
+
|
579
|
+
for pattern in dangerous_patterns:
|
580
|
+
if re.search(pattern, identifier):
|
581
|
+
raise NodeExecutionError(
|
582
|
+
f"Invalid identifier '{identifier}': contains potentially dangerous characters"
|
583
|
+
)
|
584
|
+
|
585
|
+
return identifier
|
586
|
+
|
587
|
+
def _validate_connection_string(self, connection_string: str) -> None:
|
588
|
+
"""Validate connection string format and security.
|
589
|
+
|
590
|
+
Args:
|
591
|
+
connection_string: Database connection URL
|
592
|
+
|
593
|
+
Raises:
|
594
|
+
NodeExecutionError: If connection string is invalid or insecure
|
595
|
+
"""
|
596
|
+
if not connection_string:
|
597
|
+
raise NodeExecutionError("Connection string cannot be empty")
|
598
|
+
|
599
|
+
# Check for supported database types (including driver specifications)
|
600
|
+
supported_protocols = ["sqlite", "postgresql", "mysql"]
|
601
|
+
protocol = (
|
602
|
+
connection_string.split("://")[0].lower()
|
603
|
+
if "://" in connection_string
|
604
|
+
else ""
|
605
|
+
)
|
606
|
+
|
607
|
+
# Handle SQLAlchemy driver specifications (e.g., mysql+pymysql, postgresql+psycopg2)
|
608
|
+
base_protocol = protocol.split("+")[0] if "+" in protocol else protocol
|
609
|
+
|
610
|
+
if base_protocol not in supported_protocols:
|
611
|
+
raise NodeExecutionError(
|
612
|
+
f"Unsupported database protocol '{protocol}'. "
|
613
|
+
f"Supported protocols: {', '.join(supported_protocols)}"
|
614
|
+
)
|
615
|
+
|
616
|
+
# Check for SQL injection in connection string
|
617
|
+
if any(char in connection_string for char in ["'", '"', ";", "--"]):
|
618
|
+
raise NodeExecutionError(
|
619
|
+
"Connection string contains potentially dangerous characters"
|
620
|
+
)
|
621
|
+
|
622
|
+
def _implement_connection_retry(
|
623
|
+
self,
|
624
|
+
connection_string: str,
|
625
|
+
timeout: int,
|
626
|
+
db_config: dict = None,
|
627
|
+
max_retries: int = 3,
|
628
|
+
):
|
629
|
+
"""Implement connection retry logic with exponential backoff.
|
276
630
|
|
277
|
-
|
631
|
+
Args:
|
632
|
+
connection_string: Database connection URL
|
633
|
+
timeout: Connection timeout
|
634
|
+
db_config: Database configuration dictionary
|
635
|
+
max_retries: Maximum number of retry attempts
|
278
636
|
|
279
637
|
Returns:
|
280
|
-
|
638
|
+
SQLAlchemy engine
|
639
|
+
|
640
|
+
Raises:
|
641
|
+
NodeExecutionError: If all connection attempts fail
|
281
642
|
"""
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
643
|
+
import time
|
644
|
+
|
645
|
+
# Handle None db_config
|
646
|
+
if db_config is None:
|
647
|
+
db_config = {}
|
648
|
+
|
649
|
+
last_error = None
|
650
|
+
|
651
|
+
for attempt in range(max_retries + 1):
|
652
|
+
try:
|
653
|
+
# Build SQLAlchemy engine configuration with defaults and overrides
|
654
|
+
engine_config = {
|
655
|
+
"poolclass": QueuePool,
|
656
|
+
"pool_size": db_config.get("pool_size", 5),
|
657
|
+
"max_overflow": db_config.get("max_overflow", 10),
|
658
|
+
"pool_timeout": db_config.get("pool_timeout", timeout),
|
659
|
+
"pool_recycle": db_config.get("pool_recycle", 3600),
|
660
|
+
"echo": db_config.get("echo", False),
|
661
|
+
}
|
662
|
+
|
663
|
+
# Add isolation level if specified
|
664
|
+
if "isolation_level" in db_config:
|
665
|
+
engine_config["isolation_level"] = db_config["isolation_level"]
|
666
|
+
|
667
|
+
# Add any additional SQLAlchemy engine parameters from db_config
|
668
|
+
for key, value in db_config.items():
|
669
|
+
if key not in [
|
670
|
+
"pool_size",
|
671
|
+
"max_overflow",
|
672
|
+
"pool_timeout",
|
673
|
+
"pool_recycle",
|
674
|
+
"echo",
|
675
|
+
"isolation_level",
|
676
|
+
]:
|
677
|
+
engine_config[key] = value
|
678
|
+
|
679
|
+
engine = create_engine(connection_string, **engine_config)
|
680
|
+
|
681
|
+
# Test the connection
|
682
|
+
with engine.connect() as conn:
|
683
|
+
conn.execute(text("SELECT 1"))
|
684
|
+
|
685
|
+
if attempt > 0:
|
686
|
+
self.logger.info(f"Connection established after {attempt} retries")
|
687
|
+
|
688
|
+
return engine
|
689
|
+
|
690
|
+
except Exception as e:
|
691
|
+
last_error = e
|
692
|
+
if attempt < max_retries:
|
693
|
+
# Exponential backoff: 1s, 2s, 4s
|
694
|
+
backoff_time = 2**attempt
|
695
|
+
self.logger.warning(
|
696
|
+
f"Connection attempt {attempt + 1} failed: {e}. "
|
697
|
+
f"Retrying in {backoff_time}s..."
|
698
|
+
)
|
699
|
+
time.sleep(backoff_time)
|
700
|
+
else:
|
701
|
+
self.logger.error(
|
702
|
+
f"All connection attempts failed. Last error: {e}"
|
703
|
+
)
|
704
|
+
|
705
|
+
raise NodeExecutionError(
|
706
|
+
f"Failed to establish database connection after {max_retries} retries: {last_error}"
|
707
|
+
)
|
708
|
+
|
709
|
+
def _sanitize_error_message(self, error_message: str) -> str:
|
710
|
+
"""Sanitize error messages to prevent sensitive data exposure.
|
329
711
|
|
330
|
-
|
331
|
-
|
712
|
+
Args:
|
713
|
+
error_message: Original error message
|
332
714
|
|
333
|
-
|
715
|
+
Returns:
|
716
|
+
Sanitized error message
|
717
|
+
"""
|
718
|
+
if not error_message:
|
719
|
+
return error_message
|
720
|
+
|
721
|
+
import re
|
722
|
+
|
723
|
+
# Remove potential passwords from error messages
|
724
|
+
patterns_to_mask = [
|
725
|
+
# Connection string passwords
|
726
|
+
(r"://[^:]+:[^@]+@", "://***:***@"),
|
727
|
+
# SQL query content (in some error messages)
|
728
|
+
(r"'[^']*'", "'***'"),
|
729
|
+
# Quoted strings that might contain sensitive data
|
730
|
+
(r'"[^"]*"', '"***"'),
|
731
|
+
]
|
732
|
+
|
733
|
+
sanitized = error_message
|
734
|
+
for pattern, replacement in patterns_to_mask:
|
735
|
+
sanitized = re.sub(pattern, replacement, sanitized)
|
736
|
+
|
737
|
+
return sanitized
|
738
|
+
|
739
|
+
def _convert_to_named_parameters(self, query: str, parameters: List) -> tuple:
|
740
|
+
"""Convert positional parameters to named parameters for SQLAlchemy 2.0.
|
334
741
|
|
335
742
|
Args:
|
336
|
-
|
743
|
+
query: SQL query with positional placeholders (?, $1, %s)
|
744
|
+
parameters: List of parameter values
|
337
745
|
|
338
746
|
Returns:
|
339
|
-
|
340
|
-
- query: Built SQL query with placeholders
|
341
|
-
- parameters: List of parameter values
|
747
|
+
Tuple of (modified_query, parameter_dict)
|
342
748
|
"""
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
749
|
+
import re
|
750
|
+
|
751
|
+
# Create parameter dictionary
|
752
|
+
param_dict = {}
|
753
|
+
for i, value in enumerate(parameters):
|
754
|
+
param_dict[f"p{i}"] = value
|
755
|
+
|
756
|
+
# Replace different placeholder formats with named parameters
|
757
|
+
modified_query = query
|
758
|
+
|
759
|
+
# Handle SQLite-style ? placeholders
|
760
|
+
placeholder_count = 0
|
761
|
+
|
762
|
+
def replace_question_mark(match):
|
763
|
+
nonlocal placeholder_count
|
764
|
+
replacement = f":p{placeholder_count}"
|
765
|
+
placeholder_count += 1
|
766
|
+
return replacement
|
767
|
+
|
768
|
+
modified_query = re.sub(r"\?", replace_question_mark, modified_query)
|
769
|
+
|
770
|
+
# Handle PostgreSQL-style $1, $2, etc. placeholders
|
771
|
+
def replace_postgres_placeholder(match):
|
772
|
+
index = int(match.group(1)) - 1 # PostgreSQL uses 1-based indexing
|
773
|
+
return f":p{index}"
|
774
|
+
|
775
|
+
modified_query = re.sub(
|
776
|
+
r"\$(\d+)", replace_postgres_placeholder, modified_query
|
777
|
+
)
|
778
|
+
|
779
|
+
# Handle MySQL-style %s placeholders
|
780
|
+
placeholder_count = 0
|
781
|
+
|
782
|
+
def replace_mysql_placeholder(match):
|
783
|
+
nonlocal placeholder_count
|
784
|
+
replacement = f":p{placeholder_count}"
|
785
|
+
placeholder_count += 1
|
786
|
+
return replacement
|
787
|
+
|
788
|
+
modified_query = re.sub(r"%s", replace_mysql_placeholder, modified_query)
|
789
|
+
|
790
|
+
return modified_query, param_dict
|
791
|
+
|
792
|
+
def _format_results(
|
793
|
+
self, rows: List, columns: List[str], result_format: str
|
794
|
+
) -> List[Any]:
|
795
|
+
"""Format query results according to specified format.
|
796
|
+
|
797
|
+
Args:
|
798
|
+
rows: Raw database rows
|
799
|
+
columns: Column names
|
800
|
+
result_format: Desired output format
|
801
|
+
|
802
|
+
Returns:
|
803
|
+
Formatted results
|
804
|
+
"""
|
805
|
+
if result_format == "dict":
|
806
|
+
# List of dictionaries with column names as keys
|
807
|
+
# SQLAlchemy rows can be converted to dict using _asdict() or dict()
|
808
|
+
return [dict(row._mapping) for row in rows]
|
809
|
+
|
810
|
+
elif result_format == "list":
|
811
|
+
# List of lists (raw rows)
|
812
|
+
return [list(row) for row in rows]
|
813
|
+
|
814
|
+
elif result_format == "raw":
|
815
|
+
# Raw SQLAlchemy row objects (converted to list for JSON serialization)
|
816
|
+
return [list(row) for row in rows]
|
817
|
+
|
818
|
+
else:
|
819
|
+
# Default to dict format
|
820
|
+
self.logger.warning(
|
821
|
+
f"Unknown result_format '{result_format}', defaulting to 'dict'"
|
822
|
+
)
|
823
|
+
return [dict(zip(columns, row)) for row in rows]
|