relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. relationalai/config/shims.py +1 -0
  2. relationalai/semantics/__init__.py +7 -1
  3. relationalai/semantics/frontend/base.py +19 -13
  4. relationalai/semantics/frontend/core.py +30 -2
  5. relationalai/semantics/frontend/front_compiler.py +38 -11
  6. relationalai/semantics/frontend/pprint.py +1 -1
  7. relationalai/semantics/metamodel/rewriter.py +6 -2
  8. relationalai/semantics/metamodel/typer.py +70 -26
  9. relationalai/semantics/reasoners/__init__.py +11 -0
  10. relationalai/semantics/reasoners/graph/__init__.py +38 -0
  11. relationalai/semantics/reasoners/graph/core.py +9015 -0
  12. relationalai/shims/executor.py +4 -1
  13. relationalai/shims/hoister.py +9 -0
  14. relationalai/shims/mm2v0.py +47 -34
  15. relationalai/tools/cli/cli.py +138 -0
  16. relationalai/tools/cli/docs.py +394 -0
  17. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
  18. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +57 -43
  19. v0/relationalai/__init__.py +69 -22
  20. v0/relationalai/clients/__init__.py +15 -2
  21. v0/relationalai/clients/client.py +4 -4
  22. v0/relationalai/clients/exec_txn_poller.py +91 -0
  23. v0/relationalai/clients/local.py +5 -5
  24. v0/relationalai/clients/resources/__init__.py +8 -0
  25. v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  26. v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
  27. v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  28. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +717 -0
  29. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  30. v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  31. v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  32. v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +642 -1399
  33. v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +51 -12
  34. v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  35. v0/relationalai/clients/resources/snowflake/util.py +387 -0
  36. v0/relationalai/early_access/dsl/ir/executor.py +4 -4
  37. v0/relationalai/early_access/dsl/snow/api.py +2 -1
  38. v0/relationalai/errors.py +18 -0
  39. v0/relationalai/experimental/solvers.py +7 -7
  40. v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  41. v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
  42. v0/relationalai/semantics/internal/snowflake.py +1 -1
  43. v0/relationalai/semantics/lqp/executor.py +7 -12
  44. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  45. v0/relationalai/semantics/metamodel/util.py +6 -5
  46. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
  47. v0/relationalai/semantics/rel/executor.py +14 -11
  48. v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
  49. v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  50. v0/relationalai/tools/cli.py +26 -30
  51. v0/relationalai/tools/cli_helpers.py +10 -2
  52. v0/relationalai/util/otel_configuration.py +2 -1
  53. v0/relationalai/util/otel_handler.py +1 -1
  54. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
  55. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
  56. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
  57. /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
@@ -0,0 +1,387 @@
1
+ """
2
+ Utility functions for Snowflake resources.
3
+ """
4
+ from __future__ import annotations
5
+ import re
6
+ import decimal
7
+ import base64
8
+ from numbers import Number
9
+ from datetime import datetime, date
10
+ from typing import List, Any, Dict, cast
11
+
12
+ from .... import dsl
13
+ from ....environments import runtime_env, SnowbookEnvironment
14
+
15
+ # warehouse-based snowflake notebooks currently don't have hazmat
16
+ crypto_disabled = False
17
+ try:
18
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
19
+ from cryptography.hazmat.backends import default_backend
20
+ from cryptography.hazmat.primitives import padding
21
+ except (ModuleNotFoundError, ImportError):
22
+ crypto_disabled = True
23
+
24
+ # Constants used by helper functions
25
+ ENGINE_ERRORS = ("engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted")
26
+ ENGINE_NOT_READY_MSGS = ("engine is in pending", "engine is provisioning")
27
+ DATABASE_ERRORS = ("database not found",)
28
+
29
+ # Constants for import/export and transaction processing
30
+ VALID_IMPORT_STATES = ("PENDING", "PROCESSING", "QUARANTINED", "LOADED")
31
+ IMPORT_STREAM_FIELDS = (
32
+ "ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS",
33
+ "FQ_OBJECT_NAME", "RAI_DATABASE", "RAI_RELATION", "DATA_SYNC_STATUS",
34
+ "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
35
+ "NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"
36
+ )
37
+ FIELD_MAP = {
38
+ "database_name": "database",
39
+ "engine_name": "engine",
40
+ }
41
+
42
+
43
+ def process_jinja_template(template: str, indent_spaces: int = 0, **substitutions: Any) -> str:
44
+ """Process a Jinja-like template.
45
+
46
+ Supports:
47
+ - Variable substitution {{ var }}
48
+ - Conditional blocks {% if condition %} ... {% endif %}
49
+ - For loops {% for item in items %} ... {% endfor %}
50
+ - Comments {# ... #}
51
+ - Whitespace control with {%- and -%}
52
+
53
+ Args:
54
+ template: The template string
55
+ indent_spaces: Number of spaces to indent the result
56
+ **substitutions: Variable substitutions
57
+ """
58
+
59
+ def evaluate_condition(condition: str, context: dict) -> bool:
60
+ """Safely evaluate a condition string using the context."""
61
+ # Replace variables with their values
62
+ for k, v in context.items():
63
+ if isinstance(v, str):
64
+ condition = condition.replace(k, f"'{v}'")
65
+ else:
66
+ condition = condition.replace(k, str(v))
67
+ try:
68
+ return bool(eval(condition, {"__builtins__": {}}, {}))
69
+ except Exception:
70
+ return False
71
+
72
+ def process_expression(expr: str, context: dict) -> str:
73
+ """Process a {{ expression }} block."""
74
+ expr = expr.strip()
75
+ if expr in context:
76
+ return str(context[expr])
77
+ return ""
78
+
79
+ def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
80
+ """Process a block of template lines recursively."""
81
+ result = []
82
+ i = 0
83
+ while i < len(lines):
84
+ line = lines[i]
85
+
86
+ # Handle comments
87
+ line = re.sub(r'{#.*?#}', '', line)
88
+
89
+ # Handle if blocks
90
+ if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
91
+ if if_match:
92
+ condition = if_match.group(1)
93
+ if_block = []
94
+ else_block = []
95
+ i += 1
96
+ nesting = 1
97
+ in_else_block = False
98
+ while i < len(lines) and nesting > 0:
99
+ if re.search(r'{%\s*if\s+', lines[i]):
100
+ nesting += 1
101
+ elif re.search(r'{%\s*endif\s*%}', lines[i]):
102
+ nesting -= 1
103
+ elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
104
+ in_else_block = True
105
+ i += 1
106
+ continue
107
+
108
+ if nesting > 0:
109
+ if in_else_block:
110
+ else_block.append(lines[i])
111
+ else:
112
+ if_block.append(lines[i])
113
+ i += 1
114
+ if evaluate_condition(condition, context):
115
+ result.extend(process_block(if_block, context, indent))
116
+ else:
117
+ result.extend(process_block(else_block, context, indent))
118
+ continue
119
+
120
+ # Handle for loops
121
+ for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
122
+ if for_match:
123
+ var_name, iterable_name = for_match.groups()
124
+ for_block = []
125
+ i += 1
126
+ nesting = 1
127
+ while i < len(lines) and nesting > 0:
128
+ if re.search(r'{%\s*for\s+', lines[i]):
129
+ nesting += 1
130
+ elif re.search(r'{%\s*endfor\s*%}', lines[i]):
131
+ nesting -= 1
132
+ if nesting > 0:
133
+ for_block.append(lines[i])
134
+ i += 1
135
+ if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
136
+ for item in context[iterable_name]:
137
+ loop_context = dict(context)
138
+ loop_context[var_name] = item
139
+ result.extend(process_block(for_block, loop_context, indent))
140
+ continue
141
+
142
+ # Handle variable substitution
143
+ line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
144
+
145
+ # Handle whitespace control
146
+ line = re.sub(r'{%-', '{%', line)
147
+ line = re.sub(r'-%}', '%}', line)
148
+
149
+ # Add line with proper indentation, preserving blank lines
150
+ if line.strip():
151
+ result.append(" " * (indent_spaces + indent) + line)
152
+ else:
153
+ result.append("")
154
+
155
+ i += 1
156
+
157
+ return result
158
+
159
+ # Split template into lines and process
160
+ lines = template.split('\n')
161
+ processed_lines = process_block(lines, substitutions)
162
+
163
+ return '\n'.join(processed_lines)
164
+
165
+
166
+ def type_to_sql(type_obj: Any) -> str:
167
+ if type_obj is str:
168
+ return "VARCHAR"
169
+ if type_obj is int:
170
+ return "NUMBER"
171
+ if type_obj is Number:
172
+ return "DECIMAL(38, 15)"
173
+ if type_obj is float:
174
+ return "FLOAT"
175
+ if type_obj is decimal.Decimal:
176
+ return "DECIMAL(38, 15)"
177
+ if type_obj is bool:
178
+ return "BOOLEAN"
179
+ if type_obj is dict:
180
+ return "VARIANT"
181
+ if type_obj is list:
182
+ return "ARRAY"
183
+ if type_obj is bytes:
184
+ return "BINARY"
185
+ if type_obj is datetime:
186
+ return "TIMESTAMP"
187
+ if type_obj is date:
188
+ return "DATE"
189
+ if isinstance(type_obj, dsl.Type):
190
+ return "VARCHAR"
191
+ raise ValueError(f"Unknown type {type_obj}")
192
+
193
+
194
+ def type_to_snowpark(type_obj: Any) -> str:
195
+ if type_obj is str:
196
+ return "StringType()"
197
+ if type_obj is int:
198
+ return "IntegerType()"
199
+ if type_obj is float:
200
+ return "FloatType()"
201
+ if type_obj is Number:
202
+ return "DecimalType(38, 15)"
203
+ if type_obj is decimal.Decimal:
204
+ return "DecimalType(38, 15)"
205
+ if type_obj is bool:
206
+ return "BooleanType()"
207
+ if type_obj is dict:
208
+ return "MapType()"
209
+ if type_obj is list:
210
+ return "ArrayType()"
211
+ if type_obj is bytes:
212
+ return "BinaryType()"
213
+ if type_obj is datetime:
214
+ return "TimestampType()"
215
+ if type_obj is date:
216
+ return "DateType()"
217
+ if isinstance(type_obj, dsl.Type):
218
+ return "StringType()"
219
+ raise ValueError(f"Unknown type {type_obj}")
220
+
221
+
222
+ def sanitize_user_name(user: str) -> str:
223
+ """Sanitize a user name by extracting the part before '@' and replacing invalid characters."""
224
+ # Extract the part before the '@'
225
+ sanitized_user = user.split('@')[0]
226
+ # Replace any character that is not a letter, number, or underscore with '_'
227
+ sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
228
+ return sanitized_user
229
+
230
+
231
+ def is_engine_issue(response_message: str) -> bool:
232
+ """Check if a response message indicates an engine issue."""
233
+ return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
234
+
235
+
236
+ def is_database_issue(response_message: str) -> bool:
237
+ """Check if a response message indicates a database issue."""
238
+ return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
239
+
240
+
241
+ def collect_error_messages(e: Exception) -> list[str]:
242
+ """Collect all error messages from an exception and its chain.
243
+
244
+ Extracts messages from:
245
+ - str(e)
246
+ - e.message (if present, e.g., RAIException)
247
+ - e.args (string arguments)
248
+ - e.__cause__
249
+ - e.__context__
250
+ - Nested JavaScript execution errors
251
+ """
252
+ messages = [str(e).lower()]
253
+
254
+ # Check message attribute (RAIException has this)
255
+ if hasattr(e, 'message'):
256
+ msg = getattr(e, 'message', None)
257
+ if isinstance(msg, str):
258
+ messages.append(msg.lower())
259
+
260
+ # Check args
261
+ if hasattr(e, 'args') and e.args:
262
+ for arg in e.args:
263
+ if isinstance(arg, str):
264
+ messages.append(arg.lower())
265
+
266
+ # Check cause and context
267
+ if hasattr(e, '__cause__') and e.__cause__:
268
+ messages.append(str(e.__cause__).lower())
269
+ if hasattr(e, '__context__') and e.__context__:
270
+ messages.append(str(e.__context__).lower())
271
+
272
+ # Extract nested messages from JavaScript execution errors
273
+ for msg in messages[:]: # Copy to avoid modification during iteration
274
+ if re.search(r"javascript execution error", msg):
275
+ matches = re.findall(r'"message"\s*:\s*"([^"]+)"', msg, re.IGNORECASE)
276
+ messages.extend([m.lower() for m in matches])
277
+
278
+ return messages
279
+
280
+
281
+ #--------------------------------------------------
282
+ # Parameter and Data Transformation Utilities
283
+ #--------------------------------------------------
284
+
285
+ def normalize_params(params: List[Any] | Any | None) -> List[Any] | None:
286
+ """Normalize parameters to a list format."""
287
+ if params is not None and not isinstance(params, list):
288
+ return cast(List[Any], [params])
289
+ return params
290
+
291
+
292
+ def format_sproc_name(name: str, type_obj: Any) -> str:
293
+ """Format stored procedure parameter name based on type."""
294
+ if type_obj is datetime:
295
+ return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
296
+ return name
297
+
298
+
299
+ def is_azure_url(url: str) -> bool:
300
+ """Check if a URL is an Azure blob storage URL."""
301
+ return "blob.core.windows.net" in url
302
+
303
+
304
+ def is_container_runtime() -> bool:
305
+ """Check if running in a container runtime environment."""
306
+ return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
307
+
308
+
309
+ #--------------------------------------------------
310
+ # Import/Export Utilities
311
+ #--------------------------------------------------
312
+
313
+ def is_valid_import_state(state: str) -> bool:
314
+ """Check if an import state is valid."""
315
+ return state in VALID_IMPORT_STATES
316
+
317
+
318
+ def imports_to_dicts(results: List[Any]) -> List[Dict[str, Any]]:
319
+ """Convert import results to dictionaries with lowercase keys."""
320
+ parsed_results = [
321
+ {field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
322
+ for row in results
323
+ ]
324
+ return parsed_results
325
+
326
+
327
+ #--------------------------------------------------
328
+ # Transaction Utilities
329
+ #--------------------------------------------------
330
+
331
+ def txn_list_to_dicts(transactions: List[Any]) -> List[Dict[str, Any]]:
332
+ """Convert transaction list to dictionaries with field mapping."""
333
+ dicts = []
334
+ for txn in transactions:
335
+ dict = {}
336
+ txn_dict = txn.asDict()
337
+ for key in txn_dict:
338
+ mapValue = FIELD_MAP.get(key.lower())
339
+ if mapValue:
340
+ dict[mapValue] = txn_dict[key]
341
+ else:
342
+ dict[key.lower()] = txn_dict[key]
343
+ dicts.append(dict)
344
+ return dicts
345
+
346
+
347
+ #--------------------------------------------------
348
+ # Encryption Utilities
349
+ #--------------------------------------------------
350
+
351
+ def decrypt_stream(key: bytes, iv: bytes, src: bytes) -> bytes:
352
+ """Decrypt the provided stream with PKCS#5 padding handling."""
353
+ if crypto_disabled:
354
+ if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
355
+ raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
356
+ else:
357
+ raise Exception("library `cryptography.hazmat` missing; please install")
358
+
359
+ # `type:ignore`s are because of the conditional import, which
360
+ # we have because warehouse-based snowflake notebooks don't support
361
+ # the crypto library we're using.
362
+ cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
363
+ decryptor = cipher.decryptor()
364
+
365
+ # Decrypt the data
366
+ decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
367
+
368
+ # Unpad the decrypted data using PKCS#5
369
+ unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
370
+ unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
371
+
372
+ return unpadded_data
373
+
374
+
375
+ def decrypt_artifact(data: bytes, encryption_material: str) -> bytes:
376
+ """Decrypts the artifact data using provided encryption material."""
377
+ encryption_material_parts = encryption_material.split("|")
378
+ assert len(encryption_material_parts) == 3, "Invalid encryption material"
379
+
380
+ algorithm, key_base64, iv_base64 = encryption_material_parts
381
+ assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
382
+
383
+ key = base64.standard_b64decode(key_base64)
384
+ iv = base64.standard_b64decode(iv_base64)
385
+
386
+ return decrypt_stream(key, iv, data)
387
+
@@ -6,10 +6,9 @@ from collections import defaultdict
6
6
  from typing import Any, List, Optional
7
7
 
8
8
  from pandas import DataFrame
9
- import v0.relationalai as rai
10
9
  from v0.relationalai import debugging
11
10
  from v0.relationalai.clients import result_helpers
12
- from v0.relationalai.clients.snowflake import APP_NAME
11
+ from v0.relationalai.clients.resources.snowflake import APP_NAME
13
12
  from v0.relationalai.early_access.dsl.ir.compiler import Compiler
14
13
  from v0.relationalai.early_access.dsl.ontologies.models import Model
15
14
  from v0.relationalai.semantics.metamodel import ir
@@ -37,7 +36,8 @@ class RelExecutor:
37
36
  if not self._resources:
38
37
  with debugging.span("create_session"):
39
38
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
40
- self._resources = rai.clients.snowflake.Resources(
39
+ from v0.relationalai.clients.resources.snowflake import Resources
40
+ self._resources = Resources(
41
41
  dry_run=self.dry_run,
42
42
  config=self.config,
43
43
  generation=Generation.QB,
@@ -257,4 +257,4 @@ class RelExecutor:
257
257
  if raw:
258
258
  dataframe, errors = result_helpers.format_results(raw, None, result_cols)
259
259
  self.report_errors(errors)
260
- return DataFrame()
260
+ return DataFrame()
@@ -2,13 +2,14 @@ from typing import cast, Optional
2
2
 
3
3
  import v0.relationalai as rai
4
4
  from v0.relationalai import Config
5
+ from v0.relationalai.clients.resources.snowflake import Provider
5
6
  from v0.relationalai.early_access.dsl.snow.common import TabularMetadata, ColumnMetadata, SchemaMetadata, \
6
7
  ForeignKey, ColumnRef
7
8
 
8
9
 
9
10
  class Executor:
10
11
  def __init__(self, config: Optional[Config] = None):
11
- self._provider = cast(rai.clients.snowflake.Provider, rai.Provider(config=config))
12
+ self._provider = cast(Provider, rai.Provider(config=config))
12
13
  self._table_meta_cache = {}
13
14
  self._schema_fk_cache = {}
14
15
 
v0/relationalai/errors.py CHANGED
@@ -2436,6 +2436,24 @@ class QueryTimeoutExceededException(RAIException):
2436
2436
  Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
2437
2437
  """)
2438
2438
 
2439
+ class GuardRailsException(RAIException):
2440
+ def __init__(self, progress: dict[str, Any]={}):
2441
+ self.name = "Guard Rails Violation"
2442
+ self.message = "Transaction aborted due to guard rails violation."
2443
+ self.progress = progress
2444
+ self.content = self.format_message()
2445
+ super().__init__(self.message, self.name, self.content)
2446
+
2447
+ def format_message(self):
2448
+ messages = [] if self.progress else [self.message]
2449
+ for task in self.progress.get("tasks", {}).values():
2450
+ for warning_type, warning_data in task.get("warnings", {}).items():
2451
+ messages.append(textwrap.dedent(f"""
2452
+ Relation Name: [yellow]{task["task_name"]}[/yellow]
2453
+ Warning: {warning_type}
2454
+ Message: {warning_data["message"]}
2455
+ """))
2456
+ return "\n".join(messages)
2439
2457
 
2440
2458
  #--------------------------------------------------
2441
2459
  # Azure Exceptions
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import time
3
- from typing import Any, List, Optional
3
+ from typing import Any, List, Optional, cast
4
4
  from dataclasses import dataclass
5
5
  import textwrap
6
6
  from .. import dsl, std
@@ -14,14 +14,13 @@ import uuid
14
14
  import v0.relationalai
15
15
  import json
16
16
  from ..clients.util import poll_with_specified_overhead
17
- from ..clients.snowflake import Resources as SnowflakeResources
18
- from ..clients.snowflake import DirectAccessClient, DirectAccessResources
17
+ from ..clients.resources.snowflake import Resources as SnowflakeResources, APP_NAME
18
+ from ..clients.resources.snowflake import DirectAccessResources
19
+ from ..clients.direct_access_client import DirectAccessClient
19
20
  from ..util.timeout import calc_remaining_timeout_minutes
20
21
 
21
22
  rel_sv = rel._tagged(Builtins.SingleValued)
22
23
 
23
- APP_NAME = v0.relationalai.clients.snowflake.APP_NAME
24
-
25
24
  ENGINE_TYPE_SOLVER = "SOLVER"
26
25
  # TODO (dba) The ERP still uses `worker` instead of `engine`. Change
27
26
  # this once we fix this in the ERP.
@@ -741,10 +740,11 @@ class Provider:
741
740
  def __init__(self, resources=None):
742
741
  if not resources:
743
742
  resources = v0.relationalai.Resources()
744
- if not isinstance(resources, v0.relationalai.clients.snowflake.Resources):
743
+ if not isinstance(resources, SnowflakeResources):
745
744
  raise Exception("Solvers are only supported on SPCS.")
746
745
 
747
- self.resources = resources
746
+ # Type narrowing: resources is confirmed to be SnowflakeResources
747
+ self.resources: SnowflakeResources = cast(SnowflakeResources, resources)
748
748
  self.direct_access_client: Optional[DirectAccessClient] = None
749
749
 
750
750
  if isinstance(self.resources, DirectAccessResources):
@@ -5,12 +5,11 @@ import argparse
5
5
  import os
6
6
  import json
7
7
 
8
- from v0.relationalai.clients.snowflake import Resources as snowflake_api
8
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api, APP_NAME
9
9
  from v0.relationalai.semantics.lqp.executor import LQPExecutor
10
10
  from v0.relationalai.semantics.internal import internal
11
- from v0.relationalai.clients.use_index_poller import UseIndexPoller as index_poller
11
+ from v0.relationalai.clients.resources.snowflake.use_index_poller import UseIndexPoller as index_poller
12
12
  from snowflake.connector.cursor import DictCursor
13
- from v0.relationalai.clients import snowflake
14
13
 
15
14
  from enum import Enum
16
15
 
@@ -172,7 +171,7 @@ def _exec_snowflake_override(bench_ctx, old_func, marker):
172
171
  def new_func(self, code, params, raw=False):
173
172
  cur = self._session.connection.cursor(DictCursor)
174
173
  try:
175
- cur.execute(code.replace(snowflake.APP_NAME, self.get_app_name()), params)
174
+ cur.execute(code.replace(APP_NAME, self.get_app_name()), params)
176
175
  rows = cur.fetchall()
177
176
  qid = str(getattr(cur, "sfqid", None))
178
177
  assert qid is not None, "Snowflake query ID was not available"
@@ -398,7 +397,7 @@ def get_sf_query_info(bench_ctx):
398
397
  return result
399
398
 
400
399
  def _get_query_info(qids):
401
- from v0.relationalai.clients.snowflake import Resources as snowflake_client
400
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_client
402
401
  client = snowflake_client()
403
402
 
404
403
  qids_str = "','".join(qids)
@@ -7,7 +7,7 @@ import os
7
7
  import json
8
8
  from contextlib import contextmanager
9
9
 
10
- from v0.relationalai.clients.snowflake import Resources as snowflake_api
10
+ from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api
11
11
  from v0.relationalai.semantics.internal import internal
12
12
  from typing import Dict, Optional
13
13
 
@@ -61,7 +61,7 @@ def get_session():
61
61
  _session = get_active_session()
62
62
  except Exception:
63
63
  from v0.relationalai import Resources
64
- from v0.relationalai.clients.snowflake import Resources as SnowflakeResources
64
+ from v0.relationalai.clients.resources.snowflake import Resources as SnowflakeResources
65
65
  # TODO: we need a better way to handle global config
66
66
 
67
67
  # using the resource constructor to differentiate between direct access and
@@ -7,7 +7,6 @@ import re
7
7
  from pandas import DataFrame
8
8
  from typing import Any, Optional, Literal, TYPE_CHECKING
9
9
  from snowflake.snowpark import Session
10
- import v0.relationalai as rai
11
10
 
12
11
  from v0.relationalai import debugging
13
12
  from v0.relationalai.errors import NonDefaultLQPSemanticsVersionWarning
@@ -21,10 +20,10 @@ from lqp import print as lqp_print, ir as lqp_ir
21
20
  from lqp.parser import construct_configure
22
21
  from v0.relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
23
22
  from v0.relationalai.clients.config import Config
24
- from v0.relationalai.clients.snowflake import APP_NAME
23
+ from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
25
24
  from v0.relationalai.clients.types import TransactionAsyncResponse
26
25
  from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
27
- from v0.relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
26
+ from v0.relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
28
27
  from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
29
28
 
30
29
  if TYPE_CHECKING:
@@ -32,7 +31,9 @@ if TYPE_CHECKING:
32
31
 
33
32
  # Whenever the logic engine introduces a breaking change in behaviour, we bump this version
34
33
  # once the client is ready to handle it.
35
- DEFAULT_LQP_SEMANTICS_VERSION = "0"
34
+ #
35
+ # [2026-01-09] bumping to 1 to opt-into hard validation errors from the engine
36
+ DEFAULT_LQP_SEMANTICS_VERSION = "1"
36
37
 
37
38
  class LQPExecutor(e.Executor):
38
39
  """Executes LQP using the RAI client."""
@@ -67,17 +68,11 @@ class LQPExecutor(e.Executor):
67
68
  if not self._resources:
68
69
  with debugging.span("create_session"):
69
70
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
70
- resource_class = rai.clients.snowflake.Resources
71
- if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
72
- resource_class = rai.clients.snowflake.DirectAccessResources
73
- if self.config.get("platform", "") == "local":
74
- resource_class = rai.clients.local.LocalResources
75
71
  # NOTE: language="lqp" is not strictly required for LQP execution, but it
76
72
  # will significantly improve performance.
77
- self._resources = resource_class(
78
- dry_run=self.dry_run,
73
+ self._resources = create_resources_instance(
79
74
  config=self.config,
80
- generation=rai.Generation.QB,
75
+ dry_run=self.dry_run,
81
76
  connection=self.connection,
82
77
  language="lqp",
83
78
  )
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
118
118
  the same here).
119
119
  """
120
120
  class ExtractKeysRewriter(Rewriter):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.compound_keys: dict[Any, ir.Var] = {}
124
+
125
+ def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
126
+ if orig_keys in self.compound_keys:
127
+ return self.compound_keys[orig_keys]
128
+ compound_key = f.var("compound_key", types.Hash)
129
+ self.compound_keys[orig_keys] = compound_key
130
+ return compound_key
131
+
121
132
  def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
122
133
  outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
123
134
  # We are not in a logical with an output at this level.
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
170
181
  annos = list(output.annotations)
171
182
  annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
172
183
  # Create a compound key that will be used in place of the original keys.
173
- compound_key = f.var("compound_key", types.Hash)
184
+ compound_key = self._get_compound_key(output_keys)
174
185
 
175
186
  for key_combination in combinations:
176
187
  missing_keys = OrderedSet.from_iterable(output_keys)
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
192
203
  # handle the construct node in each clone
193
204
  values: list[ir.Value] = [compound_key.type]
194
205
  for key in output_keys:
195
- assert isinstance(key.type, ir.ScalarType)
196
- values.append(ir.Literal(types.String, key.type.name))
206
+ if isinstance(key.type, ir.UnionType):
207
+ # the typer can derive union types when multiple distinct entities flow
208
+ # into a relation's field, so use AnyEntity as the type marker
209
+ values.append(ir.Literal(types.String, "AnyEntity"))
210
+ else:
211
+ assert isinstance(key.type, ir.ScalarType)
212
+ values.append(ir.Literal(types.String, key.type.name))
197
213
  if key in key_combination:
198
214
  values.append(key)
199
215
  body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
@@ -408,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
408
424
  for arg in args[:-1]:
409
425
  extended_vars.add(arg)
410
426
  there_is_progress = True
427
+ elif isinstance(task, ir.Not):
428
+ if isinstance(task.task, ir.Logical):
429
+ hoisted = helpers.hoisted_vars(task.task.hoisted)
430
+ if var in hoisted:
431
+ partitions[var].add(task)
432
+ there_is_progress = True
411
433
  else:
412
434
  assert False, f"invalid node kind {type(task)}"
413
435