relationalai 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
2
4
  import json
3
5
  import logging
@@ -5,7 +7,11 @@ import uuid
5
7
 
6
8
  from relationalai import debugging
7
9
  from relationalai.clients.cache_store import GraphIndexCache
8
- from relationalai.clients.util import get_pyrel_version, poll_with_specified_overhead
10
+ from relationalai.clients.util import (
11
+ get_pyrel_version,
12
+ normalize_datetime,
13
+ poll_with_specified_overhead,
14
+ )
9
15
  from relationalai.errors import (
10
16
  ERPNotRunningError,
11
17
  EngineProvisioningFailed,
@@ -29,6 +35,7 @@ from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
29
35
  # Set up logger for this module
30
36
  logger = logging.getLogger(__name__)
31
37
 
38
+
32
39
  try:
33
40
  from rich.console import Console
34
41
  from rich.table import Table
@@ -63,49 +70,49 @@ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
63
70
  # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
64
71
  # to detect if source table schema has changed since stream was created
65
72
  STREAM_COLUMN_HASH_QUERY = """
66
- SELECT
67
- FQ_OBJECT_NAME,
68
- SHA2(
69
- LISTAGG(
70
- value:name::VARCHAR ||
73
+ WITH stream_columns AS (
74
+ SELECT
75
+ fq_object_name,
76
+ HASH(
77
+ value:name::VARCHAR,
71
78
  CASE
72
- WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
73
- THEN CASE value:type::VARCHAR
74
- WHEN 'FIXED' THEN 'NUMBER'
75
- WHEN 'REAL' THEN 'FLOAT'
76
- WHEN 'TEXT' THEN 'TEXT'
77
- ELSE value:type::VARCHAR
78
- END || '(' || value:precision || ',' || value:scale || ')'
79
- WHEN value:precision IS NOT NULL AND value:scale IS NULL
80
- THEN CASE value:type::VARCHAR
81
- WHEN 'FIXED' THEN 'NUMBER'
82
- WHEN 'REAL' THEN 'FLOAT'
83
- WHEN 'TEXT' THEN 'TEXT'
84
- ELSE value:type::VARCHAR
85
- END || '(0,' || value:precision || ')'
86
- WHEN value:length IS NOT NULL
87
- THEN CASE value:type::VARCHAR
88
- WHEN 'FIXED' THEN 'NUMBER'
89
- WHEN 'REAL' THEN 'FLOAT'
90
- WHEN 'TEXT' THEN 'TEXT'
91
- ELSE value:type::VARCHAR
92
- END || '(' || value:length || ')'
79
+ WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL THEN CASE value:type::VARCHAR
80
+ WHEN 'FIXED' THEN 'NUMBER'
81
+ WHEN 'REAL' THEN 'FLOAT'
82
+ WHEN 'TEXT' THEN 'TEXT'
83
+ ELSE value:type::VARCHAR
84
+ END || '(' || value:precision || ',' || value:scale || ')'
85
+ WHEN value:precision IS NOT NULL AND value:scale IS NULL THEN CASE value:type::VARCHAR
86
+ WHEN 'FIXED' THEN 'NUMBER'
87
+ WHEN 'REAL' THEN 'FLOAT'
88
+ WHEN 'TEXT' THEN 'TEXT'
89
+ ELSE value:type::VARCHAR
90
+ END || '(0,' || value:precision || ')'
91
+ WHEN value:length IS NOT NULL THEN CASE value:type::VARCHAR
92
+ WHEN 'FIXED' THEN 'NUMBER'
93
+ WHEN 'REAL' THEN 'FLOAT'
94
+ WHEN 'TEXT' THEN 'TEXT'
95
+ ELSE value:type::VARCHAR
96
+ END || '(' || value:length || ')'
93
97
  ELSE CASE value:type::VARCHAR
94
- WHEN 'FIXED' THEN 'NUMBER'
95
- WHEN 'REAL' THEN 'FLOAT'
96
- WHEN 'TEXT' THEN 'TEXT'
97
- ELSE value:type::VARCHAR
98
- END
99
- END ||
100
- CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
101
- ','
102
- ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
103
- 256
104
- ) AS STREAM_HASH
105
- FROM {app_name}.api.data_streams,
106
- LATERAL FLATTEN(input => COLUMNS) f
107
- WHERE RAI_DATABASE = '{rai_database}' AND FQ_OBJECT_NAME IN ({fqn_list})
108
- GROUP BY FQ_OBJECT_NAME;
98
+ WHEN 'FIXED' THEN 'NUMBER'
99
+ WHEN 'REAL' THEN 'FLOAT'
100
+ WHEN 'TEXT' THEN 'TEXT'
101
+ ELSE value:type::VARCHAR
102
+ END
103
+ END,
104
+ IFF(value:nullable::BOOLEAN, 'YES', 'NO')
105
+ ) AS column_signature
106
+ FROM {app_name}.api.data_streams,
107
+ LATERAL FLATTEN(input => columns)
108
+ WHERE rai_database = '{rai_database}'
109
+ AND fq_object_name IN ({fqn_list})
110
+ )
111
+ SELECT
112
+ fq_object_name AS FQ_OBJECT_NAME,
113
+ HEX_ENCODE(HASH_AGG(column_signature)) AS STREAM_HASH
114
+ FROM stream_columns
115
+ GROUP BY fq_object_name;
109
116
  """
110
117
 
111
118
 
@@ -154,8 +161,9 @@ class UseIndexPoller:
154
161
  model: str,
155
162
  engine_name: str,
156
163
  engine_size: Optional[str],
157
- program_span_id: Optional[str],
158
- headers: Optional[Dict],
164
+ language: str = "rel",
165
+ program_span_id: Optional[str] = None,
166
+ headers: Optional[Dict] = None,
159
167
  generation: Optional[Generation] = None,
160
168
  ):
161
169
  self.res = resource
@@ -164,6 +172,7 @@ class UseIndexPoller:
164
172
  self.model = model
165
173
  self.engine_name = engine_name
166
174
  self.engine_size = engine_size or self.res.config.get_default_engine_size()
175
+ self.language = language
167
176
  self.program_span_id = program_span_id
168
177
  self.headers = headers or {}
169
178
  self.counter = 1
@@ -183,8 +192,8 @@ class UseIndexPoller:
183
192
  )
184
193
  current_user = self.res.get_sf_session().get_current_user()
185
194
  assert current_user is not None, "current_user must be set"
186
- data_freshness = self.res.config.get_data_freshness_mins()
187
- self.cache = GraphIndexCache(current_user, model, data_freshness, self.sources)
195
+ self.data_freshness = self.res.config.get_data_freshness_mins()
196
+ self.cache = GraphIndexCache(current_user, model, self.data_freshness, self.sources)
188
197
  self.sources = self.cache.choose_sources()
189
198
  # execution_id is allowed to group use_index call, which belongs to the same loop iteration
190
199
  self.execution_id = str(uuid.uuid4())
@@ -296,9 +305,10 @@ class UseIndexPoller:
296
305
  Returns:
297
306
  List of truly stale sources that need to be deleted/recreated
298
307
 
299
- A source is truly stale if:
300
- - The stream doesn't exist (needs to be created), OR
301
- - The column hashes don't match (needs to be recreated)
308
+ A source is truly stale if any of the following apply:
309
+ - The stream doesn't exist (needs to be created)
310
+ - The source table was recreated after the stream (table creation timestamp is newer)
311
+ - The column hashes don't match (schema drift needs cleanup)
302
312
  """
303
313
  stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
304
314
 
@@ -306,14 +316,30 @@ class UseIndexPoller:
306
316
  for source in stale_sources:
307
317
  source_hash = self.source_info[source].get("columns_hash")
308
318
  stream_hash = stream_hashes.get(source)
319
+ table_created_at_raw = self.source_info[source].get("table_created_at")
320
+ stream_created_at_raw = self.source_info[source].get("stream_created_at")
321
+
322
+ table_created_at = normalize_datetime(table_created_at_raw)
323
+ stream_created_at = normalize_datetime(stream_created_at_raw)
324
+
325
+ recreated_table = False
326
+ if table_created_at is not None and stream_created_at is not None:
327
+ # If the source table was recreated (new creation timestamp) but kept
328
+ # the same column definitions, we still need to recycle the stream so
329
+ # that Snowflake picks up the new table instance.
330
+ recreated_table = table_created_at > stream_created_at
309
331
 
310
332
  # Log hash comparison for debugging
311
333
  logger.debug(f"Source: {source}")
312
334
  logger.debug(f" Source table hash: {source_hash}")
313
335
  logger.debug(f" Stream hash: {stream_hash}")
314
336
  logger.debug(f" Match: {source_hash == stream_hash}")
337
+ if recreated_table:
338
+ logger.debug(" Table appears to have been recreated (table_created_at > stream_created_at)")
339
+ logger.debug(f" table_created_at: {table_created_at}")
340
+ logger.debug(f" stream_created_at: {stream_created_at}")
315
341
 
316
- if stream_hash is None or source_hash != stream_hash:
342
+ if stream_hash is None or source_hash != stream_hash or recreated_table:
317
343
  logger.debug(" Action: DELETE (stale)")
318
344
  truly_stale.append(source)
319
345
  else:
@@ -376,7 +402,7 @@ class UseIndexPoller:
376
402
  stale_sources = [
377
403
  source
378
404
  for source, info in self.source_info.items()
379
- if info["state"] == "STALE"
405
+ if info.get("state") == "STALE"
380
406
  ]
381
407
 
382
408
  if not stale_sources:
@@ -462,6 +488,8 @@ class UseIndexPoller:
462
488
  "wait_for_stream_sync": self.wait_for_stream_sync,
463
489
  "should_check_cdc": self.should_check_cdc,
464
490
  "init_engine_async": self.init_engine_async,
491
+ "language": self.language,
492
+ "data_freshness_mins": self.data_freshness,
465
493
  })
466
494
 
467
495
  request_headers = debugging.add_current_propagation_headers(self.headers)
@@ -763,7 +791,7 @@ class UseIndexPoller:
763
791
  # Log the error for debugging
764
792
  logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
765
793
  failed_tables.append((fqn, str(e)))
766
-
794
+
767
795
  # Handle errors based on subtask type
768
796
  if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
769
797
  # Mark the individual subtask as failed and complete it
@@ -829,11 +857,23 @@ class DirectUseIndexPoller(UseIndexPoller):
829
857
  model: str,
830
858
  engine_name: str,
831
859
  engine_size: Optional[str],
832
- program_span_id: Optional[str],
833
- headers: Optional[Dict],
860
+ language: str = "rel",
861
+ program_span_id: Optional[str] = None,
862
+ headers: Optional[Dict] = None,
834
863
  generation: Optional[Generation] = None,
835
864
  ):
836
- super().__init__(resource, app_name, sources, model, engine_name, engine_size, program_span_id, headers, generation)
865
+ super().__init__(
866
+ resource=resource,
867
+ app_name=app_name,
868
+ sources=sources,
869
+ model=model,
870
+ engine_name=engine_name,
871
+ engine_size=engine_size,
872
+ language=language,
873
+ program_span_id=program_span_id,
874
+ headers=headers,
875
+ generation=generation,
876
+ )
837
877
  from relationalai.clients.snowflake import DirectAccessResources
838
878
  self.res: DirectAccessResources = cast(DirectAccessResources, self.res)
839
879
 
@@ -854,6 +894,7 @@ class DirectUseIndexPoller(UseIndexPoller):
854
894
  model=self.model,
855
895
  engine_name=self.engine_name,
856
896
  engine_size=self.engine_size,
897
+ language=self.language,
857
898
  rai_relations=[],
858
899
  pyrel_program_id=self.program_span_id,
859
900
  skip_pull_relations=True,
@@ -80,6 +80,15 @@ def escape_for_f_string(code: str) -> str:
80
80
  def escape_for_sproc(code: str) -> str:
81
81
  return code.replace("$$", "\\$\\$")
82
82
 
83
+
84
+ def normalize_datetime(value: object) -> datetime | None:
85
+ """Return a timezone-aware UTC datetime or None."""
86
+ if not isinstance(value, datetime):
87
+ return None
88
+ if value.tzinfo is None:
89
+ return value.replace(tzinfo=timezone.utc)
90
+ return value.astimezone(timezone.utc)
91
+
83
92
  # @NOTE: `overhead_rate` should fall between 0.05 and 0.5 depending on how time sensitive / expensive the operation in question is.
84
93
  def poll_with_specified_overhead(
85
94
  f,
relationalai/dsl.py CHANGED
@@ -22,6 +22,7 @@ import sys
22
22
  from pandas import DataFrame
23
23
 
24
24
  from relationalai.environments import runtime_env, SnowbookEnvironment
25
+ from relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
25
26
 
26
27
  from .clients.client import Client
27
28
 
@@ -34,9 +35,7 @@ from .errors import FilterAsValue, Errors, InvalidPropertySetException, Multiple
34
35
  #--------------------------------------------------
35
36
 
36
37
  RESERVED_PROPS = ["add", "set", "persist", "unpersist"]
37
-
38
38
  MAX_QUERY_ATTRIBUTE_LENGTH = 255
39
- QUERY_ATTRIBUTES_HEADER = "X-Query-Attributes"
40
39
 
41
40
  Value = Union[
42
41
  "Expression",
@@ -20,7 +20,16 @@ class SnowbookEnvironment(NotebookRuntimeEnvironment, SessionEnvironment):
20
20
 
21
21
  def __init__(self):
22
22
  super().__init__()
23
- self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
23
+ # Detect runner type based on module presence:
24
+ # - Warehouse runtime has '_snowflake' module
25
+ # - Container runtime has 'snowflake._legacy' module
26
+ if "_snowflake" in sys.modules:
27
+ self.runner = "warehouse"
28
+ elif "snowflake._legacy" in sys.modules:
29
+ self.runner = "container"
30
+ else:
31
+ # Fallback to original check
32
+ self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
24
33
 
25
34
  @classmethod
26
35
  def detect(cls):