starrocks-br 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
- from typing import Literal, List, Tuple
2
- from . import logger
1
+ from typing import Literal
2
+
3
+ from . import logger, utils
3
4
 
4
5
 
5
6
  def reserve_job_slot(db, scope: str, label: str) -> None:
@@ -9,23 +10,23 @@ def reserve_job_slot(db, scope: str, label: str) -> None:
9
10
  However, we implement self-healing logic to automatically clean up stale locks.
10
11
  """
11
12
  active_jobs = _get_active_jobs_for_scope(db, scope)
12
-
13
+
13
14
  if not active_jobs:
14
15
  _insert_new_job(db, scope, label)
15
16
  return
16
-
17
+
17
18
  _handle_active_job_conflicts(db, scope, active_jobs)
18
-
19
+
19
20
  _insert_new_job(db, scope, label)
20
21
 
21
22
 
22
- def _get_active_jobs_for_scope(db, scope: str) -> List[Tuple[str, str, str]]:
23
+ def _get_active_jobs_for_scope(db, scope: str) -> list[tuple[str, str, str]]:
23
24
  """Get all active jobs for the given scope."""
24
25
  rows = db.query("SELECT scope, label, state FROM ops.run_status WHERE state = 'ACTIVE'")
25
26
  return [row for row in rows if row[0] == scope]
26
27
 
27
28
 
28
- def _handle_active_job_conflicts(db, scope: str, active_jobs: List[Tuple[str, str, str]]) -> None:
29
+ def _handle_active_job_conflicts(db, scope: str, active_jobs: list[tuple[str, str, str]]) -> None:
29
30
  """Handle conflicts with active jobs, cleaning up stale ones where possible."""
30
31
  for active_scope, active_label, _ in active_jobs:
31
32
  if _can_heal_stale_job(active_scope, active_label, db):
@@ -37,17 +38,17 @@ def _handle_active_job_conflicts(db, scope: str, active_jobs: List[Tuple[str, st
37
38
 
38
39
  def _can_heal_stale_job(scope: str, label: str, db) -> bool:
39
40
  """Check if a stale job can be healed (only for backup jobs)."""
40
- if scope != 'backup':
41
+ if scope != "backup":
41
42
  return False
42
-
43
+
43
44
  return _is_backup_job_stale(db, label)
44
45
 
45
46
 
46
- def _raise_concurrency_conflict(scope: str, active_jobs: List[Tuple[str, str, str]]) -> None:
47
+ def _raise_concurrency_conflict(scope: str, active_jobs: list[tuple[str, str, str]]) -> None:
47
48
  """Raise a concurrency conflict error with helpful message."""
48
49
  active_job_strings = [f"{job[0]}:{job[1]}" for job in active_jobs]
49
50
  active_labels = [job[1] for job in active_jobs]
50
-
51
+
51
52
  raise RuntimeError(
52
53
  f"Concurrency conflict: Another '{scope}' job is already ACTIVE: {', '.join(active_job_strings)}. "
53
54
  f"Wait for it to complete or cancel it via: UPDATE ops.run_status SET state='CANCELLED' "
@@ -57,48 +58,48 @@ def _raise_concurrency_conflict(scope: str, active_jobs: List[Tuple[str, str, st
57
58
 
58
59
  def _insert_new_job(db, scope: str, label: str) -> None:
59
60
  """Insert a new active job record."""
60
- sql = (
61
- "INSERT INTO ops.run_status (scope, label, state, started_at) "
62
- "VALUES ('%s','%s','ACTIVE', NOW())" % (scope, label)
63
- )
61
+ sql = f"""
62
+ INSERT INTO ops.run_status (scope, label, state, started_at)
63
+ VALUES ({utils.quote_value(scope)}, {utils.quote_value(label)}, 'ACTIVE', NOW())
64
+ """
64
65
  db.execute(sql)
65
66
 
66
67
 
67
68
  def _is_backup_job_stale(db, label: str) -> bool:
68
69
  """Check if a backup job is stale by querying StarRocks SHOW BACKUP.
69
-
70
+
70
71
  Returns True if the job is stale (not actually running), False if it's still active.
71
72
  """
72
73
  try:
73
74
  user_databases = _get_user_databases(db)
74
-
75
+
75
76
  for database_name in user_databases:
76
77
  job_status = _check_backup_job_in_database(db, database_name, label)
77
-
78
+
78
79
  if job_status is None:
79
80
  continue
80
-
81
+
81
82
  if job_status == "active":
82
83
  return False
83
84
  elif job_status == "stale":
84
85
  return True
85
-
86
+
86
87
  return True
87
-
88
+
88
89
  except Exception as e:
89
90
  logger.error(f"Error checking backup job status: {e}")
90
91
  return False
91
92
 
92
93
 
93
- def _get_user_databases(db) -> List[str]:
94
+ def _get_user_databases(db) -> list[str]:
94
95
  """Get list of user databases (excluding system databases)."""
95
- SYSTEM_DATABASES = {'information_schema', 'mysql', 'sys', 'ops'}
96
-
96
+ system_databases = {"information_schema", "mysql", "sys", "ops"}
97
+
97
98
  databases = db.query("SHOW DATABASES")
98
99
  return [
99
- _extract_database_name(db_row)
100
- for db_row in databases
101
- if _extract_database_name(db_row) not in SYSTEM_DATABASES
100
+ _extract_database_name(db_row)
101
+ for db_row in databases
102
+ if _extract_database_name(db_row) not in system_databases
102
103
  ]
103
104
 
104
105
 
@@ -106,40 +107,40 @@ def _extract_database_name(db_row) -> str:
106
107
  """Extract database name from database query result."""
107
108
  if isinstance(db_row, (list, tuple)):
108
109
  return db_row[0]
109
- return db_row.get('Database', '')
110
+ return db_row.get("Database", "")
110
111
 
111
112
 
112
113
  def _check_backup_job_in_database(db, database_name: str, label: str) -> str:
113
114
  """Check if backup job exists in specific database and return its status.
114
-
115
+
115
116
  Returns:
116
117
  'active' if job is still running
117
- 'stale' if job is in terminal state
118
+ 'stale' if job is in terminal state
118
119
  None if job not found in this database
119
120
  """
120
121
  try:
121
- show_backup_query = f"SHOW BACKUP FROM {database_name}"
122
+ show_backup_query = f"SHOW BACKUP FROM {utils.quote_identifier(database_name)}"
122
123
  backup_rows = db.query(show_backup_query)
123
-
124
+
124
125
  if not backup_rows:
125
126
  return None
126
-
127
+
127
128
  result = backup_rows[0]
128
129
  snapshot_name, state = _extract_backup_info(result)
129
-
130
+
130
131
  if snapshot_name != label:
131
132
  return None
132
-
133
+
133
134
  if state in ["FINISHED", "CANCELLED", "FAILED"]:
134
135
  return "stale"
135
136
  else:
136
137
  return "active"
137
-
138
+
138
139
  except Exception:
139
140
  return None
140
141
 
141
142
 
142
- def _extract_backup_info(result) -> Tuple[str, str]:
143
+ def _extract_backup_info(result) -> tuple[str, str]:
143
144
  """Extract snapshot name and state from SHOW BACKUP result."""
144
145
  if isinstance(result, dict):
145
146
  snapshot_name = result.get("SnapshotName", "")
@@ -147,31 +148,30 @@ def _extract_backup_info(result) -> Tuple[str, str]:
147
148
  else:
148
149
  snapshot_name = result[1] if len(result) > 1 else ""
149
150
  state = result[3] if len(result) > 3 else "UNKNOWN"
150
-
151
+
151
152
  return snapshot_name, state
152
153
 
153
154
 
154
155
  def _cleanup_stale_job(db, scope: str, label: str) -> None:
155
156
  """Clean up a stale job by updating its state to CANCELLED."""
156
- sql = (
157
- "UPDATE ops.run_status SET state='CANCELLED', finished_at=NOW() "
158
- "WHERE scope='%s' AND label='%s' AND state='ACTIVE'" % (scope, label)
159
- )
157
+ sql = f"""
158
+ UPDATE ops.run_status
159
+ SET state='CANCELLED', finished_at=NOW()
160
+ WHERE scope={utils.quote_value(scope)} AND label={utils.quote_value(label)} AND state='ACTIVE'
161
+ """
160
162
  db.execute(sql)
161
163
 
162
164
 
163
165
  def complete_job_slot(
164
- db,
165
- scope: str,
166
- label: str,
167
- final_state: Literal['FINISHED', 'FAILED', 'CANCELLED']
166
+ db, scope: str, label: str, final_state: Literal["FINISHED", "FAILED", "CANCELLED"]
168
167
  ) -> None:
169
168
  """Complete job slot and persist final state.
170
169
 
171
170
  Simple approach: update the same row by scope/label.
172
171
  """
173
- sql = (
174
- "UPDATE ops.run_status SET state='%s', finished_at=NOW() WHERE scope='%s' AND label='%s'"
175
- % (final_state, scope, label)
176
- )
172
+ sql = f"""
173
+ UPDATE ops.run_status
174
+ SET state={utils.quote_value(final_state)}, finished_at=NOW()
175
+ WHERE scope={utils.quote_value(scope)} AND label={utils.quote_value(label)}
176
+ """
177
177
  db.execute(sql)
starrocks_br/config.py CHANGED
@@ -1,45 +1,46 @@
1
+ from typing import Any
2
+
1
3
  import yaml
2
- from typing import Any, Dict, Optional
3
4
 
4
5
 
5
- def load_config(config_path: str) -> Dict[str, Any]:
6
+ def load_config(config_path: str) -> dict[str, Any]:
6
7
  """Load and parse YAML configuration file.
7
-
8
+
8
9
  Args:
9
10
  config_path: Path to the YAML config file
10
-
11
+
11
12
  Returns:
12
13
  Dictionary containing configuration
13
-
14
+
14
15
  Raises:
15
16
  FileNotFoundError: If config file doesn't exist
16
17
  yaml.YAMLError: If config file is not valid YAML
17
18
  """
18
- with open(config_path, 'r') as f:
19
+ with open(config_path) as f:
19
20
  config = yaml.safe_load(f)
20
-
21
+
21
22
  if not isinstance(config, dict):
22
23
  raise ValueError("Config must be a dictionary")
23
-
24
+
24
25
  return config
25
26
 
26
27
 
27
- def validate_config(config: Dict[str, Any]) -> None:
28
+ def validate_config(config: dict[str, Any]) -> None:
28
29
  """Validate that config contains required fields.
29
-
30
+
30
31
  Args:
31
32
  config: Configuration dictionary
32
-
33
+
33
34
  Raises:
34
35
  ValueError: If required fields are missing
35
36
  """
36
- required_fields = ['host', 'port', 'user', 'database', 'repository']
37
-
37
+ required_fields = ["host", "port", "user", "database", "repository"]
38
+
38
39
  for field in required_fields:
39
40
  if field not in config:
40
41
  raise ValueError(f"Missing required config field: {field}")
41
42
 
42
- _validate_tls_section(config.get('tls'))
43
+ _validate_tls_section(config.get("tls"))
43
44
 
44
45
 
45
46
  def _validate_tls_section(tls_config) -> None:
@@ -49,16 +50,23 @@ def _validate_tls_section(tls_config) -> None:
49
50
  if not isinstance(tls_config, dict):
50
51
  raise ValueError("TLS configuration must be a dictionary")
51
52
 
52
- enabled = bool(tls_config.get('enabled', False))
53
+ enabled = bool(tls_config.get("enabled", False))
53
54
 
54
- if enabled and not tls_config.get('ca_cert'):
55
+ if enabled and not tls_config.get("ca_cert"):
55
56
  raise ValueError("TLS configuration requires 'ca_cert' when 'enabled' is true")
56
57
 
57
- if 'verify_server_cert' in tls_config and not isinstance(tls_config['verify_server_cert'], bool):
58
- raise ValueError("TLS configuration field 'verify_server_cert' must be a boolean if provided")
59
-
60
- if 'tls_versions' in tls_config:
61
- tls_versions = tls_config['tls_versions']
62
- if not isinstance(tls_versions, list) or not all(isinstance(version, str) for version in tls_versions):
63
- raise ValueError("TLS configuration field 'tls_versions' must be a list of strings if provided")
58
+ if "verify_server_cert" in tls_config and not isinstance(
59
+ tls_config["verify_server_cert"], bool
60
+ ):
61
+ raise ValueError(
62
+ "TLS configuration field 'verify_server_cert' must be a boolean if provided"
63
+ )
64
64
 
65
+ if "tls_versions" in tls_config:
66
+ tls_versions = tls_config["tls_versions"]
67
+ if not isinstance(tls_versions, list) or not all(
68
+ isinstance(version, str) for version in tls_versions
69
+ ):
70
+ raise ValueError(
71
+ "TLS configuration field 'tls_versions' must be a list of strings if provided"
72
+ )
starrocks_br/db.py CHANGED
@@ -1,10 +1,11 @@
1
+ from typing import Any, Optional
2
+
1
3
  import mysql.connector
2
- from typing import Any, Dict, List, Optional
3
4
 
4
5
 
5
6
  class StarRocksDB:
6
7
  """Database connection wrapper for StarRocks."""
7
-
8
+
8
9
  def __init__(
9
10
  self,
10
11
  host: str,
@@ -12,10 +13,10 @@ class StarRocksDB:
12
13
  user: str,
13
14
  password: str,
14
15
  database: str,
15
- tls_config: Optional[Dict[str, Any]] = None,
16
+ tls_config: Optional[dict[str, Any]] = None,
16
17
  ):
17
18
  """Initialize database connection.
18
-
19
+
19
20
  Args:
20
21
  host: Database host
21
22
  port: Database port
@@ -31,35 +32,35 @@ class StarRocksDB:
31
32
  self._connection = None
32
33
  self.tls_config = tls_config or {}
33
34
  self._timezone: Optional[str] = None
34
-
35
+
35
36
  def connect(self) -> None:
36
37
  """Establish database connection."""
37
- conn_args: Dict[str, Any] = {
38
- 'host': self.host,
39
- 'port': self.port,
40
- 'user': self.user,
41
- 'password': self.password,
42
- 'database': self.database,
38
+ conn_args: dict[str, Any] = {
39
+ "host": self.host,
40
+ "port": self.port,
41
+ "user": self.user,
42
+ "password": self.password,
43
+ "database": self.database,
43
44
  }
44
45
 
45
- if self.tls_config.get('enabled'):
46
- ssl_args: Dict[str, Any] = {
47
- 'ssl_ca': self.tls_config.get('ca_cert'),
48
- 'ssl_cert': self.tls_config.get('client_cert'),
49
- 'ssl_key': self.tls_config.get('client_key'),
50
- 'ssl_verify_cert': self.tls_config.get('verify_server_cert', True),
46
+ if self.tls_config.get("enabled"):
47
+ ssl_args: dict[str, Any] = {
48
+ "ssl_ca": self.tls_config.get("ca_cert"),
49
+ "ssl_cert": self.tls_config.get("client_cert"),
50
+ "ssl_key": self.tls_config.get("client_key"),
51
+ "ssl_verify_cert": self.tls_config.get("verify_server_cert", True),
51
52
  }
52
53
 
53
- tls_versions = self.tls_config.get('tls_versions', ['TLSv1.2', 'TLSv1.3'])
54
+ tls_versions = self.tls_config.get("tls_versions", ["TLSv1.2", "TLSv1.3"])
54
55
  if tls_versions:
55
- ssl_args['tls_versions'] = tls_versions
56
+ ssl_args["tls_versions"] = tls_versions
56
57
 
57
58
  conn_args.update({key: value for key, value in ssl_args.items() if value is not None})
58
59
 
59
60
  try:
60
61
  self._connection = mysql.connector.connect(**conn_args)
61
62
  except mysql.connector.Error as e:
62
- if self.tls_config.get('enabled') and "SSL is required" in str(e):
63
+ if self.tls_config.get("enabled") and "SSL is required" in str(e):
63
64
  raise mysql.connector.Error(
64
65
  f"TLS is enabled in configuration but StarRocks server doesn't support it. "
65
66
  f"Error: {e}. "
@@ -67,42 +68,42 @@ class StarRocksDB:
67
68
  f"Alternatively, set 'enabled: false' in the tls section of your config file."
68
69
  ) from e
69
70
  raise
70
-
71
+
71
72
  def close(self) -> None:
72
73
  """Close database connection."""
73
74
  if self._connection:
74
75
  self._connection.close()
75
76
  self._connection = None
76
-
77
+
77
78
  def execute(self, sql: str) -> None:
78
79
  """Execute a SQL statement that doesn't return results.
79
-
80
+
80
81
  Args:
81
82
  sql: SQL statement to execute
82
83
  """
83
84
  if not self._connection:
84
85
  self.connect()
85
-
86
+
86
87
  cursor = self._connection.cursor()
87
88
  try:
88
89
  cursor.execute(sql)
89
90
  self._connection.commit()
90
91
  finally:
91
92
  cursor.close()
92
-
93
- def query(self, sql: str, params: tuple = None) -> List[tuple]:
93
+
94
+ def query(self, sql: str, params: tuple = None) -> list[tuple]:
94
95
  """Execute a SQL query and return results.
95
-
96
+
96
97
  Args:
97
98
  sql: SQL query to execute
98
99
  params: Optional tuple of parameters for parameterized queries
99
-
100
+
100
101
  Returns:
101
102
  List of tuples containing query results
102
103
  """
103
104
  if not self._connection:
104
105
  self.connect()
105
-
106
+
106
107
  cursor = self._connection.cursor()
107
108
  try:
108
109
  if params:
@@ -112,24 +113,24 @@ class StarRocksDB:
112
113
  return cursor.fetchall()
113
114
  finally:
114
115
  cursor.close()
115
-
116
+
116
117
  def __enter__(self):
117
118
  """Context manager entry."""
118
119
  self.connect()
119
120
  return self
120
-
121
+
121
122
  def __exit__(self, exc_type, exc_val, exc_tb):
122
123
  """Context manager exit."""
123
124
  self.close()
124
-
125
+
125
126
  @property
126
127
  def timezone(self) -> str:
127
128
  """Get the StarRocks cluster timezone.
128
-
129
+
129
130
  Queries the cluster timezone on first access and caches it for subsequent use.
130
131
  If the query fails (e.g., database unavailable, connection error, permissions),
131
132
  defaults to 'UTC' to ensure the property always returns a valid timezone string.
132
-
133
+
133
134
  Returns:
134
135
  Timezone string (e.g., 'Asia/Shanghai', 'UTC', '+08:00')
135
136
  Defaults to 'UTC' if query fails or returns no results.
@@ -138,7 +139,7 @@ class StarRocksDB:
138
139
  try:
139
140
  query = "SHOW VARIABLES LIKE 'time_zone'"
140
141
  rows = self.query(query)
141
-
142
+
142
143
  if not rows:
143
144
  self._timezone = "UTC"
144
145
  else:
@@ -149,6 +150,5 @@ class StarRocksDB:
149
150
  self._timezone = row[1] if len(row) > 1 else "UTC"
150
151
  except Exception:
151
152
  self._timezone = "UTC"
152
-
153
- return self._timezone
154
153
 
154
+ return self._timezone