relationalai 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. relationalai/clients/config.py +7 -0
  2. relationalai/clients/direct_access_client.py +113 -0
  3. relationalai/clients/snowflake.py +41 -107
  4. relationalai/clients/use_index_poller.py +349 -188
  5. relationalai/early_access/dsl/bindings/csv.py +2 -2
  6. relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
  7. relationalai/early_access/rel/rewrite/__init__.py +1 -1
  8. relationalai/errors.py +24 -3
  9. relationalai/semantics/internal/annotations.py +1 -0
  10. relationalai/semantics/internal/internal.py +22 -4
  11. relationalai/semantics/lqp/builtins.py +1 -0
  12. relationalai/semantics/lqp/executor.py +61 -12
  13. relationalai/semantics/lqp/intrinsics.py +23 -0
  14. relationalai/semantics/lqp/model2lqp.py +13 -4
  15. relationalai/semantics/lqp/passes.py +4 -6
  16. relationalai/semantics/lqp/primitives.py +12 -1
  17. relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
  18. relationalai/semantics/lqp/rewrite/extract_common.py +362 -0
  19. relationalai/semantics/metamodel/builtins.py +20 -2
  20. relationalai/semantics/metamodel/factory.py +3 -2
  21. relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
  22. relationalai/semantics/reasoners/graph/core.py +273 -71
  23. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  24. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  25. relationalai/semantics/rel/builtins.py +5 -1
  26. relationalai/semantics/rel/compiler.py +7 -19
  27. relationalai/semantics/rel/executor.py +2 -2
  28. relationalai/semantics/rel/rel.py +6 -0
  29. relationalai/semantics/rel/rel_utils.py +8 -1
  30. relationalai/semantics/sql/compiler.py +122 -42
  31. relationalai/semantics/sql/executor/duck_db.py +28 -3
  32. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  33. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  34. relationalai/semantics/sql/sql.py +27 -0
  35. relationalai/semantics/std/__init__.py +2 -1
  36. relationalai/semantics/std/datetime.py +4 -0
  37. relationalai/semantics/std/re.py +83 -0
  38. relationalai/semantics/std/strings.py +1 -1
  39. relationalai/tools/cli.py +11 -4
  40. relationalai/tools/cli_controls.py +445 -60
  41. relationalai/util/format.py +78 -1
  42. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/METADATA +7 -5
  43. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/RECORD +51 -50
  44. relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
  45. relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
  46. relationalai/semantics/rel/rewrite/extract_common.py +0 -451
  47. /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
  48. /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
  49. /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
  50. /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
  51. /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
  52. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/WHEEL +0 -0
  53. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/entry_points.txt +0 -0
  54. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
2
2
  import json
3
+ import logging
3
4
  import uuid
4
5
 
5
6
  from relationalai import debugging
@@ -12,9 +13,29 @@ from relationalai.errors import (
12
13
  SnowflakeTableObjectsException,
13
14
  SnowflakeTableObject,
14
15
  )
15
- from relationalai.tools.cli_controls import DebuggingSpan, create_progress
16
+ from relationalai.tools.cli_controls import (
17
+ DebuggingSpan,
18
+ create_progress,
19
+ TASK_CATEGORY_INDEXING,
20
+ TASK_CATEGORY_PROVISIONING,
21
+ TASK_CATEGORY_CHANGE_TRACKING,
22
+ TASK_CATEGORY_CACHE,
23
+ TASK_CATEGORY_RELATIONS,
24
+ TASK_CATEGORY_STATUS,
25
+ TASK_CATEGORY_VALIDATION,
26
+ )
16
27
  from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
17
28
 
29
+ # Set up logger for this module
30
+ logger = logging.getLogger(__name__)
31
+
32
+ try:
33
+ from rich.console import Console
34
+ from rich.table import Table
35
+ except ImportError:
36
+ Console = None
37
+ Table = None
38
+
18
39
  if TYPE_CHECKING:
19
40
  from relationalai.clients.snowflake import Resources
20
41
  from relationalai.clients.snowflake import DirectAccessResources
@@ -32,15 +53,73 @@ MAX_DATA_SOURCE_SUBTASKS = 10
32
53
 
33
54
  # How often to check ERP status (every N iterations)
34
55
  # To limit performance overhead, we only check ERP status periodically
35
- ERP_CHECK_FREQUENCY = 5
56
+ ERP_CHECK_FREQUENCY = 15
57
+
58
+ # Polling behavior constants
59
+ POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
60
+ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
61
+
62
+ # SQL query template for getting stream column hashes
63
+ # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
64
+ # to detect if source table schema has changed since stream was created
65
+ STREAM_COLUMN_HASH_QUERY = """
66
+ SELECT
67
+ FQ_OBJECT_NAME,
68
+ SHA2(
69
+ LISTAGG(
70
+ value:name::VARCHAR ||
71
+ CASE
72
+ WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
73
+ THEN CASE value:type::VARCHAR
74
+ WHEN 'FIXED' THEN 'NUMBER'
75
+ WHEN 'REAL' THEN 'FLOAT'
76
+ WHEN 'TEXT' THEN 'TEXT'
77
+ ELSE value:type::VARCHAR
78
+ END || '(' || value:precision || ',' || value:scale || ')'
79
+ WHEN value:precision IS NOT NULL AND value:scale IS NULL
80
+ THEN CASE value:type::VARCHAR
81
+ WHEN 'FIXED' THEN 'NUMBER'
82
+ WHEN 'REAL' THEN 'FLOAT'
83
+ WHEN 'TEXT' THEN 'TEXT'
84
+ ELSE value:type::VARCHAR
85
+ END || '(0,' || value:precision || ')'
86
+ WHEN value:length IS NOT NULL
87
+ THEN CASE value:type::VARCHAR
88
+ WHEN 'FIXED' THEN 'NUMBER'
89
+ WHEN 'REAL' THEN 'FLOAT'
90
+ WHEN 'TEXT' THEN 'TEXT'
91
+ ELSE value:type::VARCHAR
92
+ END || '(' || value:length || ')'
93
+ ELSE CASE value:type::VARCHAR
94
+ WHEN 'FIXED' THEN 'NUMBER'
95
+ WHEN 'REAL' THEN 'FLOAT'
96
+ WHEN 'TEXT' THEN 'TEXT'
97
+ ELSE value:type::VARCHAR
98
+ END
99
+ END ||
100
+ CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
101
+ ','
102
+ ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
103
+ 256
104
+ ) AS STREAM_HASH
105
+ FROM {app_name}.api.data_streams,
106
+ LATERAL FLATTEN(input => COLUMNS) f
107
+ WHERE RAI_DATABASE = '{rai_database}' AND FQ_OBJECT_NAME IN ({fqn_list})
108
+ GROUP BY FQ_OBJECT_NAME;
109
+ """
110
+
36
111
 
37
112
  class UseIndexPoller:
38
113
  """
39
114
  Encapsulates the polling logic for `use_index` streams.
40
115
  """
41
116
 
42
- def _add_stream_subtask(self, progress, fq_name, status, batches_count):
43
- """Add a stream subtask if we haven't reached the limit."""
117
+ def _add_stream_subtask(self, progress, fq_name: str, status: str, batches_count: int) -> bool:
118
+ """Add a stream subtask if we haven't reached the limit.
119
+
120
+ Returns:
121
+ True if subtask was added, False if limit reached
122
+ """
44
123
  if fq_name not in self.stream_task_ids and len(self.stream_task_ids) < MAX_DATA_SOURCE_SUBTASKS:
45
124
  # Get the position in the stream order (should already be there)
46
125
  if fq_name in self.stream_order:
@@ -58,11 +137,11 @@ class UseIndexPoller:
58
137
  else:
59
138
  initial_message = f"Syncing {fq_name} ({stream_position}/{self.total_streams})"
60
139
 
61
- self.stream_task_ids[fq_name] = progress.add_sub_task(initial_message, task_id=fq_name)
140
+ self.stream_task_ids[fq_name] = progress.add_sub_task(initial_message, task_id=fq_name, category=TASK_CATEGORY_INDEXING)
62
141
 
63
- # Complete immediately if already synced
142
+ # Complete immediately if already synced (without recording completion time)
64
143
  if status == "synced":
65
- progress.complete_sub_task(fq_name)
144
+ progress.complete_sub_task(fq_name, record_time=False)
66
145
 
67
146
  return True
68
147
  return False
@@ -125,16 +204,24 @@ class UseIndexPoller:
125
204
  self.stream_position = 0
126
205
  self.stream_order = [] # Track the order of streams as they appear in data
127
206
 
207
+ # Timing will be tracked by TaskProgress
208
+
128
209
  def poll(self) -> None:
129
210
  """
130
211
  Standard stream-based polling for use_index.
131
212
  """
213
+ # Read show_duration_summary config flag (defaults to True for backward compatibility)
214
+ show_duration_summary = bool(self.res.config.get("show_duration_summary", True))
215
+
132
216
  with create_progress(
133
217
  description="Initializing data index",
134
- success_message="Initialization complete",
218
+ success_message="", # We'll handle this in the context manager
135
219
  leading_newline=True,
136
220
  trailing_newline=True,
221
+ show_duration_summary=show_duration_summary,
137
222
  ) as progress:
223
+ # Set process start time
224
+ progress.set_process_start_time()
138
225
  progress.update_main_status("Validating data sources")
139
226
  self._maybe_delete_stale(progress)
140
227
 
@@ -145,123 +232,126 @@ class UseIndexPoller:
145
232
  self._poll_loop(progress)
146
233
  self._post_check(progress)
147
234
 
235
+ # Set process end time (summary will be automatically printed by __exit__)
236
+ progress.set_process_end_time()
237
+
148
238
  def _add_cache_subtask(self, progress) -> None:
149
239
  """Add a subtask showing cache usage information only when cache is used."""
150
240
  if self.cache.using_cache:
151
241
  # Cache was used - show how many sources were cached
152
242
  total_sources = len(self.cache.sources)
153
243
  cached_sources = total_sources - len(self.sources)
154
- progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage")
244
+ progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage", category=TASK_CATEGORY_CACHE)
155
245
  # Complete the subtask immediately since it's just informational
156
246
  progress.complete_sub_task("cache_usage")
157
247
 
158
- def _get_stream_column_hashes(self, sources: List[str]) -> Dict[str, str]:
248
+ def _get_stream_column_hashes(self, sources: List[str], progress) -> Dict[str, str]:
159
249
  """
160
250
  Query data_streams to get current column hashes for the given sources.
161
251
 
162
- Returns a dict mapping FQN -> column hash.
252
+ Args:
253
+ sources: List of source FQNs to query
254
+ progress: TaskProgress instance for updating status on error
255
+
256
+ Returns:
257
+ Dict mapping FQN -> column hash
258
+
259
+ Raises:
260
+ ValueError: If the query fails (permissions, table doesn't exist, etc.)
163
261
  """
164
262
  from relationalai.clients.snowflake import PYREL_ROOT_DB
165
263
 
264
+ # Build FQN list for SQL IN clause
166
265
  fqn_list = ", ".join([f"'{source}'" for source in sources])
167
266
 
168
- hash_query = f"""
169
- SELECT
170
- FQ_OBJECT_NAME,
171
- SHA2(
172
- LISTAGG(
173
- value:name::VARCHAR ||
174
- CASE
175
- WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
176
- THEN CASE value:type::VARCHAR
177
- WHEN 'FIXED' THEN 'NUMBER'
178
- WHEN 'REAL' THEN 'FLOAT'
179
- WHEN 'TEXT' THEN 'TEXT'
180
- ELSE value:type::VARCHAR
181
- END || '(' || value:precision || ',' || value:scale || ')'
182
- WHEN value:precision IS NOT NULL AND value:scale IS NULL
183
- THEN CASE value:type::VARCHAR
184
- WHEN 'FIXED' THEN 'NUMBER'
185
- WHEN 'REAL' THEN 'FLOAT'
186
- WHEN 'TEXT' THEN 'TEXT'
187
- ELSE value:type::VARCHAR
188
- END || '(0,' || value:precision || ')'
189
- WHEN value:length IS NOT NULL
190
- THEN CASE value:type::VARCHAR
191
- WHEN 'FIXED' THEN 'NUMBER'
192
- WHEN 'REAL' THEN 'FLOAT'
193
- WHEN 'TEXT' THEN 'TEXT'
194
- ELSE value:type::VARCHAR
195
- END || '(' || value:length || ')'
196
- ELSE CASE value:type::VARCHAR
197
- WHEN 'FIXED' THEN 'NUMBER'
198
- WHEN 'REAL' THEN 'FLOAT'
199
- WHEN 'TEXT' THEN 'TEXT'
200
- ELSE value:type::VARCHAR
201
- END
202
- END ||
203
- CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
204
- ','
205
- ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
206
- 256
207
- ) AS STREAM_HASH
208
- FROM {self.app_name}.api.data_streams,
209
- LATERAL FLATTEN(input => COLUMNS) f
210
- WHERE RAI_DATABASE = '{PYREL_ROOT_DB}' AND FQ_OBJECT_NAME IN ({fqn_list})
211
- GROUP BY FQ_OBJECT_NAME;
212
- """
213
-
214
- hash_results = self.res._exec(hash_query)
215
- return {row["FQ_OBJECT_NAME"]: row["STREAM_HASH"] for row in hash_results}
267
+ # Format query template with actual values
268
+ hash_query = STREAM_COLUMN_HASH_QUERY.format(
269
+ app_name=self.app_name,
270
+ rai_database=PYREL_ROOT_DB,
271
+ fqn_list=fqn_list
272
+ )
216
273
 
217
- def _filter_truly_stale_sources(self, stale_sources: List[str]) -> List[str]:
274
+ try:
275
+ hash_results = self.res._exec(hash_query)
276
+ return {row["FQ_OBJECT_NAME"]: row["STREAM_HASH"] for row in hash_results}
277
+
278
+ except Exception as e:
279
+ logger.error(f"Failed to query stream column hashes: {e}")
280
+ logger.error(f" Query: {hash_query[:200]}...")
281
+ logger.error(f" Sources: {sources}")
282
+ progress.update_main_status("❌ Failed to validate data stream metadata")
283
+ raise ValueError(
284
+ f"Failed to validate stream column hashes. This may indicate a permissions "
285
+ f"issue or missing data_streams table. Error: {e}"
286
+ ) from e
287
+
288
+ def _filter_truly_stale_sources(self, stale_sources: List[str], progress) -> List[str]:
218
289
  """
219
290
  Filter stale sources to only include those with mismatched column hashes.
220
291
 
292
+ Args:
293
+ stale_sources: List of source FQNs marked as stale
294
+ progress: TaskProgress instance for updating status on error
295
+
296
+ Returns:
297
+ List of truly stale sources that need to be deleted/recreated
298
+
221
299
  A source is truly stale if:
222
300
  - The stream doesn't exist (needs to be created), OR
223
301
  - The column hashes don't match (needs to be recreated)
224
302
  """
225
- stream_hashes = self._get_stream_column_hashes(stale_sources)
303
+ stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
226
304
 
227
305
  truly_stale = []
228
306
  for source in stale_sources:
229
307
  source_hash = self.source_info[source].get("columns_hash")
230
308
  stream_hash = stream_hashes.get(source)
231
309
 
232
- # Debug prints to see hash comparison
233
- # print(f"\n[DEBUG] Source: {source}")
234
- # print(f" Source table hash: {source_hash}")
235
- # print(f" Stream hash: {stream_hash}")
236
- # print(f" Match: {source_hash == stream_hash}")
237
- # print(f" Action: {'KEEP (valid)' if stream_hash is not None and source_hash == stream_hash else 'DELETE (stale)'}")
310
+ # Log hash comparison for debugging
311
+ logger.debug(f"Source: {source}")
312
+ logger.debug(f" Source table hash: {source_hash}")
313
+ logger.debug(f" Stream hash: {stream_hash}")
314
+ logger.debug(f" Match: {source_hash == stream_hash}")
238
315
 
239
316
  if stream_hash is None or source_hash != stream_hash:
317
+ logger.debug(" Action: DELETE (stale)")
240
318
  truly_stale.append(source)
319
+ else:
320
+ logger.debug(" Action: KEEP (valid)")
241
321
 
242
- # print(f"\n[DEBUG] Stale sources summary:")
243
- # print(f" Total candidates: {len(stale_sources)}")
244
- # print(f" Truly stale: {len(truly_stale)}")
245
- # print(f" Skipped (valid): {len(stale_sources) - len(truly_stale)}\n")
322
+ logger.debug(f"Stale sources summary: {len(truly_stale)}/{len(stale_sources)} truly stale")
246
323
 
247
324
  return truly_stale
248
325
 
249
326
  def _add_deletion_subtasks(self, progress, sources: List[str]) -> None:
250
- """Add progress subtasks for source deletion."""
327
+ """Add progress subtasks for source deletion.
328
+
329
+ Args:
330
+ progress: TaskProgress instance
331
+ sources: List of source FQNs to be deleted
332
+ """
251
333
  if len(sources) <= MAX_INDIVIDUAL_SUBTASKS:
252
334
  for i, source in enumerate(sources):
253
335
  progress.add_sub_task(
254
336
  f"Removing stale stream {source} ({i+1}/{len(sources)})",
255
- task_id=f"stale_source_{i}"
337
+ task_id=f"stale_source_{i}",
338
+ category=TASK_CATEGORY_VALIDATION
256
339
  )
257
340
  else:
258
341
  progress.add_sub_task(
259
342
  f"Removing {len(sources)} stale data sources",
260
- task_id="stale_sources_summary"
343
+ task_id="stale_sources_summary",
344
+ category=TASK_CATEGORY_VALIDATION
261
345
  )
262
346
 
263
347
  def _complete_deletion_subtasks(self, progress, sources: List[str], deleted_count: int) -> None:
264
- """Complete progress subtasks for source deletion."""
348
+ """Complete progress subtasks for source deletion.
349
+
350
+ Args:
351
+ progress: TaskProgress instance
352
+ sources: List of source FQNs that were processed
353
+ deleted_count: Number of sources successfully deleted
354
+ """
265
355
  if len(sources) <= MAX_INDIVIDUAL_SUBTASKS:
266
356
  for i in range(len(sources)):
267
357
  if f"stale_source_{i}" in progress._tasks:
@@ -277,7 +367,11 @@ class UseIndexPoller:
277
367
  progress.complete_sub_task("stale_sources_summary")
278
368
 
279
369
  def _maybe_delete_stale(self, progress) -> None:
280
- """Check for and delete stale data streams that need recreation."""
370
+ """Check for and delete stale data streams that need recreation.
371
+
372
+ Args:
373
+ progress: TaskProgress instance for tracking deletion progress
374
+ """
281
375
  with debugging.span("check_sources"):
282
376
  stale_sources = [
283
377
  source
@@ -291,7 +385,7 @@ class UseIndexPoller:
291
385
  with DebuggingSpan("validate_sources"):
292
386
  try:
293
387
  # Validate which sources truly need deletion by comparing column hashes
294
- truly_stale = self._filter_truly_stale_sources(stale_sources)
388
+ truly_stale = self._filter_truly_stale_sources(stale_sources, progress)
295
389
 
296
390
  if not truly_stale:
297
391
  return
@@ -330,6 +424,11 @@ class UseIndexPoller:
330
424
  raise e from None
331
425
 
332
426
  def _poll_loop(self, progress) -> None:
427
+ """Main polling loop for use_index streams.
428
+
429
+ Args:
430
+ progress: TaskProgress instance for tracking polling progress
431
+ """
333
432
  source_references = self.res._get_source_references(self.source_info)
334
433
  sources_object_references_str = ", ".join(source_references)
335
434
 
@@ -341,7 +440,7 @@ class UseIndexPoller:
341
440
  with debugging.span("check_erp_status"):
342
441
  # Add subtask for ERP status check
343
442
  if self._erp_check_task_id is None:
344
- self._erp_check_task_id = progress.add_sub_task("Checking system status", task_id="erp_check")
443
+ self._erp_check_task_id = progress.add_sub_task("Checking system status", task_id="erp_check", category=TASK_CATEGORY_STATUS)
345
444
 
346
445
  if not self.res.is_erp_running(self.app_name):
347
446
  progress.update_sub_task("erp_check", "❌ System status check failed")
@@ -376,11 +475,18 @@ class UseIndexPoller:
376
475
  use_index_json_str = results[0]["USE_INDEX"]
377
476
 
378
477
  # Parse the JSON string into a Python dictionary
379
- use_index_data = json.loads(use_index_json_str)
478
+ try:
479
+ use_index_data = json.loads(use_index_json_str)
480
+ except json.JSONDecodeError as e:
481
+ logger.error(f"Invalid JSON from use_index API: {e}")
482
+ logger.error(f"Raw response (first 500 chars): {use_index_json_str[:500]}")
483
+ progress.update_main_status("❌ Received invalid response from server")
484
+ raise ValueError(f"Invalid JSON response from use_index: {e}") from e
485
+
380
486
  span.update(use_index_data)
381
487
 
382
- # Useful to see the full use_index_data on each poll loop
383
- # print(f"\n\nuse_index_data: {json.dumps(use_index_data, indent=4)}\n\n")
488
+ # Log the full use_index_data for debugging if needed
489
+ logger.debug(f"use_index_data: {json.dumps(use_index_data, indent=4)}")
384
490
 
385
491
  all_data = use_index_data.get("data", [])
386
492
  ready = use_index_data.get("ready", False)
@@ -405,16 +511,17 @@ class UseIndexPoller:
405
511
  if not ready and all_data:
406
512
  progress.update_main_status("Processing background tasks. This may take a while...")
407
513
 
408
- # Build complete stream order first
409
- for data in all_data:
410
- if data is None:
411
- continue
412
- fq_name = data.get("fq_object_name", "Unknown")
413
- if fq_name not in self.stream_order:
414
- self.stream_order.append(fq_name)
514
+ # Build complete stream order first (only on first iteration with data)
515
+ if self.total_streams == 0:
516
+ for data in all_data:
517
+ if data is None:
518
+ continue
519
+ fq_name = data.get("fq_object_name", "Unknown")
520
+ if fq_name not in self.stream_order:
521
+ self.stream_order.append(fq_name)
415
522
 
416
- # Set total streams count based on complete order
417
- self.total_streams = len(self.stream_order)
523
+ # Set total streams count based on complete order (only once)
524
+ self.total_streams = len(self.stream_order)
418
525
 
419
526
  # Add new streams as subtasks if we haven't reached the limit
420
527
  for data in all_data:
@@ -474,62 +581,69 @@ class UseIndexPoller:
474
581
  for engine in engines:
475
582
  if not engine or not isinstance(engine, dict):
476
583
  continue
477
-
478
- name = engine.get("name", "Unknown")
479
584
  size = self.engine_size
480
- if name not in self.engine_task_ids:
481
- self.engine_task_ids[name] = progress.add_sub_task(f"Provisioning engine {name} ({size})", task_id=name)
482
-
483
- state = (engine.get("state") or "").lower()
585
+ name = engine.get("name", "Unknown")
484
586
  status = (engine.get("status") or "").lower()
587
+ sub_task_id = self.engine_task_ids.get(name, None)
588
+ sub_task_status_message = ""
485
589
 
486
- # Determine engine status message
487
- if state == "ready" or status == "ready":
488
- status_message = f"Engine {name} ({size}) ready"
489
- should_complete = True
490
- else:
590
+ # Complete the sub task if it exists and the engine status is ready
591
+ if sub_task_id and name in progress._tasks and not progress._tasks[name].completed and (status == "ready"):
592
+ sub_task_status_message = f"Engine {name} ({size}) ready"
593
+ progress.update_sub_task(name, sub_task_status_message)
594
+ progress.complete_sub_task(name)
595
+
596
+ # Add the sub task if it doesn't exist and the engine status is pending
597
+ if not sub_task_id and status == "pending":
491
598
  writer = engine.get("writer", False)
492
599
  engine_type = "writer engine" if writer else "engine"
493
- status_message = f"Provisioning {engine_type} {name} ({size})"
494
- should_complete = False
495
-
496
- # Only update if the task isn't already completed
497
- if name in progress._tasks and not progress._tasks[name].completed:
498
- progress.update_sub_task(name, status_message)
499
-
500
- if should_complete:
501
- progress.complete_sub_task(name)
600
+ sub_task_status_message = f"Provisioning {engine_type} {name} ({size})"
601
+ self.engine_task_ids[name] = progress.add_sub_task(sub_task_status_message, task_id=name, category=TASK_CATEGORY_PROVISIONING)
502
602
 
503
603
  # Special handling for CDC_MANAGED_ENGINE - mark ready when any stream starts processing
504
- if CDC_MANAGED_ENGINE in self.engine_task_ids:
604
+ cdc_task = progress._tasks.get(CDC_MANAGED_ENGINE) if CDC_MANAGED_ENGINE in progress._tasks else None
605
+ if CDC_MANAGED_ENGINE in self.engine_task_ids and cdc_task and not cdc_task.completed:
606
+
505
607
  has_processing_streams = any(
506
608
  stream.get("next_batch_status", "") == "processing"
507
609
  for stream in all_data
508
610
  )
509
- if has_processing_streams and CDC_MANAGED_ENGINE in progress._tasks and not progress._tasks[CDC_MANAGED_ENGINE].completed:
611
+ if has_processing_streams and cdc_task and not cdc_task.completed:
510
612
  progress.update_sub_task(CDC_MANAGED_ENGINE, f"Engine {CDC_MANAGED_ENGINE} ({self.engine_size}) ready")
511
613
  progress.complete_sub_task(CDC_MANAGED_ENGINE)
512
614
 
513
615
  self.counter += 1
514
616
 
515
617
  # Handle relations data
516
- if not ready and relations and isinstance(relations, dict):
618
+ if relations and isinstance(relations, dict):
517
619
  txn = relations.get("txn", {}) or {}
518
620
  txn_id = txn.get("id", None)
519
621
 
520
622
  # Only show relations subtask if there is a valid txn object
521
623
  if txn_id:
522
624
  status = relations.get("status", "").upper()
523
- state = txn.get("state", "").upper()
524
625
 
525
626
  # Create relations subtask if it doesn't exist
526
627
  if self.relations_task_id is None:
527
- self.relations_task_id = progress.add_sub_task("Populating relations", task_id="relations")
628
+ self.relations_task_id = progress.add_sub_task("Populating relations", task_id="relations", category=TASK_CATEGORY_RELATIONS)
629
+
630
+ # Set the start time from the JSON if available (always update)
631
+ start_time_ms = relations.get("start_time")
632
+ if start_time_ms:
633
+ start_time_seconds = start_time_ms / 1000.0
634
+ progress._tasks["relations"].added_time = start_time_seconds
528
635
 
529
636
  # Update relations status
530
- if state == "COMPLETED":
637
+ if status == "COMPLETED":
531
638
  progress.update_sub_task("relations", f"Relations populated (txn: {txn_id})")
532
- progress.complete_sub_task("relations")
639
+
640
+ # Set the completion time from the JSON if available
641
+ end_time_ms = relations.get("end_time")
642
+ if end_time_ms:
643
+ end_time_seconds = end_time_ms / 1000.0
644
+ progress._tasks["relations"].completed_time = end_time_seconds
645
+
646
+ progress.complete_sub_task("relations", record_time=False) # Don't record local time
533
647
  else:
534
648
  progress.update_sub_task("relations", f"Relations populating (txn: {txn_id})")
535
649
 
@@ -580,85 +694,128 @@ class UseIndexPoller:
580
694
 
581
695
  return break_loop
582
696
 
583
- poll_with_specified_overhead(lambda: check_ready(progress), overhead_rate=0.1, max_delay=1)
697
+ poll_with_specified_overhead(lambda: check_ready(progress), overhead_rate=POLL_OVERHEAD_RATE, max_delay=POLL_MAX_DELAY)
584
698
 
585
699
  def _post_check(self, progress) -> None:
586
- num_tables_altered = 0
700
+ """Run post-processing checks including change tracking enablement.
587
701
 
588
- enabled_tables = []
589
- if (
590
- self.tables_with_not_enabled_change_tracking
591
- and self.res.config.get("ensure_change_tracking", False)
592
- ):
593
- tables_to_process = self.tables_with_not_enabled_change_tracking
594
-
595
- # Add subtasks based on count
596
- if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
597
- # Add individual subtasks for each table
598
- for i, table in enumerate(tables_to_process):
599
- fqn, kind = table
600
- progress.add_sub_task(f"Enabling change tracking on {fqn} ({i+1}/{len(tables_to_process)})", task_id=f"change_tracking_{i}")
601
- else:
602
- # Add single summary subtask for many tables
603
- progress.add_sub_task(f"Enabling change tracking on {len(tables_to_process)} tables", task_id="change_tracking_summary")
702
+ Args:
703
+ progress: TaskProgress instance for tracking progress
604
704
 
605
- # Process tables
705
+ Raises:
706
+ SnowflakeChangeTrackingNotEnabledException: If change tracking cannot be enabled
707
+ SnowflakeTableObjectsException: If there are table-related errors
708
+ EngineProvisioningFailed: If engine provisioning fails
709
+ """
710
+ num_tables_altered = 0
711
+ failed_tables = [] # Track tables that failed to enable change tracking
712
+
713
+ enabled_tables = []
714
+ if (
715
+ self.tables_with_not_enabled_change_tracking
716
+ and self.res.config.get("ensure_change_tracking", False)
717
+ ):
718
+ tables_to_process = self.tables_with_not_enabled_change_tracking
719
+ # Track timing for change tracking
720
+
721
+ # Add subtasks based on count
722
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
723
+ # Add individual subtasks for each table
606
724
  for i, table in enumerate(tables_to_process):
607
- try:
608
- fqn, kind = table
609
- self.res._exec(f"ALTER {kind} {fqn} SET CHANGE_TRACKING = TRUE;")
610
- enabled_tables.append(table)
611
- num_tables_altered += 1
612
-
613
- # Update progress based on subtask type
614
- if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
615
- # Complete individual table subtask
616
- progress.complete_sub_task(f"change_tracking_{i}")
617
- else:
618
- # Update summary subtask with progress
619
- progress.update_sub_task("change_tracking_summary",
620
- f"Enabling change tracking on {len(tables_to_process)} tables... ({i+1}/{len(tables_to_process)})")
621
- except Exception:
622
- # Handle errors based on subtask type
725
+ fqn, kind = table
726
+ progress.add_sub_task(f"Enabling change tracking on {fqn} ({i+1}/{len(tables_to_process)})", task_id=f"change_tracking_{i}", category=TASK_CATEGORY_CHANGE_TRACKING)
727
+ else:
728
+ # Add single summary subtask for many tables
729
+ progress.add_sub_task(f"Enabling change tracking on {len(tables_to_process)} tables", task_id="change_tracking_summary", category=TASK_CATEGORY_CHANGE_TRACKING)
730
+
731
+ # Process tables
732
+ for i, table in enumerate(tables_to_process):
733
+ fqn, kind = table # Unpack outside try block to ensure fqn is defined
734
+
735
+ try:
736
+ # Validate table_type to prevent SQL injection
737
+ # Should only be TABLE or VIEW
738
+ if kind not in ("TABLE", "VIEW"):
739
+ logger.error(f"Invalid table kind '{kind}' for {fqn}, skipping")
740
+ failed_tables.append((fqn, f"Invalid table kind: {kind}"))
741
+ # Mark as failed in progress
623
742
  if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
624
- # Complete the individual subtask even if it failed
625
743
  if f"change_tracking_{i}" in progress._tasks:
744
+ progress.update_sub_task(f"change_tracking_{i}", f"❌ Invalid type: {fqn}")
626
745
  progress.complete_sub_task(f"change_tracking_{i}")
627
- pass
628
-
629
- # Complete summary subtask if used
630
- if len(tables_to_process) > MAX_INDIVIDUAL_SUBTASKS and "change_tracking_summary" in progress._tasks:
631
- if num_tables_altered > 0:
632
- s = "s" if num_tables_altered > 1 else ""
633
- progress.update_sub_task("change_tracking_summary", f"Enabled change tracking on {num_tables_altered} table{s}")
634
- progress.complete_sub_task("change_tracking_summary")
635
-
636
- # Remove the tables that were successfully enabled from the list of not enabled tables
637
- # so that we don't raise an exception for them later
638
- self.tables_with_not_enabled_change_tracking = [
639
- t for t in self.tables_with_not_enabled_change_tracking if t not in enabled_tables
640
- ]
641
-
642
- if self.tables_with_not_enabled_change_tracking:
643
- progress.update_main_status("Errors found. See below for details.")
644
- raise SnowflakeChangeTrackingNotEnabledException(
645
- self.tables_with_not_enabled_change_tracking
646
- )
746
+ continue
747
+
748
+ # Execute ALTER statement
749
+ # Note: fqn should already be properly quoted from source_info
750
+ self.res._exec(f"ALTER {kind} {fqn} SET CHANGE_TRACKING = TRUE;")
751
+ enabled_tables.append(table)
752
+ num_tables_altered += 1
753
+
754
+ # Update progress based on subtask type
755
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
756
+ # Complete individual table subtask
757
+ progress.complete_sub_task(f"change_tracking_{i}")
758
+ else:
759
+ # Update summary subtask with progress
760
+ progress.update_sub_task("change_tracking_summary",
761
+ f"Enabling change tracking on {len(tables_to_process)} tables... ({i+1}/{len(tables_to_process)})")
762
+ except Exception as e:
763
+ # Log the error for debugging
764
+ logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
765
+ failed_tables.append((fqn, str(e)))
766
+
767
+ # Handle errors based on subtask type
768
+ if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
769
+ # Mark the individual subtask as failed and complete it
770
+ if f"change_tracking_{i}" in progress._tasks:
771
+ progress.update_sub_task(f"change_tracking_{i}", f"❌ Failed: {fqn}")
772
+ progress.complete_sub_task(f"change_tracking_{i}")
773
+ # Continue processing other tables despite this failure
774
+
775
+ # Complete summary subtask if used
776
+ if len(tables_to_process) > MAX_INDIVIDUAL_SUBTASKS and "change_tracking_summary" in progress._tasks:
777
+ if num_tables_altered > 0:
778
+ s = "s" if num_tables_altered > 1 else ""
779
+ success_msg = f"Enabled change tracking on {num_tables_altered} table{s}"
780
+ if failed_tables:
781
+ success_msg += f" ({len(failed_tables)} failed)"
782
+ progress.update_sub_task("change_tracking_summary", success_msg)
783
+ elif failed_tables:
784
+ progress.update_sub_task("change_tracking_summary", f"❌ Failed on {len(failed_tables)} table(s)")
785
+ progress.complete_sub_task("change_tracking_summary")
786
+
787
+ # Log summary of failed tables
788
+ if failed_tables:
789
+ logger.warning(f"Failed to enable change tracking on {len(failed_tables)} table(s)")
790
+ for fqn, error in failed_tables:
791
+ logger.warning(f" {fqn}: {error}")
792
+
793
+ # Remove the tables that were successfully enabled from the list of not enabled tables
794
+ # so that we don't raise an exception for them later
795
+ self.tables_with_not_enabled_change_tracking = [
796
+ t for t in self.tables_with_not_enabled_change_tracking if t not in enabled_tables
797
+ ]
647
798
 
648
- if self.table_objects_with_other_errors:
649
- progress.update_main_status("Errors found. See below for details.")
650
- raise SnowflakeTableObjectsException(self.table_objects_with_other_errors)
651
- if self.engine_errors:
652
- progress.update_main_status("Errors found. See below for details.")
653
- # if there is an engine error, probably auto create engine failed
654
- # Create a synthetic exception from the first engine error
655
- first_error = self.engine_errors[0]
656
- error_message = first_error.get("message", "Unknown engine error")
657
- synthetic_exception = Exception(f"Engine error: {error_message}")
658
- raise EngineProvisioningFailed(self.engine_name, synthetic_exception)
659
-
660
- if num_tables_altered > 0:
661
- self._poll_loop(progress)
799
+ if self.tables_with_not_enabled_change_tracking:
800
+ progress.update_main_status("Errors found. See below for details.")
801
+ raise SnowflakeChangeTrackingNotEnabledException(
802
+ self.tables_with_not_enabled_change_tracking
803
+ )
804
+
805
+ if self.table_objects_with_other_errors:
806
+ progress.update_main_status("Errors found. See below for details.")
807
+ raise SnowflakeTableObjectsException(self.table_objects_with_other_errors)
808
+ if self.engine_errors:
809
+ progress.update_main_status("Errors found. See below for details.")
810
+ # if there is an engine error, probably auto create engine failed
811
+ # Create a synthetic exception from the first engine error
812
+ first_error = self.engine_errors[0]
813
+ error_message = first_error.get("message", "Unknown engine error")
814
+ synthetic_exception = Exception(f"Engine error: {error_message}")
815
+ raise EngineProvisioningFailed(self.engine_name, synthetic_exception)
816
+
817
+ if num_tables_altered > 0:
818
+ self._poll_loop(progress)
662
819
 
663
820
  class DirectUseIndexPoller(UseIndexPoller):
664
821
  """
@@ -743,17 +900,21 @@ class DirectUseIndexPoller(UseIndexPoller):
743
900
  attempt += 1
744
901
  return False
745
902
 
903
+ # Read show_duration_summary config flag (defaults to True for backward compatibility)
904
+ show_duration_summary = bool(self.res.config.get("show_duration_summary", True))
905
+
746
906
  with create_progress(
747
907
  description="Preparing your data...",
748
908
  success_message="Done",
749
909
  leading_newline=True,
750
910
  trailing_newline=True,
911
+ show_duration_summary=show_duration_summary,
751
912
  ) as progress:
752
913
  # Add cache usage subtask
753
914
  self._add_cache_subtask(progress)
754
915
 
755
916
  with debugging.span("poll_direct"):
756
- poll_with_specified_overhead(lambda: check_direct(progress), overhead_rate=0.1, max_delay=1)
917
+ poll_with_specified_overhead(lambda: check_direct(progress), overhead_rate=POLL_OVERHEAD_RATE, max_delay=POLL_MAX_DELAY)
757
918
 
758
919
  # Run the same post-check logic as UseIndexPoller
759
920
  self._post_check(progress)