starrocks-br 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
starrocks_br/restore.py CHANGED
@@ -1,41 +1,56 @@
1
- import time
2
1
  import datetime
3
- from typing import Dict, List, Optional
4
- from . import history, concurrency, logger
2
+ import time
3
+ from typing import Optional
4
+
5
+ from . import concurrency, history, logger, timezone, utils
6
+
7
+ MAX_POLLS = 86400 # 1 day
8
+
9
+
10
+ def _calculate_next_interval(current_interval: float, max_interval: float) -> float:
11
+ """Calculate the next polling interval using exponential backoff.
12
+
13
+ Args:
14
+ current_interval: Current polling interval in seconds
15
+ max_interval: Maximum allowed interval in seconds
16
+
17
+ Returns:
18
+ Next interval (min of doubled current interval and max_interval)
19
+ """
20
+ return min(current_interval * 2, max_interval)
5
21
 
6
- MAX_POLLS = 86400 # 1 day
7
22
 
8
23
  def get_snapshot_timestamp(db, repo_name: str, snapshot_name: str) -> str:
9
24
  """Get the backup timestamp for a specific snapshot from the repository.
10
-
25
+
11
26
  Args:
12
27
  db: Database connection
13
28
  repo_name: Repository name
14
29
  snapshot_name: Snapshot name to look up
15
-
30
+
16
31
  Returns:
17
32
  The backup timestamp string
18
-
33
+
19
34
  Raises:
20
35
  ValueError: If snapshot is not found in the repository
21
36
  """
22
- query = f"SHOW SNAPSHOT ON {repo_name} WHERE Snapshot = '{snapshot_name}'"
23
-
37
+ query = f"SHOW SNAPSHOT ON {utils.quote_identifier(repo_name)} WHERE Snapshot = {utils.quote_value(snapshot_name)}"
38
+
24
39
  rows = db.query(query)
25
40
  if not rows:
26
41
  raise ValueError(f"Snapshot '{snapshot_name}' not found in repository '{repo_name}'")
27
-
42
+
28
43
  # The result should be a single row with columns: Snapshot, Timestamp, Status
29
44
  result = rows[0]
30
-
45
+
31
46
  if isinstance(result, dict):
32
47
  timestamp = result.get("Timestamp")
33
48
  else:
34
49
  timestamp = result[1] if len(result) > 1 else None
35
-
50
+
36
51
  if not timestamp:
37
52
  raise ValueError(f"Could not extract timestamp for snapshot '{snapshot_name}'")
38
-
53
+
39
54
  return timestamp
40
55
 
41
56
 
@@ -48,10 +63,10 @@ def build_partition_restore_command(
48
63
  backup_timestamp: str,
49
64
  ) -> str:
50
65
  """Build RESTORE command for single partition recovery."""
51
- return f"""RESTORE SNAPSHOT {backup_label}
52
- FROM {repository}
53
- DATABASE {database}
54
- ON (TABLE {table} PARTITION ({partition}))
66
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
67
+ FROM {utils.quote_identifier(repository)}
68
+ DATABASE {utils.quote_identifier(database)}
69
+ ON (TABLE {utils.quote_identifier(table)} PARTITION ({utils.quote_identifier(partition)}))
55
70
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
56
71
 
57
72
 
@@ -63,10 +78,10 @@ def build_table_restore_command(
63
78
  backup_timestamp: str,
64
79
  ) -> str:
65
80
  """Build RESTORE command for full table recovery."""
66
- return f"""RESTORE SNAPSHOT {backup_label}
67
- FROM {repository}
68
- DATABASE {database}
69
- ON (TABLE {table})
81
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
82
+ FROM {utils.quote_identifier(repository)}
83
+ DATABASE {utils.quote_identifier(database)}
84
+ ON (TABLE {utils.quote_identifier(table)})
70
85
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
71
86
 
72
87
 
@@ -77,47 +92,57 @@ def build_database_restore_command(
77
92
  backup_timestamp: str,
78
93
  ) -> str:
79
94
  """Build RESTORE command for full database recovery."""
80
- return f"""RESTORE SNAPSHOT {backup_label}
81
- FROM {repository}
82
- DATABASE {database}
95
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
96
+ FROM {utils.quote_identifier(repository)}
97
+ DATABASE {utils.quote_identifier(database)}
83
98
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
84
99
 
85
100
 
86
- def poll_restore_status(db, label: str, database: str, max_polls: int = MAX_POLLS, poll_interval: float = 1.0) -> Dict[str, str]:
101
+ def poll_restore_status(
102
+ db,
103
+ label: str,
104
+ database: str,
105
+ max_polls: int = MAX_POLLS,
106
+ poll_interval: float = 1.0,
107
+ max_poll_interval: float = 60.0,
108
+ ) -> dict[str, str]:
87
109
  """Poll restore status until completion or timeout.
88
-
110
+
89
111
  Note: SHOW RESTORE only returns the LAST restore in a database.
90
112
  We verify that the Label matches our expected label.
91
-
113
+
92
114
  Important: If we see a different label, it means another restore
93
115
  operation overwrote ours and we've lost tracking (race condition).
94
-
116
+
95
117
  Args:
96
118
  db: Database connection
97
119
  label: Expected snapshot label to monitor
98
120
  database: Database name where restore was submitted
99
121
  max_polls: Maximum number of polling attempts
100
- poll_interval: Seconds to wait between polls
101
-
122
+ poll_interval: Initial seconds to wait between polls (exponentially increases)
123
+ max_poll_interval: Maximum interval between polls (default 60 seconds)
124
+
102
125
  Returns dictionary with keys: state, label
103
126
  Possible states: FINISHED, CANCELLED, TIMEOUT, ERROR, LOST
104
127
  """
105
- query = f"SHOW RESTORE FROM {database}"
128
+ query = f"SHOW RESTORE FROM {utils.quote_identifier(database)}"
106
129
  first_poll = True
107
130
  last_state = None
108
131
  poll_count = 0
109
-
132
+ current_interval = poll_interval
133
+
110
134
  for _ in range(max_polls):
111
135
  poll_count += 1
112
136
  try:
113
137
  rows = db.query(query)
114
-
138
+
115
139
  if not rows:
116
- time.sleep(poll_interval)
140
+ time.sleep(current_interval)
141
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
117
142
  continue
118
-
143
+
119
144
  result = rows[0]
120
-
145
+
121
146
  if isinstance(result, dict):
122
147
  snapshot_label = result.get("Label", "")
123
148
  state = result.get("State", "UNKNOWN")
@@ -125,29 +150,31 @@ def poll_restore_status(db, label: str, database: str, max_polls: int = MAX_POLL
125
150
  # Tuple format: JobId, Label, Timestamp, DbName, State, ...
126
151
  snapshot_label = result[1] if len(result) > 1 else ""
127
152
  state = result[4] if len(result) > 4 else "UNKNOWN"
128
-
153
+
129
154
  if snapshot_label != label and snapshot_label:
130
155
  if first_poll:
131
156
  first_poll = False
132
- time.sleep(poll_interval)
157
+ time.sleep(current_interval)
158
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
133
159
  continue
134
160
  else:
135
161
  return {"state": "LOST", "label": label}
136
-
162
+
137
163
  first_poll = False
138
-
164
+
139
165
  if state != last_state or poll_count % 10 == 0:
140
166
  logger.progress(f"Restore status: {state} (poll {poll_count}/{max_polls})")
141
167
  last_state = state
142
-
168
+
143
169
  if state in ["FINISHED", "CANCELLED", "UNKNOWN"]:
144
170
  return {"state": state, "label": label}
145
-
146
- time.sleep(poll_interval)
147
-
171
+
172
+ time.sleep(current_interval)
173
+ current_interval = _calculate_next_interval(current_interval, max_poll_interval)
174
+
148
175
  except Exception:
149
176
  return {"state": "ERROR", "label": label}
150
-
177
+
151
178
  return {"state": "TIMEOUT", "label": label}
152
179
 
153
180
 
@@ -161,13 +188,14 @@ def execute_restore(
161
188
  max_polls: int = MAX_POLLS,
162
189
  poll_interval: float = 1.0,
163
190
  scope: str = "restore",
164
- ) -> Dict:
191
+ ) -> dict:
165
192
  """Execute a complete restore workflow: submit command and monitor progress.
166
-
193
+
167
194
  Returns dictionary with keys: success, final_status, error_message
168
195
  """
169
- started_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
170
-
196
+ cluster_tz = db.timezone
197
+ started_at = timezone.get_current_time_in_cluster_tz(cluster_tz)
198
+
171
199
  try:
172
200
  db.execute(restore_command.strip())
173
201
  except Exception as e:
@@ -175,17 +203,17 @@ def execute_restore(
175
203
  return {
176
204
  "success": False,
177
205
  "final_status": None,
178
- "error_message": f"Failed to submit restore command: {str(e)}"
206
+ "error_message": f"Failed to submit restore command: {str(e)}",
179
207
  }
180
-
208
+
181
209
  label = backup_label
182
-
210
+
183
211
  try:
184
212
  final_status = poll_restore_status(db, label, database, max_polls, poll_interval)
185
-
213
+
186
214
  success = final_status["state"] == "FINISHED"
187
- finished_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
188
-
215
+ finished_at = timezone.get_current_time_in_cluster_tz(cluster_tz)
216
+
189
217
  try:
190
218
  history.log_restore(
191
219
  db,
@@ -202,147 +230,155 @@ def execute_restore(
202
230
  )
203
231
  except Exception as e:
204
232
  logger.error(f"Failed to log restore history: {str(e)}")
205
-
233
+
206
234
  try:
207
- concurrency.complete_job_slot(db, scope=scope, label=label, final_state=final_status["state"])
235
+ concurrency.complete_job_slot(
236
+ db, scope=scope, label=label, final_state=final_status["state"]
237
+ )
208
238
  except Exception as e:
209
239
  logger.error(f"Failed to complete job slot: {str(e)}")
210
-
240
+
211
241
  return {
212
242
  "success": success,
213
243
  "final_status": final_status,
214
- "error_message": None if success else f"Restore failed with state: {final_status['state']}"
244
+ "error_message": None
245
+ if success
246
+ else f"Restore failed with state: {final_status['state']}",
215
247
  }
216
-
248
+
217
249
  except Exception as e:
218
250
  logger.error(f"Restore execution failed: {str(e)}")
219
- return {
220
- "success": False,
221
- "final_status": None,
222
- "error_message": str(e)
223
- }
251
+ return {"success": False, "final_status": None, "error_message": str(e)}
224
252
 
225
253
 
226
- def find_restore_pair(db, target_label: str) -> List[str]:
254
+ def find_restore_pair(db, target_label: str) -> list[str]:
227
255
  """Find the correct sequence of backups needed for restore.
228
-
256
+
229
257
  Args:
230
258
  db: Database connection
231
259
  target_label: The backup label to restore to
232
-
260
+
233
261
  Returns:
234
262
  List of backup labels in restore order [base_full_backup, target_label]
235
263
  or [target_label] if target is a full backup
236
-
264
+
237
265
  Raises:
238
266
  ValueError: If target label not found or incremental has no preceding full backup
239
267
  """
240
268
  query = f"""
241
269
  SELECT label, backup_type, finished_at
242
270
  FROM ops.backup_history
243
- WHERE label = '{target_label}'
271
+ WHERE label = {utils.quote_value(target_label)}
244
272
  AND status = 'FINISHED'
245
273
  """
246
-
274
+
247
275
  rows = db.query(query)
248
276
  if not rows:
249
277
  raise ValueError(f"Backup label '{target_label}' not found or not successful")
250
-
251
- target_info = {
252
- "label": rows[0][0],
253
- "backup_type": rows[0][1],
254
- "finished_at": rows[0][2]
255
- }
256
-
278
+
279
+ target_info = {"label": rows[0][0], "backup_type": rows[0][1], "finished_at": rows[0][2]}
280
+
257
281
  if target_info["backup_type"] == "full":
258
282
  return [target_label]
259
-
283
+
260
284
  if target_info["backup_type"] == "incremental":
261
- database_name = target_label.split('_')[0]
262
-
285
+ database_name = target_label.split("_")[0]
286
+
263
287
  full_backup_query = f"""
264
288
  SELECT label, backup_type, finished_at
265
289
  FROM ops.backup_history
266
290
  WHERE backup_type = 'full'
267
291
  AND status = 'FINISHED'
268
- AND label LIKE '{database_name}_%'
269
- AND finished_at < '{target_info["finished_at"]}'
292
+ AND label LIKE {utils.quote_value(f"{database_name}_%")}
293
+ AND finished_at < {utils.quote_value(target_info["finished_at"])}
270
294
  ORDER BY finished_at DESC
271
295
  LIMIT 1
272
296
  """
273
-
297
+
274
298
  full_rows = db.query(full_backup_query)
275
299
  if not full_rows:
276
300
  raise ValueError(f"No successful full backup found before incremental '{target_label}'")
277
-
301
+
278
302
  base_full_backup = full_rows[0][0]
279
303
  return [base_full_backup, target_label]
280
-
281
- raise ValueError(f"Unknown backup type '{target_info['backup_type']}' for label '{target_label}'")
282
304
 
305
+ raise ValueError(
306
+ f"Unknown backup type '{target_info['backup_type']}' for label '{target_label}'"
307
+ )
283
308
 
284
- def get_tables_from_backup(db, label: str, group: Optional[str] = None, table: Optional[str] = None, database: Optional[str] = None) -> List[str]:
309
+
310
+ def get_tables_from_backup(
311
+ db,
312
+ label: str,
313
+ group: Optional[str] = None,
314
+ table: Optional[str] = None,
315
+ database: Optional[str] = None,
316
+ ) -> list[str]:
285
317
  """Get list of tables to restore from backup manifest.
286
-
318
+
287
319
  Args:
288
320
  db: Database connection
289
321
  label: Backup label
290
322
  group: Optional inventory group to filter tables
291
323
  table: Optional table name to filter (single table, database comes from database parameter)
292
324
  database: Database name (required if table is specified)
293
-
325
+
294
326
  Returns:
295
327
  List of table names to restore (format: database.table)
296
-
328
+
297
329
  Raises:
298
330
  ValueError: If both group and table are specified
299
331
  ValueError: If table is specified but database is not provided
300
332
  ValueError: If table is specified but not found in backup
301
333
  """
302
334
  if group and table:
303
- raise ValueError("Cannot specify both --group and --table. Use --table for single table restore or --group for inventory group restore.")
304
-
335
+ raise ValueError(
336
+ "Cannot specify both --group and --table. Use --table for single table restore or --group for inventory group restore."
337
+ )
338
+
305
339
  if table and not database:
306
340
  raise ValueError("database parameter is required when table is specified")
307
-
341
+
308
342
  query = f"""
309
343
  SELECT DISTINCT database_name, table_name
310
344
  FROM ops.backup_partitions
311
- WHERE label = '{label}'
345
+ WHERE label = {utils.quote_value(label)}
312
346
  ORDER BY database_name, table_name
313
347
  """
314
-
348
+
315
349
  rows = db.query(query)
316
350
  if not rows:
317
351
  return []
318
-
352
+
319
353
  tables = [f"{row[0]}.{row[1]}" for row in rows]
320
-
354
+
321
355
  if table:
322
356
  target_table = f"{database}.{table}"
323
357
  filtered_tables = [t for t in tables if t == target_table]
324
-
358
+
325
359
  if not filtered_tables:
326
- raise ValueError(f"Table '{table}' not found in backup '{label}' for database '{database}'")
327
-
360
+ raise ValueError(
361
+ f"Table '{table}' not found in backup '{label}' for database '{database}'"
362
+ )
363
+
328
364
  return filtered_tables
329
-
365
+
330
366
  if group:
331
367
  group_query = f"""
332
368
  SELECT database_name, table_name
333
369
  FROM ops.table_inventory
334
- WHERE inventory_group = '{group}'
370
+ WHERE inventory_group = {utils.quote_value(group)}
335
371
  """
336
-
372
+
337
373
  group_rows = db.query(group_query)
338
374
  if not group_rows:
339
375
  return []
340
-
376
+
341
377
  group_tables = set()
342
378
  for row in group_rows:
343
379
  database_name, table_name = row[0], row[1]
344
- if table_name == '*':
345
- show_tables_query = f"SHOW TABLES FROM {database_name}"
380
+ if table_name == "*":
381
+ show_tables_query = f"SHOW TABLES FROM {utils.quote_identifier(database_name)}"
346
382
  try:
347
383
  tables_rows = db.query(show_tables_query)
348
384
  for table_row in tables_rows:
@@ -351,15 +387,22 @@ def get_tables_from_backup(db, label: str, group: Optional[str] = None, table: O
351
387
  continue
352
388
  else:
353
389
  group_tables.add(f"{database_name}.{table_name}")
354
-
390
+
355
391
  tables = [table for table in tables if table in group_tables]
356
-
392
+
357
393
  return tables
358
394
 
359
395
 
360
- def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_restore: List[str], rename_suffix: str = "_restored", skip_confirmation: bool = False) -> Dict:
396
+ def execute_restore_flow(
397
+ db,
398
+ repo_name: str,
399
+ restore_pair: list[str],
400
+ tables_to_restore: list[str],
401
+ rename_suffix: str = "_restored",
402
+ skip_confirmation: bool = False,
403
+ ) -> dict:
361
404
  """Execute the complete restore flow with safety measures.
362
-
405
+
363
406
  Args:
364
407
  db: Database connection
365
408
  repo_name: Repository name
@@ -367,22 +410,16 @@ def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_
367
410
  tables_to_restore: List of tables to restore (format: database.table)
368
411
  rename_suffix: Suffix for temporary tables
369
412
  skip_confirmation: If True, skip interactive confirmation prompt
370
-
413
+
371
414
  Returns:
372
415
  Dictionary with success status and details
373
416
  """
374
417
  if not restore_pair:
375
- return {
376
- "success": False,
377
- "error_message": "No restore pair provided"
378
- }
379
-
418
+ return {"success": False, "error_message": "No restore pair provided"}
419
+
380
420
  if not tables_to_restore:
381
- return {
382
- "success": False,
383
- "error_message": "No tables to restore"
384
- }
385
-
421
+ return {"success": False, "error_message": "No tables to restore"}
422
+
386
423
  logger.info("")
387
424
  logger.info("=== RESTORE PLAN ===")
388
425
  logger.info(f"Repository: {repo_name}")
@@ -392,128 +429,143 @@ def execute_restore_flow(db, repo_name: str, restore_pair: List[str], tables_to_
392
429
  logger.info("")
393
430
  logger.info("This will restore data to temporary tables and then perform atomic rename.")
394
431
  logger.warning("WARNING: This operation will replace existing tables!")
395
-
432
+
396
433
  if not skip_confirmation:
397
434
  confirmation = input("\nDo you want to proceed? [Y/n]: ").strip()
398
- if confirmation.lower() != 'y':
399
- return {
400
- "success": False,
401
- "error_message": "Restore operation cancelled by user"
402
- }
435
+ if confirmation.lower() != "y":
436
+ return {"success": False, "error_message": "Restore operation cancelled by user"}
403
437
  else:
404
438
  logger.info("Proceeding automatically (--yes flag provided)")
405
-
439
+
406
440
  try:
407
- database_name = tables_to_restore[0].split('.')[0]
408
-
441
+ database_name = tables_to_restore[0].split(".")[0]
442
+
409
443
  base_label = restore_pair[0]
410
444
  logger.info("")
411
445
  logger.info(f"Step 1: Restoring base backup '{base_label}'...")
412
-
446
+
413
447
  base_timestamp = get_snapshot_timestamp(db, repo_name, base_label)
414
-
448
+
415
449
  base_restore_command = _build_restore_command_with_rename(
416
450
  base_label, repo_name, tables_to_restore, rename_suffix, database_name, base_timestamp
417
451
  )
418
-
452
+
419
453
  base_result = execute_restore(
420
454
  db, base_restore_command, base_label, "full", repo_name, database_name, scope="restore"
421
455
  )
422
-
456
+
423
457
  if not base_result["success"]:
424
458
  return {
425
459
  "success": False,
426
- "error_message": f"Base restore failed: {base_result['error_message']}"
460
+ "error_message": f"Base restore failed: {base_result['error_message']}",
427
461
  }
428
-
462
+
429
463
  logger.success("Base restore completed successfully")
430
-
464
+
431
465
  if len(restore_pair) > 1:
432
466
  incremental_label = restore_pair[1]
433
467
  logger.info("")
434
468
  logger.info(f"Step 2: Applying incremental backup '{incremental_label}'...")
435
-
469
+
436
470
  incremental_timestamp = get_snapshot_timestamp(db, repo_name, incremental_label)
437
-
471
+
438
472
  incremental_restore_command = _build_restore_command_without_rename(
439
- incremental_label, repo_name, tables_to_restore, database_name, incremental_timestamp
473
+ incremental_label,
474
+ repo_name,
475
+ tables_to_restore,
476
+ database_name,
477
+ incremental_timestamp,
440
478
  )
441
-
479
+
442
480
  incremental_result = execute_restore(
443
- db, incremental_restore_command, incremental_label, "incremental", repo_name, database_name, scope="restore"
481
+ db,
482
+ incremental_restore_command,
483
+ incremental_label,
484
+ "incremental",
485
+ repo_name,
486
+ database_name,
487
+ scope="restore",
444
488
  )
445
-
489
+
446
490
  if not incremental_result["success"]:
447
491
  return {
448
492
  "success": False,
449
- "error_message": f"Incremental restore failed: {incremental_result['error_message']}"
493
+ "error_message": f"Incremental restore failed: {incremental_result['error_message']}",
450
494
  }
451
-
495
+
452
496
  logger.success("Incremental restore completed successfully")
453
-
497
+
454
498
  logger.info("")
455
499
  logger.info("Step 3: Performing atomic rename...")
456
500
  rename_result = _perform_atomic_rename(db, tables_to_restore, rename_suffix)
457
-
501
+
458
502
  if not rename_result["success"]:
459
503
  return {
460
504
  "success": False,
461
- "error_message": f"Atomic rename failed: {rename_result['error_message']}"
505
+ "error_message": f"Atomic rename failed: {rename_result['error_message']}",
462
506
  }
463
-
507
+
464
508
  logger.success("Atomic rename completed successfully")
465
-
509
+
466
510
  return {
467
511
  "success": True,
468
- "message": f"Restore completed successfully. Restored {len(tables_to_restore)} tables."
512
+ "message": f"Restore completed successfully. Restored {len(tables_to_restore)} tables.",
469
513
  }
470
-
514
+
471
515
  except Exception as e:
472
- return {
473
- "success": False,
474
- "error_message": f"Restore flow failed: {str(e)}"
475
- }
516
+ return {"success": False, "error_message": f"Restore flow failed: {str(e)}"}
476
517
 
477
518
 
478
- def _build_restore_command_with_rename(backup_label: str, repo_name: str, tables: List[str], rename_suffix: str, database: str, backup_timestamp: str) -> str:
519
+ def _build_restore_command_with_rename(
520
+ backup_label: str,
521
+ repo_name: str,
522
+ tables: list[str],
523
+ rename_suffix: str,
524
+ database: str,
525
+ backup_timestamp: str,
526
+ ) -> str:
479
527
  """Build restore command with AS clause for temporary table names."""
480
528
  table_clauses = []
481
529
  for table in tables:
482
- _, table_name = table.split('.', 1)
530
+ _, table_name = table.split(".", 1)
483
531
  temp_table_name = f"{table_name}{rename_suffix}"
484
- table_clauses.append(f"TABLE {table_name} AS {temp_table_name}")
485
-
532
+ table_clauses.append(
533
+ f"TABLE {utils.quote_identifier(table_name)} AS {utils.quote_identifier(temp_table_name)}"
534
+ )
535
+
486
536
  on_clause = ",\n ".join(table_clauses)
487
-
488
- return f"""RESTORE SNAPSHOT {backup_label}
489
- FROM {repo_name}
490
- DATABASE {database}
537
+
538
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
539
+ FROM {utils.quote_identifier(repo_name)}
540
+ DATABASE {utils.quote_identifier(database)}
491
541
  ON ({on_clause})
492
542
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
493
543
 
494
544
 
495
- def _build_restore_command_without_rename(backup_label: str, repo_name: str, tables: List[str], database: str, backup_timestamp: str) -> str:
545
+ def _build_restore_command_without_rename(
546
+ backup_label: str, repo_name: str, tables: list[str], database: str, backup_timestamp: str
547
+ ) -> str:
496
548
  """Build restore command without AS clause (for incremental restores to existing temp tables)."""
497
549
  table_clauses = []
498
550
  for table in tables:
499
- _, table_name = table.split('.', 1)
500
- table_clauses.append(f"TABLE {table_name}")
501
-
551
+ _, table_name = table.split(".", 1)
552
+ table_clauses.append(f"TABLE {utils.quote_identifier(table_name)}")
553
+
502
554
  on_clause = ",\n ".join(table_clauses)
503
-
504
- return f"""RESTORE SNAPSHOT {backup_label}
505
- FROM {repo_name}
506
- DATABASE {database}
555
+
556
+ return f"""RESTORE SNAPSHOT {utils.quote_identifier(backup_label)}
557
+ FROM {utils.quote_identifier(repo_name)}
558
+ DATABASE {utils.quote_identifier(database)}
507
559
  ON ({on_clause})
508
560
  PROPERTIES ("backup_timestamp" = "{backup_timestamp}")"""
509
561
 
510
562
 
511
563
  def _generate_timestamped_backup_name(table_name: str) -> str:
512
564
  """Generate a timestamped backup table name.
513
-
565
+
514
566
  Args:
515
567
  table_name: Original table name
516
-
568
+
517
569
  Returns:
518
570
  Timestamped backup name in format: {table_name}_backup_YYYYMMDD_HHMMSS
519
571
  """
@@ -521,25 +573,26 @@ def _generate_timestamped_backup_name(table_name: str) -> str:
521
573
  return f"{table_name}_backup_{timestamp}"
522
574
 
523
575
 
524
- def _perform_atomic_rename(db, tables: List[str], rename_suffix: str) -> Dict:
576
+ def _perform_atomic_rename(db, tables: list[str], rename_suffix: str) -> dict:
525
577
  """Perform atomic rename of temporary tables to make them live."""
526
578
  try:
527
579
  rename_statements = []
528
580
  for table in tables:
529
- database, table_name = table.split('.', 1)
581
+ database, table_name = table.split(".", 1)
530
582
  temp_table_name = f"{table_name}{rename_suffix}"
531
583
  backup_table_name = _generate_timestamped_backup_name(table_name)
532
-
533
- rename_statements.append(f"ALTER TABLE {database}.{table_name} RENAME {backup_table_name}")
534
- rename_statements.append(f"ALTER TABLE {database}.{temp_table_name} RENAME {table_name}")
535
-
584
+
585
+ rename_statements.append(
586
+ f"ALTER TABLE {utils.build_qualified_table_name(database, table_name)} RENAME {utils.quote_identifier(backup_table_name)}"
587
+ )
588
+ rename_statements.append(
589
+ f"ALTER TABLE {utils.build_qualified_table_name(database, temp_table_name)} RENAME {utils.quote_identifier(table_name)}"
590
+ )
591
+
536
592
  for statement in rename_statements:
537
593
  db.execute(statement)
538
-
594
+
539
595
  return {"success": True}
540
-
596
+
541
597
  except Exception as e:
542
- return {
543
- "success": False,
544
- "error_message": f"Failed to perform atomic rename: {str(e)}"
545
- }
598
+ return {"success": False, "error_message": f"Failed to perform atomic rename: {str(e)}"}