relationalai 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/config.py +7 -0
- relationalai/clients/direct_access_client.py +113 -0
- relationalai/clients/snowflake.py +263 -189
- relationalai/clients/types.py +4 -1
- relationalai/clients/use_index_poller.py +72 -48
- relationalai/clients/util.py +9 -0
- relationalai/dsl.py +1 -2
- relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
- relationalai/early_access/rel/rewrite/__init__.py +1 -1
- relationalai/environments/snowbook.py +10 -1
- relationalai/errors.py +24 -3
- relationalai/semantics/internal/annotations.py +1 -0
- relationalai/semantics/internal/internal.py +22 -3
- relationalai/semantics/lqp/builtins.py +1 -0
- relationalai/semantics/lqp/executor.py +12 -4
- relationalai/semantics/lqp/model2lqp.py +1 -0
- relationalai/semantics/lqp/passes.py +3 -4
- relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
- relationalai/semantics/metamodel/builtins.py +12 -1
- relationalai/semantics/metamodel/executor.py +2 -1
- relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
- relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
- relationalai/semantics/reasoners/graph/core.py +1356 -258
- relationalai/semantics/rel/builtins.py +5 -1
- relationalai/semantics/rel/compiler.py +3 -3
- relationalai/semantics/rel/executor.py +20 -11
- relationalai/semantics/sql/compiler.py +2 -3
- relationalai/semantics/sql/executor/duck_db.py +8 -4
- relationalai/semantics/sql/executor/snowflake.py +1 -1
- relationalai/tools/cli.py +17 -6
- relationalai/tools/cli_controls.py +334 -352
- relationalai/tools/constants.py +1 -0
- relationalai/tools/query_utils.py +27 -0
- relationalai/util/otel_configuration.py +1 -1
- {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/METADATA +5 -4
- {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/RECORD +45 -45
- relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
- relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
- /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
- /relationalai/semantics/{rel → lqp}/rewrite/extract_common.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
- /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/WHEEL +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/entry_points.txt +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
|
|
2
4
|
import json
|
|
3
5
|
import logging
|
|
@@ -5,7 +7,11 @@ import uuid
|
|
|
5
7
|
|
|
6
8
|
from relationalai import debugging
|
|
7
9
|
from relationalai.clients.cache_store import GraphIndexCache
|
|
8
|
-
from relationalai.clients.util import
|
|
10
|
+
from relationalai.clients.util import (
|
|
11
|
+
get_pyrel_version,
|
|
12
|
+
normalize_datetime,
|
|
13
|
+
poll_with_specified_overhead,
|
|
14
|
+
)
|
|
9
15
|
from relationalai.errors import (
|
|
10
16
|
ERPNotRunningError,
|
|
11
17
|
EngineProvisioningFailed,
|
|
@@ -29,6 +35,7 @@ from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
|
|
|
29
35
|
# Set up logger for this module
|
|
30
36
|
logger = logging.getLogger(__name__)
|
|
31
37
|
|
|
38
|
+
|
|
32
39
|
try:
|
|
33
40
|
from rich.console import Console
|
|
34
41
|
from rich.table import Table
|
|
@@ -63,49 +70,49 @@ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
|
|
|
63
70
|
# This query calculates a hash of column metadata (name, type, precision, scale, nullable)
|
|
64
71
|
# to detect if source table schema has changed since stream was created
|
|
65
72
|
STREAM_COLUMN_HASH_QUERY = """
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
value:name::VARCHAR
|
|
73
|
+
WITH stream_columns AS (
|
|
74
|
+
SELECT
|
|
75
|
+
fq_object_name,
|
|
76
|
+
HASH(
|
|
77
|
+
value:name::VARCHAR,
|
|
71
78
|
CASE
|
|
72
|
-
WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
|
|
73
|
-
THEN
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
THEN
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
THEN
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
WHEN 'TEXT' THEN 'TEXT'
|
|
91
|
-
ELSE value:type::VARCHAR
|
|
92
|
-
END || '(' || value:length || ')'
|
|
79
|
+
WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL THEN CASE value:type::VARCHAR
|
|
80
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
81
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
82
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
83
|
+
ELSE value:type::VARCHAR
|
|
84
|
+
END || '(' || value:precision || ',' || value:scale || ')'
|
|
85
|
+
WHEN value:precision IS NOT NULL AND value:scale IS NULL THEN CASE value:type::VARCHAR
|
|
86
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
87
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
88
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
89
|
+
ELSE value:type::VARCHAR
|
|
90
|
+
END || '(0,' || value:precision || ')'
|
|
91
|
+
WHEN value:length IS NOT NULL THEN CASE value:type::VARCHAR
|
|
92
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
93
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
94
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
95
|
+
ELSE value:type::VARCHAR
|
|
96
|
+
END || '(' || value:length || ')'
|
|
93
97
|
ELSE CASE value:type::VARCHAR
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
END
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
98
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
99
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
100
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
101
|
+
ELSE value:type::VARCHAR
|
|
102
|
+
END
|
|
103
|
+
END,
|
|
104
|
+
IFF(value:nullable::BOOLEAN, 'YES', 'NO')
|
|
105
|
+
) AS column_signature
|
|
106
|
+
FROM {app_name}.api.data_streams,
|
|
107
|
+
LATERAL FLATTEN(input => columns)
|
|
108
|
+
WHERE rai_database = '{rai_database}'
|
|
109
|
+
AND fq_object_name IN ({fqn_list})
|
|
110
|
+
)
|
|
111
|
+
SELECT
|
|
112
|
+
fq_object_name AS FQ_OBJECT_NAME,
|
|
113
|
+
HEX_ENCODE(HASH_AGG(column_signature)) AS STREAM_HASH
|
|
114
|
+
FROM stream_columns
|
|
115
|
+
GROUP BY fq_object_name;
|
|
109
116
|
"""
|
|
110
117
|
|
|
111
118
|
|
|
@@ -296,9 +303,10 @@ class UseIndexPoller:
|
|
|
296
303
|
Returns:
|
|
297
304
|
List of truly stale sources that need to be deleted/recreated
|
|
298
305
|
|
|
299
|
-
A source is truly stale if:
|
|
300
|
-
- The stream doesn't exist (needs to be created)
|
|
301
|
-
- The
|
|
306
|
+
A source is truly stale if any of the following apply:
|
|
307
|
+
- The stream doesn't exist (needs to be created)
|
|
308
|
+
- The source table was recreated after the stream (table creation timestamp is newer)
|
|
309
|
+
- The column hashes don't match (schema drift needs cleanup)
|
|
302
310
|
"""
|
|
303
311
|
stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
|
|
304
312
|
|
|
@@ -306,14 +314,30 @@ class UseIndexPoller:
|
|
|
306
314
|
for source in stale_sources:
|
|
307
315
|
source_hash = self.source_info[source].get("columns_hash")
|
|
308
316
|
stream_hash = stream_hashes.get(source)
|
|
317
|
+
table_created_at_raw = self.source_info[source].get("table_created_at")
|
|
318
|
+
stream_created_at_raw = self.source_info[source].get("stream_created_at")
|
|
319
|
+
|
|
320
|
+
table_created_at = normalize_datetime(table_created_at_raw)
|
|
321
|
+
stream_created_at = normalize_datetime(stream_created_at_raw)
|
|
322
|
+
|
|
323
|
+
recreated_table = False
|
|
324
|
+
if table_created_at is not None and stream_created_at is not None:
|
|
325
|
+
# If the source table was recreated (new creation timestamp) but kept
|
|
326
|
+
# the same column definitions, we still need to recycle the stream so
|
|
327
|
+
# that Snowflake picks up the new table instance.
|
|
328
|
+
recreated_table = table_created_at > stream_created_at
|
|
309
329
|
|
|
310
330
|
# Log hash comparison for debugging
|
|
311
331
|
logger.debug(f"Source: {source}")
|
|
312
332
|
logger.debug(f" Source table hash: {source_hash}")
|
|
313
333
|
logger.debug(f" Stream hash: {stream_hash}")
|
|
314
334
|
logger.debug(f" Match: {source_hash == stream_hash}")
|
|
335
|
+
if recreated_table:
|
|
336
|
+
logger.debug(" Table appears to have been recreated (table_created_at > stream_created_at)")
|
|
337
|
+
logger.debug(f" table_created_at: {table_created_at}")
|
|
338
|
+
logger.debug(f" stream_created_at: {stream_created_at}")
|
|
315
339
|
|
|
316
|
-
if stream_hash is None or source_hash != stream_hash:
|
|
340
|
+
if stream_hash is None or source_hash != stream_hash or recreated_table:
|
|
317
341
|
logger.debug(" Action: DELETE (stale)")
|
|
318
342
|
truly_stale.append(source)
|
|
319
343
|
else:
|
|
@@ -376,7 +400,7 @@ class UseIndexPoller:
|
|
|
376
400
|
stale_sources = [
|
|
377
401
|
source
|
|
378
402
|
for source, info in self.source_info.items()
|
|
379
|
-
if info
|
|
403
|
+
if info.get("state") == "STALE"
|
|
380
404
|
]
|
|
381
405
|
|
|
382
406
|
if not stale_sources:
|
|
@@ -763,7 +787,7 @@ class UseIndexPoller:
|
|
|
763
787
|
# Log the error for debugging
|
|
764
788
|
logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
|
|
765
789
|
failed_tables.append((fqn, str(e)))
|
|
766
|
-
|
|
790
|
+
|
|
767
791
|
# Handle errors based on subtask type
|
|
768
792
|
if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
|
|
769
793
|
# Mark the individual subtask as failed and complete it
|
relationalai/clients/util.py
CHANGED
|
@@ -80,6 +80,15 @@ def escape_for_f_string(code: str) -> str:
|
|
|
80
80
|
def escape_for_sproc(code: str) -> str:
|
|
81
81
|
return code.replace("$$", "\\$\\$")
|
|
82
82
|
|
|
83
|
+
|
|
84
|
+
def normalize_datetime(value: object) -> datetime | None:
|
|
85
|
+
"""Return a timezone-aware UTC datetime or None."""
|
|
86
|
+
if not isinstance(value, datetime):
|
|
87
|
+
return None
|
|
88
|
+
if value.tzinfo is None:
|
|
89
|
+
return value.replace(tzinfo=timezone.utc)
|
|
90
|
+
return value.astimezone(timezone.utc)
|
|
91
|
+
|
|
83
92
|
# @NOTE: `overhead_rate` should fall between 0.05 and 0.5 depending on how time sensitive / expensive the operation in question is.
|
|
84
93
|
def poll_with_specified_overhead(
|
|
85
94
|
f,
|
relationalai/dsl.py
CHANGED
|
@@ -22,6 +22,7 @@ import sys
|
|
|
22
22
|
from pandas import DataFrame
|
|
23
23
|
|
|
24
24
|
from relationalai.environments import runtime_env, SnowbookEnvironment
|
|
25
|
+
from relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
|
|
25
26
|
|
|
26
27
|
from .clients.client import Client
|
|
27
28
|
|
|
@@ -34,9 +35,7 @@ from .errors import FilterAsValue, Errors, InvalidPropertySetException, Multiple
|
|
|
34
35
|
#--------------------------------------------------
|
|
35
36
|
|
|
36
37
|
RESERVED_PROPS = ["add", "set", "persist", "unpersist"]
|
|
37
|
-
|
|
38
38
|
MAX_QUERY_ATTRIBUTE_LENGTH = 255
|
|
39
|
-
QUERY_ATTRIBUTES_HEADER = "X-Query-Attributes"
|
|
40
39
|
|
|
41
40
|
Value = Union[
|
|
42
41
|
"Expression",
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
-
from relationalai.semantics.metamodel.rewrite import
|
|
2
|
-
DNFUnionSplitter,
|
|
1
|
+
from relationalai.semantics.metamodel.rewrite import Flatten, \
|
|
2
|
+
DNFUnionSplitter, ExtractNestedLogicals, flatten
|
|
3
|
+
from relationalai.semantics.lqp.rewrite import Splinter, \
|
|
4
|
+
ExtractKeys, FDConstraints
|
|
3
5
|
|
|
4
|
-
__all__ = ["Splinter", "
|
|
6
|
+
__all__ = ["Splinter", "Flatten", "DNFUnionSplitter", "ExtractKeys",
|
|
5
7
|
"ExtractNestedLogicals", "FDConstraints", "flatten"]
|
|
@@ -20,7 +20,16 @@ class SnowbookEnvironment(NotebookRuntimeEnvironment, SessionEnvironment):
|
|
|
20
20
|
|
|
21
21
|
def __init__(self):
|
|
22
22
|
super().__init__()
|
|
23
|
-
|
|
23
|
+
# Detect runner type based on module presence:
|
|
24
|
+
# - Warehouse runtime has '_snowflake' module
|
|
25
|
+
# - Container runtime has 'snowflake._legacy' module
|
|
26
|
+
if "_snowflake" in sys.modules:
|
|
27
|
+
self.runner = "warehouse"
|
|
28
|
+
elif "snowflake._legacy" in sys.modules:
|
|
29
|
+
self.runner = "container"
|
|
30
|
+
else:
|
|
31
|
+
# Fallback to original check
|
|
32
|
+
self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
|
|
24
33
|
|
|
25
34
|
@classmethod
|
|
26
35
|
def detect(cls):
|
relationalai/errors.py
CHANGED
|
@@ -2397,17 +2397,18 @@ class UnsupportedColumnTypesWarning(RAIWarning):
|
|
|
2397
2397
|
""")
|
|
2398
2398
|
|
|
2399
2399
|
class QueryTimeoutExceededException(RAIException):
|
|
2400
|
-
def __init__(self, timeout_mins: int, config_file_path: str | None = None):
|
|
2400
|
+
def __init__(self, timeout_mins: int, query_id: str | None = None, config_file_path: str | None = None):
|
|
2401
2401
|
self.timeout_mins = timeout_mins
|
|
2402
|
-
self.message = f"Query execution time exceeded the specified timeout of {timeout_mins} minutes."
|
|
2403
2402
|
self.name = "Query Timeout Exceeded"
|
|
2403
|
+
self.message = f"Query execution time exceeded the specified timeout of {self.timeout_mins} minutes."
|
|
2404
|
+
self.query_id = query_id or ""
|
|
2404
2405
|
self.config_file_path = config_file_path or ""
|
|
2405
2406
|
self.content = self.format_message()
|
|
2406
2407
|
super().__init__(self.message, self.name, self.content)
|
|
2407
2408
|
|
|
2408
2409
|
def format_message(self):
|
|
2409
2410
|
return textwrap.dedent(f"""
|
|
2410
|
-
{self.
|
|
2411
|
+
Query execution time exceeded the specified timeout of {self.timeout_mins} minutes{f' for query with ID: {self.query_id}' if self.query_id else ''}.
|
|
2411
2412
|
|
|
2412
2413
|
Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
|
|
2413
2414
|
""")
|
|
@@ -2432,3 +2433,23 @@ class AzureUnsupportedQueryTimeoutException(RAIException):
|
|
|
2432
2433
|
Please remove the 'query_timeout_mins' from your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} when running on platform Azure.
|
|
2433
2434
|
""")
|
|
2434
2435
|
|
|
2436
|
+
class AzureLegacyDependencyMissingException(RAIException):
|
|
2437
|
+
def __init__(self):
|
|
2438
|
+
self.message = "The Azure platform requires the 'legacy' extras to be installed."
|
|
2439
|
+
self.name = "Azure Legacy Dependency Missing"
|
|
2440
|
+
self.content = self.format_message()
|
|
2441
|
+
super().__init__(self.message, self.name, self.content)
|
|
2442
|
+
|
|
2443
|
+
def format_message(self):
|
|
2444
|
+
return textwrap.dedent("""
|
|
2445
|
+
The Azure platform requires the 'rai-sdk' package, which is not installed.
|
|
2446
|
+
|
|
2447
|
+
To use the Azure platform, please install the legacy extras:
|
|
2448
|
+
|
|
2449
|
+
pip install relationalai[legacy]
|
|
2450
|
+
|
|
2451
|
+
Or if upgrading an existing installation:
|
|
2452
|
+
|
|
2453
|
+
pip install --upgrade relationalai[legacy]
|
|
2454
|
+
""")
|
|
2455
|
+
|
|
@@ -2338,6 +2338,7 @@ class Fragment():
|
|
|
2338
2338
|
self._define.extend(parent._define)
|
|
2339
2339
|
self._order_by.extend(parent._order_by)
|
|
2340
2340
|
self._limit = parent._limit
|
|
2341
|
+
self._meta.update(parent._meta)
|
|
2341
2342
|
|
|
2342
2343
|
def _add_items(self, items:PySequence[Any], to_attr:list[Any]):
|
|
2343
2344
|
# TODO: ensure that you are _either_ a select, require, or then
|
|
@@ -2416,9 +2417,26 @@ class Fragment():
|
|
|
2416
2417
|
return f
|
|
2417
2418
|
|
|
2418
2419
|
def meta(self, **kwargs: Any) -> Fragment:
|
|
2420
|
+
"""Add metadata to the query.
|
|
2421
|
+
|
|
2422
|
+
Metadata can be used for debugging and observability purposes.
|
|
2423
|
+
|
|
2424
|
+
Args:
|
|
2425
|
+
**kwargs: Metadata key-value pairs
|
|
2426
|
+
|
|
2427
|
+
Returns:
|
|
2428
|
+
Fragment: Returns self for method chaining
|
|
2429
|
+
|
|
2430
|
+
Example:
|
|
2431
|
+
select(Person.name).meta(workload_name="test", priority=1, enabled=True)
|
|
2432
|
+
"""
|
|
2433
|
+
if not kwargs:
|
|
2434
|
+
raise ValueError("meta() requires at least one argument")
|
|
2435
|
+
|
|
2419
2436
|
self._meta.update(kwargs)
|
|
2420
2437
|
return self
|
|
2421
2438
|
|
|
2439
|
+
|
|
2422
2440
|
def annotate(self, *annos:Expression|Relationship|ir.Annotation) -> Fragment:
|
|
2423
2441
|
self._annotations.extend(annos)
|
|
2424
2442
|
return self
|
|
@@ -2497,7 +2515,7 @@ class Fragment():
|
|
|
2497
2515
|
# @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
|
|
2498
2516
|
with debugging.span("query", tag=None, dsl=str(self), **with_source(self), meta=self._meta) as query_span:
|
|
2499
2517
|
query_task = qb_model._compiler.fragment(self)
|
|
2500
|
-
results = qb_model._to_executor().execute(ir_model, query_task)
|
|
2518
|
+
results = qb_model._to_executor().execute(ir_model, query_task, meta=self._meta)
|
|
2501
2519
|
query_span["results"] = results
|
|
2502
2520
|
# For local debugging mostly
|
|
2503
2521
|
dry_run = qb_model._dry_run or bool(qb_model._config.get("compiler.dry_run", False))
|
|
@@ -2524,7 +2542,7 @@ class Fragment():
|
|
|
2524
2542
|
# @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
|
|
2525
2543
|
with debugging.span("query", tag=None, dsl=str(clone), **with_source(clone), meta=clone._meta) as query_span:
|
|
2526
2544
|
query_task = qb_model._compiler.fragment(clone)
|
|
2527
|
-
results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark")
|
|
2545
|
+
results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark", meta=clone._meta)
|
|
2528
2546
|
query_span["alt_format_results"] = results
|
|
2529
2547
|
return results
|
|
2530
2548
|
|
|
@@ -2541,7 +2559,8 @@ class Fragment():
|
|
|
2541
2559
|
clone._source = runtime_env.get_source_pos()
|
|
2542
2560
|
with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
|
|
2543
2561
|
query_task = qb_model._compiler.fragment(clone)
|
|
2544
|
-
qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update)
|
|
2562
|
+
qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update, meta=clone._meta)
|
|
2563
|
+
|
|
2545
2564
|
|
|
2546
2565
|
#--------------------------------------------------
|
|
2547
2566
|
# Select / Where
|
|
@@ -21,8 +21,8 @@ from relationalai.clients.config import Config
|
|
|
21
21
|
from relationalai.clients.snowflake import APP_NAME
|
|
22
22
|
from relationalai.clients.types import TransactionAsyncResponse
|
|
23
23
|
from relationalai.clients.util import IdentityParser
|
|
24
|
-
from relationalai.tools.constants import USE_DIRECT_ACCESS
|
|
25
|
-
|
|
24
|
+
from relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
|
|
25
|
+
from relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
26
26
|
|
|
27
27
|
class LQPExecutor(e.Executor):
|
|
28
28
|
"""Executes LQP using the RAI client."""
|
|
@@ -267,7 +267,7 @@ class LQPExecutor(e.Executor):
|
|
|
267
267
|
if ivm_flag:
|
|
268
268
|
config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
|
|
269
269
|
return construct_configure(config_dict, None)
|
|
270
|
-
|
|
270
|
+
|
|
271
271
|
def _compile_intrinsics(self) -> lqp_ir.Epoch:
|
|
272
272
|
"""Construct an epoch that defines a number of built-in definitions used by the
|
|
273
273
|
emitter."""
|
|
@@ -344,6 +344,10 @@ class LQPExecutor(e.Executor):
|
|
|
344
344
|
df, errs = result_helpers.format_results(raw_results, cols)
|
|
345
345
|
self.report_errors(errs)
|
|
346
346
|
|
|
347
|
+
# Rename columns if wide outputs is enabled
|
|
348
|
+
if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
|
|
349
|
+
df.columns = cols[: len(df.columns)]
|
|
350
|
+
|
|
347
351
|
# Process exports
|
|
348
352
|
if export_to and not self.dry_run:
|
|
349
353
|
assert cols, "No columns found in the output"
|
|
@@ -362,7 +366,7 @@ class LQPExecutor(e.Executor):
|
|
|
362
366
|
|
|
363
367
|
def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
|
|
364
368
|
result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
|
|
365
|
-
update: bool = False) -> DataFrame:
|
|
369
|
+
update: bool = False, meta: dict[str, Any] | None = None) -> DataFrame:
|
|
366
370
|
self.prepare_data()
|
|
367
371
|
previous_model = self._last_model
|
|
368
372
|
|
|
@@ -374,6 +378,9 @@ class LQPExecutor(e.Executor):
|
|
|
374
378
|
if format != "pandas":
|
|
375
379
|
raise ValueError(f"Unsupported format: {format}")
|
|
376
380
|
|
|
381
|
+
# Format meta as headers
|
|
382
|
+
json_meta = prepare_metadata_for_headers(meta)
|
|
383
|
+
headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
|
|
377
384
|
raw_results = self.resources.exec_lqp(
|
|
378
385
|
self.database,
|
|
379
386
|
self.engine,
|
|
@@ -383,6 +390,7 @@ class LQPExecutor(e.Executor):
|
|
|
383
390
|
# transactions are serialized.
|
|
384
391
|
readonly=False,
|
|
385
392
|
nowait_durable=True,
|
|
393
|
+
headers=headers,
|
|
386
394
|
)
|
|
387
395
|
assert isinstance(raw_results, TransactionAsyncResponse)
|
|
388
396
|
|
|
@@ -102,6 +102,7 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
|
|
|
102
102
|
data_columns=csv_columns,
|
|
103
103
|
compression="gzip",
|
|
104
104
|
partition_size=200,
|
|
105
|
+
syntax_escapechar='"', # To follow Snowflake's expected format
|
|
105
106
|
meta=None,
|
|
106
107
|
)
|
|
107
108
|
reads.append(lqp.Read(read_type=lqp.Export(config=export_csv_config, meta=None), meta=None))
|
|
@@ -2,13 +2,12 @@ from relationalai.semantics.metamodel.compiler import Pass
|
|
|
2
2
|
from relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
3
|
from relationalai.semantics.metamodel.typer import Checker, InferTypes, typer
|
|
4
4
|
from relationalai.semantics.metamodel import helpers, types
|
|
5
|
-
from relationalai.semantics.metamodel.rewrite import (Splinter, ExtractNestedLogicals, ExtractKeys, FDConstraints,
|
|
6
|
-
DNFUnionSplitter, DischargeConstraints)
|
|
7
5
|
from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
8
6
|
|
|
9
7
|
from relationalai.semantics.metamodel.rewrite import Flatten
|
|
10
|
-
|
|
11
|
-
from
|
|
8
|
+
|
|
9
|
+
from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
|
|
10
|
+
from .rewrite import CDC, ExtractCommon, ExtractKeys, FDConstraints, QuantifyVars, Splinter
|
|
12
11
|
|
|
13
12
|
from relationalai.semantics.lqp.utils import output_names
|
|
14
13
|
|
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
from .cdc import CDC
|
|
2
2
|
from .extract_common import ExtractCommon
|
|
3
|
+
from .extract_keys import ExtractKeys
|
|
4
|
+
from .fd_constraints import FDConstraints
|
|
3
5
|
from .quantify_vars import QuantifyVars
|
|
6
|
+
from .splinter import Splinter
|
|
4
7
|
|
|
5
8
|
__all__ = [
|
|
6
9
|
"CDC",
|
|
7
10
|
"ExtractCommon",
|
|
11
|
+
"ExtractKeys",
|
|
12
|
+
"FDConstraints",
|
|
8
13
|
"QuantifyVars",
|
|
14
|
+
"Splinter",
|
|
9
15
|
]
|
|
@@ -496,6 +496,17 @@ function = f.relation("function", [f.input_field("code", types.Symbol)])
|
|
|
496
496
|
function_checked_annotation = f.annotation(function, [f.lit("checked")])
|
|
497
497
|
function_annotation = f.annotation(function, [])
|
|
498
498
|
|
|
499
|
+
# Indicates this relation should be tracked in telemetry. Only supported for Relationships.
|
|
500
|
+
# `RAI_BackIR.with_relation_tracking` produces log messages at the start and end of each
|
|
501
|
+
# SCC evaluation, if any declarations bear the `track` annotation.
|
|
502
|
+
track = f.relation("track", [
|
|
503
|
+
# BackIR evaluation expects 2 parameters on the track annotation: the tracking
|
|
504
|
+
# library name and tracking relation name, which appear as log metadata fields.
|
|
505
|
+
f.input_field("library", types.Symbol),
|
|
506
|
+
f.input_field("relation", types.Symbol)
|
|
507
|
+
])
|
|
508
|
+
track_annotation = f.annotation(track, [])
|
|
509
|
+
|
|
499
510
|
# All ir nodes marked by this annotation will be removed from the final metamodel before compilation.
|
|
500
511
|
# Specifically it happens in `Flatten` pass when rewrites for `require` happen
|
|
501
512
|
discharged = f.relation("discharged", [])
|
|
@@ -672,7 +683,7 @@ def _compute_builtin_overloads() -> list[ir.Relation]:
|
|
|
672
683
|
return overloads
|
|
673
684
|
|
|
674
685
|
# manually maintain the list of relations that are actually annotations
|
|
675
|
-
builtin_annotations = [external, export, concept_population, from_cdc, from_cast]
|
|
686
|
+
builtin_annotations = [external, export, concept_population, from_cdc, from_cast, track]
|
|
676
687
|
builtin_annotations_by_name = dict((r.name, r) for r in builtin_annotations)
|
|
677
688
|
|
|
678
689
|
builtin_relations = _compute_builtin_relations()
|
|
@@ -1,12 +1,6 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .list_types import RewriteListTypes
|
|
3
|
-
from .gc_nodes import GarbageCollectNodes
|
|
4
|
-
from .flatten import Flatten
|
|
1
|
+
from .discharge_constraints import DischargeConstraints
|
|
5
2
|
from .dnf_union_splitter import DNFUnionSplitter
|
|
6
|
-
from .extract_keys import ExtractKeys
|
|
7
3
|
from .extract_nested_logicals import ExtractNestedLogicals
|
|
8
|
-
from .
|
|
9
|
-
from .discharge_constraints import DischargeConstraints
|
|
4
|
+
from .flatten import Flatten
|
|
10
5
|
|
|
11
|
-
__all__ = ["
|
|
12
|
-
"ExtractNestedLogicals", "FDConstraints", "DischargeConstraints"]
|
|
6
|
+
__all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten"]
|
|
@@ -558,22 +558,23 @@ class Flatten(Pass):
|
|
|
558
558
|
def rewrite_wide_output(self, output: ir.Output):
|
|
559
559
|
assert(output.keys)
|
|
560
560
|
|
|
561
|
-
# only
|
|
562
|
-
|
|
561
|
+
# only append keys that are not already in the output
|
|
562
|
+
suffix_keys = []
|
|
563
563
|
for key in output.keys:
|
|
564
564
|
if all([val is not key for _, val in output.aliases]):
|
|
565
|
-
|
|
565
|
+
suffix_keys.append(key)
|
|
566
566
|
|
|
567
567
|
aliases: OrderedSet[Tuple[str, ir.Value]] = ordered_set()
|
|
568
|
-
# add the keys to the output
|
|
569
|
-
for key in prefix_keys:
|
|
570
|
-
aliases.add((key.name, key))
|
|
571
568
|
|
|
572
569
|
# add the remaining args, unless it is already a key
|
|
573
570
|
for name, val in output.aliases:
|
|
574
|
-
if not isinstance(val, ir.Var) or val not in
|
|
571
|
+
if not isinstance(val, ir.Var) or val not in suffix_keys:
|
|
575
572
|
aliases.add((name, val))
|
|
576
573
|
|
|
574
|
+
# add the keys to the output
|
|
575
|
+
for key in suffix_keys:
|
|
576
|
+
aliases.add((key.name, key))
|
|
577
|
+
|
|
577
578
|
# TODO - we are assuming that the Rel compiler will translate nullable lookups
|
|
578
579
|
# properly, returning a `Missing` if necessary, like this:
|
|
579
580
|
# (nested_192(_adult, _adult_name) or (not nested_192(_adult, _) and _adult_name = Missing)) and
|