skypilot-nightly 1.0.0.dev20250607__py3-none-any.whl → 1.0.0.dev20250610__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. sky/__init__.py +2 -2
  2. sky/admin_policy.py +3 -0
  3. sky/authentication.py +1 -7
  4. sky/backends/backend_utils.py +18 -2
  5. sky/backends/cloud_vm_ray_backend.py +9 -20
  6. sky/check.py +4 -3
  7. sky/cli.py +6 -9
  8. sky/client/cli.py +6 -9
  9. sky/client/sdk.py +49 -4
  10. sky/clouds/kubernetes.py +15 -24
  11. sky/core.py +3 -2
  12. sky/dashboard/out/404.html +1 -1
  13. sky/dashboard/out/_next/static/4lwUJxN6KwBqUxqO1VccB/_buildManifest.js +1 -0
  14. sky/dashboard/out/_next/static/chunks/211.692afc57e812ae1a.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/350.9e123a4551f68b0d.js +1 -0
  16. sky/dashboard/out/_next/static/chunks/37-d8aebf1683522a0b.js +6 -0
  17. sky/dashboard/out/_next/static/chunks/42.d39e24467181b06b.js +6 -0
  18. sky/dashboard/out/_next/static/chunks/443.b2242d0efcdf5f47.js +1 -0
  19. sky/dashboard/out/_next/static/chunks/470-4d1a5dbe58a8a2b9.js +1 -0
  20. sky/dashboard/out/_next/static/chunks/{121-865d2bf8a3b84c6a.js → 491.b3d264269613fe09.js} +3 -3
  21. sky/dashboard/out/_next/static/chunks/513.211357a2914a34b2.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/600.9cc76ec442b22e10.js +16 -0
  23. sky/dashboard/out/_next/static/chunks/616-d6128fa9e7cae6e6.js +39 -0
  24. sky/dashboard/out/_next/static/chunks/664-047bc03493fda379.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/682.4dd5dc116f740b5f.js +6 -0
  26. sky/dashboard/out/_next/static/chunks/760-a89d354797ce7af5.js +1 -0
  27. sky/dashboard/out/_next/static/chunks/799-3625946b2ec2eb30.js +8 -0
  28. sky/dashboard/out/_next/static/chunks/804-4c9fc53aa74bc191.js +21 -0
  29. sky/dashboard/out/_next/static/chunks/843-6fcc4bf91ac45b39.js +11 -0
  30. sky/dashboard/out/_next/static/chunks/856-0776dc6ed6000c39.js +1 -0
  31. sky/dashboard/out/_next/static/chunks/901-b424d293275e1fd7.js +1 -0
  32. sky/dashboard/out/_next/static/chunks/938-a75b7712639298b7.js +1 -0
  33. sky/dashboard/out/_next/static/chunks/947-6620842ef80ae879.js +35 -0
  34. sky/dashboard/out/_next/static/chunks/969-20d54a9d998dc102.js +1 -0
  35. sky/dashboard/out/_next/static/chunks/973-c807fc34f09c7df3.js +1 -0
  36. sky/dashboard/out/_next/static/chunks/pages/_app-4768de0aede04dc9.js +20 -0
  37. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-89216c616dbaa9c5.js +6 -0
  38. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-451a14e7e755ebbc.js +6 -0
  39. sky/dashboard/out/_next/static/chunks/pages/clusters-e56b17fd85d0ba58.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/pages/config-497a35a7ed49734a.js +1 -0
  41. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-d2910be98e9227cb.js +1 -0
  42. sky/dashboard/out/_next/static/chunks/pages/infra-780860bcc1103945.js +1 -0
  43. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-b3dbf38b51cb29be.js +16 -0
  44. sky/dashboard/out/_next/static/chunks/pages/jobs-fe233baf3d073491.js +1 -0
  45. sky/dashboard/out/_next/static/chunks/pages/users-c69ffcab9d6e5269.js +1 -0
  46. sky/dashboard/out/_next/static/chunks/pages/workspace/new-31aa8bdcb7592635.js +1 -0
  47. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c8c2191328532b7d.js +1 -0
  48. sky/dashboard/out/_next/static/chunks/pages/workspaces-82e6601baa5dd280.js +1 -0
  49. sky/dashboard/out/_next/static/chunks/webpack-0574a5a4ba3cf0ac.js +1 -0
  50. sky/dashboard/out/_next/static/css/8b1c8321d4c02372.css +3 -0
  51. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  52. sky/dashboard/out/clusters/[cluster].html +1 -1
  53. sky/dashboard/out/clusters.html +1 -1
  54. sky/dashboard/out/config.html +1 -1
  55. sky/dashboard/out/index.html +1 -1
  56. sky/dashboard/out/infra/[context].html +1 -1
  57. sky/dashboard/out/infra.html +1 -1
  58. sky/dashboard/out/jobs/[job].html +1 -1
  59. sky/dashboard/out/jobs.html +1 -1
  60. sky/dashboard/out/users.html +1 -1
  61. sky/dashboard/out/workspace/new.html +1 -1
  62. sky/dashboard/out/workspaces/[name].html +1 -1
  63. sky/dashboard/out/workspaces.html +1 -1
  64. sky/exceptions.py +23 -0
  65. sky/global_user_state.py +192 -80
  66. sky/jobs/client/sdk.py +29 -21
  67. sky/jobs/server/core.py +9 -1
  68. sky/jobs/server/server.py +0 -95
  69. sky/jobs/utils.py +2 -1
  70. sky/models.py +18 -0
  71. sky/provision/kubernetes/constants.py +9 -0
  72. sky/provision/kubernetes/utils.py +106 -7
  73. sky/serve/client/sdk.py +56 -45
  74. sky/serve/server/core.py +1 -1
  75. sky/server/common.py +5 -7
  76. sky/server/constants.py +0 -2
  77. sky/server/requests/executor.py +60 -22
  78. sky/server/requests/payloads.py +3 -0
  79. sky/server/requests/process.py +69 -29
  80. sky/server/requests/requests.py +4 -3
  81. sky/server/server.py +23 -5
  82. sky/server/stream_utils.py +111 -55
  83. sky/skylet/constants.py +4 -2
  84. sky/skylet/job_lib.py +2 -1
  85. sky/skypilot_config.py +108 -25
  86. sky/users/model.conf +1 -1
  87. sky/users/permission.py +149 -32
  88. sky/users/rbac.py +26 -0
  89. sky/users/server.py +14 -13
  90. sky/utils/admin_policy_utils.py +9 -3
  91. sky/utils/common.py +6 -1
  92. sky/utils/common_utils.py +21 -3
  93. sky/utils/context.py +21 -1
  94. sky/utils/controller_utils.py +16 -1
  95. sky/utils/kubernetes/exec_kubeconfig_converter.py +19 -47
  96. sky/utils/schemas.py +9 -0
  97. sky/workspaces/core.py +100 -8
  98. sky/workspaces/server.py +15 -2
  99. sky/workspaces/utils.py +56 -0
  100. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/METADATA +1 -1
  101. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/RECORD +106 -94
  102. sky/dashboard/out/_next/static/1qG0HTmVilJPxQdBk0fX5/_buildManifest.js +0 -1
  103. sky/dashboard/out/_next/static/chunks/236-619ed0248fb6fdd9.js +0 -6
  104. sky/dashboard/out/_next/static/chunks/293-351268365226d251.js +0 -1
  105. sky/dashboard/out/_next/static/chunks/37-600191c5804dcae2.js +0 -6
  106. sky/dashboard/out/_next/static/chunks/470-ad1e0db3afcbd9c9.js +0 -1
  107. sky/dashboard/out/_next/static/chunks/614-635a84e87800f99e.js +0 -66
  108. sky/dashboard/out/_next/static/chunks/682-b60cfdacc15202e8.js +0 -6
  109. sky/dashboard/out/_next/static/chunks/843-c296541442d4af88.js +0 -11
  110. sky/dashboard/out/_next/static/chunks/856-3a32da4b84176f6d.js +0 -1
  111. sky/dashboard/out/_next/static/chunks/969-2c584e28e6b4b106.js +0 -1
  112. sky/dashboard/out/_next/static/chunks/973-6d78a0814682d771.js +0 -1
  113. sky/dashboard/out/_next/static/chunks/pages/_app-cb81dc4d27f4d009.js +0 -1
  114. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-18aed9b56247d074.js +0 -6
  115. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-b919a73aecdfa78f.js +0 -6
  116. sky/dashboard/out/_next/static/chunks/pages/clusters-4f6b9dd9abcb33ad.js +0 -1
  117. sky/dashboard/out/_next/static/chunks/pages/config-fe375a56342cf609.js +0 -6
  118. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-3a18d0eeb5119fe4.js +0 -1
  119. sky/dashboard/out/_next/static/chunks/pages/infra-a1a6abeeb58c1051.js +0 -1
  120. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1354e28c81eeb686.js +0 -16
  121. sky/dashboard/out/_next/static/chunks/pages/jobs-23bfc8bf373423db.js +0 -1
  122. sky/dashboard/out/_next/static/chunks/pages/users-5800045bd04e69c2.js +0 -16
  123. sky/dashboard/out/_next/static/chunks/pages/workspace/new-e1f9c0c3ff7ac4bd.js +0 -1
  124. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-686590e0ee4b2412.js +0 -1
  125. sky/dashboard/out/_next/static/chunks/pages/workspaces-76b07aa5da91b0df.js +0 -1
  126. sky/dashboard/out/_next/static/chunks/webpack-65d465f948974c0d.js +0 -1
  127. sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +0 -3
  128. /sky/dashboard/out/_next/static/{1qG0HTmVilJPxQdBk0fX5 → 4lwUJxN6KwBqUxqO1VccB}/_ssgManifest.js +0 -0
  129. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/WHEEL +0 -0
  130. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/entry_points.txt +0 -0
  131. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/licenses/LICENSE +0 -0
  132. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250610.dist-info}/top_level.txt +0 -0
sky/global_user_state.py CHANGED
@@ -6,11 +6,13 @@ Concepts:
6
6
  - Cluster handle: (non-user facing) an opaque backend handle for us to
7
7
  interact with a cluster.
8
8
  """
9
+ import functools
9
10
  import json
10
11
  import os
11
12
  import pathlib
12
13
  import pickle
13
14
  import re
15
+ import threading
14
16
  import time
15
17
  import typing
16
18
  from typing import Any, Dict, List, Optional, Set, Tuple
@@ -44,6 +46,9 @@ logger = sky_logging.init_logger(__name__)
44
46
 
45
47
  _ENABLED_CLOUDS_KEY_PREFIX = 'enabled_clouds_'
46
48
 
49
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
50
+ _DB_INIT_LOCK = threading.Lock()
51
+
47
52
  Base = declarative.declarative_base()
48
53
 
49
54
  config_table = sqlalchemy.Table(
@@ -171,11 +176,11 @@ def create_table():
171
176
  # https://github.com/microsoft/WSL/issues/2395
172
177
  # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
173
178
  # This may cause the database locked problem from WSL issue #1441.
174
- if (SQLALCHEMY_ENGINE.dialect.name
179
+ if (_SQLALCHEMY_ENGINE.dialect.name
175
180
  == db_utils.SQLAlchemyDialect.SQLITE.value and
176
181
  not common_utils.is_wsl()):
177
182
  try:
178
- with orm.Session(SQLALCHEMY_ENGINE) as session:
183
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
179
184
  session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
180
185
  session.commit()
181
186
  except sqlalchemy_exc.OperationalError as e:
@@ -185,12 +190,12 @@ def create_table():
185
190
  # is not critical and is likely to be enabled by other processes.
186
191
 
187
192
  # Create tables if they don't exist
188
- Base.metadata.create_all(bind=SQLALCHEMY_ENGINE)
193
+ Base.metadata.create_all(bind=_SQLALCHEMY_ENGINE)
189
194
 
190
195
  # For backward compatibility.
191
196
  # TODO(zhwu): Remove this function after all users have migrated to
192
197
  # the latest version of SkyPilot.
193
- with orm.Session(SQLALCHEMY_ENGINE) as session:
198
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
194
199
  # Add autostop column to clusters table
195
200
  db_utils.add_column_to_table_sqlalchemy(session,
196
201
  'clusters',
@@ -258,7 +263,8 @@ def create_table():
258
263
  'user_hash',
259
264
  sqlalchemy.Text(),
260
265
  default_statement='DEFAULT NULL',
261
- value_to_replace_existing_entries=common_utils.get_user_hash())
266
+ value_to_replace_existing_entries=common_utils.get_current_user(
267
+ ).id)
262
268
  db_utils.add_column_to_table_sqlalchemy(
263
269
  session,
264
270
  'clusters',
@@ -297,30 +303,53 @@ def create_table():
297
303
  session.commit()
298
304
 
299
305
 
300
- conn_string = None
301
- if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
302
- conn_string = skypilot_config.get_nested(('db',), None)
303
- if conn_string:
304
- logger.debug(f'using db URI from {conn_string}')
305
- SQLALCHEMY_ENGINE = sqlalchemy.create_engine(conn_string)
306
- else:
307
- _DB_PATH = os.path.expanduser('~/.sky/state.db')
308
- pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)
309
- SQLALCHEMY_ENGINE = sqlalchemy.create_engine('sqlite:///' + _DB_PATH)
310
- create_table()
311
-
312
-
306
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
307
+ global _SQLALCHEMY_ENGINE
308
+ if _SQLALCHEMY_ENGINE is not None:
309
+ return _SQLALCHEMY_ENGINE
310
+ with _DB_INIT_LOCK:
311
+ if _SQLALCHEMY_ENGINE is None:
312
+ conn_string = None
313
+ if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
314
+ conn_string = skypilot_config.get_nested(('db',), None)
315
+ if conn_string:
316
+ logger.debug(f'using db URI from {conn_string}')
317
+ _SQLALCHEMY_ENGINE = sqlalchemy.create_engine(conn_string)
318
+ else:
319
+ db_path = os.path.expanduser('~/.sky/state.db')
320
+ pathlib.Path(db_path).parents[0].mkdir(parents=True,
321
+ exist_ok=True)
322
+ _SQLALCHEMY_ENGINE = sqlalchemy.create_engine('sqlite:///' +
323
+ db_path)
324
+ create_table()
325
+ return _SQLALCHEMY_ENGINE
326
+
327
+
328
+ def _init_db(func):
329
+ """Initialize the database."""
330
+
331
+ @functools.wraps(func)
332
+ def wrapper(*args, **kwargs):
333
+ initialize_and_get_db()
334
+ return func(*args, **kwargs)
335
+
336
+ return wrapper
337
+
338
+
339
+ @_init_db
313
340
  def add_or_update_user(user: models.User) -> bool:
314
341
  """Store the mapping from user hash to user name for display purposes.
315
342
 
316
343
  Returns:
317
344
  Boolean: whether the user is newly added
318
345
  """
346
+ assert _SQLALCHEMY_ENGINE is not None
347
+
319
348
  if user.name is None:
320
349
  return False
321
350
 
322
- with orm.Session(SQLALCHEMY_ENGINE) as session:
323
- if (SQLALCHEMY_ENGINE.dialect.name ==
351
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
352
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
324
353
  db_utils.SQLAlchemyDialect.SQLITE.value):
325
354
  # For SQLite, use INSERT OR IGNORE followed by UPDATE to detect new
326
355
  # vs existing
@@ -342,7 +371,7 @@ def add_or_update_user(user: models.User) -> bool:
342
371
  session.commit()
343
372
  return was_inserted
344
373
 
345
- elif (SQLALCHEMY_ENGINE.dialect.name ==
374
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
346
375
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
347
376
  # For PostgreSQL, use INSERT ... ON CONFLICT with RETURNING to
348
377
  # detect insert vs update
@@ -370,20 +399,25 @@ def add_or_update_user(user: models.User) -> bool:
370
399
  raise ValueError('Unsupported database dialect')
371
400
 
372
401
 
373
- def get_user(user_id: str) -> models.User:
374
- with orm.Session(SQLALCHEMY_ENGINE) as session:
402
+ @_init_db
403
+ def get_user(user_id: str) -> Optional[models.User]:
404
+ assert _SQLALCHEMY_ENGINE is not None
405
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
375
406
  row = session.query(user_table).filter_by(id=user_id).first()
376
407
  if row is None:
377
- return models.User(id=user_id)
408
+ return None
378
409
  return models.User(id=row.id, name=row.name)
379
410
 
380
411
 
412
+ @_init_db
381
413
  def get_all_users() -> List[models.User]:
382
- with orm.Session(SQLALCHEMY_ENGINE) as session:
414
+ assert _SQLALCHEMY_ENGINE is not None
415
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
383
416
  rows = session.query(user_table).all()
384
417
  return [models.User(id=row.id, name=row.name) for row in rows]
385
418
 
386
419
 
420
+ @_init_db
387
421
  def add_or_update_cluster(cluster_name: str,
388
422
  cluster_handle: 'backends.ResourceHandle',
389
423
  requested_resources: Optional[Set[Any]],
@@ -404,6 +438,7 @@ def add_or_update_cluster(cluster_name: str,
404
438
  config_hash: Configuration hash for the cluster.
405
439
  task_config: The config of the task being launched.
406
440
  """
441
+ assert _SQLALCHEMY_ENGINE is not None
407
442
  # FIXME: launched_at will be changed when `sky launch -c` is called.
408
443
  handle = pickle.dumps(cluster_handle)
409
444
  cluster_launched_at = int(time.time()) if is_launch else None
@@ -436,7 +471,7 @@ def add_or_update_cluster(cluster_name: str,
436
471
  cluster_launched_at = int(time.time())
437
472
  usage_intervals.append((cluster_launched_at, None))
438
473
 
439
- user_hash = common_utils.get_user_hash()
474
+ user_hash = common_utils.get_current_user().id
440
475
  active_workspace = skypilot_config.get_active_workspace()
441
476
 
442
477
  conditional_values = {}
@@ -456,7 +491,7 @@ def add_or_update_cluster(cluster_name: str,
456
491
  'config_hash': config_hash,
457
492
  })
458
493
 
459
- with orm.Session(SQLALCHEMY_ENGINE) as session:
494
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
460
495
  # with_for_update() locks the row until commit() or rollback()
461
496
  # is called, or until the code escapes the with block.
462
497
  cluster_row = session.query(cluster_table).filter_by(
@@ -483,10 +518,10 @@ def add_or_update_cluster(cluster_name: str,
483
518
  'last_creation_command': last_use,
484
519
  })
485
520
 
486
- if (SQLALCHEMY_ENGINE.dialect.name ==
521
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
487
522
  db_utils.SQLAlchemyDialect.SQLITE.value):
488
523
  insert_func = sqlite.insert
489
- elif (SQLALCHEMY_ENGINE.dialect.name ==
524
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
490
525
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
491
526
  insert_func = postgresql.insert
492
527
  else:
@@ -561,29 +596,34 @@ def _get_user_hash_or_current_user(user_hash: Optional[str]) -> str:
561
596
  return common_utils.get_user_hash()
562
597
 
563
598
 
599
+ @_init_db
564
600
  def update_cluster_handle(cluster_name: str,
565
601
  cluster_handle: 'backends.ResourceHandle'):
602
+ assert _SQLALCHEMY_ENGINE is not None
566
603
  handle = pickle.dumps(cluster_handle)
567
- with orm.Session(SQLALCHEMY_ENGINE) as session:
604
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
568
605
  session.query(cluster_table).filter_by(name=cluster_name).update(
569
606
  {cluster_table.c.handle: handle})
570
607
  session.commit()
571
608
 
572
609
 
610
+ @_init_db
573
611
  def update_last_use(cluster_name: str):
574
612
  """Updates the last used command for the cluster."""
575
- with orm.Session(SQLALCHEMY_ENGINE) as session:
613
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
576
614
  session.query(cluster_table).filter_by(name=cluster_name).update(
577
615
  {cluster_table.c.last_use: common_utils.get_current_command()})
578
616
  session.commit()
579
617
 
580
618
 
619
+ @_init_db
581
620
  def remove_cluster(cluster_name: str, terminate: bool) -> None:
582
621
  """Removes cluster_name mapping."""
622
+ assert _SQLALCHEMY_ENGINE is not None
583
623
  cluster_hash = _get_hash_for_existing_cluster(cluster_name)
584
624
  usage_intervals = _get_cluster_usage_intervals(cluster_hash)
585
625
 
586
- with orm.Session(SQLALCHEMY_ENGINE) as session:
626
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
587
627
  # usage_intervals is not None and not empty
588
628
  if usage_intervals:
589
629
  assert cluster_hash is not None, cluster_name
@@ -613,24 +653,28 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None:
613
653
  session.commit()
614
654
 
615
655
 
656
+ @_init_db
616
657
  def get_handle_from_cluster_name(
617
658
  cluster_name: str) -> Optional['backends.ResourceHandle']:
659
+ assert _SQLALCHEMY_ENGINE is not None
618
660
  assert cluster_name is not None, 'cluster_name cannot be None'
619
- with orm.Session(SQLALCHEMY_ENGINE) as session:
661
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
620
662
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
621
663
  if row is None:
622
664
  return None
623
665
  return pickle.loads(row.handle)
624
666
 
625
667
 
668
+ @_init_db
626
669
  def get_glob_cluster_names(cluster_name: str) -> List[str]:
670
+ assert _SQLALCHEMY_ENGINE is not None
627
671
  assert cluster_name is not None, 'cluster_name cannot be None'
628
- with orm.Session(SQLALCHEMY_ENGINE) as session:
629
- if (SQLALCHEMY_ENGINE.dialect.name ==
672
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
673
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
630
674
  db_utils.SQLAlchemyDialect.SQLITE.value):
631
675
  rows = session.query(cluster_table).filter(
632
676
  cluster_table.c.name.op('GLOB')(cluster_name)).all()
633
- elif (SQLALCHEMY_ENGINE.dialect.name ==
677
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
634
678
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
635
679
  rows = session.query(cluster_table).filter(
636
680
  cluster_table.c.name.op('SIMILAR TO')(
@@ -640,10 +684,12 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]:
640
684
  return [row.name for row in rows]
641
685
 
642
686
 
687
+ @_init_db
643
688
  def set_cluster_status(cluster_name: str,
644
689
  status: status_lib.ClusterStatus) -> None:
690
+ assert _SQLALCHEMY_ENGINE is not None
645
691
  current_time = int(time.time())
646
- with orm.Session(SQLALCHEMY_ENGINE) as session:
692
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
647
693
  count = session.query(cluster_table).filter_by(
648
694
  name=cluster_name).update({
649
695
  cluster_table.c.status: status.value,
@@ -655,9 +701,11 @@ def set_cluster_status(cluster_name: str,
655
701
  raise ValueError(f'Cluster {cluster_name} not found.')
656
702
 
657
703
 
704
+ @_init_db
658
705
  def set_cluster_autostop_value(cluster_name: str, idle_minutes: int,
659
706
  to_down: bool) -> None:
660
- with orm.Session(SQLALCHEMY_ENGINE) as session:
707
+ assert _SQLALCHEMY_ENGINE is not None
708
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
661
709
  count = session.query(cluster_table).filter_by(
662
710
  name=cluster_name).update({
663
711
  cluster_table.c.autostop: idle_minutes,
@@ -669,24 +717,30 @@ def set_cluster_autostop_value(cluster_name: str, idle_minutes: int,
669
717
  raise ValueError(f'Cluster {cluster_name} not found.')
670
718
 
671
719
 
720
+ @_init_db
672
721
  def get_cluster_launch_time(cluster_name: str) -> Optional[int]:
673
- with orm.Session(SQLALCHEMY_ENGINE) as session:
722
+ assert _SQLALCHEMY_ENGINE is not None
723
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
674
724
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
675
725
  if row is None or row.launched_at is None:
676
726
  return None
677
727
  return int(row.launched_at)
678
728
 
679
729
 
730
+ @_init_db
680
731
  def get_cluster_info(cluster_name: str) -> Optional[Dict[str, Any]]:
681
- with orm.Session(SQLALCHEMY_ENGINE) as session:
732
+ assert _SQLALCHEMY_ENGINE is not None
733
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
682
734
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
683
735
  if row is None or row.metadata is None:
684
736
  return None
685
737
  return json.loads(row.metadata)
686
738
 
687
739
 
740
+ @_init_db
688
741
  def set_cluster_info(cluster_name: str, metadata: Dict[str, Any]) -> None:
689
- with orm.Session(SQLALCHEMY_ENGINE) as session:
742
+ assert _SQLALCHEMY_ENGINE is not None
743
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
690
744
  count = session.query(cluster_table).filter_by(
691
745
  name=cluster_name).update(
692
746
  {cluster_table.c.metadata: json.dumps(metadata)})
@@ -696,18 +750,22 @@ def set_cluster_info(cluster_name: str, metadata: Dict[str, Any]) -> None:
696
750
  raise ValueError(f'Cluster {cluster_name} not found.')
697
751
 
698
752
 
753
+ @_init_db
699
754
  def get_cluster_storage_mounts_metadata(
700
755
  cluster_name: str) -> Optional[Dict[str, Any]]:
701
- with orm.Session(SQLALCHEMY_ENGINE) as session:
756
+ assert _SQLALCHEMY_ENGINE is not None
757
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
702
758
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
703
759
  if row is None or row.storage_mounts_metadata is None:
704
760
  return None
705
761
  return pickle.loads(row.storage_mounts_metadata)
706
762
 
707
763
 
764
+ @_init_db
708
765
  def set_cluster_storage_mounts_metadata(
709
766
  cluster_name: str, storage_mounts_metadata: Dict[str, Any]) -> None:
710
- with orm.Session(SQLALCHEMY_ENGINE) as session:
767
+ assert _SQLALCHEMY_ENGINE is not None
768
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
711
769
  count = session.query(cluster_table).filter_by(
712
770
  name=cluster_name).update({
713
771
  cluster_table.c.storage_mounts_metadata:
@@ -719,12 +777,14 @@ def set_cluster_storage_mounts_metadata(
719
777
  raise ValueError(f'Cluster {cluster_name} not found.')
720
778
 
721
779
 
780
+ @_init_db
722
781
  def _get_cluster_usage_intervals(
723
782
  cluster_hash: Optional[str]
724
783
  ) -> Optional[List[Tuple[int, Optional[int]]]]:
784
+ assert _SQLALCHEMY_ENGINE is not None
725
785
  if cluster_hash is None:
726
786
  return None
727
- with orm.Session(SQLALCHEMY_ENGINE) as session:
787
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
728
788
  row = session.query(cluster_history_table).filter_by(
729
789
  cluster_hash=cluster_hash).first()
730
790
  if row is None or row.usage_intervals is None:
@@ -758,10 +818,12 @@ def _get_cluster_duration(cluster_hash: str) -> int:
758
818
  return total_duration
759
819
 
760
820
 
821
+ @_init_db
761
822
  def _set_cluster_usage_intervals(
762
823
  cluster_hash: str, usage_intervals: List[Tuple[int,
763
824
  Optional[int]]]) -> None:
764
- with orm.Session(SQLALCHEMY_ENGINE) as session:
825
+ assert _SQLALCHEMY_ENGINE is not None
826
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
765
827
  count = session.query(cluster_history_table).filter_by(
766
828
  cluster_hash=cluster_hash).update({
767
829
  cluster_history_table.c.usage_intervals:
@@ -773,12 +835,14 @@ def _set_cluster_usage_intervals(
773
835
  raise ValueError(f'Cluster hash {cluster_hash} not found.')
774
836
 
775
837
 
838
+ @_init_db
776
839
  def set_owner_identity_for_cluster(cluster_name: str,
777
840
  owner_identity: Optional[List[str]]) -> None:
841
+ assert _SQLALCHEMY_ENGINE is not None
778
842
  if owner_identity is None:
779
843
  return
780
844
  owner_identity_str = json.dumps(owner_identity)
781
- with orm.Session(SQLALCHEMY_ENGINE) as session:
845
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
782
846
  count = session.query(cluster_table).filter_by(
783
847
  name=cluster_name).update(
784
848
  {cluster_table.c.owner: owner_identity_str})
@@ -788,17 +852,21 @@ def set_owner_identity_for_cluster(cluster_name: str,
788
852
  raise ValueError(f'Cluster {cluster_name} not found.')
789
853
 
790
854
 
855
+ @_init_db
791
856
  def _get_hash_for_existing_cluster(cluster_name: str) -> Optional[str]:
792
- with orm.Session(SQLALCHEMY_ENGINE) as session:
857
+ assert _SQLALCHEMY_ENGINE is not None
858
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
793
859
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
794
860
  if row is None or row.cluster_hash is None:
795
861
  return None
796
862
  return row.cluster_hash
797
863
 
798
864
 
865
+ @_init_db
799
866
  def get_launched_resources_from_cluster_hash(
800
867
  cluster_hash: str) -> Optional[Tuple[int, Any]]:
801
- with orm.Session(SQLALCHEMY_ENGINE) as session:
868
+ assert _SQLALCHEMY_ENGINE is not None
869
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
802
870
  row = session.query(cluster_history_table).filter_by(
803
871
  cluster_hash=cluster_hash).first()
804
872
  if row is None:
@@ -840,14 +908,18 @@ def _load_storage_mounts_metadata(
840
908
  return pickle.loads(record_storage_mounts_metadata)
841
909
 
842
910
 
911
+ @_init_db
843
912
  @context_utils.cancellation_guard
844
913
  def get_cluster_from_name(
845
914
  cluster_name: Optional[str]) -> Optional[Dict[str, Any]]:
846
- with orm.Session(SQLALCHEMY_ENGINE) as session:
915
+ assert _SQLALCHEMY_ENGINE is not None
916
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
847
917
  row = session.query(cluster_table).filter_by(name=cluster_name).first()
848
918
  if row is None:
849
919
  return None
850
920
  user_hash = _get_user_hash_or_current_user(row.user_hash)
921
+ user = get_user(user_hash)
922
+ user_name = user.name if user is not None else None
851
923
  # TODO: use namedtuple instead of dict
852
924
  record = {
853
925
  'name': row.name,
@@ -865,7 +937,7 @@ def get_cluster_from_name(
865
937
  'cluster_ever_up': bool(row.cluster_ever_up),
866
938
  'status_updated_at': row.status_updated_at,
867
939
  'user_hash': user_hash,
868
- 'user_name': get_user(user_hash).name,
940
+ 'user_name': user_name,
869
941
  'config_hash': row.config_hash,
870
942
  'workspace': row.workspace,
871
943
  'last_creation_yaml': row.last_creation_yaml,
@@ -875,13 +947,17 @@ def get_cluster_from_name(
875
947
  return record
876
948
 
877
949
 
950
+ @_init_db
878
951
  def get_clusters() -> List[Dict[str, Any]]:
879
- with orm.Session(SQLALCHEMY_ENGINE) as session:
952
+ assert _SQLALCHEMY_ENGINE is not None
953
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
880
954
  rows = session.query(cluster_table).order_by(
881
955
  sqlalchemy.desc(cluster_table.c.launched_at)).all()
882
956
  records = []
883
957
  for row in rows:
884
958
  user_hash = _get_user_hash_or_current_user(row.user_hash)
959
+ user = get_user(user_hash)
960
+ user_name = user.name if user is not None else None
885
961
  # TODO: use namedtuple instead of dict
886
962
  record = {
887
963
  'name': row.name,
@@ -899,7 +975,7 @@ def get_clusters() -> List[Dict[str, Any]]:
899
975
  'cluster_ever_up': bool(row.cluster_ever_up),
900
976
  'status_updated_at': row.status_updated_at,
901
977
  'user_hash': user_hash,
902
- 'user_name': get_user(user_hash).name,
978
+ 'user_name': user_name,
903
979
  'config_hash': row.config_hash,
904
980
  'workspace': row.workspace,
905
981
  'last_creation_yaml': row.last_creation_yaml,
@@ -910,8 +986,10 @@ def get_clusters() -> List[Dict[str, Any]]:
910
986
  return records
911
987
 
912
988
 
989
+ @_init_db
913
990
  def get_clusters_from_history() -> List[Dict[str, Any]]:
914
- with orm.Session(SQLALCHEMY_ENGINE) as session:
991
+ assert _SQLALCHEMY_ENGINE is not None
992
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
915
993
  rows = session.query(
916
994
  cluster_history_table.join(cluster_table,
917
995
  cluster_history_table.c.cluster_hash ==
@@ -946,16 +1024,20 @@ def get_clusters_from_history() -> List[Dict[str, Any]]:
946
1024
  return records
947
1025
 
948
1026
 
1027
+ @_init_db
949
1028
  def get_cluster_names_start_with(starts_with: str) -> List[str]:
950
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1029
+ assert _SQLALCHEMY_ENGINE is not None
1030
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
951
1031
  rows = session.query(cluster_table).filter(
952
1032
  cluster_table.c.name.like(f'{starts_with}%')).all()
953
1033
  return [row.name for row in rows]
954
1034
 
955
1035
 
1036
+ @_init_db
956
1037
  def get_cached_enabled_clouds(cloud_capability: 'cloud.CloudCapability',
957
1038
  workspace: str) -> List['clouds.Cloud']:
958
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1039
+ assert _SQLALCHEMY_ENGINE is not None
1040
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
959
1041
  row = session.query(config_table).filter_by(
960
1042
  key=_get_enabled_clouds_key(cloud_capability, workspace)).first()
961
1043
  ret = []
@@ -976,14 +1058,16 @@ def get_cached_enabled_clouds(cloud_capability: 'cloud.CloudCapability',
976
1058
  return enabled_clouds
977
1059
 
978
1060
 
1061
+ @_init_db
979
1062
  def set_enabled_clouds(enabled_clouds: List[str],
980
1063
  cloud_capability: 'cloud.CloudCapability',
981
1064
  workspace: str) -> None:
982
- with orm.Session(SQLALCHEMY_ENGINE) as session:
983
- if (SQLALCHEMY_ENGINE.dialect.name ==
1065
+ assert _SQLALCHEMY_ENGINE is not None
1066
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1067
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
984
1068
  db_utils.SQLAlchemyDialect.SQLITE.value):
985
1069
  insert_func = sqlite.insert
986
- elif (SQLALCHEMY_ENGINE.dialect.name ==
1070
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
987
1071
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
988
1072
  insert_func = postgresql.insert
989
1073
  else:
@@ -1003,9 +1087,11 @@ def _get_enabled_clouds_key(cloud_capability: 'cloud.CloudCapability',
1003
1087
  return _ENABLED_CLOUDS_KEY_PREFIX + workspace + '_' + cloud_capability.value
1004
1088
 
1005
1089
 
1090
+ @_init_db
1006
1091
  def add_or_update_storage(storage_name: str,
1007
1092
  storage_handle: 'Storage.StorageMetadata',
1008
1093
  storage_status: status_lib.StorageStatus):
1094
+ assert _SQLALCHEMY_ENGINE is not None
1009
1095
  storage_launched_at = int(time.time())
1010
1096
  handle = pickle.dumps(storage_handle)
1011
1097
  last_use = common_utils.get_current_command()
@@ -1016,11 +1102,11 @@ def add_or_update_storage(storage_name: str,
1016
1102
  if not status_check(storage_status):
1017
1103
  raise ValueError(f'Error in updating global state. Storage Status '
1018
1104
  f'{storage_status} is passed in incorrectly')
1019
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1020
- if (SQLALCHEMY_ENGINE.dialect.name ==
1105
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1106
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1021
1107
  db_utils.SQLAlchemyDialect.SQLITE.value):
1022
1108
  insert_func = sqlite.insert
1023
- elif (SQLALCHEMY_ENGINE.dialect.name ==
1109
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1024
1110
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1025
1111
  insert_func = postgresql.insert
1026
1112
  else:
@@ -1043,16 +1129,20 @@ def add_or_update_storage(storage_name: str,
1043
1129
  session.commit()
1044
1130
 
1045
1131
 
1132
+ @_init_db
1046
1133
  def remove_storage(storage_name: str):
1047
1134
  """Removes Storage from Database"""
1048
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1135
+ assert _SQLALCHEMY_ENGINE is not None
1136
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1049
1137
  session.query(storage_table).filter_by(name=storage_name).delete()
1050
1138
  session.commit()
1051
1139
 
1052
1140
 
1141
+ @_init_db
1053
1142
  def set_storage_status(storage_name: str,
1054
1143
  status: status_lib.StorageStatus) -> None:
1055
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1144
+ assert _SQLALCHEMY_ENGINE is not None
1145
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1056
1146
  count = session.query(storage_table).filter_by(
1057
1147
  name=storage_name).update({storage_table.c.status: status.value})
1058
1148
  session.commit()
@@ -1061,18 +1151,22 @@ def set_storage_status(storage_name: str,
1061
1151
  raise ValueError(f'Storage {storage_name} not found.')
1062
1152
 
1063
1153
 
1154
+ @_init_db
1064
1155
  def get_storage_status(storage_name: str) -> Optional[status_lib.StorageStatus]:
1156
+ assert _SQLALCHEMY_ENGINE is not None
1065
1157
  assert storage_name is not None, 'storage_name cannot be None'
1066
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1158
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1067
1159
  row = session.query(storage_table).filter_by(name=storage_name).first()
1068
1160
  if row:
1069
1161
  return status_lib.StorageStatus[row.status]
1070
1162
  return None
1071
1163
 
1072
1164
 
1165
+ @_init_db
1073
1166
  def set_storage_handle(storage_name: str,
1074
1167
  handle: 'Storage.StorageMetadata') -> None:
1075
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1168
+ assert _SQLALCHEMY_ENGINE is not None
1169
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1076
1170
  count = session.query(storage_table).filter_by(
1077
1171
  name=storage_name).update(
1078
1172
  {storage_table.c.handle: pickle.dumps(handle)})
@@ -1082,25 +1176,29 @@ def set_storage_handle(storage_name: str,
1082
1176
  raise ValueError(f'Storage{storage_name} not found.')
1083
1177
 
1084
1178
 
1179
+ @_init_db
1085
1180
  def get_handle_from_storage_name(
1086
1181
  storage_name: Optional[str]) -> Optional['Storage.StorageMetadata']:
1182
+ assert _SQLALCHEMY_ENGINE is not None
1087
1183
  if storage_name is None:
1088
1184
  return None
1089
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1185
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1090
1186
  row = session.query(storage_table).filter_by(name=storage_name).first()
1091
1187
  if row:
1092
1188
  return pickle.loads(row.handle)
1093
1189
  return None
1094
1190
 
1095
1191
 
1192
+ @_init_db
1096
1193
  def get_glob_storage_name(storage_name: str) -> List[str]:
1194
+ assert _SQLALCHEMY_ENGINE is not None
1097
1195
  assert storage_name is not None, 'storage_name cannot be None'
1098
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1099
- if (SQLALCHEMY_ENGINE.dialect.name ==
1196
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1197
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1100
1198
  db_utils.SQLAlchemyDialect.SQLITE.value):
1101
1199
  rows = session.query(storage_table).filter(
1102
1200
  storage_table.c.name.op('GLOB')(storage_name)).all()
1103
- elif (SQLALCHEMY_ENGINE.dialect.name ==
1201
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1104
1202
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1105
1203
  rows = session.query(storage_table).filter(
1106
1204
  storage_table.c.name.op('SIMILAR TO')(
@@ -1110,15 +1208,19 @@ def get_glob_storage_name(storage_name: str) -> List[str]:
1110
1208
  return [row.name for row in rows]
1111
1209
 
1112
1210
 
1211
+ @_init_db
1113
1212
  def get_storage_names_start_with(starts_with: str) -> List[str]:
1114
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1213
+ assert _SQLALCHEMY_ENGINE is not None
1214
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1115
1215
  rows = session.query(storage_table).filter(
1116
1216
  storage_table.c.name.like(f'{starts_with}%')).all()
1117
1217
  return [row.name for row in rows]
1118
1218
 
1119
1219
 
1220
+ @_init_db
1120
1221
  def get_storage() -> List[Dict[str, Any]]:
1121
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1222
+ assert _SQLALCHEMY_ENGINE is not None
1223
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1122
1224
  rows = session.query(storage_table).all()
1123
1225
  records = []
1124
1226
  for row in rows:
@@ -1133,8 +1235,10 @@ def get_storage() -> List[Dict[str, Any]]:
1133
1235
  return records
1134
1236
 
1135
1237
 
1238
+ @_init_db
1136
1239
  def get_ssh_keys(user_hash: str) -> Tuple[str, str, bool]:
1137
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1240
+ assert _SQLALCHEMY_ENGINE is not None
1241
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1138
1242
  row = session.query(ssh_key_table).filter_by(
1139
1243
  user_hash=user_hash).first()
1140
1244
  if row:
@@ -1142,12 +1246,14 @@ def get_ssh_keys(user_hash: str) -> Tuple[str, str, bool]:
1142
1246
  return '', '', False
1143
1247
 
1144
1248
 
1249
+ @_init_db
1145
1250
  def set_ssh_keys(user_hash: str, ssh_public_key: str, ssh_private_key: str):
1146
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1147
- if (SQLALCHEMY_ENGINE.dialect.name ==
1251
+ assert _SQLALCHEMY_ENGINE is not None
1252
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1253
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1148
1254
  db_utils.SQLAlchemyDialect.SQLITE.value):
1149
1255
  insert_func = sqlite.insert
1150
- elif (SQLALCHEMY_ENGINE.dialect.name ==
1256
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1151
1257
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1152
1258
  insert_func = postgresql.insert
1153
1259
  else:
@@ -1166,6 +1272,7 @@ def set_ssh_keys(user_hash: str, ssh_public_key: str, ssh_private_key: str):
1166
1272
  session.commit()
1167
1273
 
1168
1274
 
1275
+ @_init_db
1169
1276
  def get_cluster_yaml_str(cluster_yaml_path: Optional[str]) -> Optional[str]:
1170
1277
  """Get the cluster yaml from the database or the local file system.
1171
1278
  If the cluster yaml is not in the database, check if it exists on the
@@ -1173,11 +1280,12 @@ def get_cluster_yaml_str(cluster_yaml_path: Optional[str]) -> Optional[str]:
1173
1280
 
1174
1281
  It is assumed that the cluster yaml file is named as <cluster_name>.yml.
1175
1282
  """
1283
+ assert _SQLALCHEMY_ENGINE is not None
1176
1284
  if cluster_yaml_path is None:
1177
1285
  raise ValueError('Attempted to read a None YAML.')
1178
1286
  cluster_file_name = os.path.basename(cluster_yaml_path)
1179
1287
  cluster_name, _ = os.path.splitext(cluster_file_name)
1180
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1288
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1181
1289
  row = session.query(cluster_yaml_table).filter_by(
1182
1290
  cluster_name=cluster_name).first()
1183
1291
  if row is None:
@@ -1205,13 +1313,15 @@ def get_cluster_yaml_dict(cluster_yaml_path: Optional[str]) -> Dict[str, Any]:
1205
1313
  return yaml.safe_load(yaml_str)
1206
1314
 
1207
1315
 
1316
+ @_init_db
1208
1317
  def set_cluster_yaml(cluster_name: str, yaml_str: str) -> None:
1209
1318
  """Set the cluster yaml in the database."""
1210
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1211
- if (SQLALCHEMY_ENGINE.dialect.name ==
1319
+ assert _SQLALCHEMY_ENGINE is not None
1320
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1321
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1212
1322
  db_utils.SQLAlchemyDialect.SQLITE.value):
1213
1323
  insert_func = sqlite.insert
1214
- elif (SQLALCHEMY_ENGINE.dialect.name ==
1324
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1215
1325
  db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1216
1326
  insert_func = postgresql.insert
1217
1327
  else:
@@ -1225,8 +1335,10 @@ def set_cluster_yaml(cluster_name: str, yaml_str: str) -> None:
1225
1335
  session.commit()
1226
1336
 
1227
1337
 
1338
+ @_init_db
1228
1339
  def remove_cluster_yaml(cluster_name: str):
1229
- with orm.Session(SQLALCHEMY_ENGINE) as session:
1340
+ assert _SQLALCHEMY_ENGINE is not None
1341
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1230
1342
  session.query(cluster_yaml_table).filter_by(
1231
1343
  cluster_name=cluster_name).delete()
1232
1344
  session.commit()