relationalai 0.11.2__py3-none-any.whl → 0.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. relationalai/clients/snowflake.py +44 -15
  2. relationalai/clients/types.py +1 -0
  3. relationalai/clients/use_index_poller.py +446 -178
  4. relationalai/early_access/builder/std/__init__.py +1 -1
  5. relationalai/early_access/dsl/bindings/csv.py +4 -4
  6. relationalai/semantics/internal/internal.py +22 -4
  7. relationalai/semantics/lqp/executor.py +69 -18
  8. relationalai/semantics/lqp/intrinsics.py +23 -0
  9. relationalai/semantics/lqp/model2lqp.py +16 -6
  10. relationalai/semantics/lqp/passes.py +3 -4
  11. relationalai/semantics/lqp/primitives.py +38 -14
  12. relationalai/semantics/metamodel/builtins.py +152 -11
  13. relationalai/semantics/metamodel/factory.py +3 -2
  14. relationalai/semantics/metamodel/helpers.py +78 -2
  15. relationalai/semantics/reasoners/graph/core.py +343 -40
  16. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  17. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  18. relationalai/semantics/rel/compiler.py +5 -17
  19. relationalai/semantics/rel/executor.py +2 -2
  20. relationalai/semantics/rel/rel.py +6 -0
  21. relationalai/semantics/rel/rel_utils.py +37 -1
  22. relationalai/semantics/rel/rewrite/extract_common.py +153 -242
  23. relationalai/semantics/sql/compiler.py +540 -202
  24. relationalai/semantics/sql/executor/duck_db.py +21 -0
  25. relationalai/semantics/sql/executor/result_helpers.py +7 -0
  26. relationalai/semantics/sql/executor/snowflake.py +9 -2
  27. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  28. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  29. relationalai/semantics/sql/sql.py +120 -46
  30. relationalai/semantics/std/__init__.py +9 -4
  31. relationalai/semantics/std/datetime.py +363 -0
  32. relationalai/semantics/std/math.py +77 -0
  33. relationalai/semantics/std/re.py +83 -0
  34. relationalai/semantics/std/strings.py +1 -1
  35. relationalai/tools/cli_controls.py +445 -60
  36. relationalai/util/format.py +78 -1
  37. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
  38. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/RECORD +41 -39
  39. relationalai/semantics/std/dates.py +0 -213
  40. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
  41. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
  42. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
2
2
  import json
3
+ import logging
3
4
  import uuid
4
5
 
5
6
  from relationalai import debugging
@@ -12,9 +13,29 @@ from relationalai.errors import (
12
13
  SnowflakeTableObjectsException,
13
14
  SnowflakeTableObject,
14
15
  )
15
- from relationalai.tools.cli_controls import DebuggingSpan, create_progress
16
+ from relationalai.tools.cli_controls import (
17
+ DebuggingSpan,
18
+ create_progress,
19
+ TASK_CATEGORY_INDEXING,
20
+ TASK_CATEGORY_PROVISIONING,
21
+ TASK_CATEGORY_CHANGE_TRACKING,
22
+ TASK_CATEGORY_CACHE,
23
+ TASK_CATEGORY_RELATIONS,
24
+ TASK_CATEGORY_STATUS,
25
+ TASK_CATEGORY_VALIDATION,
26
+ )
16
27
  from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
17
28
 
29
+ # Set up logger for this module
30
+ logger = logging.getLogger(__name__)
31
+
32
+ try:
33
+ from rich.console import Console
34
+ from rich.table import Table
35
+ except ImportError:
36
+ Console = None
37
+ Table = None
38
+
18
39
  if TYPE_CHECKING:
19
40
  from relationalai.clients.snowflake import Resources
20
41
  from relationalai.clients.snowflake import DirectAccessResources
@@ -32,15 +53,73 @@ MAX_DATA_SOURCE_SUBTASKS = 10
32
53
 
33
54
  # How often to check ERP status (every N iterations)
34
55
  # To limit performance overhead, we only check ERP status periodically
35
- ERP_CHECK_FREQUENCY = 5
56
+ ERP_CHECK_FREQUENCY = 15
57
+
58
+ # Polling behavior constants
59
+ POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
60
+ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
61
+
62
+ # SQL query template for getting stream column hashes
63
+ # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
64
+ # to detect if source table schema has changed since stream was created
65
+ STREAM_COLUMN_HASH_QUERY = """
66
+ SELECT
67
+ FQ_OBJECT_NAME,
68
+ SHA2(
69
+ LISTAGG(
70
+ value:name::VARCHAR ||
71
+ CASE
72
+ WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
73
+ THEN CASE value:type::VARCHAR
74
+ WHEN 'FIXED' THEN 'NUMBER'
75
+ WHEN 'REAL' THEN 'FLOAT'
76
+ WHEN 'TEXT' THEN 'TEXT'
77
+ ELSE value:type::VARCHAR
78
+ END || '(' || value:precision || ',' || value:scale || ')'
79
+ WHEN value:precision IS NOT NULL AND value:scale IS NULL
80
+ THEN CASE value:type::VARCHAR
81
+ WHEN 'FIXED' THEN 'NUMBER'
82
+ WHEN 'REAL' THEN 'FLOAT'
83
+ WHEN 'TEXT' THEN 'TEXT'
84
+ ELSE value:type::VARCHAR
85
+ END || '(0,' || value:precision || ')'
86
+ WHEN value:length IS NOT NULL
87
+ THEN CASE value:type::VARCHAR
88
+ WHEN 'FIXED' THEN 'NUMBER'
89
+ WHEN 'REAL' THEN 'FLOAT'
90
+ WHEN 'TEXT' THEN 'TEXT'
91
+ ELSE value:type::VARCHAR
92
+ END || '(' || value:length || ')'
93
+ ELSE CASE value:type::VARCHAR
94
+ WHEN 'FIXED' THEN 'NUMBER'
95
+ WHEN 'REAL' THEN 'FLOAT'
96
+ WHEN 'TEXT' THEN 'TEXT'
97
+ ELSE value:type::VARCHAR
98
+ END
99
+ END ||
100
+ CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
101
+ ','
102
+ ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
103
+ 256
104
+ ) AS STREAM_HASH
105
+ FROM {app_name}.api.data_streams,
106
+ LATERAL FLATTEN(input => COLUMNS) f
107
+ WHERE RAI_DATABASE = '{rai_database}' AND FQ_OBJECT_NAME IN ({fqn_list})
108
+ GROUP BY FQ_OBJECT_NAME;
109
+ """
110
+
36
111
 
37
112
  class UseIndexPoller:
38
113
  """
39
114
  Encapsulates the polling logic for `use_index` streams.
40
115
  """
41
116
 
42
- def _add_stream_subtask(self, progress, fq_name, status, batches_count):
43
- """Add a stream subtask if we haven't reached the limit."""
117
+ def _add_stream_subtask(self, progress, fq_name: str, status: str, batches_count: int) -> bool:
118
+ """Add a stream subtask if we haven't reached the limit.
119
+
120
+ Returns:
121
+ True if subtask was added, False if limit reached
122
+ """
44
123
  if fq_name not in self.stream_task_ids and len(self.stream_task_ids) < MAX_DATA_SOURCE_SUBTASKS:
45
124
  # Get the position in the stream order (should already be there)
46
125
  if fq_name in self.stream_order:
@@ -58,11 +137,11 @@ class UseIndexPoller:
58
137
  else:
59
138
  initial_message = f"Syncing {fq_name} ({stream_position}/{self.total_streams})"
60
139
 
61
- self.stream_task_ids[fq_name] = progress.add_sub_task(initial_message, task_id=fq_name)
140
+ self.stream_task_ids[fq_name] = progress.add_sub_task(initial_message, task_id=fq_name, category=TASK_CATEGORY_INDEXING)
62
141
 
63
- # Complete immediately if already synced
142
+ # Complete immediately if already synced (without recording completion time)
64
143
  if status == "synced":
65
- progress.complete_sub_task(fq_name)
144
+ progress.complete_sub_task(fq_name, record_time=False)
66
145
 
67
146
  return True
68
147
  return False
@@ -125,16 +204,24 @@ class UseIndexPoller:
125
204
  self.stream_position = 0
126
205
  self.stream_order = [] # Track the order of streams as they appear in data
127
206
 
207
+ # Timing will be tracked by TaskProgress
208
+
128
209
  def poll(self) -> None:
129
210
  """
130
211
  Standard stream-based polling for use_index.
131
212
  """
213
+ # Read show_duration_summary config flag (defaults to True for backward compatibility)
214
+ show_duration_summary = bool(self.res.config.get("show_duration_summary", True))
215
+
132
216
  with create_progress(
133
217
  description="Initializing data index",
134
- success_message="Initialization complete",
218
+ success_message="", # We'll handle this in the context manager
135
219
  leading_newline=True,
136
220
  trailing_newline=True,
221
+ show_duration_summary=show_duration_summary,
137
222
  ) as progress:
223
+ # Set process start time
224
+ progress.set_process_start_time()
138
225
  progress.update_main_status("Validating data sources")
139
226
  self._maybe_delete_stale(progress)
140
227
 
@@ -145,84 +232,203 @@ class UseIndexPoller:
145
232
  self._poll_loop(progress)
146
233
  self._post_check(progress)
147
234
 
235
+ # Set process end time (summary will be automatically printed by __exit__)
236
+ progress.set_process_end_time()
237
+
148
238
  def _add_cache_subtask(self, progress) -> None:
149
239
  """Add a subtask showing cache usage information only when cache is used."""
150
240
  if self.cache.using_cache:
151
241
  # Cache was used - show how many sources were cached
152
242
  total_sources = len(self.cache.sources)
153
243
  cached_sources = total_sources - len(self.sources)
154
- progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage")
244
+ progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage", category=TASK_CATEGORY_CACHE)
155
245
  # Complete the subtask immediately since it's just informational
156
246
  progress.complete_sub_task("cache_usage")
157
247
 
248
+ def _get_stream_column_hashes(self, sources: List[str], progress) -> Dict[str, str]:
249
+ """
250
+ Query data_streams to get current column hashes for the given sources.
251
+
252
+ Args:
253
+ sources: List of source FQNs to query
254
+ progress: TaskProgress instance for updating status on error
255
+
256
+ Returns:
257
+ Dict mapping FQN -> column hash
258
+
259
+ Raises:
260
+ ValueError: If the query fails (permissions, table doesn't exist, etc.)
261
+ """
262
+ from relationalai.clients.snowflake import PYREL_ROOT_DB
263
+
264
+ # Build FQN list for SQL IN clause
265
+ fqn_list = ", ".join([f"'{source}'" for source in sources])
266
+
267
+ # Format query template with actual values
268
+ hash_query = STREAM_COLUMN_HASH_QUERY.format(
269
+ app_name=self.app_name,
270
+ rai_database=PYREL_ROOT_DB,
271
+ fqn_list=fqn_list
272
+ )
273
+
274
+ try:
275
+ hash_results = self.res._exec(hash_query)
276
+ return {row["FQ_OBJECT_NAME"]: row["STREAM_HASH"] for row in hash_results}
277
+
278
+ except Exception as e:
279
+ logger.error(f"Failed to query stream column hashes: {e}")
280
+ logger.error(f" Query: {hash_query[:200]}...")
281
+ logger.error(f" Sources: {sources}")
282
+ progress.update_main_status("❌ Failed to validate data stream metadata")
283
+ raise ValueError(
284
+ f"Failed to validate stream column hashes. This may indicate a permissions "
285
+ f"issue or missing data_streams table. Error: {e}"
286
+ ) from e
287
+
288
+ def _filter_truly_stale_sources(self, stale_sources: List[str], progress) -> List[str]:
289
+ """
290
+ Filter stale sources to only include those with mismatched column hashes.
291
+
292
+ Args:
293
+ stale_sources: List of source FQNs marked as stale
294
+ progress: TaskProgress instance for updating status on error
295
+
296
+ Returns:
297
+ List of truly stale sources that need to be deleted/recreated
298
+
299
+ A source is truly stale if:
300
+ - The stream doesn't exist (needs to be created), OR
301
+ - The column hashes don't match (needs to be recreated)
302
+ """
303
+ stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
304
+
305
+ truly_stale = []
306
+ for source in stale_sources:
307
+ source_hash = self.source_info[source].get("columns_hash")
308
+ stream_hash = stream_hashes.get(source)
309
+
310
+ # Log hash comparison for debugging
311
+ logger.debug(f"Source: {source}")
312
+ logger.debug(f" Source table hash: {source_hash}")
313
+ logger.debug(f" Stream hash: {stream_hash}")
314
+ logger.debug(f" Match: {source_hash == stream_hash}")
315
+
316
+ if stream_hash is None or source_hash != stream_hash:
317
+ logger.debug(" Action: DELETE (stale)")
318
+ truly_stale.append(source)
319
+ else:
320
+ logger.debug(" Action: KEEP (valid)")
321
+
322
+ logger.debug(f"Stale sources summary: {len(truly_stale)}/{len(stale_sources)} truly stale")
323
+
324
+ return truly_stale
325
+
326
+ def _add_deletion_subtasks(self, progress, sources: List[str]) -> None:
327
+ """Add progress subtasks for source deletion.
328
+
329
+ Args:
330
+ progress: TaskProgress instance
331
+ sources: List of source FQNs to be deleted
332
+ """
333
+ if len(sources) <= MAX_INDIVIDUAL_SUBTASKS:
334
+ for i, source in enumerate(sources):
335
+ progress.add_sub_task(
336
+ f"Removing stale stream {source} ({i+1}/{len(sources)})",
337
+ task_id=f"stale_source_{i}",
338
+ category=TASK_CATEGORY_VALIDATION
339
+ )
340
+ else:
341
+ progress.add_sub_task(
342
+ f"Removing {len(sources)} stale data sources",
343
+ task_id="stale_sources_summary",
344
+ category=TASK_CATEGORY_VALIDATION
345
+ )
346
+
347
+ def _complete_deletion_subtasks(self, progress, sources: List[str], deleted_count: int) -> None:
348
+ """Complete progress subtasks for source deletion.
349
+
350
+ Args:
351
+ progress: TaskProgress instance
352
+ sources: List of source FQNs that were processed
353
+ deleted_count: Number of sources successfully deleted
354
+ """
355
+ if len(sources) <= MAX_INDIVIDUAL_SUBTASKS:
356
+ for i in range(len(sources)):
357
+ if f"stale_source_{i}" in progress._tasks:
358
+ progress.complete_sub_task(f"stale_source_{i}")
359
+ else:
360
+ if "stale_sources_summary" in progress._tasks:
361
+ if deleted_count > 0:
362
+ s = "s" if deleted_count > 1 else ""
363
+ progress.update_sub_task(
364
+ "stale_sources_summary",
365
+ f"Removed {deleted_count} stale data source{s}"
366
+ )
367
+ progress.complete_sub_task("stale_sources_summary")
368
+
158
369
  def _maybe_delete_stale(self, progress) -> None:
370
+ """Check for and delete stale data streams that need recreation.
371
+
372
+ Args:
373
+ progress: TaskProgress instance for tracking deletion progress
374
+ """
159
375
  with debugging.span("check_sources"):
160
- # Source tables that have been altered/changed since the last stream creation
161
376
  stale_sources = [
162
377
  source
163
378
  for source, info in self.source_info.items()
164
379
  if info["state"] == "STALE"
165
380
  ]
166
- if stale_sources:
167
- with DebuggingSpan("validate_sources"):
168
- try:
169
- # Delete all stale streams, so use_index could recreate them again
170
- from relationalai.clients.snowflake import PYREL_ROOT_DB
171
- query = f"CALL {self.app_name}.api.delete_data_streams({stale_sources}, '{PYREL_ROOT_DB}');"
172
-
173
- # Add subtasks based on count
174
- if len(stale_sources) <= MAX_INDIVIDUAL_SUBTASKS:
175
- # Add individual subtasks for each stale source
176
- for i, source in enumerate(stale_sources):
177
- progress.add_sub_task(f"Removing stale stream {source} ({i+1}/{len(stale_sources)})", task_id=f"stale_source_{i}")
178
- else:
179
- # Add single summary subtask for many sources
180
- progress.add_sub_task(f"Removing {len(stale_sources)} stale data sources", task_id="stale_sources_summary")
181
-
182
- delete_response = self.res._exec(query)
183
- delete_json_str = delete_response[0]["DELETE_DATA_STREAMS"].lower()
184
- delete_data = json.loads(delete_json_str)
185
- deleted_count = delete_data.get("deleted", 0)
186
- diff = len(stale_sources) - deleted_count
187
-
188
- # Complete subtasks
189
- if len(stale_sources) <= MAX_INDIVIDUAL_SUBTASKS:
190
- # Complete all individual subtasks
191
- for i in range(len(stale_sources)):
192
- if f"stale_source_{i}" in progress._tasks:
193
- progress.complete_sub_task(f"stale_source_{i}")
194
- else:
195
- # Complete summary subtask
196
- if "stale_sources_summary" in progress._tasks:
197
- if deleted_count > 0:
198
- is_many = deleted_count > 1
199
- s = "s" if is_many else ""
200
- progress.update_sub_task("stale_sources_summary", f"Removed {deleted_count} stale data source{s}")
201
- progress.complete_sub_task("stale_sources_summary")
202
-
203
- if diff > 0:
204
- errors = delete_data.get("errors", None)
205
- if errors:
206
- raise Exception(f"Error(s) deleting streams with modified sources: {errors}")
207
- except Exception as e:
208
- # Complete any remaining subtasks
209
- if len(stale_sources) <= MAX_INDIVIDUAL_SUBTASKS:
210
- for i in range(len(stale_sources)):
211
- if f"stale_source_{i}" in progress._tasks:
212
- progress.complete_sub_task(f"stale_source_{i}")
213
- else:
214
- if "stale_sources_summary" in progress._tasks:
215
- progress.update_sub_task("stale_sources_summary", f"❌ Failed to remove stale sources: {str(e)}")
216
- progress.complete_sub_task("stale_sources_summary")
217
-
218
- # The delete_data_streams procedure will raise an exception if the streams do not exist
219
- if "data streams do not exist" in str(e).lower():
220
- # Don't raise an error if streams don't exist - this is expected
221
- pass
222
- else:
223
- raise e from None
381
+
382
+ if not stale_sources:
383
+ return
384
+
385
+ with DebuggingSpan("validate_sources"):
386
+ try:
387
+ # Validate which sources truly need deletion by comparing column hashes
388
+ truly_stale = self._filter_truly_stale_sources(stale_sources, progress)
389
+
390
+ if not truly_stale:
391
+ return
392
+
393
+ # Delete truly stale streams
394
+ from relationalai.clients.snowflake import PYREL_ROOT_DB
395
+ query = f"CALL {self.app_name}.api.delete_data_streams({truly_stale}, '{PYREL_ROOT_DB}');"
396
+
397
+ self._add_deletion_subtasks(progress, truly_stale)
398
+
399
+ delete_response = self.res._exec(query)
400
+ delete_json_str = delete_response[0]["DELETE_DATA_STREAMS"].lower()
401
+ delete_data = json.loads(delete_json_str)
402
+ deleted_count = delete_data.get("deleted", 0)
403
+
404
+ self._complete_deletion_subtasks(progress, truly_stale, deleted_count)
405
+
406
+ # Check for errors
407
+ diff = len(truly_stale) - deleted_count
408
+ if diff > 0:
409
+ errors = delete_data.get("errors", None)
410
+ if errors:
411
+ raise Exception(f"Error(s) deleting streams with modified sources: {errors}")
412
+
413
+ except Exception as e:
414
+ # Complete any remaining subtasks
415
+ self._complete_deletion_subtasks(progress, stale_sources, 0)
416
+ if "stale_sources_summary" in progress._tasks:
417
+ progress.update_sub_task(
418
+ "stale_sources_summary",
419
+ f"❌ Failed to remove stale sources: {str(e)}"
420
+ )
421
+
422
+ # Don't raise if streams don't exist - this is expected
423
+ if "data streams do not exist" not in str(e).lower():
424
+ raise e from None
224
425
 
225
426
  def _poll_loop(self, progress) -> None:
427
+ """Main polling loop for use_index streams.
428
+
429
+ Args:
430
+ progress: TaskProgress instance for tracking polling progress
431
+ """
226
432
  source_references = self.res._get_source_references(self.source_info)
227
433
  sources_object_references_str = ", ".join(source_references)
228
434
 
@@ -234,7 +440,7 @@ class UseIndexPoller:
234
440
  with debugging.span("check_erp_status"):
235
441
  # Add subtask for ERP status check
236
442
  if self._erp_check_task_id is None:
237
- self._erp_check_task_id = progress.add_sub_task("Checking system status", task_id="erp_check")
443
+ self._erp_check_task_id = progress.add_sub_task("Checking system status", task_id="erp_check", category=TASK_CATEGORY_STATUS)
238
444
 
239
445
  if not self.res.is_erp_running(self.app_name):
240
446
  progress.update_sub_task("erp_check", "❌ System status check failed")
@@ -269,11 +475,18 @@ class UseIndexPoller:
269
475
  use_index_json_str = results[0]["USE_INDEX"]
270
476
 
271
477
  # Parse the JSON string into a Python dictionary
272
- use_index_data = json.loads(use_index_json_str)
478
+ try:
479
+ use_index_data = json.loads(use_index_json_str)
480
+ except json.JSONDecodeError as e:
481
+ logger.error(f"Invalid JSON from use_index API: {e}")
482
+ logger.error(f"Raw response (first 500 chars): {use_index_json_str[:500]}")
483
+ progress.update_main_status("❌ Received invalid response from server")
484
+ raise ValueError(f"Invalid JSON response from use_index: {e}") from e
485
+
273
486
  span.update(use_index_data)
274
487
 
275
- # Useful to see the full use_index_data on each poll loop
276
- # print(f"\n\nuse_index_data: {json.dumps(use_index_data, indent=4)}\n\n")
488
+ # Log the full use_index_data for debugging if needed
489
+ logger.debug(f"use_index_data: {json.dumps(use_index_data, indent=4)}")
277
490
 
278
491
  all_data = use_index_data.get("data", [])
279
492
  ready = use_index_data.get("ready", False)
@@ -298,16 +511,17 @@ class UseIndexPoller:
298
511
  if not ready and all_data:
299
512
  progress.update_main_status("Processing background tasks. This may take a while...")
300
513
 
301
- # Build complete stream order first
302
- for data in all_data:
303
- if data is None:
304
- continue
305
- fq_name = data.get("fq_object_name", "Unknown")
306
- if fq_name not in self.stream_order:
307
- self.stream_order.append(fq_name)
514
+ # Build complete stream order first (only on first iteration with data)
515
+ if self.total_streams == 0:
516
+ for data in all_data:
517
+ if data is None:
518
+ continue
519
+ fq_name = data.get("fq_object_name", "Unknown")
520
+ if fq_name not in self.stream_order:
521
+ self.stream_order.append(fq_name)
308
522
 
309
- # Set total streams count based on complete order
310
- self.total_streams = len(self.stream_order)
523
+ # Set total streams count based on complete order (only once)
524
+ self.total_streams = len(self.stream_order)
311
525
 
312
526
  # Add new streams as subtasks if we haven't reached the limit
313
527
  for data in all_data:
@@ -367,62 +581,69 @@ class UseIndexPoller:
367
581
  for engine in engines:
368
582
  if not engine or not isinstance(engine, dict):
369
583
  continue
370
-
371
- name = engine.get("name", "Unknown")
372
584
  size = self.engine_size
373
- if name not in self.engine_task_ids:
374
- self.engine_task_ids[name] = progress.add_sub_task(f"Provisioning engine {name} ({size})", task_id=name)
375
-
376
- state = (engine.get("state") or "").lower()
585
+ name = engine.get("name", "Unknown")
377
586
  status = (engine.get("status") or "").lower()
587
+ sub_task_id = self.engine_task_ids.get(name, None)
588
+ sub_task_status_message = ""
378
589
 
379
- # Determine engine status message
380
- if state == "ready" or status == "ready":
381
- status_message = f"Engine {name} ({size}) ready"
382
- should_complete = True
383
- else:
590
+ # Complete the sub task if it exists and the engine status is ready
591
+ if sub_task_id and name in progress._tasks and not progress._tasks[name].completed and (status == "ready"):
592
+ sub_task_status_message = f"Engine {name} ({size}) ready"
593
+ progress.update_sub_task(name, sub_task_status_message)
594
+ progress.complete_sub_task(name)
595
+
596
+ # Add the sub task if it doesn't exist and the engine status is pending
597
+ if not sub_task_id and status == "pending":
384
598
  writer = engine.get("writer", False)
385
599
  engine_type = "writer engine" if writer else "engine"
386
- status_message = f"Provisioning {engine_type} {name} ({size})"
387
- should_complete = False
388
-
389
- # Only update if the task isn't already completed
390
- if name in progress._tasks and not progress._tasks[name].completed:
391
- progress.update_sub_task(name, status_message)
392
-
393
- if should_complete:
394
- progress.complete_sub_task(name)
600
+ sub_task_status_message = f"Provisioning {engine_type} {name} ({size})"
601
+ self.engine_task_ids[name] = progress.add_sub_task(sub_task_status_message, task_id=name, category=TASK_CATEGORY_PROVISIONING)
395
602
 
396
603
  # Special handling for CDC_MANAGED_ENGINE - mark ready when any stream starts processing
397
- if CDC_MANAGED_ENGINE in self.engine_task_ids:
604
+ cdc_task = progress._tasks.get(CDC_MANAGED_ENGINE) if CDC_MANAGED_ENGINE in progress._tasks else None
605
+ if CDC_MANAGED_ENGINE in self.engine_task_ids and cdc_task and not cdc_task.completed:
606
+
398
607
  has_processing_streams = any(
399
608
  stream.get("next_batch_status", "") == "processing"
400
609
  for stream in all_data
401
610
  )
402
- if has_processing_streams and CDC_MANAGED_ENGINE in progress._tasks and not progress._tasks[CDC_MANAGED_ENGINE].completed:
611
+ if has_processing_streams and cdc_task and not cdc_task.completed:
403
612
  progress.update_sub_task(CDC_MANAGED_ENGINE, f"Engine {CDC_MANAGED_ENGINE} ({self.engine_size}) ready")
404
613
  progress.complete_sub_task(CDC_MANAGED_ENGINE)
405
614
 
406
615
  self.counter += 1
407
616
 
408
617
  # Handle relations data
409
- if not ready and relations and isinstance(relations, dict):
618
+ if relations and isinstance(relations, dict):
410
619
  txn = relations.get("txn", {}) or {}
411
620
  txn_id = txn.get("id", None)
412
621
 
413
622
  # Only show relations subtask if there is a valid txn object
414
623
  if txn_id:
415
624
  status = relations.get("status", "").upper()
416
- state = txn.get("state", "").upper()
417
625
 
418
626
  # Create relations subtask if it doesn't exist
419
627
  if self.relations_task_id is None:
420
- self.relations_task_id = progress.add_sub_task("Populating relations", task_id="relations")
628
+ self.relations_task_id = progress.add_sub_task("Populating relations", task_id="relations", category=TASK_CATEGORY_RELATIONS)
629
+
630
+ # Set the start time from the JSON if available (always update)
631
+ start_time_ms = relations.get("start_time")
632
+ if start_time_ms:
633
+ start_time_seconds = start_time_ms / 1000.0
634
+ progress._tasks["relations"].added_time = start_time_seconds
421
635
 
422
636
  # Update relations status
423
- if state == "COMPLETED":
637
+ if status == "COMPLETED":
424
638
  progress.update_sub_task("relations", f"Relations populated (txn: {txn_id})")
425
- progress.complete_sub_task("relations")
639
+
640
+ # Set the completion time from the JSON if available
641
+ end_time_ms = relations.get("end_time")
642
+ if end_time_ms:
643
+ end_time_seconds = end_time_ms / 1000.0
644
+ progress._tasks["relations"].completed_time = end_time_seconds
645
+
646
+ progress.complete_sub_task("relations", record_time=False) # Don't record local time
426
647
  else:
427
648
  progress.update_sub_task("relations", f"Relations populating (txn: {txn_id})")
428
649
 
@@ -473,85 +694,128 @@ class UseIndexPoller:
473
694
 
474
695
  return break_loop
475
696
 
476
- poll_with_specified_overhead(lambda: check_ready(progress), overhead_rate=0.1, max_delay=1)
697
+ poll_with_specified_overhead(lambda: check_ready(progress), overhead_rate=POLL_OVERHEAD_RATE, max_delay=POLL_MAX_DELAY)
477
698
 
478
699
  def _post_check(self, progress) -> None:
479
- num_tables_altered = 0
700
+ """Run post-processing checks including change tracking enablement.
480
701
 
481
- enabled_tables = []
482
- if (
483
- self.tables_with_not_enabled_change_tracking
484
- and self.res.config.get("ensure_change_tracking", False)
485
- ):
486
- tables_to_process = self.tables_with_not_enabled_change_tracking
487
-
488
- # Add subtasks based on count
489
- if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
490
- # Add individual subtasks for each table
491
- for i, table in enumerate(tables_to_process):
492
- fqn, kind = table
493
- progress.add_sub_task(f"Enabling change tracking on {fqn} ({i+1}/{len(tables_to_process)})", task_id=f"change_tracking_{i}")
494
- else:
495
- # Add single summary subtask for many tables
496
- progress.add_sub_task(f"Enabling change tracking on {len(tables_to_process)} tables", task_id="change_tracking_summary")
702
+ Args:
703
+ progress: TaskProgress instance for tracking progress
497
704
 
498
- # Process tables
705
+ Raises:
706
+ SnowflakeChangeTrackingNotEnabledException: If change tracking cannot be enabled
707
+ SnowflakeTableObjectsException: If there are table-related errors
708
+ EngineProvisioningFailed: If engine provisioning fails
709
+ """
710
+ num_tables_altered = 0
711
+ failed_tables = [] # Track tables that failed to enable change tracking
712
+
713
+ enabled_tables = []
714
+ if (
715
+ self.tables_with_not_enabled_change_tracking
716
+ and self.res.config.get("ensure_change_tracking", False)
717
+ ):
718
+ tables_to_process = self.tables_with_not_enabled_change_tracking
719
+ # Track timing for change tracking
720
+
721
+ # Add subtasks based on count
722
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
723
+ # Add individual subtasks for each table
499
724
  for i, table in enumerate(tables_to_process):
500
- try:
501
- fqn, kind = table
502
- self.res._exec(f"ALTER {kind} {fqn} SET CHANGE_TRACKING = TRUE;")
503
- enabled_tables.append(table)
504
- num_tables_altered += 1
725
+ fqn, kind = table
726
+ progress.add_sub_task(f"Enabling change tracking on {fqn} ({i+1}/{len(tables_to_process)})", task_id=f"change_tracking_{i}", category=TASK_CATEGORY_CHANGE_TRACKING)
727
+ else:
728
+ # Add single summary subtask for many tables
729
+ progress.add_sub_task(f"Enabling change tracking on {len(tables_to_process)} tables", task_id="change_tracking_summary", category=TASK_CATEGORY_CHANGE_TRACKING)
505
730
 
506
- # Update progress based on subtask type
507
- if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
508
- # Complete individual table subtask
509
- progress.complete_sub_task(f"change_tracking_{i}")
510
- else:
511
- # Update summary subtask with progress
512
- progress.update_sub_task("change_tracking_summary",
513
- f"Enabling change tracking on {len(tables_to_process)} tables... ({i+1}/{len(tables_to_process)})")
514
- except Exception:
515
- # Handle errors based on subtask type
731
+ # Process tables
732
+ for i, table in enumerate(tables_to_process):
733
+ fqn, kind = table # Unpack outside try block to ensure fqn is defined
734
+
735
+ try:
736
+ # Validate table_type to prevent SQL injection
737
+ # Should only be TABLE or VIEW
738
+ if kind not in ("TABLE", "VIEW"):
739
+ logger.error(f"Invalid table kind '{kind}' for {fqn}, skipping")
740
+ failed_tables.append((fqn, f"Invalid table kind: {kind}"))
741
+ # Mark as failed in progress
516
742
  if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
517
- # Complete the individual subtask even if it failed
518
743
  if f"change_tracking_{i}" in progress._tasks:
744
+ progress.update_sub_task(f"change_tracking_{i}", f"❌ Invalid type: {fqn}")
519
745
  progress.complete_sub_task(f"change_tracking_{i}")
520
- pass
521
-
522
- # Complete summary subtask if used
523
- if len(tables_to_process) > MAX_INDIVIDUAL_SUBTASKS and "change_tracking_summary" in progress._tasks:
524
- if num_tables_altered > 0:
525
- s = "s" if num_tables_altered > 1 else ""
526
- progress.update_sub_task("change_tracking_summary", f"Enabled change tracking on {num_tables_altered} table{s}")
527
- progress.complete_sub_task("change_tracking_summary")
528
-
529
- # Remove the tables that were successfully enabled from the list of not enabled tables
530
- # so that we don't raise an exception for them later
531
- self.tables_with_not_enabled_change_tracking = [
532
- t for t in self.tables_with_not_enabled_change_tracking if t not in enabled_tables
533
- ]
534
-
535
- if self.tables_with_not_enabled_change_tracking:
536
- progress.update_main_status("Errors found. See below for details.")
537
- raise SnowflakeChangeTrackingNotEnabledException(
538
- self.tables_with_not_enabled_change_tracking
539
- )
746
+ continue
747
+
748
+ # Execute ALTER statement
749
+ # Note: fqn should already be properly quoted from source_info
750
+ self.res._exec(f"ALTER {kind} {fqn} SET CHANGE_TRACKING = TRUE;")
751
+ enabled_tables.append(table)
752
+ num_tables_altered += 1
753
+
754
+ # Update progress based on subtask type
755
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
756
+ # Complete individual table subtask
757
+ progress.complete_sub_task(f"change_tracking_{i}")
758
+ else:
759
+ # Update summary subtask with progress
760
+ progress.update_sub_task("change_tracking_summary",
761
+ f"Enabling change tracking on {len(tables_to_process)} tables... ({i+1}/{len(tables_to_process)})")
762
+ except Exception as e:
763
+ # Log the error for debugging
764
+ logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
765
+ failed_tables.append((fqn, str(e)))
766
+
767
+ # Handle errors based on subtask type
768
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
769
+ # Mark the individual subtask as failed and complete it
770
+ if f"change_tracking_{i}" in progress._tasks:
771
+ progress.update_sub_task(f"change_tracking_{i}", f"❌ Failed: {fqn}")
772
+ progress.complete_sub_task(f"change_tracking_{i}")
773
+ # Continue processing other tables despite this failure
774
+
775
+ # Complete summary subtask if used
776
+ if len(tables_to_process) > MAX_INDIVIDUAL_SUBTASKS and "change_tracking_summary" in progress._tasks:
777
+ if num_tables_altered > 0:
778
+ s = "s" if num_tables_altered > 1 else ""
779
+ success_msg = f"Enabled change tracking on {num_tables_altered} table{s}"
780
+ if failed_tables:
781
+ success_msg += f" ({len(failed_tables)} failed)"
782
+ progress.update_sub_task("change_tracking_summary", success_msg)
783
+ elif failed_tables:
784
+ progress.update_sub_task("change_tracking_summary", f"❌ Failed on {len(failed_tables)} table(s)")
785
+ progress.complete_sub_task("change_tracking_summary")
786
+
787
+ # Log summary of failed tables
788
+ if failed_tables:
789
+ logger.warning(f"Failed to enable change tracking on {len(failed_tables)} table(s)")
790
+ for fqn, error in failed_tables:
791
+ logger.warning(f" {fqn}: {error}")
792
+
793
+ # Remove the tables that were successfully enabled from the list of not enabled tables
794
+ # so that we don't raise an exception for them later
795
+ self.tables_with_not_enabled_change_tracking = [
796
+ t for t in self.tables_with_not_enabled_change_tracking if t not in enabled_tables
797
+ ]
540
798
 
541
- if self.table_objects_with_other_errors:
542
- progress.update_main_status("Errors found. See below for details.")
543
- raise SnowflakeTableObjectsException(self.table_objects_with_other_errors)
544
- if self.engine_errors:
545
- progress.update_main_status("Errors found. See below for details.")
546
- # if there is an engine error, probably auto create engine failed
547
- # Create a synthetic exception from the first engine error
548
- first_error = self.engine_errors[0]
549
- error_message = first_error.get("message", "Unknown engine error")
550
- synthetic_exception = Exception(f"Engine error: {error_message}")
551
- raise EngineProvisioningFailed(self.engine_name, synthetic_exception)
552
-
553
- if num_tables_altered > 0:
554
- self._poll_loop(progress)
799
+ if self.tables_with_not_enabled_change_tracking:
800
+ progress.update_main_status("Errors found. See below for details.")
801
+ raise SnowflakeChangeTrackingNotEnabledException(
802
+ self.tables_with_not_enabled_change_tracking
803
+ )
804
+
805
+ if self.table_objects_with_other_errors:
806
+ progress.update_main_status("Errors found. See below for details.")
807
+ raise SnowflakeTableObjectsException(self.table_objects_with_other_errors)
808
+ if self.engine_errors:
809
+ progress.update_main_status("Errors found. See below for details.")
810
+ # if there is an engine error, probably auto create engine failed
811
+ # Create a synthetic exception from the first engine error
812
+ first_error = self.engine_errors[0]
813
+ error_message = first_error.get("message", "Unknown engine error")
814
+ synthetic_exception = Exception(f"Engine error: {error_message}")
815
+ raise EngineProvisioningFailed(self.engine_name, synthetic_exception)
816
+
817
+ if num_tables_altered > 0:
818
+ self._poll_loop(progress)
555
819
 
556
820
  class DirectUseIndexPoller(UseIndexPoller):
557
821
  """
@@ -636,17 +900,21 @@ class DirectUseIndexPoller(UseIndexPoller):
636
900
  attempt += 1
637
901
  return False
638
902
 
903
+ # Read show_duration_summary config flag (defaults to True for backward compatibility)
904
+ show_duration_summary = bool(self.res.config.get("show_duration_summary", True))
905
+
639
906
  with create_progress(
640
907
  description="Preparing your data...",
641
908
  success_message="Done",
642
909
  leading_newline=True,
643
910
  trailing_newline=True,
911
+ show_duration_summary=show_duration_summary,
644
912
  ) as progress:
645
913
  # Add cache usage subtask
646
914
  self._add_cache_subtask(progress)
647
915
 
648
916
  with debugging.span("poll_direct"):
649
- poll_with_specified_overhead(lambda: check_direct(progress), overhead_rate=0.1, max_delay=1)
917
+ poll_with_specified_overhead(lambda: check_direct(progress), overhead_rate=POLL_OVERHEAD_RATE, max_delay=POLL_MAX_DELAY)
650
918
 
651
919
  # Run the same post-check logic as UseIndexPoller
652
920
  self._post_check(progress)