relationalai 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. relationalai/clients/config.py +7 -0
  2. relationalai/clients/direct_access_client.py +113 -0
  3. relationalai/clients/snowflake.py +263 -189
  4. relationalai/clients/types.py +4 -1
  5. relationalai/clients/use_index_poller.py +72 -48
  6. relationalai/clients/util.py +9 -0
  7. relationalai/dsl.py +1 -2
  8. relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
  9. relationalai/early_access/rel/rewrite/__init__.py +1 -1
  10. relationalai/environments/snowbook.py +10 -1
  11. relationalai/errors.py +24 -3
  12. relationalai/semantics/internal/annotations.py +1 -0
  13. relationalai/semantics/internal/internal.py +22 -3
  14. relationalai/semantics/lqp/builtins.py +1 -0
  15. relationalai/semantics/lqp/executor.py +12 -4
  16. relationalai/semantics/lqp/model2lqp.py +1 -0
  17. relationalai/semantics/lqp/passes.py +3 -4
  18. relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
  19. relationalai/semantics/metamodel/builtins.py +12 -1
  20. relationalai/semantics/metamodel/executor.py +2 -1
  21. relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
  22. relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
  23. relationalai/semantics/reasoners/graph/core.py +1356 -258
  24. relationalai/semantics/rel/builtins.py +5 -1
  25. relationalai/semantics/rel/compiler.py +3 -3
  26. relationalai/semantics/rel/executor.py +20 -11
  27. relationalai/semantics/sql/compiler.py +2 -3
  28. relationalai/semantics/sql/executor/duck_db.py +8 -4
  29. relationalai/semantics/sql/executor/snowflake.py +1 -1
  30. relationalai/tools/cli.py +17 -6
  31. relationalai/tools/cli_controls.py +334 -352
  32. relationalai/tools/constants.py +1 -0
  33. relationalai/tools/query_utils.py +27 -0
  34. relationalai/util/otel_configuration.py +1 -1
  35. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/METADATA +5 -4
  36. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/RECORD +45 -45
  37. relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
  38. relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
  39. /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
  40. /relationalai/semantics/{rel → lqp}/rewrite/extract_common.py +0 -0
  41. /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
  42. /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
  43. /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
  44. /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
  45. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/WHEEL +0 -0
  46. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/entry_points.txt +0 -0
  47. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/licenses/LICENSE +0 -0
@@ -28,15 +28,11 @@ import requests
28
28
  import snowflake.connector
29
29
  import pyarrow as pa
30
30
 
31
- from dataclasses import dataclass
32
31
  from snowflake.snowpark import Session
33
32
  from snowflake.snowpark.context import get_active_session
34
33
  from . import result_helpers
35
34
  from .. import debugging
36
35
  from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
37
- from urllib.parse import urlencode, quote
38
- from requests.adapters import HTTPAdapter
39
- from urllib3.util.retry import Retry
40
36
 
41
37
  from pandas import DataFrame
42
38
 
@@ -44,10 +40,11 @@ from ..tools.cli_controls import Spinner
44
40
  from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
45
41
  from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
46
42
  from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
47
- from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp
43
+ from ..clients.direct_access_client import DirectAccessClient
44
+ from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
48
45
  from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
49
46
  from .. import dsl, rel, metamodel as m
50
- from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning
47
+ from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
51
48
  from concurrent.futures import ThreadPoolExecutor
52
49
  from datetime import datetime, date, timedelta
53
50
  from snowflake.snowpark.types import StringType, StructField, StructType
@@ -92,6 +89,8 @@ TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
92
89
 
93
90
  DUO_TEXT = "duo security"
94
91
 
92
+ TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
93
+
95
94
  #--------------------------------------------------
96
95
  # Helpers
97
96
  #--------------------------------------------------
@@ -1307,7 +1306,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1307
1306
  response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
1308
1307
  assert response, f"No results from get_transaction('{txn_id}')"
1309
1308
 
1310
- response_row = next(iter(response))
1309
+ response_row = next(iter(response)).asDict()
1311
1310
  status: str = response_row['STATE']
1312
1311
 
1313
1312
  # remove the transaction from the pending list if it's completed or aborted
@@ -1315,6 +1314,16 @@ Otherwise, remove it from your '{profile}' configuration profile.
1315
1314
  if txn_id in self._pending_transactions:
1316
1315
  self._pending_transactions.remove(txn_id)
1317
1316
 
1317
+ if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1318
+ config_file_path = getattr(self.config, 'file_path', None)
1319
+ # todo: use the timeout returned alongside the transaction as soon as it's exposed
1320
+ timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1321
+ raise QueryTimeoutExceededException(
1322
+ timeout_mins=timeout_mins,
1323
+ query_id=txn_id,
1324
+ config_file_path=config_file_path,
1325
+ )
1326
+
1318
1327
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1319
1328
  return status == "COMPLETED" or status == "ABORTED"
1320
1329
 
@@ -1858,7 +1867,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1858
1867
  except Exception as e:
1859
1868
  err_message = str(e).lower()
1860
1869
  if _is_engine_issue(err_message):
1861
- self.auto_create_engine(engine)
1870
+ self.auto_create_engine(engine, headers=headers)
1862
1871
  self._exec_async_v2(
1863
1872
  database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1864
1873
  headers=headers, bypass_index=bypass_index, language='lqp',
@@ -1898,7 +1907,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1898
1907
  except Exception as e:
1899
1908
  err_message = str(e).lower()
1900
1909
  if _is_engine_issue(err_message):
1901
- self.auto_create_engine(engine)
1910
+ self.auto_create_engine(engine, headers=headers)
1902
1911
  return self._exec_async_v2(
1903
1912
  database,
1904
1913
  engine,
@@ -1961,9 +1970,9 @@ Otherwise, remove it from your '{profile}' configuration profile.
1961
1970
  if use_graph_index:
1962
1971
  # we do not provide a default value for query_timeout_mins so that we can control the default on app level
1963
1972
  if query_timeout_mins is not None:
1964
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1973
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1965
1974
  else:
1966
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1975
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1967
1976
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
1968
1977
  rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
1969
1978
  rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
@@ -2038,9 +2047,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
2038
2047
  app_name = self.get_app_name()
2039
2048
 
2040
2049
  source_types = dict[str, SourceInfo]()
2041
- partitioned_sources: dict[str, dict[str, list[str]]] = defaultdict(
2050
+ partitioned_sources: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2042
2051
  lambda: defaultdict(list)
2043
2052
  )
2053
+ fqn_to_parts: dict[str, tuple[str, str, str]] = {}
2044
2054
 
2045
2055
  for source in sources:
2046
2056
  parser = IdentityParser(source, True)
@@ -2048,82 +2058,219 @@ Otherwise, remove it from your '{profile}' configuration profile.
2048
2058
  assert len(parsed) == 4, f"Invalid source: {source}"
2049
2059
  db, schema, entity, identity = parsed
2050
2060
  assert db and schema and entity and identity, f"Invalid source: {source}"
2051
- source_types[identity] = cast(SourceInfo, {"type": None, "state": "", "columns_hash": None})
2052
- partitioned_sources[db][schema].append(entity)
2053
-
2054
- # TODO: Move to NA layer
2055
- query = (
2056
- " UNION ALL ".join(
2057
- f"""SELECT
2058
- inf.FQN,
2059
- inf.KIND,
2060
- inf.COLUMNS_HASH,
2061
- IFF(DATEDIFF(second, ds.created_at::TIMESTAMP, inf.LAST_DDL::TIMESTAMP) > 0, 'STALE', 'CURRENT') AS STATE
2062
- FROM (
2063
- SELECT (SELECT {app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(FQ_OBJECT_NAME))[0]:identifier::string) as FQ_OBJECT_NAME,
2064
- CREATED_AT FROM {app_name}.api.data_streams
2065
- WHERE RAI_DATABASE = '{PYREL_ROOT_DB}'
2066
- ) ds
2067
- RIGHT JOIN (
2061
+ source_types[identity] = cast(
2062
+ SourceInfo,
2063
+ {
2064
+ "type": None,
2065
+ "state": "",
2066
+ "columns_hash": None,
2067
+ "table_created_at": None,
2068
+ "stream_created_at": None,
2069
+ "last_ddl": None,
2070
+ },
2071
+ )
2072
+ partitioned_sources[db][schema].append({"entity": entity, "identity": identity})
2073
+ fqn_to_parts[identity] = (db, schema, entity)
2074
+
2075
+ if not partitioned_sources:
2076
+ return source_types
2077
+
2078
+ state_queries: list[str] = []
2079
+ for db, schemas in partitioned_sources.items():
2080
+ select_rows: list[str] = []
2081
+ for schema, tables in schemas.items():
2082
+ for table_info in tables:
2083
+ select_rows.append(
2084
+ "SELECT "
2085
+ f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2086
+ f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2087
+ f"{IdentityParser.to_sql_value(table_info['entity'])} AS table_name"
2088
+ )
2089
+
2090
+ if not select_rows:
2091
+ continue
2092
+
2093
+ target_entities_clause = "\n UNION ALL\n ".join(select_rows)
2094
+ # Main query:
2095
+ # 1. Enumerate the target tables via target_entities.
2096
+ # 2. Pull their metadata (last_altered, type) from INFORMATION_SCHEMA.TABLES.
2097
+ # 3. Look up the most recent stream activity for those FQNs only.
2098
+ # 4. Capture creation timestamps and use last_ddl vs created_at to classify each target,
2099
+ # so we mark tables as stale when they were recreated even if column hashes still match.
2100
+ state_queries.append(
2101
+ f"""WITH target_entities AS (
2102
+ {target_entities_clause}
2103
+ ),
2104
+ table_info AS (
2105
+ SELECT
2106
+ {app_name}.api.normalize_fq_ids(
2107
+ ARRAY_CONSTRUCT(
2108
+ CASE
2109
+ WHEN t.table_catalog = UPPER(t.table_catalog) THEN t.table_catalog
2110
+ ELSE '"' || t.table_catalog || '"'
2111
+ END || '.' ||
2112
+ CASE
2113
+ WHEN t.table_schema = UPPER(t.table_schema) THEN t.table_schema
2114
+ ELSE '"' || t.table_schema || '"'
2115
+ END || '.' ||
2116
+ CASE
2117
+ WHEN t.table_name = UPPER(t.table_name) THEN t.table_name
2118
+ ELSE '"' || t.table_name || '"'
2119
+ END
2120
+ )
2121
+ )[0]:identifier::string AS fqn,
2122
+ CONVERT_TIMEZONE('UTC', t.last_altered) AS last_ddl,
2123
+ CONVERT_TIMEZONE('UTC', t.created) AS table_created_at,
2124
+ t.table_type AS kind
2125
+ FROM {db}.INFORMATION_SCHEMA.tables t
2126
+ JOIN target_entities te
2127
+ ON t.table_catalog = te.catalog_name
2128
+ AND t.table_schema = te.schema_name
2129
+ AND t.table_name = te.table_name
2130
+ ),
2131
+ stream_activity AS (
2132
+ SELECT
2133
+ sa.fqn,
2134
+ MAX(sa.created_at) AS created_at
2135
+ FROM (
2136
+ SELECT
2137
+ {app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(fq_object_name))[0]:identifier::string AS fqn,
2138
+ created_at
2139
+ FROM {app_name}.api.data_streams
2140
+ WHERE rai_database = '{PYREL_ROOT_DB}'
2141
+ ) sa
2142
+ JOIN table_info ti
2143
+ ON sa.fqn = ti.fqn
2144
+ GROUP BY sa.fqn
2145
+ )
2068
2146
  SELECT
2069
- (SELECT {app_name}.api.normalize_fq_ids(
2070
- ARRAY_CONSTRUCT(
2071
- CASE
2072
- WHEN t.TABLE_CATALOG = UPPER(t.TABLE_CATALOG) THEN t.TABLE_CATALOG
2073
- ELSE '"' || t.TABLE_CATALOG || '"'
2074
- END || '.' ||
2075
- CASE
2076
- WHEN t.TABLE_SCHEMA = UPPER(t.TABLE_SCHEMA) THEN t.TABLE_SCHEMA
2077
- ELSE '"' || t.TABLE_SCHEMA || '"'
2078
- END || '.' ||
2079
- CASE
2080
- WHEN t.TABLE_NAME = UPPER(t.TABLE_NAME) THEN t.TABLE_NAME
2081
- ELSE '"' || t.TABLE_NAME || '"'
2082
- END
2083
- )
2084
- )[0]:identifier::string) as FQN,
2085
- CONVERT_TIMEZONE('UTC', LAST_DDL) AS LAST_DDL,
2086
- TABLE_TYPE as KIND,
2087
- SHA2(LISTAGG(
2088
- COLUMN_NAME ||
2089
- CASE
2090
- WHEN c.NUMERIC_PRECISION IS NOT NULL AND c.NUMERIC_SCALE IS NOT NULL
2091
- THEN c.DATA_TYPE || '(' || c.NUMERIC_PRECISION || ',' || c.NUMERIC_SCALE || ')'
2092
- WHEN c.DATETIME_PRECISION IS NOT NULL
2093
- THEN c.DATA_TYPE || '(0,' || c.DATETIME_PRECISION || ')'
2094
- WHEN c.CHARACTER_MAXIMUM_LENGTH IS NOT NULL
2095
- THEN c.DATA_TYPE || '(' || c.CHARACTER_MAXIMUM_LENGTH || ')'
2096
- ELSE c.DATA_TYPE
2097
- END ||
2098
- IS_NULLABLE,
2099
- ','
2100
- ) WITHIN GROUP (ORDER BY COLUMN_NAME), 256) as COLUMNS_HASH
2101
- FROM {db}.INFORMATION_SCHEMA.TABLES t
2102
- JOIN {db}.INFORMATION_SCHEMA.COLUMNS c
2103
- ON t.TABLE_CATALOG = c.TABLE_CATALOG
2104
- AND t.TABLE_SCHEMA = c.TABLE_SCHEMA
2105
- AND t.TABLE_NAME = c.TABLE_NAME
2106
- WHERE t.TABLE_CATALOG = {IdentityParser.to_sql_value(db)} AND ({" OR ".join(
2107
- f"(t.TABLE_SCHEMA = {IdentityParser.to_sql_value(schema)} AND t.TABLE_NAME IN ({','.join(f'{IdentityParser.to_sql_value(table)}' for table in tables)}))"
2108
- for schema, tables in schemas.items()
2109
- )})
2110
- GROUP BY t.TABLE_CATALOG, t.TABLE_SCHEMA, t.TABLE_NAME, t.LAST_DDL, t.TABLE_TYPE
2111
- ) inf on inf.FQN = ds.FQ_OBJECT_NAME
2112
- """
2113
- for db, schemas in partitioned_sources.items()
2147
+ ti.fqn,
2148
+ ti.kind,
2149
+ ti.last_ddl,
2150
+ ti.table_created_at,
2151
+ sa.created_at AS stream_created_at,
2152
+ IFF(
2153
+ DATEDIFF(second, sa.created_at::timestamp, ti.last_ddl::timestamp) > 0,
2154
+ 'STALE',
2155
+ 'CURRENT'
2156
+ ) AS state
2157
+ FROM table_info ti
2158
+ LEFT JOIN stream_activity sa
2159
+ ON sa.fqn = ti.fqn
2160
+ """
2114
2161
  )
2115
- + ";"
2162
+
2163
+ stale_fqns: list[str] = []
2164
+ for state_query in state_queries:
2165
+ for row in self._exec(state_query):
2166
+ row_dict = row.as_dict() if hasattr(row, "as_dict") else dict(row)
2167
+ row_fqn = row_dict["FQN"]
2168
+ parser = IdentityParser(row_fqn, True)
2169
+ fqn = parser.identity
2170
+ assert fqn, f"Error parsing returned FQN: {row_fqn}"
2171
+
2172
+ source_types[fqn]["type"] = (
2173
+ "TABLE" if row_dict["KIND"] == "BASE TABLE" else row_dict["KIND"]
2174
+ )
2175
+ source_types[fqn]["state"] = row_dict["STATE"]
2176
+ source_types[fqn]["last_ddl"] = normalize_datetime(row_dict.get("LAST_DDL"))
2177
+ source_types[fqn]["table_created_at"] = normalize_datetime(row_dict.get("TABLE_CREATED_AT"))
2178
+ source_types[fqn]["stream_created_at"] = normalize_datetime(row_dict.get("STREAM_CREATED_AT"))
2179
+ if row_dict["STATE"] == "STALE":
2180
+ stale_fqns.append(fqn)
2181
+
2182
+ if not stale_fqns:
2183
+ return source_types
2184
+
2185
+ # We batch stale tables by database/schema so each Snowflake query can hash
2186
+ # multiple objects at once instead of issuing one statement per table.
2187
+ stale_partitioned: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2188
+ lambda: defaultdict(list)
2116
2189
  )
2190
+ for fqn in stale_fqns:
2191
+ db, schema, table = fqn_to_parts[fqn]
2192
+ stale_partitioned[db][schema].append({"table": table, "identity": fqn})
2193
+
2194
+ # Build one hash query per database, grouping schemas/tables inside so we submit
2195
+ # at most a handful of set-based statements to Snowflake.
2196
+ for db, schemas in stale_partitioned.items():
2197
+ column_select_rows: list[str] = []
2198
+ for schema, tables in schemas.items():
2199
+ for table_info in tables:
2200
+ # Build the literal rows for this db/schema so we can join back
2201
+ # against INFORMATION_SCHEMA.COLUMNS in a single statement.
2202
+ column_select_rows.append(
2203
+ "SELECT "
2204
+ f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2205
+ f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2206
+ f"{IdentityParser.to_sql_value(table_info['table'])} AS table_name"
2207
+ )
2117
2208
 
2118
- for row in self._exec(query):
2119
- row_fqn = row["FQN"]
2120
- parser = IdentityParser(row_fqn, True)
2121
- fqn = parser.identity
2122
- assert fqn, f"Error parsing returned FQN: {row_fqn}"
2209
+ if not column_select_rows:
2210
+ continue
2123
2211
 
2124
- source_types[fqn]["type"] = "TABLE" if row["KIND"] == "BASE TABLE" else row["KIND"]
2125
- source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
2126
- source_types[fqn]["state"] = row["STATE"]
2212
+ target_entities_clause = "\n UNION ALL\n ".join(column_select_rows)
2213
+ # Main query: compute deterministic column hashes for every stale table
2214
+ # in this database/schema batch so we can compare schemas without a round trip per table.
2215
+ column_query = f"""WITH target_entities AS (
2216
+ {target_entities_clause}
2217
+ ),
2218
+ column_info AS (
2219
+ SELECT
2220
+ {app_name}.api.normalize_fq_ids(
2221
+ ARRAY_CONSTRUCT(
2222
+ CASE
2223
+ WHEN c.table_catalog = UPPER(c.table_catalog) THEN c.table_catalog
2224
+ ELSE '"' || c.table_catalog || '"'
2225
+ END || '.' ||
2226
+ CASE
2227
+ WHEN c.table_schema = UPPER(c.table_schema) THEN c.table_schema
2228
+ ELSE '"' || c.table_schema || '"'
2229
+ END || '.' ||
2230
+ CASE
2231
+ WHEN c.table_name = UPPER(c.table_name) THEN c.table_name
2232
+ ELSE '"' || c.table_name || '"'
2233
+ END
2234
+ )
2235
+ )[0]:identifier::string AS fqn,
2236
+ c.column_name,
2237
+ CASE
2238
+ WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL
2239
+ THEN c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')'
2240
+ WHEN c.datetime_precision IS NOT NULL
2241
+ THEN c.data_type || '(0,' || c.datetime_precision || ')'
2242
+ WHEN c.character_maximum_length IS NOT NULL
2243
+ THEN c.data_type || '(' || c.character_maximum_length || ')'
2244
+ ELSE c.data_type
2245
+ END AS type_signature,
2246
+ IFF(c.is_nullable = 'YES', 'YES', 'NO') AS nullable_flag
2247
+ FROM {db}.INFORMATION_SCHEMA.COLUMNS c
2248
+ JOIN target_entities te
2249
+ ON c.table_catalog = te.catalog_name
2250
+ AND c.table_schema = te.schema_name
2251
+ AND c.table_name = te.table_name
2252
+ )
2253
+ SELECT
2254
+ fqn,
2255
+ HEX_ENCODE(
2256
+ HASH_AGG(
2257
+ HASH(
2258
+ column_name,
2259
+ type_signature,
2260
+ nullable_flag
2261
+ )
2262
+ )
2263
+ ) AS columns_hash
2264
+ FROM column_info
2265
+ GROUP BY fqn
2266
+ """
2267
+
2268
+ for row in self._exec(column_query):
2269
+ row_fqn = row["FQN"]
2270
+ parser = IdentityParser(row_fqn, True)
2271
+ fqn = parser.identity
2272
+ assert fqn, f"Error parsing returned FQN: {row_fqn}"
2273
+ source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
2127
2274
 
2128
2275
  return source_types
2129
2276
 
@@ -2133,12 +2280,13 @@ Otherwise, remove it from your '{profile}' configuration profile.
2133
2280
  invalid_sources = {}
2134
2281
  source_references = []
2135
2282
  for source, info in source_info.items():
2136
- if info['type'] is None:
2283
+ source_type = info.get("type")
2284
+ if source_type is None:
2137
2285
  missing_sources.append(source)
2138
- elif info['type'] not in ("TABLE", "VIEW"):
2139
- invalid_sources[source] = info['type']
2286
+ elif source_type not in ("TABLE", "VIEW"):
2287
+ invalid_sources[source] = source_type
2140
2288
  else:
2141
- source_references.append(f"{app_name}.api.object_reference('{info['type']}', '{source}')")
2289
+ source_references.append(f"{app_name}.api.object_reference('{source_type}', '{source}')")
2142
2290
 
2143
2291
  if missing_sources:
2144
2292
  current_role = self.get_sf_session().get_current_role()
@@ -2957,104 +3105,6 @@ def Graph(
2957
3105
  #--------------------------------------------------
2958
3106
  # Note: All direct access components should live in a separate file
2959
3107
 
2960
- @dataclass
2961
- class Endpoint:
2962
- method: str
2963
- endpoint: str
2964
-
2965
- class DirectAccessClient:
2966
- """
2967
- DirectAccessClient is a client for direct service access without service function calls.
2968
- """
2969
-
2970
- def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
2971
- self._config: Config = config
2972
- self._token_handler: TokenHandler = token_handler
2973
- self.service_endpoint: str = service_endpoint
2974
- self.generation: Optional[Generation] = generation
2975
- self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
2976
- self.endpoints: Dict[str, Endpoint] = {
2977
- "create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
2978
- "get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
2979
- "get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
2980
- "get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
2981
- "get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
2982
- "get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
2983
- "get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
2984
- "create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
2985
- "get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
2986
- "delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
2987
- "release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
2988
- "list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
2989
- "get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
2990
- "create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
2991
- "delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
2992
- "suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
2993
- "resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
2994
- "prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
2995
- }
2996
- self.http_session = self._create_retry_session()
2997
-
2998
- def _create_retry_session(self) -> requests.Session:
2999
- http_session = requests.Session()
3000
- retries = Retry(
3001
- total=3,
3002
- backoff_factor=0.3,
3003
- status_forcelist=[500, 502, 503, 504],
3004
- allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
3005
- raise_on_status=False
3006
- )
3007
- adapter = HTTPAdapter(max_retries=retries)
3008
- http_session.mount("http://", adapter)
3009
- http_session.mount("https://", adapter)
3010
- http_session.headers.update({"Connection": "keep-alive"})
3011
- return http_session
3012
-
3013
- def request(
3014
- self,
3015
- endpoint: str,
3016
- payload: Dict[str, Any] | None = None,
3017
- headers: Dict[str, str] | None = None,
3018
- path_params: Dict[str, str] | None = None,
3019
- query_params: Dict[str, str] | None = None,
3020
- ) -> requests.Response:
3021
- """
3022
- Send a request to the service endpoint.
3023
- """
3024
- url, method = self._prepare_url(endpoint, path_params, query_params)
3025
- request_headers = self._prepare_headers(headers)
3026
- return self.http_session.request(method, url, json=payload, headers=request_headers)
3027
-
3028
- def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
3029
- try:
3030
- ep = self.endpoints[endpoint]
3031
- except KeyError:
3032
- raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
3033
- url = f"{self.service_endpoint}{ep.endpoint}"
3034
- if path_params:
3035
- escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
3036
- url = url.format(**escaped_path_params)
3037
- if query_params:
3038
- url += '?' + urlencode(query_params)
3039
- return url, ep.method
3040
-
3041
- def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
3042
- request_headers = {}
3043
- if headers:
3044
- request_headers.update(headers)
3045
- # Authorization tokens are not needed in a snowflake notebook environment
3046
- if not self._is_snowflake_notebook:
3047
- request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
3048
- # needed for oauth, does no harm for other authentication methods
3049
- request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
3050
- request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
3051
- request_headers["Accept"] = "application/json"
3052
-
3053
- request_headers["user-agent"] = get_pyrel_version(self.generation)
3054
- request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
3055
-
3056
- return debugging.add_current_propagation_headers(request_headers)
3057
-
3058
3108
  class DirectAccessResources(Resources):
3059
3109
  """
3060
3110
  Resources class for Direct Service Access avoiding Snowflake service functions.
@@ -3068,7 +3118,14 @@ class DirectAccessResources(Resources):
3068
3118
  reset_session: bool = False,
3069
3119
  generation: Optional[Generation] = None,
3070
3120
  ):
3071
- super().__init__(generation=generation, profile=profile, config=config, connection=connection, dry_run=dry_run)
3121
+ super().__init__(
3122
+ generation=generation,
3123
+ profile=profile,
3124
+ config=config,
3125
+ connection=connection,
3126
+ reset_session=reset_session,
3127
+ dry_run=dry_run,
3128
+ )
3072
3129
  self._endpoint_info = ConfigStore(ENDPOINT_FILE)
3073
3130
  self._service_endpoint = ""
3074
3131
  self._direct_access_client = None
@@ -3127,6 +3184,7 @@ class DirectAccessResources(Resources):
3127
3184
  headers: Dict[str, str] | None = None,
3128
3185
  path_params: Dict[str, str] | None = None,
3129
3186
  query_params: Dict[str, str] | None = None,
3187
+ skip_auto_create: bool = False,
3130
3188
  ) -> requests.Response:
3131
3189
  with debugging.span("direct_access_request"):
3132
3190
  def _send_request():
@@ -3148,7 +3206,8 @@ class DirectAccessResources(Resources):
3148
3206
  )
3149
3207
 
3150
3208
  # fix engine on engine error and retry
3151
- if _is_engine_issue(message):
3209
+ # Skip auto-retry if skip_auto_create is True to avoid recursion
3210
+ if _is_engine_issue(message) and not skip_auto_create:
3152
3211
  engine = payload.get("engine_name", "") if payload else ""
3153
3212
  self.auto_create_engine(engine)
3154
3213
  response = _send_request()
@@ -3310,13 +3369,24 @@ class DirectAccessResources(Resources):
3310
3369
  assert response, f"No results from get_transaction('{txn_id}')"
3311
3370
 
3312
3371
  response_content = response.json()
3313
- status: str = response_content["transaction"]['state']
3372
+ transaction = response_content["transaction"]
3373
+ status: str = transaction['state']
3314
3374
 
3315
3375
  # remove the transaction from the pending list if it's completed or aborted
3316
3376
  if status in ["COMPLETED", "ABORTED"]:
3317
3377
  if txn_id in self._pending_transactions:
3318
3378
  self._pending_transactions.remove(txn_id)
3319
3379
 
3380
+ if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
3381
+ config_file_path = getattr(self.config, 'file_path', None)
3382
+ timeout_ms = int(transaction.get("timeout_ms", 0))
3383
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
3384
+ raise QueryTimeoutExceededException(
3385
+ timeout_mins=timeout_mins,
3386
+ query_id=txn_id,
3387
+ config_file_path=config_file_path,
3388
+ )
3389
+
3320
3390
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
3321
3391
  return status == "COMPLETED" or status == "ABORTED"
3322
3392
 
@@ -3502,7 +3572,7 @@ class DirectAccessResources(Resources):
3502
3572
  return sorted(engines, key=lambda x: x["name"])
3503
3573
 
3504
3574
  def get_engine(self, name: str):
3505
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"})
3575
+ response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
3506
3576
  if response.status_code == 404: # engine not found return 404
3507
3577
  return None
3508
3578
  elif response.status_code != 200:
@@ -3549,6 +3619,7 @@ class DirectAccessResources(Resources):
3549
3619
  payload=payload,
3550
3620
  path_params={"engine_type": "logic"},
3551
3621
  headers=headers,
3622
+ skip_auto_create=True,
3552
3623
  )
3553
3624
  if response.status_code != 200:
3554
3625
  raise ResponseStatusException(
@@ -3560,6 +3631,7 @@ class DirectAccessResources(Resources):
3560
3631
  "delete_engine",
3561
3632
  path_params={"engine_name": name, "engine_type": "logic"},
3562
3633
  headers=headers,
3634
+ skip_auto_create=True,
3563
3635
  )
3564
3636
  if response.status_code != 200:
3565
3637
  raise ResponseStatusException(
@@ -3570,6 +3642,7 @@ class DirectAccessResources(Resources):
3570
3642
  response = self.request(
3571
3643
  "suspend_engine",
3572
3644
  path_params={"engine_name": name, "engine_type": "logic"},
3645
+ skip_auto_create=True,
3573
3646
  )
3574
3647
  if response.status_code != 200:
3575
3648
  raise ResponseStatusException(
@@ -3581,6 +3654,7 @@ class DirectAccessResources(Resources):
3581
3654
  "resume_engine",
3582
3655
  path_params={"engine_name": name, "engine_type": "logic"},
3583
3656
  headers=headers,
3657
+ skip_auto_create=True,
3584
3658
  )
3585
3659
  if response.status_code != 200:
3586
3660
  raise ResponseStatusException(
@@ -38,10 +38,13 @@ class EngineState(TypedDict):
38
38
  auto_suspend: int|None
39
39
  suspends_at: datetime|None
40
40
 
41
- class SourceInfo(TypedDict):
41
+ class SourceInfo(TypedDict, total=False):
42
42
  type: str|None
43
43
  state: str
44
44
  columns_hash: str|None
45
+ table_created_at: datetime|None
46
+ stream_created_at: datetime|None
47
+ last_ddl: datetime|None
45
48
  source: str
46
49
 
47
50