starrocks-br 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
starrocks_br/restore.py CHANGED
@@ -1,41 +1,55 @@
1
- import time
2
1
  import datetime
3
- from typing import Dict, List, Optional
4
- from . import history, concurrency, logger
2
+ import time
3
+
4
+ from . import concurrency, exceptions, history, logger, timezone, utils
5
+
6
+ MAX_POLLS = 86400 # 1 day
7
+
8
+
9
+ def _calculate_next_interval(current_interval: float, max_interval: float) -> float:
10
+ """Calculate the next polling interval using exponential backoff.
11
+
12
+ Args:
13
+ current_interval: Current polling interval in seconds
14
+ max_interval: Maximum allowed interval in seconds
15
+
16
+ Returns:
17
+ Next interval (min of doubled current interval and max_interval)
18
+ """
19
+ return min(current_interval * 2, max_interval)
5
20
 
6
- MAX_POLLS = 86400 # 1 day
7
21
 
8
22
  def get_snapshot_timestamp(db, repo_name: str, snapshot_name: str) -> str:
9
23
  """Get the backup timestamp for a specific snapshot from the repository.
10
-
24
+
11
25
  Args:
12
26
  db: Database connection
13
27
  repo_name: Repository name
14
28
  snapshot_name: Snapshot name to look up
15
-
29
+
16
30
  Returns:
17
31
  The backup timestamp string
18
-
32
+
19
33
  Raises:
20
34
  ValueError: If snapshot is not found in the repository
21
35
  """
22
- query = f"SHOW SNAPSHOT ON {repo_name} WHERE Snapshot = '{snapshot_name}'"
23
-
36
+ query = f"SHOW SNAPSHOT ON {utils.quote_identifier(repo_name)} WHERE Snapshot = {utils.quote_value(snapshot_name)}"
37
+
24
38
  rows = db.query(query)
25
39
  if not rows:
26
- raise ValueError(f"Snapshot '{snapshot_name}' not found in repository '{repo_name}'")
27
-
40
+ raise exceptions.SnapshotNotFoundError(snapshot_name, repo_name)
41
+
28
42
  # The result should be a single row with columns: Snapshot, Timestamp, Status
29
43
  result = rows[0]
30
-
44
+
31
45
  if isinstance(result, dict):
32
46
  timestamp = result.get("Timestamp")
33
47
  else:
34
48
  timestamp = result[1] if len(result) > 1 else None
35
-
49
+
36
50
  if not timestamp:
37
51
  raise ValueError(f"Could not extract timestamp for snapshot '{snapshot_name}'")
38
-
52
+
39
53
  return timestamp
40
54
 
41
55
 
@@ -48,10 +62,10 @@ def build_partition_restore_command(
48
62
  backup_timestamp: str,
49
63
  ) -> str:
50
64
  """Build RESTORE command for single partition recovery."""
51
- return f"""RESTORE SNAPSHOT {backup_label}
52
- FROM {repository}
53
- DATABASE {database}
54
- ON (TABLE {table} PARTITION ({partition}))
65
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
66
+ FROM {utils.quote_identifier(repository)}
67
+ DATABASE {utils.quote_identifier(database)}
68
+ ON (TABLE {utils.quote_identifier(table)} PARTITION ({utils.quote_identifier(partition)}))
55
69
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
56
70
 
57
71
 
@@ -63,10 +77,10 @@ def build_table_restore_command(
63
77
  backup_timestamp: str,
64
78
  ) -> str:
65
79
  """Build RESTORE command for full table recovery."""
66
- return f"""RESTORE SNAPSHOT {backup_label}
67
- FROM {repository}
68
- DATABASE {database}
69
- ON (TABLE {table})
80
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
81
+ FROM {utils.quote_identifier(repository)}
82
+ DATABASE {utils.quote_identifier(database)}
83
+ ON (TABLE {utils.quote_identifier(table)})
70
84
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
71
85
 
72
86
 
@@ -77,47 +91,57 @@ def build_database_restore_command(
77
91
  backup_timestamp: str,
78
92
  ) -> str:
79
93
  """Build RESTORE command for full database recovery."""
80
- return f"""RESTORE SNAPSHOT {backup_label}
81
- FROM {repository}
82
- DATABASE {database}
94
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
95
+ FROM {utils.quote_identifier(repository)}
96
+ DATABASE {utils.quote_identifier(database)}
83
97
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
84
98
 
85
99
 
86
- def poll_restore_status(db, label: str, database: str, max_polls: int = MAX_POLLS, poll_interval: float = 1.0) -> Dict[str, str]:
100
+ def poll_restore_status(
101
+ db,
102
+ label: str,
103
+ database: str,
104
+ max_polls: int = MAX_POLLS,
105
+ poll_interval: float = 1.0,
106
+ max_poll_interval: float = 60.0,
107
+ ) -> dict[str, str]:
87
108
  """Poll restore status until completion or timeout.
88
-
109
+
89
110
  Note: SHOW RESTORE only returns the LAST restore in a database.
90
111
  We verify that the Label matches our expected label.
91
-
112
+
92
113
  Important: If we see a different label, it means another restore
93
114
  operation overwrote ours and we've lost tracking (race condition).
94
-
115
+
95
116
  Args:
96
117
  db: Database connection
97
118
  label: Expected snapshot label to monitor
98
119
  database: Database name where restore was submitted
99
120
  max_polls: Maximum number of polling attempts
100
- poll_interval: Seconds to wait between polls
101
-
121
+ poll_interval: Initial seconds to wait between polls (exponentially increases)
122
+ max_poll_interval: Maximum interval between polls (default 60 seconds)
123
+
102
124
  Returns dictionary with keys: state, label
103
125
  Possible states: FINISHED, CANCELLED, TIMEOUT, ERROR, LOST
104
126
  """
105
- query = f"SHOW RESTORE FROM {database}"
127
+ query = f"SHOW RESTORE FROM {utils.quote_identifier(database)}"
106
128
  first_poll = True
107
129
  last_state = None
108
130
  poll_count = 0
109
-
131
+ current_interval = poll_interval
132
+
110
133
  for _ in range(max_polls):
111
134
  poll_count += 1
112
135
  try:
113
136
  rows = db.query(query)
114
-
137
+
115
138
  if not rows:
116
- time.sleep(poll_interval)
139
+ time.sleep(current_interval)
140
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
117
141
  continue
118
-
142
+
119
143
  result = rows[0]
120
-
144
+
121
145
  if isinstance(result, dict):
122
146
  snapshot_label = result.get("Label", "")
123
147
  state = result.get("State", "UNKNOWN")
@@ -125,29 +149,31 @@ def poll_restore_status(db, label: str, database: str, max_polls: int = MAX_POLL
125
149
  # Tuple format: JobId, Label, Timestamp, DbName, State, ...
126
150
  snapshot_label = result[1] if len(result) > 1 else ""
127
151
  state = result[4] if len(result) > 4 else "UNKNOWN"
128
-
152
+
129
153
  if snapshot_label != label and snapshot_label:
130
154
  if first_poll:
131
155
  first_poll = False
132
- time.sleep(poll_interval)
156
+ time.sleep(current_interval)
157
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
133
158
  continue
134
159
  else:
135
160
  return {"state": "LOST", "label": label}
136
-
161
+
137
162
  first_poll = False
138
-
163
+
139
164
  if state != last_state or poll_count % 10 == 0:
140
165
  logger.progress(f"Restore status: {state} (poll {poll_count}/{max_polls})")
141
166
  last_state = state
142
-
167
+
143
168
  if state in ["FINISHED", "CANCELLED", "UNKNOWN"]:
144
169
  return {"state": state, "label": label}
145
-
146
- time.sleep(poll_interval)
147
-
170
+
171
+ time.sleep(current_interval)
172
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
173
+
148
174
  except Exception:
149
175
  return {"state": "ERROR", "label": label}
150
-
176
+
151
177
  return {"state": "TIMEOUT", "label": label}
152
178
 
153
179
 
@@ -161,13 +187,14 @@ def execute_restore(
161
187
  max_polls: int = MAX_POLLS,
162
188
  poll_interval: float = 1.0,
163
189
  scope: str = "restore",
164
- ) -> Dict:
190
+ ) -> dict:
165
191
  """Execute a complete restore workflow: submit command and monitor progress.
166
-
192
+
167
193
  Returns dictionary with keys: success, final_status, error_message
168
194
  """
169
- started_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
170
-
195
+ cluster_tz = db.timezone
196
+ started_at = timezone.get_current_time_in_cluster_tz(cluster_tz)
197
+
171
198
  try:
172
199
  db.execute(restore_command.strip())
173
200
  except Exception as e:
@@ -175,17 +202,17 @@ def execute_restore(
175
202
  return {
176
203
  "success": False,
177
204
  "final_status": None,
178
- "error_message": f"Failed to submit restore command: {str(e)}"
205
+ "error_message": f"Failed to submit restore command: {str(e)}",
179
206
  }
180
-
207
+
181
208
  label = backup_label
182
-
209
+
183
210
  try:
184
211
  final_status = poll_restore_status(db, label, database, max_polls, poll_interval)
185
-
212
+
186
213
  success = final_status["state"] == "FINISHED"
187
- finished_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
188
-
214
+ finished_at = timezone.get_current_time_in_cluster_tz(cluster_tz)
215
+
189
216
  try:
190
217
  history.log_restore(
191
218
  db,
@@ -202,147 +229,153 @@ def execute_restore(
202
229
  )
203
230
  except Exception as e:
204
231
  logger.error(f"Failed to log restore history: {str(e)}")
205
-
232
+
206
233
  try:
207
- concurrency.complete_job_slot(db, scope=scope, label=label, final_state=final_status["state"])
234
+ concurrency.complete_job_slot(
235
+ db, scope=scope, label=label, final_state=final_status["state"]
236
+ )
208
237
  except Exception as e:
209
238
  logger.error(f"Failed to complete job slot: {str(e)}")
210
-
239
+
211
240
  return {
212
241
  "success": success,
213
242
  "final_status": final_status,
214
- "error_message": None if success else f"Restore failed with state: {final_status['state']}"
243
+ "error_message": None
244
+ if success
245
+ else f"Restore failed with state: {final_status['state']}",
215
246
  }
216
-
247
+
217
248
  except Exception as e:
218
249
  logger.error(f"Restore execution failed: {str(e)}")
219
- return {
220
- "success": False,
221
- "final_status": None,
222
- "error_message": str(e)
223
- }
250
+ return {"success": False, "final_status": None, "error_message": str(e)}
224
251
 
225
252
 
226
- def find_restore_pair(db, target_label: str) -> List[str]:
253
+ def find_restore_pair(db, target_label: str) -> list[str]:
227
254
  """Find the correct sequence of backups needed for restore.
228
-
255
+
229
256
  Args:
230
257
  db: Database connection
231
258
  target_label: The backup label to restore to
232
-
259
+
233
260
  Returns:
234
261
  List of backup labels in restore order [base_full_backup, target_label]
235
262
  or [target_label] if target is a full backup
236
-
263
+
237
264
  Raises:
238
265
  ValueError: If target label not found or incremental has no preceding full backup
239
266
  """
240
267
  query = f"""
241
268
  SELECT label, backup_type, finished_at
242
269
  FROM ops.backup_history
243
- WHERE label = '{target_label}'
270
+ WHERE label = {utils.quote_value(target_label)}
244
271
  AND status = 'FINISHED'
245
272
  """
246
-
273
+
247
274
  rows = db.query(query)
248
275
  if not rows:
249
- raise ValueError(f"Backup label '{target_label}' not found or not successful")
250
-
251
- target_info = {
252
- "label": rows[0][0],
253
- "backup_type": rows[0][1],
254
- "finished_at": rows[0][2]
255
- }
256
-
276
+ raise exceptions.BackupLabelNotFoundError(target_label)
277
+
278
+ target_info = {"label": rows[0][0], "backup_type": rows[0][1], "finished_at": rows[0][2]}
279
+
257
280
  if target_info["backup_type"] == "full":
258
281
  return [target_label]
259
-
282
+
260
283
  if target_info["backup_type"] == "incremental":
261
- database_name = target_label.split('_')[0]
262
-
284
+ database_name = target_label.split("_")[0]
285
+
263
286
  full_backup_query = f"""
264
287
  SELECT label, backup_type, finished_at
265
288
  FROM ops.backup_history
266
289
  WHERE backup_type = 'full'
267
290
  AND status = 'FINISHED'
268
- AND label LIKE '{database_name}_%'
269
- AND finished_at < '{target_info["finished_at"]}'
291
+ AND label LIKE {utils.quote_value(f"{database_name}_%")}
292
+ AND finished_at < {utils.quote_value(target_info["finished_at"])}
270
293
  ORDER BY finished_at DESC
271
294
  LIMIT 1
272
295
  """
273
-
296
+
274
297
  full_rows = db.query(full_backup_query)
275
298
  if not full_rows:
276
- raise ValueError(f"No successful full backup found before incremental '{target_label}'")
277
-
299
+ raise exceptions.NoSuccessfulFullBackupFoundError(target_label)
300
+
278
301
  base_full_backup = full_rows[0][0]
279
302
  return [base_full_backup, target_label]
280
-
281
- raise ValueError(f"Unknown backup type '{target_info['backup_type']}' for label '{target_label}'")
303
+
304
+ raise ValueError(
305
+ f"Unknown backup type '{target_info['backup_type']}' for label '{target_label}'"
306
+ )
282
307
 
283
308
 
284
- def get_tables_from_backup(db, label: str, group: Optional[str] = None, table: Optional[str] = None, database: Optional[str] = None) -> List[str]:
309
+ def get_tables_from_backup(
310
+ db,
311
+ label: str,
312
+ group: str | None = None,
313
+ table: str | None = None,
314
+ database: str | None = None,
315
+ ) -> list[str]:
285
316
  """Get list of tables to restore from backup manifest.
286
-
317
+
287
318
  Args:
288
319
  db: Database connection
289
320
  label: Backup label
290
321
  group: Optional inventory group to filter tables
291
322
  table: Optional table name to filter (single table, database comes from database parameter)
292
323
  database: Database name (required if table is specified)
293
-
324
+
294
325
  Returns:
295
326
  List of table names to restore (format: database.table)
296
-
327
+
297
328
  Raises:
298
329
  ValueError: If both group and table are specified
299
330
  ValueError: If table is specified but database is not provided
300
331
  ValueError: If table is specified but not found in backup
301
332
  """
302
333
  if group and table:
303
- raise ValueError("Cannot specify both --group and --table. Use --table for single table restore or --group for inventory group restore.")
304
-
334
+ raise exceptions.InvalidTableNameError(table, "Cannot specify both --group and --table")
335
+
305
336
  if table and not database:
306
- raise ValueError("database parameter is required when table is specified")
307
-
337
+ raise exceptions.InvalidTableNameError(
338
+ table, "database parameter is required when table is specified"
339
+ )
340
+
308
341
  query = f"""
309
342
  SELECT DISTINCT database_name, table_name
310
343
  FROM ops.backup_partitions
311
- WHERE label = '{label}'
344
+ WHERE label = {utils.quote_value(label)}
312
345
  ORDER BY database_name, table_name
313
346
  """
314
-
347
+
315
348
  rows = db.query(query)
316
349
  if not rows:
317
350
  return []
318
-
351
+
319
352
  tables = [f"{row[0]}.{row[1]}" for row in rows]
320
-
353
+
321
354
  if table:
322
355
  target_table = f"{database}.{table}"
323
356
  filtered_tables = [t for t in tables if t == target_table]
324
-
357
+
325
358
  if not filtered_tables:
326
- raise ValueError(f"Table '{table}' not found in backup '{label}' for database '{database}'")
327
-
359
+ raise exceptions.TableNotFoundInBackupError(table, label, database)
360
+
328
361
  return filtered_tables
329
-
362
+
330
363
  if group:
331
364
  group_query = f"""
332
365
  SELECT database_name, table_name
333
366
  FROM ops.table_inventory
334
- WHERE inventory_group = '{group}'
367
+ WHERE inventory_group = {utils.quote_value(group)}
335
368
  """
336
-
369
+
337
370
  group_rows = db.query(group_query)
338
371
  if not group_rows:
339
372
  return []
340
-
373
+
341
374
  group_tables = set()
342
375
  for row in group_rows:
343
376
  database_name, table_name = row[0], row[1]
344
- if table_name == '*':
345
- show_tables_query = f"SHOW TABLES FROM {database_name}"
377
+ if table_name == "*":
378
+ show_tables_query = f"SHOW TABLES FROM {utils.quote_identifier(database_name)}"
346
379
  try:
347
380
  tables_rows = db.query(show_tables_query)
348
381
  for table_row in tables_rows:
@@ -351,15 +384,22 @@ def get_tables_from_backup(db, label: str, group: Optional[str] = None, table: O
351
384
  continue
352
385
  else:
353
386
  group_tables.add(f"{database_name}.{table_name}")
354
-
387
+
355
388
  tables = [table for table in tables if table in group_tables]
356
-
389
+
357
390
  return tables
358
391
 
359
392
 
360
- def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_restore: List[str], rename_suffix: str = "_restored", skip_confirmation: bool = False) -> Dict:
393
+ def execute_restore_flow(
394
+ db,
395
+ repo_name: str,
396
+ restore_pair: list[str],
397
+ tables_to_restore: list[str],
398
+ rename_suffix: str = "_restored",
399
+ skip_confirmation: bool = False,
400
+ ) -> dict:
361
401
  """Execute the complete restore flow with safety measures.
362
-
402
+
363
403
  Args:
364
404
  db: Database connection
365
405
  repo_name: Repository name
@@ -367,22 +407,16 @@ def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_
367
407
  tables_to_restore: List of tables to restore (format: database.table)
368
408
  rename_suffix: Suffix for temporary tables
369
409
  skip_confirmation: If True, skip interactive confirmation prompt
370
-
410
+
371
411
  Returns:
372
412
  Dictionary with success status and details
373
413
  """
374
414
  if not restore_pair:
375
- return {
376
- "success": False,
377
- "error_message": "No restore pair provided"
378
- }
379
-
415
+ return {"success": False, "error_message": "No restore pair provided"}
416
+
380
417
  if not tables_to_restore:
381
- return {
382
- "success": False,
383
- "error_message": "No tables to restore"
384
- }
385
-
418
+ return {"success": False, "error_message": "No tables to restore"}
419
+
386
420
  logger.info("")
387
421
  logger.info("=== RESTORE PLAN ===")
388
422
  logger.info(f"Repository: {repo_name}")
@@ -392,128 +426,143 @@ def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_
392
426
  logger.info("")
393
427
  logger.info("This will restore data to temporary tables and then perform atomic rename.")
394
428
  logger.warning("WARNING: This operation will replace existing tables!")
395
-
429
+
396
430
  if not skip_confirmation:
397
431
  confirmation = input("\nDo you want to proceed? [Y/n]: ").strip()
398
- if confirmation.lower() != 'y':
399
- return {
400
- "success": False,
401
- "error_message": "Restore operation cancelled by user"
402
- }
432
+ if confirmation.lower() != "y":
433
+ raise exceptions.RestoreOperationCancelledError()
403
434
  else:
404
435
  logger.info("Proceeding automatically (--yes flag provided)")
405
-
436
+
406
437
  try:
407
- database_name = tables_to_restore[0].split('.')[0]
408
-
438
+ database_name = tables_to_restore[0].split(".")[0]
439
+
409
440
  base_label = restore_pair[0]
410
441
  logger.info("")
411
442
  logger.info(f"Step 1: Restoring base backup '{base_label}'...")
412
-
443
+
413
444
  base_timestamp = get_snapshot_timestamp(db, repo_name, base_label)
414
-
445
+
415
446
  base_restore_command = _build_restore_command_with_rename(
416
447
  base_label, repo_name, tables_to_restore, rename_suffix, database_name, base_timestamp
417
448
  )
418
-
449
+
419
450
  base_result = execute_restore(
420
451
  db, base_restore_command, base_label, "full", repo_name, database_name, scope="restore"
421
452
  )
422
-
453
+
423
454
  if not base_result["success"]:
424
455
  return {
425
456
  "success": False,
426
- "error_message": f"Base restore failed: {base_result['error_message']}"
457
+ "error_message": f"Base restore failed: {base_result['error_message']}",
427
458
  }
428
-
459
+
429
460
  logger.success("Base restore completed successfully")
430
-
461
+
431
462
  if len(restore_pair) > 1:
432
463
  incremental_label = restore_pair[1]
433
464
  logger.info("")
434
465
  logger.info(f"Step 2: Applying incremental backup '{incremental_label}'...")
435
-
466
+
436
467
  incremental_timestamp = get_snapshot_timestamp(db, repo_name, incremental_label)
437
-
468
+
438
469
  incremental_restore_command = _build_restore_command_without_rename(
439
- incremental_label, repo_name, tables_to_restore, database_name, incremental_timestamp
470
+ incremental_label,
471
+ repo_name,
472
+ tables_to_restore,
473
+ database_name,
474
+ incremental_timestamp,
440
475
  )
441
-
476
+
442
477
  incremental_result = execute_restore(
443
- db, incremental_restore_command, incremental_label, "incremental", repo_name, database_name, scope="restore"
478
+ db,
479
+ incremental_restore_command,
480
+ incremental_label,
481
+ "incremental",
482
+ repo_name,
483
+ database_name,
484
+ scope="restore",
444
485
  )
445
-
486
+
446
487
  if not incremental_result["success"]:
447
488
  return {
448
489
  "success": False,
449
- "error_message": f"Incremental restore failed: {incremental_result['error_message']}"
490
+ "error_message": f"Incremental restore failed: {incremental_result['error_message']}",
450
491
  }
451
-
492
+
452
493
  logger.success("Incremental restore completed successfully")
453
-
494
+
454
495
  logger.info("")
455
496
  logger.info("Step 3: Performing atomic rename...")
456
497
  rename_result = _perform_atomic_rename(db, tables_to_restore, rename_suffix)
457
-
498
+
458
499
  if not rename_result["success"]:
459
500
  return {
460
501
  "success": False,
461
- "error_message": f"Atomic rename failed: {rename_result['error_message']}"
502
+ "error_message": f"Atomic rename failed: {rename_result['error_message']}",
462
503
  }
463
-
504
+
464
505
  logger.success("Atomic rename completed successfully")
465
-
506
+
466
507
  return {
467
508
  "success": True,
468
- "message": f"Restore completed successfully. Restored {len(tables_to_restore)} tables."
509
+ "message": f"Restore completed successfully. Restored {len(tables_to_restore)} tables.",
469
510
  }
470
-
511
+
471
512
  except Exception as e:
472
- return {
473
- "success": False,
474
- "error_message": f"Restore flow failed: {str(e)}"
475
- }
513
+ return {"success": False, "error_message": f"Restore flow failed: {str(e)}"}
476
514
 
477
515
 
478
- def _build_restore_command_with_rename(backup_label: str, repo_name: str, tables: List[str], rename_suffix: str, database: str, backup_timestamp: str) -> str:
516
+ def _build_restore_command_with_rename(
517
+ backup_label: str,
518
+ repo_name: str,
519
+ tables: list[str],
520
+ rename_suffix: str,
521
+ database: str,
522
+ backup_timestamp: str,
523
+ ) -> str:
479
524
  """Build restore command with AS clause for temporary table names."""
480
525
  table_clauses = []
481
526
  for table in tables:
482
- _, table_name = table.split('.', 1)
527
+ _, table_name = table.split(".", 1)
483
528
  temp_table_name = f"{table_name}{rename_suffix}"
484
- table_clauses.append(f"TABLE {table_name} AS {temp_table_name}")
485
-
529
+ table_clauses.append(
530
+ f"TABLE {utils.quote_identifier(table_name)} AS {utils.quote_identifier(temp_table_name)}"
531
+ )
532
+
486
533
  on_clause = ",\n ".join(table_clauses)
487
-
488
- return f"""RESTORE SNAPSHOT {backup_label}
489
- FROM {repo_name}
490
- DATABASE {database}
534
+
535
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
536
+ FROM {utils.quote_identifier(repo_name)}
537
+ DATABASE {utils.quote_identifier(database)}
491
538
  ON ({on_clause})
492
539
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
493
540
 
494
541
 
495
- def _build_restore_command_without_rename(backup_label: str, repo_name: str, tables: List[str], database: str, backup_timestamp: str) -> str:
542
+ def _build_restore_command_without_rename(
543
+ backup_label: str, repo_name: str, tables: list[str], database: str, backup_timestamp: str
544
+ ) -> str:
496
545
  """Build restore command without AS clause (for incremental restores to existing temp tables)."""
497
546
  table_clauses = []
498
547
  for table in tables:
499
- _, table_name = table.split('.', 1)
500
- table_clauses.append(f"TABLE {table_name}")
501
-
548
+ _, table_name = table.split(".", 1)
549
+ table_clauses.append(f"TABLE {utils.quote_identifier(table_name)}")
550
+
502
551
  on_clause = ",\n ".join(table_clauses)
503
-
504
- return f"""RESTORE SNAPSHOT {backup_label}
505
- FROM {repo_name}
506
- DATABASE {database}
552
+
553
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
554
+ FROM {utils.quote_identifier(repo_name)}
555
+ DATABASE {utils.quote_identifier(database)}
507
556
  ON ({on_clause})
508
557
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
509
558
 
510
559
 
511
560
  def _generate_timestamped_backup_name(table_name: str) -> str:
512
561
  """Generate a timestamped backup table name.
513
-
562
+
514
563
  Args:
515
564
  table_name: Original table name
516
-
565
+
517
566
  Returns:
518
567
  Timestamped backup name in format: {table_name}_backup_YYYYMMDD_HHMMSS
519
568
  """
@@ -521,25 +570,26 @@ def _generate_timestamped_backup_name(table_name: str) -> str:
521
570
  return f"{table_name}_backup_{timestamp}"
522
571
 
523
572
 
524
- def _perform_atomic_rename(db, tables: List[str], rename_suffix: str) -> Dict:
573
+ def _perform_atomic_rename(db, tables: list[str], rename_suffix: str) -> dict:
525
574
  """Perform atomic rename of temporary tables to make them live."""
526
575
  try:
527
576
  rename_statements = []
528
577
  for table in tables:
529
- database, table_name = table.split('.', 1)
578
+ database, table_name = table.split(".", 1)
530
579
  temp_table_name = f"{table_name}{rename_suffix}"
531
580
  backup_table_name = _generate_timestamped_backup_name(table_name)
532
-
533
- rename_statements.append(f"ALTER TABLE {database}.{table_name} RENAME {backup_table_name}")
534
- rename_statements.append(f"ALTER TABLE {database}.{temp_table_name} RENAME {table_name}")
535
-
581
+
582
+ rename_statements.append(
583
+ f"ALTER TABLE {utils.build_qualified_table_name(database, table_name)} RENAME {utils.quote_identifier(backup_table_name)}"
584
+ )
585
+ rename_statements.append(
586
+ f"ALTER TABLE {utils.build_qualified_table_name(database, temp_table_name)} RENAME {utils.quote_identifier(table_name)}"
587
+ )
588
+
536
589
  for statement in rename_statements:
537
590
  db.execute(statement)
538
-
591
+
539
592
  return {"success": True}
540
-
593
+
541
594
  except Exception as e:
542
- return {
543
- "success": False,
544
- "error_message": f"Failed to perform atomic rename: {str(e)}"
545
- }
595
+ return {"success": False, "error_message": f"Failed to perform atomic rename: {str(e)}"}