kailash 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kailash/nodes/data/sql.py CHANGED
@@ -12,30 +12,148 @@ Design Philosophy:
12
12
  5. Transaction support
13
13
  """
14
14
 
15
- from typing import Any, Dict
15
+ import os
16
+ import threading
17
+ import time
18
+ from datetime import datetime
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ import yaml
22
+ from sqlalchemy import create_engine, text
23
+ from sqlalchemy.exc import SQLAlchemyError
24
+ from sqlalchemy.pool import QueuePool
16
25
 
17
26
  from kailash.nodes.base import Node, NodeParameter, register_node
27
+ from kailash.sdk_exceptions import NodeExecutionError
18
28
 
19
29
 
20
30
  @register_node()
21
31
  class SQLDatabaseNode(Node):
22
- """Executes SQL queries against relational databases.
32
+
33
+ class _DatabaseConfigManager:
34
+ """Internal manager for database configurations from project settings."""
35
+
36
+ def __init__(self, project_config_path: str):
37
+ """Initialize with project configuration file path."""
38
+ self.config_path = project_config_path
39
+ self.config = self._load_project_config()
40
+
41
+ def _load_project_config(self) -> Dict[str, Any]:
42
+ """Load project configuration from YAML file."""
43
+ if not os.path.exists(self.config_path):
44
+ raise NodeExecutionError(
45
+ f"Project configuration file not found: {self.config_path}"
46
+ )
47
+
48
+ try:
49
+ with open(self.config_path, "r") as f:
50
+ config = yaml.safe_load(f)
51
+ return config or {}
52
+ except yaml.YAMLError as e:
53
+ raise NodeExecutionError(f"Invalid YAML in project configuration: {e}")
54
+ except Exception as e:
55
+ raise NodeExecutionError(f"Failed to load project configuration: {e}")
56
+
57
+ def get_database_config(
58
+ self, connection_name: str
59
+ ) -> Tuple[str, Dict[str, Any]]:
60
+ """Get database configuration by connection name.
61
+
62
+ Args:
63
+ connection_name: Name of the database connection from project config
64
+
65
+ Returns:
66
+ Tuple of (connection_string, db_config)
67
+
68
+ Raises:
69
+ NodeExecutionError: If connection not found in configuration
70
+ """
71
+ databases = self.config.get("databases", {})
72
+
73
+ if connection_name in databases:
74
+ db_config = databases[connection_name].copy()
75
+ connection_string = db_config.pop("url", None)
76
+
77
+ if not connection_string:
78
+ raise NodeExecutionError(
79
+ f"No 'url' specified for database connection '{connection_name}'"
80
+ )
81
+
82
+ # Handle environment variable substitution
83
+ connection_string = self._substitute_env_vars(connection_string)
84
+
85
+ return connection_string, db_config
86
+
87
+ # Fall back to default configuration
88
+ if "default" in databases:
89
+ default_config = databases["default"].copy()
90
+ connection_string = default_config.pop("url", None)
91
+
92
+ if connection_string:
93
+ connection_string = self._substitute_env_vars(connection_string)
94
+ return connection_string, default_config
95
+
96
+ # Ultimate fallback
97
+ raise NodeExecutionError(
98
+ f"Database connection '{connection_name}' not found in project configuration. "
99
+ f"Available connections: {list(databases.keys())}"
100
+ )
101
+
102
+ def _substitute_env_vars(self, value: str) -> str:
103
+ """Substitute environment variables in configuration values."""
104
+ if (
105
+ isinstance(value, str)
106
+ and value.startswith("${")
107
+ and value.endswith("}")
108
+ ):
109
+ env_var = value[2:-1]
110
+ env_value = os.getenv(env_var)
111
+ if env_value is None:
112
+ raise NodeExecutionError(
113
+ f"Environment variable '{env_var}' not found"
114
+ )
115
+ return env_value
116
+ return value
117
+
118
+ def validate_config(self) -> None:
119
+ """Validate the project configuration."""
120
+ databases = self.config.get("databases", {})
121
+
122
+ if not databases:
123
+ raise NodeExecutionError(
124
+ "No databases configured in project configuration"
125
+ )
126
+
127
+ for name, config in databases.items():
128
+ if not isinstance(config, dict):
129
+ raise NodeExecutionError(
130
+ f"Database '{name}' configuration must be a dictionary"
131
+ )
132
+
133
+ if "url" not in config and name != "default":
134
+ raise NodeExecutionError(
135
+ f"Database '{name}' missing required 'url' field"
136
+ )
137
+
138
+ """Executes SQL queries against relational databases with shared connection pools.
23
139
 
24
140
  This node provides a unified interface for interacting with various RDBMS
25
141
  systems including PostgreSQL, MySQL, SQLite, and others. It handles
26
- connection management, query execution, and result formatting.
142
+ connection management, query execution, and result formatting using
143
+ shared connection pools for efficient resource utilization.
27
144
 
28
145
  Design Features:
29
- 1. Database adapter pattern for multiple RDBMS support
30
- 2. Connection pooling for efficient resource usage
146
+ 1. Shared connection pools across all node instances
147
+ 2. Project-level database configuration
31
148
  3. Parameterized queries to prevent SQL injection
32
149
  4. Flexible result formats (dict, list, raw)
33
150
  5. Transaction support with commit/rollback
34
151
  6. Query timeout handling
152
+ 7. Connection pool monitoring and metrics
35
153
 
36
154
  Data Flow:
37
- - Input: SQL query, parameters, connection config
38
- - Processing: Execute query, format results
155
+ - Input: Connection name (from project config), SQL query, parameters
156
+ - Processing: Execute query using shared pools, format results
39
157
  - Output: Query results in specified format
40
158
 
41
159
  Common Usage Patterns:
@@ -45,57 +163,106 @@ class SQLDatabaseNode(Node):
45
163
  4. Report generation
46
164
  5. Data validation queries
47
165
 
48
- Upstream Sources:
49
- - User-defined queries
50
- - Query builder nodes
51
- - Template processors
52
- - Previous query results
53
-
54
- Downstream Consumers:
55
- - Transform nodes: Process query results
56
- - Writer nodes: Export to files
57
- - Aggregator nodes: Summarize data
58
- - Visualization nodes: Create charts
59
-
60
- Error Handling:
61
- - ConnectionError: Database connection issues
62
- - QueryError: SQL syntax or execution errors
63
- - TimeoutError: Query execution timeout
64
- - PermissionError: Access denied
65
-
66
166
  Example:
67
- >>> # Query customer data
68
- >>> sql_node = SQLDatabaseNode(
69
- ... connection_string='postgresql://user:pass@host/db',
167
+ >>> # Initialize with project configuration
168
+ >>> SQLDatabaseNode.initialize('kailash_project.yaml')
169
+ >>>
170
+ >>> # Create node with database connection configuration
171
+ >>> sql_node = SQLDatabaseNode(connection='customer_db')
172
+ >>>
173
+ >>> # Execute multiple queries with the same node
174
+ >>> result1 = sql_node.run(
70
175
  ... query='SELECT * FROM customers WHERE active = ?',
71
- ... parameters=[True],
72
- ... result_format='dict'
176
+ ... parameters=[True]
177
+ ... )
178
+ >>> result2 = sql_node.run(
179
+ ... query='SELECT COUNT(*) as total FROM orders'
73
180
  ... )
74
- >>> result = sql_node.execute()
75
- >>> # result['data'] = [
181
+ >>> # result1['data'] = [
76
182
  >>> # {'id': 1, 'name': 'John', 'active': True},
77
183
  >>> # {'id': 2, 'name': 'Jane', 'active': True}
78
184
  >>> # ]
79
185
  """
80
186
 
187
+ # Class-level shared resources for connection pooling
188
+ _shared_pools: Dict[Tuple[str, frozenset], Any] = {}
189
+ _pool_metrics: Dict[Tuple[str, frozenset], Dict[str, Any]] = {}
190
+ _pool_lock = threading.Lock()
191
+ _config_manager: Optional["SQLDatabaseNode._DatabaseConfigManager"] = None
192
+
193
+ # NOTE: This method is deprecated in favor of direct configuration in constructor
194
+ @classmethod
195
+ def initialize(cls, project_config_path: str) -> None:
196
+ """Initialize shared resources with project configuration.
197
+
198
+ DEPRECATED: Use direct configuration in constructor instead.
199
+
200
+ Args:
201
+ project_config_path: Path to the project configuration YAML file
202
+ """
203
+ with cls._pool_lock:
204
+ cls._config_manager = cls._DatabaseConfigManager(project_config_path)
205
+ cls._config_manager.validate_config()
206
+
207
+ def __init__(
208
+ self,
209
+ connection_string: str = None,
210
+ pool_size: int = 5,
211
+ max_overflow: int = 10,
212
+ pool_timeout: int = 30,
213
+ pool_recycle: int = 3600,
214
+ pool_pre_ping: bool = True,
215
+ echo: bool = False,
216
+ connect_args: dict = None,
217
+ **kwargs,
218
+ ):
219
+ """Initialize SQLDatabaseNode with direct database connection configuration.
220
+
221
+ Args:
222
+ connection_string: Database connection URL (e.g., "sqlite:///path/to/db.db")
223
+ pool_size: Number of connections in the pool (default: 5)
224
+ max_overflow: Maximum overflow connections (default: 10)
225
+ pool_timeout: Timeout in seconds to get connection from pool (default: 30)
226
+ pool_recycle: Time in seconds to recycle connections (default: 3600)
227
+ pool_pre_ping: Test connections before use (default: True)
228
+ echo: Enable SQLAlchemy query logging (default: False)
229
+ connect_args: Additional database-specific connection arguments
230
+ **kwargs: Additional node configuration parameters
231
+ """
232
+ if not connection_string:
233
+ raise NodeExecutionError("connection_string parameter is required")
234
+
235
+ # Store connection configuration
236
+ self.connection_string = connection_string
237
+ self.db_config = {
238
+ "pool_size": pool_size,
239
+ "max_overflow": max_overflow,
240
+ "pool_timeout": pool_timeout,
241
+ "pool_recycle": pool_recycle,
242
+ "pool_pre_ping": pool_pre_ping,
243
+ "echo": echo,
244
+ }
245
+
246
+ if connect_args:
247
+ self.db_config["connect_args"] = connect_args
248
+
249
+ # Add connection_string to kwargs for base class validation
250
+ kwargs["connection_string"] = connection_string
251
+
252
+ # Call parent constructor
253
+ super().__init__(**kwargs)
254
+
81
255
  def get_parameters(self) -> Dict[str, NodeParameter]:
82
256
  """Define input parameters for SQL execution.
83
257
 
84
- Comprehensive parameters supporting various database operations
85
- and configuration options.
258
+ Configuration parameters (provided to constructor):
259
+ 1. connection_string: Database connection URL
260
+ 2. pool_size, max_overflow, etc.: Connection pool configuration
86
261
 
87
- Parameter Design:
88
- 1. connection_string: Database connection details
89
- 2. query: SQL query to execute
90
- 3. parameters: Query parameters for safety
91
- 4. result_format: Output structure preference
92
- 5. timeout: Query execution limit
93
- 6. transaction_mode: Transaction handling
94
-
95
- Security considerations:
96
- - Always use parameterized queries
97
- - Connection strings should use environment variables
98
- - Validate query permissions
262
+ Runtime parameters (passed to run() method):
263
+ 3. query: SQL query to execute
264
+ 4. parameters: Query parameters for safety
265
+ 5. result_format: Output format
99
266
 
100
267
  Returns:
101
268
  Dictionary of parameter definitions
@@ -105,13 +272,13 @@ class SQLDatabaseNode(Node):
105
272
  name="connection_string",
106
273
  type=str,
107
274
  required=True,
108
- description="Database connection string (e.g., 'postgresql://user:pass@host/db')",
275
+ description="Database connection URL (e.g., 'sqlite:///path/to/db.db')",
109
276
  ),
110
277
  "query": NodeParameter(
111
278
  name="query",
112
279
  type=str,
113
- required=True,
114
- description="SQL query to execute (use ? for parameters)",
280
+ required=False, # Not required in constructor, provided at runtime
281
+ description="SQL query to execute (use ? for SQLite, $1 for PostgreSQL, %s for MySQL)",
115
282
  ),
116
283
  "parameters": NodeParameter(
117
284
  name="parameters",
@@ -127,55 +294,54 @@ class SQLDatabaseNode(Node):
127
294
  default="dict",
128
295
  description="Result format: 'dict', 'list', or 'raw'",
129
296
  ),
130
- "timeout": NodeParameter(
131
- name="timeout",
132
- type=int,
133
- required=False,
134
- default=30,
135
- description="Query timeout in seconds",
136
- ),
137
- "transaction_mode": NodeParameter(
138
- name="transaction_mode",
139
- type=str,
140
- required=False,
141
- default="auto",
142
- description="Transaction mode: 'auto', 'manual', or 'none'",
143
- ),
144
297
  }
145
298
 
299
+ @staticmethod
300
+ def _make_hashable(obj):
301
+ """Convert nested dictionaries/lists to hashable tuples for cache keys."""
302
+ if isinstance(obj, dict):
303
+ return tuple(
304
+ sorted((k, SQLDatabaseNode._make_hashable(v)) for k, v in obj.items())
305
+ )
306
+ elif isinstance(obj, list):
307
+ return tuple(SQLDatabaseNode._make_hashable(item) for item in obj)
308
+ else:
309
+ return obj
310
+
311
+ def _get_shared_engine(self):
312
+ """Get or create shared engine for database connection."""
313
+ cache_key = (self.connection_string, self._make_hashable(self.db_config))
314
+
315
+ with self._pool_lock:
316
+ if cache_key not in self._shared_pools:
317
+ self.logger.info(
318
+ f"Creating shared pool for {SQLDatabaseNode._mask_connection_password(self.connection_string)}"
319
+ )
320
+
321
+ # Apply configuration with sensible defaults
322
+ pool_config = {
323
+ "poolclass": QueuePool,
324
+ **self.db_config, # Use the stored db_config
325
+ }
326
+
327
+ engine = create_engine(self.connection_string, **pool_config)
328
+
329
+ self._shared_pools[cache_key] = engine
330
+ self._pool_metrics[cache_key] = {
331
+ "created_at": datetime.now(),
332
+ "total_queries": 0,
333
+ }
334
+
335
+ return self._shared_pools[cache_key]
336
+
146
337
  def run(self, **kwargs) -> Dict[str, Any]:
147
- """Execute SQL query against database.
148
-
149
- Performs database query execution with proper connection handling,
150
- parameter binding, and result formatting.
151
-
152
- Processing Steps:
153
- 1. Parse connection string
154
- 2. Establish database connection
155
- 3. Prepare parameterized query
156
- 4. Execute with timeout
157
- 5. Format results
158
- 6. Handle transactions
159
- 7. Close connection
160
-
161
- Connection Management:
162
- - Uses connection pooling when available
163
- - Automatic retry on connection failure
164
- - Proper cleanup on errors
165
-
166
- Result Formatting:
167
- - dict: List of dictionaries with column names
168
- - list: List of lists (raw rows)
169
- - raw: Database cursor object
338
+ """Execute SQL query using shared connection pool.
170
339
 
171
340
  Args:
172
341
  **kwargs: Validated parameters including:
173
- - connection_string: Database URL
174
342
  - query: SQL statement
175
- - parameters: Query parameters
176
- - result_format: Output format
177
- - timeout: Execution timeout
178
- - transaction_mode: Transaction handling
343
+ - parameters: Query parameters (optional)
344
+ - result_format: Output format (optional)
179
345
 
180
346
  Returns:
181
347
  Dictionary containing:
@@ -186,195 +352,472 @@ class SQLDatabaseNode(Node):
186
352
 
187
353
  Raises:
188
354
  NodeExecutionError: Connection or query errors
189
- NodeValidationError: Invalid parameters
190
- TimeoutError: Query timeout exceeded
191
355
  """
192
- connection_string = kwargs["connection_string"]
193
- query = kwargs["query"]
194
- # parameters = kwargs.get("parameters", []) # TODO: Implement parameterized queries
356
+ # Extract validated inputs
357
+ query = kwargs.get("query")
358
+ parameters = kwargs.get("parameters", [])
195
359
  result_format = kwargs.get("result_format", "dict")
196
- # timeout = kwargs.get("timeout", 30) # TODO: Implement query timeout
197
- # transaction_mode = kwargs.get("transaction_mode", "auto") # TODO: Implement transaction handling
198
-
199
- # This is a placeholder implementation
200
- # In a real implementation, you would:
201
- # 1. Use appropriate database driver (psycopg2, pymysql, sqlite3, etc.)
202
- # 2. Implement connection pooling
203
- # 3. Handle parameterized queries properly
204
- # 4. Implement timeout handling
205
- # 5. Format results according to result_format
206
-
207
- self.logger.info(f"Executing SQL query on {connection_string}")
208
-
209
- # Simulate query execution
210
- # In real implementation, use actual database connection
211
- if "SELECT" in query.upper():
212
- # Simulate SELECT query results
213
- data = [
214
- {"id": 1, "name": "Sample1", "value": 100},
215
- {"id": 2, "name": "Sample2", "value": 200},
216
- ]
217
- columns = ["id", "name", "value"]
218
- row_count = len(data)
219
- else:
220
- # Simulate INSERT/UPDATE/DELETE
221
- data = []
222
- columns = []
223
- row_count = 1 # Affected rows
224
360
 
225
- # Format results based on result_format
226
- if result_format == "dict":
227
- formatted_data = data
228
- elif result_format == "list":
229
- formatted_data = [[row[col] for col in columns] for row in data]
230
- else: # raw
231
- formatted_data = data
361
+ # Validate required parameters
362
+ if not query:
363
+ raise NodeExecutionError("query parameter is required")
364
+
365
+ # Validate query safety
366
+ self._validate_query_safety(query)
367
+
368
+ # Mask password in connection string for logging
369
+ masked_connection = SQLDatabaseNode._mask_connection_password(
370
+ self.connection_string
371
+ )
372
+ self.logger.info(f"Executing SQL query on {masked_connection}")
373
+ self.logger.debug(f"Query: {query}")
374
+ self.logger.debug(f"Parameters: {parameters}")
375
+
376
+ # Get shared engine
377
+ engine = self._get_shared_engine()
378
+
379
+ # Track metrics - use same cache key generation logic
380
+ cache_key = (self.connection_string, self._make_hashable(self.db_config))
381
+ with self._pool_lock:
382
+ self._pool_metrics[cache_key]["total_queries"] += 1
383
+
384
+ # Execute query with shared connection pool
385
+ start_time = time.time()
386
+
387
+ try:
388
+ with engine.connect() as conn:
389
+ with conn.begin() as trans:
390
+ try:
391
+ # Handle parameterized queries
392
+ # SQLAlchemy 2.0 with text() requires named parameters for positional values
393
+ if parameters:
394
+ if isinstance(parameters, dict):
395
+ # Named parameters - use as-is
396
+ result = conn.execute(text(query), parameters)
397
+ elif isinstance(parameters, (list, tuple)):
398
+ # Convert positional parameters to named parameters
399
+ named_query, param_dict = (
400
+ self._convert_to_named_parameters(query, parameters)
401
+ )
402
+ result = conn.execute(text(named_query), param_dict)
403
+ else:
404
+ # Single parameter
405
+ named_query, param_dict = (
406
+ self._convert_to_named_parameters(
407
+ query, [parameters]
408
+ )
409
+ )
410
+ result = conn.execute(text(named_query), param_dict)
411
+ else:
412
+ result = conn.execute(text(query))
413
+
414
+ execution_time = time.time() - start_time
415
+
416
+ # Process results
417
+ if result.returns_rows:
418
+ rows = result.fetchall()
419
+ columns = list(result.keys()) if result.keys() else []
420
+ row_count = len(rows)
421
+ formatted_data = self._format_results(
422
+ rows, columns, result_format
423
+ )
424
+ else:
425
+ formatted_data = []
426
+ columns = []
427
+ row_count = result.rowcount if result.rowcount != -1 else 0
428
+
429
+ trans.commit()
430
+
431
+ except Exception:
432
+ trans.rollback()
433
+ raise
434
+
435
+ except SQLAlchemyError as e:
436
+ execution_time = time.time() - start_time
437
+ sanitized_error = self._sanitize_error_message(str(e))
438
+ error_msg = f"Database error: {sanitized_error}"
439
+ self.logger.error(error_msg)
440
+ raise NodeExecutionError(error_msg) from e
441
+
442
+ except Exception as e:
443
+ execution_time = time.time() - start_time
444
+ sanitized_error = self._sanitize_error_message(str(e))
445
+ error_msg = f"Unexpected error during query execution: {sanitized_error}"
446
+ self.logger.error(error_msg)
447
+ raise NodeExecutionError(error_msg) from e
448
+
449
+ self.logger.info(
450
+ f"Query executed successfully in {execution_time:.3f}s, {row_count} rows affected/returned"
451
+ )
232
452
 
233
453
  return {
234
454
  "data": formatted_data,
235
455
  "row_count": row_count,
236
456
  "columns": columns,
237
- "execution_time": 0.125, # Simulated execution time
457
+ "execution_time": execution_time,
238
458
  }
239
459
 
460
+ @classmethod
461
+ def get_pool_status(cls) -> Dict[str, Any]:
462
+ """Get status of all shared connection pools."""
463
+ with cls._pool_lock:
464
+ status = {}
465
+ for key, engine in cls._shared_pools.items():
466
+ pool = engine.pool
467
+ connection_string = key[0]
468
+ masked_string = SQLDatabaseNode._mask_connection_password(
469
+ connection_string
470
+ )
471
+
472
+ status[masked_string] = {
473
+ "pool_size": pool.size(),
474
+ "checked_out": pool.checkedout(),
475
+ "overflow": pool.overflow(),
476
+ "total_capacity": pool.size() + pool.overflow(),
477
+ "utilization": (
478
+ pool.checkedout() / (pool.size() + pool.overflow())
479
+ if (pool.size() + pool.overflow()) > 0
480
+ else 0
481
+ ),
482
+ "metrics": cls._pool_metrics.get(key, {}),
483
+ }
484
+
485
+ return status
486
+
487
+ @classmethod
488
+ def cleanup_pools(cls):
489
+ """Clean up all shared connection pools."""
490
+ with cls._pool_lock:
491
+ for engine in cls._shared_pools.values():
492
+ engine.dispose()
493
+ cls._shared_pools.clear()
494
+ cls._pool_metrics.clear()
495
+
496
+ @staticmethod
497
+ def _mask_connection_password(connection_string: str) -> str:
498
+ """Mask password in connection string for secure logging."""
499
+ import re
500
+
501
+ pattern = r"(://[^:]+:)[^@]+(@)"
502
+ return re.sub(pattern, r"\1***\2", connection_string)
503
+
504
+ def _validate_query_safety(self, query: str) -> None:
505
+ """Validate query for potential security issues.
240
506
 
241
- @register_node()
242
- class SQLQueryBuilderNode(Node):
243
- """Builds SQL queries dynamically from components.
507
+ Args:
508
+ query: SQL query to validate
244
509
 
245
- This node constructs SQL queries programmatically, providing a safe
246
- and flexible way to build complex queries without string concatenation.
510
+ Raises:
511
+ NodeExecutionError: If query contains dangerous operations
512
+ """
513
+ if not query:
514
+ return
515
+
516
+ # Convert to uppercase for case-insensitive checks
517
+ query_upper = query.upper().strip()
518
+
519
+ # Check for dangerous SQL operations in dynamic queries
520
+ dangerous_keywords = [
521
+ "DROP",
522
+ "DELETE",
523
+ "TRUNCATE",
524
+ "ALTER",
525
+ "CREATE",
526
+ "GRANT",
527
+ "REVOKE",
528
+ "EXEC",
529
+ "EXECUTE",
530
+ "SHUTDOWN",
531
+ "BACKUP",
532
+ "RESTORE",
533
+ ]
534
+
535
+ # Only flag if these appear as standalone words (not within other words)
536
+ import re
537
+
538
+ for keyword in dangerous_keywords:
539
+ # Use word boundaries to match standalone keywords
540
+ pattern = r"\b" + re.escape(keyword) + r"\b"
541
+ if re.search(pattern, query_upper):
542
+ self.logger.warning(
543
+ f"Query contains potentially dangerous keyword: {keyword}"
544
+ )
545
+ # Note: In production, you might want to block these entirely
546
+ # raise NodeExecutionError(f"Query contains forbidden keyword: {keyword}")
547
+
548
+ def _sanitize_identifier(self, identifier: str) -> str:
549
+ """Sanitize table/column names for dynamic SQL.
247
550
 
248
- Design Features:
249
- 1. Fluent interface for query building
250
- 2. Automatic parameter binding
251
- 3. SQL injection prevention
252
- 4. Cross-database SQL generation
253
- 5. Query validation
551
+ Args:
552
+ identifier: Table or column name
254
553
 
255
- Common Usage Patterns:
256
- 1. Dynamic report queries
257
- 2. Conditional filtering
258
- 3. Multi-table joins
259
- 4. Aggregation queries
554
+ Returns:
555
+ Sanitized identifier
260
556
 
261
- Example:
262
- >>> builder = SQLQueryBuilderNode(
263
- ... table='customers',
264
- ... select=['name', 'email'],
265
- ... where={'active': True, 'country': 'USA'},
266
- ... order_by=['name'],
267
- ... limit=100
268
- ... )
269
- >>> result = builder.execute()
270
- >>> # result['query'] = 'SELECT name, email FROM customers WHERE active = ? AND country = ? ORDER BY name LIMIT 100'
271
- >>> # result['parameters'] = [True, 'USA']
272
- """
557
+ Raises:
558
+ NodeExecutionError: If identifier contains invalid characters
559
+ """
560
+ if not identifier:
561
+ return identifier
273
562
 
274
- def get_parameters(self) -> Dict[str, NodeParameter]:
275
- """Define input parameters for query building.
563
+ import re
564
+
565
+ # Allow only alphanumeric characters, underscores, and dots
566
+ if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_\.]*$", identifier):
567
+ raise NodeExecutionError(
568
+ f"Invalid identifier '{identifier}': must contain only letters, numbers, underscores, and dots"
569
+ )
570
+
571
+ # Check for SQL injection attempts
572
+ dangerous_patterns = [
573
+ r'[\'"`;]', # Quotes and semicolons
574
+ r"--", # SQL comments
575
+ r"/\*", # Block comment start
576
+ r"\*/", # Block comment end
577
+ ]
578
+
579
+ for pattern in dangerous_patterns:
580
+ if re.search(pattern, identifier):
581
+ raise NodeExecutionError(
582
+ f"Invalid identifier '{identifier}': contains potentially dangerous characters"
583
+ )
584
+
585
+ return identifier
586
+
587
+ def _validate_connection_string(self, connection_string: str) -> None:
588
+ """Validate connection string format and security.
589
+
590
+ Args:
591
+ connection_string: Database connection URL
592
+
593
+ Raises:
594
+ NodeExecutionError: If connection string is invalid or insecure
595
+ """
596
+ if not connection_string:
597
+ raise NodeExecutionError("Connection string cannot be empty")
598
+
599
+ # Check for supported database types (including driver specifications)
600
+ supported_protocols = ["sqlite", "postgresql", "mysql"]
601
+ protocol = (
602
+ connection_string.split("://")[0].lower()
603
+ if "://" in connection_string
604
+ else ""
605
+ )
606
+
607
+ # Handle SQLAlchemy driver specifications (e.g., mysql+pymysql, postgresql+psycopg2)
608
+ base_protocol = protocol.split("+")[0] if "+" in protocol else protocol
609
+
610
+ if base_protocol not in supported_protocols:
611
+ raise NodeExecutionError(
612
+ f"Unsupported database protocol '{protocol}'. "
613
+ f"Supported protocols: {', '.join(supported_protocols)}"
614
+ )
615
+
616
+ # Check for SQL injection in connection string
617
+ if any(char in connection_string for char in ["'", '"', ";", "--"]):
618
+ raise NodeExecutionError(
619
+ "Connection string contains potentially dangerous characters"
620
+ )
621
+
622
+ def _implement_connection_retry(
623
+ self,
624
+ connection_string: str,
625
+ timeout: int,
626
+ db_config: dict = None,
627
+ max_retries: int = 3,
628
+ ):
629
+ """Implement connection retry logic with exponential backoff.
276
630
 
277
- Parameters for constructing SQL queries programmatically.
631
+ Args:
632
+ connection_string: Database connection URL
633
+ timeout: Connection timeout
634
+ db_config: Database configuration dictionary
635
+ max_retries: Maximum number of retry attempts
278
636
 
279
637
  Returns:
280
- Dictionary of parameter definitions
638
+ SQLAlchemy engine
639
+
640
+ Raises:
641
+ NodeExecutionError: If all connection attempts fail
281
642
  """
282
- return {
283
- "table": NodeParameter(
284
- name="table", type=str, required=True, description="Target table name"
285
- ),
286
- "select": NodeParameter(
287
- name="select",
288
- type=list,
289
- required=False,
290
- default=["*"],
291
- description="Columns to select",
292
- ),
293
- "where": NodeParameter(
294
- name="where",
295
- type=dict,
296
- required=False,
297
- default={},
298
- description="WHERE clause conditions",
299
- ),
300
- "join": NodeParameter(
301
- name="join",
302
- type=list,
303
- required=False,
304
- default=[],
305
- description="JOIN clauses",
306
- ),
307
- "order_by": NodeParameter(
308
- name="order_by",
309
- type=list,
310
- required=False,
311
- default=[],
312
- description="ORDER BY columns",
313
- ),
314
- "limit": NodeParameter(
315
- name="limit",
316
- type=int,
317
- required=False,
318
- default=None,
319
- description="Result limit",
320
- ),
321
- "offset": NodeParameter(
322
- name="offset",
323
- type=int,
324
- required=False,
325
- default=None,
326
- description="Result offset",
327
- ),
328
- }
643
+ import time
644
+
645
+ # Handle None db_config
646
+ if db_config is None:
647
+ db_config = {}
648
+
649
+ last_error = None
650
+
651
+ for attempt in range(max_retries + 1):
652
+ try:
653
+ # Build SQLAlchemy engine configuration with defaults and overrides
654
+ engine_config = {
655
+ "poolclass": QueuePool,
656
+ "pool_size": db_config.get("pool_size", 5),
657
+ "max_overflow": db_config.get("max_overflow", 10),
658
+ "pool_timeout": db_config.get("pool_timeout", timeout),
659
+ "pool_recycle": db_config.get("pool_recycle", 3600),
660
+ "echo": db_config.get("echo", False),
661
+ }
662
+
663
+ # Add isolation level if specified
664
+ if "isolation_level" in db_config:
665
+ engine_config["isolation_level"] = db_config["isolation_level"]
666
+
667
+ # Add any additional SQLAlchemy engine parameters from db_config
668
+ for key, value in db_config.items():
669
+ if key not in [
670
+ "pool_size",
671
+ "max_overflow",
672
+ "pool_timeout",
673
+ "pool_recycle",
674
+ "echo",
675
+ "isolation_level",
676
+ ]:
677
+ engine_config[key] = value
678
+
679
+ engine = create_engine(connection_string, **engine_config)
680
+
681
+ # Test the connection
682
+ with engine.connect() as conn:
683
+ conn.execute(text("SELECT 1"))
684
+
685
+ if attempt > 0:
686
+ self.logger.info(f"Connection established after {attempt} retries")
687
+
688
+ return engine
689
+
690
+ except Exception as e:
691
+ last_error = e
692
+ if attempt < max_retries:
693
+ # Exponential backoff: 1s, 2s, 4s
694
+ backoff_time = 2**attempt
695
+ self.logger.warning(
696
+ f"Connection attempt {attempt + 1} failed: {e}. "
697
+ f"Retrying in {backoff_time}s..."
698
+ )
699
+ time.sleep(backoff_time)
700
+ else:
701
+ self.logger.error(
702
+ f"All connection attempts failed. Last error: {e}"
703
+ )
704
+
705
+ raise NodeExecutionError(
706
+ f"Failed to establish database connection after {max_retries} retries: {last_error}"
707
+ )
708
+
709
+ def _sanitize_error_message(self, error_message: str) -> str:
710
+ """Sanitize error messages to prevent sensitive data exposure.
329
711
 
330
- def run(self, **kwargs) -> Dict[str, Any]:
331
- """Build SQL query from components.
712
+ Args:
713
+ error_message: Original error message
332
714
 
333
- Constructs a parameterized SQL query from the provided components.
715
+ Returns:
716
+ Sanitized error message
717
+ """
718
+ if not error_message:
719
+ return error_message
720
+
721
+ import re
722
+
723
+ # Remove potential passwords from error messages
724
+ patterns_to_mask = [
725
+ # Connection string passwords
726
+ (r"://[^:]+:[^@]+@", "://***:***@"),
727
+ # SQL query content (in some error messages)
728
+ (r"'[^']*'", "'***'"),
729
+ # Quoted strings that might contain sensitive data
730
+ (r'"[^"]*"', '"***"'),
731
+ ]
732
+
733
+ sanitized = error_message
734
+ for pattern, replacement in patterns_to_mask:
735
+ sanitized = re.sub(pattern, replacement, sanitized)
736
+
737
+ return sanitized
738
+
739
+ def _convert_to_named_parameters(self, query: str, parameters: List) -> tuple:
740
+ """Convert positional parameters to named parameters for SQLAlchemy 2.0.
334
741
 
335
742
  Args:
336
- **kwargs: Query components
743
+ query: SQL query with positional placeholders (?, $1, %s)
744
+ parameters: List of parameter values
337
745
 
338
746
  Returns:
339
- Dictionary containing:
340
- - query: Built SQL query with placeholders
341
- - parameters: List of parameter values
747
+ Tuple of (modified_query, parameter_dict)
342
748
  """
343
- table = kwargs["table"]
344
- select = kwargs.get("select", ["*"])
345
- where = kwargs.get("where", {})
346
- join = kwargs.get("join", [])
347
- order_by = kwargs.get("order_by", [])
348
- limit = kwargs.get("limit")
349
- offset = kwargs.get("offset")
350
-
351
- # Build SELECT clause
352
- select_clause = ", ".join(select)
353
- query_parts = [f"SELECT {select_clause}", f"FROM {table}"]
354
- parameters = []
355
-
356
- # Build JOIN clauses
357
- for join_spec in join:
358
- query_parts.append(f"JOIN {join_spec}")
359
-
360
- # Build WHERE clause
361
- if where:
362
- conditions = []
363
- for key, value in where.items():
364
- conditions.append(f"{key} = ?")
365
- parameters.append(value)
366
- query_parts.append(f"WHERE {' AND '.join(conditions)}")
367
-
368
- # Build ORDER BY clause
369
- if order_by:
370
- query_parts.append(f"ORDER BY {', '.join(order_by)}")
371
-
372
- # Build LIMIT/OFFSET
373
- if limit is not None:
374
- query_parts.append(f"LIMIT {limit}")
375
- if offset is not None:
376
- query_parts.append(f"OFFSET {offset}")
377
-
378
- query = " ".join(query_parts)
379
-
380
- return {"query": query, "parameters": parameters}
749
+ import re
750
+
751
+ # Create parameter dictionary
752
+ param_dict = {}
753
+ for i, value in enumerate(parameters):
754
+ param_dict[f"p{i}"] = value
755
+
756
+ # Replace different placeholder formats with named parameters
757
+ modified_query = query
758
+
759
+ # Handle SQLite-style ? placeholders
760
+ placeholder_count = 0
761
+
762
+ def replace_question_mark(match):
763
+ nonlocal placeholder_count
764
+ replacement = f":p{placeholder_count}"
765
+ placeholder_count += 1
766
+ return replacement
767
+
768
+ modified_query = re.sub(r"\?", replace_question_mark, modified_query)
769
+
770
+ # Handle PostgreSQL-style $1, $2, etc. placeholders
771
+ def replace_postgres_placeholder(match):
772
+ index = int(match.group(1)) - 1 # PostgreSQL uses 1-based indexing
773
+ return f":p{index}"
774
+
775
+ modified_query = re.sub(
776
+ r"\$(\d+)", replace_postgres_placeholder, modified_query
777
+ )
778
+
779
+ # Handle MySQL-style %s placeholders
780
+ placeholder_count = 0
781
+
782
+ def replace_mysql_placeholder(match):
783
+ nonlocal placeholder_count
784
+ replacement = f":p{placeholder_count}"
785
+ placeholder_count += 1
786
+ return replacement
787
+
788
+ modified_query = re.sub(r"%s", replace_mysql_placeholder, modified_query)
789
+
790
+ return modified_query, param_dict
791
+
792
+ def _format_results(
793
+ self, rows: List, columns: List[str], result_format: str
794
+ ) -> List[Any]:
795
+ """Format query results according to specified format.
796
+
797
+ Args:
798
+ rows: Raw database rows
799
+ columns: Column names
800
+ result_format: Desired output format
801
+
802
+ Returns:
803
+ Formatted results
804
+ """
805
+ if result_format == "dict":
806
+ # List of dictionaries with column names as keys
807
+ # SQLAlchemy rows can be converted to dict using _asdict() or dict()
808
+ return [dict(row._mapping) for row in rows]
809
+
810
+ elif result_format == "list":
811
+ # List of lists (raw rows)
812
+ return [list(row) for row in rows]
813
+
814
+ elif result_format == "raw":
815
+ # Raw SQLAlchemy row objects (converted to list for JSON serialization)
816
+ return [list(row) for row in rows]
817
+
818
+ else:
819
+ # Default to dict format
820
+ self.logger.warning(
821
+ f"Unknown result_format '{result_format}', defaulting to 'dict'"
822
+ )
823
+ return [dict(zip(columns, row)) for row in rows]