skypilot-nightly 1.0.0.dev20250909__py3-none-any.whl → 1.0.0.dev20250912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +19 -4
  3. sky/backends/backend_utils.py +160 -23
  4. sky/backends/cloud_vm_ray_backend.py +226 -74
  5. sky/catalog/__init__.py +7 -0
  6. sky/catalog/aws_catalog.py +4 -0
  7. sky/catalog/common.py +18 -0
  8. sky/catalog/data_fetchers/fetch_aws.py +13 -1
  9. sky/client/cli/command.py +2 -71
  10. sky/client/sdk.py +20 -0
  11. sky/client/sdk_async.py +23 -18
  12. sky/clouds/aws.py +26 -6
  13. sky/clouds/cloud.py +8 -0
  14. sky/dashboard/out/404.html +1 -1
  15. sky/dashboard/out/_next/static/chunks/3294.ba6586f9755b0edb.js +6 -0
  16. sky/dashboard/out/_next/static/chunks/{webpack-d4fabc08788e14af.js → webpack-e8a0c4c3c6f408fb.js} +1 -1
  17. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  18. sky/dashboard/out/clusters/[cluster].html +1 -1
  19. sky/dashboard/out/clusters.html +1 -1
  20. sky/dashboard/out/config.html +1 -1
  21. sky/dashboard/out/index.html +1 -1
  22. sky/dashboard/out/infra/[context].html +1 -1
  23. sky/dashboard/out/infra.html +1 -1
  24. sky/dashboard/out/jobs/[job].html +1 -1
  25. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  26. sky/dashboard/out/jobs.html +1 -1
  27. sky/dashboard/out/users.html +1 -1
  28. sky/dashboard/out/volumes.html +1 -1
  29. sky/dashboard/out/workspace/new.html +1 -1
  30. sky/dashboard/out/workspaces/[name].html +1 -1
  31. sky/dashboard/out/workspaces.html +1 -1
  32. sky/data/storage.py +5 -1
  33. sky/execution.py +21 -14
  34. sky/global_user_state.py +34 -0
  35. sky/jobs/client/sdk_async.py +4 -2
  36. sky/jobs/constants.py +3 -0
  37. sky/jobs/controller.py +734 -310
  38. sky/jobs/recovery_strategy.py +251 -129
  39. sky/jobs/scheduler.py +247 -174
  40. sky/jobs/server/core.py +20 -4
  41. sky/jobs/server/utils.py +2 -2
  42. sky/jobs/state.py +709 -508
  43. sky/jobs/utils.py +90 -40
  44. sky/logs/agent.py +10 -2
  45. sky/provision/aws/config.py +4 -1
  46. sky/provision/gcp/config.py +6 -1
  47. sky/provision/kubernetes/config.py +7 -2
  48. sky/provision/kubernetes/instance.py +84 -41
  49. sky/provision/kubernetes/utils.py +17 -8
  50. sky/provision/provisioner.py +1 -0
  51. sky/provision/vast/instance.py +1 -1
  52. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  53. sky/serve/replica_managers.py +0 -7
  54. sky/serve/serve_utils.py +5 -0
  55. sky/serve/server/impl.py +1 -2
  56. sky/serve/service.py +0 -2
  57. sky/server/common.py +8 -3
  58. sky/server/config.py +55 -27
  59. sky/server/constants.py +1 -0
  60. sky/server/daemons.py +7 -11
  61. sky/server/metrics.py +41 -8
  62. sky/server/requests/executor.py +41 -4
  63. sky/server/requests/serializers/encoders.py +1 -1
  64. sky/server/server.py +9 -1
  65. sky/server/uvicorn.py +11 -5
  66. sky/setup_files/dependencies.py +4 -2
  67. sky/skylet/attempt_skylet.py +1 -0
  68. sky/skylet/constants.py +14 -7
  69. sky/skylet/events.py +2 -10
  70. sky/skylet/log_lib.py +11 -0
  71. sky/skylet/log_lib.pyi +9 -0
  72. sky/task.py +62 -0
  73. sky/templates/kubernetes-ray.yml.j2 +120 -3
  74. sky/utils/accelerator_registry.py +3 -1
  75. sky/utils/command_runner.py +35 -11
  76. sky/utils/command_runner.pyi +25 -3
  77. sky/utils/common_utils.py +11 -1
  78. sky/utils/context_utils.py +15 -2
  79. sky/utils/controller_utils.py +5 -0
  80. sky/utils/db/db_utils.py +31 -2
  81. sky/utils/db/migration_utils.py +1 -1
  82. sky/utils/git.py +559 -1
  83. sky/utils/resource_checker.py +8 -7
  84. sky/utils/rich_utils.py +3 -1
  85. sky/utils/subprocess_utils.py +9 -0
  86. sky/volumes/volume.py +2 -0
  87. sky/workspaces/core.py +57 -21
  88. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/METADATA +38 -36
  89. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/RECORD +95 -95
  90. sky/client/cli/git.py +0 -549
  91. sky/dashboard/out/_next/static/chunks/3294.c80326aec9bfed40.js +0 -6
  92. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → DAiq7V2xJnO1LSfmunZl6}/_buildManifest.js +0 -0
  93. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → DAiq7V2xJnO1LSfmunZl6}/_ssgManifest.js +0 -0
  94. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/WHEEL +0 -0
  95. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/entry_points.txt +0 -0
  96. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/licenses/LICENSE +0 -0
  97. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/top_level.txt +0 -0
sky/jobs/state.py CHANGED
@@ -1,13 +1,17 @@
1
1
  """The database for managed jobs status."""
2
2
  # TODO(zhwu): maybe use file based status instead of database, so
3
3
  # that we can easily switch to a s3-based storage.
4
+ import asyncio
4
5
  import enum
5
6
  import functools
7
+ import ipaddress
6
8
  import json
9
+ import sqlite3
7
10
  import threading
8
11
  import time
9
12
  import typing
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
14
+ import urllib.parse
11
15
 
12
16
  import colorama
13
17
  import sqlalchemy
@@ -15,27 +19,34 @@ from sqlalchemy import exc as sqlalchemy_exc
15
19
  from sqlalchemy import orm
16
20
  from sqlalchemy.dialects import postgresql
17
21
  from sqlalchemy.dialects import sqlite
22
+ from sqlalchemy.ext import asyncio as sql_async
18
23
  from sqlalchemy.ext import declarative
19
24
 
20
25
  from sky import exceptions
21
26
  from sky import sky_logging
27
+ from sky import skypilot_config
22
28
  from sky.skylet import constants
23
29
  from sky.utils import common_utils
30
+ from sky.utils import context_utils
24
31
  from sky.utils.db import db_utils
25
32
  from sky.utils.db import migration_utils
26
33
 
27
34
  if typing.TYPE_CHECKING:
28
35
  from sqlalchemy.engine import row
29
36
 
30
- import sky
31
-
32
- CallbackType = Callable[[str], None]
37
+ # Separate callback types for sync and async contexts
38
+ SyncCallbackType = Callable[[str], None]
39
+ AsyncCallbackType = Callable[[str], Awaitable[Any]]
40
+ CallbackType = Union[SyncCallbackType, AsyncCallbackType]
33
41
 
34
42
  logger = sky_logging.init_logger(__name__)
35
43
 
36
44
  _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
45
+ _SQLALCHEMY_ENGINE_ASYNC: Optional[sql_async.AsyncEngine] = None
37
46
  _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
38
47
 
48
+ _DB_RETRY_TIMES = 30
49
+
39
50
  Base = declarative.declarative_base()
40
51
 
41
52
  # === Database schema ===
@@ -70,7 +81,7 @@ spot_table = sqlalchemy.Table(
70
81
  sqlalchemy.Column('recovery_count', sqlalchemy.Integer, server_default='0'),
71
82
  sqlalchemy.Column('job_duration', sqlalchemy.Float, server_default='0'),
72
83
  sqlalchemy.Column('failure_reason', sqlalchemy.Text),
73
- sqlalchemy.Column('spot_job_id', sqlalchemy.Integer),
84
+ sqlalchemy.Column('spot_job_id', sqlalchemy.Integer, index=True),
74
85
  sqlalchemy.Column('task_id', sqlalchemy.Integer, server_default='0'),
75
86
  sqlalchemy.Column('task_name', sqlalchemy.Text),
76
87
  sqlalchemy.Column('specs', sqlalchemy.Text),
@@ -129,6 +140,7 @@ def create_table(engine: sqlalchemy.engine.Engine):
129
140
  try:
130
141
  with orm.Session(engine) as session:
131
142
  session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
143
+ session.execute(sqlalchemy.text('PRAGMA synchronous=1'))
132
144
  session.commit()
133
145
  except sqlalchemy_exc.OperationalError as e:
134
146
  if 'database is locked' not in str(e):
@@ -141,6 +153,43 @@ def create_table(engine: sqlalchemy.engine.Engine):
141
153
  migration_utils.SPOT_JOBS_VERSION)
142
154
 
143
155
 
156
+ def force_no_postgres() -> bool:
157
+ """Force no postgres.
158
+
159
+ If the db is localhost on the api server, and we are not in consolidation
160
+ mode, we must force using sqlite and not using the api server on the jobs
161
+ controller.
162
+ """
163
+ conn_string = skypilot_config.get_nested(('db',), None)
164
+
165
+ if conn_string:
166
+ parsed = urllib.parse.urlparse(conn_string)
167
+ # it freezes if we use the normal get_consolidation_mode function
168
+ consolidation_mode = skypilot_config.get_nested(
169
+ ('jobs', 'controller', 'consolidation_mode'), default_value=False)
170
+ if ((parsed.hostname == 'localhost' or
171
+ ipaddress.ip_address(parsed.hostname).is_loopback) and
172
+ not consolidation_mode):
173
+ return True
174
+ return False
175
+
176
+
177
+ def initialize_and_get_db_async() -> sql_async.AsyncEngine:
178
+ global _SQLALCHEMY_ENGINE_ASYNC
179
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
180
+ return _SQLALCHEMY_ENGINE_ASYNC
181
+ with _SQLALCHEMY_ENGINE_LOCK:
182
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
183
+ return _SQLALCHEMY_ENGINE_ASYNC
184
+
185
+ _SQLALCHEMY_ENGINE_ASYNC = db_utils.get_engine('spot_jobs',
186
+ async_engine=True)
187
+
188
+ # to create the table in case an async function gets called first
189
+ initialize_and_get_db()
190
+ return _SQLALCHEMY_ENGINE_ASYNC
191
+
192
+
144
193
  # We wrap the sqlalchemy engine initialization in a thread
145
194
  # lock to ensure that multiple threads do not initialize the
146
195
  # engine which could result in a rare race condition where
@@ -149,7 +198,6 @@ def create_table(engine: sqlalchemy.engine.Engine):
149
198
  # which could result in e1 being garbage collected unexpectedly.
150
199
  def initialize_and_get_db() -> sqlalchemy.engine.Engine:
151
200
  global _SQLALCHEMY_ENGINE
152
-
153
201
  if _SQLALCHEMY_ENGINE is not None:
154
202
  return _SQLALCHEMY_ENGINE
155
203
 
@@ -167,13 +215,60 @@ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
167
215
  return _SQLALCHEMY_ENGINE
168
216
 
169
217
 
218
+ def _init_db_async(func):
219
+ """Initialize the async database. Add backoff to the function call."""
220
+
221
+ @functools.wraps(func)
222
+ async def wrapper(*args, **kwargs):
223
+ if _SQLALCHEMY_ENGINE_ASYNC is None:
224
+ # this may happen multiple times since there is no locking
225
+ # here but thats fine, this is just a short circuit for the
226
+ # common case.
227
+ await context_utils.to_thread(initialize_and_get_db_async)
228
+
229
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=5)
230
+ last_exc = None
231
+ for _ in range(_DB_RETRY_TIMES):
232
+ try:
233
+ return await func(*args, **kwargs)
234
+ except (sqlalchemy_exc.OperationalError,
235
+ asyncio.exceptions.TimeoutError, OSError,
236
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
237
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
238
+ last_exc = e
239
+ logger.debug(f'DB error: {last_exc}')
240
+ await asyncio.sleep(backoff.current_backoff())
241
+ assert last_exc is not None
242
+ raise last_exc
243
+
244
+ return wrapper
245
+
246
+
170
247
  def _init_db(func):
171
- """Initialize the database."""
248
+ """Initialize the database. Add backoff to the function call."""
172
249
 
173
250
  @functools.wraps(func)
174
251
  def wrapper(*args, **kwargs):
175
- initialize_and_get_db()
176
- return func(*args, **kwargs)
252
+ if _SQLALCHEMY_ENGINE is None:
253
+ # this may happen multiple times since there is no locking
254
+ # here but thats fine, this is just a short circuit for the
255
+ # common case.
256
+ initialize_and_get_db()
257
+
258
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=10)
259
+ last_exc = None
260
+ for _ in range(_DB_RETRY_TIMES):
261
+ try:
262
+ return func(*args, **kwargs)
263
+ except (sqlalchemy_exc.OperationalError,
264
+ asyncio.exceptions.TimeoutError, OSError,
265
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
266
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
267
+ last_exc = e
268
+ logger.debug(f'DB error: {last_exc}')
269
+ time.sleep(backoff.current_backoff())
270
+ assert last_exc is not None
271
+ raise last_exc
177
272
 
178
273
  return wrapper
179
274
 
@@ -416,6 +511,10 @@ class ManagedJobScheduleState(enum.Enum):
416
511
  # This job may have been created before scheduler was introduced in #4458.
417
512
  # This state is not used by scheduler but just for backward compatibility.
418
513
  # TODO(cooperc): remove this in v0.11.0
514
+ # TODO(luca): the only states we need are INACTIVE, WAITING, ALIVE, and
515
+ # DONE. ALIVE = old LAUNCHING + ALIVE + ALIVE_BACKOFF + ALIVE_WAITING and
516
+ # will represent jobs that are claimed by a controller. Delete the rest
517
+ # in v0.13.0
419
518
  INVALID = None
420
519
  # The job should be ignored by the scheduler.
421
520
  INACTIVE = 'INACTIVE'
@@ -440,32 +539,6 @@ class ManagedJobScheduleState(enum.Enum):
440
539
 
441
540
 
442
541
  # === Status transition functions ===
443
- @_init_db
444
- def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str,
445
- pool: Optional[str], pool_hash: Optional[str]):
446
- assert _SQLALCHEMY_ENGINE is not None
447
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
448
- if (_SQLALCHEMY_ENGINE.dialect.name ==
449
- db_utils.SQLAlchemyDialect.SQLITE.value):
450
- insert_func = sqlite.insert
451
- elif (_SQLALCHEMY_ENGINE.dialect.name ==
452
- db_utils.SQLAlchemyDialect.POSTGRESQL.value):
453
- insert_func = postgresql.insert
454
- else:
455
- raise ValueError('Unsupported database dialect')
456
- insert_stmt = insert_func(job_info_table).values(
457
- spot_job_id=job_id,
458
- name=name,
459
- schedule_state=ManagedJobScheduleState.INACTIVE.value,
460
- workspace=workspace,
461
- entrypoint=entrypoint,
462
- pool=pool,
463
- pool_hash=pool_hash,
464
- )
465
- session.execute(insert_stmt)
466
- session.commit()
467
-
468
-
469
542
  @_init_db
470
543
  def set_job_info_without_job_id(name: str, workspace: str, entrypoint: str,
471
544
  pool: Optional[str],
@@ -517,6 +590,7 @@ def set_pending(
517
590
  ):
518
591
  """Set the task to pending state."""
519
592
  assert _SQLALCHEMY_ENGINE is not None
593
+
520
594
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
521
595
  session.execute(
522
596
  sqlalchemy.insert(spot_table).values(
@@ -530,76 +604,28 @@ def set_pending(
530
604
  session.commit()
531
605
 
532
606
 
533
- @_init_db
534
- def set_starting(job_id: int, task_id: int, run_timestamp: str,
535
- submit_time: float, resources_str: str,
536
- specs: Dict[str, Union[str,
537
- int]], callback_func: CallbackType):
538
- """Set the task to starting state.
539
-
540
- Args:
541
- job_id: The managed job ID.
542
- task_id: The task ID.
543
- run_timestamp: The run_timestamp of the run. This will be used to
544
- determine the log directory of the managed task.
545
- submit_time: The time when the managed task is submitted.
546
- resources_str: The resources string of the managed task.
547
- specs: The specs of the managed task.
548
- callback_func: The callback function.
549
- """
550
- assert _SQLALCHEMY_ENGINE is not None
551
- # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
552
- # the log directory and submission time align with each other, so as to
553
- # make it easier to find them based on one of the values.
554
- # Also, using the earlier timestamp should be closer to the term
555
- # `submit_at`, which represents the time the managed task is submitted.
556
- logger.info('Launching the spot cluster...')
557
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
558
- count = session.query(spot_table).filter(
559
- sqlalchemy.and_(
560
- spot_table.c.spot_job_id == job_id,
561
- spot_table.c.task_id == task_id,
562
- spot_table.c.status == ManagedJobStatus.PENDING.value,
563
- spot_table.c.end_at.is_(None),
564
- )).update({
565
- spot_table.c.resources: resources_str,
566
- spot_table.c.submitted_at: submit_time,
567
- spot_table.c.status: ManagedJobStatus.STARTING.value,
568
- spot_table.c.run_timestamp: run_timestamp,
569
- spot_table.c.specs: json.dumps(specs),
570
- })
571
- session.commit()
572
- if count != 1:
573
- raise exceptions.ManagedJobStatusError(
574
- 'Failed to set the task to starting. '
575
- f'({count} rows updated)')
576
- # SUBMITTED is no longer used, but we keep it for backward compatibility.
577
- # TODO(cooperc): remove this in v0.12.0
578
- callback_func('SUBMITTED')
579
- callback_func('STARTING')
580
-
581
-
582
- @_init_db
583
- def set_backoff_pending(job_id: int, task_id: int):
607
+ @_init_db_async
608
+ async def set_backoff_pending_async(job_id: int, task_id: int):
584
609
  """Set the task to PENDING state if it is in backoff.
585
610
 
586
611
  This should only be used to transition from STARTING or RECOVERING back to
587
612
  PENDING.
588
613
  """
589
- assert _SQLALCHEMY_ENGINE is not None
590
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
591
- count = session.query(spot_table).filter(
592
- sqlalchemy.and_(
593
- spot_table.c.spot_job_id == job_id,
594
- spot_table.c.task_id == task_id,
595
- spot_table.c.status.in_([
596
- ManagedJobStatus.STARTING.value,
597
- ManagedJobStatus.RECOVERING.value
598
- ]),
599
- spot_table.c.end_at.is_(None),
600
- )).update({spot_table.c.status: ManagedJobStatus.PENDING.value})
601
- session.commit()
602
- logger.debug('back to PENDING')
614
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
615
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
616
+ count = await session.execute(
617
+ sqlalchemy.update(spot_table).where(
618
+ sqlalchemy.and_(
619
+ spot_table.c.spot_job_id == job_id,
620
+ spot_table.c.task_id == task_id,
621
+ spot_table.c.status.in_([
622
+ ManagedJobStatus.STARTING.value,
623
+ ManagedJobStatus.RECOVERING.value
624
+ ]),
625
+ spot_table.c.end_at.is_(None),
626
+ )).values({spot_table.c.status: ManagedJobStatus.PENDING.value})
627
+ )
628
+ await session.commit()
603
629
  if count != 1:
604
630
  raise exceptions.ManagedJobStatusError(
605
631
  'Failed to set the task back to pending. '
@@ -608,7 +634,7 @@ def set_backoff_pending(job_id: int, task_id: int):
608
634
 
609
635
 
610
636
  @_init_db
611
- def set_restarting(job_id: int, task_id: int, recovering: bool):
637
+ async def set_restarting_async(job_id: int, task_id: int, recovering: bool):
612
638
  """Set the task back to STARTING or RECOVERING from PENDING.
613
639
 
614
640
  This should not be used for the initial transition from PENDING to STARTING.
@@ -616,19 +642,20 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
616
642
  after using set_backoff_pending to transition back to PENDING during
617
643
  launch retry backoff.
618
644
  """
619
- assert _SQLALCHEMY_ENGINE is not None
645
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
620
646
  target_status = ManagedJobStatus.STARTING.value
621
647
  if recovering:
622
648
  target_status = ManagedJobStatus.RECOVERING.value
623
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
624
- count = session.query(spot_table).filter(
625
- sqlalchemy.and_(
626
- spot_table.c.spot_job_id == job_id,
627
- spot_table.c.task_id == task_id,
628
- spot_table.c.status == ManagedJobStatus.PENDING.value,
629
- spot_table.c.end_at.is_(None),
630
- )).update({spot_table.c.status: target_status})
631
- session.commit()
649
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
650
+ result = await session.execute(
651
+ sqlalchemy.update(spot_table).where(
652
+ sqlalchemy.and_(
653
+ spot_table.c.spot_job_id == job_id,
654
+ spot_table.c.task_id == task_id,
655
+ spot_table.c.end_at.is_(None),
656
+ )).values({spot_table.c.status: target_status}))
657
+ count = result.rowcount
658
+ await session.commit()
632
659
  logger.debug(f'back to {target_status}')
633
660
  if count != 1:
634
661
  raise exceptions.ManagedJobStatusError(
@@ -638,137 +665,6 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
638
665
  # initial (pre-`set_backoff_pending`) transition to STARTING or RECOVERING.
639
666
 
640
667
 
641
- @_init_db
642
- def set_started(job_id: int, task_id: int, start_time: float,
643
- callback_func: CallbackType):
644
- """Set the task to started state."""
645
- assert _SQLALCHEMY_ENGINE is not None
646
- logger.info('Job started.')
647
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
648
- count = session.query(spot_table).filter(
649
- sqlalchemy.and_(
650
- spot_table.c.spot_job_id == job_id,
651
- spot_table.c.task_id == task_id,
652
- spot_table.c.status.in_([
653
- ManagedJobStatus.STARTING.value,
654
- # If the task is empty, we will jump straight
655
- # from PENDING to RUNNING
656
- ManagedJobStatus.PENDING.value
657
- ]),
658
- spot_table.c.end_at.is_(None),
659
- )).update({
660
- spot_table.c.status: ManagedJobStatus.RUNNING.value,
661
- spot_table.c.start_at: start_time,
662
- spot_table.c.last_recovered_at: start_time,
663
- })
664
- session.commit()
665
- if count != 1:
666
- raise exceptions.ManagedJobStatusError(
667
- f'Failed to set the task to started. '
668
- f'({count} rows updated)')
669
- callback_func('STARTED')
670
-
671
-
672
- @_init_db
673
- def set_recovering(job_id: int, task_id: int, force_transit_to_recovering: bool,
674
- callback_func: CallbackType):
675
- """Set the task to recovering state, and update the job duration."""
676
- assert _SQLALCHEMY_ENGINE is not None
677
- logger.info('=== Recovering... ===')
678
- # NOTE: if we are resuming from a controller failure and the previous status
679
- # is STARTING, the initial value of `last_recovered_at` might not be set
680
- # yet (default value -1). In this case, we should not add current timestamp.
681
- # Otherwise, the job duration will be incorrect (~55 years from 1970).
682
- current_time = time.time()
683
-
684
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
685
- if force_transit_to_recovering:
686
- # For the HA job controller, it is possible that the jobs came from
687
- # any processing status to recovering. But it should not be any
688
- # terminal status as such jobs will not be recovered; and it should
689
- # not be CANCELLING as we will directly trigger a cleanup.
690
- status_condition = spot_table.c.status.in_(
691
- [s.value for s in ManagedJobStatus.processing_statuses()])
692
- else:
693
- status_condition = (
694
- spot_table.c.status == ManagedJobStatus.RUNNING.value)
695
-
696
- count = session.query(spot_table).filter(
697
- sqlalchemy.and_(
698
- spot_table.c.spot_job_id == job_id,
699
- spot_table.c.task_id == task_id,
700
- status_condition,
701
- spot_table.c.end_at.is_(None),
702
- )).update({
703
- spot_table.c.status: ManagedJobStatus.RECOVERING.value,
704
- spot_table.c.job_duration: sqlalchemy.case(
705
- (spot_table.c.last_recovered_at >= 0,
706
- spot_table.c.job_duration + current_time -
707
- spot_table.c.last_recovered_at),
708
- else_=spot_table.c.job_duration),
709
- spot_table.c.last_recovered_at: sqlalchemy.case(
710
- (spot_table.c.last_recovered_at < 0, current_time),
711
- else_=spot_table.c.last_recovered_at),
712
- })
713
- session.commit()
714
- if count != 1:
715
- raise exceptions.ManagedJobStatusError(
716
- f'Failed to set the task to recovering. '
717
- f'({count} rows updated)')
718
- callback_func('RECOVERING')
719
-
720
-
721
- @_init_db
722
- def set_recovered(job_id: int, task_id: int, recovered_time: float,
723
- callback_func: CallbackType):
724
- """Set the task to recovered."""
725
- assert _SQLALCHEMY_ENGINE is not None
726
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
727
- count = session.query(spot_table).filter(
728
- sqlalchemy.and_(
729
- spot_table.c.spot_job_id == job_id,
730
- spot_table.c.task_id == task_id,
731
- spot_table.c.status == ManagedJobStatus.RECOVERING.value,
732
- spot_table.c.end_at.is_(None),
733
- )).update({
734
- spot_table.c.status: ManagedJobStatus.RUNNING.value,
735
- spot_table.c.last_recovered_at: recovered_time,
736
- spot_table.c.recovery_count: spot_table.c.recovery_count + 1,
737
- })
738
- session.commit()
739
- if count != 1:
740
- raise exceptions.ManagedJobStatusError(
741
- f'Failed to set the task to recovered. '
742
- f'({count} rows updated)')
743
- logger.info('==== Recovered. ====')
744
- callback_func('RECOVERED')
745
-
746
-
747
- @_init_db
748
- def set_succeeded(job_id: int, task_id: int, end_time: float,
749
- callback_func: CallbackType):
750
- """Set the task to succeeded, if it is in a non-terminal state."""
751
- assert _SQLALCHEMY_ENGINE is not None
752
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
753
- count = session.query(spot_table).filter(
754
- sqlalchemy.and_(
755
- spot_table.c.spot_job_id == job_id,
756
- spot_table.c.task_id == task_id,
757
- spot_table.c.status == ManagedJobStatus.RUNNING.value,
758
- spot_table.c.end_at.is_(None),
759
- )).update({
760
- spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
761
- spot_table.c.end_at: end_time,
762
- })
763
- session.commit()
764
- if count != 1:
765
- raise exceptions.ManagedJobStatusError(
766
- f'Failed to set the task to succeeded. '
767
- f'({count} rows updated)')
768
- callback_func('SUCCEEDED')
769
- logger.info('Job succeeded.')
770
-
771
-
772
668
  @_init_db
773
669
  def set_failed(
774
670
  job_id: int,
@@ -834,51 +730,35 @@ def set_failed(
834
730
 
835
731
 
836
732
  @_init_db
837
- def set_cancelling(job_id: int, callback_func: CallbackType):
838
- """Set tasks in the job as cancelling, if they are in non-terminal states.
839
-
840
- task_id is not needed, because we expect the job should be cancelled
841
- as a whole, and we should not cancel a single task.
842
- """
733
+ def set_pending_cancelled(job_id: int):
734
+ """Set the job as pending cancelled, if it is in non-terminal states."""
843
735
  assert _SQLALCHEMY_ENGINE is not None
844
736
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
845
- count = session.query(spot_table).filter(
846
- sqlalchemy.and_(
847
- spot_table.c.spot_job_id == job_id,
848
- spot_table.c.end_at.is_(None),
849
- )).update({spot_table.c.status: ManagedJobStatus.CANCELLING.value})
850
- session.commit()
851
- updated = count > 0
852
- if updated:
853
- logger.info('Cancelling the job...')
854
- callback_func('CANCELLING')
855
- else:
856
- logger.info('Cancellation skipped, job is already terminal')
857
-
858
-
859
- @_init_db
860
- def set_cancelled(job_id: int, callback_func: CallbackType):
861
- """Set tasks in the job as cancelled, if they are in CANCELLING state.
737
+ # Subquery to get the spot_job_ids that match the joined condition
738
+ subquery = session.query(spot_table.c.job_id).join(
739
+ job_info_table,
740
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id
741
+ ).filter(
742
+ spot_table.c.spot_job_id == job_id,
743
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
744
+ # Note: it's possible that a WAITING job actually needs to be
745
+ # cleaned up, if we are in the middle of an upgrade/recovery and
746
+ # the job is waiting to be reclaimed by a new controller. But,
747
+ # in this case the status will not be PENDING.
748
+ sqlalchemy.or_(
749
+ job_info_table.c.schedule_state ==
750
+ ManagedJobScheduleState.WAITING.value,
751
+ job_info_table.c.schedule_state ==
752
+ ManagedJobScheduleState.INACTIVE.value,
753
+ ),
754
+ ).subquery()
862
755
 
863
- The set_cancelling should be called before this function.
864
- """
865
- assert _SQLALCHEMY_ENGINE is not None
866
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
867
756
  count = session.query(spot_table).filter(
868
- sqlalchemy.and_(
869
- spot_table.c.spot_job_id == job_id,
870
- spot_table.c.status == ManagedJobStatus.CANCELLING.value,
871
- )).update({
872
- spot_table.c.status: ManagedJobStatus.CANCELLED.value,
873
- spot_table.c.end_at: time.time(),
874
- })
757
+ spot_table.c.job_id.in_(subquery)).update(
758
+ {spot_table.c.status: ManagedJobStatus.CANCELLED.value},
759
+ synchronize_session=False)
875
760
  session.commit()
876
- updated = count > 0
877
- if updated:
878
- logger.info('Job cancelled.')
879
- callback_func('CANCELLED')
880
- else:
881
- logger.info('Cancellation skipped, job is not CANCELLING')
761
+ return count > 0
882
762
 
883
763
 
884
764
  @_init_db
@@ -936,45 +816,6 @@ def get_nonterminal_job_ids_by_name(name: Optional[str],
936
816
  return job_ids
937
817
 
938
818
 
939
- @_init_db
940
- def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
941
- """Get jobs from the database that have a live schedule_state.
942
-
943
- This should return job(s) that are not INACTIVE, WAITING, or DONE. So a
944
- returned job should correspond to a live job controller process, with one
945
- exception: the job may have just transitioned from WAITING to LAUNCHING, but
946
- the controller process has not yet started.
947
- """
948
- assert _SQLALCHEMY_ENGINE is not None
949
-
950
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
951
- query = sqlalchemy.select(
952
- job_info_table.c.spot_job_id,
953
- job_info_table.c.schedule_state,
954
- job_info_table.c.controller_pid,
955
- ).where(~job_info_table.c.schedule_state.in_([
956
- ManagedJobScheduleState.INACTIVE.value,
957
- ManagedJobScheduleState.WAITING.value,
958
- ManagedJobScheduleState.DONE.value,
959
- ]))
960
-
961
- if job_id is not None:
962
- query = query.where(job_info_table.c.spot_job_id == job_id)
963
-
964
- query = query.order_by(job_info_table.c.spot_job_id.desc())
965
-
966
- rows = session.execute(query).fetchall()
967
- jobs = []
968
- for row in rows:
969
- job_dict = {
970
- 'job_id': row[0],
971
- 'schedule_state': ManagedJobScheduleState(row[1]),
972
- 'controller_pid': row[2],
973
- }
974
- jobs.append(job_dict)
975
- return jobs
976
-
977
-
978
819
  @_init_db
979
820
  def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
980
821
  """Get jobs that need controller process checking.
@@ -1035,32 +876,6 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
1035
876
  return [row[0] for row in rows if row[0] is not None]
1036
877
 
1037
878
 
1038
- @_init_db
1039
- def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
1040
- """Get all job ids by name."""
1041
- assert _SQLALCHEMY_ENGINE is not None
1042
-
1043
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1044
- query = sqlalchemy.select(
1045
- spot_table.c.spot_job_id.distinct()).select_from(
1046
- spot_table.outerjoin(
1047
- job_info_table,
1048
- spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1049
- if name is not None:
1050
- # We match the job name from `job_info` for the jobs submitted after
1051
- # #1982, and from `spot` for the jobs submitted before #1982, whose
1052
- # job_info is not available.
1053
- name_condition = sqlalchemy.or_(
1054
- job_info_table.c.name == name,
1055
- sqlalchemy.and_(job_info_table.c.name.is_(None),
1056
- spot_table.c.task_name == name))
1057
- query = query.where(name_condition)
1058
- query = query.order_by(spot_table.c.spot_job_id.desc())
1059
- rows = session.execute(query).fetchall()
1060
- job_ids = [row[0] for row in rows if row[0] is not None]
1061
- return job_ids
1062
-
1063
-
1064
879
  @_init_db
1065
880
  def _get_all_task_ids_statuses(
1066
881
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
@@ -1092,18 +907,6 @@ def get_all_task_ids_names_statuses_logs(
1092
907
  for row in id_names]
1093
908
 
1094
909
 
1095
- @_init_db
1096
- def get_job_status_with_task_id(job_id: int,
1097
- task_id: int) -> Optional[ManagedJobStatus]:
1098
- assert _SQLALCHEMY_ENGINE is not None
1099
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1100
- status = session.execute(
1101
- sqlalchemy.select(spot_table.c.status).where(
1102
- sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
1103
- spot_table.c.task_id == task_id))).fetchone()
1104
- return ManagedJobStatus(status[0]) if status else None
1105
-
1106
-
1107
910
  def get_num_tasks(job_id: int) -> int:
1108
911
  return len(_get_all_task_ids_statuses(job_id))
1109
912
 
@@ -1131,6 +934,16 @@ def get_latest_task_id_status(
1131
934
  return task_id, status
1132
935
 
1133
936
 
937
+ @_init_db
938
+ def get_job_controller_pid(job_id: int) -> Optional[int]:
939
+ assert _SQLALCHEMY_ENGINE is not None
940
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
941
+ pid = session.execute(
942
+ sqlalchemy.select(job_info_table.c.controller_pid).where(
943
+ job_info_table.c.spot_job_id == job_id)).fetchone()
944
+ return pid[0] if pid else None
945
+
946
+
1134
947
  def get_status(job_id: int) -> Optional[ManagedJobStatus]:
1135
948
  _, status = get_latest_task_id_status(job_id)
1136
949
  return status
@@ -1242,30 +1055,10 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1242
1055
  return json.loads(task_specs[0])
1243
1056
 
1244
1057
 
1245
- @_init_db
1246
- def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1247
- """Get the local log directory for a job."""
1248
- assert _SQLALCHEMY_ENGINE is not None
1249
-
1250
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1251
- where_conditions = [spot_table.c.spot_job_id == job_id]
1252
- if task_id is not None:
1253
- where_conditions.append(spot_table.c.task_id == task_id)
1254
- local_log_file = session.execute(
1255
- sqlalchemy.select(spot_table.c.local_log_file).where(
1256
- sqlalchemy.and_(*where_conditions))).fetchone()
1257
- return local_log_file[-1] if local_log_file else None
1258
-
1259
-
1260
- # === Scheduler state functions ===
1261
- # Only the scheduler should call these functions. They may require holding the
1262
- # scheduler lock to work correctly.
1263
-
1264
-
1265
1058
  @_init_db
1266
1059
  def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1267
1060
  original_user_yaml_path: str, env_file_path: str,
1268
- user_hash: str, priority: int) -> bool:
1061
+ user_hash: str, priority: int):
1269
1062
  """Do not call without holding the scheduler lock.
1270
1063
 
1271
1064
  Returns: Whether this is a recovery run or not.
@@ -1277,11 +1070,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1277
1070
  assert _SQLALCHEMY_ENGINE is not None
1278
1071
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1279
1072
  updated_count = session.query(job_info_table).filter(
1280
- sqlalchemy.and_(
1281
- job_info_table.c.spot_job_id == job_id,
1282
- job_info_table.c.schedule_state ==
1283
- ManagedJobScheduleState.INACTIVE.value,
1284
- )
1073
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id,)
1285
1074
  ).update({
1286
1075
  job_info_table.c.schedule_state:
1287
1076
  ManagedJobScheduleState.WAITING.value,
@@ -1292,9 +1081,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1292
1081
  job_info_table.c.priority: priority,
1293
1082
  })
1294
1083
  session.commit()
1295
- # For a recovery run, the job may already be in the WAITING state.
1296
1084
  assert updated_count <= 1, (job_id, updated_count)
1297
- return updated_count == 0
1298
1085
 
1299
1086
 
1300
1087
  @_init_db
@@ -1319,17 +1106,18 @@ def set_current_cluster_name(job_id: int, current_cluster_name: str) -> None:
1319
1106
  session.commit()
1320
1107
 
1321
1108
 
1322
- @_init_db
1323
- def set_job_id_on_pool_cluster(job_id: int,
1324
- job_id_on_pool_cluster: int) -> None:
1109
+ @_init_db_async
1110
+ async def set_job_id_on_pool_cluster_async(job_id: int,
1111
+ job_id_on_pool_cluster: int) -> None:
1325
1112
  """Set the job id on the pool cluster for a job."""
1326
- assert _SQLALCHEMY_ENGINE is not None
1327
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1328
- session.query(job_info_table).filter(
1329
- job_info_table.c.spot_job_id == job_id).update({
1113
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1114
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1115
+ await session.execute(
1116
+ sqlalchemy.update(job_info_table).
1117
+ where(job_info_table.c.spot_job_id == job_id).values({
1330
1118
  job_info_table.c.job_id_on_pool_cluster: job_id_on_pool_cluster
1331
- })
1332
- session.commit()
1119
+ }))
1120
+ await session.commit()
1333
1121
 
1334
1122
 
1335
1123
  @_init_db
@@ -1347,77 +1135,54 @@ def get_pool_submit_info(job_id: int) -> Tuple[Optional[str], Optional[int]]:
1347
1135
  return info[0], info[1]
1348
1136
 
1349
1137
 
1350
- @_init_db
1351
- def scheduler_set_launching(job_id: int,
1352
- current_state: ManagedJobScheduleState) -> None:
1353
- """Do not call without holding the scheduler lock."""
1354
- assert _SQLALCHEMY_ENGINE is not None
1355
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1356
- updated_count = session.query(job_info_table).filter(
1357
- sqlalchemy.and_(
1358
- job_info_table.c.spot_job_id == job_id,
1359
- job_info_table.c.schedule_state == current_state.value,
1360
- )).update({
1361
- job_info_table.c.schedule_state:
1362
- ManagedJobScheduleState.LAUNCHING.value
1363
- })
1364
- session.commit()
1365
- assert updated_count == 1, (job_id, updated_count)
1366
-
1367
-
1368
- @_init_db
1369
- def scheduler_set_alive(job_id: int) -> None:
1370
- """Do not call without holding the scheduler lock."""
1371
- assert _SQLALCHEMY_ENGINE is not None
1372
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1373
- updated_count = session.query(job_info_table).filter(
1374
- sqlalchemy.and_(
1375
- job_info_table.c.spot_job_id == job_id,
1376
- job_info_table.c.schedule_state ==
1377
- ManagedJobScheduleState.LAUNCHING.value,
1378
- )).update({
1379
- job_info_table.c.schedule_state:
1380
- ManagedJobScheduleState.ALIVE.value
1381
- })
1382
- session.commit()
1383
- assert updated_count == 1, (job_id, updated_count)
1138
+ @_init_db_async
1139
+ async def get_pool_submit_info_async(
1140
+ job_id: int) -> Tuple[Optional[str], Optional[int]]:
1141
+ """Get the cluster name and job id on the pool from the managed job id."""
1142
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1143
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1144
+ result = await session.execute(
1145
+ sqlalchemy.select(job_info_table.c.current_cluster_name,
1146
+ job_info_table.c.job_id_on_pool_cluster).where(
1147
+ job_info_table.c.spot_job_id == job_id))
1148
+ info = result.fetchone()
1149
+ if info is None:
1150
+ return None, None
1151
+ return info[0], info[1]
1384
1152
 
1385
1153
 
1386
- @_init_db
1387
- def scheduler_set_alive_backoff(job_id: int) -> None:
1388
- """Do not call without holding the scheduler lock."""
1389
- assert _SQLALCHEMY_ENGINE is not None
1390
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1391
- updated_count = session.query(job_info_table).filter(
1392
- sqlalchemy.and_(
1393
- job_info_table.c.spot_job_id == job_id,
1394
- job_info_table.c.schedule_state ==
1395
- ManagedJobScheduleState.LAUNCHING.value,
1396
- )).update({
1397
- job_info_table.c.schedule_state:
1398
- ManagedJobScheduleState.ALIVE_BACKOFF.value
1399
- })
1400
- session.commit()
1401
- assert updated_count == 1, (job_id, updated_count)
1154
+ @_init_db_async
1155
+ async def scheduler_set_launching_async(job_id: int):
1156
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1157
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1158
+ await session.execute(
1159
+ sqlalchemy.update(job_info_table).where(
1160
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id)).values(
1161
+ {
1162
+ job_info_table.c.schedule_state:
1163
+ ManagedJobScheduleState.LAUNCHING.value
1164
+ }))
1165
+ await session.commit()
1402
1166
 
1403
1167
 
1404
- @_init_db
1405
- def scheduler_set_alive_waiting(job_id: int) -> None:
1168
+ @_init_db_async
1169
+ async def scheduler_set_alive_async(job_id: int) -> None:
1406
1170
  """Do not call without holding the scheduler lock."""
1407
- assert _SQLALCHEMY_ENGINE is not None
1408
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1409
- updated_count = session.query(job_info_table).filter(
1410
- sqlalchemy.and_(
1411
- job_info_table.c.spot_job_id == job_id,
1412
- job_info_table.c.schedule_state.in_([
1413
- ManagedJobScheduleState.ALIVE.value,
1414
- ManagedJobScheduleState.ALIVE_BACKOFF.value,
1415
- ]))).update({
1171
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1172
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1173
+ result = await session.execute(
1174
+ sqlalchemy.update(job_info_table).where(
1175
+ sqlalchemy.and_(
1176
+ job_info_table.c.spot_job_id == job_id,
1177
+ job_info_table.c.schedule_state ==
1178
+ ManagedJobScheduleState.LAUNCHING.value,
1179
+ )).values({
1416
1180
  job_info_table.c.schedule_state:
1417
- ManagedJobScheduleState.ALIVE_WAITING.value
1418
- })
1419
- session.commit()
1420
- assert updated_count == 1, (job_id, updated_count)
1181
+ ManagedJobScheduleState.ALIVE.value
1182
+ }))
1183
+ changes = result.rowcount
1184
+ await session.commit()
1185
+ assert changes == 1, (job_id, changes)
1421
1186
 
1422
1187
 
1423
1188
  @_init_db
@@ -1439,16 +1204,6 @@ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1439
1204
  assert updated_count == 1, (job_id, updated_count)
1440
1205
 
1441
1206
 
1442
- @_init_db
1443
- def set_job_controller_pid(job_id: int, pid: int):
1444
- assert _SQLALCHEMY_ENGINE is not None
1445
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1446
- updated_count = session.query(job_info_table).filter_by(
1447
- spot_job_id=job_id).update({job_info_table.c.controller_pid: pid})
1448
- session.commit()
1449
- assert updated_count == 1, (job_id, updated_count)
1450
-
1451
-
1452
1207
  @_init_db
1453
1208
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1454
1209
  assert _SQLALCHEMY_ENGINE is not None
@@ -1527,58 +1282,78 @@ def get_nonterminal_job_ids_by_pool(pool: str,
1527
1282
  return job_ids
1528
1283
 
1529
1284
 
1530
- @_init_db
1531
- def get_waiting_job() -> Optional[Dict[str, Any]]:
1285
+ @_init_db_async
1286
+ async def get_waiting_job_async(pid: int) -> Optional[Dict[str, Any]]:
1532
1287
  """Get the next job that should transition to LAUNCHING.
1533
1288
 
1534
- Selects the highest-priority WAITING or ALIVE_WAITING job, provided its
1535
- priority is greater than or equal to any currently LAUNCHING or
1536
- ALIVE_BACKOFF job.
1289
+ Selects the highest-priority WAITING or ALIVE_WAITING job and atomically
1290
+ transitions it to LAUNCHING state to prevent race conditions.
1291
+
1292
+ Returns the job information if a job was successfully transitioned to
1293
+ LAUNCHING, or None if no suitable job was found.
1537
1294
 
1538
1295
  Backwards compatibility note: jobs submitted before #4485 will have no
1539
1296
  schedule_state and will be ignored by this SQL query.
1540
1297
  """
1541
- assert _SQLALCHEMY_ENGINE is not None
1542
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1543
- # Get the highest-priority WAITING or ALIVE_WAITING job whose priority
1544
- # is greater than or equal to the highest priority LAUNCHING or
1545
- # ALIVE_BACKOFF job's priority.
1546
- # First, get the max priority of LAUNCHING or ALIVE_BACKOFF jobs
1547
- max_priority_subquery = sqlalchemy.select(
1548
- sqlalchemy.func.max(job_info_table.c.priority)).where(
1549
- job_info_table.c.schedule_state.in_([
1550
- ManagedJobScheduleState.LAUNCHING.value,
1551
- ManagedJobScheduleState.ALIVE_BACKOFF.value,
1552
- ])).scalar_subquery()
1553
- # Main query for waiting jobs
1554
- select_conds = [
1555
- job_info_table.c.schedule_state.in_([
1556
- ManagedJobScheduleState.WAITING.value,
1557
- ManagedJobScheduleState.ALIVE_WAITING.value,
1558
- ]),
1559
- job_info_table.c.priority >= sqlalchemy.func.coalesce(
1560
- max_priority_subquery, 0),
1561
- ]
1562
- query = sqlalchemy.select(
1298
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1299
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1300
+ # Select the highest priority waiting job for update (locks the row)
1301
+ select_query = sqlalchemy.select(
1563
1302
  job_info_table.c.spot_job_id,
1564
1303
  job_info_table.c.schedule_state,
1565
1304
  job_info_table.c.dag_yaml_path,
1566
1305
  job_info_table.c.env_file_path,
1306
+ job_info_table.c.controller_pid,
1567
1307
  job_info_table.c.pool,
1568
- ).where(sqlalchemy.and_(*select_conds)).order_by(
1569
- job_info_table.c.priority.desc(),
1570
- job_info_table.c.spot_job_id.asc(),
1571
- ).limit(1)
1572
- waiting_job_row = session.execute(query).fetchone()
1308
+ ).where(
1309
+ job_info_table.c.schedule_state.in_([
1310
+ ManagedJobScheduleState.WAITING.value,
1311
+ ])).order_by(
1312
+ job_info_table.c.priority.desc(),
1313
+ job_info_table.c.spot_job_id.asc(),
1314
+ ).limit(1).with_for_update()
1315
+
1316
+ # Execute the select with row locking
1317
+ result = await session.execute(select_query)
1318
+ waiting_job_row = result.fetchone()
1319
+
1573
1320
  if waiting_job_row is None:
1574
1321
  return None
1575
1322
 
1323
+ job_id = waiting_job_row[0]
1324
+ current_state = ManagedJobScheduleState(waiting_job_row[1])
1325
+ dag_yaml_path = waiting_job_row[2]
1326
+ env_file_path = waiting_job_row[3]
1327
+ controller_pid = waiting_job_row[4]
1328
+ pool = waiting_job_row[5]
1329
+
1330
+ # Update the job state to LAUNCHING
1331
+ update_result = await session.execute(
1332
+ sqlalchemy.update(job_info_table).where(
1333
+ sqlalchemy.and_(
1334
+ job_info_table.c.spot_job_id == job_id,
1335
+ job_info_table.c.schedule_state == current_state.value,
1336
+ )).values({
1337
+ job_info_table.c.schedule_state:
1338
+ ManagedJobScheduleState.LAUNCHING.value,
1339
+ job_info_table.c.controller_pid: pid,
1340
+ }))
1341
+
1342
+ if update_result.rowcount != 1:
1343
+ # Update failed, rollback and return None
1344
+ await session.rollback()
1345
+ return None
1346
+
1347
+ # Commit the transaction
1348
+ await session.commit()
1349
+
1576
1350
  return {
1577
- 'job_id': waiting_job_row[0],
1578
- 'schedule_state': ManagedJobScheduleState(waiting_job_row[1]),
1579
- 'dag_yaml_path': waiting_job_row[2],
1580
- 'env_file_path': waiting_job_row[3],
1581
- 'pool': waiting_job_row[4],
1351
+ 'job_id': job_id,
1352
+ 'schedule_state': current_state,
1353
+ 'dag_yaml_path': dag_yaml_path,
1354
+ 'env_file_path': env_file_path,
1355
+ 'old_pid': controller_pid,
1356
+ 'pool': pool,
1582
1357
  }
1583
1358
 
1584
1359
 
@@ -1641,3 +1416,429 @@ def remove_ha_recovery_script(job_id: int) -> None:
1641
1416
  session.query(ha_recovery_script_table).filter_by(
1642
1417
  job_id=job_id).delete()
1643
1418
  session.commit()
1419
+
1420
+
1421
+ @_init_db_async
1422
+ async def get_latest_task_id_status_async(
1423
+ job_id: int) -> Union[Tuple[int, ManagedJobStatus], Tuple[None, None]]:
1424
+ """Returns the (task id, status) of the latest task of a job."""
1425
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1426
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1427
+ result = await session.execute(
1428
+ sqlalchemy.select(
1429
+ spot_table.c.task_id,
1430
+ spot_table.c.status,
1431
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1432
+ spot_table.c.task_id.asc()))
1433
+ id_statuses = [
1434
+ (row[0], ManagedJobStatus(row[1])) for row in result.fetchall()
1435
+ ]
1436
+
1437
+ if not id_statuses:
1438
+ return None, None
1439
+ task_id, status = next(
1440
+ ((tid, st) for tid, st in id_statuses if not st.is_terminal()),
1441
+ id_statuses[-1],
1442
+ )
1443
+ return task_id, status
1444
+
1445
+
1446
+ @_init_db_async
1447
+ async def set_starting_async(job_id: int, task_id: int, run_timestamp: str,
1448
+ submit_time: float, resources_str: str,
1449
+ specs: Dict[str, Union[str, int]],
1450
+ callback_func: AsyncCallbackType):
1451
+ """Set the task to starting state."""
1452
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1453
+ logger.info('Launching the spot cluster...')
1454
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1455
+ result = await session.execute(
1456
+ sqlalchemy.update(spot_table).where(
1457
+ sqlalchemy.and_(
1458
+ spot_table.c.spot_job_id == job_id,
1459
+ spot_table.c.task_id == task_id,
1460
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
1461
+ spot_table.c.end_at.is_(None),
1462
+ )).values({
1463
+ spot_table.c.resources: resources_str,
1464
+ spot_table.c.submitted_at: submit_time,
1465
+ spot_table.c.status: ManagedJobStatus.STARTING.value,
1466
+ spot_table.c.run_timestamp: run_timestamp,
1467
+ spot_table.c.specs: json.dumps(specs),
1468
+ }))
1469
+ count = result.rowcount
1470
+ await session.commit()
1471
+ if count != 1:
1472
+ raise exceptions.ManagedJobStatusError(
1473
+ 'Failed to set the task to starting. '
1474
+ f'({count} rows updated)')
1475
+ await callback_func('SUBMITTED')
1476
+ await callback_func('STARTING')
1477
+
1478
+
1479
+ @_init_db_async
1480
+ async def set_started_async(job_id: int, task_id: int, start_time: float,
1481
+ callback_func: AsyncCallbackType):
1482
+ """Set the task to started state."""
1483
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1484
+ logger.info('Job started.')
1485
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1486
+ result = await session.execute(
1487
+ sqlalchemy.update(spot_table).where(
1488
+ sqlalchemy.and_(
1489
+ spot_table.c.spot_job_id == job_id,
1490
+ spot_table.c.task_id == task_id,
1491
+ spot_table.c.status.in_([
1492
+ ManagedJobStatus.STARTING.value,
1493
+ ManagedJobStatus.PENDING.value
1494
+ ]),
1495
+ spot_table.c.end_at.is_(None),
1496
+ )).values({
1497
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
1498
+ spot_table.c.start_at: start_time,
1499
+ spot_table.c.last_recovered_at: start_time,
1500
+ }))
1501
+ count = result.rowcount
1502
+ await session.commit()
1503
+ if count != 1:
1504
+ raise exceptions.ManagedJobStatusError(
1505
+ f'Failed to set the task to started. '
1506
+ f'({count} rows updated)')
1507
+ await callback_func('STARTED')
1508
+
1509
+
1510
+ @_init_db_async
1511
+ async def get_job_status_with_task_id_async(
1512
+ job_id: int, task_id: int) -> Optional[ManagedJobStatus]:
1513
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1514
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1515
+ result = await session.execute(
1516
+ sqlalchemy.select(spot_table.c.status).where(
1517
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
1518
+ spot_table.c.task_id == task_id)))
1519
+ status = result.fetchone()
1520
+ return ManagedJobStatus(status[0]) if status else None
1521
+
1522
+
1523
+ @_init_db_async
1524
+ async def set_recovering_async(job_id: int, task_id: int,
1525
+ force_transit_to_recovering: bool,
1526
+ callback_func: AsyncCallbackType):
1527
+ """Set the task to recovering state, and update the job duration."""
1528
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1529
+ logger.info('=== Recovering... ===')
1530
+ current_time = time.time()
1531
+
1532
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1533
+ if force_transit_to_recovering:
1534
+ status_condition = spot_table.c.status.in_(
1535
+ [s.value for s in ManagedJobStatus.processing_statuses()])
1536
+ else:
1537
+ status_condition = (
1538
+ spot_table.c.status == ManagedJobStatus.RUNNING.value)
1539
+
1540
+ result = await session.execute(
1541
+ sqlalchemy.update(spot_table).where(
1542
+ sqlalchemy.and_(
1543
+ spot_table.c.spot_job_id == job_id,
1544
+ spot_table.c.task_id == task_id,
1545
+ status_condition,
1546
+ spot_table.c.end_at.is_(None),
1547
+ )).values({
1548
+ spot_table.c.status: ManagedJobStatus.RECOVERING.value,
1549
+ spot_table.c.job_duration: sqlalchemy.case(
1550
+ (spot_table.c.last_recovered_at >= 0,
1551
+ spot_table.c.job_duration + current_time -
1552
+ spot_table.c.last_recovered_at),
1553
+ else_=spot_table.c.job_duration),
1554
+ spot_table.c.last_recovered_at: sqlalchemy.case(
1555
+ (spot_table.c.last_recovered_at < 0, current_time),
1556
+ else_=spot_table.c.last_recovered_at),
1557
+ }))
1558
+ count = result.rowcount
1559
+ await session.commit()
1560
+ if count != 1:
1561
+ raise exceptions.ManagedJobStatusError(
1562
+ f'Failed to set the task to recovering. '
1563
+ f'({count} rows updated)')
1564
+ await callback_func('RECOVERING')
1565
+
1566
+
1567
+ @_init_db_async
1568
+ async def set_recovered_async(job_id: int, task_id: int, recovered_time: float,
1569
+ callback_func: AsyncCallbackType):
1570
+ """Set the task to recovered."""
1571
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1572
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1573
+ result = await session.execute(
1574
+ sqlalchemy.update(spot_table).where(
1575
+ sqlalchemy.and_(
1576
+ spot_table.c.spot_job_id == job_id,
1577
+ spot_table.c.task_id == task_id,
1578
+ spot_table.c.status == ManagedJobStatus.RECOVERING.value,
1579
+ spot_table.c.end_at.is_(None),
1580
+ )).values({
1581
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
1582
+ spot_table.c.last_recovered_at: recovered_time,
1583
+ spot_table.c.recovery_count: spot_table.c.recovery_count +
1584
+ 1,
1585
+ }))
1586
+ count = result.rowcount
1587
+ await session.commit()
1588
+ if count != 1:
1589
+ raise exceptions.ManagedJobStatusError(
1590
+ f'Failed to set the task to recovered. '
1591
+ f'({count} rows updated)')
1592
+ logger.info('==== Recovered. ====')
1593
+ await callback_func('RECOVERED')
1594
+
1595
+
1596
+ @_init_db_async
1597
+ async def set_succeeded_async(job_id: int, task_id: int, end_time: float,
1598
+ callback_func: AsyncCallbackType):
1599
+ """Set the task to succeeded, if it is in a non-terminal state."""
1600
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1601
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1602
+ result = await session.execute(
1603
+ sqlalchemy.update(spot_table).where(
1604
+ sqlalchemy.and_(
1605
+ spot_table.c.spot_job_id == job_id,
1606
+ spot_table.c.task_id == task_id,
1607
+ spot_table.c.status == ManagedJobStatus.RUNNING.value,
1608
+ spot_table.c.end_at.is_(None),
1609
+ )).values({
1610
+ spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
1611
+ spot_table.c.end_at: end_time,
1612
+ }))
1613
+ count = result.rowcount
1614
+ await session.commit()
1615
+ if count != 1:
1616
+ raise exceptions.ManagedJobStatusError(
1617
+ f'Failed to set the task to succeeded. '
1618
+ f'({count} rows updated)')
1619
+ await callback_func('SUCCEEDED')
1620
+ logger.info('Job succeeded.')
1621
+
1622
+
1623
+ @_init_db_async
1624
+ async def set_failed_async(
1625
+ job_id: int,
1626
+ task_id: Optional[int],
1627
+ failure_type: ManagedJobStatus,
1628
+ failure_reason: str,
1629
+ callback_func: Optional[AsyncCallbackType] = None,
1630
+ end_time: Optional[float] = None,
1631
+ override_terminal: bool = False,
1632
+ ):
1633
+ """Set an entire job or task to failed."""
1634
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1635
+ assert failure_type.is_failed(), failure_type
1636
+ end_time = time.time() if end_time is None else end_time
1637
+
1638
+ fields_to_set: Dict[str, Any] = {
1639
+ spot_table.c.status: failure_type.value,
1640
+ spot_table.c.failure_reason: failure_reason,
1641
+ }
1642
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1643
+ # Get previous status
1644
+ result = await session.execute(
1645
+ sqlalchemy.select(
1646
+ spot_table.c.status).where(spot_table.c.spot_job_id == job_id))
1647
+ previous_status_row = result.fetchone()
1648
+ previous_status = ManagedJobStatus(previous_status_row[0])
1649
+ if previous_status == ManagedJobStatus.RECOVERING:
1650
+ fields_to_set[spot_table.c.last_recovered_at] = end_time
1651
+ where_conditions = [spot_table.c.spot_job_id == job_id]
1652
+ if task_id is not None:
1653
+ where_conditions.append(spot_table.c.task_id == task_id)
1654
+ if override_terminal:
1655
+ fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
1656
+ spot_table.c.end_at, end_time)
1657
+ else:
1658
+ fields_to_set[spot_table.c.end_at] = end_time
1659
+ where_conditions.append(spot_table.c.end_at.is_(None))
1660
+ result = await session.execute(
1661
+ sqlalchemy.update(spot_table).where(
1662
+ sqlalchemy.and_(*where_conditions)).values(fields_to_set))
1663
+ count = result.rowcount
1664
+ await session.commit()
1665
+ updated = count > 0
1666
+ if callback_func and updated:
1667
+ await callback_func('FAILED')
1668
+ logger.info(failure_reason)
1669
+
1670
+
1671
+ @_init_db_async
1672
+ async def set_cancelling_async(job_id: int, callback_func: AsyncCallbackType):
1673
+ """Set tasks in the job as cancelling, if they are in non-terminal
1674
+ states."""
1675
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1676
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1677
+ result = await session.execute(
1678
+ sqlalchemy.update(spot_table).where(
1679
+ sqlalchemy.and_(
1680
+ spot_table.c.spot_job_id == job_id,
1681
+ spot_table.c.end_at.is_(None),
1682
+ )).values(
1683
+ {spot_table.c.status: ManagedJobStatus.CANCELLING.value}))
1684
+ count = result.rowcount
1685
+ await session.commit()
1686
+ updated = count > 0
1687
+ if updated:
1688
+ logger.info('Cancelling the job...')
1689
+ await callback_func('CANCELLING')
1690
+ else:
1691
+ logger.info('Cancellation skipped, job is already terminal')
1692
+
1693
+
1694
+ @_init_db_async
1695
+ async def set_cancelled_async(job_id: int, callback_func: AsyncCallbackType):
1696
+ """Set tasks in the job as cancelled, if they are in CANCELLING state."""
1697
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1698
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1699
+ result = await session.execute(
1700
+ sqlalchemy.update(spot_table).where(
1701
+ sqlalchemy.and_(
1702
+ spot_table.c.spot_job_id == job_id,
1703
+ spot_table.c.status == ManagedJobStatus.CANCELLING.value,
1704
+ )).values({
1705
+ spot_table.c.status: ManagedJobStatus.CANCELLED.value,
1706
+ spot_table.c.end_at: time.time(),
1707
+ }))
1708
+ count = result.rowcount
1709
+ await session.commit()
1710
+ updated = count > 0
1711
+ if updated:
1712
+ logger.info('Job cancelled.')
1713
+ await callback_func('CANCELLED')
1714
+ else:
1715
+ logger.info('Cancellation skipped, job is not CANCELLING')
1716
+
1717
+
1718
+ @_init_db_async
1719
+ async def remove_ha_recovery_script_async(job_id: int) -> None:
1720
+ """Remove the HA recovery script for a job."""
1721
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1722
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1723
+ await session.execute(
1724
+ sqlalchemy.delete(ha_recovery_script_table).where(
1725
+ ha_recovery_script_table.c.job_id == job_id))
1726
+ await session.commit()
1727
+
1728
+
1729
+ async def get_status_async(job_id: int) -> Optional[ManagedJobStatus]:
1730
+ _, status = await get_latest_task_id_status_async(job_id)
1731
+ return status
1732
+
1733
+
1734
+ @_init_db_async
1735
+ async def get_job_schedule_state_async(job_id: int) -> ManagedJobScheduleState:
1736
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1737
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1738
+ result = await session.execute(
1739
+ sqlalchemy.select(job_info_table.c.schedule_state).where(
1740
+ job_info_table.c.spot_job_id == job_id))
1741
+ state = result.fetchone()[0]
1742
+ return ManagedJobScheduleState(state)
1743
+
1744
+
1745
+ @_init_db_async
1746
+ async def scheduler_set_done_async(job_id: int,
1747
+ idempotent: bool = False) -> None:
1748
+ """Do not call without holding the scheduler lock."""
1749
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1750
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1751
+ result = await session.execute(
1752
+ sqlalchemy.update(job_info_table).where(
1753
+ sqlalchemy.and_(
1754
+ job_info_table.c.spot_job_id == job_id,
1755
+ job_info_table.c.schedule_state !=
1756
+ ManagedJobScheduleState.DONE.value,
1757
+ )).values({
1758
+ job_info_table.c.schedule_state:
1759
+ ManagedJobScheduleState.DONE.value
1760
+ }))
1761
+ updated_count = result.rowcount
1762
+ await session.commit()
1763
+ if not idempotent:
1764
+ assert updated_count == 1, (job_id, updated_count)
1765
+
1766
+
1767
+ # ==== needed for codegen ====
1768
+ # functions have no use outside of codegen, remove at your own peril
1769
+
1770
+
1771
+ @_init_db
1772
+ def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str,
1773
+ pool: Optional[str], pool_hash: Optional[str]):
1774
+ assert _SQLALCHEMY_ENGINE is not None
1775
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1776
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1777
+ db_utils.SQLAlchemyDialect.SQLITE.value):
1778
+ insert_func = sqlite.insert
1779
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1780
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1781
+ insert_func = postgresql.insert
1782
+ else:
1783
+ raise ValueError('Unsupported database dialect')
1784
+ insert_stmt = insert_func(job_info_table).values(
1785
+ spot_job_id=job_id,
1786
+ name=name,
1787
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
1788
+ workspace=workspace,
1789
+ entrypoint=entrypoint,
1790
+ pool=pool,
1791
+ pool_hash=pool_hash,
1792
+ )
1793
+ session.execute(insert_stmt)
1794
+ session.commit()
1795
+
1796
+
1797
+ @_init_db
1798
+ def reset_jobs_for_recovery() -> None:
1799
+ """Remove controller PIDs for live jobs, allowing them to be recovered."""
1800
+ assert _SQLALCHEMY_ENGINE is not None
1801
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1802
+ session.query(job_info_table).filter(
1803
+ # PID should be set.
1804
+ job_info_table.c.controller_pid.isnot(None),
1805
+ # Schedule state should be alive.
1806
+ job_info_table.c.schedule_state.isnot(None),
1807
+ (job_info_table.c.schedule_state !=
1808
+ ManagedJobScheduleState.INVALID.value),
1809
+ (job_info_table.c.schedule_state !=
1810
+ ManagedJobScheduleState.WAITING.value),
1811
+ (job_info_table.c.schedule_state !=
1812
+ ManagedJobScheduleState.DONE.value),
1813
+ ).update({
1814
+ job_info_table.c.controller_pid: None,
1815
+ job_info_table.c.schedule_state:
1816
+ (ManagedJobScheduleState.WAITING.value)
1817
+ })
1818
+ session.commit()
1819
+
1820
+
1821
+ @_init_db
1822
+ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
1823
+ """Get all job ids by name."""
1824
+ assert _SQLALCHEMY_ENGINE is not None
1825
+
1826
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1827
+ query = sqlalchemy.select(
1828
+ spot_table.c.spot_job_id.distinct()).select_from(
1829
+ spot_table.outerjoin(
1830
+ job_info_table,
1831
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1832
+ if name is not None:
1833
+ # We match the job name from `job_info` for the jobs submitted after
1834
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
1835
+ # job_info is not available.
1836
+ name_condition = sqlalchemy.or_(
1837
+ job_info_table.c.name == name,
1838
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
1839
+ spot_table.c.task_name == name))
1840
+ query = query.where(name_condition)
1841
+ query = query.order_by(spot_table.c.spot_job_id.desc())
1842
+ rows = session.execute(query).fetchall()
1843
+ job_ids = [row[0] for row in rows if row[0] is not None]
1844
+ return job_ids