skypilot-nightly 1.0.0.dev20250807__py3-none-any.whl → 1.0.0.dev20250812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/kubernetes.py +5 -2
  3. sky/backends/backend_utils.py +57 -7
  4. sky/backends/cloud_vm_ray_backend.py +50 -8
  5. sky/client/cli/command.py +60 -26
  6. sky/client/sdk.py +132 -65
  7. sky/client/sdk_async.py +1 -1
  8. sky/core.py +10 -2
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/_next/static/{YAirOGsV1z6B2RJ0VIUmD → Fuy7OzApYTUMz2QgoP7dP}/_buildManifest.js +1 -1
  11. sky/dashboard/out/_next/static/chunks/{6601-3e21152fe16da09c.js → 6601-06114c982db410b6.js} +1 -1
  12. sky/dashboard/out/_next/static/chunks/8056-5bdeda81199c0def.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/{8969-318c3dca725e8e5d.js → 8969-c9686994ddafcf01.js} +1 -1
  14. sky/dashboard/out/_next/static/chunks/pages/{_app-1e6de35d15a8d432.js → _app-491a4d699d95e808.js} +1 -1
  15. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-078751bad714c017.js +11 -0
  16. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-da9cc0901349c2e9.js +1 -0
  17. sky/dashboard/out/_next/static/chunks/webpack-7fd0cf9dbecff10f.js +1 -0
  18. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  19. sky/dashboard/out/clusters/[cluster].html +1 -1
  20. sky/dashboard/out/clusters.html +1 -1
  21. sky/dashboard/out/config.html +1 -1
  22. sky/dashboard/out/index.html +1 -1
  23. sky/dashboard/out/infra/[context].html +1 -1
  24. sky/dashboard/out/infra.html +1 -1
  25. sky/dashboard/out/jobs/[job].html +1 -1
  26. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  27. sky/dashboard/out/jobs.html +1 -1
  28. sky/dashboard/out/users.html +1 -1
  29. sky/dashboard/out/volumes.html +1 -1
  30. sky/dashboard/out/workspace/new.html +1 -1
  31. sky/dashboard/out/workspaces/[name].html +1 -1
  32. sky/dashboard/out/workspaces.html +1 -1
  33. sky/execution.py +21 -4
  34. sky/global_user_state.py +110 -1
  35. sky/jobs/client/sdk.py +27 -20
  36. sky/jobs/controller.py +2 -1
  37. sky/jobs/recovery_strategy.py +3 -0
  38. sky/jobs/server/core.py +4 -0
  39. sky/jobs/utils.py +9 -2
  40. sky/provision/__init__.py +3 -2
  41. sky/provision/aws/instance.py +5 -4
  42. sky/provision/azure/instance.py +5 -4
  43. sky/provision/cudo/instance.py +5 -4
  44. sky/provision/do/instance.py +5 -4
  45. sky/provision/fluidstack/instance.py +5 -4
  46. sky/provision/gcp/instance.py +5 -4
  47. sky/provision/hyperbolic/instance.py +5 -4
  48. sky/provision/kubernetes/instance.py +36 -6
  49. sky/provision/lambda_cloud/instance.py +5 -4
  50. sky/provision/nebius/instance.py +5 -4
  51. sky/provision/oci/instance.py +5 -4
  52. sky/provision/paperspace/instance.py +5 -4
  53. sky/provision/provisioner.py +6 -0
  54. sky/provision/runpod/instance.py +5 -4
  55. sky/provision/scp/instance.py +5 -5
  56. sky/provision/vast/instance.py +5 -5
  57. sky/provision/vsphere/instance.py +5 -4
  58. sky/schemas/db/global_user_state/001_initial_schema.py +1 -1
  59. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  60. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  61. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  62. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  63. sky/schemas/db/spot_jobs/001_initial_schema.py +1 -1
  64. sky/serve/client/impl.py +11 -8
  65. sky/serve/client/sdk.py +7 -7
  66. sky/serve/serve_state.py +437 -340
  67. sky/serve/serve_utils.py +37 -3
  68. sky/serve/server/impl.py +2 -2
  69. sky/server/common.py +12 -8
  70. sky/server/constants.py +1 -1
  71. sky/setup_files/alembic.ini +4 -0
  72. sky/skypilot_config.py +4 -4
  73. sky/users/permission.py +1 -1
  74. sky/utils/cli_utils/status_utils.py +10 -1
  75. sky/utils/db/db_utils.py +53 -1
  76. sky/utils/db/migration_utils.py +5 -1
  77. sky/utils/kubernetes/deploy_remote_cluster.py +3 -1
  78. sky/utils/resource_checker.py +162 -21
  79. sky/volumes/client/sdk.py +4 -4
  80. sky/workspaces/core.py +210 -6
  81. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/METADATA +2 -2
  82. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/RECORD +87 -83
  83. sky/dashboard/out/_next/static/chunks/8056-019615038d6ce427.js +0 -1
  84. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6fd1d2d8441aa54b.js +0 -11
  85. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-155d477a6c3e04e2.js +0 -1
  86. sky/dashboard/out/_next/static/chunks/webpack-76efbdad99742559.js +0 -1
  87. /sky/dashboard/out/_next/static/{YAirOGsV1z6B2RJ0VIUmD → Fuy7OzApYTUMz2QgoP7dP}/_ssgManifest.js +0 -0
  88. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/WHEEL +0 -0
  89. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/entry_points.txt +0 -0
  90. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/licenses/LICENSE +0 -0
  91. {skypilot_nightly-1.0.0.dev20250807.dist-info → skypilot_nightly-1.0.0.dev20250812.dist-info}/top_level.txt +0 -0
sky/serve/serve_state.py CHANGED
@@ -3,108 +3,141 @@ import collections
3
3
  import enum
4
4
  import functools
5
5
  import json
6
- import pathlib
7
6
  import pickle
8
- import sqlite3
9
7
  import threading
10
8
  import typing
11
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any, Dict, List, Optional
12
10
  import uuid
13
11
 
14
12
  import colorama
13
+ import sqlalchemy
14
+ from sqlalchemy import exc as sqlalchemy_exc
15
+ from sqlalchemy import orm
16
+ from sqlalchemy.dialects import postgresql
17
+ from sqlalchemy.dialects import sqlite
18
+ from sqlalchemy.ext import declarative
15
19
 
16
20
  from sky.serve import constants
21
+ from sky.utils import common_utils
17
22
  from sky.utils.db import db_utils
23
+ from sky.utils.db import migration_utils
18
24
 
19
25
  if typing.TYPE_CHECKING:
26
+ from sqlalchemy.engine import row
27
+
20
28
  from sky.serve import replica_managers
21
29
  from sky.serve import service_spec
22
30
 
23
-
24
- def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
31
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
32
+ _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
33
+
34
+ Base = declarative.declarative_base()
35
+
36
+ # === Database schema ===
37
+ services_table = sqlalchemy.Table(
38
+ 'services',
39
+ Base.metadata,
40
+ sqlalchemy.Column('name', sqlalchemy.Text, primary_key=True),
41
+ sqlalchemy.Column('controller_job_id',
42
+ sqlalchemy.Integer,
43
+ server_default=None),
44
+ sqlalchemy.Column('controller_port',
45
+ sqlalchemy.Integer,
46
+ server_default=None),
47
+ sqlalchemy.Column('load_balancer_port',
48
+ sqlalchemy.Integer,
49
+ server_default=None),
50
+ sqlalchemy.Column('status', sqlalchemy.Text),
51
+ sqlalchemy.Column('uptime', sqlalchemy.Integer, server_default=None),
52
+ sqlalchemy.Column('policy', sqlalchemy.Text, server_default=None),
53
+ sqlalchemy.Column('auto_restart', sqlalchemy.Integer, server_default=None),
54
+ sqlalchemy.Column('requested_resources',
55
+ sqlalchemy.LargeBinary,
56
+ server_default=None),
57
+ sqlalchemy.Column('requested_resources_str', sqlalchemy.Text),
58
+ sqlalchemy.Column('current_version',
59
+ sqlalchemy.Integer,
60
+ server_default=str(constants.INITIAL_VERSION)),
61
+ sqlalchemy.Column('active_versions',
62
+ sqlalchemy.Text,
63
+ server_default=json.dumps([])),
64
+ sqlalchemy.Column('load_balancing_policy',
65
+ sqlalchemy.Text,
66
+ server_default=None),
67
+ sqlalchemy.Column('tls_encrypted', sqlalchemy.Integer, server_default='0'),
68
+ sqlalchemy.Column('pool', sqlalchemy.Integer, server_default='0'),
69
+ sqlalchemy.Column('controller_pid', sqlalchemy.Integer,
70
+ server_default=None),
71
+ sqlalchemy.Column('hash', sqlalchemy.Text, server_default=None),
72
+ sqlalchemy.Column('entrypoint', sqlalchemy.Text, server_default=None),
73
+ )
74
+
75
+ replicas_table = sqlalchemy.Table(
76
+ 'replicas',
77
+ Base.metadata,
78
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
79
+ sqlalchemy.Column('replica_id', sqlalchemy.Integer, primary_key=True),
80
+ sqlalchemy.Column('replica_info', sqlalchemy.LargeBinary),
81
+ )
82
+
83
+ version_specs_table = sqlalchemy.Table(
84
+ 'version_specs',
85
+ Base.metadata,
86
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
87
+ sqlalchemy.Column('version', sqlalchemy.Integer, primary_key=True),
88
+ sqlalchemy.Column('spec', sqlalchemy.LargeBinary),
89
+ )
90
+
91
+ serve_ha_recovery_script_table = sqlalchemy.Table(
92
+ 'serve_ha_recovery_script',
93
+ Base.metadata,
94
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
95
+ sqlalchemy.Column('script', sqlalchemy.Text),
96
+ )
97
+
98
+
99
+ def create_table(engine: sqlalchemy.engine.Engine):
25
100
  """Creates the service and replica tables if they do not exist."""
26
101
 
27
- # auto_restart and requested_resources column is deprecated.
28
- cursor.execute("""\
29
- CREATE TABLE IF NOT EXISTS services (
30
- name TEXT PRIMARY KEY,
31
- controller_job_id INTEGER DEFAULT NULL,
32
- controller_port INTEGER DEFAULT NULL,
33
- load_balancer_port INTEGER DEFAULT NULL,
34
- status TEXT,
35
- uptime INTEGER DEFAULT NULL,
36
- policy TEXT DEFAULT NULL,
37
- auto_restart INTEGER DEFAULT NULL,
38
- requested_resources BLOB DEFAULT NULL)""")
39
- cursor.execute("""\
40
- CREATE TABLE IF NOT EXISTS replicas (
41
- service_name TEXT,
42
- replica_id INTEGER,
43
- replica_info BLOB,
44
- PRIMARY KEY (service_name, replica_id))""")
45
- cursor.execute("""\
46
- CREATE TABLE IF NOT EXISTS version_specs (
47
- version INTEGER,
48
- service_name TEXT,
49
- spec BLOB,
50
- PRIMARY KEY (service_name, version))""")
51
- cursor.execute("""\
52
- CREATE TABLE IF NOT EXISTS ha_recovery_script (
53
- service_name TEXT PRIMARY KEY,
54
- script TEXT)""")
55
- conn.commit()
56
-
57
- # Backward compatibility.
58
- db_utils.add_column_to_table(cursor, conn, 'services',
59
- 'requested_resources_str', 'TEXT')
60
- # Deprecated: switched to `active_versions` below for the version
61
- # considered active by the load balancer. The
62
- # authscaler/replica_manager version can be found in the
63
- # version_specs table.
64
- db_utils.add_column_to_table(
65
- cursor, conn, 'services', 'current_version',
66
- f'INTEGER DEFAULT {constants.INITIAL_VERSION}')
67
- # The versions that is activated for the service. This is a list
68
- # of integers in json format.
69
- db_utils.add_column_to_table(cursor, conn, 'services', 'active_versions',
70
- f'TEXT DEFAULT {json.dumps([])!r}')
71
- db_utils.add_column_to_table(cursor, conn, 'services',
72
- 'load_balancing_policy', 'TEXT DEFAULT NULL')
73
- # Whether the service's load balancer is encrypted with TLS.
74
- db_utils.add_column_to_table(cursor, conn, 'services', 'tls_encrypted',
75
- 'INTEGER DEFAULT 0')
76
- # Whether the service is a cluster pool.
77
- db_utils.add_column_to_table(cursor, conn, 'services', 'pool',
78
- 'INTEGER DEFAULT 0')
79
- # Add controller_pid for status tracking.
80
- db_utils.add_column_to_table(cursor,
81
- conn,
82
- 'services',
83
- 'controller_pid',
84
- 'INTEGER DEFAULT NULL',
85
- value_to_replace_existing_entries=-1)
86
- # The service hash. Unique for each service, even if the service name is
87
- # the same.
88
- db_utils.add_column_to_table(cursor, conn, 'services', 'hash',
89
- 'TEXT DEFAULT NULL')
90
- # Entrypoint to launch the service.
91
- db_utils.add_column_to_table(cursor, conn, 'services', 'entrypoint',
92
- 'TEXT DEFAULT NULL')
93
- conn.commit()
94
-
95
-
96
- def _get_db_path() -> str:
97
- """Workaround to collapse multi-step Path ops for type checker.
98
- Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
99
- """
100
- path = pathlib.Path(constants.SKYSERVE_METADATA_DIR) / 'services.db'
101
- path = path.expanduser().absolute()
102
- path.parents[0].mkdir(parents=True, exist_ok=True)
103
- return str(path)
102
+ # Enable WAL mode to avoid locking issues.
103
+ # See: issue #3863, #1441 and PR #1509
104
+ # https://github.com/microsoft/WSL/issues/2395
105
+ # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
106
+ # This may cause the database locked problem from WSL issue #1441.
107
+ if (engine.dialect.name == db_utils.SQLAlchemyDialect.SQLITE.value and
108
+ not common_utils.is_wsl()):
109
+ try:
110
+ with orm.Session(engine) as session:
111
+ session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
112
+ session.commit()
113
+ except sqlalchemy_exc.OperationalError as e:
114
+ if 'database is locked' not in str(e):
115
+ raise
116
+ # If the database is locked, it is OK to continue, as the WAL mode
117
+ # is not critical and is likely to be enabled by other processes.
118
+
119
+ migration_utils.safe_alembic_upgrade(engine, migration_utils.SERVE_DB_NAME,
120
+ migration_utils.SERVE_VERSION)
121
+
122
+
123
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
124
+ global _SQLALCHEMY_ENGINE
125
+
126
+ if _SQLALCHEMY_ENGINE is not None:
127
+ return _SQLALCHEMY_ENGINE
128
+
129
+ with _SQLALCHEMY_ENGINE_LOCK:
130
+ if _SQLALCHEMY_ENGINE is not None:
131
+ return _SQLALCHEMY_ENGINE
132
+ # get an engine to the db
133
+ engine = migration_utils.get_engine('serve/services')
104
134
 
135
+ # run migrations if needed
136
+ create_table(engine)
105
137
 
106
- _DB_PATH = None
107
- _db_init_lock = threading.Lock()
138
+ # return engine
139
+ _SQLALCHEMY_ENGINE = engine
140
+ return _SQLALCHEMY_ENGINE
108
141
 
109
142
 
110
143
  def init_db(func):
@@ -112,19 +145,18 @@ def init_db(func):
112
145
 
113
146
  @functools.wraps(func)
114
147
  def wrapper(*args, **kwargs):
115
- global _DB_PATH
116
- if _DB_PATH is not None:
117
- return func(*args, **kwargs)
118
- with _db_init_lock:
119
- if _DB_PATH is None:
120
- _DB_PATH = _get_db_path()
121
- db_utils.SQLiteConn(_DB_PATH, create_table)
148
+ initialize_and_get_db()
122
149
  return func(*args, **kwargs)
123
150
 
124
151
  return wrapper
125
152
 
126
153
 
127
- _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
154
+ _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS = [
155
+ # sqlite
156
+ 'UNIQUE constraint failed: services.name',
157
+ # postgres
158
+ 'duplicate key value violates unique constraint "services_pkey"',
159
+ ]
128
160
 
129
161
 
130
162
  # === Statuses ===
@@ -299,25 +331,38 @@ def add_service(name: str, controller_job_id: int, policy: str,
299
331
  True if the service is added successfully, False if the service already
300
332
  exists.
301
333
  """
302
- assert _DB_PATH is not None
334
+ assert _SQLALCHEMY_ENGINE is not None
303
335
  try:
304
- with db_utils.safe_cursor(_DB_PATH) as cursor:
305
- cursor.execute(
306
- """\
307
- INSERT INTO services
308
- (name, controller_job_id, status, policy,
309
- requested_resources_str, load_balancing_policy, tls_encrypted,
310
- pool, controller_pid, hash, entrypoint)
311
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
312
- (name, controller_job_id, status.value, policy,
313
- requested_resources_str, load_balancing_policy,
314
- int(tls_encrypted), int(pool), controller_pid, str(
315
- uuid.uuid4()), entrypoint))
316
-
317
- except sqlite3.IntegrityError as e:
318
- if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
319
- raise RuntimeError('Unexpected database error') from e
320
- return False
336
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
337
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
338
+ db_utils.SQLAlchemyDialect.SQLITE.value):
339
+ insert_func = sqlite.insert
340
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
341
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
342
+ insert_func = postgresql.insert
343
+ else:
344
+ raise ValueError('Unsupported database dialect')
345
+
346
+ insert_stmt = insert_func(services_table).values(
347
+ name=name,
348
+ controller_job_id=controller_job_id,
349
+ status=status.value,
350
+ policy=policy,
351
+ requested_resources_str=requested_resources_str,
352
+ load_balancing_policy=load_balancing_policy,
353
+ tls_encrypted=int(tls_encrypted),
354
+ pool=int(pool),
355
+ controller_pid=controller_pid,
356
+ hash=str(uuid.uuid4()),
357
+ entrypoint=entrypoint)
358
+ session.execute(insert_stmt)
359
+ session.commit()
360
+
361
+ except sqlalchemy_exc.IntegrityError as e:
362
+ for msg in _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS:
363
+ if msg in str(e):
364
+ return False
365
+ raise RuntimeError('Unexpected database error') from e
321
366
  return True
322
367
 
323
368
 
@@ -328,33 +373,34 @@ def update_service_controller_pid(service_name: str,
328
373
 
329
374
  This is used to update the controller pid of a service on ha recovery.
330
375
  """
331
- assert _DB_PATH is not None
332
- with db_utils.safe_cursor(_DB_PATH) as cursor:
333
- cursor.execute(
334
- """\
335
- UPDATE services SET
336
- controller_pid=(?) WHERE name=(?)""",
337
- (controller_pid, service_name))
376
+ assert _SQLALCHEMY_ENGINE is not None
377
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
378
+ session.query(services_table).filter(
379
+ services_table.c.name == service_name).update(
380
+ {services_table.c.controller_pid: controller_pid})
381
+ session.commit()
338
382
 
339
383
 
340
384
  @init_db
341
385
  def remove_service(service_name: str) -> None:
342
386
  """Removes a service from the database."""
343
- assert _DB_PATH is not None
344
- with db_utils.safe_cursor(_DB_PATH) as cursor:
345
- cursor.execute("""\
346
- DELETE FROM services WHERE name=(?)""", (service_name,))
387
+ assert _SQLALCHEMY_ENGINE is not None
388
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
389
+ session.execute(
390
+ sqlalchemy.delete(services_table).where(
391
+ services_table.c.name == service_name))
392
+ session.commit()
347
393
 
348
394
 
349
395
  @init_db
350
396
  def set_service_uptime(service_name: str, uptime: int) -> None:
351
397
  """Sets the uptime of a service."""
352
- assert _DB_PATH is not None
353
- with db_utils.safe_cursor(_DB_PATH) as cursor:
354
- cursor.execute(
355
- """\
356
- UPDATE services SET
357
- uptime=(?) WHERE name=(?)""", (uptime, service_name))
398
+ assert _SQLALCHEMY_ENGINE is not None
399
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
400
+ session.query(services_table).filter(
401
+ services_table.c.name == service_name).update(
402
+ {services_table.c.uptime: uptime})
403
+ session.commit()
358
404
 
359
405
 
360
406
  @init_db
@@ -363,74 +409,71 @@ def set_service_status_and_active_versions(
363
409
  status: ServiceStatus,
364
410
  active_versions: Optional[List[int]] = None) -> None:
365
411
  """Sets the service status."""
366
- assert _DB_PATH is not None
367
- vars_to_set = 'status=(?)'
368
- values: Tuple[str, ...] = (status.value, service_name)
412
+ assert _SQLALCHEMY_ENGINE is not None
413
+ update_dict = {services_table.c.status: status.value}
369
414
  if active_versions is not None:
370
- vars_to_set = 'status=(?), active_versions=(?)'
371
- values = (status.value, json.dumps(active_versions), service_name)
372
- with db_utils.safe_cursor(_DB_PATH) as cursor:
373
- cursor.execute(
374
- f"""\
375
- UPDATE services SET
376
- {vars_to_set} WHERE name=(?)""", values)
415
+ update_dict[services_table.c.active_versions] = json.dumps(
416
+ active_versions)
417
+
418
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
419
+ session.query(services_table).filter(
420
+ services_table.c.name == service_name).update(update_dict)
421
+ session.commit()
377
422
 
378
423
 
379
424
  @init_db
380
425
  def set_service_controller_port(service_name: str,
381
426
  controller_port: int) -> None:
382
427
  """Sets the controller port of a service."""
383
- assert _DB_PATH is not None
384
- with db_utils.safe_cursor(_DB_PATH) as cursor:
385
- cursor.execute(
386
- """\
387
- UPDATE services SET
388
- controller_port=(?) WHERE name=(?)""",
389
- (controller_port, service_name))
428
+ assert _SQLALCHEMY_ENGINE is not None
429
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
430
+ session.query(services_table).filter(
431
+ services_table.c.name == service_name).update(
432
+ {services_table.c.controller_port: controller_port})
433
+ session.commit()
390
434
 
391
435
 
392
436
  @init_db
393
437
  def set_service_load_balancer_port(service_name: str,
394
438
  load_balancer_port: int) -> None:
395
439
  """Sets the load balancer port of a service."""
396
- assert _DB_PATH is not None
397
- with db_utils.safe_cursor(_DB_PATH) as cursor:
398
- cursor.execute(
399
- """\
400
- UPDATE services SET
401
- load_balancer_port=(?) WHERE name=(?)""",
402
- (load_balancer_port, service_name))
403
-
404
-
405
- def _get_service_from_row(row) -> Dict[str, Any]:
406
- (current_version, name, controller_job_id, controller_port,
407
- load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
408
- _, active_versions, load_balancing_policy, tls_encrypted, pool,
409
- controller_pid, svc_hash, entrypoint) = row[:19]
440
+ assert _SQLALCHEMY_ENGINE is not None
441
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
442
+ session.query(services_table).filter(
443
+ services_table.c.name == service_name).update(
444
+ {services_table.c.load_balancer_port: load_balancer_port})
445
+ session.commit()
446
+
447
+
448
+ def _get_service_from_row(r: 'row.RowMapping') -> Dict[str, Any]:
449
+ # Get the max_version from the first column (from the subquery)
450
+ current_version = r['max_version']
451
+
410
452
  record = {
411
- 'name': name,
412
- 'controller_job_id': controller_job_id,
413
- 'controller_port': controller_port,
414
- 'load_balancer_port': load_balancer_port,
415
- 'status': ServiceStatus[status],
416
- 'uptime': uptime,
417
- 'policy': policy,
453
+ 'name': r['name'],
454
+ 'controller_job_id': r['controller_job_id'],
455
+ 'controller_port': r['controller_port'],
456
+ 'load_balancer_port': r['load_balancer_port'],
457
+ 'status': ServiceStatus[r['status']],
458
+ 'uptime': r['uptime'],
459
+ 'policy': r['policy'],
418
460
  # The version of the autoscaler/replica manager are on. It can be larger
419
461
  # than the active versions as the load balancer may not consider the
420
462
  # latest version to be active for serving traffic.
421
463
  'version': current_version,
422
464
  # The versions that is active for the load balancer. This is a list of
423
465
  # integers in json format. This is mainly for display purpose.
424
- 'active_versions': json.loads(active_versions),
425
- 'requested_resources_str': requested_resources_str,
426
- 'load_balancing_policy': load_balancing_policy,
427
- 'tls_encrypted': bool(tls_encrypted),
428
- 'pool': bool(pool),
429
- 'controller_pid': controller_pid,
430
- 'hash': svc_hash,
431
- 'entrypoint': entrypoint,
466
+ 'active_versions': json.loads(r['active_versions'])
467
+ if r['active_versions'] else [],
468
+ 'requested_resources_str': r['requested_resources_str'],
469
+ 'load_balancing_policy': r['load_balancing_policy'],
470
+ 'tls_encrypted': bool(r['tls_encrypted']),
471
+ 'pool': bool(r['pool']),
472
+ 'controller_pid': r['controller_pid'],
473
+ 'hash': r['hash'],
474
+ 'entrypoint': r['entrypoint'],
432
475
  }
433
- latest_spec = get_spec(name, current_version)
476
+ latest_spec = get_spec(r['name'], current_version)
434
477
  if latest_spec is not None:
435
478
  record['policy'] = latest_spec.autoscaling_policy_str()
436
479
  record['load_balancing_policy'] = latest_spec.load_balancing_policy
@@ -440,57 +483,69 @@ def _get_service_from_row(row) -> Dict[str, Any]:
440
483
  @init_db
441
484
  def get_services() -> List[Dict[str, Any]]:
442
485
  """Get all existing service records."""
443
- assert _DB_PATH is not None
444
- with db_utils.safe_cursor(_DB_PATH) as cursor:
445
- rows = cursor.execute('SELECT v.max_version, s.* FROM services s '
446
- 'JOIN ('
447
- 'SELECT service_name, MAX(version) as max_version'
448
- ' FROM version_specs GROUP BY service_name) v '
449
- 'ON s.name=v.service_name').fetchall()
486
+ assert _SQLALCHEMY_ENGINE is not None
487
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
488
+ subquery = sqlalchemy.select(
489
+ version_specs_table.c.service_name,
490
+ sqlalchemy.func.max(
491
+ version_specs_table.c.version).label('max_version')).group_by(
492
+ version_specs_table.c.service_name).alias('v')
493
+
494
+ query = sqlalchemy.select(
495
+ subquery.c.max_version, services_table).select_from(
496
+ services_table.join(
497
+ subquery, services_table.c.name == subquery.c.service_name))
498
+ rows = session.execute(query).fetchall()
450
499
  records = []
451
500
  for row in rows:
452
- records.append(_get_service_from_row(row))
501
+ records.append(_get_service_from_row(row._mapping)) # pylint: disable=protected-access
453
502
  return records
454
503
 
455
504
 
456
505
  @init_db
457
506
  def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:
458
507
  """Get all existing service records."""
459
- assert _DB_PATH is not None
460
- with db_utils.safe_cursor(_DB_PATH) as cursor:
461
- rows = cursor.execute(
462
- 'SELECT v.max_version, s.* FROM services s '
463
- 'JOIN ('
464
- 'SELECT service_name, MAX(version) as max_version '
465
- 'FROM version_specs WHERE service_name=(?)) v '
466
- 'ON s.name=v.service_name WHERE name=(?)',
467
- (service_name, service_name)).fetchall()
508
+ assert _SQLALCHEMY_ENGINE is not None
509
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
510
+ subquery = sqlalchemy.select(
511
+ version_specs_table.c.service_name,
512
+ sqlalchemy.func.max(
513
+ version_specs_table.c.version).label('max_version')
514
+ ).where(version_specs_table.c.service_name == service_name).group_by(
515
+ version_specs_table.c.service_name).alias('v')
516
+
517
+ query = sqlalchemy.select(
518
+ subquery.c.max_version, services_table).select_from(
519
+ services_table.join(
520
+ subquery,
521
+ services_table.c.name == subquery.c.service_name)).where(
522
+ services_table.c.name == service_name)
523
+
524
+ rows = session.execute(query).fetchall()
468
525
  for row in rows:
469
- return _get_service_from_row(row)
526
+ return _get_service_from_row(row._mapping) # pylint: disable=protected-access
470
527
  return None
471
528
 
472
529
 
473
530
  @init_db
474
531
  def get_service_hash(service_name: str) -> Optional[str]:
475
532
  """Get the hash of a service."""
476
- assert _DB_PATH is not None
477
- with db_utils.safe_cursor(_DB_PATH) as cursor:
478
- rows = cursor.execute('SELECT hash FROM services WHERE name=(?)',
479
- (service_name,)).fetchall()
480
- for row in rows:
481
- return row[0]
482
- return None
533
+ assert _SQLALCHEMY_ENGINE is not None
534
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
535
+ result = session.execute(
536
+ sqlalchemy.select(services_table.c.hash).where(
537
+ services_table.c.name == service_name)).fetchone()
538
+ return result[0] if result else None
483
539
 
484
540
 
485
541
  @init_db
486
542
  def get_service_versions(service_name: str) -> List[int]:
487
543
  """Gets all versions of a service."""
488
- assert _DB_PATH is not None
489
- with db_utils.safe_cursor(_DB_PATH) as cursor:
490
- rows = cursor.execute(
491
- """\
492
- SELECT DISTINCT version FROM version_specs
493
- WHERE service_name=(?)""", (service_name,)).fetchall()
544
+ assert _SQLALCHEMY_ENGINE is not None
545
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
546
+ rows = session.execute(
547
+ sqlalchemy.select(version_specs_table.c.version.distinct()).where(
548
+ version_specs_table.c.service_name == service_name)).fetchall()
494
549
  return [row[0] for row in rows]
495
550
 
496
551
 
@@ -506,17 +561,19 @@ def get_glob_service_names(
506
561
  Returns:
507
562
  A list of non-duplicated service names.
508
563
  """
509
- assert _DB_PATH is not None
510
- with db_utils.safe_cursor(_DB_PATH) as cursor:
564
+ assert _SQLALCHEMY_ENGINE is not None
565
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
511
566
  if service_names is None:
512
- rows = cursor.execute('SELECT name FROM services').fetchall()
567
+ rows = session.execute(sqlalchemy.select(
568
+ services_table.c.name)).fetchall()
513
569
  else:
514
570
  rows = []
515
571
  for service_name in service_names:
516
- rows.extend(
517
- cursor.execute(
518
- 'SELECT name FROM services WHERE name GLOB (?)',
519
- (service_name,)).fetchall())
572
+ pattern_rows = session.execute(
573
+ sqlalchemy.select(services_table.c.name).where(
574
+ services_table.c.name.like(
575
+ service_name.replace('*', '%')))).fetchall()
576
+ rows.extend(pattern_rows)
520
577
  return list({row[0] for row in rows})
521
578
 
522
579
 
@@ -525,26 +582,40 @@ def get_glob_service_names(
525
582
  def add_or_update_replica(service_name: str, replica_id: int,
526
583
  replica_info: 'replica_managers.ReplicaInfo') -> None:
527
584
  """Adds a replica to the database."""
528
- assert _DB_PATH is not None
529
- with db_utils.safe_cursor(_DB_PATH) as cursor:
530
- cursor.execute(
531
- """\
532
- INSERT OR REPLACE INTO replicas
533
- (service_name, replica_id, replica_info)
534
- VALUES (?, ?, ?)""",
535
- (service_name, replica_id, pickle.dumps(replica_info)))
585
+ assert _SQLALCHEMY_ENGINE is not None
586
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
587
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
588
+ db_utils.SQLAlchemyDialect.SQLITE.value):
589
+ insert_func = sqlite.insert
590
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
591
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
592
+ insert_func = postgresql.insert
593
+ else:
594
+ raise ValueError('Unsupported database dialect')
595
+
596
+ insert_stmt = insert_func(replicas_table).values(
597
+ service_name=service_name,
598
+ replica_id=replica_id,
599
+ replica_info=pickle.dumps(replica_info))
600
+
601
+ insert_stmt = insert_stmt.on_conflict_do_update(
602
+ index_elements=['service_name', 'replica_id'],
603
+ set_={'replica_info': insert_stmt.excluded.replica_info})
604
+
605
+ session.execute(insert_stmt)
606
+ session.commit()
536
607
 
537
608
 
538
609
  @init_db
539
610
  def remove_replica(service_name: str, replica_id: int) -> None:
540
611
  """Removes a replica from the database."""
541
- assert _DB_PATH is not None
542
- with db_utils.safe_cursor(_DB_PATH) as cursor:
543
- cursor.execute(
544
- """\
545
- DELETE FROM replicas
546
- WHERE service_name=(?)
547
- AND replica_id=(?)""", (service_name, replica_id))
612
+ assert _SQLALCHEMY_ENGINE is not None
613
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
614
+ session.execute(
615
+ sqlalchemy.delete(replicas_table).where(
616
+ sqlalchemy.and_(replicas_table.c.service_name == service_name,
617
+ replicas_table.c.replica_id == replica_id)))
618
+ session.commit()
548
619
 
549
620
 
550
621
  @init_db
@@ -552,37 +623,35 @@ def get_replica_info_from_id(
552
623
  service_name: str,
553
624
  replica_id: int) -> Optional['replica_managers.ReplicaInfo']:
554
625
  """Gets a replica info from the database."""
555
- assert _DB_PATH is not None
556
- with db_utils.safe_cursor(_DB_PATH) as cursor:
557
- rows = cursor.execute(
558
- """\
559
- SELECT replica_info FROM replicas
560
- WHERE service_name=(?)
561
- AND replica_id=(?)""", (service_name, replica_id)).fetchall()
562
- for row in rows:
563
- return pickle.loads(row[0])
564
- return None
626
+ assert _SQLALCHEMY_ENGINE is not None
627
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
628
+ result = session.execute(
629
+ sqlalchemy.select(replicas_table.c.replica_info).where(
630
+ sqlalchemy.and_(
631
+ replicas_table.c.service_name == service_name,
632
+ replicas_table.c.replica_id == replica_id))).fetchone()
633
+ return pickle.loads(result[0]) if result else None
565
634
 
566
635
 
567
636
  @init_db
568
637
  def get_replica_infos(
569
638
  service_name: str) -> List['replica_managers.ReplicaInfo']:
570
639
  """Gets all replica infos of a service."""
571
- assert _DB_PATH is not None
572
- with db_utils.safe_cursor(_DB_PATH) as cursor:
573
- rows = cursor.execute(
574
- """\
575
- SELECT replica_info FROM replicas
576
- WHERE service_name=(?)""", (service_name,)).fetchall()
640
+ assert _SQLALCHEMY_ENGINE is not None
641
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
642
+ rows = session.execute(
643
+ sqlalchemy.select(replicas_table.c.replica_info).where(
644
+ replicas_table.c.service_name == service_name)).fetchall()
577
645
  return [pickle.loads(row[0]) for row in rows]
578
646
 
579
647
 
580
648
  @init_db
581
649
  def total_number_provisioning_replicas() -> int:
582
650
  """Returns the total number of provisioning replicas."""
583
- assert _DB_PATH is not None
584
- with db_utils.safe_cursor(_DB_PATH) as cursor:
585
- rows = cursor.execute('SELECT replica_info FROM replicas').fetchall()
651
+ assert _SQLALCHEMY_ENGINE is not None
652
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
653
+ rows = session.execute(sqlalchemy.select(
654
+ replicas_table.c.replica_info)).fetchall()
586
655
  provisioning_count = 0
587
656
  for row in rows:
588
657
  replica_info: 'replica_managers.ReplicaInfo' = pickle.loads(row[0])
@@ -603,154 +672,182 @@ def get_replicas_at_status(
603
672
  @init_db
604
673
  def add_version(service_name: str) -> int:
605
674
  """Adds a version to the database."""
606
- assert _DB_PATH is not None
607
- with db_utils.safe_cursor(_DB_PATH) as cursor:
608
- cursor.execute(
609
- """\
610
- INSERT INTO version_specs
611
- (version, service_name, spec)
612
- VALUES (
613
- (SELECT COALESCE(MAX(version), 0) + 1 FROM
614
- version_specs WHERE service_name = ?), ?, ?)
615
- RETURNING version""",
616
- (service_name, service_name, pickle.dumps(None)))
617
-
618
- inserted_version = cursor.fetchone()[0]
619
-
620
- return inserted_version
675
+ assert _SQLALCHEMY_ENGINE is not None
676
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
677
+ # Insert new version with MAX(version) + 1 in a single atomic operation
678
+ max_version_subquery = sqlalchemy.select(
679
+ sqlalchemy.func.coalesce(
680
+ sqlalchemy.func.max(version_specs_table.c.version), 0) +
681
+ 1).where(version_specs_table.c.service_name ==
682
+ service_name).scalar_subquery()
683
+
684
+ # Use INSERT with subquery and RETURNING
685
+ insert_stmt = sqlalchemy.insert(version_specs_table).values(
686
+ service_name=service_name,
687
+ version=max_version_subquery,
688
+ spec=pickle.dumps(None)).returning(version_specs_table.c.version)
689
+
690
+ result = session.execute(insert_stmt)
691
+ new_version = result.scalar()
692
+ session.commit()
693
+ return new_version
621
694
 
622
695
 
623
696
  @init_db
624
697
  def add_or_update_version(service_name: str, version: int,
625
698
  spec: 'service_spec.SkyServiceSpec') -> None:
626
- assert _DB_PATH is not None
627
- with db_utils.safe_cursor(_DB_PATH) as cursor:
628
- cursor.execute(
629
- """\
630
- INSERT or REPLACE INTO version_specs
631
- (service_name, version, spec)
632
- VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
699
+ assert _SQLALCHEMY_ENGINE is not None
700
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
701
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
702
+ db_utils.SQLAlchemyDialect.SQLITE.value):
703
+ insert_func = sqlite.insert
704
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
705
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
706
+ insert_func = postgresql.insert
707
+ else:
708
+ raise ValueError('Unsupported database dialect')
709
+
710
+ insert_stmt = insert_func(version_specs_table).values(
711
+ service_name=service_name, version=version, spec=pickle.dumps(spec))
712
+
713
+ insert_stmt = insert_stmt.on_conflict_do_update(
714
+ index_elements=['service_name', 'version'],
715
+ set_={'spec': insert_stmt.excluded.spec})
716
+
717
+ session.execute(insert_stmt)
718
+ session.commit()
633
719
 
634
720
 
635
721
  @init_db
636
722
  def remove_service_versions(service_name: str) -> None:
637
723
  """Removes a replica from the database."""
638
- assert _DB_PATH is not None
639
- with db_utils.safe_cursor(_DB_PATH) as cursor:
640
- cursor.execute(
641
- """\
642
- DELETE FROM version_specs
643
- WHERE service_name=(?)""", (service_name,))
724
+ assert _SQLALCHEMY_ENGINE is not None
725
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
726
+ session.execute(
727
+ sqlalchemy.delete(version_specs_table).where(
728
+ version_specs_table.c.service_name == service_name))
729
+ session.commit()
644
730
 
645
731
 
646
732
  @init_db
647
733
  def get_spec(service_name: str,
648
734
  version: int) -> Optional['service_spec.SkyServiceSpec']:
649
735
  """Gets spec from the database."""
650
- assert _DB_PATH is not None
651
- with db_utils.safe_cursor(_DB_PATH) as cursor:
652
- rows = cursor.execute(
653
- """\
654
- SELECT spec FROM version_specs
655
- WHERE service_name=(?)
656
- AND version=(?)""", (service_name, version)).fetchall()
657
- for row in rows:
658
- return pickle.loads(row[0])
659
- return None
736
+ assert _SQLALCHEMY_ENGINE is not None
737
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
738
+ result = session.execute(
739
+ sqlalchemy.select(version_specs_table.c.spec).where(
740
+ sqlalchemy.and_(
741
+ version_specs_table.c.service_name == service_name,
742
+ version_specs_table.c.version == version))).fetchone()
743
+ return pickle.loads(result[0]) if result else None
660
744
 
661
745
 
662
746
  @init_db
663
747
  def delete_version(service_name: str, version: int) -> None:
664
748
  """Deletes a version from the database."""
665
- assert _DB_PATH is not None
666
- with db_utils.safe_cursor(_DB_PATH) as cursor:
667
- cursor.execute(
668
- """\
669
- DELETE FROM version_specs
670
- WHERE service_name=(?)
671
- AND version=(?)""", (service_name, version))
749
+ assert _SQLALCHEMY_ENGINE is not None
750
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
751
+ session.execute(
752
+ sqlalchemy.delete(version_specs_table).where(
753
+ sqlalchemy.and_(
754
+ version_specs_table.c.service_name == service_name,
755
+ version_specs_table.c.version == version)))
756
+ session.commit()
672
757
 
673
758
 
674
759
  @init_db
675
760
  def delete_all_versions(service_name: str) -> None:
676
761
  """Deletes all versions from the database."""
677
- assert _DB_PATH is not None
678
- with db_utils.safe_cursor(_DB_PATH) as cursor:
679
- cursor.execute(
680
- """\
681
- DELETE FROM version_specs
682
- WHERE service_name=(?)""", (service_name,))
762
+ assert _SQLALCHEMY_ENGINE is not None
763
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
764
+ session.execute(
765
+ sqlalchemy.delete(version_specs_table).where(
766
+ version_specs_table.c.service_name == service_name))
767
+ session.commit()
683
768
 
684
769
 
685
770
  @init_db
686
771
  def get_latest_version(service_name: str) -> Optional[int]:
687
- assert _DB_PATH is not None
688
- with db_utils.safe_cursor(_DB_PATH) as cursor:
689
- rows = cursor.execute(
690
- """\
691
- SELECT MAX(version) FROM version_specs
692
- WHERE service_name=(?)""", (service_name,)).fetchall()
693
- if not rows or rows[0][0] is None:
694
- return None
695
- return rows[0][0]
772
+ assert _SQLALCHEMY_ENGINE is not None
773
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
774
+ result = session.execute(
775
+ sqlalchemy.select(sqlalchemy.func.max(
776
+ version_specs_table.c.version)).where(
777
+ version_specs_table.c.service_name ==
778
+ service_name)).fetchone()
779
+ return result[0] if result else None
696
780
 
697
781
 
698
782
  @init_db
699
783
  def get_service_controller_port(service_name: str) -> int:
700
784
  """Gets the controller port of a service."""
701
- assert _DB_PATH is not None
702
- with db_utils.safe_cursor(_DB_PATH) as cursor:
703
- cursor.execute('SELECT controller_port FROM services WHERE name = ?',
704
- (service_name,))
705
- row = cursor.fetchone()
706
- if row is None:
785
+ assert _SQLALCHEMY_ENGINE is not None
786
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
787
+ result = session.execute(
788
+ sqlalchemy.select(services_table.c.controller_port).where(
789
+ services_table.c.name == service_name)).fetchone()
790
+ if result is None:
707
791
  raise ValueError(f'Service {service_name} does not exist.')
708
- return row[0]
792
+ return result[0]
709
793
 
710
794
 
711
795
  @init_db
712
796
  def get_service_load_balancer_port(service_name: str) -> int:
713
797
  """Gets the load balancer port of a service."""
714
- assert _DB_PATH is not None
715
- with db_utils.safe_cursor(_DB_PATH) as cursor:
716
- cursor.execute('SELECT load_balancer_port FROM services WHERE name = ?',
717
- (service_name,))
718
- row = cursor.fetchone()
719
- if row is None:
798
+ assert _SQLALCHEMY_ENGINE is not None
799
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
800
+ result = session.execute(
801
+ sqlalchemy.select(services_table.c.load_balancer_port).where(
802
+ services_table.c.name == service_name)).fetchone()
803
+ if result is None:
720
804
  raise ValueError(f'Service {service_name} does not exist.')
721
- return row[0]
805
+ return result[0]
722
806
 
723
807
 
724
808
  @init_db
725
809
  def get_ha_recovery_script(service_name: str) -> Optional[str]:
726
810
  """Gets the HA recovery script for a service."""
727
- assert _DB_PATH is not None
728
- with db_utils.safe_cursor(_DB_PATH) as cursor:
729
- cursor.execute(
730
- 'SELECT script FROM ha_recovery_script WHERE service_name = ?',
731
- (service_name,))
732
- row = cursor.fetchone()
733
- if row is None:
734
- return None
735
- return row[0]
811
+ assert _SQLALCHEMY_ENGINE is not None
812
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
813
+ result = session.execute(
814
+ sqlalchemy.select(serve_ha_recovery_script_table.c.script).where(
815
+ serve_ha_recovery_script_table.c.service_name ==
816
+ service_name)).fetchone()
817
+ return result[0] if result else None
736
818
 
737
819
 
738
820
  @init_db
739
821
  def set_ha_recovery_script(service_name: str, script: str) -> None:
740
822
  """Sets the HA recovery script for a service."""
741
- assert _DB_PATH is not None
742
- with db_utils.safe_cursor(_DB_PATH) as cursor:
743
- cursor.execute(
744
- """\
745
- INSERT OR REPLACE INTO ha_recovery_script
746
- (service_name, script)
747
- VALUES (?, ?)""", (service_name, script))
823
+ assert _SQLALCHEMY_ENGINE is not None
824
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
825
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
826
+ db_utils.SQLAlchemyDialect.SQLITE.value):
827
+ insert_func = sqlite.insert
828
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
829
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
830
+ insert_func = postgresql.insert
831
+ else:
832
+ raise ValueError('Unsupported database dialect')
833
+
834
+ insert_stmt = insert_func(serve_ha_recovery_script_table).values(
835
+ service_name=service_name, script=script)
836
+
837
+ insert_stmt = insert_stmt.on_conflict_do_update(
838
+ index_elements=['service_name'],
839
+ set_={'script': insert_stmt.excluded.script})
840
+
841
+ session.execute(insert_stmt)
842
+ session.commit()
748
843
 
749
844
 
750
845
  @init_db
751
846
  def remove_ha_recovery_script(service_name: str) -> None:
752
847
  """Removes the HA recovery script for a service."""
753
- assert _DB_PATH is not None
754
- with db_utils.safe_cursor(_DB_PATH) as cursor:
755
- cursor.execute('DELETE FROM ha_recovery_script WHERE service_name = ?',
756
- (service_name,))
848
+ assert _SQLALCHEMY_ENGINE is not None
849
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
850
+ session.execute(
851
+ sqlalchemy.delete(serve_ha_recovery_script_table).where(
852
+ serve_ha_recovery_script_table.c.service_name == service_name))
853
+ session.commit()