daita-agents 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1159 @@
1
+ """
2
+ Snowflake plugin for Daita Agents.
3
+
4
+ Provides Snowflake data warehouse connection and querying capabilities.
5
+ Supports key-pair authentication, warehouse management, and stage operations.
6
+ """
7
+ import logging
8
+ import os
9
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING
10
+ from .base_db import BaseDatabasePlugin
11
+
12
+ if TYPE_CHECKING:
13
+ from ..core.tools import AgentTool
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SnowflakePlugin(BaseDatabasePlugin):
19
+ """
20
+ Snowflake plugin for agents with warehouse management and stage operations.
21
+
22
+ Inherits common database functionality from BaseDatabasePlugin and adds
23
+ Snowflake-specific features like warehouse switching and stage data loading.
24
+
25
+ Supports both password and key-pair authentication.
26
+
27
+ Example:
28
+ from daita.plugins import snowflake
29
+
30
+ # Password authentication
31
+ db = snowflake(
32
+ account="xy12345",
33
+ warehouse="COMPUTE_WH",
34
+ database="MYDB",
35
+ user="myuser",
36
+ password="mypass"
37
+ )
38
+
39
+ # Key-pair authentication
40
+ db = snowflake(
41
+ account="xy12345",
42
+ warehouse="COMPUTE_WH",
43
+ database="MYDB",
44
+ user="myuser",
45
+ private_key_path="/path/to/key.p8"
46
+ )
47
+
48
+ # Use with agent
49
+ agent = SubstrateAgent(
50
+ name="Data Analyst",
51
+ tools=[db]
52
+ )
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ account: Optional[str] = None,
58
+ warehouse: Optional[str] = None,
59
+ database: Optional[str] = None,
60
+ schema: str = "PUBLIC",
61
+ user: Optional[str] = None,
62
+ password: Optional[str] = None,
63
+ role: Optional[str] = None,
64
+ private_key_path: Optional[str] = None,
65
+ private_key_passphrase: Optional[str] = None,
66
+ timeout: int = 300,
67
+ **kwargs
68
+ ):
69
+ """
70
+ Initialize Snowflake connection.
71
+
72
+ Args:
73
+ account: Snowflake account identifier (e.g., "xy12345")
74
+ warehouse: Compute warehouse name
75
+ database: Database name
76
+ schema: Schema name (default: "PUBLIC")
77
+ user: Username for authentication
78
+ password: Password for authentication (optional if using key-pair)
79
+ role: Role to use for the session
80
+ private_key_path: Path to private key file for key-pair auth
81
+ private_key_passphrase: Passphrase for encrypted private key
82
+ timeout: Query timeout in seconds (default: 300)
83
+ **kwargs: Additional configuration options
84
+
85
+ Environment variables:
86
+ SNOWFLAKE_ACCOUNT: Account identifier
87
+ SNOWFLAKE_WAREHOUSE: Warehouse name
88
+ SNOWFLAKE_DATABASE: Database name
89
+ SNOWFLAKE_SCHEMA: Schema name
90
+ SNOWFLAKE_USER: Username
91
+ SNOWFLAKE_PASSWORD: Password
92
+ SNOWFLAKE_ROLE: Role name
93
+ SNOWFLAKE_PRIVATE_KEY_PATH: Path to private key
94
+ SNOWFLAKE_PRIVATE_KEY_PASSPHRASE: Private key passphrase
95
+ """
96
+ # Load from environment variables with fallbacks
97
+ self.account = account if account is not None else os.getenv("SNOWFLAKE_ACCOUNT")
98
+ self.warehouse = warehouse if warehouse is not None else os.getenv("SNOWFLAKE_WAREHOUSE")
99
+ self.database_name = database if database is not None else os.getenv("SNOWFLAKE_DATABASE")
100
+ self.schema = schema if schema != "PUBLIC" else os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
101
+ self.user = user if user is not None else os.getenv("SNOWFLAKE_USER")
102
+ self.password = password if password is not None else os.getenv("SNOWFLAKE_PASSWORD")
103
+ self.role = role if role is not None else os.getenv("SNOWFLAKE_ROLE")
104
+ self.private_key_path = private_key_path if private_key_path is not None else os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH")
105
+ self.private_key_passphrase = private_key_passphrase if private_key_passphrase is not None else os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE")
106
+ self.timeout = timeout
107
+
108
+ # Validate required parameters
109
+ if not self.account:
110
+ raise ValueError("Snowflake account is required. Provide 'account' parameter or set SNOWFLAKE_ACCOUNT environment variable.")
111
+ if not self.user:
112
+ raise ValueError("Snowflake user is required. Provide 'user' parameter or set SNOWFLAKE_USER environment variable.")
113
+ if not self.warehouse:
114
+ raise ValueError("Snowflake warehouse is required. Provide 'warehouse' parameter or set SNOWFLAKE_WAREHOUSE environment variable.")
115
+ if not self.database_name:
116
+ raise ValueError("Snowflake database is required. Provide 'database' parameter or set SNOWFLAKE_DATABASE environment variable.")
117
+
118
+ # Validate authentication credentials
119
+ if not self.password and not self.private_key_path:
120
+ raise ValueError("Authentication required: provide either 'password' or 'private_key_path' parameter.")
121
+
122
+ # Build connection configuration
123
+ self.connection_config = {
124
+ 'account': self.account,
125
+ 'warehouse': self.warehouse,
126
+ 'database': self.database_name,
127
+ 'schema': self.schema,
128
+ 'user': self.user,
129
+ 'network_timeout': timeout,
130
+ 'login_timeout': 60,
131
+ }
132
+
133
+ # Add role if specified
134
+ if self.role:
135
+ self.connection_config['role'] = self.role
136
+
137
+ # Add authentication (handled in connect method)
138
+ self._use_key_pair = bool(self.private_key_path)
139
+
140
+ # Call parent constructor
141
+ super().__init__(
142
+ account=self.account,
143
+ warehouse=self.warehouse,
144
+ database=self.database_name,
145
+ schema=self.schema,
146
+ user=self.user,
147
+ role=self.role,
148
+ **kwargs
149
+ )
150
+
151
+ logger.debug(f"Snowflake plugin configured for {self.account}/{self.database_name} (auth: {'key-pair' if self._use_key_pair else 'password'})")
152
+
153
+ def _load_private_key(self):
154
+ """Load and decode private key for key-pair authentication."""
155
+ try:
156
+ from cryptography.hazmat.backends import default_backend
157
+ from cryptography.hazmat.primitives import serialization
158
+
159
+ with open(self.private_key_path, 'rb') as key_file:
160
+ private_key_data = key_file.read()
161
+
162
+ # Load private key with optional passphrase
163
+ passphrase = self.private_key_passphrase.encode() if self.private_key_passphrase else None
164
+
165
+ private_key = serialization.load_pem_private_key(
166
+ private_key_data,
167
+ password=passphrase,
168
+ backend=default_backend()
169
+ )
170
+
171
+ # Get private key bytes in DER format
172
+ private_key_bytes = private_key.private_bytes(
173
+ encoding=serialization.Encoding.DER,
174
+ format=serialization.PrivateFormat.PKCS8,
175
+ encryption_algorithm=serialization.NoEncryption()
176
+ )
177
+
178
+ return private_key_bytes
179
+
180
+ except ImportError:
181
+ raise ImportError(
182
+ "cryptography library is required for key-pair authentication. "
183
+ "Install with: pip install 'snowflake-connector-python[secure-local-storage]'"
184
+ )
185
+ except FileNotFoundError:
186
+ raise FileNotFoundError(f"Private key file not found: {self.private_key_path}")
187
+ except Exception as e:
188
+ raise ValueError(f"Failed to load private key: {str(e)}")
189
+
190
+ async def connect(self):
191
+ """
192
+ Connect to Snowflake.
193
+
194
+ Establishes connection using either password or key-pair authentication.
195
+ Connection is idempotent - won't create duplicate connections.
196
+ """
197
+ if self._connection is not None:
198
+ return # Already connected
199
+
200
+ try:
201
+ import snowflake.connector
202
+
203
+ # Add authentication credentials
204
+ config = self.connection_config.copy()
205
+
206
+ if self._use_key_pair:
207
+ # Use key-pair authentication
208
+ private_key_bytes = self._load_private_key()
209
+ config['private_key'] = private_key_bytes
210
+ logger.debug("Using key-pair authentication")
211
+ else:
212
+ # Use password authentication
213
+ config['password'] = self.password
214
+ logger.debug("Using password authentication")
215
+
216
+ # Create connection
217
+ self._connection = snowflake.connector.connect(**config)
218
+
219
+ logger.info(f"Connected to Snowflake: {self.account}/{self.database_name}.{self.schema} (warehouse: {self.warehouse})")
220
+
221
+ except ImportError:
222
+ self._handle_connection_error(
223
+ ImportError(
224
+ "snowflake-connector-python not installed. "
225
+ "Install with: pip install snowflake-connector-python"
226
+ ),
227
+ "connection"
228
+ )
229
+ except Exception as e:
230
+ self._handle_connection_error(e, "connection")
231
+
232
+ async def disconnect(self):
233
+ """
234
+ Disconnect from Snowflake.
235
+
236
+ Closes the connection and releases resources.
237
+ """
238
+ if self._connection:
239
+ try:
240
+ self._connection.close()
241
+ self._connection = None
242
+ logger.info("Disconnected from Snowflake")
243
+ except Exception as e:
244
+ logger.warning(f"Error during disconnect: {str(e)}")
245
+ self._connection = None
246
+
247
+ async def query(self, sql: str, params: Optional[List] = None) -> List[Dict[str, Any]]:
248
+ """
249
+ Run a SELECT query and return results.
250
+
251
+ Args:
252
+ sql: SQL query with %s or %(name)s placeholders
253
+ params: List or dict of parameters for the query
254
+
255
+ Returns:
256
+ List of rows as dictionaries
257
+
258
+ Example:
259
+ results = await db.query("SELECT * FROM users WHERE age > %s", [25])
260
+ results = await db.query("SELECT * FROM users WHERE age > %(min_age)s", {"min_age": 25})
261
+ """
262
+ if self._connection is None:
263
+ await self.connect()
264
+
265
+ cursor = self._connection.cursor()
266
+
267
+ try:
268
+ if params:
269
+ cursor.execute(sql, params)
270
+ else:
271
+ cursor.execute(sql)
272
+
273
+ rows = cursor.fetchall()
274
+
275
+ # Handle empty result set
276
+ if not rows:
277
+ return []
278
+
279
+ # Convert to list of dictionaries
280
+ columns = [desc[0] for desc in cursor.description]
281
+ return [dict(zip(columns, row)) for row in rows]
282
+
283
+ finally:
284
+ cursor.close()
285
+
286
+ async def execute(self, sql: str, params: Optional[List] = None) -> int:
287
+ """
288
+ Execute INSERT/UPDATE/DELETE and return affected rows.
289
+
290
+ Args:
291
+ sql: SQL statement with %s or %(name)s placeholders
292
+ params: List or dict of parameters for the statement
293
+
294
+ Returns:
295
+ Number of affected rows
296
+
297
+ Example:
298
+ affected = await db.execute("INSERT INTO users (name, age) VALUES (%s, %s)", ["Alice", 30])
299
+ """
300
+ if self._connection is None:
301
+ await self.connect()
302
+
303
+ cursor = self._connection.cursor()
304
+
305
+ try:
306
+ if params:
307
+ cursor.execute(sql, params)
308
+ else:
309
+ cursor.execute(sql)
310
+
311
+ rowcount = cursor.rowcount
312
+ self._connection.commit()
313
+ return rowcount
314
+
315
+ finally:
316
+ cursor.close()
317
+
318
+ async def tables(self, schema: Optional[str] = None) -> List[str]:
319
+ """
320
+ List all tables in the database or specific schema.
321
+
322
+ Args:
323
+ schema: Schema name (defaults to current schema)
324
+
325
+ Returns:
326
+ List of table names
327
+ """
328
+ if self._connection is None:
329
+ await self.connect()
330
+
331
+ cursor = self._connection.cursor()
332
+
333
+ try:
334
+ if schema:
335
+ cursor.execute(f"SHOW TABLES IN SCHEMA {schema}")
336
+ else:
337
+ cursor.execute("SHOW TABLES")
338
+
339
+ rows = cursor.fetchall()
340
+ # SHOW TABLES returns multiple columns, table name is in 'name' column (index varies)
341
+ # Fetch as dictionaries to be safe
342
+ columns = [desc[0] for desc in cursor.description]
343
+ name_idx = columns.index('name')
344
+ return [row[name_idx] for row in rows]
345
+
346
+ finally:
347
+ cursor.close()
348
+
349
+ async def schemas(self) -> List[str]:
350
+ """
351
+ List all schemas in the database.
352
+
353
+ Returns:
354
+ List of schema names
355
+ """
356
+ if self._connection is None:
357
+ await self.connect()
358
+
359
+ cursor = self._connection.cursor()
360
+
361
+ try:
362
+ cursor.execute("SHOW SCHEMAS")
363
+ rows = cursor.fetchall()
364
+ columns = [desc[0] for desc in cursor.description]
365
+ name_idx = columns.index('name')
366
+ return [row[name_idx] for row in rows]
367
+
368
+ finally:
369
+ cursor.close()
370
+
371
+ async def describe(self, table: str) -> List[Dict[str, Any]]:
372
+ """
373
+ Get table column information.
374
+
375
+ Args:
376
+ table: Table name (optionally schema-qualified: schema.table)
377
+
378
+ Returns:
379
+ List of column details with name, type, nullable, default, etc.
380
+
381
+ Example:
382
+ columns = await db.describe("users")
383
+ for col in columns:
384
+ print(f"{col['name']}: {col['type']}")
385
+ """
386
+ if self._connection is None:
387
+ await self.connect()
388
+
389
+ cursor = self._connection.cursor()
390
+
391
+ try:
392
+ cursor.execute(f"DESCRIBE TABLE {table}")
393
+ rows = cursor.fetchall()
394
+ columns = [desc[0] for desc in cursor.description]
395
+ return [dict(zip(columns, row)) for row in rows]
396
+
397
+ finally:
398
+ cursor.close()
399
+
400
+ async def databases(self) -> List[str]:
401
+ """
402
+ List all accessible databases.
403
+
404
+ Returns:
405
+ List of database names
406
+ """
407
+ if self._connection is None:
408
+ await self.connect()
409
+
410
+ cursor = self._connection.cursor()
411
+
412
+ try:
413
+ cursor.execute("SHOW DATABASES")
414
+ rows = cursor.fetchall()
415
+ columns = [desc[0] for desc in cursor.description]
416
+ name_idx = columns.index('name')
417
+ return [row[name_idx] for row in rows]
418
+
419
+ finally:
420
+ cursor.close()
421
+
422
+ async def list_warehouses(self) -> List[Dict[str, Any]]:
423
+ """
424
+ List all available warehouses.
425
+
426
+ Returns:
427
+ List of warehouse details with name, state, size, etc.
428
+ """
429
+ if self._connection is None:
430
+ await self.connect()
431
+
432
+ cursor = self._connection.cursor()
433
+
434
+ try:
435
+ cursor.execute("SHOW WAREHOUSES")
436
+ rows = cursor.fetchall()
437
+ columns = [desc[0] for desc in cursor.description]
438
+ return [dict(zip(columns, row)) for row in rows]
439
+
440
+ finally:
441
+ cursor.close()
442
+
443
+ async def switch_warehouse(self, warehouse: str) -> None:
444
+ """
445
+ Switch to a different warehouse.
446
+
447
+ Args:
448
+ warehouse: Warehouse name to switch to
449
+ """
450
+ if self._connection is None:
451
+ await self.connect()
452
+
453
+ cursor = self._connection.cursor()
454
+
455
+ try:
456
+ cursor.execute(f"USE WAREHOUSE {warehouse}")
457
+ self.warehouse = warehouse
458
+ logger.info(f"Switched to warehouse: {warehouse}")
459
+
460
+ finally:
461
+ cursor.close()
462
+
463
+ async def get_current_warehouse(self) -> Dict[str, Any]:
464
+ """
465
+ Get current warehouse information.
466
+
467
+ Returns:
468
+ Dictionary with current warehouse details
469
+ """
470
+ result = await self.query("SELECT CURRENT_WAREHOUSE() as warehouse")
471
+ return result[0] if result else {"warehouse": None}
472
+
473
+ async def query_history(self, limit: int = 100) -> List[Dict[str, Any]]:
474
+ """
475
+ Get recent query history.
476
+
477
+ Args:
478
+ limit: Maximum number of queries to return (default: 100)
479
+
480
+ Returns:
481
+ List of query history records
482
+ """
483
+ sql = f"""
484
+ SELECT
485
+ query_id,
486
+ query_text,
487
+ database_name,
488
+ schema_name,
489
+ query_type,
490
+ warehouse_name,
491
+ user_name,
492
+ role_name,
493
+ execution_status,
494
+ error_message,
495
+ start_time,
496
+ end_time,
497
+ total_elapsed_time,
498
+ bytes_scanned,
499
+ rows_produced
500
+ FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY())
501
+ ORDER BY start_time DESC
502
+ LIMIT {limit}
503
+ """
504
+ return await self.query(sql)
505
+
506
+ async def get_warehouse_usage(self, warehouse: Optional[str] = None, days: int = 7) -> List[Dict[str, Any]]:
507
+ """
508
+ Get warehouse credit usage.
509
+
510
+ Args:
511
+ warehouse: Warehouse name (defaults to current warehouse)
512
+ days: Number of days to look back (default: 7)
513
+
514
+ Returns:
515
+ List of usage records
516
+ """
517
+ warehouse_filter = f"AND warehouse_name = '{warehouse}'" if warehouse else ""
518
+
519
+ sql = f"""
520
+ SELECT
521
+ warehouse_name,
522
+ DATE(start_time) as usage_date,
523
+ SUM(credits_used) as total_credits
524
+ FROM TABLE(INFORMATION_SCHEMA.WAREHOUSE_METERING_HISTORY(
525
+ DATE_RANGE_START => DATEADD(day, -{days}, CURRENT_DATE())
526
+ ))
527
+ WHERE 1=1 {warehouse_filter}
528
+ GROUP BY warehouse_name, usage_date
529
+ ORDER BY usage_date DESC, warehouse_name
530
+ """
531
+ return await self.query(sql)
532
+
533
+ async def list_stages(self) -> List[Dict[str, Any]]:
534
+ """
535
+ List all stages (internal and external).
536
+
537
+ Returns:
538
+ List of stage details
539
+ """
540
+ if self._connection is None:
541
+ await self.connect()
542
+
543
+ cursor = self._connection.cursor()
544
+
545
+ try:
546
+ cursor.execute("SHOW STAGES")
547
+ rows = cursor.fetchall()
548
+ columns = [desc[0] for desc in cursor.description]
549
+ return [dict(zip(columns, row)) for row in rows]
550
+
551
+ finally:
552
+ cursor.close()
553
+
554
+ async def put_file(self, local_path: str, stage_path: str, overwrite: bool = False) -> Dict[str, Any]:
555
+ """
556
+ Upload file to Snowflake stage.
557
+
558
+ Args:
559
+ local_path: Local file path
560
+ stage_path: Stage location (e.g., "@my_stage/path/")
561
+ overwrite: Whether to overwrite existing files
562
+
563
+ Returns:
564
+ Dictionary with upload results
565
+ """
566
+ if self._connection is None:
567
+ await self.connect()
568
+
569
+ cursor = self._connection.cursor()
570
+
571
+ try:
572
+ overwrite_str = "OVERWRITE = TRUE" if overwrite else ""
573
+ sql = f"PUT 'file://{local_path}' {stage_path} {overwrite_str}"
574
+ cursor.execute(sql)
575
+
576
+ rows = cursor.fetchall()
577
+ columns = [desc[0] for desc in cursor.description]
578
+ results = [dict(zip(columns, row)) for row in rows]
579
+
580
+ return {
581
+ "success": True,
582
+ "files_uploaded": len(results),
583
+ "details": results
584
+ }
585
+
586
+ finally:
587
+ cursor.close()
588
+
589
+ async def get_file(self, stage_path: str, local_path: str) -> Dict[str, Any]:
590
+ """
591
+ Download file from Snowflake stage.
592
+
593
+ Args:
594
+ stage_path: Stage file location (e.g., "@my_stage/path/file.csv")
595
+ local_path: Local directory to download to
596
+
597
+ Returns:
598
+ Dictionary with download results
599
+ """
600
+ if self._connection is None:
601
+ await self.connect()
602
+
603
+ cursor = self._connection.cursor()
604
+
605
+ try:
606
+ sql = f"GET {stage_path} 'file://{local_path}'"
607
+ cursor.execute(sql)
608
+
609
+ rows = cursor.fetchall()
610
+ columns = [desc[0] for desc in cursor.description]
611
+ results = [dict(zip(columns, row)) for row in rows]
612
+
613
+ return {
614
+ "success": True,
615
+ "files_downloaded": len(results),
616
+ "details": results
617
+ }
618
+
619
+ finally:
620
+ cursor.close()
621
+
622
+ async def load_from_stage(
623
+ self,
624
+ table: str,
625
+ stage: str,
626
+ file_format: str = "CSV",
627
+ pattern: Optional[str] = None,
628
+ on_error: str = "ABORT_STATEMENT"
629
+ ) -> Dict[str, Any]:
630
+ """
631
+ Load data from stage into table.
632
+
633
+ Args:
634
+ table: Target table name
635
+ stage: Stage location (e.g., "@my_stage/path/")
636
+ file_format: File format name or type (default: "CSV")
637
+ pattern: File pattern to match (e.g., r".*\.csv")
638
+ on_error: Error handling: ABORT_STATEMENT, CONTINUE, SKIP_FILE
639
+
640
+ Returns:
641
+ Dictionary with load results
642
+ """
643
+ if self._connection is None:
644
+ await self.connect()
645
+
646
+ cursor = self._connection.cursor()
647
+
648
+ try:
649
+ pattern_clause = f"PATTERN = '{pattern}'" if pattern else ""
650
+
651
+ sql = f"""
652
+ COPY INTO {table}
653
+ FROM {stage}
654
+ FILE_FORMAT = (TYPE = {file_format})
655
+ {pattern_clause}
656
+ ON_ERROR = {on_error}
657
+ """
658
+
659
+ cursor.execute(sql)
660
+
661
+ rows = cursor.fetchall()
662
+ columns = [desc[0] for desc in cursor.description]
663
+ results = [dict(zip(columns, row)) for row in rows]
664
+
665
+ # Calculate summary
666
+ rows_loaded = sum(row.get('rows_loaded', 0) or 0 for row in results)
667
+
668
+ return {
669
+ "success": True,
670
+ "rows_loaded": rows_loaded,
671
+ "files_processed": len(results),
672
+ "details": results
673
+ }
674
+
675
+ finally:
676
+ cursor.close()
677
+
678
+ async def create_stage(
679
+ self,
680
+ name: str,
681
+ url: Optional[str] = None,
682
+ storage_integration: Optional[str] = None,
683
+ credentials: Optional[Dict[str, str]] = None
684
+ ) -> None:
685
+ """
686
+ Create a new stage (internal or external).
687
+
688
+ Args:
689
+ name: Stage name
690
+ url: External URL (e.g., "s3://bucket/path/") for external stages
691
+ storage_integration: Storage integration name for cloud storage
692
+ credentials: Cloud credentials (e.g., {"AWS_KEY_ID": "...", "AWS_SECRET_KEY": "..."})
693
+
694
+ Example:
695
+ # Internal stage
696
+ await db.create_stage("my_internal_stage")
697
+
698
+ # External S3 stage
699
+ await db.create_stage(
700
+ "my_s3_stage",
701
+ url="s3://my-bucket/data/",
702
+ credentials={"AWS_KEY_ID": "...", "AWS_SECRET_KEY": "..."}
703
+ )
704
+ """
705
+ if self._connection is None:
706
+ await self.connect()
707
+
708
+ cursor = self._connection.cursor()
709
+
710
+ try:
711
+ if url:
712
+ # External stage
713
+ sql = f"CREATE STAGE IF NOT EXISTS {name} URL = '{url}'"
714
+
715
+ if storage_integration:
716
+ sql += f" STORAGE_INTEGRATION = {storage_integration}"
717
+ elif credentials:
718
+ creds_str = " ".join([f"{k} = '{v}'" for k, v in credentials.items()])
719
+ sql += f" CREDENTIALS = ({creds_str})"
720
+ else:
721
+ # Internal stage
722
+ sql = f"CREATE STAGE IF NOT EXISTS {name}"
723
+
724
+ cursor.execute(sql)
725
+ logger.info(f"Created stage: {name}")
726
+
727
+ finally:
728
+ cursor.close()
729
+
730
+ def get_tools(self) -> List['AgentTool']:
731
+ """
732
+ Expose Snowflake operations as agent tools.
733
+
734
+ Returns:
735
+ List of AgentTool instances for Snowflake operations
736
+ """
737
+ from ..core.tools import AgentTool
738
+
739
+ return [
740
+ AgentTool(
741
+ name="query_database",
742
+ description="Execute a SQL SELECT query on Snowflake and return results as a list of dictionaries",
743
+ parameters={
744
+ "type": "object",
745
+ "properties": {
746
+ "sql": {
747
+ "type": "string",
748
+ "description": "SQL SELECT query with %s placeholders for parameters"
749
+ },
750
+ "params": {
751
+ "type": "array",
752
+ "description": "Optional list of parameter values for query placeholders",
753
+ "items": {"type": "string"}
754
+ }
755
+ },
756
+ "required": ["sql"]
757
+ },
758
+ handler=self._tool_query,
759
+ category="database",
760
+ source="plugin",
761
+ plugin_name="Snowflake",
762
+ timeout_seconds=60
763
+ ),
764
+ AgentTool(
765
+ name="execute_sql",
766
+ description="Execute a SQL INSERT, UPDATE, or DELETE statement on Snowflake and return the number of affected rows",
767
+ parameters={
768
+ "type": "object",
769
+ "properties": {
770
+ "sql": {
771
+ "type": "string",
772
+ "description": "SQL statement with %s placeholders for parameters"
773
+ },
774
+ "params": {
775
+ "type": "array",
776
+ "description": "Optional list of parameter values for statement placeholders",
777
+ "items": {"type": "string"}
778
+ }
779
+ },
780
+ "required": ["sql"]
781
+ },
782
+ handler=self._tool_execute,
783
+ category="database",
784
+ source="plugin",
785
+ plugin_name="Snowflake",
786
+ timeout_seconds=60
787
+ ),
788
+ AgentTool(
789
+ name="list_tables",
790
+ description="List all tables in the Snowflake database or a specific schema",
791
+ parameters={
792
+ "type": "object",
793
+ "properties": {
794
+ "schema": {
795
+ "type": "string",
796
+ "description": "Optional schema name to list tables from (defaults to current schema)"
797
+ }
798
+ },
799
+ "required": []
800
+ },
801
+ handler=self._tool_list_tables,
802
+ category="database",
803
+ source="plugin",
804
+ plugin_name="Snowflake",
805
+ timeout_seconds=30
806
+ ),
807
+ AgentTool(
808
+ name="list_schemas",
809
+ description="List all schemas in the Snowflake database",
810
+ parameters={
811
+ "type": "object",
812
+ "properties": {},
813
+ "required": []
814
+ },
815
+ handler=self._tool_list_schemas,
816
+ category="database",
817
+ source="plugin",
818
+ plugin_name="Snowflake",
819
+ timeout_seconds=30
820
+ ),
821
+ AgentTool(
822
+ name="get_table_schema",
823
+ description="Get detailed column information for a Snowflake table including column names, types, and constraints",
824
+ parameters={
825
+ "type": "object",
826
+ "properties": {
827
+ "table": {
828
+ "type": "string",
829
+ "description": "Table name (optionally schema-qualified like 'schema.table')"
830
+ }
831
+ },
832
+ "required": ["table"]
833
+ },
834
+ handler=self._tool_get_table_schema,
835
+ category="database",
836
+ source="plugin",
837
+ plugin_name="Snowflake",
838
+ timeout_seconds=30
839
+ ),
840
+ AgentTool(
841
+ name="list_warehouses",
842
+ description="List all available Snowflake compute warehouses with their status and configuration",
843
+ parameters={
844
+ "type": "object",
845
+ "properties": {},
846
+ "required": []
847
+ },
848
+ handler=self._tool_list_warehouses,
849
+ category="database",
850
+ source="plugin",
851
+ plugin_name="Snowflake",
852
+ timeout_seconds=30
853
+ ),
854
+ AgentTool(
855
+ name="switch_warehouse",
856
+ description="Switch to a different Snowflake compute warehouse for subsequent queries",
857
+ parameters={
858
+ "type": "object",
859
+ "properties": {
860
+ "warehouse": {
861
+ "type": "string",
862
+ "description": "Name of the warehouse to switch to"
863
+ }
864
+ },
865
+ "required": ["warehouse"]
866
+ },
867
+ handler=self._tool_switch_warehouse,
868
+ category="database",
869
+ source="plugin",
870
+ plugin_name="Snowflake",
871
+ timeout_seconds=30
872
+ ),
873
+ AgentTool(
874
+ name="get_query_history",
875
+ description="Get recent query history from Snowflake including execution status, timing, and resource usage",
876
+ parameters={
877
+ "type": "object",
878
+ "properties": {
879
+ "limit": {
880
+ "type": "integer",
881
+ "description": "Maximum number of queries to return (default: 100)",
882
+ "default": 100
883
+ }
884
+ },
885
+ "required": []
886
+ },
887
+ handler=self._tool_get_query_history,
888
+ category="database",
889
+ source="plugin",
890
+ plugin_name="Snowflake",
891
+ timeout_seconds=45
892
+ ),
893
+ AgentTool(
894
+ name="list_stages",
895
+ description="List all Snowflake stages (internal and external) for data loading",
896
+ parameters={
897
+ "type": "object",
898
+ "properties": {},
899
+ "required": []
900
+ },
901
+ handler=self._tool_list_stages,
902
+ category="database",
903
+ source="plugin",
904
+ plugin_name="Snowflake",
905
+ timeout_seconds=30
906
+ ),
907
+ AgentTool(
908
+ name="upload_to_stage",
909
+ description="Upload a local file to a Snowflake stage for data loading",
910
+ parameters={
911
+ "type": "object",
912
+ "properties": {
913
+ "local_path": {
914
+ "type": "string",
915
+ "description": "Local file path to upload"
916
+ },
917
+ "stage_path": {
918
+ "type": "string",
919
+ "description": "Stage location (e.g., '@my_stage/path/')"
920
+ },
921
+ "overwrite": {
922
+ "type": "boolean",
923
+ "description": "Whether to overwrite existing files (default: false)",
924
+ "default": False
925
+ }
926
+ },
927
+ "required": ["local_path", "stage_path"]
928
+ },
929
+ handler=self._tool_upload_to_stage,
930
+ category="database",
931
+ source="plugin",
932
+ plugin_name="Snowflake",
933
+ timeout_seconds=120
934
+ ),
935
+ AgentTool(
936
+ name="load_from_stage",
937
+ description="Load data from a Snowflake stage into a table using COPY INTO command",
938
+ parameters={
939
+ "type": "object",
940
+ "properties": {
941
+ "table": {
942
+ "type": "string",
943
+ "description": "Target table name"
944
+ },
945
+ "stage": {
946
+ "type": "string",
947
+ "description": "Stage location (e.g., '@my_stage/path/')"
948
+ },
949
+ "file_format": {
950
+ "type": "string",
951
+ "description": "File format type (default: CSV)",
952
+ "default": "CSV"
953
+ },
954
+ "pattern": {
955
+ "type": "string",
956
+ "description": "Optional file pattern to match (e.g., '.*\\.csv')"
957
+ }
958
+ },
959
+ "required": ["table", "stage"]
960
+ },
961
+ handler=self._tool_load_from_stage,
962
+ category="database",
963
+ source="plugin",
964
+ plugin_name="Snowflake",
965
+ timeout_seconds=180
966
+ ),
967
+ AgentTool(
968
+ name="create_stage",
969
+ description="Create a new Snowflake stage (internal or external) for data loading",
970
+ parameters={
971
+ "type": "object",
972
+ "properties": {
973
+ "name": {
974
+ "type": "string",
975
+ "description": "Stage name"
976
+ },
977
+ "url": {
978
+ "type": "string",
979
+ "description": "External URL for external stages (e.g., 's3://bucket/path/')"
980
+ },
981
+ "storage_integration": {
982
+ "type": "string",
983
+ "description": "Storage integration name for cloud storage"
984
+ }
985
+ },
986
+ "required": ["name"]
987
+ },
988
+ handler=self._tool_create_stage,
989
+ category="database",
990
+ source="plugin",
991
+ plugin_name="Snowflake",
992
+ timeout_seconds=30
993
+ ),
994
+ ]
995
+
996
+ # Tool handler methods
997
+
998
+ async def _tool_query(self, args: Dict[str, Any]) -> Dict[str, Any]:
999
+ """Tool handler for query_database"""
1000
+ sql = args.get("sql")
1001
+ params = args.get("params")
1002
+
1003
+ results = await self.query(sql, params)
1004
+
1005
+ return {
1006
+ "success": True,
1007
+ "rows": results,
1008
+ "row_count": len(results)
1009
+ }
1010
+
1011
+ async def _tool_execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
1012
+ """Tool handler for execute_sql"""
1013
+ sql = args.get("sql")
1014
+ params = args.get("params")
1015
+
1016
+ affected_rows = await self.execute(sql, params)
1017
+
1018
+ return {
1019
+ "success": True,
1020
+ "affected_rows": affected_rows
1021
+ }
1022
+
1023
+ async def _tool_list_tables(self, args: Dict[str, Any]) -> Dict[str, Any]:
1024
+ """Tool handler for list_tables"""
1025
+ schema = args.get("schema")
1026
+
1027
+ tables = await self.tables(schema)
1028
+
1029
+ return {
1030
+ "success": True,
1031
+ "tables": tables,
1032
+ "count": len(tables)
1033
+ }
1034
+
1035
+ async def _tool_list_schemas(self, args: Dict[str, Any]) -> Dict[str, Any]:
1036
+ """Tool handler for list_schemas"""
1037
+ schemas = await self.schemas()
1038
+
1039
+ return {
1040
+ "success": True,
1041
+ "schemas": schemas,
1042
+ "count": len(schemas)
1043
+ }
1044
+
1045
+ async def _tool_get_table_schema(self, args: Dict[str, Any]) -> Dict[str, Any]:
1046
+ """Tool handler for get_table_schema"""
1047
+ table = args.get("table")
1048
+
1049
+ columns = await self.describe(table)
1050
+
1051
+ return {
1052
+ "success": True,
1053
+ "table": table,
1054
+ "columns": columns,
1055
+ "column_count": len(columns)
1056
+ }
1057
+
1058
+ async def _tool_list_warehouses(self, args: Dict[str, Any]) -> Dict[str, Any]:
1059
+ """Tool handler for list_warehouses"""
1060
+ warehouses = await self.list_warehouses()
1061
+
1062
+ return {
1063
+ "success": True,
1064
+ "warehouses": warehouses,
1065
+ "count": len(warehouses)
1066
+ }
1067
+
1068
+ async def _tool_switch_warehouse(self, args: Dict[str, Any]) -> Dict[str, Any]:
1069
+ """Tool handler for switch_warehouse"""
1070
+ warehouse = args.get("warehouse")
1071
+
1072
+ await self.switch_warehouse(warehouse)
1073
+
1074
+ return {
1075
+ "success": True,
1076
+ "message": f"Switched to warehouse: {warehouse}",
1077
+ "warehouse": warehouse
1078
+ }
1079
+
1080
+ async def _tool_get_query_history(self, args: Dict[str, Any]) -> Dict[str, Any]:
1081
+ """Tool handler for get_query_history"""
1082
+ limit = args.get("limit", 100)
1083
+
1084
+ history = await self.query_history(limit)
1085
+
1086
+ return {
1087
+ "success": True,
1088
+ "queries": history,
1089
+ "count": len(history)
1090
+ }
1091
+
1092
+ async def _tool_list_stages(self, args: Dict[str, Any]) -> Dict[str, Any]:
1093
+ """Tool handler for list_stages"""
1094
+ stages = await self.list_stages()
1095
+
1096
+ return {
1097
+ "success": True,
1098
+ "stages": stages,
1099
+ "count": len(stages)
1100
+ }
1101
+
1102
+ async def _tool_upload_to_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
1103
+ """Tool handler for upload_to_stage"""
1104
+ local_path = args.get("local_path")
1105
+ stage_path = args.get("stage_path")
1106
+ overwrite = args.get("overwrite", False)
1107
+
1108
+ result = await self.put_file(local_path, stage_path, overwrite)
1109
+
1110
+ return result
1111
+
1112
+ async def _tool_load_from_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
1113
+ """Tool handler for load_from_stage"""
1114
+ table = args.get("table")
1115
+ stage = args.get("stage")
1116
+ file_format = args.get("file_format", "CSV")
1117
+ pattern = args.get("pattern")
1118
+
1119
+ result = await self.load_from_stage(table, stage, file_format, pattern)
1120
+
1121
+ return result
1122
+
1123
+ async def _tool_create_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
1124
+ """Tool handler for create_stage"""
1125
+ name = args.get("name")
1126
+ url = args.get("url")
1127
+ storage_integration = args.get("storage_integration")
1128
+
1129
+ await self.create_stage(name, url, storage_integration)
1130
+
1131
+ return {
1132
+ "success": True,
1133
+ "message": f"Created stage: {name}",
1134
+ "stage": name
1135
+ }
1136
+
1137
+
1138
+ def snowflake(**kwargs) -> SnowflakePlugin:
1139
+ """
1140
+ Create Snowflake plugin with simplified interface.
1141
+
1142
+ Args:
1143
+ **kwargs: Connection parameters (account, warehouse, database, etc.)
1144
+
1145
+ Returns:
1146
+ SnowflakePlugin instance
1147
+
1148
+ Example:
1149
+ from daita.plugins import snowflake
1150
+
1151
+ db = snowflake(
1152
+ account="xy12345",
1153
+ warehouse="COMPUTE_WH",
1154
+ database="MYDB",
1155
+ user="myuser",
1156
+ password="mypass"
1157
+ )
1158
+ """
1159
+ return SnowflakePlugin(**kwargs)