skypilot-nightly 1.0.0.dev20250806__py3-none-any.whl → 1.0.0.dev20250808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (137) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +20 -1
  3. sky/backends/cloud_vm_ray_backend.py +42 -6
  4. sky/check.py +11 -1
  5. sky/client/cli/command.py +248 -119
  6. sky/client/sdk.py +146 -66
  7. sky/client/sdk_async.py +5 -1
  8. sky/core.py +5 -2
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/_next/static/-DXZksWqf2waNHeU9YTQe/_buildManifest.js +1 -0
  11. sky/dashboard/out/_next/static/chunks/1141-a8a8f1adba34c892.js +11 -0
  12. sky/dashboard/out/_next/static/chunks/1871-980a395e92633a5c.js +6 -0
  13. sky/dashboard/out/_next/static/chunks/3785.6003d293cb83eab4.js +1 -0
  14. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/4725.29550342bd53afd8.js +1 -0
  16. sky/dashboard/out/_next/static/chunks/{4937.d6bf67771e353356.js → 4937.a2baa2df5572a276.js} +1 -1
  17. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
  19. sky/dashboard/out/_next/static/chunks/{691.6d99cbfba347cebf.js → 691.5eeedf82cc243343.js} +1 -1
  20. sky/dashboard/out/_next/static/chunks/6989-6129c1cfbcf51063.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/6990-0f886f16e0d55ff8.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/8056-34d27f51e6d1c631.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/8252.62b0d23aed618bb2.js +16 -0
  24. sky/dashboard/out/_next/static/chunks/8969-c9686994ddafcf01.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/9025.a1bef12d672bb66d.js +6 -0
  26. sky/dashboard/out/_next/static/chunks/9159-11421c0f2909236f.js +1 -0
  27. sky/dashboard/out/_next/static/chunks/9360.85b0b1b4054574dd.js +31 -0
  28. sky/dashboard/out/_next/static/chunks/9666.cd4273f2a5c5802c.js +1 -0
  29. sky/dashboard/out/_next/static/chunks/{9847.4c46c5e229c78704.js → 9847.757720f3b40c0aa5.js} +1 -1
  30. sky/dashboard/out/_next/static/chunks/pages/{_app-2a43ea3241bbdacd.js → _app-491a4d699d95e808.js} +1 -1
  31. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-ae17cec0fc6483d9.js +11 -0
  32. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-155d477a6c3e04e2.js +1 -0
  33. sky/dashboard/out/_next/static/chunks/pages/{clusters-47f1ddae13a2f8e4.js → clusters-b30460f683e6ba96.js} +1 -1
  34. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  35. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-2a44e70b500b6b70.js → [context]-13d53fffc03ccb52.js} +1 -1
  36. sky/dashboard/out/_next/static/chunks/pages/{infra-22faac9325016d83.js → infra-fc9222e26c8e2f0d.js} +1 -1
  37. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-154f55cf8af55be5.js +11 -0
  38. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-f5ccf5d39d87aebe.js +21 -0
  39. sky/dashboard/out/_next/static/chunks/pages/jobs-cdc60fb5d371e16a.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/pages/{users-b90c865a690bfe84.js → users-7ed36e44e779d5c7.js} +1 -1
  41. sky/dashboard/out/_next/static/chunks/pages/{volumes-7af733f5d7b6ed1c.js → volumes-c9695d657f78b5dc.js} +1 -1
  42. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  43. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-f72f73bcef9541dc.js +1 -0
  44. sky/dashboard/out/_next/static/chunks/pages/workspaces-8f67be60165724cc.js +1 -0
  45. sky/dashboard/out/_next/static/chunks/webpack-339efec49c0cc7d0.js +1 -0
  46. sky/dashboard/out/_next/static/css/4614e06482d7309e.css +3 -0
  47. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  48. sky/dashboard/out/clusters/[cluster].html +1 -1
  49. sky/dashboard/out/clusters.html +1 -1
  50. sky/dashboard/out/config.html +1 -1
  51. sky/dashboard/out/index.html +1 -1
  52. sky/dashboard/out/infra/[context].html +1 -1
  53. sky/dashboard/out/infra.html +1 -1
  54. sky/dashboard/out/jobs/[job].html +1 -1
  55. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  56. sky/dashboard/out/jobs.html +1 -1
  57. sky/dashboard/out/users.html +1 -1
  58. sky/dashboard/out/volumes.html +1 -1
  59. sky/dashboard/out/workspace/new.html +1 -1
  60. sky/dashboard/out/workspaces/[name].html +1 -1
  61. sky/dashboard/out/workspaces.html +1 -1
  62. sky/execution.py +6 -4
  63. sky/global_user_state.py +22 -3
  64. sky/jobs/__init__.py +2 -0
  65. sky/jobs/client/sdk.py +67 -19
  66. sky/jobs/controller.py +2 -1
  67. sky/jobs/server/core.py +48 -1
  68. sky/jobs/server/server.py +52 -3
  69. sky/jobs/state.py +5 -1
  70. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  71. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  72. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  73. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  74. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  75. sky/serve/client/impl.py +93 -6
  76. sky/serve/client/sdk.py +22 -53
  77. sky/serve/constants.py +2 -1
  78. sky/serve/controller.py +4 -2
  79. sky/serve/serve_state.py +444 -324
  80. sky/serve/serve_utils.py +77 -46
  81. sky/serve/server/core.py +13 -197
  82. sky/serve/server/impl.py +239 -2
  83. sky/serve/service.py +8 -3
  84. sky/server/common.py +18 -7
  85. sky/server/constants.py +1 -1
  86. sky/server/requests/executor.py +5 -3
  87. sky/server/requests/payloads.py +19 -0
  88. sky/setup_files/alembic.ini +4 -0
  89. sky/task.py +18 -11
  90. sky/templates/kubernetes-ray.yml.j2 +5 -0
  91. sky/templates/sky-serve-controller.yaml.j2 +1 -0
  92. sky/usage/usage_lib.py +8 -6
  93. sky/utils/annotations.py +8 -3
  94. sky/utils/cli_utils/status_utils.py +1 -1
  95. sky/utils/common_utils.py +11 -1
  96. sky/utils/db/db_utils.py +31 -0
  97. sky/utils/db/migration_utils.py +6 -2
  98. sky/utils/kubernetes/deploy_remote_cluster.py +3 -1
  99. sky/utils/resource_checker.py +162 -21
  100. sky/volumes/client/sdk.py +4 -4
  101. sky/workspaces/core.py +210 -6
  102. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/METADATA +19 -14
  103. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/RECORD +109 -103
  104. sky/client/sdk.pyi +0 -301
  105. sky/dashboard/out/_next/static/Gelsd19kVxXcX7aQQGsGu/_buildManifest.js +0 -1
  106. sky/dashboard/out/_next/static/chunks/1043-75af48ca5d5aaf57.js +0 -1
  107. sky/dashboard/out/_next/static/chunks/1141-8678a9102cc5f67e.js +0 -11
  108. sky/dashboard/out/_next/static/chunks/1664-22b00e32c9ff96a4.js +0 -1
  109. sky/dashboard/out/_next/static/chunks/1871-ced1c14230cad6e1.js +0 -6
  110. sky/dashboard/out/_next/static/chunks/2003.f90b06bb1f914295.js +0 -1
  111. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
  112. sky/dashboard/out/_next/static/chunks/2622-951867535095b0eb.js +0 -1
  113. sky/dashboard/out/_next/static/chunks/3785.0a173cd4393f0fef.js +0 -1
  114. sky/dashboard/out/_next/static/chunks/4725.42f21f250f91f65b.js +0 -1
  115. sky/dashboard/out/_next/static/chunks/4869.18e6a4361a380763.js +0 -16
  116. sky/dashboard/out/_next/static/chunks/5230-f3bb2663e442e86c.js +0 -1
  117. sky/dashboard/out/_next/static/chunks/6601-2109d22e7861861c.js +0 -1
  118. sky/dashboard/out/_next/static/chunks/6990-08b2a1cae076a943.js +0 -1
  119. sky/dashboard/out/_next/static/chunks/8969-9a8cca241b30db83.js +0 -1
  120. sky/dashboard/out/_next/static/chunks/9025.99f29acb7617963e.js +0 -6
  121. sky/dashboard/out/_next/static/chunks/938-bda2685db5eae6cf.js +0 -1
  122. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-7cb24da04ca00956.js +0 -11
  123. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-1e95993124dbfc57.js +0 -1
  124. sky/dashboard/out/_next/static/chunks/pages/config-d56e64f30db7b42e.js +0 -1
  125. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90693cb88b5599a7.js +0 -11
  126. sky/dashboard/out/_next/static/chunks/pages/jobs-ab318e52eb4424a7.js +0 -1
  127. sky/dashboard/out/_next/static/chunks/pages/workspace/new-92f741084a89e27b.js +0 -1
  128. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-35e0de5bca55e594.js +0 -1
  129. sky/dashboard/out/_next/static/chunks/pages/workspaces-062525fb5462acb6.js +0 -1
  130. sky/dashboard/out/_next/static/chunks/webpack-387626669badf82e.js +0 -1
  131. sky/dashboard/out/_next/static/css/b3227360726f12eb.css +0 -3
  132. /sky/dashboard/out/_next/static/{Gelsd19kVxXcX7aQQGsGu → -DXZksWqf2waNHeU9YTQe}/_ssgManifest.js +0 -0
  133. /sky/dashboard/out/_next/static/chunks/{6135-2d7ed3350659d073.js → 6135-85426374db04811e.js} +0 -0
  134. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/WHEEL +0 -0
  135. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/entry_points.txt +0 -0
  136. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/licenses/LICENSE +0 -0
  137. {skypilot_nightly-1.0.0.dev20250806.dist-info → skypilot_nightly-1.0.0.dev20250808.dist-info}/top_level.txt +0 -0
sky/serve/serve_state.py CHANGED
@@ -3,100 +3,141 @@ import collections
3
3
  import enum
4
4
  import functools
5
5
  import json
6
- import pathlib
7
6
  import pickle
8
- import sqlite3
9
7
  import threading
10
8
  import typing
11
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any, Dict, List, Optional
10
+ import uuid
12
11
 
13
12
  import colorama
13
+ import sqlalchemy
14
+ from sqlalchemy import exc as sqlalchemy_exc
15
+ from sqlalchemy import orm
16
+ from sqlalchemy.dialects import postgresql
17
+ from sqlalchemy.dialects import sqlite
18
+ from sqlalchemy.ext import declarative
14
19
 
15
20
  from sky.serve import constants
21
+ from sky.utils import common_utils
16
22
  from sky.utils.db import db_utils
23
+ from sky.utils.db import migration_utils
17
24
 
18
25
  if typing.TYPE_CHECKING:
26
+ from sqlalchemy.engine import row
27
+
19
28
  from sky.serve import replica_managers
20
29
  from sky.serve import service_spec
21
30
 
22
-
23
- def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
31
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
32
+ _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
33
+
34
+ Base = declarative.declarative_base()
35
+
36
+ # === Database schema ===
37
+ services_table = sqlalchemy.Table(
38
+ 'services',
39
+ Base.metadata,
40
+ sqlalchemy.Column('name', sqlalchemy.Text, primary_key=True),
41
+ sqlalchemy.Column('controller_job_id',
42
+ sqlalchemy.Integer,
43
+ server_default=None),
44
+ sqlalchemy.Column('controller_port',
45
+ sqlalchemy.Integer,
46
+ server_default=None),
47
+ sqlalchemy.Column('load_balancer_port',
48
+ sqlalchemy.Integer,
49
+ server_default=None),
50
+ sqlalchemy.Column('status', sqlalchemy.Text),
51
+ sqlalchemy.Column('uptime', sqlalchemy.Integer, server_default=None),
52
+ sqlalchemy.Column('policy', sqlalchemy.Text, server_default=None),
53
+ sqlalchemy.Column('auto_restart', sqlalchemy.Integer, server_default=None),
54
+ sqlalchemy.Column('requested_resources',
55
+ sqlalchemy.LargeBinary,
56
+ server_default=None),
57
+ sqlalchemy.Column('requested_resources_str', sqlalchemy.Text),
58
+ sqlalchemy.Column('current_version',
59
+ sqlalchemy.Integer,
60
+ server_default=str(constants.INITIAL_VERSION)),
61
+ sqlalchemy.Column('active_versions',
62
+ sqlalchemy.Text,
63
+ server_default=json.dumps([])),
64
+ sqlalchemy.Column('load_balancing_policy',
65
+ sqlalchemy.Text,
66
+ server_default=None),
67
+ sqlalchemy.Column('tls_encrypted', sqlalchemy.Integer, server_default='0'),
68
+ sqlalchemy.Column('pool', sqlalchemy.Integer, server_default='0'),
69
+ sqlalchemy.Column('controller_pid', sqlalchemy.Integer,
70
+ server_default=None),
71
+ sqlalchemy.Column('hash', sqlalchemy.Text, server_default=None),
72
+ sqlalchemy.Column('entrypoint', sqlalchemy.Text, server_default=None),
73
+ )
74
+
75
+ replicas_table = sqlalchemy.Table(
76
+ 'replicas',
77
+ Base.metadata,
78
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
79
+ sqlalchemy.Column('replica_id', sqlalchemy.Integer, primary_key=True),
80
+ sqlalchemy.Column('replica_info', sqlalchemy.LargeBinary),
81
+ )
82
+
83
+ version_specs_table = sqlalchemy.Table(
84
+ 'version_specs',
85
+ Base.metadata,
86
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
87
+ sqlalchemy.Column('version', sqlalchemy.Integer, primary_key=True),
88
+ sqlalchemy.Column('spec', sqlalchemy.LargeBinary),
89
+ )
90
+
91
+ serve_ha_recovery_script_table = sqlalchemy.Table(
92
+ 'serve_ha_recovery_script',
93
+ Base.metadata,
94
+ sqlalchemy.Column('service_name', sqlalchemy.Text, primary_key=True),
95
+ sqlalchemy.Column('script', sqlalchemy.Text),
96
+ )
97
+
98
+
99
+ def create_table(engine: sqlalchemy.engine.Engine):
24
100
  """Creates the service and replica tables if they do not exist."""
25
101
 
26
- # auto_restart and requested_resources column is deprecated.
27
- cursor.execute("""\
28
- CREATE TABLE IF NOT EXISTS services (
29
- name TEXT PRIMARY KEY,
30
- controller_job_id INTEGER DEFAULT NULL,
31
- controller_port INTEGER DEFAULT NULL,
32
- load_balancer_port INTEGER DEFAULT NULL,
33
- status TEXT,
34
- uptime INTEGER DEFAULT NULL,
35
- policy TEXT DEFAULT NULL,
36
- auto_restart INTEGER DEFAULT NULL,
37
- requested_resources BLOB DEFAULT NULL)""")
38
- cursor.execute("""\
39
- CREATE TABLE IF NOT EXISTS replicas (
40
- service_name TEXT,
41
- replica_id INTEGER,
42
- replica_info BLOB,
43
- PRIMARY KEY (service_name, replica_id))""")
44
- cursor.execute("""\
45
- CREATE TABLE IF NOT EXISTS version_specs (
46
- version INTEGER,
47
- service_name TEXT,
48
- spec BLOB,
49
- PRIMARY KEY (service_name, version))""")
50
- cursor.execute("""\
51
- CREATE TABLE IF NOT EXISTS ha_recovery_script (
52
- service_name TEXT PRIMARY KEY,
53
- script TEXT)""")
54
- conn.commit()
55
-
56
- # Backward compatibility.
57
- db_utils.add_column_to_table(cursor, conn, 'services',
58
- 'requested_resources_str', 'TEXT')
59
- # Deprecated: switched to `active_versions` below for the version
60
- # considered active by the load balancer. The
61
- # authscaler/replica_manager version can be found in the
62
- # version_specs table.
63
- db_utils.add_column_to_table(
64
- cursor, conn, 'services', 'current_version',
65
- f'INTEGER DEFAULT {constants.INITIAL_VERSION}')
66
- # The versions that is activated for the service. This is a list
67
- # of integers in json format.
68
- db_utils.add_column_to_table(cursor, conn, 'services', 'active_versions',
69
- f'TEXT DEFAULT {json.dumps([])!r}')
70
- db_utils.add_column_to_table(cursor, conn, 'services',
71
- 'load_balancing_policy', 'TEXT DEFAULT NULL')
72
- # Whether the service's load balancer is encrypted with TLS.
73
- db_utils.add_column_to_table(cursor, conn, 'services', 'tls_encrypted',
74
- 'INTEGER DEFAULT 0')
75
- # Whether the service is a cluster pool.
76
- db_utils.add_column_to_table(cursor, conn, 'services', 'pool',
77
- 'INTEGER DEFAULT 0')
78
- # Add controller_pid for status tracking.
79
- db_utils.add_column_to_table(cursor,
80
- conn,
81
- 'services',
82
- 'controller_pid',
83
- 'INTEGER DEFAULT NULL',
84
- value_to_replace_existing_entries=-1)
85
- conn.commit()
86
-
87
-
88
- def _get_db_path() -> str:
89
- """Workaround to collapse multi-step Path ops for type checker.
90
- Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
91
- """
92
- path = pathlib.Path(constants.SKYSERVE_METADATA_DIR) / 'services.db'
93
- path = path.expanduser().absolute()
94
- path.parents[0].mkdir(parents=True, exist_ok=True)
95
- return str(path)
102
+ # Enable WAL mode to avoid locking issues.
103
+ # See: issue #3863, #1441 and PR #1509
104
+ # https://github.com/microsoft/WSL/issues/2395
105
+ # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
106
+ # This may cause the database locked problem from WSL issue #1441.
107
+ if (engine.dialect.name == db_utils.SQLAlchemyDialect.SQLITE.value and
108
+ not common_utils.is_wsl()):
109
+ try:
110
+ with orm.Session(engine) as session:
111
+ session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
112
+ session.commit()
113
+ except sqlalchemy_exc.OperationalError as e:
114
+ if 'database is locked' not in str(e):
115
+ raise
116
+ # If the database is locked, it is OK to continue, as the WAL mode
117
+ # is not critical and is likely to be enabled by other processes.
118
+
119
+ migration_utils.safe_alembic_upgrade(engine, migration_utils.SERVE_DB_NAME,
120
+ migration_utils.SERVE_VERSION)
121
+
96
122
 
123
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
124
+ global _SQLALCHEMY_ENGINE
97
125
 
98
- _DB_PATH = None
99
- _db_init_lock = threading.Lock()
126
+ if _SQLALCHEMY_ENGINE is not None:
127
+ return _SQLALCHEMY_ENGINE
128
+
129
+ with _SQLALCHEMY_ENGINE_LOCK:
130
+ if _SQLALCHEMY_ENGINE is not None:
131
+ return _SQLALCHEMY_ENGINE
132
+ # get an engine to the db
133
+ engine = migration_utils.get_engine('serve/services')
134
+
135
+ # run migrations if needed
136
+ create_table(engine)
137
+
138
+ # return engine
139
+ _SQLALCHEMY_ENGINE = engine
140
+ return _SQLALCHEMY_ENGINE
100
141
 
101
142
 
102
143
  def init_db(func):
@@ -104,19 +145,18 @@ def init_db(func):
104
145
 
105
146
  @functools.wraps(func)
106
147
  def wrapper(*args, **kwargs):
107
- global _DB_PATH
108
- if _DB_PATH is not None:
109
- return func(*args, **kwargs)
110
- with _db_init_lock:
111
- if _DB_PATH is None:
112
- _DB_PATH = _get_db_path()
113
- db_utils.SQLiteConn(_DB_PATH, create_table)
148
+ initialize_and_get_db()
114
149
  return func(*args, **kwargs)
115
150
 
116
151
  return wrapper
117
152
 
118
153
 
119
- _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
154
+ _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS = [
155
+ # sqlite
156
+ 'UNIQUE constraint failed: services.name',
157
+ # postgres
158
+ 'duplicate key value violates unique constraint "services_pkey"',
159
+ ]
120
160
 
121
161
 
122
162
  # === Statuses ===
@@ -284,31 +324,45 @@ _SERVICE_STATUS_TO_COLOR = {
284
324
  def add_service(name: str, controller_job_id: int, policy: str,
285
325
  requested_resources_str: str, load_balancing_policy: str,
286
326
  status: ServiceStatus, tls_encrypted: bool, pool: bool,
287
- controller_pid: int) -> bool:
327
+ controller_pid: int, entrypoint: str) -> bool:
288
328
  """Add a service in the database.
289
329
 
290
330
  Returns:
291
331
  True if the service is added successfully, False if the service already
292
332
  exists.
293
333
  """
294
- assert _DB_PATH is not None
334
+ assert _SQLALCHEMY_ENGINE is not None
295
335
  try:
296
- with db_utils.safe_cursor(_DB_PATH) as cursor:
297
- cursor.execute(
298
- """\
299
- INSERT INTO services
300
- (name, controller_job_id, status, policy,
301
- requested_resources_str, load_balancing_policy, tls_encrypted,
302
- pool, controller_pid)
303
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
304
- (name, controller_job_id, status.value, policy,
305
- requested_resources_str, load_balancing_policy,
306
- int(tls_encrypted), int(pool), controller_pid))
307
-
308
- except sqlite3.IntegrityError as e:
309
- if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
310
- raise RuntimeError('Unexpected database error') from e
311
- return False
336
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
337
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
338
+ db_utils.SQLAlchemyDialect.SQLITE.value):
339
+ insert_func = sqlite.insert
340
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
341
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
342
+ insert_func = postgresql.insert
343
+ else:
344
+ raise ValueError('Unsupported database dialect')
345
+
346
+ insert_stmt = insert_func(services_table).values(
347
+ name=name,
348
+ controller_job_id=controller_job_id,
349
+ status=status.value,
350
+ policy=policy,
351
+ requested_resources_str=requested_resources_str,
352
+ load_balancing_policy=load_balancing_policy,
353
+ tls_encrypted=int(tls_encrypted),
354
+ pool=int(pool),
355
+ controller_pid=controller_pid,
356
+ hash=str(uuid.uuid4()),
357
+ entrypoint=entrypoint)
358
+ session.execute(insert_stmt)
359
+ session.commit()
360
+
361
+ except sqlalchemy_exc.IntegrityError as e:
362
+ for msg in _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS:
363
+ if msg in str(e):
364
+ return False
365
+ raise RuntimeError('Unexpected database error') from e
312
366
  return True
313
367
 
314
368
 
@@ -319,33 +373,34 @@ def update_service_controller_pid(service_name: str,
319
373
 
320
374
  This is used to update the controller pid of a service on ha recovery.
321
375
  """
322
- assert _DB_PATH is not None
323
- with db_utils.safe_cursor(_DB_PATH) as cursor:
324
- cursor.execute(
325
- """\
326
- UPDATE services SET
327
- controller_pid=(?) WHERE name=(?)""",
328
- (controller_pid, service_name))
376
+ assert _SQLALCHEMY_ENGINE is not None
377
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
378
+ session.query(services_table).filter(
379
+ services_table.c.name == service_name).update(
380
+ {services_table.c.controller_pid: controller_pid})
381
+ session.commit()
329
382
 
330
383
 
331
384
  @init_db
332
385
  def remove_service(service_name: str) -> None:
333
386
  """Removes a service from the database."""
334
- assert _DB_PATH is not None
335
- with db_utils.safe_cursor(_DB_PATH) as cursor:
336
- cursor.execute("""\
337
- DELETE FROM services WHERE name=(?)""", (service_name,))
387
+ assert _SQLALCHEMY_ENGINE is not None
388
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
389
+ session.execute(
390
+ sqlalchemy.delete(services_table).where(
391
+ services_table.c.name == service_name))
392
+ session.commit()
338
393
 
339
394
 
340
395
  @init_db
341
396
  def set_service_uptime(service_name: str, uptime: int) -> None:
342
397
  """Sets the uptime of a service."""
343
- assert _DB_PATH is not None
344
- with db_utils.safe_cursor(_DB_PATH) as cursor:
345
- cursor.execute(
346
- """\
347
- UPDATE services SET
348
- uptime=(?) WHERE name=(?)""", (uptime, service_name))
398
+ assert _SQLALCHEMY_ENGINE is not None
399
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
400
+ session.query(services_table).filter(
401
+ services_table.c.name == service_name).update(
402
+ {services_table.c.uptime: uptime})
403
+ session.commit()
349
404
 
350
405
 
351
406
  @init_db
@@ -354,72 +409,71 @@ def set_service_status_and_active_versions(
354
409
  status: ServiceStatus,
355
410
  active_versions: Optional[List[int]] = None) -> None:
356
411
  """Sets the service status."""
357
- assert _DB_PATH is not None
358
- vars_to_set = 'status=(?)'
359
- values: Tuple[str, ...] = (status.value, service_name)
412
+ assert _SQLALCHEMY_ENGINE is not None
413
+ update_dict = {services_table.c.status: status.value}
360
414
  if active_versions is not None:
361
- vars_to_set = 'status=(?), active_versions=(?)'
362
- values = (status.value, json.dumps(active_versions), service_name)
363
- with db_utils.safe_cursor(_DB_PATH) as cursor:
364
- cursor.execute(
365
- f"""\
366
- UPDATE services SET
367
- {vars_to_set} WHERE name=(?)""", values)
415
+ update_dict[services_table.c.active_versions] = json.dumps(
416
+ active_versions)
417
+
418
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
419
+ session.query(services_table).filter(
420
+ services_table.c.name == service_name).update(update_dict)
421
+ session.commit()
368
422
 
369
423
 
370
424
  @init_db
371
425
  def set_service_controller_port(service_name: str,
372
426
  controller_port: int) -> None:
373
427
  """Sets the controller port of a service."""
374
- assert _DB_PATH is not None
375
- with db_utils.safe_cursor(_DB_PATH) as cursor:
376
- cursor.execute(
377
- """\
378
- UPDATE services SET
379
- controller_port=(?) WHERE name=(?)""",
380
- (controller_port, service_name))
428
+ assert _SQLALCHEMY_ENGINE is not None
429
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
430
+ session.query(services_table).filter(
431
+ services_table.c.name == service_name).update(
432
+ {services_table.c.controller_port: controller_port})
433
+ session.commit()
381
434
 
382
435
 
383
436
  @init_db
384
437
  def set_service_load_balancer_port(service_name: str,
385
438
  load_balancer_port: int) -> None:
386
439
  """Sets the load balancer port of a service."""
387
- assert _DB_PATH is not None
388
- with db_utils.safe_cursor(_DB_PATH) as cursor:
389
- cursor.execute(
390
- """\
391
- UPDATE services SET
392
- load_balancer_port=(?) WHERE name=(?)""",
393
- (load_balancer_port, service_name))
394
-
395
-
396
- def _get_service_from_row(row) -> Dict[str, Any]:
397
- (current_version, name, controller_job_id, controller_port,
398
- load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
399
- _, active_versions, load_balancing_policy, tls_encrypted, pool,
400
- controller_pid) = row[:17]
440
+ assert _SQLALCHEMY_ENGINE is not None
441
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
442
+ session.query(services_table).filter(
443
+ services_table.c.name == service_name).update(
444
+ {services_table.c.load_balancer_port: load_balancer_port})
445
+ session.commit()
446
+
447
+
448
+ def _get_service_from_row(r: 'row.RowMapping') -> Dict[str, Any]:
449
+ # Get the max_version from the first column (from the subquery)
450
+ current_version = r['max_version']
451
+
401
452
  record = {
402
- 'name': name,
403
- 'controller_job_id': controller_job_id,
404
- 'controller_port': controller_port,
405
- 'load_balancer_port': load_balancer_port,
406
- 'status': ServiceStatus[status],
407
- 'uptime': uptime,
408
- 'policy': policy,
453
+ 'name': r['name'],
454
+ 'controller_job_id': r['controller_job_id'],
455
+ 'controller_port': r['controller_port'],
456
+ 'load_balancer_port': r['load_balancer_port'],
457
+ 'status': ServiceStatus[r['status']],
458
+ 'uptime': r['uptime'],
459
+ 'policy': r['policy'],
409
460
  # The version of the autoscaler/replica manager are on. It can be larger
410
461
  # than the active versions as the load balancer may not consider the
411
462
  # latest version to be active for serving traffic.
412
463
  'version': current_version,
413
464
  # The versions that is active for the load balancer. This is a list of
414
465
  # integers in json format. This is mainly for display purpose.
415
- 'active_versions': json.loads(active_versions),
416
- 'requested_resources_str': requested_resources_str,
417
- 'load_balancing_policy': load_balancing_policy,
418
- 'tls_encrypted': bool(tls_encrypted),
419
- 'pool': bool(pool),
420
- 'controller_pid': controller_pid,
466
+ 'active_versions': json.loads(r['active_versions'])
467
+ if r['active_versions'] else [],
468
+ 'requested_resources_str': r['requested_resources_str'],
469
+ 'load_balancing_policy': r['load_balancing_policy'],
470
+ 'tls_encrypted': bool(r['tls_encrypted']),
471
+ 'pool': bool(r['pool']),
472
+ 'controller_pid': r['controller_pid'],
473
+ 'hash': r['hash'],
474
+ 'entrypoint': r['entrypoint'],
421
475
  }
422
- latest_spec = get_spec(name, current_version)
476
+ latest_spec = get_spec(r['name'], current_version)
423
477
  if latest_spec is not None:
424
478
  record['policy'] = latest_spec.autoscaling_policy_str()
425
479
  record['load_balancing_policy'] = latest_spec.load_balancing_policy
@@ -429,45 +483,69 @@ def _get_service_from_row(row) -> Dict[str, Any]:
429
483
  @init_db
430
484
  def get_services() -> List[Dict[str, Any]]:
431
485
  """Get all existing service records."""
432
- assert _DB_PATH is not None
433
- with db_utils.safe_cursor(_DB_PATH) as cursor:
434
- rows = cursor.execute('SELECT v.max_version, s.* FROM services s '
435
- 'JOIN ('
436
- 'SELECT service_name, MAX(version) as max_version'
437
- ' FROM version_specs GROUP BY service_name) v '
438
- 'ON s.name=v.service_name').fetchall()
486
+ assert _SQLALCHEMY_ENGINE is not None
487
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
488
+ subquery = sqlalchemy.select(
489
+ version_specs_table.c.service_name,
490
+ sqlalchemy.func.max(
491
+ version_specs_table.c.version).label('max_version')).group_by(
492
+ version_specs_table.c.service_name).alias('v')
493
+
494
+ query = sqlalchemy.select(
495
+ subquery.c.max_version, services_table).select_from(
496
+ services_table.join(
497
+ subquery, services_table.c.name == subquery.c.service_name))
498
+ rows = session.execute(query).fetchall()
439
499
  records = []
440
500
  for row in rows:
441
- records.append(_get_service_from_row(row))
501
+ records.append(_get_service_from_row(row._mapping)) # pylint: disable=protected-access
442
502
  return records
443
503
 
444
504
 
445
505
  @init_db
446
506
  def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:
447
507
  """Get all existing service records."""
448
- assert _DB_PATH is not None
449
- with db_utils.safe_cursor(_DB_PATH) as cursor:
450
- rows = cursor.execute(
451
- 'SELECT v.max_version, s.* FROM services s '
452
- 'JOIN ('
453
- 'SELECT service_name, MAX(version) as max_version '
454
- 'FROM version_specs WHERE service_name=(?)) v '
455
- 'ON s.name=v.service_name WHERE name=(?)',
456
- (service_name, service_name)).fetchall()
508
+ assert _SQLALCHEMY_ENGINE is not None
509
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
510
+ subquery = sqlalchemy.select(
511
+ version_specs_table.c.service_name,
512
+ sqlalchemy.func.max(
513
+ version_specs_table.c.version).label('max_version')
514
+ ).where(version_specs_table.c.service_name == service_name).group_by(
515
+ version_specs_table.c.service_name).alias('v')
516
+
517
+ query = sqlalchemy.select(
518
+ subquery.c.max_version, services_table).select_from(
519
+ services_table.join(
520
+ subquery,
521
+ services_table.c.name == subquery.c.service_name)).where(
522
+ services_table.c.name == service_name)
523
+
524
+ rows = session.execute(query).fetchall()
457
525
  for row in rows:
458
- return _get_service_from_row(row)
526
+ return _get_service_from_row(row._mapping) # pylint: disable=protected-access
459
527
  return None
460
528
 
461
529
 
530
+ @init_db
531
+ def get_service_hash(service_name: str) -> Optional[str]:
532
+ """Get the hash of a service."""
533
+ assert _SQLALCHEMY_ENGINE is not None
534
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
535
+ result = session.execute(
536
+ sqlalchemy.select(services_table.c.hash).where(
537
+ services_table.c.name == service_name)).fetchone()
538
+ return result[0] if result else None
539
+
540
+
462
541
  @init_db
463
542
  def get_service_versions(service_name: str) -> List[int]:
464
543
  """Gets all versions of a service."""
465
- assert _DB_PATH is not None
466
- with db_utils.safe_cursor(_DB_PATH) as cursor:
467
- rows = cursor.execute(
468
- """\
469
- SELECT DISTINCT version FROM version_specs
470
- WHERE service_name=(?)""", (service_name,)).fetchall()
544
+ assert _SQLALCHEMY_ENGINE is not None
545
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
546
+ rows = session.execute(
547
+ sqlalchemy.select(version_specs_table.c.version.distinct()).where(
548
+ version_specs_table.c.service_name == service_name)).fetchall()
471
549
  return [row[0] for row in rows]
472
550
 
473
551
 
@@ -483,17 +561,19 @@ def get_glob_service_names(
483
561
  Returns:
484
562
  A list of non-duplicated service names.
485
563
  """
486
- assert _DB_PATH is not None
487
- with db_utils.safe_cursor(_DB_PATH) as cursor:
564
+ assert _SQLALCHEMY_ENGINE is not None
565
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
488
566
  if service_names is None:
489
- rows = cursor.execute('SELECT name FROM services').fetchall()
567
+ rows = session.execute(sqlalchemy.select(
568
+ services_table.c.name)).fetchall()
490
569
  else:
491
570
  rows = []
492
571
  for service_name in service_names:
493
- rows.extend(
494
- cursor.execute(
495
- 'SELECT name FROM services WHERE name GLOB (?)',
496
- (service_name,)).fetchall())
572
+ pattern_rows = session.execute(
573
+ sqlalchemy.select(services_table.c.name).where(
574
+ services_table.c.name.like(
575
+ service_name.replace('*', '%')))).fetchall()
576
+ rows.extend(pattern_rows)
497
577
  return list({row[0] for row in rows})
498
578
 
499
579
 
@@ -502,26 +582,40 @@ def get_glob_service_names(
502
582
  def add_or_update_replica(service_name: str, replica_id: int,
503
583
  replica_info: 'replica_managers.ReplicaInfo') -> None:
504
584
  """Adds a replica to the database."""
505
- assert _DB_PATH is not None
506
- with db_utils.safe_cursor(_DB_PATH) as cursor:
507
- cursor.execute(
508
- """\
509
- INSERT OR REPLACE INTO replicas
510
- (service_name, replica_id, replica_info)
511
- VALUES (?, ?, ?)""",
512
- (service_name, replica_id, pickle.dumps(replica_info)))
585
+ assert _SQLALCHEMY_ENGINE is not None
586
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
587
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
588
+ db_utils.SQLAlchemyDialect.SQLITE.value):
589
+ insert_func = sqlite.insert
590
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
591
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
592
+ insert_func = postgresql.insert
593
+ else:
594
+ raise ValueError('Unsupported database dialect')
595
+
596
+ insert_stmt = insert_func(replicas_table).values(
597
+ service_name=service_name,
598
+ replica_id=replica_id,
599
+ replica_info=pickle.dumps(replica_info))
600
+
601
+ insert_stmt = insert_stmt.on_conflict_do_update(
602
+ index_elements=['service_name', 'replica_id'],
603
+ set_={'replica_info': insert_stmt.excluded.replica_info})
604
+
605
+ session.execute(insert_stmt)
606
+ session.commit()
513
607
 
514
608
 
515
609
  @init_db
516
610
  def remove_replica(service_name: str, replica_id: int) -> None:
517
611
  """Removes a replica from the database."""
518
- assert _DB_PATH is not None
519
- with db_utils.safe_cursor(_DB_PATH) as cursor:
520
- cursor.execute(
521
- """\
522
- DELETE FROM replicas
523
- WHERE service_name=(?)
524
- AND replica_id=(?)""", (service_name, replica_id))
612
+ assert _SQLALCHEMY_ENGINE is not None
613
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
614
+ session.execute(
615
+ sqlalchemy.delete(replicas_table).where(
616
+ sqlalchemy.and_(replicas_table.c.service_name == service_name,
617
+ replicas_table.c.replica_id == replica_id)))
618
+ session.commit()
525
619
 
526
620
 
527
621
  @init_db
@@ -529,37 +623,35 @@ def get_replica_info_from_id(
529
623
  service_name: str,
530
624
  replica_id: int) -> Optional['replica_managers.ReplicaInfo']:
531
625
  """Gets a replica info from the database."""
532
- assert _DB_PATH is not None
533
- with db_utils.safe_cursor(_DB_PATH) as cursor:
534
- rows = cursor.execute(
535
- """\
536
- SELECT replica_info FROM replicas
537
- WHERE service_name=(?)
538
- AND replica_id=(?)""", (service_name, replica_id)).fetchall()
539
- for row in rows:
540
- return pickle.loads(row[0])
541
- return None
626
+ assert _SQLALCHEMY_ENGINE is not None
627
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
628
+ result = session.execute(
629
+ sqlalchemy.select(replicas_table.c.replica_info).where(
630
+ sqlalchemy.and_(
631
+ replicas_table.c.service_name == service_name,
632
+ replicas_table.c.replica_id == replica_id))).fetchone()
633
+ return pickle.loads(result[0]) if result else None
542
634
 
543
635
 
544
636
  @init_db
545
637
  def get_replica_infos(
546
638
  service_name: str) -> List['replica_managers.ReplicaInfo']:
547
639
  """Gets all replica infos of a service."""
548
- assert _DB_PATH is not None
549
- with db_utils.safe_cursor(_DB_PATH) as cursor:
550
- rows = cursor.execute(
551
- """\
552
- SELECT replica_info FROM replicas
553
- WHERE service_name=(?)""", (service_name,)).fetchall()
640
+ assert _SQLALCHEMY_ENGINE is not None
641
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
642
+ rows = session.execute(
643
+ sqlalchemy.select(replicas_table.c.replica_info).where(
644
+ replicas_table.c.service_name == service_name)).fetchall()
554
645
  return [pickle.loads(row[0]) for row in rows]
555
646
 
556
647
 
557
648
  @init_db
558
649
  def total_number_provisioning_replicas() -> int:
559
650
  """Returns the total number of provisioning replicas."""
560
- assert _DB_PATH is not None
561
- with db_utils.safe_cursor(_DB_PATH) as cursor:
562
- rows = cursor.execute('SELECT replica_info FROM replicas').fetchall()
651
+ assert _SQLALCHEMY_ENGINE is not None
652
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
653
+ rows = session.execute(sqlalchemy.select(
654
+ replicas_table.c.replica_info)).fetchall()
563
655
  provisioning_count = 0
564
656
  for row in rows:
565
657
  replica_info: 'replica_managers.ReplicaInfo' = pickle.loads(row[0])
@@ -580,154 +672,182 @@ def get_replicas_at_status(
580
672
  @init_db
581
673
  def add_version(service_name: str) -> int:
582
674
  """Adds a version to the database."""
583
- assert _DB_PATH is not None
584
- with db_utils.safe_cursor(_DB_PATH) as cursor:
585
- cursor.execute(
586
- """\
587
- INSERT INTO version_specs
588
- (version, service_name, spec)
589
- VALUES (
590
- (SELECT COALESCE(MAX(version), 0) + 1 FROM
591
- version_specs WHERE service_name = ?), ?, ?)
592
- RETURNING version""",
593
- (service_name, service_name, pickle.dumps(None)))
594
-
595
- inserted_version = cursor.fetchone()[0]
596
-
597
- return inserted_version
675
+ assert _SQLALCHEMY_ENGINE is not None
676
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
677
+ # Insert new version with MAX(version) + 1 in a single atomic operation
678
+ max_version_subquery = sqlalchemy.select(
679
+ sqlalchemy.func.coalesce(
680
+ sqlalchemy.func.max(version_specs_table.c.version), 0) +
681
+ 1).where(version_specs_table.c.service_name ==
682
+ service_name).scalar_subquery()
683
+
684
+ # Use INSERT with subquery and RETURNING
685
+ insert_stmt = sqlalchemy.insert(version_specs_table).values(
686
+ service_name=service_name,
687
+ version=max_version_subquery,
688
+ spec=pickle.dumps(None)).returning(version_specs_table.c.version)
689
+
690
+ result = session.execute(insert_stmt)
691
+ new_version = result.scalar()
692
+ session.commit()
693
+ return new_version
598
694
 
599
695
 
600
696
  @init_db
601
697
  def add_or_update_version(service_name: str, version: int,
602
698
  spec: 'service_spec.SkyServiceSpec') -> None:
603
- assert _DB_PATH is not None
604
- with db_utils.safe_cursor(_DB_PATH) as cursor:
605
- cursor.execute(
606
- """\
607
- INSERT or REPLACE INTO version_specs
608
- (service_name, version, spec)
609
- VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
699
+ assert _SQLALCHEMY_ENGINE is not None
700
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
701
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
702
+ db_utils.SQLAlchemyDialect.SQLITE.value):
703
+ insert_func = sqlite.insert
704
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
705
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
706
+ insert_func = postgresql.insert
707
+ else:
708
+ raise ValueError('Unsupported database dialect')
709
+
710
+ insert_stmt = insert_func(version_specs_table).values(
711
+ service_name=service_name, version=version, spec=pickle.dumps(spec))
712
+
713
+ insert_stmt = insert_stmt.on_conflict_do_update(
714
+ index_elements=['service_name', 'version'],
715
+ set_={'spec': insert_stmt.excluded.spec})
716
+
717
+ session.execute(insert_stmt)
718
+ session.commit()
610
719
 
611
720
 
612
721
  @init_db
613
722
  def remove_service_versions(service_name: str) -> None:
614
723
  """Removes a replica from the database."""
615
- assert _DB_PATH is not None
616
- with db_utils.safe_cursor(_DB_PATH) as cursor:
617
- cursor.execute(
618
- """\
619
- DELETE FROM version_specs
620
- WHERE service_name=(?)""", (service_name,))
724
+ assert _SQLALCHEMY_ENGINE is not None
725
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
726
+ session.execute(
727
+ sqlalchemy.delete(version_specs_table).where(
728
+ version_specs_table.c.service_name == service_name))
729
+ session.commit()
621
730
 
622
731
 
623
732
  @init_db
624
733
  def get_spec(service_name: str,
625
734
  version: int) -> Optional['service_spec.SkyServiceSpec']:
626
735
  """Gets spec from the database."""
627
- assert _DB_PATH is not None
628
- with db_utils.safe_cursor(_DB_PATH) as cursor:
629
- rows = cursor.execute(
630
- """\
631
- SELECT spec FROM version_specs
632
- WHERE service_name=(?)
633
- AND version=(?)""", (service_name, version)).fetchall()
634
- for row in rows:
635
- return pickle.loads(row[0])
636
- return None
736
+ assert _SQLALCHEMY_ENGINE is not None
737
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
738
+ result = session.execute(
739
+ sqlalchemy.select(version_specs_table.c.spec).where(
740
+ sqlalchemy.and_(
741
+ version_specs_table.c.service_name == service_name,
742
+ version_specs_table.c.version == version))).fetchone()
743
+ return pickle.loads(result[0]) if result else None
637
744
 
638
745
 
639
746
  @init_db
640
747
  def delete_version(service_name: str, version: int) -> None:
641
748
  """Deletes a version from the database."""
642
- assert _DB_PATH is not None
643
- with db_utils.safe_cursor(_DB_PATH) as cursor:
644
- cursor.execute(
645
- """\
646
- DELETE FROM version_specs
647
- WHERE service_name=(?)
648
- AND version=(?)""", (service_name, version))
749
+ assert _SQLALCHEMY_ENGINE is not None
750
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
751
+ session.execute(
752
+ sqlalchemy.delete(version_specs_table).where(
753
+ sqlalchemy.and_(
754
+ version_specs_table.c.service_name == service_name,
755
+ version_specs_table.c.version == version)))
756
+ session.commit()
649
757
 
650
758
 
651
759
  @init_db
652
760
  def delete_all_versions(service_name: str) -> None:
653
761
  """Deletes all versions from the database."""
654
- assert _DB_PATH is not None
655
- with db_utils.safe_cursor(_DB_PATH) as cursor:
656
- cursor.execute(
657
- """\
658
- DELETE FROM version_specs
659
- WHERE service_name=(?)""", (service_name,))
762
+ assert _SQLALCHEMY_ENGINE is not None
763
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
764
+ session.execute(
765
+ sqlalchemy.delete(version_specs_table).where(
766
+ version_specs_table.c.service_name == service_name))
767
+ session.commit()
660
768
 
661
769
 
662
770
  @init_db
663
771
  def get_latest_version(service_name: str) -> Optional[int]:
664
- assert _DB_PATH is not None
665
- with db_utils.safe_cursor(_DB_PATH) as cursor:
666
- rows = cursor.execute(
667
- """\
668
- SELECT MAX(version) FROM version_specs
669
- WHERE service_name=(?)""", (service_name,)).fetchall()
670
- if not rows or rows[0][0] is None:
671
- return None
672
- return rows[0][0]
772
+ assert _SQLALCHEMY_ENGINE is not None
773
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
774
+ result = session.execute(
775
+ sqlalchemy.select(sqlalchemy.func.max(
776
+ version_specs_table.c.version)).where(
777
+ version_specs_table.c.service_name ==
778
+ service_name)).fetchone()
779
+ return result[0] if result else None
673
780
 
674
781
 
675
782
  @init_db
676
783
  def get_service_controller_port(service_name: str) -> int:
677
784
  """Gets the controller port of a service."""
678
- assert _DB_PATH is not None
679
- with db_utils.safe_cursor(_DB_PATH) as cursor:
680
- cursor.execute('SELECT controller_port FROM services WHERE name = ?',
681
- (service_name,))
682
- row = cursor.fetchone()
683
- if row is None:
785
+ assert _SQLALCHEMY_ENGINE is not None
786
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
787
+ result = session.execute(
788
+ sqlalchemy.select(services_table.c.controller_port).where(
789
+ services_table.c.name == service_name)).fetchone()
790
+ if result is None:
684
791
  raise ValueError(f'Service {service_name} does not exist.')
685
- return row[0]
792
+ return result[0]
686
793
 
687
794
 
688
795
  @init_db
689
796
  def get_service_load_balancer_port(service_name: str) -> int:
690
797
  """Gets the load balancer port of a service."""
691
- assert _DB_PATH is not None
692
- with db_utils.safe_cursor(_DB_PATH) as cursor:
693
- cursor.execute('SELECT load_balancer_port FROM services WHERE name = ?',
694
- (service_name,))
695
- row = cursor.fetchone()
696
- if row is None:
798
+ assert _SQLALCHEMY_ENGINE is not None
799
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
800
+ result = session.execute(
801
+ sqlalchemy.select(services_table.c.load_balancer_port).where(
802
+ services_table.c.name == service_name)).fetchone()
803
+ if result is None:
697
804
  raise ValueError(f'Service {service_name} does not exist.')
698
- return row[0]
805
+ return result[0]
699
806
 
700
807
 
701
808
  @init_db
702
809
  def get_ha_recovery_script(service_name: str) -> Optional[str]:
703
810
  """Gets the HA recovery script for a service."""
704
- assert _DB_PATH is not None
705
- with db_utils.safe_cursor(_DB_PATH) as cursor:
706
- cursor.execute(
707
- 'SELECT script FROM ha_recovery_script WHERE service_name = ?',
708
- (service_name,))
709
- row = cursor.fetchone()
710
- if row is None:
711
- return None
712
- return row[0]
811
+ assert _SQLALCHEMY_ENGINE is not None
812
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
813
+ result = session.execute(
814
+ sqlalchemy.select(serve_ha_recovery_script_table.c.script).where(
815
+ serve_ha_recovery_script_table.c.service_name ==
816
+ service_name)).fetchone()
817
+ return result[0] if result else None
713
818
 
714
819
 
715
820
  @init_db
716
821
  def set_ha_recovery_script(service_name: str, script: str) -> None:
717
822
  """Sets the HA recovery script for a service."""
718
- assert _DB_PATH is not None
719
- with db_utils.safe_cursor(_DB_PATH) as cursor:
720
- cursor.execute(
721
- """\
722
- INSERT OR REPLACE INTO ha_recovery_script
723
- (service_name, script)
724
- VALUES (?, ?)""", (service_name, script))
823
+ assert _SQLALCHEMY_ENGINE is not None
824
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
825
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
826
+ db_utils.SQLAlchemyDialect.SQLITE.value):
827
+ insert_func = sqlite.insert
828
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
829
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
830
+ insert_func = postgresql.insert
831
+ else:
832
+ raise ValueError('Unsupported database dialect')
833
+
834
+ insert_stmt = insert_func(serve_ha_recovery_script_table).values(
835
+ service_name=service_name, script=script)
836
+
837
+ insert_stmt = insert_stmt.on_conflict_do_update(
838
+ index_elements=['service_name'],
839
+ set_={'script': insert_stmt.excluded.script})
840
+
841
+ session.execute(insert_stmt)
842
+ session.commit()
725
843
 
726
844
 
727
845
  @init_db
728
846
  def remove_ha_recovery_script(service_name: str) -> None:
729
847
  """Removes the HA recovery script for a service."""
730
- assert _DB_PATH is not None
731
- with db_utils.safe_cursor(_DB_PATH) as cursor:
732
- cursor.execute('DELETE FROM ha_recovery_script WHERE service_name = ?',
733
- (service_name,))
848
+ assert _SQLALCHEMY_ENGINE is not None
849
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
850
+ session.execute(
851
+ sqlalchemy.delete(serve_ha_recovery_script_table).where(
852
+ serve_ha_recovery_script_table.c.service_name == service_name))
853
+ session.commit()