relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/shims/executor.py +4 -1
- relationalai/shims/mm2v0.py +15 -10
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +39 -30
- v0/relationalai/__init__.py +69 -22
- v0/relationalai/clients/__init__.py +15 -2
- v0/relationalai/clients/client.py +4 -4
- v0/relationalai/clients/local.py +5 -5
- v0/relationalai/clients/resources/__init__.py +8 -0
- v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
- v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
- v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
- v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- v0/relationalai/clients/resources/snowflake/util.py +387 -0
- v0/relationalai/early_access/dsl/ir/executor.py +4 -4
- v0/relationalai/early_access/dsl/snow/api.py +2 -1
- v0/relationalai/experimental/solvers.py +7 -7
- v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
- v0/relationalai/semantics/internal/snowflake.py +1 -1
- v0/relationalai/semantics/lqp/executor.py +4 -11
- v0/relationalai/semantics/metamodel/util.py +6 -5
- v0/relationalai/semantics/rel/executor.py +14 -11
- v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- v0/relationalai/tools/cli.py +26 -30
- v0/relationalai/tools/cli_helpers.py +10 -2
- v0/relationalai/util/otel_configuration.py +2 -1
- v0/relationalai/util/otel_handler.py +1 -1
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
- /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for Snowflake resources.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
import re
|
|
6
|
+
import decimal
|
|
7
|
+
import base64
|
|
8
|
+
from numbers import Number
|
|
9
|
+
from datetime import datetime, date
|
|
10
|
+
from typing import List, Any, Dict, cast
|
|
11
|
+
|
|
12
|
+
from .... import dsl
|
|
13
|
+
from ....environments import runtime_env, SnowbookEnvironment
|
|
14
|
+
|
|
15
|
+
# warehouse-based snowflake notebooks currently don't have hazmat
|
|
16
|
+
crypto_disabled = False
|
|
17
|
+
try:
|
|
18
|
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
19
|
+
from cryptography.hazmat.backends import default_backend
|
|
20
|
+
from cryptography.hazmat.primitives import padding
|
|
21
|
+
except (ModuleNotFoundError, ImportError):
|
|
22
|
+
crypto_disabled = True
|
|
23
|
+
|
|
24
|
+
# Constants used by helper functions
|
|
25
|
+
ENGINE_ERRORS = ("engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted")
|
|
26
|
+
ENGINE_NOT_READY_MSGS = ("engine is in pending", "engine is provisioning")
|
|
27
|
+
DATABASE_ERRORS = ("database not found",)
|
|
28
|
+
|
|
29
|
+
# Constants for import/export and transaction processing
|
|
30
|
+
VALID_IMPORT_STATES = ("PENDING", "PROCESSING", "QUARANTINED", "LOADED")
|
|
31
|
+
IMPORT_STREAM_FIELDS = (
|
|
32
|
+
"ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS",
|
|
33
|
+
"FQ_OBJECT_NAME", "RAI_DATABASE", "RAI_RELATION", "DATA_SYNC_STATUS",
|
|
34
|
+
"PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
|
|
35
|
+
"NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"
|
|
36
|
+
)
|
|
37
|
+
FIELD_MAP = {
|
|
38
|
+
"database_name": "database",
|
|
39
|
+
"engine_name": "engine",
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def process_jinja_template(template: str, indent_spaces: int = 0, **substitutions: Any) -> str:
|
|
44
|
+
"""Process a Jinja-like template.
|
|
45
|
+
|
|
46
|
+
Supports:
|
|
47
|
+
- Variable substitution {{ var }}
|
|
48
|
+
- Conditional blocks {% if condition %} ... {% endif %}
|
|
49
|
+
- For loops {% for item in items %} ... {% endfor %}
|
|
50
|
+
- Comments {# ... #}
|
|
51
|
+
- Whitespace control with {%- and -%}
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
template: The template string
|
|
55
|
+
indent_spaces: Number of spaces to indent the result
|
|
56
|
+
**substitutions: Variable substitutions
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def evaluate_condition(condition: str, context: dict) -> bool:
|
|
60
|
+
"""Safely evaluate a condition string using the context."""
|
|
61
|
+
# Replace variables with their values
|
|
62
|
+
for k, v in context.items():
|
|
63
|
+
if isinstance(v, str):
|
|
64
|
+
condition = condition.replace(k, f"'{v}'")
|
|
65
|
+
else:
|
|
66
|
+
condition = condition.replace(k, str(v))
|
|
67
|
+
try:
|
|
68
|
+
return bool(eval(condition, {"__builtins__": {}}, {}))
|
|
69
|
+
except Exception:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
def process_expression(expr: str, context: dict) -> str:
|
|
73
|
+
"""Process a {{ expression }} block."""
|
|
74
|
+
expr = expr.strip()
|
|
75
|
+
if expr in context:
|
|
76
|
+
return str(context[expr])
|
|
77
|
+
return ""
|
|
78
|
+
|
|
79
|
+
def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
|
|
80
|
+
"""Process a block of template lines recursively."""
|
|
81
|
+
result = []
|
|
82
|
+
i = 0
|
|
83
|
+
while i < len(lines):
|
|
84
|
+
line = lines[i]
|
|
85
|
+
|
|
86
|
+
# Handle comments
|
|
87
|
+
line = re.sub(r'{#.*?#}', '', line)
|
|
88
|
+
|
|
89
|
+
# Handle if blocks
|
|
90
|
+
if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
|
|
91
|
+
if if_match:
|
|
92
|
+
condition = if_match.group(1)
|
|
93
|
+
if_block = []
|
|
94
|
+
else_block = []
|
|
95
|
+
i += 1
|
|
96
|
+
nesting = 1
|
|
97
|
+
in_else_block = False
|
|
98
|
+
while i < len(lines) and nesting > 0:
|
|
99
|
+
if re.search(r'{%\s*if\s+', lines[i]):
|
|
100
|
+
nesting += 1
|
|
101
|
+
elif re.search(r'{%\s*endif\s*%}', lines[i]):
|
|
102
|
+
nesting -= 1
|
|
103
|
+
elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
|
|
104
|
+
in_else_block = True
|
|
105
|
+
i += 1
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
if nesting > 0:
|
|
109
|
+
if in_else_block:
|
|
110
|
+
else_block.append(lines[i])
|
|
111
|
+
else:
|
|
112
|
+
if_block.append(lines[i])
|
|
113
|
+
i += 1
|
|
114
|
+
if evaluate_condition(condition, context):
|
|
115
|
+
result.extend(process_block(if_block, context, indent))
|
|
116
|
+
else:
|
|
117
|
+
result.extend(process_block(else_block, context, indent))
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
# Handle for loops
|
|
121
|
+
for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
|
|
122
|
+
if for_match:
|
|
123
|
+
var_name, iterable_name = for_match.groups()
|
|
124
|
+
for_block = []
|
|
125
|
+
i += 1
|
|
126
|
+
nesting = 1
|
|
127
|
+
while i < len(lines) and nesting > 0:
|
|
128
|
+
if re.search(r'{%\s*for\s+', lines[i]):
|
|
129
|
+
nesting += 1
|
|
130
|
+
elif re.search(r'{%\s*endfor\s*%}', lines[i]):
|
|
131
|
+
nesting -= 1
|
|
132
|
+
if nesting > 0:
|
|
133
|
+
for_block.append(lines[i])
|
|
134
|
+
i += 1
|
|
135
|
+
if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
|
|
136
|
+
for item in context[iterable_name]:
|
|
137
|
+
loop_context = dict(context)
|
|
138
|
+
loop_context[var_name] = item
|
|
139
|
+
result.extend(process_block(for_block, loop_context, indent))
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Handle variable substitution
|
|
143
|
+
line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
|
|
144
|
+
|
|
145
|
+
# Handle whitespace control
|
|
146
|
+
line = re.sub(r'{%-', '{%', line)
|
|
147
|
+
line = re.sub(r'-%}', '%}', line)
|
|
148
|
+
|
|
149
|
+
# Add line with proper indentation, preserving blank lines
|
|
150
|
+
if line.strip():
|
|
151
|
+
result.append(" " * (indent_spaces + indent) + line)
|
|
152
|
+
else:
|
|
153
|
+
result.append("")
|
|
154
|
+
|
|
155
|
+
i += 1
|
|
156
|
+
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
# Split template into lines and process
|
|
160
|
+
lines = template.split('\n')
|
|
161
|
+
processed_lines = process_block(lines, substitutions)
|
|
162
|
+
|
|
163
|
+
return '\n'.join(processed_lines)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def type_to_sql(type_obj: Any) -> str:
|
|
167
|
+
if type_obj is str:
|
|
168
|
+
return "VARCHAR"
|
|
169
|
+
if type_obj is int:
|
|
170
|
+
return "NUMBER"
|
|
171
|
+
if type_obj is Number:
|
|
172
|
+
return "DECIMAL(38, 15)"
|
|
173
|
+
if type_obj is float:
|
|
174
|
+
return "FLOAT"
|
|
175
|
+
if type_obj is decimal.Decimal:
|
|
176
|
+
return "DECIMAL(38, 15)"
|
|
177
|
+
if type_obj is bool:
|
|
178
|
+
return "BOOLEAN"
|
|
179
|
+
if type_obj is dict:
|
|
180
|
+
return "VARIANT"
|
|
181
|
+
if type_obj is list:
|
|
182
|
+
return "ARRAY"
|
|
183
|
+
if type_obj is bytes:
|
|
184
|
+
return "BINARY"
|
|
185
|
+
if type_obj is datetime:
|
|
186
|
+
return "TIMESTAMP"
|
|
187
|
+
if type_obj is date:
|
|
188
|
+
return "DATE"
|
|
189
|
+
if isinstance(type_obj, dsl.Type):
|
|
190
|
+
return "VARCHAR"
|
|
191
|
+
raise ValueError(f"Unknown type {type_obj}")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def type_to_snowpark(type_obj: Any) -> str:
|
|
195
|
+
if type_obj is str:
|
|
196
|
+
return "StringType()"
|
|
197
|
+
if type_obj is int:
|
|
198
|
+
return "IntegerType()"
|
|
199
|
+
if type_obj is float:
|
|
200
|
+
return "FloatType()"
|
|
201
|
+
if type_obj is Number:
|
|
202
|
+
return "DecimalType(38, 15)"
|
|
203
|
+
if type_obj is decimal.Decimal:
|
|
204
|
+
return "DecimalType(38, 15)"
|
|
205
|
+
if type_obj is bool:
|
|
206
|
+
return "BooleanType()"
|
|
207
|
+
if type_obj is dict:
|
|
208
|
+
return "MapType()"
|
|
209
|
+
if type_obj is list:
|
|
210
|
+
return "ArrayType()"
|
|
211
|
+
if type_obj is bytes:
|
|
212
|
+
return "BinaryType()"
|
|
213
|
+
if type_obj is datetime:
|
|
214
|
+
return "TimestampType()"
|
|
215
|
+
if type_obj is date:
|
|
216
|
+
return "DateType()"
|
|
217
|
+
if isinstance(type_obj, dsl.Type):
|
|
218
|
+
return "StringType()"
|
|
219
|
+
raise ValueError(f"Unknown type {type_obj}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def sanitize_user_name(user: str) -> str:
|
|
223
|
+
"""Sanitize a user name by extracting the part before '@' and replacing invalid characters."""
|
|
224
|
+
# Extract the part before the '@'
|
|
225
|
+
sanitized_user = user.split('@')[0]
|
|
226
|
+
# Replace any character that is not a letter, number, or underscore with '_'
|
|
227
|
+
sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
|
|
228
|
+
return sanitized_user
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def is_engine_issue(response_message: str) -> bool:
|
|
232
|
+
"""Check if a response message indicates an engine issue."""
|
|
233
|
+
return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def is_database_issue(response_message: str) -> bool:
|
|
237
|
+
"""Check if a response message indicates a database issue."""
|
|
238
|
+
return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def collect_error_messages(e: Exception) -> list[str]:
|
|
242
|
+
"""Collect all error messages from an exception and its chain.
|
|
243
|
+
|
|
244
|
+
Extracts messages from:
|
|
245
|
+
- str(e)
|
|
246
|
+
- e.message (if present, e.g., RAIException)
|
|
247
|
+
- e.args (string arguments)
|
|
248
|
+
- e.__cause__
|
|
249
|
+
- e.__context__
|
|
250
|
+
- Nested JavaScript execution errors
|
|
251
|
+
"""
|
|
252
|
+
messages = [str(e).lower()]
|
|
253
|
+
|
|
254
|
+
# Check message attribute (RAIException has this)
|
|
255
|
+
if hasattr(e, 'message'):
|
|
256
|
+
msg = getattr(e, 'message', None)
|
|
257
|
+
if isinstance(msg, str):
|
|
258
|
+
messages.append(msg.lower())
|
|
259
|
+
|
|
260
|
+
# Check args
|
|
261
|
+
if hasattr(e, 'args') and e.args:
|
|
262
|
+
for arg in e.args:
|
|
263
|
+
if isinstance(arg, str):
|
|
264
|
+
messages.append(arg.lower())
|
|
265
|
+
|
|
266
|
+
# Check cause and context
|
|
267
|
+
if hasattr(e, '__cause__') and e.__cause__:
|
|
268
|
+
messages.append(str(e.__cause__).lower())
|
|
269
|
+
if hasattr(e, '__context__') and e.__context__:
|
|
270
|
+
messages.append(str(e.__context__).lower())
|
|
271
|
+
|
|
272
|
+
# Extract nested messages from JavaScript execution errors
|
|
273
|
+
for msg in messages[:]: # Copy to avoid modification during iteration
|
|
274
|
+
if re.search(r"javascript execution error", msg):
|
|
275
|
+
matches = re.findall(r'"message"\s*:\s*"([^"]+)"', msg, re.IGNORECASE)
|
|
276
|
+
messages.extend([m.lower() for m in matches])
|
|
277
|
+
|
|
278
|
+
return messages
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
#--------------------------------------------------
|
|
282
|
+
# Parameter and Data Transformation Utilities
|
|
283
|
+
#--------------------------------------------------
|
|
284
|
+
|
|
285
|
+
def normalize_params(params: List[Any] | Any | None) -> List[Any] | None:
|
|
286
|
+
"""Normalize parameters to a list format."""
|
|
287
|
+
if params is not None and not isinstance(params, list):
|
|
288
|
+
return cast(List[Any], [params])
|
|
289
|
+
return params
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def format_sproc_name(name: str, type_obj: Any) -> str:
|
|
293
|
+
"""Format stored procedure parameter name based on type."""
|
|
294
|
+
if type_obj is datetime:
|
|
295
|
+
return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
|
|
296
|
+
return name
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def is_azure_url(url: str) -> bool:
|
|
300
|
+
"""Check if a URL is an Azure blob storage URL."""
|
|
301
|
+
return "blob.core.windows.net" in url
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def is_container_runtime() -> bool:
|
|
305
|
+
"""Check if running in a container runtime environment."""
|
|
306
|
+
return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
#--------------------------------------------------
|
|
310
|
+
# Import/Export Utilities
|
|
311
|
+
#--------------------------------------------------
|
|
312
|
+
|
|
313
|
+
def is_valid_import_state(state: str) -> bool:
|
|
314
|
+
"""Check if an import state is valid."""
|
|
315
|
+
return state in VALID_IMPORT_STATES
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def imports_to_dicts(results: List[Any]) -> List[Dict[str, Any]]:
|
|
319
|
+
"""Convert import results to dictionaries with lowercase keys."""
|
|
320
|
+
parsed_results = [
|
|
321
|
+
{field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
|
|
322
|
+
for row in results
|
|
323
|
+
]
|
|
324
|
+
return parsed_results
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
#--------------------------------------------------
|
|
328
|
+
# Transaction Utilities
|
|
329
|
+
#--------------------------------------------------
|
|
330
|
+
|
|
331
|
+
def txn_list_to_dicts(transactions: List[Any]) -> List[Dict[str, Any]]:
|
|
332
|
+
"""Convert transaction list to dictionaries with field mapping."""
|
|
333
|
+
dicts = []
|
|
334
|
+
for txn in transactions:
|
|
335
|
+
dict = {}
|
|
336
|
+
txn_dict = txn.asDict()
|
|
337
|
+
for key in txn_dict:
|
|
338
|
+
mapValue = FIELD_MAP.get(key.lower())
|
|
339
|
+
if mapValue:
|
|
340
|
+
dict[mapValue] = txn_dict[key]
|
|
341
|
+
else:
|
|
342
|
+
dict[key.lower()] = txn_dict[key]
|
|
343
|
+
dicts.append(dict)
|
|
344
|
+
return dicts
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
#--------------------------------------------------
|
|
348
|
+
# Encryption Utilities
|
|
349
|
+
#--------------------------------------------------
|
|
350
|
+
|
|
351
|
+
def decrypt_stream(key: bytes, iv: bytes, src: bytes) -> bytes:
|
|
352
|
+
"""Decrypt the provided stream with PKCS#5 padding handling."""
|
|
353
|
+
if crypto_disabled:
|
|
354
|
+
if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
|
|
355
|
+
raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
|
|
356
|
+
else:
|
|
357
|
+
raise Exception("library `cryptography.hazmat` missing; please install")
|
|
358
|
+
|
|
359
|
+
# `type:ignore`s are because of the conditional import, which
|
|
360
|
+
# we have because warehouse-based snowflake notebooks don't support
|
|
361
|
+
# the crypto library we're using.
|
|
362
|
+
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
|
|
363
|
+
decryptor = cipher.decryptor()
|
|
364
|
+
|
|
365
|
+
# Decrypt the data
|
|
366
|
+
decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
|
|
367
|
+
|
|
368
|
+
# Unpad the decrypted data using PKCS#5
|
|
369
|
+
unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
|
|
370
|
+
unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
|
371
|
+
|
|
372
|
+
return unpadded_data
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def decrypt_artifact(data: bytes, encryption_material: str) -> bytes:
|
|
376
|
+
"""Decrypts the artifact data using provided encryption material."""
|
|
377
|
+
encryption_material_parts = encryption_material.split("|")
|
|
378
|
+
assert len(encryption_material_parts) == 3, "Invalid encryption material"
|
|
379
|
+
|
|
380
|
+
algorithm, key_base64, iv_base64 = encryption_material_parts
|
|
381
|
+
assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
|
|
382
|
+
|
|
383
|
+
key = base64.standard_b64decode(key_base64)
|
|
384
|
+
iv = base64.standard_b64decode(iv_base64)
|
|
385
|
+
|
|
386
|
+
return decrypt_stream(key, iv, data)
|
|
387
|
+
|
|
@@ -6,10 +6,9 @@ from collections import defaultdict
|
|
|
6
6
|
from typing import Any, List, Optional
|
|
7
7
|
|
|
8
8
|
from pandas import DataFrame
|
|
9
|
-
import v0.relationalai as rai
|
|
10
9
|
from v0.relationalai import debugging
|
|
11
10
|
from v0.relationalai.clients import result_helpers
|
|
12
|
-
from v0.relationalai.clients.snowflake import APP_NAME
|
|
11
|
+
from v0.relationalai.clients.resources.snowflake import APP_NAME
|
|
13
12
|
from v0.relationalai.early_access.dsl.ir.compiler import Compiler
|
|
14
13
|
from v0.relationalai.early_access.dsl.ontologies.models import Model
|
|
15
14
|
from v0.relationalai.semantics.metamodel import ir
|
|
@@ -37,7 +36,8 @@ class RelExecutor:
|
|
|
37
36
|
if not self._resources:
|
|
38
37
|
with debugging.span("create_session"):
|
|
39
38
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
40
|
-
|
|
39
|
+
from v0.relationalai.clients.resources.snowflake import Resources
|
|
40
|
+
self._resources = Resources(
|
|
41
41
|
dry_run=self.dry_run,
|
|
42
42
|
config=self.config,
|
|
43
43
|
generation=Generation.QB,
|
|
@@ -257,4 +257,4 @@ class RelExecutor:
|
|
|
257
257
|
if raw:
|
|
258
258
|
dataframe, errors = result_helpers.format_results(raw, None, result_cols)
|
|
259
259
|
self.report_errors(errors)
|
|
260
|
-
return DataFrame()
|
|
260
|
+
return DataFrame()
|
|
@@ -2,13 +2,14 @@ from typing import cast, Optional
|
|
|
2
2
|
|
|
3
3
|
import v0.relationalai as rai
|
|
4
4
|
from v0.relationalai import Config
|
|
5
|
+
from v0.relationalai.clients.resources.snowflake import Provider
|
|
5
6
|
from v0.relationalai.early_access.dsl.snow.common import TabularMetadata, ColumnMetadata, SchemaMetadata, \
|
|
6
7
|
ForeignKey, ColumnRef
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class Executor:
|
|
10
11
|
def __init__(self, config: Optional[Config] = None):
|
|
11
|
-
self._provider = cast(
|
|
12
|
+
self._provider = cast(Provider, rai.Provider(config=config))
|
|
12
13
|
self._table_meta_cache = {}
|
|
13
14
|
self._schema_fk_cache = {}
|
|
14
15
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import time
|
|
3
|
-
from typing import Any, List, Optional
|
|
3
|
+
from typing import Any, List, Optional, cast
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
import textwrap
|
|
6
6
|
from .. import dsl, std
|
|
@@ -14,14 +14,13 @@ import uuid
|
|
|
14
14
|
import v0.relationalai
|
|
15
15
|
import json
|
|
16
16
|
from ..clients.util import poll_with_specified_overhead
|
|
17
|
-
from ..clients.snowflake import Resources as SnowflakeResources
|
|
18
|
-
from ..clients.snowflake import
|
|
17
|
+
from ..clients.resources.snowflake import Resources as SnowflakeResources, APP_NAME
|
|
18
|
+
from ..clients.resources.snowflake import DirectAccessResources
|
|
19
|
+
from ..clients.direct_access_client import DirectAccessClient
|
|
19
20
|
from ..util.timeout import calc_remaining_timeout_minutes
|
|
20
21
|
|
|
21
22
|
rel_sv = rel._tagged(Builtins.SingleValued)
|
|
22
23
|
|
|
23
|
-
APP_NAME = v0.relationalai.clients.snowflake.APP_NAME
|
|
24
|
-
|
|
25
24
|
ENGINE_TYPE_SOLVER = "SOLVER"
|
|
26
25
|
# TODO (dba) The ERP still uses `worker` instead of `engine`. Change
|
|
27
26
|
# this once we fix this in the ERP.
|
|
@@ -741,10 +740,11 @@ class Provider:
|
|
|
741
740
|
def __init__(self, resources=None):
|
|
742
741
|
if not resources:
|
|
743
742
|
resources = v0.relationalai.Resources()
|
|
744
|
-
if not isinstance(resources,
|
|
743
|
+
if not isinstance(resources, SnowflakeResources):
|
|
745
744
|
raise Exception("Solvers are only supported on SPCS.")
|
|
746
745
|
|
|
747
|
-
|
|
746
|
+
# Type narrowing: resources is confirmed to be SnowflakeResources
|
|
747
|
+
self.resources: SnowflakeResources = cast(SnowflakeResources, resources)
|
|
748
748
|
self.direct_access_client: Optional[DirectAccessClient] = None
|
|
749
749
|
|
|
750
750
|
if isinstance(self.resources, DirectAccessResources):
|
|
@@ -5,12 +5,11 @@ import argparse
|
|
|
5
5
|
import os
|
|
6
6
|
import json
|
|
7
7
|
|
|
8
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_api
|
|
8
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api, APP_NAME
|
|
9
9
|
from v0.relationalai.semantics.lqp.executor import LQPExecutor
|
|
10
10
|
from v0.relationalai.semantics.internal import internal
|
|
11
|
-
from v0.relationalai.clients.use_index_poller import UseIndexPoller as index_poller
|
|
11
|
+
from v0.relationalai.clients.resources.snowflake.use_index_poller import UseIndexPoller as index_poller
|
|
12
12
|
from snowflake.connector.cursor import DictCursor
|
|
13
|
-
from v0.relationalai.clients import snowflake
|
|
14
13
|
|
|
15
14
|
from enum import Enum
|
|
16
15
|
|
|
@@ -172,7 +171,7 @@ def _exec_snowflake_override(bench_ctx, old_func, marker):
|
|
|
172
171
|
def new_func(self, code, params, raw=False):
|
|
173
172
|
cur = self._session.connection.cursor(DictCursor)
|
|
174
173
|
try:
|
|
175
|
-
cur.execute(code.replace(
|
|
174
|
+
cur.execute(code.replace(APP_NAME, self.get_app_name()), params)
|
|
176
175
|
rows = cur.fetchall()
|
|
177
176
|
qid = str(getattr(cur, "sfqid", None))
|
|
178
177
|
assert qid is not None, "Snowflake query ID was not available"
|
|
@@ -398,7 +397,7 @@ def get_sf_query_info(bench_ctx):
|
|
|
398
397
|
return result
|
|
399
398
|
|
|
400
399
|
def _get_query_info(qids):
|
|
401
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_client
|
|
400
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_client
|
|
402
401
|
client = snowflake_client()
|
|
403
402
|
|
|
404
403
|
qids_str = "','".join(qids)
|
|
@@ -7,7 +7,7 @@ import os
|
|
|
7
7
|
import json
|
|
8
8
|
from contextlib import contextmanager
|
|
9
9
|
|
|
10
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_api
|
|
10
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api
|
|
11
11
|
from v0.relationalai.semantics.internal import internal
|
|
12
12
|
from typing import Dict, Optional
|
|
13
13
|
|
|
@@ -61,7 +61,7 @@ def get_session():
|
|
|
61
61
|
_session = get_active_session()
|
|
62
62
|
except Exception:
|
|
63
63
|
from v0.relationalai import Resources
|
|
64
|
-
from v0.relationalai.clients.snowflake import Resources as SnowflakeResources
|
|
64
|
+
from v0.relationalai.clients.resources.snowflake import Resources as SnowflakeResources
|
|
65
65
|
# TODO: we need a better way to handle global config
|
|
66
66
|
|
|
67
67
|
# using the resource constructor to differentiate between direct access and
|
|
@@ -7,7 +7,6 @@ import re
|
|
|
7
7
|
from pandas import DataFrame
|
|
8
8
|
from typing import Any, Optional, Literal, TYPE_CHECKING
|
|
9
9
|
from snowflake.snowpark import Session
|
|
10
|
-
import v0.relationalai as rai
|
|
11
10
|
|
|
12
11
|
from v0.relationalai import debugging
|
|
13
12
|
from v0.relationalai.errors import NonDefaultLQPSemanticsVersionWarning
|
|
@@ -21,10 +20,10 @@ from lqp import print as lqp_print, ir as lqp_ir
|
|
|
21
20
|
from lqp.parser import construct_configure
|
|
22
21
|
from v0.relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
|
|
23
22
|
from v0.relationalai.clients.config import Config
|
|
24
|
-
from v0.relationalai.clients.snowflake import APP_NAME
|
|
23
|
+
from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
|
|
25
24
|
from v0.relationalai.clients.types import TransactionAsyncResponse
|
|
26
25
|
from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
|
|
27
|
-
from v0.relationalai.tools.constants import
|
|
26
|
+
from v0.relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
|
|
28
27
|
from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
29
28
|
|
|
30
29
|
if TYPE_CHECKING:
|
|
@@ -67,17 +66,11 @@ class LQPExecutor(e.Executor):
|
|
|
67
66
|
if not self._resources:
|
|
68
67
|
with debugging.span("create_session"):
|
|
69
68
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
70
|
-
resource_class = rai.clients.snowflake.Resources
|
|
71
|
-
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
72
|
-
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
73
|
-
if self.config.get("platform", "") == "local":
|
|
74
|
-
resource_class = rai.clients.local.LocalResources
|
|
75
69
|
# NOTE: language="lqp" is not strictly required for LQP execution, but it
|
|
76
70
|
# will significantly improve performance.
|
|
77
|
-
self._resources =
|
|
78
|
-
dry_run=self.dry_run,
|
|
71
|
+
self._resources = create_resources_instance(
|
|
79
72
|
config=self.config,
|
|
80
|
-
|
|
73
|
+
dry_run=self.dry_run,
|
|
81
74
|
connection=self.connection,
|
|
82
75
|
language="lqp",
|
|
83
76
|
)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable
|
|
2
|
+
from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable, Union
|
|
3
|
+
import types
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
|
|
5
6
|
#--------------------------------------------------
|
|
@@ -345,7 +346,7 @@ def rewrite_set(t: type[T], f: Callable[[T], T], items: FrozenOrderedSet[T]) ->
|
|
|
345
346
|
return items
|
|
346
347
|
return ordered_set(*new_items).frozen()
|
|
347
348
|
|
|
348
|
-
def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
|
|
349
|
+
def rewrite_list(t: Union[type[T], types.UnionType], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
|
|
349
350
|
""" Map a function over a list, returning a new list with the results. Avoid allocating a new list if the function is the identity. """
|
|
350
351
|
new_items: Optional[list[T]] = None
|
|
351
352
|
for i in range(len(items)):
|
|
@@ -359,15 +360,15 @@ def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple
|
|
|
359
360
|
return items
|
|
360
361
|
return tuple(new_items)
|
|
361
362
|
|
|
362
|
-
def flatten_iter(items: Iterable[object], t: type[T]) -> Generator[T, None, None]:
|
|
363
|
+
def flatten_iter(items: Iterable[object], t: Union[type[T], types.UnionType]) -> Generator[T, None, None]:
|
|
363
364
|
"""Yield items from a nested iterable structure one at a time."""
|
|
364
365
|
for item in items:
|
|
365
366
|
if isinstance(item, (list, tuple, OrderedSet)):
|
|
366
367
|
yield from flatten_iter(item, t)
|
|
367
368
|
elif isinstance(item, t):
|
|
368
|
-
yield item
|
|
369
|
+
yield cast(T, item)
|
|
369
370
|
|
|
370
|
-
def flatten_tuple(items: Iterable[object], t: type[T]) -> tuple[T, ...]:
|
|
371
|
+
def flatten_tuple(items: Iterable[object], t: Union[type[T], types.UnionType]) -> tuple[T, ...]:
|
|
371
372
|
""" Flatten the nested iterable structure into a tuple."""
|
|
372
373
|
return tuple(flatten_iter(items, t))
|
|
373
374
|
|
|
@@ -9,16 +9,15 @@ import uuid
|
|
|
9
9
|
from pandas import DataFrame
|
|
10
10
|
from typing import Any, Optional, Literal, TYPE_CHECKING
|
|
11
11
|
from snowflake.snowpark import Session
|
|
12
|
-
import v0.relationalai as rai
|
|
13
12
|
|
|
14
13
|
from v0.relationalai import debugging
|
|
15
14
|
from v0.relationalai.clients import result_helpers
|
|
16
15
|
from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
|
|
17
|
-
from v0.relationalai.clients.snowflake import APP_NAME
|
|
16
|
+
from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
|
|
18
17
|
from v0.relationalai.semantics.metamodel import ir, executor as e, factory as f
|
|
19
18
|
from v0.relationalai.semantics.rel import Compiler
|
|
20
19
|
from v0.relationalai.clients.config import Config
|
|
21
|
-
from v0.relationalai.tools.constants import
|
|
20
|
+
from v0.relationalai.tools.constants import Generation, QUERY_ATTRIBUTES_HEADER
|
|
22
21
|
from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
23
22
|
|
|
24
23
|
if TYPE_CHECKING:
|
|
@@ -53,15 +52,11 @@ class RelExecutor(e.Executor):
|
|
|
53
52
|
if not self._resources:
|
|
54
53
|
with debugging.span("create_session"):
|
|
55
54
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
56
|
-
resource_class = rai.clients.snowflake.Resources
|
|
57
|
-
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
58
|
-
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
59
55
|
# NOTE: language="rel" is required for Rel execution. It is the default, but
|
|
60
56
|
# we set it explicitly here to be sure.
|
|
61
|
-
self._resources =
|
|
62
|
-
dry_run=self.dry_run,
|
|
57
|
+
self._resources = create_resources_instance(
|
|
63
58
|
config=self.config,
|
|
64
|
-
|
|
59
|
+
dry_run=self.dry_run,
|
|
65
60
|
connection=self.connection,
|
|
66
61
|
language="rel",
|
|
67
62
|
)
|
|
@@ -163,13 +158,20 @@ class RelExecutor(e.Executor):
|
|
|
163
158
|
raise errors.RAIExceptionSet(all_errors)
|
|
164
159
|
|
|
165
160
|
def _export(self, raw_code: str, dest: Table, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
|
|
161
|
+
# _export is Snowflake-specific and requires Snowflake Resources
|
|
162
|
+
# It calls Snowflake stored procedures (APP_NAME.api.exec_into_table, etc.)
|
|
163
|
+
# LocalResources doesn't support this functionality
|
|
164
|
+
from v0.relationalai.clients.local import LocalResources
|
|
165
|
+
if isinstance(self.resources, LocalResources):
|
|
166
|
+
raise NotImplementedError("Export functionality is not supported in local mode. Use Snowflake Resources instead.")
|
|
167
|
+
|
|
166
168
|
_exec = self.resources._exec
|
|
167
169
|
output_table = "out" + str(uuid.uuid4()).replace("-", "_")
|
|
168
170
|
txn_id = None
|
|
169
171
|
artifacts = None
|
|
170
172
|
dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
|
|
171
173
|
dest_fqn = dest._fqn
|
|
172
|
-
assert self.resources._session
|
|
174
|
+
assert self.resources._session # All Snowflake Resources have _session
|
|
173
175
|
with debugging.span("transaction"):
|
|
174
176
|
try:
|
|
175
177
|
with debugging.span("exec_format") as span:
|
|
@@ -258,7 +260,7 @@ class RelExecutor(e.Executor):
|
|
|
258
260
|
SELECT 1
|
|
259
261
|
FROM {dest_database}.INFORMATION_SCHEMA.TABLES
|
|
260
262
|
WHERE table_schema = '{dest_schema}'
|
|
261
|
-
|
|
263
|
+
AND table_name = '{dest_table}'
|
|
262
264
|
)) THEN
|
|
263
265
|
EXECUTE IMMEDIATE 'TRUNCATE TABLE {dest_fqn}';
|
|
264
266
|
END IF;
|
|
@@ -267,6 +269,7 @@ class RelExecutor(e.Executor):
|
|
|
267
269
|
else:
|
|
268
270
|
raise e
|
|
269
271
|
if txn_id:
|
|
272
|
+
# These methods are available on all Snowflake Resources
|
|
270
273
|
artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
|
|
271
274
|
with debugging.span("fetch"):
|
|
272
275
|
artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
|