relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. relationalai/shims/executor.py +4 -1
  2. relationalai/shims/mm2v0.py +15 -10
  3. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
  4. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +39 -30
  5. v0/relationalai/__init__.py +69 -22
  6. v0/relationalai/clients/__init__.py +15 -2
  7. v0/relationalai/clients/client.py +4 -4
  8. v0/relationalai/clients/local.py +5 -5
  9. v0/relationalai/clients/resources/__init__.py +8 -0
  10. v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  11. v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
  12. v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  13. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
  14. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  15. v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  16. v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  17. v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
  18. v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
  19. v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  20. v0/relationalai/clients/resources/snowflake/util.py +387 -0
  21. v0/relationalai/early_access/dsl/ir/executor.py +4 -4
  22. v0/relationalai/early_access/dsl/snow/api.py +2 -1
  23. v0/relationalai/experimental/solvers.py +7 -7
  24. v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  25. v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
  26. v0/relationalai/semantics/internal/snowflake.py +1 -1
  27. v0/relationalai/semantics/lqp/executor.py +4 -11
  28. v0/relationalai/semantics/metamodel/util.py +6 -5
  29. v0/relationalai/semantics/rel/executor.py +14 -11
  30. v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
  31. v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  32. v0/relationalai/tools/cli.py +26 -30
  33. v0/relationalai/tools/cli_helpers.py +10 -2
  34. v0/relationalai/util/otel_configuration.py +2 -1
  35. v0/relationalai/util/otel_handler.py +1 -1
  36. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
  37. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
  38. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
  39. /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
@@ -0,0 +1,387 @@
1
+ """
2
+ Utility functions for Snowflake resources.
3
+ """
4
+ from __future__ import annotations
5
+ import re
6
+ import decimal
7
+ import base64
8
+ from numbers import Number
9
+ from datetime import datetime, date
10
+ from typing import List, Any, Dict, cast
11
+
12
+ from .... import dsl
13
+ from ....environments import runtime_env, SnowbookEnvironment
14
+
15
+ # warehouse-based snowflake notebooks currently don't have hazmat
16
+ crypto_disabled = False
17
+ try:
18
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
19
+ from cryptography.hazmat.backends import default_backend
20
+ from cryptography.hazmat.primitives import padding
21
+ except (ModuleNotFoundError, ImportError):
22
+ crypto_disabled = True
23
+
24
+ # Constants used by helper functions
25
+ ENGINE_ERRORS = ("engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted")
26
+ ENGINE_NOT_READY_MSGS = ("engine is in pending", "engine is provisioning")
27
+ DATABASE_ERRORS = ("database not found",)
28
+
29
+ # Constants for import/export and transaction processing
30
+ VALID_IMPORT_STATES = ("PENDING", "PROCESSING", "QUARANTINED", "LOADED")
31
+ IMPORT_STREAM_FIELDS = (
32
+ "ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS",
33
+ "FQ_OBJECT_NAME", "RAI_DATABASE", "RAI_RELATION", "DATA_SYNC_STATUS",
34
+ "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
35
+ "NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"
36
+ )
37
+ FIELD_MAP = {
38
+ "database_name": "database",
39
+ "engine_name": "engine",
40
+ }
41
+
42
+
43
+ def process_jinja_template(template: str, indent_spaces: int = 0, **substitutions: Any) -> str:
44
+ """Process a Jinja-like template.
45
+
46
+ Supports:
47
+ - Variable substitution {{ var }}
48
+ - Conditional blocks {% if condition %} ... {% endif %}
49
+ - For loops {% for item in items %} ... {% endfor %}
50
+ - Comments {# ... #}
51
+ - Whitespace control with {%- and -%}
52
+
53
+ Args:
54
+ template: The template string
55
+ indent_spaces: Number of spaces to indent the result
56
+ **substitutions: Variable substitutions
57
+ """
58
+
59
+ def evaluate_condition(condition: str, context: dict) -> bool:
60
+ """Safely evaluate a condition string using the context."""
61
+ # Replace variables with their values
62
+ for k, v in context.items():
63
+ if isinstance(v, str):
64
+ condition = condition.replace(k, f"'{v}'")
65
+ else:
66
+ condition = condition.replace(k, str(v))
67
+ try:
68
+ return bool(eval(condition, {"__builtins__": {}}, {}))
69
+ except Exception:
70
+ return False
71
+
72
+ def process_expression(expr: str, context: dict) -> str:
73
+ """Process a {{ expression }} block."""
74
+ expr = expr.strip()
75
+ if expr in context:
76
+ return str(context[expr])
77
+ return ""
78
+
79
+ def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
80
+ """Process a block of template lines recursively."""
81
+ result = []
82
+ i = 0
83
+ while i < len(lines):
84
+ line = lines[i]
85
+
86
+ # Handle comments
87
+ line = re.sub(r'{#.*?#}', '', line)
88
+
89
+ # Handle if blocks
90
+ if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
91
+ if if_match:
92
+ condition = if_match.group(1)
93
+ if_block = []
94
+ else_block = []
95
+ i += 1
96
+ nesting = 1
97
+ in_else_block = False
98
+ while i < len(lines) and nesting > 0:
99
+ if re.search(r'{%\s*if\s+', lines[i]):
100
+ nesting += 1
101
+ elif re.search(r'{%\s*endif\s*%}', lines[i]):
102
+ nesting -= 1
103
+ elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
104
+ in_else_block = True
105
+ i += 1
106
+ continue
107
+
108
+ if nesting > 0:
109
+ if in_else_block:
110
+ else_block.append(lines[i])
111
+ else:
112
+ if_block.append(lines[i])
113
+ i += 1
114
+ if evaluate_condition(condition, context):
115
+ result.extend(process_block(if_block, context, indent))
116
+ else:
117
+ result.extend(process_block(else_block, context, indent))
118
+ continue
119
+
120
+ # Handle for loops
121
+ for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
122
+ if for_match:
123
+ var_name, iterable_name = for_match.groups()
124
+ for_block = []
125
+ i += 1
126
+ nesting = 1
127
+ while i < len(lines) and nesting > 0:
128
+ if re.search(r'{%\s*for\s+', lines[i]):
129
+ nesting += 1
130
+ elif re.search(r'{%\s*endfor\s*%}', lines[i]):
131
+ nesting -= 1
132
+ if nesting > 0:
133
+ for_block.append(lines[i])
134
+ i += 1
135
+ if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
136
+ for item in context[iterable_name]:
137
+ loop_context = dict(context)
138
+ loop_context[var_name] = item
139
+ result.extend(process_block(for_block, loop_context, indent))
140
+ continue
141
+
142
+ # Handle variable substitution
143
+ line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
144
+
145
+ # Handle whitespace control
146
+ line = re.sub(r'{%-', '{%', line)
147
+ line = re.sub(r'-%}', '%}', line)
148
+
149
+ # Add line with proper indentation, preserving blank lines
150
+ if line.strip():
151
+ result.append(" " * (indent_spaces + indent) + line)
152
+ else:
153
+ result.append("")
154
+
155
+ i += 1
156
+
157
+ return result
158
+
159
+ # Split template into lines and process
160
+ lines = template.split('\n')
161
+ processed_lines = process_block(lines, substitutions)
162
+
163
+ return '\n'.join(processed_lines)
164
+
165
+
166
+ def type_to_sql(type_obj: Any) -> str:
167
+ if type_obj is str:
168
+ return "VARCHAR"
169
+ if type_obj is int:
170
+ return "NUMBER"
171
+ if type_obj is Number:
172
+ return "DECIMAL(38, 15)"
173
+ if type_obj is float:
174
+ return "FLOAT"
175
+ if type_obj is decimal.Decimal:
176
+ return "DECIMAL(38, 15)"
177
+ if type_obj is bool:
178
+ return "BOOLEAN"
179
+ if type_obj is dict:
180
+ return "VARIANT"
181
+ if type_obj is list:
182
+ return "ARRAY"
183
+ if type_obj is bytes:
184
+ return "BINARY"
185
+ if type_obj is datetime:
186
+ return "TIMESTAMP"
187
+ if type_obj is date:
188
+ return "DATE"
189
+ if isinstance(type_obj, dsl.Type):
190
+ return "VARCHAR"
191
+ raise ValueError(f"Unknown type {type_obj}")
192
+
193
+
194
+ def type_to_snowpark(type_obj: Any) -> str:
195
+ if type_obj is str:
196
+ return "StringType()"
197
+ if type_obj is int:
198
+ return "IntegerType()"
199
+ if type_obj is float:
200
+ return "FloatType()"
201
+ if type_obj is Number:
202
+ return "DecimalType(38, 15)"
203
+ if type_obj is decimal.Decimal:
204
+ return "DecimalType(38, 15)"
205
+ if type_obj is bool:
206
+ return "BooleanType()"
207
+ if type_obj is dict:
208
+ return "MapType()"
209
+ if type_obj is list:
210
+ return "ArrayType()"
211
+ if type_obj is bytes:
212
+ return "BinaryType()"
213
+ if type_obj is datetime:
214
+ return "TimestampType()"
215
+ if type_obj is date:
216
+ return "DateType()"
217
+ if isinstance(type_obj, dsl.Type):
218
+ return "StringType()"
219
+ raise ValueError(f"Unknown type {type_obj}")
220
+
221
+
222
+ def sanitize_user_name(user: str) -> str:
223
+ """Sanitize a user name by extracting the part before '@' and replacing invalid characters."""
224
+ # Extract the part before the '@'
225
+ sanitized_user = user.split('@')[0]
226
+ # Replace any character that is not a letter, number, or underscore with '_'
227
+ sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
228
+ return sanitized_user
229
+
230
+
231
+ def is_engine_issue(response_message: str) -> bool:
232
+ """Check if a response message indicates an engine issue."""
233
+ return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
234
+
235
+
236
+ def is_database_issue(response_message: str) -> bool:
237
+ """Check if a response message indicates a database issue."""
238
+ return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
239
+
240
+
241
+ def collect_error_messages(e: Exception) -> list[str]:
242
+ """Collect all error messages from an exception and its chain.
243
+
244
+ Extracts messages from:
245
+ - str(e)
246
+ - e.message (if present, e.g., RAIException)
247
+ - e.args (string arguments)
248
+ - e.__cause__
249
+ - e.__context__
250
+ - Nested JavaScript execution errors
251
+ """
252
+ messages = [str(e).lower()]
253
+
254
+ # Check message attribute (RAIException has this)
255
+ if hasattr(e, 'message'):
256
+ msg = getattr(e, 'message', None)
257
+ if isinstance(msg, str):
258
+ messages.append(msg.lower())
259
+
260
+ # Check args
261
+ if hasattr(e, 'args') and e.args:
262
+ for arg in e.args:
263
+ if isinstance(arg, str):
264
+ messages.append(arg.lower())
265
+
266
+ # Check cause and context
267
+ if hasattr(e, '__cause__') and e.__cause__:
268
+ messages.append(str(e.__cause__).lower())
269
+ if hasattr(e, '__context__') and e.__context__:
270
+ messages.append(str(e.__context__).lower())
271
+
272
+ # Extract nested messages from JavaScript execution errors
273
+ for msg in messages[:]: # Copy to avoid modification during iteration
274
+ if re.search(r"javascript execution error", msg):
275
+ matches = re.findall(r'"message"\s*:\s*"([^"]+)"', msg, re.IGNORECASE)
276
+ messages.extend([m.lower() for m in matches])
277
+
278
+ return messages
279
+
280
+
281
+ #--------------------------------------------------
282
+ # Parameter and Data Transformation Utilities
283
+ #--------------------------------------------------
284
+
285
+ def normalize_params(params: List[Any] | Any | None) -> List[Any] | None:
286
+ """Normalize parameters to a list format."""
287
+ if params is not None and not isinstance(params, list):
288
+ return cast(List[Any], [params])
289
+ return params
290
+
291
+
292
+ def format_sproc_name(name: str, type_obj: Any) -> str:
293
+ """Format stored procedure parameter name based on type."""
294
+ if type_obj is datetime:
295
+ return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
296
+ return name
297
+
298
+
299
+ def is_azure_url(url: str) -> bool:
300
+ """Check if a URL is an Azure blob storage URL."""
301
+ return "blob.core.windows.net" in url
302
+
303
+
304
+ def is_container_runtime() -> bool:
305
+ """Check if running in a container runtime environment."""
306
+ return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
307
+
308
+
309
+ #--------------------------------------------------
310
+ # Import/Export Utilities
311
+ #--------------------------------------------------
312
+
313
+ def is_valid_import_state(state: str) -> bool:
314
+ """Check if an import state is valid."""
315
+ return state in VALID_IMPORT_STATES
316
+
317
+
318
+ def imports_to_dicts(results: List[Any]) -> List[Dict[str, Any]]:
319
+ """Convert import results to dictionaries with lowercase keys."""
320
+ parsed_results = [
321
+ {field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
322
+ for row in results
323
+ ]
324
+ return parsed_results
325
+
326
+
327
+ #--------------------------------------------------
328
+ # Transaction Utilities
329
+ #--------------------------------------------------
330
+
331
+ def txn_list_to_dicts(transactions: List[Any]) -> List[Dict[str, Any]]:
332
+ """Convert transaction list to dictionaries with field mapping."""
333
+ dicts = []
334
+ for txn in transactions:
335
+ dict = {}
336
+ txn_dict = txn.asDict()
337
+ for key in txn_dict:
338
+ mapValue = FIELD_MAP.get(key.lower())
339
+ if mapValue:
340
+ dict[mapValue] = txn_dict[key]
341
+ else:
342
+ dict[key.lower()] = txn_dict[key]
343
+ dicts.append(dict)
344
+ return dicts
345
+
346
+
347
+ #--------------------------------------------------
348
+ # Encryption Utilities
349
+ #--------------------------------------------------
350
+
351
+ def decrypt_stream(key: bytes, iv: bytes, src: bytes) -> bytes:
352
+ """Decrypt the provided stream with PKCS#5 padding handling."""
353
+ if crypto_disabled:
354
+ if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
355
+ raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
356
+ else:
357
+ raise Exception("library `cryptography.hazmat` missing; please install")
358
+
359
+ # `type:ignore`s are because of the conditional import, which
360
+ # we have because warehouse-based snowflake notebooks don't support
361
+ # the crypto library we're using.
362
+ cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
363
+ decryptor = cipher.decryptor()
364
+
365
+ # Decrypt the data
366
+ decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
367
+
368
+ # Unpad the decrypted data using PKCS#5
369
+ unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
370
+ unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
371
+
372
+ return unpadded_data
373
+
374
+
375
+ def decrypt_artifact(data: bytes, encryption_material: str) -> bytes:
376
+ """Decrypts the artifact data using provided encryption material."""
377
+ encryption_material_parts = encryption_material.split("|")
378
+ assert len(encryption_material_parts) == 3, "Invalid encryption material"
379
+
380
+ algorithm, key_base64, iv_base64 = encryption_material_parts
381
+ assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
382
+
383
+ key = base64.standard_b64decode(key_base64)
384
+ iv = base64.standard_b64decode(iv_base64)
385
+
386
+ return decrypt_stream(key, iv, data)
387
+
@@ -6,10 +6,9 @@ from collections import defaultdict
6
6
  from typing import Any, List, Optional
7
7
 
8
8
  from pandas import DataFrame
9
- import v0.relationalai as rai
10
9
  from v0.relationalai import debugging
11
10
  from v0.relationalai.clients import result_helpers
12
- from v0.relationalai.clients.snowflake import APP_NAME
11
+ from v0.relationalai.clients.resources.snowflake import APP_NAME
13
12
  from v0.relationalai.early_access.dsl.ir.compiler import Compiler
14
13
  from v0.relationalai.early_access.dsl.ontologies.models import Model
15
14
  from v0.relationalai.semantics.metamodel import ir
@@ -37,7 +36,8 @@ class RelExecutor:
37
36
  if not self._resources:
38
37
  with debugging.span("create_session"):
39
38
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
40
- self._resources = rai.clients.snowflake.Resources(
39
+ from v0.relationalai.clients.resources.snowflake import Resources
40
+ self._resources = Resources(
41
41
  dry_run=self.dry_run,
42
42
  config=self.config,
43
43
  generation=Generation.QB,
@@ -257,4 +257,4 @@ class RelExecutor:
257
257
  if raw:
258
258
  dataframe, errors = result_helpers.format_results(raw, None, result_cols)
259
259
  self.report_errors(errors)
260
- return DataFrame()
260
+ return DataFrame()
@@ -2,13 +2,14 @@ from typing import cast, Optional
2
2
 
3
3
  import v0.relationalai as rai
4
4
  from v0.relationalai import Config
5
+ from v0.relationalai.clients.resources.snowflake import Provider
5
6
  from v0.relationalai.early_access.dsl.snow.common import TabularMetadata, ColumnMetadata, SchemaMetadata, \
6
7
  ForeignKey, ColumnRef
7
8
 
8
9
 
9
10
  class Executor:
10
11
  def __init__(self, config: Optional[Config] = None):
11
- self._provider = cast(rai.clients.snowflake.Provider, rai.Provider(config=config))
12
+ self._provider = cast(Provider, rai.Provider(config=config))
12
13
  self._table_meta_cache = {}
13
14
  self._schema_fk_cache = {}
14
15
 
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import time
3
- from typing import Any, List, Optional
3
+ from typing import Any, List, Optional, cast
4
4
  from dataclasses import dataclass
5
5
  import textwrap
6
6
  from .. import dsl, std
@@ -14,14 +14,13 @@ import uuid
14
14
  import v0.relationalai
15
15
  import json
16
16
  from ..clients.util import poll_with_specified_overhead
17
- from ..clients.snowflake import Resources as SnowflakeResources
18
- from ..clients.snowflake import DirectAccessClient, DirectAccessResources
17
+ from ..clients.resources.snowflake import Resources as SnowflakeResources, APP_NAME
18
+ from ..clients.resources.snowflake import DirectAccessResources
19
+ from ..clients.direct_access_client import DirectAccessClient
19
20
  from ..util.timeout import calc_remaining_timeout_minutes
20
21
 
21
22
  rel_sv = rel._tagged(Builtins.SingleValued)
22
23
 
23
- APP_NAME = v0.relationalai.clients.snowflake.APP_NAME
24
-
25
24
  ENGINE_TYPE_SOLVER = "SOLVER"
26
25
  # TODO (dba) The ERP still uses `worker` instead of `engine`. Change
27
26
  # this once we fix this in the ERP.
@@ -741,10 +740,11 @@ class Provider:
741
740
  def __init__(self, resources=None):
742
741
  if not resources:
743
742
  resources = v0.relationalai.Resources()
744
- if not isinstance(resources, v0.relationalai.clients.snowflake.Resources):
743
+ if not isinstance(resources, SnowflakeResources):
745
744
  raise Exception("Solvers are only supported on SPCS.")
746
745
 
747
- self.resources = resources
746
+ # Type narrowing: resources is confirmed to be SnowflakeResources
747
+ self.resources: SnowflakeResources = cast(SnowflakeResources, resources)
748
748
  self.direct_access_client: Optional[DirectAccessClient] = None
749
749
 
750
750
  if isinstance(self.resources, DirectAccessResources):
@@ -5,12 +5,11 @@ import argparse
5
5
  import os
6
6
  import json
7
7
 
8
- from v0.relationalai.clients.snowflake import Resources as snowflake_api
8
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api, APP_NAME
9
9
  from v0.relationalai.semantics.lqp.executor import LQPExecutor
10
10
  from v0.relationalai.semantics.internal import internal
11
- from v0.relationalai.clients.use_index_poller import UseIndexPoller as index_poller
11
+ from v0.relationalai.clients.resources.snowflake.use_index_poller import UseIndexPoller as index_poller
12
12
  from snowflake.connector.cursor import DictCursor
13
- from v0.relationalai.clients import snowflake
14
13
 
15
14
  from enum import Enum
16
15
 
@@ -172,7 +171,7 @@ def _exec_snowflake_override(bench_ctx, old_func, marker):
172
171
  def new_func(self, code, params, raw=False):
173
172
  cur = self._session.connection.cursor(DictCursor)
174
173
  try:
175
- cur.execute(code.replace(snowflake.APP_NAME, self.get_app_name()), params)
174
+ cur.execute(code.replace(APP_NAME, self.get_app_name()), params)
176
175
  rows = cur.fetchall()
177
176
  qid = str(getattr(cur, "sfqid", None))
178
177
  assert qid is not None, "Snowflake query ID was not available"
@@ -398,7 +397,7 @@ def get_sf_query_info(bench_ctx):
398
397
  return result
399
398
 
400
399
  def _get_query_info(qids):
401
- from v0.relationalai.clients.snowflake import Resources as snowflake_client
400
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_client
402
401
  client = snowflake_client()
403
402
 
404
403
  qids_str = "','".join(qids)
@@ -7,7 +7,7 @@ import os
7
7
  import json
8
8
  from contextlib import contextmanager
9
9
 
10
- from v0.relationalai.clients.snowflake import Resources as snowflake_api
10
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api
11
11
  from v0.relationalai.semantics.internal import internal
12
12
  from typing import Dict, Optional
13
13
 
@@ -61,7 +61,7 @@ def get_session():
61
61
  _session = get_active_session()
62
62
  except Exception:
63
63
  from v0.relationalai import Resources
64
- from v0.relationalai.clients.snowflake import Resources as SnowflakeResources
64
+ from v0.relationalai.clients.resources.snowflake import Resources as SnowflakeResources
65
65
  # TODO: we need a better way to handle global config
66
66
 
67
67
  # using the resource constructor to differentiate between direct access and
@@ -7,7 +7,6 @@ import re
7
7
  from pandas import DataFrame
8
8
  from typing import Any, Optional, Literal, TYPE_CHECKING
9
9
  from snowflake.snowpark import Session
10
- import v0.relationalai as rai
11
10
 
12
11
  from v0.relationalai import debugging
13
12
  from v0.relationalai.errors import NonDefaultLQPSemanticsVersionWarning
@@ -21,10 +20,10 @@ from lqp import print as lqp_print, ir as lqp_ir
21
20
  from lqp.parser import construct_configure
22
21
  from v0.relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
23
22
  from v0.relationalai.clients.config import Config
24
- from v0.relationalai.clients.snowflake import APP_NAME
23
+ from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
25
24
  from v0.relationalai.clients.types import TransactionAsyncResponse
26
25
  from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
27
- from v0.relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
26
+ from v0.relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
28
27
  from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
29
28
 
30
29
  if TYPE_CHECKING:
@@ -67,17 +66,11 @@ class LQPExecutor(e.Executor):
67
66
  if not self._resources:
68
67
  with debugging.span("create_session"):
69
68
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
70
- resource_class = rai.clients.snowflake.Resources
71
- if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
72
- resource_class = rai.clients.snowflake.DirectAccessResources
73
- if self.config.get("platform", "") == "local":
74
- resource_class = rai.clients.local.LocalResources
75
69
  # NOTE: language="lqp" is not strictly required for LQP execution, but it
76
70
  # will significantly improve performance.
77
- self._resources = resource_class(
78
- dry_run=self.dry_run,
71
+ self._resources = create_resources_instance(
79
72
  config=self.config,
80
- generation=rai.Generation.QB,
73
+ dry_run=self.dry_run,
81
74
  connection=self.connection,
82
75
  language="lqp",
83
76
  )
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
- from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable
2
+ from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable, Union
3
+ import types
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  #--------------------------------------------------
@@ -345,7 +346,7 @@ def rewrite_set(t: type[T], f: Callable[[T], T], items: FrozenOrderedSet[T]) ->
345
346
  return items
346
347
  return ordered_set(*new_items).frozen()
347
348
 
348
- def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
349
+ def rewrite_list(t: Union[type[T], types.UnionType], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
349
350
  """ Map a function over a list, returning a new list with the results. Avoid allocating a new list if the function is the identity. """
350
351
  new_items: Optional[list[T]] = None
351
352
  for i in range(len(items)):
@@ -359,15 +360,15 @@ def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple
359
360
  return items
360
361
  return tuple(new_items)
361
362
 
362
- def flatten_iter(items: Iterable[object], t: type[T]) -> Generator[T, None, None]:
363
+ def flatten_iter(items: Iterable[object], t: Union[type[T], types.UnionType]) -> Generator[T, None, None]:
363
364
  """Yield items from a nested iterable structure one at a time."""
364
365
  for item in items:
365
366
  if isinstance(item, (list, tuple, OrderedSet)):
366
367
  yield from flatten_iter(item, t)
367
368
  elif isinstance(item, t):
368
- yield item
369
+ yield cast(T, item)
369
370
 
370
- def flatten_tuple(items: Iterable[object], t: type[T]) -> tuple[T, ...]:
371
+ def flatten_tuple(items: Iterable[object], t: Union[type[T], types.UnionType]) -> tuple[T, ...]:
371
372
  """ Flatten the nested iterable structure into a tuple."""
372
373
  return tuple(flatten_iter(items, t))
373
374
 
@@ -9,16 +9,15 @@ import uuid
9
9
  from pandas import DataFrame
10
10
  from typing import Any, Optional, Literal, TYPE_CHECKING
11
11
  from snowflake.snowpark import Session
12
- import v0.relationalai as rai
13
12
 
14
13
  from v0.relationalai import debugging
15
14
  from v0.relationalai.clients import result_helpers
16
15
  from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
17
- from v0.relationalai.clients.snowflake import APP_NAME
16
+ from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
18
17
  from v0.relationalai.semantics.metamodel import ir, executor as e, factory as f
19
18
  from v0.relationalai.semantics.rel import Compiler
20
19
  from v0.relationalai.clients.config import Config
21
- from v0.relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
20
+ from v0.relationalai.tools.constants import Generation, QUERY_ATTRIBUTES_HEADER
22
21
  from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
23
22
 
24
23
  if TYPE_CHECKING:
@@ -53,15 +52,11 @@ class RelExecutor(e.Executor):
53
52
  if not self._resources:
54
53
  with debugging.span("create_session"):
55
54
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
56
- resource_class = rai.clients.snowflake.Resources
57
- if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
58
- resource_class = rai.clients.snowflake.DirectAccessResources
59
55
  # NOTE: language="rel" is required for Rel execution. It is the default, but
60
56
  # we set it explicitly here to be sure.
61
- self._resources = resource_class(
62
- dry_run=self.dry_run,
57
+ self._resources = create_resources_instance(
63
58
  config=self.config,
64
- generation=rai.Generation.QB,
59
+ dry_run=self.dry_run,
65
60
  connection=self.connection,
66
61
  language="rel",
67
62
  )
@@ -163,13 +158,20 @@ class RelExecutor(e.Executor):
163
158
  raise errors.RAIExceptionSet(all_errors)
164
159
 
165
160
  def _export(self, raw_code: str, dest: Table, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
161
+ # _export is Snowflake-specific and requires Snowflake Resources
162
+ # It calls Snowflake stored procedures (APP_NAME.api.exec_into_table, etc.)
163
+ # LocalResources doesn't support this functionality
164
+ from v0.relationalai.clients.local import LocalResources
165
+ if isinstance(self.resources, LocalResources):
166
+ raise NotImplementedError("Export functionality is not supported in local mode. Use Snowflake Resources instead.")
167
+
166
168
  _exec = self.resources._exec
167
169
  output_table = "out" + str(uuid.uuid4()).replace("-", "_")
168
170
  txn_id = None
169
171
  artifacts = None
170
172
  dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
171
173
  dest_fqn = dest._fqn
172
- assert self.resources._session
174
+ assert self.resources._session # All Snowflake Resources have _session
173
175
  with debugging.span("transaction"):
174
176
  try:
175
177
  with debugging.span("exec_format") as span:
@@ -258,7 +260,7 @@ class RelExecutor(e.Executor):
258
260
  SELECT 1
259
261
  FROM {dest_database}.INFORMATION_SCHEMA.TABLES
260
262
  WHERE table_schema = '{dest_schema}'
261
- AND table_name = '{dest_table}'
263
+ AND table_name = '{dest_table}'
262
264
  )) THEN
263
265
  EXECUTE IMMEDIATE 'TRUNCATE TABLE {dest_fqn}';
264
266
  END IF;
@@ -267,6 +269,7 @@ class RelExecutor(e.Executor):
267
269
  else:
268
270
  raise e
269
271
  if txn_id:
272
+ # These methods are available on all Snowflake Resources
270
273
  artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
271
274
  with debugging.span("fetch"):
272
275
  artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")