relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/semantics/frontend/base.py +3 -0
- relationalai/semantics/frontend/front_compiler.py +5 -2
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/metamodel.py +32 -4
- relationalai/semantics/metamodel/pprint.py +5 -3
- relationalai/semantics/metamodel/typer.py +324 -297
- relationalai/semantics/std/aggregates.py +0 -1
- relationalai/semantics/std/datetime.py +4 -1
- relationalai/shims/executor.py +26 -5
- relationalai/shims/mm2v0.py +119 -44
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +57 -48
- v0/relationalai/__init__.py +69 -22
- v0/relationalai/clients/__init__.py +15 -2
- v0/relationalai/clients/client.py +4 -4
- v0/relationalai/clients/local.py +5 -5
- v0/relationalai/clients/resources/__init__.py +8 -0
- v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
- v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
- v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
- v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- v0/relationalai/clients/resources/snowflake/util.py +387 -0
- v0/relationalai/early_access/dsl/ir/executor.py +4 -4
- v0/relationalai/early_access/dsl/snow/api.py +2 -1
- v0/relationalai/errors.py +23 -0
- v0/relationalai/experimental/solvers.py +7 -7
- v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
- v0/relationalai/semantics/internal/internal.py +4 -4
- v0/relationalai/semantics/internal/snowflake.py +3 -2
- v0/relationalai/semantics/lqp/executor.py +20 -22
- v0/relationalai/semantics/lqp/model2lqp.py +42 -4
- v0/relationalai/semantics/lqp/passes.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
- v0/relationalai/semantics/metamodel/builtins.py +8 -6
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- v0/relationalai/semantics/metamodel/util.py +6 -5
- v0/relationalai/semantics/reasoners/graph/core.py +8 -9
- v0/relationalai/semantics/rel/executor.py +14 -11
- v0/relationalai/semantics/sql/compiler.py +2 -2
- v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- v0/relationalai/tools/cli.py +26 -30
- v0/relationalai/tools/cli_helpers.py +10 -2
- v0/relationalai/util/otel_configuration.py +2 -1
- v0/relationalai/util/otel_handler.py +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
- /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for Snowflake resources.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
import re
|
|
6
|
+
import decimal
|
|
7
|
+
import base64
|
|
8
|
+
from numbers import Number
|
|
9
|
+
from datetime import datetime, date
|
|
10
|
+
from typing import List, Any, Dict, cast
|
|
11
|
+
|
|
12
|
+
from .... import dsl
|
|
13
|
+
from ....environments import runtime_env, SnowbookEnvironment
|
|
14
|
+
|
|
15
|
+
# warehouse-based snowflake notebooks currently don't have hazmat
|
|
16
|
+
crypto_disabled = False
|
|
17
|
+
try:
|
|
18
|
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
19
|
+
from cryptography.hazmat.backends import default_backend
|
|
20
|
+
from cryptography.hazmat.primitives import padding
|
|
21
|
+
except (ModuleNotFoundError, ImportError):
|
|
22
|
+
crypto_disabled = True
|
|
23
|
+
|
|
24
|
+
# Constants used by helper functions
|
|
25
|
+
ENGINE_ERRORS = ("engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted")
|
|
26
|
+
ENGINE_NOT_READY_MSGS = ("engine is in pending", "engine is provisioning")
|
|
27
|
+
DATABASE_ERRORS = ("database not found",)
|
|
28
|
+
|
|
29
|
+
# Constants for import/export and transaction processing
|
|
30
|
+
VALID_IMPORT_STATES = ("PENDING", "PROCESSING", "QUARANTINED", "LOADED")
|
|
31
|
+
IMPORT_STREAM_FIELDS = (
|
|
32
|
+
"ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS",
|
|
33
|
+
"FQ_OBJECT_NAME", "RAI_DATABASE", "RAI_RELATION", "DATA_SYNC_STATUS",
|
|
34
|
+
"PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
|
|
35
|
+
"NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"
|
|
36
|
+
)
|
|
37
|
+
FIELD_MAP = {
|
|
38
|
+
"database_name": "database",
|
|
39
|
+
"engine_name": "engine",
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def process_jinja_template(template: str, indent_spaces: int = 0, **substitutions: Any) -> str:
|
|
44
|
+
"""Process a Jinja-like template.
|
|
45
|
+
|
|
46
|
+
Supports:
|
|
47
|
+
- Variable substitution {{ var }}
|
|
48
|
+
- Conditional blocks {% if condition %} ... {% endif %}
|
|
49
|
+
- For loops {% for item in items %} ... {% endfor %}
|
|
50
|
+
- Comments {# ... #}
|
|
51
|
+
- Whitespace control with {%- and -%}
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
template: The template string
|
|
55
|
+
indent_spaces: Number of spaces to indent the result
|
|
56
|
+
**substitutions: Variable substitutions
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def evaluate_condition(condition: str, context: dict) -> bool:
|
|
60
|
+
"""Safely evaluate a condition string using the context."""
|
|
61
|
+
# Replace variables with their values
|
|
62
|
+
for k, v in context.items():
|
|
63
|
+
if isinstance(v, str):
|
|
64
|
+
condition = condition.replace(k, f"'{v}'")
|
|
65
|
+
else:
|
|
66
|
+
condition = condition.replace(k, str(v))
|
|
67
|
+
try:
|
|
68
|
+
return bool(eval(condition, {"__builtins__": {}}, {}))
|
|
69
|
+
except Exception:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
def process_expression(expr: str, context: dict) -> str:
|
|
73
|
+
"""Process a {{ expression }} block."""
|
|
74
|
+
expr = expr.strip()
|
|
75
|
+
if expr in context:
|
|
76
|
+
return str(context[expr])
|
|
77
|
+
return ""
|
|
78
|
+
|
|
79
|
+
def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
|
|
80
|
+
"""Process a block of template lines recursively."""
|
|
81
|
+
result = []
|
|
82
|
+
i = 0
|
|
83
|
+
while i < len(lines):
|
|
84
|
+
line = lines[i]
|
|
85
|
+
|
|
86
|
+
# Handle comments
|
|
87
|
+
line = re.sub(r'{#.*?#}', '', line)
|
|
88
|
+
|
|
89
|
+
# Handle if blocks
|
|
90
|
+
if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
|
|
91
|
+
if if_match:
|
|
92
|
+
condition = if_match.group(1)
|
|
93
|
+
if_block = []
|
|
94
|
+
else_block = []
|
|
95
|
+
i += 1
|
|
96
|
+
nesting = 1
|
|
97
|
+
in_else_block = False
|
|
98
|
+
while i < len(lines) and nesting > 0:
|
|
99
|
+
if re.search(r'{%\s*if\s+', lines[i]):
|
|
100
|
+
nesting += 1
|
|
101
|
+
elif re.search(r'{%\s*endif\s*%}', lines[i]):
|
|
102
|
+
nesting -= 1
|
|
103
|
+
elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
|
|
104
|
+
in_else_block = True
|
|
105
|
+
i += 1
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
if nesting > 0:
|
|
109
|
+
if in_else_block:
|
|
110
|
+
else_block.append(lines[i])
|
|
111
|
+
else:
|
|
112
|
+
if_block.append(lines[i])
|
|
113
|
+
i += 1
|
|
114
|
+
if evaluate_condition(condition, context):
|
|
115
|
+
result.extend(process_block(if_block, context, indent))
|
|
116
|
+
else:
|
|
117
|
+
result.extend(process_block(else_block, context, indent))
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
# Handle for loops
|
|
121
|
+
for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
|
|
122
|
+
if for_match:
|
|
123
|
+
var_name, iterable_name = for_match.groups()
|
|
124
|
+
for_block = []
|
|
125
|
+
i += 1
|
|
126
|
+
nesting = 1
|
|
127
|
+
while i < len(lines) and nesting > 0:
|
|
128
|
+
if re.search(r'{%\s*for\s+', lines[i]):
|
|
129
|
+
nesting += 1
|
|
130
|
+
elif re.search(r'{%\s*endfor\s*%}', lines[i]):
|
|
131
|
+
nesting -= 1
|
|
132
|
+
if nesting > 0:
|
|
133
|
+
for_block.append(lines[i])
|
|
134
|
+
i += 1
|
|
135
|
+
if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
|
|
136
|
+
for item in context[iterable_name]:
|
|
137
|
+
loop_context = dict(context)
|
|
138
|
+
loop_context[var_name] = item
|
|
139
|
+
result.extend(process_block(for_block, loop_context, indent))
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Handle variable substitution
|
|
143
|
+
line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
|
|
144
|
+
|
|
145
|
+
# Handle whitespace control
|
|
146
|
+
line = re.sub(r'{%-', '{%', line)
|
|
147
|
+
line = re.sub(r'-%}', '%}', line)
|
|
148
|
+
|
|
149
|
+
# Add line with proper indentation, preserving blank lines
|
|
150
|
+
if line.strip():
|
|
151
|
+
result.append(" " * (indent_spaces + indent) + line)
|
|
152
|
+
else:
|
|
153
|
+
result.append("")
|
|
154
|
+
|
|
155
|
+
i += 1
|
|
156
|
+
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
# Split template into lines and process
|
|
160
|
+
lines = template.split('\n')
|
|
161
|
+
processed_lines = process_block(lines, substitutions)
|
|
162
|
+
|
|
163
|
+
return '\n'.join(processed_lines)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def type_to_sql(type_obj: Any) -> str:
|
|
167
|
+
if type_obj is str:
|
|
168
|
+
return "VARCHAR"
|
|
169
|
+
if type_obj is int:
|
|
170
|
+
return "NUMBER"
|
|
171
|
+
if type_obj is Number:
|
|
172
|
+
return "DECIMAL(38, 15)"
|
|
173
|
+
if type_obj is float:
|
|
174
|
+
return "FLOAT"
|
|
175
|
+
if type_obj is decimal.Decimal:
|
|
176
|
+
return "DECIMAL(38, 15)"
|
|
177
|
+
if type_obj is bool:
|
|
178
|
+
return "BOOLEAN"
|
|
179
|
+
if type_obj is dict:
|
|
180
|
+
return "VARIANT"
|
|
181
|
+
if type_obj is list:
|
|
182
|
+
return "ARRAY"
|
|
183
|
+
if type_obj is bytes:
|
|
184
|
+
return "BINARY"
|
|
185
|
+
if type_obj is datetime:
|
|
186
|
+
return "TIMESTAMP"
|
|
187
|
+
if type_obj is date:
|
|
188
|
+
return "DATE"
|
|
189
|
+
if isinstance(type_obj, dsl.Type):
|
|
190
|
+
return "VARCHAR"
|
|
191
|
+
raise ValueError(f"Unknown type {type_obj}")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def type_to_snowpark(type_obj: Any) -> str:
|
|
195
|
+
if type_obj is str:
|
|
196
|
+
return "StringType()"
|
|
197
|
+
if type_obj is int:
|
|
198
|
+
return "IntegerType()"
|
|
199
|
+
if type_obj is float:
|
|
200
|
+
return "FloatType()"
|
|
201
|
+
if type_obj is Number:
|
|
202
|
+
return "DecimalType(38, 15)"
|
|
203
|
+
if type_obj is decimal.Decimal:
|
|
204
|
+
return "DecimalType(38, 15)"
|
|
205
|
+
if type_obj is bool:
|
|
206
|
+
return "BooleanType()"
|
|
207
|
+
if type_obj is dict:
|
|
208
|
+
return "MapType()"
|
|
209
|
+
if type_obj is list:
|
|
210
|
+
return "ArrayType()"
|
|
211
|
+
if type_obj is bytes:
|
|
212
|
+
return "BinaryType()"
|
|
213
|
+
if type_obj is datetime:
|
|
214
|
+
return "TimestampType()"
|
|
215
|
+
if type_obj is date:
|
|
216
|
+
return "DateType()"
|
|
217
|
+
if isinstance(type_obj, dsl.Type):
|
|
218
|
+
return "StringType()"
|
|
219
|
+
raise ValueError(f"Unknown type {type_obj}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def sanitize_user_name(user: str) -> str:
|
|
223
|
+
"""Sanitize a user name by extracting the part before '@' and replacing invalid characters."""
|
|
224
|
+
# Extract the part before the '@'
|
|
225
|
+
sanitized_user = user.split('@')[0]
|
|
226
|
+
# Replace any character that is not a letter, number, or underscore with '_'
|
|
227
|
+
sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
|
|
228
|
+
return sanitized_user
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def is_engine_issue(response_message: str) -> bool:
|
|
232
|
+
"""Check if a response message indicates an engine issue."""
|
|
233
|
+
return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def is_database_issue(response_message: str) -> bool:
|
|
237
|
+
"""Check if a response message indicates a database issue."""
|
|
238
|
+
return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def collect_error_messages(e: Exception) -> list[str]:
|
|
242
|
+
"""Collect all error messages from an exception and its chain.
|
|
243
|
+
|
|
244
|
+
Extracts messages from:
|
|
245
|
+
- str(e)
|
|
246
|
+
- e.message (if present, e.g., RAIException)
|
|
247
|
+
- e.args (string arguments)
|
|
248
|
+
- e.__cause__
|
|
249
|
+
- e.__context__
|
|
250
|
+
- Nested JavaScript execution errors
|
|
251
|
+
"""
|
|
252
|
+
messages = [str(e).lower()]
|
|
253
|
+
|
|
254
|
+
# Check message attribute (RAIException has this)
|
|
255
|
+
if hasattr(e, 'message'):
|
|
256
|
+
msg = getattr(e, 'message', None)
|
|
257
|
+
if isinstance(msg, str):
|
|
258
|
+
messages.append(msg.lower())
|
|
259
|
+
|
|
260
|
+
# Check args
|
|
261
|
+
if hasattr(e, 'args') and e.args:
|
|
262
|
+
for arg in e.args:
|
|
263
|
+
if isinstance(arg, str):
|
|
264
|
+
messages.append(arg.lower())
|
|
265
|
+
|
|
266
|
+
# Check cause and context
|
|
267
|
+
if hasattr(e, '__cause__') and e.__cause__:
|
|
268
|
+
messages.append(str(e.__cause__).lower())
|
|
269
|
+
if hasattr(e, '__context__') and e.__context__:
|
|
270
|
+
messages.append(str(e.__context__).lower())
|
|
271
|
+
|
|
272
|
+
# Extract nested messages from JavaScript execution errors
|
|
273
|
+
for msg in messages[:]: # Copy to avoid modification during iteration
|
|
274
|
+
if re.search(r"javascript execution error", msg):
|
|
275
|
+
matches = re.findall(r'"message"\s*:\s*"([^"]+)"', msg, re.IGNORECASE)
|
|
276
|
+
messages.extend([m.lower() for m in matches])
|
|
277
|
+
|
|
278
|
+
return messages
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
#--------------------------------------------------
|
|
282
|
+
# Parameter and Data Transformation Utilities
|
|
283
|
+
#--------------------------------------------------
|
|
284
|
+
|
|
285
|
+
def normalize_params(params: List[Any] | Any | None) -> List[Any] | None:
|
|
286
|
+
"""Normalize parameters to a list format."""
|
|
287
|
+
if params is not None and not isinstance(params, list):
|
|
288
|
+
return cast(List[Any], [params])
|
|
289
|
+
return params
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def format_sproc_name(name: str, type_obj: Any) -> str:
|
|
293
|
+
"""Format stored procedure parameter name based on type."""
|
|
294
|
+
if type_obj is datetime:
|
|
295
|
+
return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
|
|
296
|
+
return name
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def is_azure_url(url: str) -> bool:
|
|
300
|
+
"""Check if a URL is an Azure blob storage URL."""
|
|
301
|
+
return "blob.core.windows.net" in url
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def is_container_runtime() -> bool:
|
|
305
|
+
"""Check if running in a container runtime environment."""
|
|
306
|
+
return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
#--------------------------------------------------
|
|
310
|
+
# Import/Export Utilities
|
|
311
|
+
#--------------------------------------------------
|
|
312
|
+
|
|
313
|
+
def is_valid_import_state(state: str) -> bool:
|
|
314
|
+
"""Check if an import state is valid."""
|
|
315
|
+
return state in VALID_IMPORT_STATES
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def imports_to_dicts(results: List[Any]) -> List[Dict[str, Any]]:
|
|
319
|
+
"""Convert import results to dictionaries with lowercase keys."""
|
|
320
|
+
parsed_results = [
|
|
321
|
+
{field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
|
|
322
|
+
for row in results
|
|
323
|
+
]
|
|
324
|
+
return parsed_results
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
#--------------------------------------------------
|
|
328
|
+
# Transaction Utilities
|
|
329
|
+
#--------------------------------------------------
|
|
330
|
+
|
|
331
|
+
def txn_list_to_dicts(transactions: List[Any]) -> List[Dict[str, Any]]:
|
|
332
|
+
"""Convert transaction list to dictionaries with field mapping."""
|
|
333
|
+
dicts = []
|
|
334
|
+
for txn in transactions:
|
|
335
|
+
dict = {}
|
|
336
|
+
txn_dict = txn.asDict()
|
|
337
|
+
for key in txn_dict:
|
|
338
|
+
mapValue = FIELD_MAP.get(key.lower())
|
|
339
|
+
if mapValue:
|
|
340
|
+
dict[mapValue] = txn_dict[key]
|
|
341
|
+
else:
|
|
342
|
+
dict[key.lower()] = txn_dict[key]
|
|
343
|
+
dicts.append(dict)
|
|
344
|
+
return dicts
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
#--------------------------------------------------
|
|
348
|
+
# Encryption Utilities
|
|
349
|
+
#--------------------------------------------------
|
|
350
|
+
|
|
351
|
+
def decrypt_stream(key: bytes, iv: bytes, src: bytes) -> bytes:
|
|
352
|
+
"""Decrypt the provided stream with PKCS#5 padding handling."""
|
|
353
|
+
if crypto_disabled:
|
|
354
|
+
if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
|
|
355
|
+
raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
|
|
356
|
+
else:
|
|
357
|
+
raise Exception("library `cryptography.hazmat` missing; please install")
|
|
358
|
+
|
|
359
|
+
# `type:ignore`s are because of the conditional import, which
|
|
360
|
+
# we have because warehouse-based snowflake notebooks don't support
|
|
361
|
+
# the crypto library we're using.
|
|
362
|
+
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
|
|
363
|
+
decryptor = cipher.decryptor()
|
|
364
|
+
|
|
365
|
+
# Decrypt the data
|
|
366
|
+
decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
|
|
367
|
+
|
|
368
|
+
# Unpad the decrypted data using PKCS#5
|
|
369
|
+
unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
|
|
370
|
+
unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
|
371
|
+
|
|
372
|
+
return unpadded_data
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def decrypt_artifact(data: bytes, encryption_material: str) -> bytes:
|
|
376
|
+
"""Decrypts the artifact data using provided encryption material."""
|
|
377
|
+
encryption_material_parts = encryption_material.split("|")
|
|
378
|
+
assert len(encryption_material_parts) == 3, "Invalid encryption material"
|
|
379
|
+
|
|
380
|
+
algorithm, key_base64, iv_base64 = encryption_material_parts
|
|
381
|
+
assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
|
|
382
|
+
|
|
383
|
+
key = base64.standard_b64decode(key_base64)
|
|
384
|
+
iv = base64.standard_b64decode(iv_base64)
|
|
385
|
+
|
|
386
|
+
return decrypt_stream(key, iv, data)
|
|
387
|
+
|
|
@@ -6,10 +6,9 @@ from collections import defaultdict
|
|
|
6
6
|
from typing import Any, List, Optional
|
|
7
7
|
|
|
8
8
|
from pandas import DataFrame
|
|
9
|
-
import v0.relationalai as rai
|
|
10
9
|
from v0.relationalai import debugging
|
|
11
10
|
from v0.relationalai.clients import result_helpers
|
|
12
|
-
from v0.relationalai.clients.snowflake import APP_NAME
|
|
11
|
+
from v0.relationalai.clients.resources.snowflake import APP_NAME
|
|
13
12
|
from v0.relationalai.early_access.dsl.ir.compiler import Compiler
|
|
14
13
|
from v0.relationalai.early_access.dsl.ontologies.models import Model
|
|
15
14
|
from v0.relationalai.semantics.metamodel import ir
|
|
@@ -37,7 +36,8 @@ class RelExecutor:
|
|
|
37
36
|
if not self._resources:
|
|
38
37
|
with debugging.span("create_session"):
|
|
39
38
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
40
|
-
|
|
39
|
+
from v0.relationalai.clients.resources.snowflake import Resources
|
|
40
|
+
self._resources = Resources(
|
|
41
41
|
dry_run=self.dry_run,
|
|
42
42
|
config=self.config,
|
|
43
43
|
generation=Generation.QB,
|
|
@@ -257,4 +257,4 @@ class RelExecutor:
|
|
|
257
257
|
if raw:
|
|
258
258
|
dataframe, errors = result_helpers.format_results(raw, None, result_cols)
|
|
259
259
|
self.report_errors(errors)
|
|
260
|
-
return DataFrame()
|
|
260
|
+
return DataFrame()
|
|
@@ -2,13 +2,14 @@ from typing import cast, Optional
|
|
|
2
2
|
|
|
3
3
|
import v0.relationalai as rai
|
|
4
4
|
from v0.relationalai import Config
|
|
5
|
+
from v0.relationalai.clients.resources.snowflake import Provider
|
|
5
6
|
from v0.relationalai.early_access.dsl.snow.common import TabularMetadata, ColumnMetadata, SchemaMetadata, \
|
|
6
7
|
ForeignKey, ColumnRef
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class Executor:
|
|
10
11
|
def __init__(self, config: Optional[Config] = None):
|
|
11
|
-
self._provider = cast(
|
|
12
|
+
self._provider = cast(Provider, rai.Provider(config=config))
|
|
12
13
|
self._table_meta_cache = {}
|
|
13
14
|
self._schema_fk_cache = {}
|
|
14
15
|
|
v0/relationalai/errors.py
CHANGED
|
@@ -2168,6 +2168,29 @@ class DirectAccessInvalidAuthWarning(RAIWarning):
|
|
|
2168
2168
|
Direct access requires one of the following authenticators: {valid_authenticators_str}
|
|
2169
2169
|
""")
|
|
2170
2170
|
|
|
2171
|
+
class NonDefaultLQPSemanticsVersionWarning(RAIWarning):
|
|
2172
|
+
def __init__(self, current_version: str, default_version: str):
|
|
2173
|
+
self.current_version = current_version
|
|
2174
|
+
self.default_version = default_version
|
|
2175
|
+
self.name = "Non-default LQP Semantics Version"
|
|
2176
|
+
self.message = f"Using non-default LQP semantics version {current_version}. Default is {default_version}."
|
|
2177
|
+
self.content = self.format_message()
|
|
2178
|
+
super().__init__(self.message, self.name, self.content)
|
|
2179
|
+
|
|
2180
|
+
def format_message(self):
|
|
2181
|
+
return textwrap.dedent(f"""
|
|
2182
|
+
{self.message}
|
|
2183
|
+
|
|
2184
|
+
You are using a non-default LQP semantics version, likely to avoid a change in
|
|
2185
|
+
behaviour that broke one of your models. This is a reminder to ensure you switch
|
|
2186
|
+
back to the default version once any blocking issues have been resolved.
|
|
2187
|
+
|
|
2188
|
+
To do so you need to remove the following section from your raiconfig.toml:
|
|
2189
|
+
|
|
2190
|
+
[reasoner.rule]
|
|
2191
|
+
lqp.semantics_version = {self.current_version}
|
|
2192
|
+
""")
|
|
2193
|
+
|
|
2171
2194
|
class InsecureKeychainWarning(RAIWarning):
|
|
2172
2195
|
def __init__(self):
|
|
2173
2196
|
self.message = "Insecure keyring detected. Please use a secure keyring backend."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import time
|
|
3
|
-
from typing import Any, List, Optional
|
|
3
|
+
from typing import Any, List, Optional, cast
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
import textwrap
|
|
6
6
|
from .. import dsl, std
|
|
@@ -14,14 +14,13 @@ import uuid
|
|
|
14
14
|
import v0.relationalai
|
|
15
15
|
import json
|
|
16
16
|
from ..clients.util import poll_with_specified_overhead
|
|
17
|
-
from ..clients.snowflake import Resources as SnowflakeResources
|
|
18
|
-
from ..clients.snowflake import
|
|
17
|
+
from ..clients.resources.snowflake import Resources as SnowflakeResources, APP_NAME
|
|
18
|
+
from ..clients.resources.snowflake import DirectAccessResources
|
|
19
|
+
from ..clients.direct_access_client import DirectAccessClient
|
|
19
20
|
from ..util.timeout import calc_remaining_timeout_minutes
|
|
20
21
|
|
|
21
22
|
rel_sv = rel._tagged(Builtins.SingleValued)
|
|
22
23
|
|
|
23
|
-
APP_NAME = v0.relationalai.clients.snowflake.APP_NAME
|
|
24
|
-
|
|
25
24
|
ENGINE_TYPE_SOLVER = "SOLVER"
|
|
26
25
|
# TODO (dba) The ERP still uses `worker` instead of `engine`. Change
|
|
27
26
|
# this once we fix this in the ERP.
|
|
@@ -741,10 +740,11 @@ class Provider:
|
|
|
741
740
|
def __init__(self, resources=None):
|
|
742
741
|
if not resources:
|
|
743
742
|
resources = v0.relationalai.Resources()
|
|
744
|
-
if not isinstance(resources,
|
|
743
|
+
if not isinstance(resources, SnowflakeResources):
|
|
745
744
|
raise Exception("Solvers are only supported on SPCS.")
|
|
746
745
|
|
|
747
|
-
|
|
746
|
+
# Type narrowing: resources is confirmed to be SnowflakeResources
|
|
747
|
+
self.resources: SnowflakeResources = cast(SnowflakeResources, resources)
|
|
748
748
|
self.direct_access_client: Optional[DirectAccessClient] = None
|
|
749
749
|
|
|
750
750
|
if isinstance(self.resources, DirectAccessResources):
|
|
@@ -5,12 +5,11 @@ import argparse
|
|
|
5
5
|
import os
|
|
6
6
|
import json
|
|
7
7
|
|
|
8
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_api
|
|
8
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api, APP_NAME
|
|
9
9
|
from v0.relationalai.semantics.lqp.executor import LQPExecutor
|
|
10
10
|
from v0.relationalai.semantics.internal import internal
|
|
11
|
-
from v0.relationalai.clients.use_index_poller import UseIndexPoller as index_poller
|
|
11
|
+
from v0.relationalai.clients.resources.snowflake.use_index_poller import UseIndexPoller as index_poller
|
|
12
12
|
from snowflake.connector.cursor import DictCursor
|
|
13
|
-
from v0.relationalai.clients import snowflake
|
|
14
13
|
|
|
15
14
|
from enum import Enum
|
|
16
15
|
|
|
@@ -172,7 +171,7 @@ def _exec_snowflake_override(bench_ctx, old_func, marker):
|
|
|
172
171
|
def new_func(self, code, params, raw=False):
|
|
173
172
|
cur = self._session.connection.cursor(DictCursor)
|
|
174
173
|
try:
|
|
175
|
-
cur.execute(code.replace(
|
|
174
|
+
cur.execute(code.replace(APP_NAME, self.get_app_name()), params)
|
|
176
175
|
rows = cur.fetchall()
|
|
177
176
|
qid = str(getattr(cur, "sfqid", None))
|
|
178
177
|
assert qid is not None, "Snowflake query ID was not available"
|
|
@@ -398,7 +397,7 @@ def get_sf_query_info(bench_ctx):
|
|
|
398
397
|
return result
|
|
399
398
|
|
|
400
399
|
def _get_query_info(qids):
|
|
401
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_client
|
|
400
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_client
|
|
402
401
|
client = snowflake_client()
|
|
403
402
|
|
|
404
403
|
qids_str = "','".join(qids)
|
|
@@ -7,7 +7,7 @@ import os
|
|
|
7
7
|
import json
|
|
8
8
|
from contextlib import contextmanager
|
|
9
9
|
|
|
10
|
-
from v0.relationalai.clients.snowflake import Resources as snowflake_api
|
|
10
|
+
from v0.relationalai.clients.resources.snowflake import Resources as snowflake_api
|
|
11
11
|
from v0.relationalai.semantics.internal import internal
|
|
12
12
|
from typing import Dict, Optional
|
|
13
13
|
|
|
@@ -884,12 +884,12 @@ class Concept(Producer):
|
|
|
884
884
|
raise ValueError("Concept names cannot start with '_'")
|
|
885
885
|
|
|
886
886
|
# Check if it matches either allowed format
|
|
887
|
-
pattern_a = r'^[a-zA-Z0-9_.]+$'
|
|
888
|
-
pattern_b = r'^[a-zA-Z0-9_.]+\([0-9]+,[0-9]+\)$'
|
|
887
|
+
pattern_a = r'^[a-zA-Z0-9_."-]+$'
|
|
888
|
+
pattern_b = r'^[a-zA-Z0-9_."-]+\([0-9]+,[0-9]+\)$'
|
|
889
889
|
|
|
890
890
|
if not (re.match(pattern_a, name) or re.match(pattern_b, name)):
|
|
891
891
|
raise ValueError(f"Concept name '{name}' contains invalid characters. "
|
|
892
|
-
f"Names must contain only letters, digits, dots, and underscores, "
|
|
892
|
+
f"Names must contain only letters, digits, dots, double quotes, hyphens, and underscores, "
|
|
893
893
|
f"optionally followed by precision/scale in parentheses like 'Decimal(38,14)'")
|
|
894
894
|
|
|
895
895
|
def __init__(self, name:str, extends:list[Any] = [], model:Model|None=None, identify_by:dict[str, Any]={}):
|
|
@@ -1365,7 +1365,7 @@ class Relationship(Producer):
|
|
|
1365
1365
|
|
|
1366
1366
|
def _parse_schema_format(self, format_string:str):
|
|
1367
1367
|
# Pattern to extract fields like {Type} or {name:Type}, where Type can have precision and scale, like Decimal(38,14)
|
|
1368
|
-
pattern = r'\{([a-zA-Z0-9_.]+(?:\([0-9]+,[0-9]+\))?)(?::([a-zA-Z0-9_.]+(?:\([0-9]+,[0-9]+\))?))?\}'
|
|
1368
|
+
pattern = r'\{([a-zA-Z0-9_."-]+(?:\([0-9]+,[0-9]+\))?)(?::([a-zA-Z0-9_."-]+(?:\([0-9]+,[0-9]+\))?))?\}'
|
|
1369
1369
|
matches = re.findall(pattern, format_string)
|
|
1370
1370
|
|
|
1371
1371
|
namer = NameCache()
|
|
@@ -61,7 +61,7 @@ def get_session():
|
|
|
61
61
|
_session = get_active_session()
|
|
62
62
|
except Exception:
|
|
63
63
|
from v0.relationalai import Resources
|
|
64
|
-
from v0.relationalai.clients.snowflake import Resources as SnowflakeResources
|
|
64
|
+
from v0.relationalai.clients.resources.snowflake import Resources as SnowflakeResources
|
|
65
65
|
# TODO: we need a better way to handle global config
|
|
66
66
|
|
|
67
67
|
# using the resource constructor to differentiate between direct access and
|
|
@@ -230,7 +230,8 @@ class Table():
|
|
|
230
230
|
if self._table not in schema_info.fetched:
|
|
231
231
|
schema_info.fetch()
|
|
232
232
|
table_info = schema_info.tables[self._table]
|
|
233
|
-
|
|
233
|
+
CDC_name = self._fqn.lower() if '"' not in self._fqn else self._fqn.replace("\"", "_")
|
|
234
|
+
self._rel = b.Relationship(CDC_name, fields=[b.Field(name="RowId", type_str=self._fqn, type=self._concept)] + table_info.fields)
|
|
234
235
|
self._rel.annotate(anns.external).annotate(anns.from_cdc)
|
|
235
236
|
|
|
236
237
|
def __getattr__(self, name: str):
|
|
@@ -7,9 +7,9 @@ import re
|
|
|
7
7
|
from pandas import DataFrame
|
|
8
8
|
from typing import Any, Optional, Literal, TYPE_CHECKING
|
|
9
9
|
from snowflake.snowpark import Session
|
|
10
|
-
import v0.relationalai as rai
|
|
11
10
|
|
|
12
11
|
from v0.relationalai import debugging
|
|
12
|
+
from v0.relationalai.errors import NonDefaultLQPSemanticsVersionWarning
|
|
13
13
|
from v0.relationalai.semantics.lqp import result_helpers
|
|
14
14
|
from v0.relationalai.semantics.metamodel import ir, factory as f, executor as e
|
|
15
15
|
from v0.relationalai.semantics.lqp.compiler import Compiler
|
|
@@ -20,15 +20,18 @@ from lqp import print as lqp_print, ir as lqp_ir
|
|
|
20
20
|
from lqp.parser import construct_configure
|
|
21
21
|
from v0.relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
|
|
22
22
|
from v0.relationalai.clients.config import Config
|
|
23
|
-
from v0.relationalai.clients.snowflake import APP_NAME
|
|
23
|
+
from v0.relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
|
|
24
24
|
from v0.relationalai.clients.types import TransactionAsyncResponse
|
|
25
25
|
from v0.relationalai.clients.util import IdentityParser, escape_for_f_string
|
|
26
|
-
from v0.relationalai.tools.constants import
|
|
26
|
+
from v0.relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
|
|
27
27
|
from v0.relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from v0.relationalai.semantics.snowflake import Table
|
|
31
31
|
|
|
32
|
+
# Whenever the logic engine introduces a breaking change in behaviour, we bump this version
|
|
33
|
+
# once the client is ready to handle it.
|
|
34
|
+
DEFAULT_LQP_SEMANTICS_VERSION = "0"
|
|
32
35
|
|
|
33
36
|
class LQPExecutor(e.Executor):
|
|
34
37
|
"""Executes LQP using the RAI client."""
|
|
@@ -63,17 +66,11 @@ class LQPExecutor(e.Executor):
|
|
|
63
66
|
if not self._resources:
|
|
64
67
|
with debugging.span("create_session"):
|
|
65
68
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
66
|
-
resource_class = rai.clients.snowflake.Resources
|
|
67
|
-
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
68
|
-
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
69
|
-
if self.config.get("platform", "") == "local":
|
|
70
|
-
resource_class = rai.clients.local.LocalResources
|
|
71
69
|
# NOTE: language="lqp" is not strictly required for LQP execution, but it
|
|
72
70
|
# will significantly improve performance.
|
|
73
|
-
self._resources =
|
|
74
|
-
dry_run=self.dry_run,
|
|
71
|
+
self._resources = create_resources_instance(
|
|
75
72
|
config=self.config,
|
|
76
|
-
|
|
73
|
+
dry_run=self.dry_run,
|
|
77
74
|
connection=self.connection,
|
|
78
75
|
language="lqp",
|
|
79
76
|
)
|
|
@@ -147,17 +144,7 @@ class LQPExecutor(e.Executor):
|
|
|
147
144
|
elif code == "PYREL_ERROR":
|
|
148
145
|
pyrel_errors[problem["props"]["pyrel_id"]].append(problem)
|
|
149
146
|
elif abort_on_error:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
if code == 'SYSTEM_INTERNAL':
|
|
153
|
-
supplementary_message = "Troubleshooting:\n" + \
|
|
154
|
-
" 1. Please retry with a new name for your model. This can work around state-related issues.\n" + \
|
|
155
|
-
" 2. If the error persists, please retry with the `use_lqp` flag set to `False`, for example:\n" + \
|
|
156
|
-
" `model = Model(..., use_lqp=False)`\n" + \
|
|
157
|
-
" This will switch the execution to the legacy backend, which may avoid the issue with some performance cost.\n"
|
|
158
|
-
|
|
159
|
-
e.content = f"{e.content}{supplementary_message}"
|
|
160
|
-
all_errors.append(e)
|
|
147
|
+
all_errors.append(errors.RelQueryError(problem, source))
|
|
161
148
|
else:
|
|
162
149
|
if code == "ARITY_MISMATCH":
|
|
163
150
|
errors.ArityMismatch(problem, source)
|
|
@@ -311,6 +298,17 @@ class LQPExecutor(e.Executor):
|
|
|
311
298
|
ivm_flag = self.config.get('reasoner.rule.incremental_maintenance', None)
|
|
312
299
|
if ivm_flag:
|
|
313
300
|
config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
|
|
301
|
+
|
|
302
|
+
# Set semantics_version from config, defaulting to 0
|
|
303
|
+
semantics_version: str | Any = self.config.get('reasoner.rule.lqp.semantics_version', DEFAULT_LQP_SEMANTICS_VERSION)
|
|
304
|
+
config_dict['semantics_version'] = lqp_ir.Value(value=int(semantics_version), meta=None)
|
|
305
|
+
|
|
306
|
+
# Warn if a non-default semantics version is used. Most likely, this is due to a
|
|
307
|
+
# user manually reverting to an older version. We want them to not get stuck on that
|
|
308
|
+
# version for longer than necessary.
|
|
309
|
+
if semantics_version != DEFAULT_LQP_SEMANTICS_VERSION:
|
|
310
|
+
debugging.warn(NonDefaultLQPSemanticsVersionWarning(semantics_version, DEFAULT_LQP_SEMANTICS_VERSION))
|
|
311
|
+
|
|
314
312
|
return construct_configure(config_dict, None)
|
|
315
313
|
|
|
316
314
|
def _should_sync(self, model) :
|