relationalai 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,7 +41,7 @@ from ..clients.types import AvailableModel, EngineState, Import, ImportSource, I
41
41
  from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
42
42
  from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
43
43
  from ..clients.direct_access_client import DirectAccessClient
44
- from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp
44
+ from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
45
45
  from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
46
46
  from .. import dsl, rel, metamodel as m
47
47
  from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
@@ -1867,7 +1867,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1867
1867
  except Exception as e:
1868
1868
  err_message = str(e).lower()
1869
1869
  if _is_engine_issue(err_message):
1870
- self.auto_create_engine(engine)
1870
+ self.auto_create_engine(engine, headers=headers)
1871
1871
  self._exec_async_v2(
1872
1872
  database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1873
1873
  headers=headers, bypass_index=bypass_index, language='lqp',
@@ -1907,7 +1907,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1907
1907
  except Exception as e:
1908
1908
  err_message = str(e).lower()
1909
1909
  if _is_engine_issue(err_message):
1910
- self.auto_create_engine(engine)
1910
+ self.auto_create_engine(engine, headers=headers)
1911
1911
  return self._exec_async_v2(
1912
1912
  database,
1913
1913
  engine,
@@ -1970,9 +1970,9 @@ Otherwise, remove it from your '{profile}' configuration profile.
1970
1970
  if use_graph_index:
1971
1971
  # we do not provide a default value for query_timeout_mins so that we can control the default on app level
1972
1972
  if query_timeout_mins is not None:
1973
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1973
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1974
1974
  else:
1975
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1975
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1976
1976
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
1977
1977
  rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
1978
1978
  rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
@@ -2047,9 +2047,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
2047
2047
  app_name = self.get_app_name()
2048
2048
 
2049
2049
  source_types = dict[str, SourceInfo]()
2050
- partitioned_sources: dict[str, dict[str, list[str]]] = defaultdict(
2050
+ partitioned_sources: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2051
2051
  lambda: defaultdict(list)
2052
2052
  )
2053
+ fqn_to_parts: dict[str, tuple[str, str, str]] = {}
2053
2054
 
2054
2055
  for source in sources:
2055
2056
  parser = IdentityParser(source, True)
@@ -2057,82 +2058,219 @@ Otherwise, remove it from your '{profile}' configuration profile.
2057
2058
  assert len(parsed) == 4, f"Invalid source: {source}"
2058
2059
  db, schema, entity, identity = parsed
2059
2060
  assert db and schema and entity and identity, f"Invalid source: {source}"
2060
- source_types[identity] = cast(SourceInfo, {"type": None, "state": "", "columns_hash": None})
2061
- partitioned_sources[db][schema].append(entity)
2062
-
2063
- # TODO: Move to NA layer
2064
- query = (
2065
- " UNION ALL ".join(
2066
- f"""SELECT
2067
- inf.FQN,
2068
- inf.KIND,
2069
- inf.COLUMNS_HASH,
2070
- IFF(DATEDIFF(second, ds.created_at::TIMESTAMP, inf.LAST_DDL::TIMESTAMP) > 0, 'STALE', 'CURRENT') AS STATE
2071
- FROM (
2072
- SELECT (SELECT {app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(FQ_OBJECT_NAME))[0]:identifier::string) as FQ_OBJECT_NAME,
2073
- CREATED_AT FROM {app_name}.api.data_streams
2074
- WHERE RAI_DATABASE = '{PYREL_ROOT_DB}'
2075
- ) ds
2076
- RIGHT JOIN (
2061
+ source_types[identity] = cast(
2062
+ SourceInfo,
2063
+ {
2064
+ "type": None,
2065
+ "state": "",
2066
+ "columns_hash": None,
2067
+ "table_created_at": None,
2068
+ "stream_created_at": None,
2069
+ "last_ddl": None,
2070
+ },
2071
+ )
2072
+ partitioned_sources[db][schema].append({"entity": entity, "identity": identity})
2073
+ fqn_to_parts[identity] = (db, schema, entity)
2074
+
2075
+ if not partitioned_sources:
2076
+ return source_types
2077
+
2078
+ state_queries: list[str] = []
2079
+ for db, schemas in partitioned_sources.items():
2080
+ select_rows: list[str] = []
2081
+ for schema, tables in schemas.items():
2082
+ for table_info in tables:
2083
+ select_rows.append(
2084
+ "SELECT "
2085
+ f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2086
+ f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2087
+ f"{IdentityParser.to_sql_value(table_info['entity'])} AS table_name"
2088
+ )
2089
+
2090
+ if not select_rows:
2091
+ continue
2092
+
2093
+ target_entities_clause = "\n UNION ALL\n ".join(select_rows)
2094
+ # Main query:
2095
+ # 1. Enumerate the target tables via target_entities.
2096
+ # 2. Pull their metadata (last_altered, type) from INFORMATION_SCHEMA.TABLES.
2097
+ # 3. Look up the most recent stream activity for those FQNs only.
2098
+ # 4. Capture creation timestamps and use last_ddl vs created_at to classify each target,
2099
+ # so we mark tables as stale when they were recreated even if column hashes still match.
2100
+ state_queries.append(
2101
+ f"""WITH target_entities AS (
2102
+ {target_entities_clause}
2103
+ ),
2104
+ table_info AS (
2105
+ SELECT
2106
+ {app_name}.api.normalize_fq_ids(
2107
+ ARRAY_CONSTRUCT(
2108
+ CASE
2109
+ WHEN t.table_catalog = UPPER(t.table_catalog) THEN t.table_catalog
2110
+ ELSE '"' || t.table_catalog || '"'
2111
+ END || '.' ||
2112
+ CASE
2113
+ WHEN t.table_schema = UPPER(t.table_schema) THEN t.table_schema
2114
+ ELSE '"' || t.table_schema || '"'
2115
+ END || '.' ||
2116
+ CASE
2117
+ WHEN t.table_name = UPPER(t.table_name) THEN t.table_name
2118
+ ELSE '"' || t.table_name || '"'
2119
+ END
2120
+ )
2121
+ )[0]:identifier::string AS fqn,
2122
+ CONVERT_TIMEZONE('UTC', t.last_altered) AS last_ddl,
2123
+ CONVERT_TIMEZONE('UTC', t.created) AS table_created_at,
2124
+ t.table_type AS kind
2125
+ FROM {db}.INFORMATION_SCHEMA.tables t
2126
+ JOIN target_entities te
2127
+ ON t.table_catalog = te.catalog_name
2128
+ AND t.table_schema = te.schema_name
2129
+ AND t.table_name = te.table_name
2130
+ ),
2131
+ stream_activity AS (
2132
+ SELECT
2133
+ sa.fqn,
2134
+ MAX(sa.created_at) AS created_at
2135
+ FROM (
2136
+ SELECT
2137
+ {app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(fq_object_name))[0]:identifier::string AS fqn,
2138
+ created_at
2139
+ FROM {app_name}.api.data_streams
2140
+ WHERE rai_database = '{PYREL_ROOT_DB}'
2141
+ ) sa
2142
+ JOIN table_info ti
2143
+ ON sa.fqn = ti.fqn
2144
+ GROUP BY sa.fqn
2145
+ )
2077
2146
  SELECT
2078
- (SELECT {app_name}.api.normalize_fq_ids(
2079
- ARRAY_CONSTRUCT(
2080
- CASE
2081
- WHEN t.TABLE_CATALOG = UPPER(t.TABLE_CATALOG) THEN t.TABLE_CATALOG
2082
- ELSE '"' || t.TABLE_CATALOG || '"'
2083
- END || '.' ||
2084
- CASE
2085
- WHEN t.TABLE_SCHEMA = UPPER(t.TABLE_SCHEMA) THEN t.TABLE_SCHEMA
2086
- ELSE '"' || t.TABLE_SCHEMA || '"'
2087
- END || '.' ||
2088
- CASE
2089
- WHEN t.TABLE_NAME = UPPER(t.TABLE_NAME) THEN t.TABLE_NAME
2090
- ELSE '"' || t.TABLE_NAME || '"'
2091
- END
2092
- )
2093
- )[0]:identifier::string) as FQN,
2094
- CONVERT_TIMEZONE('UTC', LAST_DDL) AS LAST_DDL,
2095
- TABLE_TYPE as KIND,
2096
- SHA2(LISTAGG(
2097
- COLUMN_NAME ||
2098
- CASE
2099
- WHEN c.NUMERIC_PRECISION IS NOT NULL AND c.NUMERIC_SCALE IS NOT NULL
2100
- THEN c.DATA_TYPE || '(' || c.NUMERIC_PRECISION || ',' || c.NUMERIC_SCALE || ')'
2101
- WHEN c.DATETIME_PRECISION IS NOT NULL
2102
- THEN c.DATA_TYPE || '(0,' || c.DATETIME_PRECISION || ')'
2103
- WHEN c.CHARACTER_MAXIMUM_LENGTH IS NOT NULL
2104
- THEN c.DATA_TYPE || '(' || c.CHARACTER_MAXIMUM_LENGTH || ')'
2105
- ELSE c.DATA_TYPE
2106
- END ||
2107
- IS_NULLABLE,
2108
- ','
2109
- ) WITHIN GROUP (ORDER BY COLUMN_NAME), 256) as COLUMNS_HASH
2110
- FROM {db}.INFORMATION_SCHEMA.TABLES t
2111
- JOIN {db}.INFORMATION_SCHEMA.COLUMNS c
2112
- ON t.TABLE_CATALOG = c.TABLE_CATALOG
2113
- AND t.TABLE_SCHEMA = c.TABLE_SCHEMA
2114
- AND t.TABLE_NAME = c.TABLE_NAME
2115
- WHERE t.TABLE_CATALOG = {IdentityParser.to_sql_value(db)} AND ({" OR ".join(
2116
- f"(t.TABLE_SCHEMA = {IdentityParser.to_sql_value(schema)} AND t.TABLE_NAME IN ({','.join(f'{IdentityParser.to_sql_value(table)}' for table in tables)}))"
2117
- for schema, tables in schemas.items()
2118
- )})
2119
- GROUP BY t.TABLE_CATALOG, t.TABLE_SCHEMA, t.TABLE_NAME, t.LAST_DDL, t.TABLE_TYPE
2120
- ) inf on inf.FQN = ds.FQ_OBJECT_NAME
2121
- """
2122
- for db, schemas in partitioned_sources.items()
2147
+ ti.fqn,
2148
+ ti.kind,
2149
+ ti.last_ddl,
2150
+ ti.table_created_at,
2151
+ sa.created_at AS stream_created_at,
2152
+ IFF(
2153
+ DATEDIFF(second, sa.created_at::timestamp, ti.last_ddl::timestamp) > 0,
2154
+ 'STALE',
2155
+ 'CURRENT'
2156
+ ) AS state
2157
+ FROM table_info ti
2158
+ LEFT JOIN stream_activity sa
2159
+ ON sa.fqn = ti.fqn
2160
+ """
2123
2161
  )
2124
- + ";"
2162
+
2163
+ stale_fqns: list[str] = []
2164
+ for state_query in state_queries:
2165
+ for row in self._exec(state_query):
2166
+ row_dict = row.as_dict() if hasattr(row, "as_dict") else dict(row)
2167
+ row_fqn = row_dict["FQN"]
2168
+ parser = IdentityParser(row_fqn, True)
2169
+ fqn = parser.identity
2170
+ assert fqn, f"Error parsing returned FQN: {row_fqn}"
2171
+
2172
+ source_types[fqn]["type"] = (
2173
+ "TABLE" if row_dict["KIND"] == "BASE TABLE" else row_dict["KIND"]
2174
+ )
2175
+ source_types[fqn]["state"] = row_dict["STATE"]
2176
+ source_types[fqn]["last_ddl"] = normalize_datetime(row_dict.get("LAST_DDL"))
2177
+ source_types[fqn]["table_created_at"] = normalize_datetime(row_dict.get("TABLE_CREATED_AT"))
2178
+ source_types[fqn]["stream_created_at"] = normalize_datetime(row_dict.get("STREAM_CREATED_AT"))
2179
+ if row_dict["STATE"] == "STALE":
2180
+ stale_fqns.append(fqn)
2181
+
2182
+ if not stale_fqns:
2183
+ return source_types
2184
+
2185
+ # We batch stale tables by database/schema so each Snowflake query can hash
2186
+ # multiple objects at once instead of issuing one statement per table.
2187
+ stale_partitioned: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2188
+ lambda: defaultdict(list)
2125
2189
  )
2190
+ for fqn in stale_fqns:
2191
+ db, schema, table = fqn_to_parts[fqn]
2192
+ stale_partitioned[db][schema].append({"table": table, "identity": fqn})
2193
+
2194
+ # Build one hash query per database, grouping schemas/tables inside so we submit
2195
+ # at most a handful of set-based statements to Snowflake.
2196
+ for db, schemas in stale_partitioned.items():
2197
+ column_select_rows: list[str] = []
2198
+ for schema, tables in schemas.items():
2199
+ for table_info in tables:
2200
+ # Build the literal rows for this db/schema so we can join back
2201
+ # against INFORMATION_SCHEMA.COLUMNS in a single statement.
2202
+ column_select_rows.append(
2203
+ "SELECT "
2204
+ f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2205
+ f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2206
+ f"{IdentityParser.to_sql_value(table_info['table'])} AS table_name"
2207
+ )
2126
2208
 
2127
- for row in self._exec(query):
2128
- row_fqn = row["FQN"]
2129
- parser = IdentityParser(row_fqn, True)
2130
- fqn = parser.identity
2131
- assert fqn, f"Error parsing returned FQN: {row_fqn}"
2209
+ if not column_select_rows:
2210
+ continue
2132
2211
 
2133
- source_types[fqn]["type"] = "TABLE" if row["KIND"] == "BASE TABLE" else row["KIND"]
2134
- source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
2135
- source_types[fqn]["state"] = row["STATE"]
2212
+ target_entities_clause = "\n UNION ALL\n ".join(column_select_rows)
2213
+ # Main query: compute deterministic column hashes for every stale table
2214
+ # in this database/schema batch so we can compare schemas without a round trip per table.
2215
+ column_query = f"""WITH target_entities AS (
2216
+ {target_entities_clause}
2217
+ ),
2218
+ column_info AS (
2219
+ SELECT
2220
+ {app_name}.api.normalize_fq_ids(
2221
+ ARRAY_CONSTRUCT(
2222
+ CASE
2223
+ WHEN c.table_catalog = UPPER(c.table_catalog) THEN c.table_catalog
2224
+ ELSE '"' || c.table_catalog || '"'
2225
+ END || '.' ||
2226
+ CASE
2227
+ WHEN c.table_schema = UPPER(c.table_schema) THEN c.table_schema
2228
+ ELSE '"' || c.table_schema || '"'
2229
+ END || '.' ||
2230
+ CASE
2231
+ WHEN c.table_name = UPPER(c.table_name) THEN c.table_name
2232
+ ELSE '"' || c.table_name || '"'
2233
+ END
2234
+ )
2235
+ )[0]:identifier::string AS fqn,
2236
+ c.column_name,
2237
+ CASE
2238
+ WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL
2239
+ THEN c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')'
2240
+ WHEN c.datetime_precision IS NOT NULL
2241
+ THEN c.data_type || '(0,' || c.datetime_precision || ')'
2242
+ WHEN c.character_maximum_length IS NOT NULL
2243
+ THEN c.data_type || '(' || c.character_maximum_length || ')'
2244
+ ELSE c.data_type
2245
+ END AS type_signature,
2246
+ IFF(c.is_nullable = 'YES', 'YES', 'NO') AS nullable_flag
2247
+ FROM {db}.INFORMATION_SCHEMA.COLUMNS c
2248
+ JOIN target_entities te
2249
+ ON c.table_catalog = te.catalog_name
2250
+ AND c.table_schema = te.schema_name
2251
+ AND c.table_name = te.table_name
2252
+ )
2253
+ SELECT
2254
+ fqn,
2255
+ HEX_ENCODE(
2256
+ HASH_AGG(
2257
+ HASH(
2258
+ column_name,
2259
+ type_signature,
2260
+ nullable_flag
2261
+ )
2262
+ )
2263
+ ) AS columns_hash
2264
+ FROM column_info
2265
+ GROUP BY fqn
2266
+ """
2267
+
2268
+ for row in self._exec(column_query):
2269
+ row_fqn = row["FQN"]
2270
+ parser = IdentityParser(row_fqn, True)
2271
+ fqn = parser.identity
2272
+ assert fqn, f"Error parsing returned FQN: {row_fqn}"
2273
+ source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
2136
2274
 
2137
2275
  return source_types
2138
2276
 
@@ -2142,12 +2280,13 @@ Otherwise, remove it from your '{profile}' configuration profile.
2142
2280
  invalid_sources = {}
2143
2281
  source_references = []
2144
2282
  for source, info in source_info.items():
2145
- if info['type'] is None:
2283
+ source_type = info.get("type")
2284
+ if source_type is None:
2146
2285
  missing_sources.append(source)
2147
- elif info['type'] not in ("TABLE", "VIEW"):
2148
- invalid_sources[source] = info['type']
2286
+ elif source_type not in ("TABLE", "VIEW"):
2287
+ invalid_sources[source] = source_type
2149
2288
  else:
2150
- source_references.append(f"{app_name}.api.object_reference('{info['type']}', '{source}')")
2289
+ source_references.append(f"{app_name}.api.object_reference('{source_type}', '{source}')")
2151
2290
 
2152
2291
  if missing_sources:
2153
2292
  current_role = self.get_sf_session().get_current_role()
@@ -3045,6 +3184,7 @@ class DirectAccessResources(Resources):
3045
3184
  headers: Dict[str, str] | None = None,
3046
3185
  path_params: Dict[str, str] | None = None,
3047
3186
  query_params: Dict[str, str] | None = None,
3187
+ skip_auto_create: bool = False,
3048
3188
  ) -> requests.Response:
3049
3189
  with debugging.span("direct_access_request"):
3050
3190
  def _send_request():
@@ -3066,7 +3206,8 @@ class DirectAccessResources(Resources):
3066
3206
  )
3067
3207
 
3068
3208
  # fix engine on engine error and retry
3069
- if _is_engine_issue(message):
3209
+ # Skip auto-retry if skip_auto_create is True to avoid recursion
3210
+ if _is_engine_issue(message) and not skip_auto_create:
3070
3211
  engine = payload.get("engine_name", "") if payload else ""
3071
3212
  self.auto_create_engine(engine)
3072
3213
  response = _send_request()
@@ -3431,7 +3572,7 @@ class DirectAccessResources(Resources):
3431
3572
  return sorted(engines, key=lambda x: x["name"])
3432
3573
 
3433
3574
  def get_engine(self, name: str):
3434
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"})
3575
+ response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
3435
3576
  if response.status_code == 404: # engine not found return 404
3436
3577
  return None
3437
3578
  elif response.status_code != 200:
@@ -3478,6 +3619,7 @@ class DirectAccessResources(Resources):
3478
3619
  payload=payload,
3479
3620
  path_params={"engine_type": "logic"},
3480
3621
  headers=headers,
3622
+ skip_auto_create=True,
3481
3623
  )
3482
3624
  if response.status_code != 200:
3483
3625
  raise ResponseStatusException(
@@ -3489,6 +3631,7 @@ class DirectAccessResources(Resources):
3489
3631
  "delete_engine",
3490
3632
  path_params={"engine_name": name, "engine_type": "logic"},
3491
3633
  headers=headers,
3634
+ skip_auto_create=True,
3492
3635
  )
3493
3636
  if response.status_code != 200:
3494
3637
  raise ResponseStatusException(
@@ -3499,6 +3642,7 @@ class DirectAccessResources(Resources):
3499
3642
  response = self.request(
3500
3643
  "suspend_engine",
3501
3644
  path_params={"engine_name": name, "engine_type": "logic"},
3645
+ skip_auto_create=True,
3502
3646
  )
3503
3647
  if response.status_code != 200:
3504
3648
  raise ResponseStatusException(
@@ -3510,6 +3654,7 @@ class DirectAccessResources(Resources):
3510
3654
  "resume_engine",
3511
3655
  path_params={"engine_name": name, "engine_type": "logic"},
3512
3656
  headers=headers,
3657
+ skip_auto_create=True,
3513
3658
  )
3514
3659
  if response.status_code != 200:
3515
3660
  raise ResponseStatusException(
@@ -38,10 +38,13 @@ class EngineState(TypedDict):
38
38
  auto_suspend: int|None
39
39
  suspends_at: datetime|None
40
40
 
41
- class SourceInfo(TypedDict):
41
+ class SourceInfo(TypedDict, total=False):
42
42
  type: str|None
43
43
  state: str
44
44
  columns_hash: str|None
45
+ table_created_at: datetime|None
46
+ stream_created_at: datetime|None
47
+ last_ddl: datetime|None
45
48
  source: str
46
49
 
47
50
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
2
4
  import json
3
5
  import logging
@@ -5,7 +7,11 @@ import uuid
5
7
 
6
8
  from relationalai import debugging
7
9
  from relationalai.clients.cache_store import GraphIndexCache
8
- from relationalai.clients.util import get_pyrel_version, poll_with_specified_overhead
10
+ from relationalai.clients.util import (
11
+ get_pyrel_version,
12
+ normalize_datetime,
13
+ poll_with_specified_overhead,
14
+ )
9
15
  from relationalai.errors import (
10
16
  ERPNotRunningError,
11
17
  EngineProvisioningFailed,
@@ -29,6 +35,7 @@ from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
29
35
  # Set up logger for this module
30
36
  logger = logging.getLogger(__name__)
31
37
 
38
+
32
39
  try:
33
40
  from rich.console import Console
34
41
  from rich.table import Table
@@ -63,49 +70,49 @@ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
63
70
  # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
64
71
  # to detect if source table schema has changed since stream was created
65
72
  STREAM_COLUMN_HASH_QUERY = """
66
- SELECT
67
- FQ_OBJECT_NAME,
68
- SHA2(
69
- LISTAGG(
70
- value:name::VARCHAR ||
73
+ WITH stream_columns AS (
74
+ SELECT
75
+ fq_object_name,
76
+ HASH(
77
+ value:name::VARCHAR,
71
78
  CASE
72
- WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
73
- THEN CASE value:type::VARCHAR
74
- WHEN 'FIXED' THEN 'NUMBER'
75
- WHEN 'REAL' THEN 'FLOAT'
76
- WHEN 'TEXT' THEN 'TEXT'
77
- ELSE value:type::VARCHAR
78
- END || '(' || value:precision || ',' || value:scale || ')'
79
- WHEN value:precision IS NOT NULL AND value:scale IS NULL
80
- THEN CASE value:type::VARCHAR
81
- WHEN 'FIXED' THEN 'NUMBER'
82
- WHEN 'REAL' THEN 'FLOAT'
83
- WHEN 'TEXT' THEN 'TEXT'
84
- ELSE value:type::VARCHAR
85
- END || '(0,' || value:precision || ')'
86
- WHEN value:length IS NOT NULL
87
- THEN CASE value:type::VARCHAR
88
- WHEN 'FIXED' THEN 'NUMBER'
89
- WHEN 'REAL' THEN 'FLOAT'
90
- WHEN 'TEXT' THEN 'TEXT'
91
- ELSE value:type::VARCHAR
92
- END || '(' || value:length || ')'
79
+ WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL THEN CASE value:type::VARCHAR
80
+ WHEN 'FIXED' THEN 'NUMBER'
81
+ WHEN 'REAL' THEN 'FLOAT'
82
+ WHEN 'TEXT' THEN 'TEXT'
83
+ ELSE value:type::VARCHAR
84
+ END || '(' || value:precision || ',' || value:scale || ')'
85
+ WHEN value:precision IS NOT NULL AND value:scale IS NULL THEN CASE value:type::VARCHAR
86
+ WHEN 'FIXED' THEN 'NUMBER'
87
+ WHEN 'REAL' THEN 'FLOAT'
88
+ WHEN 'TEXT' THEN 'TEXT'
89
+ ELSE value:type::VARCHAR
90
+ END || '(0,' || value:precision || ')'
91
+ WHEN value:length IS NOT NULL THEN CASE value:type::VARCHAR
92
+ WHEN 'FIXED' THEN 'NUMBER'
93
+ WHEN 'REAL' THEN 'FLOAT'
94
+ WHEN 'TEXT' THEN 'TEXT'
95
+ ELSE value:type::VARCHAR
96
+ END || '(' || value:length || ')'
93
97
  ELSE CASE value:type::VARCHAR
94
- WHEN 'FIXED' THEN 'NUMBER'
95
- WHEN 'REAL' THEN 'FLOAT'
96
- WHEN 'TEXT' THEN 'TEXT'
97
- ELSE value:type::VARCHAR
98
- END
99
- END ||
100
- CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
101
- ','
102
- ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
103
- 256
104
- ) AS STREAM_HASH
105
- FROM {app_name}.api.data_streams,
106
- LATERAL FLATTEN(input => COLUMNS) f
107
- WHERE RAI_DATABASE = '{rai_database}' AND FQ_OBJECT_NAME IN ({fqn_list})
108
- GROUP BY FQ_OBJECT_NAME;
98
+ WHEN 'FIXED' THEN 'NUMBER'
99
+ WHEN 'REAL' THEN 'FLOAT'
100
+ WHEN 'TEXT' THEN 'TEXT'
101
+ ELSE value:type::VARCHAR
102
+ END
103
+ END,
104
+ IFF(value:nullable::BOOLEAN, 'YES', 'NO')
105
+ ) AS column_signature
106
+ FROM {app_name}.api.data_streams,
107
+ LATERAL FLATTEN(input => columns)
108
+ WHERE rai_database = '{rai_database}'
109
+ AND fq_object_name IN ({fqn_list})
110
+ )
111
+ SELECT
112
+ fq_object_name AS FQ_OBJECT_NAME,
113
+ HEX_ENCODE(HASH_AGG(column_signature)) AS STREAM_HASH
114
+ FROM stream_columns
115
+ GROUP BY fq_object_name;
109
116
  """
110
117
 
111
118
 
@@ -296,9 +303,10 @@ class UseIndexPoller:
296
303
  Returns:
297
304
  List of truly stale sources that need to be deleted/recreated
298
305
 
299
- A source is truly stale if:
300
- - The stream doesn't exist (needs to be created), OR
301
- - The column hashes don't match (needs to be recreated)
306
+ A source is truly stale if any of the following apply:
307
+ - The stream doesn't exist (needs to be created)
308
+ - The source table was recreated after the stream (table creation timestamp is newer)
309
+ - The column hashes don't match (schema drift needs cleanup)
302
310
  """
303
311
  stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
304
312
 
@@ -306,14 +314,30 @@ class UseIndexPoller:
306
314
  for source in stale_sources:
307
315
  source_hash = self.source_info[source].get("columns_hash")
308
316
  stream_hash = stream_hashes.get(source)
317
+ table_created_at_raw = self.source_info[source].get("table_created_at")
318
+ stream_created_at_raw = self.source_info[source].get("stream_created_at")
319
+
320
+ table_created_at = normalize_datetime(table_created_at_raw)
321
+ stream_created_at = normalize_datetime(stream_created_at_raw)
322
+
323
+ recreated_table = False
324
+ if table_created_at is not None and stream_created_at is not None:
325
+ # If the source table was recreated (new creation timestamp) but kept
326
+ # the same column definitions, we still need to recycle the stream so
327
+ # that Snowflake picks up the new table instance.
328
+ recreated_table = table_created_at > stream_created_at
309
329
 
310
330
  # Log hash comparison for debugging
311
331
  logger.debug(f"Source: {source}")
312
332
  logger.debug(f" Source table hash: {source_hash}")
313
333
  logger.debug(f" Stream hash: {stream_hash}")
314
334
  logger.debug(f" Match: {source_hash == stream_hash}")
335
+ if recreated_table:
336
+ logger.debug(" Table appears to have been recreated (table_created_at > stream_created_at)")
337
+ logger.debug(f" table_created_at: {table_created_at}")
338
+ logger.debug(f" stream_created_at: {stream_created_at}")
315
339
 
316
- if stream_hash is None or source_hash != stream_hash:
340
+ if stream_hash is None or source_hash != stream_hash or recreated_table:
317
341
  logger.debug(" Action: DELETE (stale)")
318
342
  truly_stale.append(source)
319
343
  else:
@@ -376,7 +400,7 @@ class UseIndexPoller:
376
400
  stale_sources = [
377
401
  source
378
402
  for source, info in self.source_info.items()
379
- if info["state"] == "STALE"
403
+ if info.get("state") == "STALE"
380
404
  ]
381
405
 
382
406
  if not stale_sources:
@@ -763,7 +787,7 @@ class UseIndexPoller:
763
787
  # Log the error for debugging
764
788
  logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
765
789
  failed_tables.append((fqn, str(e)))
766
-
790
+
767
791
  # Handle errors based on subtask type
768
792
  if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
769
793
  # Mark the individual subtask as failed and complete it
@@ -80,6 +80,15 @@ def escape_for_f_string(code: str) -> str:
80
80
  def escape_for_sproc(code: str) -> str:
81
81
  return code.replace("$$", "\\$\\$")
82
82
 
83
+
84
+ def normalize_datetime(value: object) -> datetime | None:
85
+ """Return a timezone-aware UTC datetime or None."""
86
+ if not isinstance(value, datetime):
87
+ return None
88
+ if value.tzinfo is None:
89
+ return value.replace(tzinfo=timezone.utc)
90
+ return value.astimezone(timezone.utc)
91
+
83
92
  # @NOTE: `overhead_rate` should fall between 0.05 and 0.5 depending on how time sensitive / expensive the operation in question is.
84
93
  def poll_with_specified_overhead(
85
94
  f,
relationalai/dsl.py CHANGED
@@ -22,6 +22,7 @@ import sys
22
22
  from pandas import DataFrame
23
23
 
24
24
  from relationalai.environments import runtime_env, SnowbookEnvironment
25
+ from relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
25
26
 
26
27
  from .clients.client import Client
27
28
 
@@ -34,9 +35,7 @@ from .errors import FilterAsValue, Errors, InvalidPropertySetException, Multiple
34
35
  #--------------------------------------------------
35
36
 
36
37
  RESERVED_PROPS = ["add", "set", "persist", "unpersist"]
37
-
38
38
  MAX_QUERY_ATTRIBUTE_LENGTH = 255
39
- QUERY_ATTRIBUTES_HEADER = "X-Query-Attributes"
40
39
 
41
40
  Value = Union[
42
41
  "Expression",