relationalai 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/direct_access_client.py +5 -0
- relationalai/clients/snowflake.py +259 -91
- relationalai/clients/types.py +4 -1
- relationalai/clients/use_index_poller.py +96 -55
- relationalai/clients/util.py +9 -0
- relationalai/dsl.py +1 -2
- relationalai/environments/snowbook.py +10 -1
- relationalai/experimental/solvers.py +283 -79
- relationalai/semantics/internal/internal.py +24 -5
- relationalai/semantics/lqp/executor.py +22 -6
- relationalai/semantics/lqp/model2lqp.py +4 -2
- relationalai/semantics/metamodel/executor.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
- relationalai/semantics/reasoners/graph/core.py +1174 -226
- relationalai/semantics/rel/executor.py +30 -12
- relationalai/semantics/sql/executor/snowflake.py +1 -1
- relationalai/tools/cli.py +6 -2
- relationalai/tools/cli_controls.py +334 -352
- relationalai/tools/constants.py +1 -0
- relationalai/tools/query_utils.py +27 -0
- relationalai/util/otel_configuration.py +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/METADATA +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/RECORD +26 -25
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/WHEEL +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
|
|
2
4
|
import json
|
|
3
5
|
import logging
|
|
@@ -5,7 +7,11 @@ import uuid
|
|
|
5
7
|
|
|
6
8
|
from relationalai import debugging
|
|
7
9
|
from relationalai.clients.cache_store import GraphIndexCache
|
|
8
|
-
from relationalai.clients.util import
|
|
10
|
+
from relationalai.clients.util import (
|
|
11
|
+
get_pyrel_version,
|
|
12
|
+
normalize_datetime,
|
|
13
|
+
poll_with_specified_overhead,
|
|
14
|
+
)
|
|
9
15
|
from relationalai.errors import (
|
|
10
16
|
ERPNotRunningError,
|
|
11
17
|
EngineProvisioningFailed,
|
|
@@ -29,6 +35,7 @@ from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
|
|
|
29
35
|
# Set up logger for this module
|
|
30
36
|
logger = logging.getLogger(__name__)
|
|
31
37
|
|
|
38
|
+
|
|
32
39
|
try:
|
|
33
40
|
from rich.console import Console
|
|
34
41
|
from rich.table import Table
|
|
@@ -63,49 +70,49 @@ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
|
|
|
63
70
|
# This query calculates a hash of column metadata (name, type, precision, scale, nullable)
|
|
64
71
|
# to detect if source table schema has changed since stream was created
|
|
65
72
|
STREAM_COLUMN_HASH_QUERY = """
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
value:name::VARCHAR
|
|
73
|
+
WITH stream_columns AS (
|
|
74
|
+
SELECT
|
|
75
|
+
fq_object_name,
|
|
76
|
+
HASH(
|
|
77
|
+
value:name::VARCHAR,
|
|
71
78
|
CASE
|
|
72
|
-
WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
|
|
73
|
-
THEN
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
THEN
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
THEN
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
WHEN 'TEXT' THEN 'TEXT'
|
|
91
|
-
ELSE value:type::VARCHAR
|
|
92
|
-
END || '(' || value:length || ')'
|
|
79
|
+
WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL THEN CASE value:type::VARCHAR
|
|
80
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
81
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
82
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
83
|
+
ELSE value:type::VARCHAR
|
|
84
|
+
END || '(' || value:precision || ',' || value:scale || ')'
|
|
85
|
+
WHEN value:precision IS NOT NULL AND value:scale IS NULL THEN CASE value:type::VARCHAR
|
|
86
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
87
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
88
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
89
|
+
ELSE value:type::VARCHAR
|
|
90
|
+
END || '(0,' || value:precision || ')'
|
|
91
|
+
WHEN value:length IS NOT NULL THEN CASE value:type::VARCHAR
|
|
92
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
93
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
94
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
95
|
+
ELSE value:type::VARCHAR
|
|
96
|
+
END || '(' || value:length || ')'
|
|
93
97
|
ELSE CASE value:type::VARCHAR
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
END
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
98
|
+
WHEN 'FIXED' THEN 'NUMBER'
|
|
99
|
+
WHEN 'REAL' THEN 'FLOAT'
|
|
100
|
+
WHEN 'TEXT' THEN 'TEXT'
|
|
101
|
+
ELSE value:type::VARCHAR
|
|
102
|
+
END
|
|
103
|
+
END,
|
|
104
|
+
IFF(value:nullable::BOOLEAN, 'YES', 'NO')
|
|
105
|
+
) AS column_signature
|
|
106
|
+
FROM {app_name}.api.data_streams,
|
|
107
|
+
LATERAL FLATTEN(input => columns)
|
|
108
|
+
WHERE rai_database = '{rai_database}'
|
|
109
|
+
AND fq_object_name IN ({fqn_list})
|
|
110
|
+
)
|
|
111
|
+
SELECT
|
|
112
|
+
fq_object_name AS FQ_OBJECT_NAME,
|
|
113
|
+
HEX_ENCODE(HASH_AGG(column_signature)) AS STREAM_HASH
|
|
114
|
+
FROM stream_columns
|
|
115
|
+
GROUP BY fq_object_name;
|
|
109
116
|
"""
|
|
110
117
|
|
|
111
118
|
|
|
@@ -154,8 +161,9 @@ class UseIndexPoller:
|
|
|
154
161
|
model: str,
|
|
155
162
|
engine_name: str,
|
|
156
163
|
engine_size: Optional[str],
|
|
157
|
-
|
|
158
|
-
|
|
164
|
+
language: str = "rel",
|
|
165
|
+
program_span_id: Optional[str] = None,
|
|
166
|
+
headers: Optional[Dict] = None,
|
|
159
167
|
generation: Optional[Generation] = None,
|
|
160
168
|
):
|
|
161
169
|
self.res = resource
|
|
@@ -164,6 +172,7 @@ class UseIndexPoller:
|
|
|
164
172
|
self.model = model
|
|
165
173
|
self.engine_name = engine_name
|
|
166
174
|
self.engine_size = engine_size or self.res.config.get_default_engine_size()
|
|
175
|
+
self.language = language
|
|
167
176
|
self.program_span_id = program_span_id
|
|
168
177
|
self.headers = headers or {}
|
|
169
178
|
self.counter = 1
|
|
@@ -183,8 +192,8 @@ class UseIndexPoller:
|
|
|
183
192
|
)
|
|
184
193
|
current_user = self.res.get_sf_session().get_current_user()
|
|
185
194
|
assert current_user is not None, "current_user must be set"
|
|
186
|
-
data_freshness = self.res.config.get_data_freshness_mins()
|
|
187
|
-
self.cache = GraphIndexCache(current_user, model, data_freshness, self.sources)
|
|
195
|
+
self.data_freshness = self.res.config.get_data_freshness_mins()
|
|
196
|
+
self.cache = GraphIndexCache(current_user, model, self.data_freshness, self.sources)
|
|
188
197
|
self.sources = self.cache.choose_sources()
|
|
189
198
|
# execution_id is allowed to group use_index call, which belongs to the same loop iteration
|
|
190
199
|
self.execution_id = str(uuid.uuid4())
|
|
@@ -296,9 +305,10 @@ class UseIndexPoller:
|
|
|
296
305
|
Returns:
|
|
297
306
|
List of truly stale sources that need to be deleted/recreated
|
|
298
307
|
|
|
299
|
-
A source is truly stale if:
|
|
300
|
-
- The stream doesn't exist (needs to be created)
|
|
301
|
-
- The
|
|
308
|
+
A source is truly stale if any of the following apply:
|
|
309
|
+
- The stream doesn't exist (needs to be created)
|
|
310
|
+
- The source table was recreated after the stream (table creation timestamp is newer)
|
|
311
|
+
- The column hashes don't match (schema drift needs cleanup)
|
|
302
312
|
"""
|
|
303
313
|
stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
|
|
304
314
|
|
|
@@ -306,14 +316,30 @@ class UseIndexPoller:
|
|
|
306
316
|
for source in stale_sources:
|
|
307
317
|
source_hash = self.source_info[source].get("columns_hash")
|
|
308
318
|
stream_hash = stream_hashes.get(source)
|
|
319
|
+
table_created_at_raw = self.source_info[source].get("table_created_at")
|
|
320
|
+
stream_created_at_raw = self.source_info[source].get("stream_created_at")
|
|
321
|
+
|
|
322
|
+
table_created_at = normalize_datetime(table_created_at_raw)
|
|
323
|
+
stream_created_at = normalize_datetime(stream_created_at_raw)
|
|
324
|
+
|
|
325
|
+
recreated_table = False
|
|
326
|
+
if table_created_at is not None and stream_created_at is not None:
|
|
327
|
+
# If the source table was recreated (new creation timestamp) but kept
|
|
328
|
+
# the same column definitions, we still need to recycle the stream so
|
|
329
|
+
# that Snowflake picks up the new table instance.
|
|
330
|
+
recreated_table = table_created_at > stream_created_at
|
|
309
331
|
|
|
310
332
|
# Log hash comparison for debugging
|
|
311
333
|
logger.debug(f"Source: {source}")
|
|
312
334
|
logger.debug(f" Source table hash: {source_hash}")
|
|
313
335
|
logger.debug(f" Stream hash: {stream_hash}")
|
|
314
336
|
logger.debug(f" Match: {source_hash == stream_hash}")
|
|
337
|
+
if recreated_table:
|
|
338
|
+
logger.debug(" Table appears to have been recreated (table_created_at > stream_created_at)")
|
|
339
|
+
logger.debug(f" table_created_at: {table_created_at}")
|
|
340
|
+
logger.debug(f" stream_created_at: {stream_created_at}")
|
|
315
341
|
|
|
316
|
-
if stream_hash is None or source_hash != stream_hash:
|
|
342
|
+
if stream_hash is None or source_hash != stream_hash or recreated_table:
|
|
317
343
|
logger.debug(" Action: DELETE (stale)")
|
|
318
344
|
truly_stale.append(source)
|
|
319
345
|
else:
|
|
@@ -376,7 +402,7 @@ class UseIndexPoller:
|
|
|
376
402
|
stale_sources = [
|
|
377
403
|
source
|
|
378
404
|
for source, info in self.source_info.items()
|
|
379
|
-
if info
|
|
405
|
+
if info.get("state") == "STALE"
|
|
380
406
|
]
|
|
381
407
|
|
|
382
408
|
if not stale_sources:
|
|
@@ -462,6 +488,8 @@ class UseIndexPoller:
|
|
|
462
488
|
"wait_for_stream_sync": self.wait_for_stream_sync,
|
|
463
489
|
"should_check_cdc": self.should_check_cdc,
|
|
464
490
|
"init_engine_async": self.init_engine_async,
|
|
491
|
+
"language": self.language,
|
|
492
|
+
"data_freshness_mins": self.data_freshness,
|
|
465
493
|
})
|
|
466
494
|
|
|
467
495
|
request_headers = debugging.add_current_propagation_headers(self.headers)
|
|
@@ -763,7 +791,7 @@ class UseIndexPoller:
|
|
|
763
791
|
# Log the error for debugging
|
|
764
792
|
logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
|
|
765
793
|
failed_tables.append((fqn, str(e)))
|
|
766
|
-
|
|
794
|
+
|
|
767
795
|
# Handle errors based on subtask type
|
|
768
796
|
if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
|
|
769
797
|
# Mark the individual subtask as failed and complete it
|
|
@@ -829,11 +857,23 @@ class DirectUseIndexPoller(UseIndexPoller):
|
|
|
829
857
|
model: str,
|
|
830
858
|
engine_name: str,
|
|
831
859
|
engine_size: Optional[str],
|
|
832
|
-
|
|
833
|
-
|
|
860
|
+
language: str = "rel",
|
|
861
|
+
program_span_id: Optional[str] = None,
|
|
862
|
+
headers: Optional[Dict] = None,
|
|
834
863
|
generation: Optional[Generation] = None,
|
|
835
864
|
):
|
|
836
|
-
super().__init__(
|
|
865
|
+
super().__init__(
|
|
866
|
+
resource=resource,
|
|
867
|
+
app_name=app_name,
|
|
868
|
+
sources=sources,
|
|
869
|
+
model=model,
|
|
870
|
+
engine_name=engine_name,
|
|
871
|
+
engine_size=engine_size,
|
|
872
|
+
language=language,
|
|
873
|
+
program_span_id=program_span_id,
|
|
874
|
+
headers=headers,
|
|
875
|
+
generation=generation,
|
|
876
|
+
)
|
|
837
877
|
from relationalai.clients.snowflake import DirectAccessResources
|
|
838
878
|
self.res: DirectAccessResources = cast(DirectAccessResources, self.res)
|
|
839
879
|
|
|
@@ -854,6 +894,7 @@ class DirectUseIndexPoller(UseIndexPoller):
|
|
|
854
894
|
model=self.model,
|
|
855
895
|
engine_name=self.engine_name,
|
|
856
896
|
engine_size=self.engine_size,
|
|
897
|
+
language=self.language,
|
|
857
898
|
rai_relations=[],
|
|
858
899
|
pyrel_program_id=self.program_span_id,
|
|
859
900
|
skip_pull_relations=True,
|
relationalai/clients/util.py
CHANGED
|
@@ -80,6 +80,15 @@ def escape_for_f_string(code: str) -> str:
|
|
|
80
80
|
def escape_for_sproc(code: str) -> str:
|
|
81
81
|
return code.replace("$$", "\\$\\$")
|
|
82
82
|
|
|
83
|
+
|
|
84
|
+
def normalize_datetime(value: object) -> datetime | None:
|
|
85
|
+
"""Return a timezone-aware UTC datetime or None."""
|
|
86
|
+
if not isinstance(value, datetime):
|
|
87
|
+
return None
|
|
88
|
+
if value.tzinfo is None:
|
|
89
|
+
return value.replace(tzinfo=timezone.utc)
|
|
90
|
+
return value.astimezone(timezone.utc)
|
|
91
|
+
|
|
83
92
|
# @NOTE: `overhead_rate` should fall between 0.05 and 0.5 depending on how time sensitive / expensive the operation in question is.
|
|
84
93
|
def poll_with_specified_overhead(
|
|
85
94
|
f,
|
relationalai/dsl.py
CHANGED
|
@@ -22,6 +22,7 @@ import sys
|
|
|
22
22
|
from pandas import DataFrame
|
|
23
23
|
|
|
24
24
|
from relationalai.environments import runtime_env, SnowbookEnvironment
|
|
25
|
+
from relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
|
|
25
26
|
|
|
26
27
|
from .clients.client import Client
|
|
27
28
|
|
|
@@ -34,9 +35,7 @@ from .errors import FilterAsValue, Errors, InvalidPropertySetException, Multiple
|
|
|
34
35
|
#--------------------------------------------------
|
|
35
36
|
|
|
36
37
|
RESERVED_PROPS = ["add", "set", "persist", "unpersist"]
|
|
37
|
-
|
|
38
38
|
MAX_QUERY_ATTRIBUTE_LENGTH = 255
|
|
39
|
-
QUERY_ATTRIBUTES_HEADER = "X-Query-Attributes"
|
|
40
39
|
|
|
41
40
|
Value = Union[
|
|
42
41
|
"Expression",
|
|
@@ -20,7 +20,16 @@ class SnowbookEnvironment(NotebookRuntimeEnvironment, SessionEnvironment):
|
|
|
20
20
|
|
|
21
21
|
def __init__(self):
|
|
22
22
|
super().__init__()
|
|
23
|
-
|
|
23
|
+
# Detect runner type based on module presence:
|
|
24
|
+
# - Warehouse runtime has '_snowflake' module
|
|
25
|
+
# - Container runtime has 'snowflake._legacy' module
|
|
26
|
+
if "_snowflake" in sys.modules:
|
|
27
|
+
self.runner = "warehouse"
|
|
28
|
+
elif "snowflake._legacy" in sys.modules:
|
|
29
|
+
self.runner = "container"
|
|
30
|
+
else:
|
|
31
|
+
# Fallback to original check
|
|
32
|
+
self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
|
|
24
33
|
|
|
25
34
|
@classmethod
|
|
26
35
|
def detect(cls):
|