daita-agents 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/cli/__init__.py +2 -3
- daita/cli/core/deployments.py +43 -62
- daita/cli/core/logs.py +2 -2
- daita/cli/core/managed_deploy.py +1 -12
- daita/cli/core/run.py +133 -68
- daita/cli/core/status.py +54 -10
- daita/cli/main.py +16 -38
- daita/cli/utils.py +1 -1
- daita/plugins/__init__.py +14 -2
- daita/plugins/snowflake.py +1159 -0
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/METADATA +5 -1
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/RECORD +16 -15
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/WHEEL +0 -0
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/entry_points.txt +0 -0
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {daita_agents-0.2.0.dist-info → daita_agents-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Snowflake plugin for Daita Agents.
|
|
3
|
+
|
|
4
|
+
Provides Snowflake data warehouse connection and querying capabilities.
|
|
5
|
+
Supports key-pair authentication, warehouse management, and stage operations.
|
|
6
|
+
"""
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
|
10
|
+
from .base_db import BaseDatabasePlugin
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..core.tools import AgentTool
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SnowflakePlugin(BaseDatabasePlugin):
|
|
19
|
+
"""
|
|
20
|
+
Snowflake plugin for agents with warehouse management and stage operations.
|
|
21
|
+
|
|
22
|
+
Inherits common database functionality from BaseDatabasePlugin and adds
|
|
23
|
+
Snowflake-specific features like warehouse switching and stage data loading.
|
|
24
|
+
|
|
25
|
+
Supports both password and key-pair authentication.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
from daita.plugins import snowflake
|
|
29
|
+
|
|
30
|
+
# Password authentication
|
|
31
|
+
db = snowflake(
|
|
32
|
+
account="xy12345",
|
|
33
|
+
warehouse="COMPUTE_WH",
|
|
34
|
+
database="MYDB",
|
|
35
|
+
user="myuser",
|
|
36
|
+
password="mypass"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Key-pair authentication
|
|
40
|
+
db = snowflake(
|
|
41
|
+
account="xy12345",
|
|
42
|
+
warehouse="COMPUTE_WH",
|
|
43
|
+
database="MYDB",
|
|
44
|
+
user="myuser",
|
|
45
|
+
private_key_path="/path/to/key.p8"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Use with agent
|
|
49
|
+
agent = SubstrateAgent(
|
|
50
|
+
name="Data Analyst",
|
|
51
|
+
tools=[db]
|
|
52
|
+
)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
account: Optional[str] = None,
|
|
58
|
+
warehouse: Optional[str] = None,
|
|
59
|
+
database: Optional[str] = None,
|
|
60
|
+
schema: str = "PUBLIC",
|
|
61
|
+
user: Optional[str] = None,
|
|
62
|
+
password: Optional[str] = None,
|
|
63
|
+
role: Optional[str] = None,
|
|
64
|
+
private_key_path: Optional[str] = None,
|
|
65
|
+
private_key_passphrase: Optional[str] = None,
|
|
66
|
+
timeout: int = 300,
|
|
67
|
+
**kwargs
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Initialize Snowflake connection.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
account: Snowflake account identifier (e.g., "xy12345")
|
|
74
|
+
warehouse: Compute warehouse name
|
|
75
|
+
database: Database name
|
|
76
|
+
schema: Schema name (default: "PUBLIC")
|
|
77
|
+
user: Username for authentication
|
|
78
|
+
password: Password for authentication (optional if using key-pair)
|
|
79
|
+
role: Role to use for the session
|
|
80
|
+
private_key_path: Path to private key file for key-pair auth
|
|
81
|
+
private_key_passphrase: Passphrase for encrypted private key
|
|
82
|
+
timeout: Query timeout in seconds (default: 300)
|
|
83
|
+
**kwargs: Additional configuration options
|
|
84
|
+
|
|
85
|
+
Environment variables:
|
|
86
|
+
SNOWFLAKE_ACCOUNT: Account identifier
|
|
87
|
+
SNOWFLAKE_WAREHOUSE: Warehouse name
|
|
88
|
+
SNOWFLAKE_DATABASE: Database name
|
|
89
|
+
SNOWFLAKE_SCHEMA: Schema name
|
|
90
|
+
SNOWFLAKE_USER: Username
|
|
91
|
+
SNOWFLAKE_PASSWORD: Password
|
|
92
|
+
SNOWFLAKE_ROLE: Role name
|
|
93
|
+
SNOWFLAKE_PRIVATE_KEY_PATH: Path to private key
|
|
94
|
+
SNOWFLAKE_PRIVATE_KEY_PASSPHRASE: Private key passphrase
|
|
95
|
+
"""
|
|
96
|
+
# Load from environment variables with fallbacks
|
|
97
|
+
self.account = account if account is not None else os.getenv("SNOWFLAKE_ACCOUNT")
|
|
98
|
+
self.warehouse = warehouse if warehouse is not None else os.getenv("SNOWFLAKE_WAREHOUSE")
|
|
99
|
+
self.database_name = database if database is not None else os.getenv("SNOWFLAKE_DATABASE")
|
|
100
|
+
self.schema = schema if schema != "PUBLIC" else os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
|
101
|
+
self.user = user if user is not None else os.getenv("SNOWFLAKE_USER")
|
|
102
|
+
self.password = password if password is not None else os.getenv("SNOWFLAKE_PASSWORD")
|
|
103
|
+
self.role = role if role is not None else os.getenv("SNOWFLAKE_ROLE")
|
|
104
|
+
self.private_key_path = private_key_path if private_key_path is not None else os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH")
|
|
105
|
+
self.private_key_passphrase = private_key_passphrase if private_key_passphrase is not None else os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE")
|
|
106
|
+
self.timeout = timeout
|
|
107
|
+
|
|
108
|
+
# Validate required parameters
|
|
109
|
+
if not self.account:
|
|
110
|
+
raise ValueError("Snowflake account is required. Provide 'account' parameter or set SNOWFLAKE_ACCOUNT environment variable.")
|
|
111
|
+
if not self.user:
|
|
112
|
+
raise ValueError("Snowflake user is required. Provide 'user' parameter or set SNOWFLAKE_USER environment variable.")
|
|
113
|
+
if not self.warehouse:
|
|
114
|
+
raise ValueError("Snowflake warehouse is required. Provide 'warehouse' parameter or set SNOWFLAKE_WAREHOUSE environment variable.")
|
|
115
|
+
if not self.database_name:
|
|
116
|
+
raise ValueError("Snowflake database is required. Provide 'database' parameter or set SNOWFLAKE_DATABASE environment variable.")
|
|
117
|
+
|
|
118
|
+
# Validate authentication credentials
|
|
119
|
+
if not self.password and not self.private_key_path:
|
|
120
|
+
raise ValueError("Authentication required: provide either 'password' or 'private_key_path' parameter.")
|
|
121
|
+
|
|
122
|
+
# Build connection configuration
|
|
123
|
+
self.connection_config = {
|
|
124
|
+
'account': self.account,
|
|
125
|
+
'warehouse': self.warehouse,
|
|
126
|
+
'database': self.database_name,
|
|
127
|
+
'schema': self.schema,
|
|
128
|
+
'user': self.user,
|
|
129
|
+
'network_timeout': timeout,
|
|
130
|
+
'login_timeout': 60,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Add role if specified
|
|
134
|
+
if self.role:
|
|
135
|
+
self.connection_config['role'] = self.role
|
|
136
|
+
|
|
137
|
+
# Add authentication (handled in connect method)
|
|
138
|
+
self._use_key_pair = bool(self.private_key_path)
|
|
139
|
+
|
|
140
|
+
# Call parent constructor
|
|
141
|
+
super().__init__(
|
|
142
|
+
account=self.account,
|
|
143
|
+
warehouse=self.warehouse,
|
|
144
|
+
database=self.database_name,
|
|
145
|
+
schema=self.schema,
|
|
146
|
+
user=self.user,
|
|
147
|
+
role=self.role,
|
|
148
|
+
**kwargs
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
logger.debug(f"Snowflake plugin configured for {self.account}/{self.database_name} (auth: {'key-pair' if self._use_key_pair else 'password'})")
|
|
152
|
+
|
|
153
|
+
def _load_private_key(self):
|
|
154
|
+
"""Load and decode private key for key-pair authentication."""
|
|
155
|
+
try:
|
|
156
|
+
from cryptography.hazmat.backends import default_backend
|
|
157
|
+
from cryptography.hazmat.primitives import serialization
|
|
158
|
+
|
|
159
|
+
with open(self.private_key_path, 'rb') as key_file:
|
|
160
|
+
private_key_data = key_file.read()
|
|
161
|
+
|
|
162
|
+
# Load private key with optional passphrase
|
|
163
|
+
passphrase = self.private_key_passphrase.encode() if self.private_key_passphrase else None
|
|
164
|
+
|
|
165
|
+
private_key = serialization.load_pem_private_key(
|
|
166
|
+
private_key_data,
|
|
167
|
+
password=passphrase,
|
|
168
|
+
backend=default_backend()
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Get private key bytes in DER format
|
|
172
|
+
private_key_bytes = private_key.private_bytes(
|
|
173
|
+
encoding=serialization.Encoding.DER,
|
|
174
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
175
|
+
encryption_algorithm=serialization.NoEncryption()
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return private_key_bytes
|
|
179
|
+
|
|
180
|
+
except ImportError:
|
|
181
|
+
raise ImportError(
|
|
182
|
+
"cryptography library is required for key-pair authentication. "
|
|
183
|
+
"Install with: pip install 'snowflake-connector-python[secure-local-storage]'"
|
|
184
|
+
)
|
|
185
|
+
except FileNotFoundError:
|
|
186
|
+
raise FileNotFoundError(f"Private key file not found: {self.private_key_path}")
|
|
187
|
+
except Exception as e:
|
|
188
|
+
raise ValueError(f"Failed to load private key: {str(e)}")
|
|
189
|
+
|
|
190
|
+
async def connect(self):
|
|
191
|
+
"""
|
|
192
|
+
Connect to Snowflake.
|
|
193
|
+
|
|
194
|
+
Establishes connection using either password or key-pair authentication.
|
|
195
|
+
Connection is idempotent - won't create duplicate connections.
|
|
196
|
+
"""
|
|
197
|
+
if self._connection is not None:
|
|
198
|
+
return # Already connected
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
import snowflake.connector
|
|
202
|
+
|
|
203
|
+
# Add authentication credentials
|
|
204
|
+
config = self.connection_config.copy()
|
|
205
|
+
|
|
206
|
+
if self._use_key_pair:
|
|
207
|
+
# Use key-pair authentication
|
|
208
|
+
private_key_bytes = self._load_private_key()
|
|
209
|
+
config['private_key'] = private_key_bytes
|
|
210
|
+
logger.debug("Using key-pair authentication")
|
|
211
|
+
else:
|
|
212
|
+
# Use password authentication
|
|
213
|
+
config['password'] = self.password
|
|
214
|
+
logger.debug("Using password authentication")
|
|
215
|
+
|
|
216
|
+
# Create connection
|
|
217
|
+
self._connection = snowflake.connector.connect(**config)
|
|
218
|
+
|
|
219
|
+
logger.info(f"Connected to Snowflake: {self.account}/{self.database_name}.{self.schema} (warehouse: {self.warehouse})")
|
|
220
|
+
|
|
221
|
+
except ImportError:
|
|
222
|
+
self._handle_connection_error(
|
|
223
|
+
ImportError(
|
|
224
|
+
"snowflake-connector-python not installed. "
|
|
225
|
+
"Install with: pip install snowflake-connector-python"
|
|
226
|
+
),
|
|
227
|
+
"connection"
|
|
228
|
+
)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
self._handle_connection_error(e, "connection")
|
|
231
|
+
|
|
232
|
+
async def disconnect(self):
|
|
233
|
+
"""
|
|
234
|
+
Disconnect from Snowflake.
|
|
235
|
+
|
|
236
|
+
Closes the connection and releases resources.
|
|
237
|
+
"""
|
|
238
|
+
if self._connection:
|
|
239
|
+
try:
|
|
240
|
+
self._connection.close()
|
|
241
|
+
self._connection = None
|
|
242
|
+
logger.info("Disconnected from Snowflake")
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.warning(f"Error during disconnect: {str(e)}")
|
|
245
|
+
self._connection = None
|
|
246
|
+
|
|
247
|
+
async def query(self, sql: str, params: Optional[List] = None) -> List[Dict[str, Any]]:
|
|
248
|
+
"""
|
|
249
|
+
Run a SELECT query and return results.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
sql: SQL query with %s or %(name)s placeholders
|
|
253
|
+
params: List or dict of parameters for the query
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
List of rows as dictionaries
|
|
257
|
+
|
|
258
|
+
Example:
|
|
259
|
+
results = await db.query("SELECT * FROM users WHERE age > %s", [25])
|
|
260
|
+
results = await db.query("SELECT * FROM users WHERE age > %(min_age)s", {"min_age": 25})
|
|
261
|
+
"""
|
|
262
|
+
if self._connection is None:
|
|
263
|
+
await self.connect()
|
|
264
|
+
|
|
265
|
+
cursor = self._connection.cursor()
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
if params:
|
|
269
|
+
cursor.execute(sql, params)
|
|
270
|
+
else:
|
|
271
|
+
cursor.execute(sql)
|
|
272
|
+
|
|
273
|
+
rows = cursor.fetchall()
|
|
274
|
+
|
|
275
|
+
# Handle empty result set
|
|
276
|
+
if not rows:
|
|
277
|
+
return []
|
|
278
|
+
|
|
279
|
+
# Convert to list of dictionaries
|
|
280
|
+
columns = [desc[0] for desc in cursor.description]
|
|
281
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
282
|
+
|
|
283
|
+
finally:
|
|
284
|
+
cursor.close()
|
|
285
|
+
|
|
286
|
+
async def execute(self, sql: str, params: Optional[List] = None) -> int:
|
|
287
|
+
"""
|
|
288
|
+
Execute INSERT/UPDATE/DELETE and return affected rows.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
sql: SQL statement with %s or %(name)s placeholders
|
|
292
|
+
params: List or dict of parameters for the statement
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Number of affected rows
|
|
296
|
+
|
|
297
|
+
Example:
|
|
298
|
+
affected = await db.execute("INSERT INTO users (name, age) VALUES (%s, %s)", ["Alice", 30])
|
|
299
|
+
"""
|
|
300
|
+
if self._connection is None:
|
|
301
|
+
await self.connect()
|
|
302
|
+
|
|
303
|
+
cursor = self._connection.cursor()
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
if params:
|
|
307
|
+
cursor.execute(sql, params)
|
|
308
|
+
else:
|
|
309
|
+
cursor.execute(sql)
|
|
310
|
+
|
|
311
|
+
rowcount = cursor.rowcount
|
|
312
|
+
self._connection.commit()
|
|
313
|
+
return rowcount
|
|
314
|
+
|
|
315
|
+
finally:
|
|
316
|
+
cursor.close()
|
|
317
|
+
|
|
318
|
+
async def tables(self, schema: Optional[str] = None) -> List[str]:
|
|
319
|
+
"""
|
|
320
|
+
List all tables in the database or specific schema.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
schema: Schema name (defaults to current schema)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
List of table names
|
|
327
|
+
"""
|
|
328
|
+
if self._connection is None:
|
|
329
|
+
await self.connect()
|
|
330
|
+
|
|
331
|
+
cursor = self._connection.cursor()
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
if schema:
|
|
335
|
+
cursor.execute(f"SHOW TABLES IN SCHEMA {schema}")
|
|
336
|
+
else:
|
|
337
|
+
cursor.execute("SHOW TABLES")
|
|
338
|
+
|
|
339
|
+
rows = cursor.fetchall()
|
|
340
|
+
# SHOW TABLES returns multiple columns, table name is in 'name' column (index varies)
|
|
341
|
+
# Fetch as dictionaries to be safe
|
|
342
|
+
columns = [desc[0] for desc in cursor.description]
|
|
343
|
+
name_idx = columns.index('name')
|
|
344
|
+
return [row[name_idx] for row in rows]
|
|
345
|
+
|
|
346
|
+
finally:
|
|
347
|
+
cursor.close()
|
|
348
|
+
|
|
349
|
+
async def schemas(self) -> List[str]:
|
|
350
|
+
"""
|
|
351
|
+
List all schemas in the database.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
List of schema names
|
|
355
|
+
"""
|
|
356
|
+
if self._connection is None:
|
|
357
|
+
await self.connect()
|
|
358
|
+
|
|
359
|
+
cursor = self._connection.cursor()
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
cursor.execute("SHOW SCHEMAS")
|
|
363
|
+
rows = cursor.fetchall()
|
|
364
|
+
columns = [desc[0] for desc in cursor.description]
|
|
365
|
+
name_idx = columns.index('name')
|
|
366
|
+
return [row[name_idx] for row in rows]
|
|
367
|
+
|
|
368
|
+
finally:
|
|
369
|
+
cursor.close()
|
|
370
|
+
|
|
371
|
+
async def describe(self, table: str) -> List[Dict[str, Any]]:
|
|
372
|
+
"""
|
|
373
|
+
Get table column information.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
table: Table name (optionally schema-qualified: schema.table)
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
List of column details with name, type, nullable, default, etc.
|
|
380
|
+
|
|
381
|
+
Example:
|
|
382
|
+
columns = await db.describe("users")
|
|
383
|
+
for col in columns:
|
|
384
|
+
print(f"{col['name']}: {col['type']}")
|
|
385
|
+
"""
|
|
386
|
+
if self._connection is None:
|
|
387
|
+
await self.connect()
|
|
388
|
+
|
|
389
|
+
cursor = self._connection.cursor()
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
cursor.execute(f"DESCRIBE TABLE {table}")
|
|
393
|
+
rows = cursor.fetchall()
|
|
394
|
+
columns = [desc[0] for desc in cursor.description]
|
|
395
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
396
|
+
|
|
397
|
+
finally:
|
|
398
|
+
cursor.close()
|
|
399
|
+
|
|
400
|
+
async def databases(self) -> List[str]:
|
|
401
|
+
"""
|
|
402
|
+
List all accessible databases.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
List of database names
|
|
406
|
+
"""
|
|
407
|
+
if self._connection is None:
|
|
408
|
+
await self.connect()
|
|
409
|
+
|
|
410
|
+
cursor = self._connection.cursor()
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
cursor.execute("SHOW DATABASES")
|
|
414
|
+
rows = cursor.fetchall()
|
|
415
|
+
columns = [desc[0] for desc in cursor.description]
|
|
416
|
+
name_idx = columns.index('name')
|
|
417
|
+
return [row[name_idx] for row in rows]
|
|
418
|
+
|
|
419
|
+
finally:
|
|
420
|
+
cursor.close()
|
|
421
|
+
|
|
422
|
+
async def list_warehouses(self) -> List[Dict[str, Any]]:
|
|
423
|
+
"""
|
|
424
|
+
List all available warehouses.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
List of warehouse details with name, state, size, etc.
|
|
428
|
+
"""
|
|
429
|
+
if self._connection is None:
|
|
430
|
+
await self.connect()
|
|
431
|
+
|
|
432
|
+
cursor = self._connection.cursor()
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
cursor.execute("SHOW WAREHOUSES")
|
|
436
|
+
rows = cursor.fetchall()
|
|
437
|
+
columns = [desc[0] for desc in cursor.description]
|
|
438
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
439
|
+
|
|
440
|
+
finally:
|
|
441
|
+
cursor.close()
|
|
442
|
+
|
|
443
|
+
async def switch_warehouse(self, warehouse: str) -> None:
|
|
444
|
+
"""
|
|
445
|
+
Switch to a different warehouse.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
warehouse: Warehouse name to switch to
|
|
449
|
+
"""
|
|
450
|
+
if self._connection is None:
|
|
451
|
+
await self.connect()
|
|
452
|
+
|
|
453
|
+
cursor = self._connection.cursor()
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
cursor.execute(f"USE WAREHOUSE {warehouse}")
|
|
457
|
+
self.warehouse = warehouse
|
|
458
|
+
logger.info(f"Switched to warehouse: {warehouse}")
|
|
459
|
+
|
|
460
|
+
finally:
|
|
461
|
+
cursor.close()
|
|
462
|
+
|
|
463
|
+
async def get_current_warehouse(self) -> Dict[str, Any]:
|
|
464
|
+
"""
|
|
465
|
+
Get current warehouse information.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Dictionary with current warehouse details
|
|
469
|
+
"""
|
|
470
|
+
result = await self.query("SELECT CURRENT_WAREHOUSE() as warehouse")
|
|
471
|
+
return result[0] if result else {"warehouse": None}
|
|
472
|
+
|
|
473
|
+
async def query_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
474
|
+
"""
|
|
475
|
+
Get recent query history.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
limit: Maximum number of queries to return (default: 100)
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
List of query history records
|
|
482
|
+
"""
|
|
483
|
+
sql = f"""
|
|
484
|
+
SELECT
|
|
485
|
+
query_id,
|
|
486
|
+
query_text,
|
|
487
|
+
database_name,
|
|
488
|
+
schema_name,
|
|
489
|
+
query_type,
|
|
490
|
+
warehouse_name,
|
|
491
|
+
user_name,
|
|
492
|
+
role_name,
|
|
493
|
+
execution_status,
|
|
494
|
+
error_message,
|
|
495
|
+
start_time,
|
|
496
|
+
end_time,
|
|
497
|
+
total_elapsed_time,
|
|
498
|
+
bytes_scanned,
|
|
499
|
+
rows_produced
|
|
500
|
+
FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY())
|
|
501
|
+
ORDER BY start_time DESC
|
|
502
|
+
LIMIT {limit}
|
|
503
|
+
"""
|
|
504
|
+
return await self.query(sql)
|
|
505
|
+
|
|
506
|
+
async def get_warehouse_usage(self, warehouse: Optional[str] = None, days: int = 7) -> List[Dict[str, Any]]:
|
|
507
|
+
"""
|
|
508
|
+
Get warehouse credit usage.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
warehouse: Warehouse name (defaults to current warehouse)
|
|
512
|
+
days: Number of days to look back (default: 7)
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
List of usage records
|
|
516
|
+
"""
|
|
517
|
+
warehouse_filter = f"AND warehouse_name = '{warehouse}'" if warehouse else ""
|
|
518
|
+
|
|
519
|
+
sql = f"""
|
|
520
|
+
SELECT
|
|
521
|
+
warehouse_name,
|
|
522
|
+
DATE(start_time) as usage_date,
|
|
523
|
+
SUM(credits_used) as total_credits
|
|
524
|
+
FROM TABLE(INFORMATION_SCHEMA.WAREHOUSE_METERING_HISTORY(
|
|
525
|
+
DATE_RANGE_START => DATEADD(day, -{days}, CURRENT_DATE())
|
|
526
|
+
))
|
|
527
|
+
WHERE 1=1 {warehouse_filter}
|
|
528
|
+
GROUP BY warehouse_name, usage_date
|
|
529
|
+
ORDER BY usage_date DESC, warehouse_name
|
|
530
|
+
"""
|
|
531
|
+
return await self.query(sql)
|
|
532
|
+
|
|
533
|
+
async def list_stages(self) -> List[Dict[str, Any]]:
|
|
534
|
+
"""
|
|
535
|
+
List all stages (internal and external).
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
List of stage details
|
|
539
|
+
"""
|
|
540
|
+
if self._connection is None:
|
|
541
|
+
await self.connect()
|
|
542
|
+
|
|
543
|
+
cursor = self._connection.cursor()
|
|
544
|
+
|
|
545
|
+
try:
|
|
546
|
+
cursor.execute("SHOW STAGES")
|
|
547
|
+
rows = cursor.fetchall()
|
|
548
|
+
columns = [desc[0] for desc in cursor.description]
|
|
549
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
550
|
+
|
|
551
|
+
finally:
|
|
552
|
+
cursor.close()
|
|
553
|
+
|
|
554
|
+
async def put_file(self, local_path: str, stage_path: str, overwrite: bool = False) -> Dict[str, Any]:
|
|
555
|
+
"""
|
|
556
|
+
Upload file to Snowflake stage.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
local_path: Local file path
|
|
560
|
+
stage_path: Stage location (e.g., "@my_stage/path/")
|
|
561
|
+
overwrite: Whether to overwrite existing files
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Dictionary with upload results
|
|
565
|
+
"""
|
|
566
|
+
if self._connection is None:
|
|
567
|
+
await self.connect()
|
|
568
|
+
|
|
569
|
+
cursor = self._connection.cursor()
|
|
570
|
+
|
|
571
|
+
try:
|
|
572
|
+
overwrite_str = "OVERWRITE = TRUE" if overwrite else ""
|
|
573
|
+
sql = f"PUT 'file://{local_path}' {stage_path} {overwrite_str}"
|
|
574
|
+
cursor.execute(sql)
|
|
575
|
+
|
|
576
|
+
rows = cursor.fetchall()
|
|
577
|
+
columns = [desc[0] for desc in cursor.description]
|
|
578
|
+
results = [dict(zip(columns, row)) for row in rows]
|
|
579
|
+
|
|
580
|
+
return {
|
|
581
|
+
"success": True,
|
|
582
|
+
"files_uploaded": len(results),
|
|
583
|
+
"details": results
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
finally:
|
|
587
|
+
cursor.close()
|
|
588
|
+
|
|
589
|
+
async def get_file(self, stage_path: str, local_path: str) -> Dict[str, Any]:
|
|
590
|
+
"""
|
|
591
|
+
Download file from Snowflake stage.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
stage_path: Stage file location (e.g., "@my_stage/path/file.csv")
|
|
595
|
+
local_path: Local directory to download to
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
Dictionary with download results
|
|
599
|
+
"""
|
|
600
|
+
if self._connection is None:
|
|
601
|
+
await self.connect()
|
|
602
|
+
|
|
603
|
+
cursor = self._connection.cursor()
|
|
604
|
+
|
|
605
|
+
try:
|
|
606
|
+
sql = f"GET {stage_path} 'file://{local_path}'"
|
|
607
|
+
cursor.execute(sql)
|
|
608
|
+
|
|
609
|
+
rows = cursor.fetchall()
|
|
610
|
+
columns = [desc[0] for desc in cursor.description]
|
|
611
|
+
results = [dict(zip(columns, row)) for row in rows]
|
|
612
|
+
|
|
613
|
+
return {
|
|
614
|
+
"success": True,
|
|
615
|
+
"files_downloaded": len(results),
|
|
616
|
+
"details": results
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
finally:
|
|
620
|
+
cursor.close()
|
|
621
|
+
|
|
622
|
+
async def load_from_stage(
|
|
623
|
+
self,
|
|
624
|
+
table: str,
|
|
625
|
+
stage: str,
|
|
626
|
+
file_format: str = "CSV",
|
|
627
|
+
pattern: Optional[str] = None,
|
|
628
|
+
on_error: str = "ABORT_STATEMENT"
|
|
629
|
+
) -> Dict[str, Any]:
|
|
630
|
+
"""
|
|
631
|
+
Load data from stage into table.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
table: Target table name
|
|
635
|
+
stage: Stage location (e.g., "@my_stage/path/")
|
|
636
|
+
file_format: File format name or type (default: "CSV")
|
|
637
|
+
pattern: File pattern to match (e.g., r".*\.csv")
|
|
638
|
+
on_error: Error handling: ABORT_STATEMENT, CONTINUE, SKIP_FILE
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
Dictionary with load results
|
|
642
|
+
"""
|
|
643
|
+
if self._connection is None:
|
|
644
|
+
await self.connect()
|
|
645
|
+
|
|
646
|
+
cursor = self._connection.cursor()
|
|
647
|
+
|
|
648
|
+
try:
|
|
649
|
+
pattern_clause = f"PATTERN = '{pattern}'" if pattern else ""
|
|
650
|
+
|
|
651
|
+
sql = f"""
|
|
652
|
+
COPY INTO {table}
|
|
653
|
+
FROM {stage}
|
|
654
|
+
FILE_FORMAT = (TYPE = {file_format})
|
|
655
|
+
{pattern_clause}
|
|
656
|
+
ON_ERROR = {on_error}
|
|
657
|
+
"""
|
|
658
|
+
|
|
659
|
+
cursor.execute(sql)
|
|
660
|
+
|
|
661
|
+
rows = cursor.fetchall()
|
|
662
|
+
columns = [desc[0] for desc in cursor.description]
|
|
663
|
+
results = [dict(zip(columns, row)) for row in rows]
|
|
664
|
+
|
|
665
|
+
# Calculate summary
|
|
666
|
+
rows_loaded = sum(row.get('rows_loaded', 0) or 0 for row in results)
|
|
667
|
+
|
|
668
|
+
return {
|
|
669
|
+
"success": True,
|
|
670
|
+
"rows_loaded": rows_loaded,
|
|
671
|
+
"files_processed": len(results),
|
|
672
|
+
"details": results
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
finally:
|
|
676
|
+
cursor.close()
|
|
677
|
+
|
|
678
|
+
async def create_stage(
|
|
679
|
+
self,
|
|
680
|
+
name: str,
|
|
681
|
+
url: Optional[str] = None,
|
|
682
|
+
storage_integration: Optional[str] = None,
|
|
683
|
+
credentials: Optional[Dict[str, str]] = None
|
|
684
|
+
) -> None:
|
|
685
|
+
"""
|
|
686
|
+
Create a new stage (internal or external).
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
name: Stage name
|
|
690
|
+
url: External URL (e.g., "s3://bucket/path/") for external stages
|
|
691
|
+
storage_integration: Storage integration name for cloud storage
|
|
692
|
+
credentials: Cloud credentials (e.g., {"AWS_KEY_ID": "...", "AWS_SECRET_KEY": "..."})
|
|
693
|
+
|
|
694
|
+
Example:
|
|
695
|
+
# Internal stage
|
|
696
|
+
await db.create_stage("my_internal_stage")
|
|
697
|
+
|
|
698
|
+
# External S3 stage
|
|
699
|
+
await db.create_stage(
|
|
700
|
+
"my_s3_stage",
|
|
701
|
+
url="s3://my-bucket/data/",
|
|
702
|
+
credentials={"AWS_KEY_ID": "...", "AWS_SECRET_KEY": "..."}
|
|
703
|
+
)
|
|
704
|
+
"""
|
|
705
|
+
if self._connection is None:
|
|
706
|
+
await self.connect()
|
|
707
|
+
|
|
708
|
+
cursor = self._connection.cursor()
|
|
709
|
+
|
|
710
|
+
try:
|
|
711
|
+
if url:
|
|
712
|
+
# External stage
|
|
713
|
+
sql = f"CREATE STAGE IF NOT EXISTS {name} URL = '{url}'"
|
|
714
|
+
|
|
715
|
+
if storage_integration:
|
|
716
|
+
sql += f" STORAGE_INTEGRATION = {storage_integration}"
|
|
717
|
+
elif credentials:
|
|
718
|
+
creds_str = " ".join([f"{k} = '{v}'" for k, v in credentials.items()])
|
|
719
|
+
sql += f" CREDENTIALS = ({creds_str})"
|
|
720
|
+
else:
|
|
721
|
+
# Internal stage
|
|
722
|
+
sql = f"CREATE STAGE IF NOT EXISTS {name}"
|
|
723
|
+
|
|
724
|
+
cursor.execute(sql)
|
|
725
|
+
logger.info(f"Created stage: {name}")
|
|
726
|
+
|
|
727
|
+
finally:
|
|
728
|
+
cursor.close()
|
|
729
|
+
|
|
730
|
+
def get_tools(self) -> List['AgentTool']:
|
|
731
|
+
"""
|
|
732
|
+
Expose Snowflake operations as agent tools.
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
List of AgentTool instances for Snowflake operations
|
|
736
|
+
"""
|
|
737
|
+
from ..core.tools import AgentTool
|
|
738
|
+
|
|
739
|
+
return [
|
|
740
|
+
AgentTool(
|
|
741
|
+
name="query_database",
|
|
742
|
+
description="Execute a SQL SELECT query on Snowflake and return results as a list of dictionaries",
|
|
743
|
+
parameters={
|
|
744
|
+
"type": "object",
|
|
745
|
+
"properties": {
|
|
746
|
+
"sql": {
|
|
747
|
+
"type": "string",
|
|
748
|
+
"description": "SQL SELECT query with %s placeholders for parameters"
|
|
749
|
+
},
|
|
750
|
+
"params": {
|
|
751
|
+
"type": "array",
|
|
752
|
+
"description": "Optional list of parameter values for query placeholders",
|
|
753
|
+
"items": {"type": "string"}
|
|
754
|
+
}
|
|
755
|
+
},
|
|
756
|
+
"required": ["sql"]
|
|
757
|
+
},
|
|
758
|
+
handler=self._tool_query,
|
|
759
|
+
category="database",
|
|
760
|
+
source="plugin",
|
|
761
|
+
plugin_name="Snowflake",
|
|
762
|
+
timeout_seconds=60
|
|
763
|
+
),
|
|
764
|
+
AgentTool(
|
|
765
|
+
name="execute_sql",
|
|
766
|
+
description="Execute a SQL INSERT, UPDATE, or DELETE statement on Snowflake and return the number of affected rows",
|
|
767
|
+
parameters={
|
|
768
|
+
"type": "object",
|
|
769
|
+
"properties": {
|
|
770
|
+
"sql": {
|
|
771
|
+
"type": "string",
|
|
772
|
+
"description": "SQL statement with %s placeholders for parameters"
|
|
773
|
+
},
|
|
774
|
+
"params": {
|
|
775
|
+
"type": "array",
|
|
776
|
+
"description": "Optional list of parameter values for statement placeholders",
|
|
777
|
+
"items": {"type": "string"}
|
|
778
|
+
}
|
|
779
|
+
},
|
|
780
|
+
"required": ["sql"]
|
|
781
|
+
},
|
|
782
|
+
handler=self._tool_execute,
|
|
783
|
+
category="database",
|
|
784
|
+
source="plugin",
|
|
785
|
+
plugin_name="Snowflake",
|
|
786
|
+
timeout_seconds=60
|
|
787
|
+
),
|
|
788
|
+
AgentTool(
|
|
789
|
+
name="list_tables",
|
|
790
|
+
description="List all tables in the Snowflake database or a specific schema",
|
|
791
|
+
parameters={
|
|
792
|
+
"type": "object",
|
|
793
|
+
"properties": {
|
|
794
|
+
"schema": {
|
|
795
|
+
"type": "string",
|
|
796
|
+
"description": "Optional schema name to list tables from (defaults to current schema)"
|
|
797
|
+
}
|
|
798
|
+
},
|
|
799
|
+
"required": []
|
|
800
|
+
},
|
|
801
|
+
handler=self._tool_list_tables,
|
|
802
|
+
category="database",
|
|
803
|
+
source="plugin",
|
|
804
|
+
plugin_name="Snowflake",
|
|
805
|
+
timeout_seconds=30
|
|
806
|
+
),
|
|
807
|
+
AgentTool(
|
|
808
|
+
name="list_schemas",
|
|
809
|
+
description="List all schemas in the Snowflake database",
|
|
810
|
+
parameters={
|
|
811
|
+
"type": "object",
|
|
812
|
+
"properties": {},
|
|
813
|
+
"required": []
|
|
814
|
+
},
|
|
815
|
+
handler=self._tool_list_schemas,
|
|
816
|
+
category="database",
|
|
817
|
+
source="plugin",
|
|
818
|
+
plugin_name="Snowflake",
|
|
819
|
+
timeout_seconds=30
|
|
820
|
+
),
|
|
821
|
+
AgentTool(
|
|
822
|
+
name="get_table_schema",
|
|
823
|
+
description="Get detailed column information for a Snowflake table including column names, types, and constraints",
|
|
824
|
+
parameters={
|
|
825
|
+
"type": "object",
|
|
826
|
+
"properties": {
|
|
827
|
+
"table": {
|
|
828
|
+
"type": "string",
|
|
829
|
+
"description": "Table name (optionally schema-qualified like 'schema.table')"
|
|
830
|
+
}
|
|
831
|
+
},
|
|
832
|
+
"required": ["table"]
|
|
833
|
+
},
|
|
834
|
+
handler=self._tool_get_table_schema,
|
|
835
|
+
category="database",
|
|
836
|
+
source="plugin",
|
|
837
|
+
plugin_name="Snowflake",
|
|
838
|
+
timeout_seconds=30
|
|
839
|
+
),
|
|
840
|
+
AgentTool(
|
|
841
|
+
name="list_warehouses",
|
|
842
|
+
description="List all available Snowflake compute warehouses with their status and configuration",
|
|
843
|
+
parameters={
|
|
844
|
+
"type": "object",
|
|
845
|
+
"properties": {},
|
|
846
|
+
"required": []
|
|
847
|
+
},
|
|
848
|
+
handler=self._tool_list_warehouses,
|
|
849
|
+
category="database",
|
|
850
|
+
source="plugin",
|
|
851
|
+
plugin_name="Snowflake",
|
|
852
|
+
timeout_seconds=30
|
|
853
|
+
),
|
|
854
|
+
AgentTool(
|
|
855
|
+
name="switch_warehouse",
|
|
856
|
+
description="Switch to a different Snowflake compute warehouse for subsequent queries",
|
|
857
|
+
parameters={
|
|
858
|
+
"type": "object",
|
|
859
|
+
"properties": {
|
|
860
|
+
"warehouse": {
|
|
861
|
+
"type": "string",
|
|
862
|
+
"description": "Name of the warehouse to switch to"
|
|
863
|
+
}
|
|
864
|
+
},
|
|
865
|
+
"required": ["warehouse"]
|
|
866
|
+
},
|
|
867
|
+
handler=self._tool_switch_warehouse,
|
|
868
|
+
category="database",
|
|
869
|
+
source="plugin",
|
|
870
|
+
plugin_name="Snowflake",
|
|
871
|
+
timeout_seconds=30
|
|
872
|
+
),
|
|
873
|
+
AgentTool(
|
|
874
|
+
name="get_query_history",
|
|
875
|
+
description="Get recent query history from Snowflake including execution status, timing, and resource usage",
|
|
876
|
+
parameters={
|
|
877
|
+
"type": "object",
|
|
878
|
+
"properties": {
|
|
879
|
+
"limit": {
|
|
880
|
+
"type": "integer",
|
|
881
|
+
"description": "Maximum number of queries to return (default: 100)",
|
|
882
|
+
"default": 100
|
|
883
|
+
}
|
|
884
|
+
},
|
|
885
|
+
"required": []
|
|
886
|
+
},
|
|
887
|
+
handler=self._tool_get_query_history,
|
|
888
|
+
category="database",
|
|
889
|
+
source="plugin",
|
|
890
|
+
plugin_name="Snowflake",
|
|
891
|
+
timeout_seconds=45
|
|
892
|
+
),
|
|
893
|
+
AgentTool(
|
|
894
|
+
name="list_stages",
|
|
895
|
+
description="List all Snowflake stages (internal and external) for data loading",
|
|
896
|
+
parameters={
|
|
897
|
+
"type": "object",
|
|
898
|
+
"properties": {},
|
|
899
|
+
"required": []
|
|
900
|
+
},
|
|
901
|
+
handler=self._tool_list_stages,
|
|
902
|
+
category="database",
|
|
903
|
+
source="plugin",
|
|
904
|
+
plugin_name="Snowflake",
|
|
905
|
+
timeout_seconds=30
|
|
906
|
+
),
|
|
907
|
+
AgentTool(
|
|
908
|
+
name="upload_to_stage",
|
|
909
|
+
description="Upload a local file to a Snowflake stage for data loading",
|
|
910
|
+
parameters={
|
|
911
|
+
"type": "object",
|
|
912
|
+
"properties": {
|
|
913
|
+
"local_path": {
|
|
914
|
+
"type": "string",
|
|
915
|
+
"description": "Local file path to upload"
|
|
916
|
+
},
|
|
917
|
+
"stage_path": {
|
|
918
|
+
"type": "string",
|
|
919
|
+
"description": "Stage location (e.g., '@my_stage/path/')"
|
|
920
|
+
},
|
|
921
|
+
"overwrite": {
|
|
922
|
+
"type": "boolean",
|
|
923
|
+
"description": "Whether to overwrite existing files (default: false)",
|
|
924
|
+
"default": False
|
|
925
|
+
}
|
|
926
|
+
},
|
|
927
|
+
"required": ["local_path", "stage_path"]
|
|
928
|
+
},
|
|
929
|
+
handler=self._tool_upload_to_stage,
|
|
930
|
+
category="database",
|
|
931
|
+
source="plugin",
|
|
932
|
+
plugin_name="Snowflake",
|
|
933
|
+
timeout_seconds=120
|
|
934
|
+
),
|
|
935
|
+
AgentTool(
|
|
936
|
+
name="load_from_stage",
|
|
937
|
+
description="Load data from a Snowflake stage into a table using COPY INTO command",
|
|
938
|
+
parameters={
|
|
939
|
+
"type": "object",
|
|
940
|
+
"properties": {
|
|
941
|
+
"table": {
|
|
942
|
+
"type": "string",
|
|
943
|
+
"description": "Target table name"
|
|
944
|
+
},
|
|
945
|
+
"stage": {
|
|
946
|
+
"type": "string",
|
|
947
|
+
"description": "Stage location (e.g., '@my_stage/path/')"
|
|
948
|
+
},
|
|
949
|
+
"file_format": {
|
|
950
|
+
"type": "string",
|
|
951
|
+
"description": "File format type (default: CSV)",
|
|
952
|
+
"default": "CSV"
|
|
953
|
+
},
|
|
954
|
+
"pattern": {
|
|
955
|
+
"type": "string",
|
|
956
|
+
"description": "Optional file pattern to match (e.g., '.*\\.csv')"
|
|
957
|
+
}
|
|
958
|
+
},
|
|
959
|
+
"required": ["table", "stage"]
|
|
960
|
+
},
|
|
961
|
+
handler=self._tool_load_from_stage,
|
|
962
|
+
category="database",
|
|
963
|
+
source="plugin",
|
|
964
|
+
plugin_name="Snowflake",
|
|
965
|
+
timeout_seconds=180
|
|
966
|
+
),
|
|
967
|
+
AgentTool(
|
|
968
|
+
name="create_stage",
|
|
969
|
+
description="Create a new Snowflake stage (internal or external) for data loading",
|
|
970
|
+
parameters={
|
|
971
|
+
"type": "object",
|
|
972
|
+
"properties": {
|
|
973
|
+
"name": {
|
|
974
|
+
"type": "string",
|
|
975
|
+
"description": "Stage name"
|
|
976
|
+
},
|
|
977
|
+
"url": {
|
|
978
|
+
"type": "string",
|
|
979
|
+
"description": "External URL for external stages (e.g., 's3://bucket/path/')"
|
|
980
|
+
},
|
|
981
|
+
"storage_integration": {
|
|
982
|
+
"type": "string",
|
|
983
|
+
"description": "Storage integration name for cloud storage"
|
|
984
|
+
}
|
|
985
|
+
},
|
|
986
|
+
"required": ["name"]
|
|
987
|
+
},
|
|
988
|
+
handler=self._tool_create_stage,
|
|
989
|
+
category="database",
|
|
990
|
+
source="plugin",
|
|
991
|
+
plugin_name="Snowflake",
|
|
992
|
+
timeout_seconds=30
|
|
993
|
+
),
|
|
994
|
+
]
|
|
995
|
+
|
|
996
|
+
# Tool handler methods
|
|
997
|
+
|
|
998
|
+
async def _tool_query(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
999
|
+
"""Tool handler for query_database"""
|
|
1000
|
+
sql = args.get("sql")
|
|
1001
|
+
params = args.get("params")
|
|
1002
|
+
|
|
1003
|
+
results = await self.query(sql, params)
|
|
1004
|
+
|
|
1005
|
+
return {
|
|
1006
|
+
"success": True,
|
|
1007
|
+
"rows": results,
|
|
1008
|
+
"row_count": len(results)
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
async def _tool_execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1012
|
+
"""Tool handler for execute_sql"""
|
|
1013
|
+
sql = args.get("sql")
|
|
1014
|
+
params = args.get("params")
|
|
1015
|
+
|
|
1016
|
+
affected_rows = await self.execute(sql, params)
|
|
1017
|
+
|
|
1018
|
+
return {
|
|
1019
|
+
"success": True,
|
|
1020
|
+
"affected_rows": affected_rows
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
async def _tool_list_tables(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1024
|
+
"""Tool handler for list_tables"""
|
|
1025
|
+
schema = args.get("schema")
|
|
1026
|
+
|
|
1027
|
+
tables = await self.tables(schema)
|
|
1028
|
+
|
|
1029
|
+
return {
|
|
1030
|
+
"success": True,
|
|
1031
|
+
"tables": tables,
|
|
1032
|
+
"count": len(tables)
|
|
1033
|
+
}
|
|
1034
|
+
|
|
1035
|
+
async def _tool_list_schemas(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1036
|
+
"""Tool handler for list_schemas"""
|
|
1037
|
+
schemas = await self.schemas()
|
|
1038
|
+
|
|
1039
|
+
return {
|
|
1040
|
+
"success": True,
|
|
1041
|
+
"schemas": schemas,
|
|
1042
|
+
"count": len(schemas)
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
async def _tool_get_table_schema(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1046
|
+
"""Tool handler for get_table_schema"""
|
|
1047
|
+
table = args.get("table")
|
|
1048
|
+
|
|
1049
|
+
columns = await self.describe(table)
|
|
1050
|
+
|
|
1051
|
+
return {
|
|
1052
|
+
"success": True,
|
|
1053
|
+
"table": table,
|
|
1054
|
+
"columns": columns,
|
|
1055
|
+
"column_count": len(columns)
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
async def _tool_list_warehouses(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1059
|
+
"""Tool handler for list_warehouses"""
|
|
1060
|
+
warehouses = await self.list_warehouses()
|
|
1061
|
+
|
|
1062
|
+
return {
|
|
1063
|
+
"success": True,
|
|
1064
|
+
"warehouses": warehouses,
|
|
1065
|
+
"count": len(warehouses)
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
async def _tool_switch_warehouse(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1069
|
+
"""Tool handler for switch_warehouse"""
|
|
1070
|
+
warehouse = args.get("warehouse")
|
|
1071
|
+
|
|
1072
|
+
await self.switch_warehouse(warehouse)
|
|
1073
|
+
|
|
1074
|
+
return {
|
|
1075
|
+
"success": True,
|
|
1076
|
+
"message": f"Switched to warehouse: {warehouse}",
|
|
1077
|
+
"warehouse": warehouse
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
async def _tool_get_query_history(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1081
|
+
"""Tool handler for get_query_history"""
|
|
1082
|
+
limit = args.get("limit", 100)
|
|
1083
|
+
|
|
1084
|
+
history = await self.query_history(limit)
|
|
1085
|
+
|
|
1086
|
+
return {
|
|
1087
|
+
"success": True,
|
|
1088
|
+
"queries": history,
|
|
1089
|
+
"count": len(history)
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
async def _tool_list_stages(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1093
|
+
"""Tool handler for list_stages"""
|
|
1094
|
+
stages = await self.list_stages()
|
|
1095
|
+
|
|
1096
|
+
return {
|
|
1097
|
+
"success": True,
|
|
1098
|
+
"stages": stages,
|
|
1099
|
+
"count": len(stages)
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
async def _tool_upload_to_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1103
|
+
"""Tool handler for upload_to_stage"""
|
|
1104
|
+
local_path = args.get("local_path")
|
|
1105
|
+
stage_path = args.get("stage_path")
|
|
1106
|
+
overwrite = args.get("overwrite", False)
|
|
1107
|
+
|
|
1108
|
+
result = await self.put_file(local_path, stage_path, overwrite)
|
|
1109
|
+
|
|
1110
|
+
return result
|
|
1111
|
+
|
|
1112
|
+
async def _tool_load_from_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1113
|
+
"""Tool handler for load_from_stage"""
|
|
1114
|
+
table = args.get("table")
|
|
1115
|
+
stage = args.get("stage")
|
|
1116
|
+
file_format = args.get("file_format", "CSV")
|
|
1117
|
+
pattern = args.get("pattern")
|
|
1118
|
+
|
|
1119
|
+
result = await self.load_from_stage(table, stage, file_format, pattern)
|
|
1120
|
+
|
|
1121
|
+
return result
|
|
1122
|
+
|
|
1123
|
+
async def _tool_create_stage(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
1124
|
+
"""Tool handler for create_stage"""
|
|
1125
|
+
name = args.get("name")
|
|
1126
|
+
url = args.get("url")
|
|
1127
|
+
storage_integration = args.get("storage_integration")
|
|
1128
|
+
|
|
1129
|
+
await self.create_stage(name, url, storage_integration)
|
|
1130
|
+
|
|
1131
|
+
return {
|
|
1132
|
+
"success": True,
|
|
1133
|
+
"message": f"Created stage: {name}",
|
|
1134
|
+
"stage": name
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def snowflake(**kwargs) -> SnowflakePlugin:
|
|
1139
|
+
"""
|
|
1140
|
+
Create Snowflake plugin with simplified interface.
|
|
1141
|
+
|
|
1142
|
+
Args:
|
|
1143
|
+
**kwargs: Connection parameters (account, warehouse, database, etc.)
|
|
1144
|
+
|
|
1145
|
+
Returns:
|
|
1146
|
+
SnowflakePlugin instance
|
|
1147
|
+
|
|
1148
|
+
Example:
|
|
1149
|
+
from daita.plugins import snowflake
|
|
1150
|
+
|
|
1151
|
+
db = snowflake(
|
|
1152
|
+
account="xy12345",
|
|
1153
|
+
warehouse="COMPUTE_WH",
|
|
1154
|
+
database="MYDB",
|
|
1155
|
+
user="myuser",
|
|
1156
|
+
password="mypass"
|
|
1157
|
+
)
|
|
1158
|
+
"""
|
|
1159
|
+
return SnowflakePlugin(**kwargs)
|