skypilot-nightly 1.0.0.dev20250909__py3-none-any.whl → 1.0.0.dev20250910__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +19 -4
  3. sky/backends/backend_utils.py +35 -1
  4. sky/backends/cloud_vm_ray_backend.py +2 -2
  5. sky/client/sdk.py +20 -0
  6. sky/client/sdk_async.py +18 -16
  7. sky/clouds/aws.py +3 -1
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/_next/static/chunks/{webpack-d4fabc08788e14af.js → webpack-1d7e11230da3ca89.js} +1 -1
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/config.html +1 -1
  14. sky/dashboard/out/index.html +1 -1
  15. sky/dashboard/out/infra/[context].html +1 -1
  16. sky/dashboard/out/infra.html +1 -1
  17. sky/dashboard/out/jobs/[job].html +1 -1
  18. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  19. sky/dashboard/out/jobs.html +1 -1
  20. sky/dashboard/out/users.html +1 -1
  21. sky/dashboard/out/volumes.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/data/storage.py +5 -1
  26. sky/execution.py +21 -14
  27. sky/jobs/constants.py +3 -0
  28. sky/jobs/controller.py +732 -310
  29. sky/jobs/recovery_strategy.py +251 -129
  30. sky/jobs/scheduler.py +247 -174
  31. sky/jobs/server/core.py +20 -4
  32. sky/jobs/server/utils.py +2 -2
  33. sky/jobs/state.py +702 -511
  34. sky/jobs/utils.py +94 -39
  35. sky/provision/aws/config.py +4 -1
  36. sky/provision/gcp/config.py +6 -1
  37. sky/provision/kubernetes/utils.py +17 -8
  38. sky/provision/provisioner.py +1 -0
  39. sky/serve/replica_managers.py +0 -7
  40. sky/serve/serve_utils.py +5 -0
  41. sky/serve/server/impl.py +1 -2
  42. sky/serve/service.py +0 -2
  43. sky/server/common.py +8 -3
  44. sky/server/config.py +43 -24
  45. sky/server/constants.py +1 -0
  46. sky/server/daemons.py +7 -11
  47. sky/server/requests/serializers/encoders.py +1 -1
  48. sky/server/server.py +8 -1
  49. sky/setup_files/dependencies.py +4 -2
  50. sky/skylet/attempt_skylet.py +1 -0
  51. sky/skylet/constants.py +3 -1
  52. sky/skylet/events.py +2 -10
  53. sky/utils/command_runner.pyi +3 -3
  54. sky/utils/common_utils.py +11 -1
  55. sky/utils/controller_utils.py +5 -0
  56. sky/utils/db/db_utils.py +31 -2
  57. sky/utils/rich_utils.py +3 -1
  58. sky/utils/subprocess_utils.py +9 -0
  59. sky/volumes/volume.py +2 -0
  60. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/METADATA +39 -37
  61. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/RECORD +67 -67
  62. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_buildManifest.js +0 -0
  63. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_ssgManifest.js +0 -0
  64. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/WHEEL +0 -0
  65. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/entry_points.txt +0 -0
  66. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/licenses/LICENSE +0 -0
  67. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/top_level.txt +0 -0
sky/jobs/controller.py CHANGED
@@ -1,31 +1,26 @@
1
- """Controller: handles the life cycle of a managed job.
2
-
3
- TODO(cooperc): Document lifecycle, and multiprocess layout.
1
+ """Controller: handles scheduling and the life cycle of a managed job.
4
2
  """
5
- import argparse
6
- import multiprocessing
3
+ import asyncio
4
+ import logging
7
5
  import os
8
- import pathlib
6
+ import resource
9
7
  import shutil
8
+ import sys
10
9
  import time
11
10
  import traceback
12
11
  import typing
13
- from typing import Optional, Tuple
12
+ from typing import Dict, Optional, Set, Tuple
14
13
 
15
- import filelock
14
+ import dotenv
16
15
 
17
- # This import ensures backward compatibility. Controller processes may not have
18
- # imported this module initially, but will attempt to import it during job
19
- # termination on the fly. If a job was launched with an old SkyPilot runtime
20
- # and a new job is launched with a newer runtime, the old job's termination
21
- # will try to import code from a different SkyPilot runtime, causing exceptions.
22
- # pylint: disable=unused-import
16
+ import sky
23
17
  from sky import core
24
18
  from sky import exceptions
25
19
  from sky import sky_logging
26
20
  from sky.backends import backend_utils
27
21
  from sky.backends import cloud_vm_ray_backend
28
22
  from sky.data import data_utils
23
+ from sky.jobs import constants as jobs_constants
29
24
  from sky.jobs import recovery_strategy
30
25
  from sky.jobs import scheduler
31
26
  from sky.jobs import state as managed_job_state
@@ -35,20 +30,34 @@ from sky.skylet import job_lib
35
30
  from sky.usage import usage_lib
36
31
  from sky.utils import common
37
32
  from sky.utils import common_utils
33
+ from sky.utils import context
34
+ from sky.utils import context_utils
38
35
  from sky.utils import controller_utils
39
36
  from sky.utils import dag_utils
40
37
  from sky.utils import status_lib
41
- from sky.utils import subprocess_utils
42
38
  from sky.utils import ux_utils
43
39
 
44
- if typing.TYPE_CHECKING:
45
- import sky
46
-
47
- # Use the explicit logger name so that the logger is under the
48
- # `sky.jobs.controller` namespace when executed directly, so as
49
- # to inherit the setup from the `sky` logger.
50
40
  logger = sky_logging.init_logger('sky.jobs.controller')
51
41
 
42
+ _background_tasks: Set[asyncio.Task] = set()
43
+ _background_tasks_lock: asyncio.Lock = asyncio.Lock()
44
+
45
+
46
+ async def create_background_task(coro: typing.Coroutine) -> None:
47
+ """Create a background task and add it to the set of background tasks.
48
+
49
+ Main reason we do this is since tasks are only held as a weak reference in
50
+ the executor, we need to keep a strong reference to the task to avoid it
51
+ being garbage collected.
52
+
53
+ Args:
54
+ coro: The coroutine to create a task for.
55
+ """
56
+ async with _background_tasks_lock:
57
+ task = asyncio.create_task(coro)
58
+ _background_tasks.add(task)
59
+ task.add_done_callback(_background_tasks.discard)
60
+
52
61
 
53
62
  def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]:
54
63
  dag = dag_utils.load_chain_dag_from_yaml(dag_yaml)
@@ -58,13 +67,79 @@ def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]:
58
67
 
59
68
 
60
69
  class JobsController:
61
- """Each jobs controller manages the life cycle of one managed job."""
70
+ """Controls the lifecycle of a single managed job.
71
+
72
+ This controller executes a chain DAG defined in ``dag_yaml`` by:
73
+ - Loading the DAG and preparing per-task environment variables so each task
74
+ has a stable global job identifier across recoveries.
75
+ - Launching the task on the configured backend (``CloudVmRayBackend``),
76
+ optionally via a cluster pool.
77
+ - Persisting state transitions to the managed jobs state store
78
+ (e.g., STARTING → RUNNING → SUCCEEDED/FAILED/CANCELLED).
79
+ - Monitoring execution, downloading/streaming logs, detecting failures or
80
+ preemptions, and invoking recovery through
81
+ ``recovery_strategy.StrategyExecutor``.
82
+ - Cleaning up clusters and ephemeral resources when tasks finish.
83
+
84
+ Concurrency and coordination:
85
+ - Runs inside an ``asyncio`` event loop.
86
+ - Shares a ``starting`` set, guarded by ``starting_lock`` and signaled via
87
+ ``starting_signal``, to throttle concurrent launches across jobs that the
88
+ top-level ``Controller`` manages.
89
+
90
+ Key attributes:
91
+ - ``_job_id``: Integer identifier of this managed job.
92
+ - ``_dag_yaml`` / ``_dag`` / ``_dag_name``: The job definition and metadata.
93
+ - ``_backend``: Backend used to launch and manage clusters.
94
+ - ``_pool``: Optional pool name if using a cluster pool.
95
+ - ``_logger``: Job-scoped logger for progress and diagnostics.
96
+ - ``starting`` / ``starting_lock`` / ``starting_signal``: Shared scheduler
97
+ coordination primitives. ``starting_lock`` must be used for accessing
98
+ ``starting_signal`` and ``starting``
99
+ - ``_strategy_executor``: Recovery/launch strategy executor (created per
100
+ task).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ job_id: int,
106
+ dag_yaml: str,
107
+ job_logger: logging.Logger,
108
+ starting: Set[int],
109
+ starting_lock: asyncio.Lock,
110
+ starting_signal: asyncio.Condition,
111
+ pool: Optional[str] = None,
112
+ ) -> None:
113
+ """Initialize a ``JobsController``.
114
+
115
+ Args:
116
+ job_id: Integer ID of the managed job.
117
+ dag_yaml: Path to the YAML file containing the chain DAG to run.
118
+ job_logger: Logger instance dedicated to this job.
119
+ starting: Shared set of job IDs currently in the STARTING phase,
120
+ used to limit concurrent launches.
121
+ starting_lock: ``asyncio.Lock`` guarding access to the shared
122
+ scheduler state (e.g., the ``starting`` set).
123
+ starting_signal: ``asyncio.Condition`` used to notify when a job
124
+ exits STARTING so more jobs can be admitted.
125
+ pool: Optional cluster pool name. When provided, the job is
126
+ submitted to the pool rather than launching a dedicated
127
+ cluster.
128
+ """
129
+
130
+ self.starting = starting
131
+ self.starting_lock = starting_lock
132
+ self.starting_signal = starting_signal
133
+
134
+ self._logger = job_logger
135
+ self._logger.info(f'Initializing JobsController for job_id={job_id}, '
136
+ f'dag_yaml={dag_yaml}')
62
137
 
63
- def __init__(self, job_id: int, dag_yaml: str, pool: Optional[str]) -> None:
64
138
  self._job_id = job_id
139
+ self._dag_yaml = dag_yaml
65
140
  self._dag, self._dag_name = _get_dag_and_name(dag_yaml)
66
- logger.info(self._dag)
67
- # TODO(zhwu): this assumes the specific backend.
141
+ self._logger.info(f'Loaded DAG: {self._dag}')
142
+
68
143
  self._backend = cloud_vm_ray_backend.CloudVmRayBackend()
69
144
  self._pool = pool
70
145
 
@@ -84,6 +159,7 @@ class JobsController:
84
159
  # dag_utils.maybe_infer_and_fill_dag_and_task_names.
85
160
  assert task_name is not None, self._dag
86
161
  task_name = f'{self._dag_name}_{task_name}'
162
+
87
163
  job_id_env_var = common_utils.get_global_job_id(
88
164
  self._backend.run_timestamp,
89
165
  f'{task_name}',
@@ -102,7 +178,7 @@ class JobsController:
102
178
  def _download_log_and_stream(
103
179
  self,
104
180
  task_id: Optional[int],
105
- handle: Optional[cloud_vm_ray_backend.CloudVmRayResourceHandle],
181
+ handle: Optional['cloud_vm_ray_backend.CloudVmRayResourceHandle'],
106
182
  job_id_on_pool_cluster: Optional[int],
107
183
  ) -> None:
108
184
  """Downloads and streams the logs of the current job with given task ID.
@@ -112,9 +188,10 @@ class JobsController:
112
188
  preemptions or ssh disconnection during the streaming.
113
189
  """
114
190
  if handle is None:
115
- logger.info(f'Cluster for job {self._job_id} is not found. '
116
- 'Skipping downloading and streaming the logs.')
191
+ self._logger.info(f'Cluster for job {self._job_id} is not found. '
192
+ 'Skipping downloading and streaming the logs.')
117
193
  return
194
+
118
195
  managed_job_logs_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
119
196
  'managed_jobs',
120
197
  f'job-id-{self._job_id}')
@@ -125,19 +202,25 @@ class JobsController:
125
202
  job_ids=[str(job_id_on_pool_cluster)]
126
203
  if job_id_on_pool_cluster is not None else None)
127
204
  if log_file is not None:
128
- # Set the path of the log file for the current task, so it can be
129
- # accessed even after the job is finished
205
+ # Set the path of the log file for the current task, so it can
206
+ # be accessed even after the job is finished
130
207
  managed_job_state.set_local_log_file(self._job_id, task_id,
131
208
  log_file)
132
- logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
209
+ else:
210
+ self._logger.warning(
211
+ f'No log file was downloaded for job {self._job_id}, '
212
+ f'task {task_id}')
213
+
214
+ self._logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
133
215
 
134
- def _cleanup_cluster(self, cluster_name: Optional[str]) -> None:
216
+ async def _cleanup_cluster(self, cluster_name: Optional[str]) -> None:
135
217
  if cluster_name is None:
136
218
  return
137
219
  if self._pool is None:
138
- managed_job_utils.terminate_cluster(cluster_name)
220
+ await context_utils.to_thread(managed_job_utils.terminate_cluster,
221
+ cluster_name)
139
222
 
140
- def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
223
+ async def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
141
224
  """Busy loop monitoring cluster status and handling recovery.
142
225
 
143
226
  When the task is successfully completed, this function returns True,
@@ -172,38 +255,52 @@ class JobsController:
172
255
  3. Any unexpected error happens during the `sky.launch`.
173
256
  Other exceptions may be raised depending on the backend.
174
257
  """
258
+ task_start_time = time.time()
259
+ self._logger.info(
260
+ f'Starting task {task_id} ({task.name}) for job {self._job_id}')
175
261
 
176
262
  latest_task_id, last_task_prev_status = (
177
- managed_job_state.get_latest_task_id_status(self._job_id))
263
+ await
264
+ managed_job_state.get_latest_task_id_status_async(self._job_id))
265
+
178
266
  is_resume = False
179
267
  if (latest_task_id is not None and last_task_prev_status !=
180
268
  managed_job_state.ManagedJobStatus.PENDING):
181
269
  assert latest_task_id >= task_id, (latest_task_id, task_id)
182
270
  if latest_task_id > task_id:
183
- logger.info(f'Task {task_id} ({task.name}) has already '
184
- 'been executed. Skipping...')
271
+ self._logger.info(f'Task {task_id} ({task.name}) has already '
272
+ 'been executed. Skipping...')
185
273
  return True
186
274
  if latest_task_id == task_id:
187
275
  # Start recovery.
188
276
  is_resume = True
277
+ self._logger.info(
278
+ f'Resuming task {task_id} from previous execution')
189
279
 
190
280
  callback_func = managed_job_utils.event_callback_func(
191
281
  job_id=self._job_id, task_id=task_id, task=task)
282
+
192
283
  if task.run is None:
193
- logger.info(f'Skip running task {task_id} ({task.name}) due to its '
194
- 'run commands being empty.')
284
+ self._logger.info(
285
+ f'Skip running task {task_id} ({task.name}) due to its '
286
+ 'run commands being empty.')
195
287
  # Call set_started first to initialize columns in the state table,
196
288
  # including start_at and last_recovery_at to avoid issues for
197
289
  # uninitialized columns.
198
- managed_job_state.set_started(job_id=self._job_id,
199
- task_id=task_id,
200
- start_time=time.time(),
201
- callback_func=callback_func)
202
- managed_job_state.set_succeeded(job_id=self._job_id,
203
- task_id=task_id,
204
- end_time=time.time(),
205
- callback_func=callback_func)
290
+ await managed_job_state.set_started_async(
291
+ job_id=self._job_id,
292
+ task_id=task_id,
293
+ start_time=time.time(),
294
+ callback_func=callback_func)
295
+ await managed_job_state.set_succeeded_async(
296
+ job_id=self._job_id,
297
+ task_id=task_id,
298
+ end_time=time.time(),
299
+ callback_func=callback_func)
300
+ self._logger.info(
301
+ f'Empty task {task_id} marked as succeeded immediately')
206
302
  return True
303
+
207
304
  usage_lib.messages.usage.update_task_id(task_id)
208
305
  task_id_env_var = task.envs[constants.TASK_ID_ENV_VAR]
209
306
  assert task.name is not None, task
@@ -214,61 +311,97 @@ class JobsController:
214
311
  task.name, self._job_id) if self._pool is None else None
215
312
  self._strategy_executor = recovery_strategy.StrategyExecutor.make(
216
313
  cluster_name, self._backend, task, self._job_id, task_id,
217
- self._pool)
314
+ self._logger, self._pool, self.starting, self.starting_lock,
315
+ self.starting_signal)
218
316
  if not is_resume:
219
317
  submitted_at = time.time()
220
318
  if task_id == 0:
221
319
  submitted_at = backend_utils.get_timestamp_from_run_timestamp(
222
320
  self._backend.run_timestamp)
223
- managed_job_state.set_starting(
321
+
322
+ resources_str = backend_utils.get_task_resources_str(
323
+ task, is_managed_job=True)
324
+
325
+ await managed_job_state.set_starting_async(
224
326
  self._job_id,
225
327
  task_id,
226
328
  self._backend.run_timestamp,
227
329
  submitted_at,
228
- resources_str=backend_utils.get_task_resources_str(
229
- task, is_managed_job=True),
330
+ resources_str=resources_str,
230
331
  specs={
231
332
  'max_restarts_on_errors':
232
333
  self._strategy_executor.max_restarts_on_errors
233
334
  },
234
335
  callback_func=callback_func)
235
- logger.info(f'Submitted managed job {self._job_id} '
236
- f'(task: {task_id}, name: {task.name!r}); '
237
- f'{constants.TASK_ID_ENV_VAR}: {task_id_env_var}')
336
+ self._logger.info(f'Submitted managed job {self._job_id} '
337
+ f'(task: {task_id}, name: {task.name!r}); '
338
+ f'{constants.TASK_ID_ENV_VAR}: {task_id_env_var}')
238
339
 
239
- logger.info('Started monitoring.')
340
+ self._logger.info('Started monitoring.')
240
341
 
241
342
  # Only do the initial cluster launch if not resuming from a controller
242
343
  # failure. Otherwise, we will transit to recovering immediately.
243
344
  remote_job_submitted_at = time.time()
244
345
  if not is_resume:
245
- remote_job_submitted_at = self._strategy_executor.launch()
346
+ launch_start = time.time()
347
+
348
+ # Run the launch in a separate thread to avoid blocking the event
349
+ # loop. The scheduler functions used internally already have their
350
+ # own file locks.
351
+ remote_job_submitted_at = await self._strategy_executor.launch()
352
+
353
+ launch_time = time.time() - launch_start
354
+ self._logger.info(f'Cluster launch completed in {launch_time:.2f}s')
246
355
  assert remote_job_submitted_at is not None, remote_job_submitted_at
247
356
  if self._pool is None:
248
357
  job_id_on_pool_cluster = None
249
358
  else:
250
359
  # Update the cluster name when using cluster pool.
251
360
  cluster_name, job_id_on_pool_cluster = (
252
- managed_job_state.get_pool_submit_info(self._job_id))
361
+ await
362
+ managed_job_state.get_pool_submit_info_async(self._job_id))
253
363
  assert cluster_name is not None, (cluster_name, job_id_on_pool_cluster)
254
364
 
255
365
  if not is_resume:
256
- managed_job_state.set_started(job_id=self._job_id,
257
- task_id=task_id,
258
- start_time=remote_job_submitted_at,
259
- callback_func=callback_func)
366
+ await managed_job_state.set_started_async(
367
+ job_id=self._job_id,
368
+ task_id=task_id,
369
+ start_time=remote_job_submitted_at,
370
+ callback_func=callback_func)
371
+
372
+ monitoring_start_time = time.time()
373
+ status_check_count = 0
374
+
375
+ async with self.starting_lock:
376
+ try:
377
+ self.starting.remove(self._job_id)
378
+ # its fine if we notify again, better to wake someone up
379
+ # and have them go to sleep again, then have some stuck
380
+ # sleeping.
381
+ # ps. this shouldn't actually happen because if its been
382
+ # removed from the set then we would get a key error.
383
+ self.starting_signal.notify()
384
+ except KeyError:
385
+ pass
260
386
 
261
387
  while True:
388
+ status_check_count += 1
389
+
262
390
  # NOTE: if we are resuming from a controller failure, we only keep
263
391
  # monitoring if the job is in RUNNING state. For all other cases,
264
392
  # we will directly transit to recovering since we have no idea what
265
393
  # the cluster status is.
266
394
  force_transit_to_recovering = False
267
395
  if is_resume:
268
- prev_status = managed_job_state.get_job_status_with_task_id(
269
- job_id=self._job_id, task_id=task_id)
396
+ prev_status = await (
397
+ managed_job_state.get_job_status_with_task_id_async(
398
+ job_id=self._job_id, task_id=task_id))
399
+
270
400
  if prev_status is not None:
271
401
  if prev_status.is_terminal():
402
+ self._logger.info(
403
+ f'Task {task_id} already in terminal state: '
404
+ f'{prev_status}')
272
405
  return (prev_status ==
273
406
  managed_job_state.ManagedJobStatus.SUCCEEDED)
274
407
  if (prev_status ==
@@ -276,23 +409,26 @@ class JobsController:
276
409
  # If the controller is down when cancelling the job,
277
410
  # we re-raise the error to run the `_cleanup` function
278
411
  # again to clean up any remaining resources.
279
- raise exceptions.ManagedJobUserCancelledError(
280
- 'Recovering cancel signal.')
412
+ self._logger.info(
413
+ f'Task {task_id} was being cancelled, '
414
+ 're-raising cancellation')
415
+ raise asyncio.CancelledError()
281
416
  if prev_status != managed_job_state.ManagedJobStatus.RUNNING:
282
417
  force_transit_to_recovering = True
283
418
  # This resume logic should only be triggered once.
284
419
  is_resume = False
285
420
 
286
- time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS)
421
+ await asyncio.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS)
287
422
 
288
423
  # Check the network connection to avoid false alarm for job failure.
289
424
  # Network glitch was observed even in the VM.
290
425
  try:
291
- backend_utils.check_network_connection()
426
+ await backend_utils.async_check_network_connection()
292
427
  except exceptions.NetworkError:
293
- logger.info('Network is not available. Retrying again in '
294
- f'{managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS} '
295
- 'seconds.')
428
+ self._logger.info(
429
+ 'Network is not available. Retrying again in '
430
+ f'{managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS} '
431
+ 'seconds.')
296
432
  continue
297
433
 
298
434
  # NOTE: we do not check cluster status first because race condition
@@ -303,32 +439,47 @@ class JobsController:
303
439
  job_status = None
304
440
  if not force_transit_to_recovering:
305
441
  try:
306
- job_status = managed_job_utils.get_job_status(
442
+ job_status = await managed_job_utils.get_job_status(
307
443
  self._backend,
308
444
  cluster_name,
309
- job_id=job_id_on_pool_cluster)
445
+ job_id=job_id_on_pool_cluster,
446
+ job_logger=self._logger,
447
+ )
310
448
  except exceptions.FetchClusterInfoError as fetch_e:
311
- logger.info(
449
+ self._logger.info(
312
450
  'Failed to fetch the job status. Start recovery.\n'
313
451
  f'Exception: {common_utils.format_exception(fetch_e)}\n'
314
452
  f'Traceback: {traceback.format_exc()}')
315
453
 
316
454
  if job_status == job_lib.JobStatus.SUCCEEDED:
317
- success_end_time = managed_job_utils.try_to_get_job_end_time(
318
- self._backend, cluster_name, job_id_on_pool_cluster)
455
+ self._logger.info(f'Task {task_id} succeeded! '
456
+ 'Getting end time and cleaning up')
457
+ try:
458
+ success_end_time = await context_utils.to_thread(
459
+ managed_job_utils.try_to_get_job_end_time,
460
+ self._backend, cluster_name, job_id_on_pool_cluster)
461
+ except Exception as e: # pylint: disable=broad-except
462
+ self._logger.warning(
463
+ f'Failed to get job end time: '
464
+ f'{common_utils.format_exception(e)}',
465
+ exc_info=True)
466
+ success_end_time = 0
467
+
319
468
  # The job is done. Set the job to SUCCEEDED first before start
320
469
  # downloading and streaming the logs to make it more responsive.
321
- managed_job_state.set_succeeded(self._job_id,
322
- task_id,
323
- end_time=success_end_time,
324
- callback_func=callback_func)
325
- logger.info(
470
+ await managed_job_state.set_succeeded_async(
471
+ self._job_id,
472
+ task_id,
473
+ end_time=success_end_time,
474
+ callback_func=callback_func)
475
+ self._logger.info(
326
476
  f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. '
327
477
  f'Cleaning up the cluster {cluster_name}.')
328
478
  try:
329
479
  logger.info(f'Downloading logs on cluster {cluster_name} '
330
480
  f'and job id {job_id_on_pool_cluster}.')
331
- clusters = backend_utils.get_clusters(
481
+ clusters = await context_utils.to_thread(
482
+ backend_utils.get_clusters,
332
483
  cluster_names=[cluster_name],
333
484
  refresh=common.StatusRefreshMode.NONE,
334
485
  all_users=True,
@@ -337,17 +488,25 @@ class JobsController:
337
488
  assert len(clusters) == 1, (clusters, cluster_name)
338
489
  handle = clusters[0].get('handle')
339
490
  # Best effort to download and stream the logs.
340
- self._download_log_and_stream(task_id, handle,
341
- job_id_on_pool_cluster)
491
+ await context_utils.to_thread(
492
+ self._download_log_and_stream, task_id, handle,
493
+ job_id_on_pool_cluster)
342
494
  except Exception as e: # pylint: disable=broad-except
343
495
  # We don't want to crash here, so just log and continue.
344
- logger.warning(
496
+ self._logger.warning(
345
497
  f'Failed to download and stream logs: '
346
498
  f'{common_utils.format_exception(e)}',
347
499
  exc_info=True)
348
500
  # Only clean up the cluster, not the storages, because tasks may
349
501
  # share storages.
350
- self._cleanup_cluster(cluster_name)
502
+ await self._cleanup_cluster(cluster_name)
503
+
504
+ task_total_time = time.time() - task_start_time
505
+ monitoring_time = time.time() - monitoring_start_time
506
+ self._logger.info(f'Task {task_id} completed successfully in '
507
+ f'{task_total_time:.2f}s '
508
+ f'(monitoring time: {monitoring_time:.2f}s, '
509
+ f'status checks: {status_check_count})')
351
510
  return True
352
511
 
353
512
  # For single-node jobs, non-terminated job_status indicates a
@@ -363,7 +522,7 @@ class JobsController:
363
522
  if job_status in job_lib.JobStatus.user_code_failure_states():
364
523
  # Add a grace period before the check of preemption to avoid
365
524
  # false alarm for job failure.
366
- time.sleep(5)
525
+ await asyncio.sleep(5)
367
526
 
368
527
  # Pull the actual cluster status from the cloud provider to
369
528
  # determine whether the cluster is preempted or failed.
@@ -383,7 +542,7 @@ class JobsController:
383
542
  # code).
384
543
  cluster_status_str = ('' if cluster_status is None else
385
544
  f' (status: {cluster_status.value})')
386
- logger.info(
545
+ self._logger.info(
387
546
  f'Cluster is preempted or failed{cluster_status_str}. '
388
547
  'Recovering...')
389
548
  else:
@@ -394,14 +553,18 @@ class JobsController:
394
553
  in job_lib.JobStatus.user_code_failure_states() or
395
554
  job_status == job_lib.JobStatus.FAILED_DRIVER):
396
555
  # The user code has probably crashed, fail immediately.
397
- end_time = managed_job_utils.try_to_get_job_end_time(
556
+ self._logger.info(
557
+ f'Task {task_id} failed with status: {job_status}')
558
+ end_time = await context_utils.to_thread(
559
+ managed_job_utils.try_to_get_job_end_time,
398
560
  self._backend, cluster_name, job_id_on_pool_cluster)
399
- logger.info(
561
+ self._logger.info(
400
562
  f'The user job failed ({job_status}). Please check the '
401
563
  'logs below.\n'
402
564
  f'== Logs of the user job (ID: {self._job_id}) ==\n')
403
565
 
404
- self._download_log_and_stream(task_id, handle,
566
+ await context_utils.to_thread(self._download_log_and_stream,
567
+ task_id, handle,
405
568
  job_id_on_pool_cluster)
406
569
 
407
570
  failure_reason = (
@@ -430,7 +593,7 @@ class JobsController:
430
593
  if should_restart_on_failure:
431
594
  max_restarts = (
432
595
  self._strategy_executor.max_restarts_on_errors)
433
- logger.info(
596
+ self._logger.info(
434
597
  f'User program crashed '
435
598
  f'({managed_job_status.value}). '
436
599
  f'Retry the job as max_restarts_on_errors is '
@@ -438,7 +601,9 @@ class JobsController:
438
601
  f'[{self._strategy_executor.restart_cnt_on_failure}'
439
602
  f'/{max_restarts}]')
440
603
  else:
441
- managed_job_state.set_failed(
604
+ self._logger.info(
605
+ f'Task {task_id} failed and will not be retried')
606
+ await managed_job_state.set_failed_async(
442
607
  self._job_id,
443
608
  task_id,
444
609
  failure_type=managed_job_status,
@@ -449,11 +614,11 @@ class JobsController:
449
614
  elif job_status is not None:
450
615
  # Either the job is cancelled (should not happen) or in some
451
616
  # unknown new state that we do not handle.
452
- logger.error(f'Unknown job status: {job_status}')
617
+ self._logger.error(f'Unknown job status: {job_status}')
453
618
  failure_reason = (
454
619
  f'Unknown job status {job_status}. To see the details, '
455
620
  f'run: sky jobs logs --controller {self._job_id}')
456
- managed_job_state.set_failed(
621
+ await managed_job_state.set_failed_async(
457
622
  self._job_id,
458
623
  task_id,
459
624
  failure_type=managed_job_state.ManagedJobStatus.
@@ -466,9 +631,10 @@ class JobsController:
466
631
  # job status. Try to recover the job (will not restart the
467
632
  # cluster, if the cluster is healthy).
468
633
  assert job_status is None, job_status
469
- logger.info('Failed to fetch the job status while the '
470
- 'cluster is healthy. Try to recover the job '
471
- '(the cluster will not be restarted).')
634
+ self._logger.info(
635
+ 'Failed to fetch the job status while the '
636
+ 'cluster is healthy. Try to recover the job '
637
+ '(the cluster will not be restarted).')
472
638
  # When the handle is None, the cluster should be cleaned up already.
473
639
  if handle is not None:
474
640
  resources = handle.launched_resources
@@ -487,86 +653,121 @@ class JobsController:
487
653
  # Some spot resource (e.g., Spot TPU VM) may need to be
488
654
  # cleaned up after preemption, as running launch again on
489
655
  # those clusters again may fail.
490
- logger.info('Cleaning up the preempted or failed cluster'
491
- '...')
492
- self._cleanup_cluster(cluster_name)
656
+ self._logger.info(
657
+ 'Cleaning up the preempted or failed cluster'
658
+ '...')
659
+ await self._cleanup_cluster(cluster_name)
493
660
 
494
661
  # Try to recover the managed jobs, when the cluster is preempted or
495
662
  # failed or the job status is failed to be fetched.
496
- managed_job_state.set_recovering(
663
+ self._logger.info(f'Starting recovery for task {task_id}, '
664
+ f'it is currently {job_status}')
665
+ await managed_job_state.set_recovering_async(
497
666
  job_id=self._job_id,
498
667
  task_id=task_id,
499
668
  force_transit_to_recovering=force_transit_to_recovering,
500
669
  callback_func=callback_func)
501
- recovered_time = self._strategy_executor.recover()
670
+
671
+ recovered_time = await self._strategy_executor.recover()
672
+
502
673
  if self._pool is not None:
503
674
  cluster_name, job_id_on_pool_cluster = (
504
- managed_job_state.get_pool_submit_info(self._job_id))
675
+ await
676
+ managed_job_state.get_pool_submit_info_async(self._job_id))
505
677
  assert cluster_name is not None
506
- managed_job_state.set_recovered(self._job_id,
507
- task_id,
508
- recovered_time=recovered_time,
509
- callback_func=callback_func)
678
+ await managed_job_state.set_recovered_async(
679
+ self._job_id,
680
+ task_id,
681
+ recovered_time=recovered_time,
682
+ callback_func=callback_func)
510
683
 
511
- def run(self):
684
+ async def run(self):
512
685
  """Run controller logic and handle exceptions."""
686
+ self._logger.info(f'Starting JobsController run for job {self._job_id}')
513
687
  task_id = 0
688
+ cancelled = False
689
+
514
690
  try:
515
691
  succeeded = True
516
692
  # We support chain DAGs only for now.
517
693
  for task_id, task in enumerate(self._dag.tasks):
518
- succeeded = self._run_one_task(task_id, task)
694
+ self._logger.info(
695
+ f'Processing task {task_id}/{len(self._dag.tasks)-1}: '
696
+ f'{task.name}')
697
+ task_start = time.time()
698
+ succeeded = await self._run_one_task(task_id, task)
699
+ task_time = time.time() - task_start
700
+ self._logger.info(
701
+ f'Task {task_id} completed in {task_time:.2f}s '
702
+ f'with success={succeeded}')
703
+
519
704
  if not succeeded:
705
+ self._logger.info(
706
+ f'Task {task_id} failed, stopping execution')
520
707
  break
708
+
521
709
  except exceptions.ProvisionPrechecksError as e:
522
710
  # Please refer to the docstring of self._run for the cases when
523
711
  # this exception can occur.
712
+ self._logger.error(f'Provision prechecks failed for task {task_id}')
524
713
  failure_reason = ('; '.join(
525
714
  common_utils.format_exception(reason, use_bracket=True)
526
715
  for reason in e.reasons))
527
- logger.error(failure_reason)
528
- self._update_failed_task_state(
716
+ self._logger.error(failure_reason)
717
+ await self._update_failed_task_state(
529
718
  task_id, managed_job_state.ManagedJobStatus.FAILED_PRECHECKS,
530
719
  failure_reason)
531
720
  except exceptions.ManagedJobReachedMaxRetriesError as e:
532
721
  # Please refer to the docstring of self._run for the cases when
533
722
  # this exception can occur.
723
+ self._logger.error(
724
+ f'Managed job reached max retries for task {task_id}')
534
725
  failure_reason = common_utils.format_exception(e)
535
- logger.error(failure_reason)
726
+ self._logger.error(failure_reason)
536
727
  # The managed job should be marked as FAILED_NO_RESOURCE, as the
537
728
  # managed job may be able to launch next time.
538
- self._update_failed_task_state(
729
+ await self._update_failed_task_state(
539
730
  task_id, managed_job_state.ManagedJobStatus.FAILED_NO_RESOURCE,
540
731
  failure_reason)
732
+ except asyncio.CancelledError: # pylint: disable=try-except-raise
733
+ # have this here to avoid getting caught by the general except block
734
+ # below.
735
+ cancelled = True
736
+ raise
541
737
  except (Exception, SystemExit) as e: # pylint: disable=broad-except
738
+ self._logger.error(
739
+ f'Unexpected error in JobsController run for task {task_id}')
542
740
  with ux_utils.enable_traceback():
543
- logger.error(traceback.format_exc())
741
+ self._logger.error(traceback.format_exc())
544
742
  msg = ('Unexpected error occurred: ' +
545
743
  common_utils.format_exception(e, use_bracket=True))
546
- logger.error(msg)
547
- self._update_failed_task_state(
744
+ self._logger.error(msg)
745
+ await self._update_failed_task_state(
548
746
  task_id, managed_job_state.ManagedJobStatus.FAILED_CONTROLLER,
549
747
  msg)
550
748
  finally:
551
- # This will set all unfinished tasks to CANCELLING, and will not
552
- # affect the jobs in terminal states.
553
- # We need to call set_cancelling before set_cancelled to make sure
554
- # the table entries are correctly set.
555
749
  callback_func = managed_job_utils.event_callback_func(
556
750
  job_id=self._job_id,
557
751
  task_id=task_id,
558
752
  task=self._dag.tasks[task_id])
559
- managed_job_state.set_cancelling(job_id=self._job_id,
560
- callback_func=callback_func)
561
- managed_job_state.set_cancelled(job_id=self._job_id,
562
- callback_func=callback_func)
753
+ await managed_job_state.set_cancelling_async(
754
+ job_id=self._job_id, callback_func=callback_func)
755
+ if not cancelled:
756
+ # the others haven't been run yet so we can set them to
757
+ # cancelled immediately (no resources to clean up).
758
+ # if we are running and get cancelled, we need to clean up the
759
+ # resources first so this will be done later.
760
+ await managed_job_state.set_cancelled_async(
761
+ job_id=self._job_id, callback_func=callback_func)
563
762
 
564
- def _update_failed_task_state(
763
+ async def _update_failed_task_state(
565
764
  self, task_id: int,
566
765
  failure_type: managed_job_state.ManagedJobStatus,
567
766
  failure_reason: str):
568
767
  """Update the state of the failed task."""
569
- managed_job_state.set_failed(
768
+ self._logger.info(f'Updating failed task state: task_id={task_id}, '
769
+ f'failure_type={failure_type}')
770
+ await managed_job_state.set_failed_async(
570
771
  self._job_id,
571
772
  task_id=task_id,
572
773
  failure_type=failure_type,
@@ -577,199 +778,420 @@ class JobsController:
577
778
  task=self._dag.tasks[task_id]))
578
779
 
579
780
 
580
- def _run_controller(job_id: int, dag_yaml: str, pool: Optional[str]):
581
- """Runs the controller in a remote process for interruption."""
582
- # The controller needs to be instantiated in the remote process, since
583
- # the controller is not serializable.
584
- jobs_controller = JobsController(job_id, dag_yaml, pool)
585
- jobs_controller.run()
586
-
587
-
588
- def _handle_signal(job_id):
589
- """Handle the signal if the user sent it."""
590
- signal_file = pathlib.Path(
591
- managed_job_utils.SIGNAL_FILE_PREFIX.format(job_id))
592
- user_signal = None
593
- if signal_file.exists():
594
- # Filelock is needed to prevent race condition with concurrent
595
- # signal writing.
596
- with filelock.FileLock(str(signal_file) + '.lock'):
597
- with signal_file.open(mode='r', encoding='utf-8') as f:
598
- user_signal = f.read().strip()
599
- try:
600
- user_signal = managed_job_utils.UserSignal(user_signal)
601
- except ValueError:
602
- logger.warning(
603
- f'Unknown signal received: {user_signal}. Ignoring.')
604
- user_signal = None
605
- # Remove the signal file, after reading the signal.
606
- signal_file.unlink()
607
- if user_signal is None:
608
- # None or empty string.
609
- return
610
- assert user_signal == managed_job_utils.UserSignal.CANCEL, (
611
- f'Only cancel signal is supported, but {user_signal} got.')
612
- raise exceptions.ManagedJobUserCancelledError(
613
- f'User sent {user_signal.value} signal.')
614
-
615
-
616
- def _cleanup(job_id: int, dag_yaml: str, pool: Optional[str]):
617
- """Clean up the cluster(s) and storages.
618
-
619
- (1) Clean up the succeeded task(s)' ephemeral storage. The storage has
620
- to be cleaned up after the whole job is finished, as the tasks
621
- may share the same storage.
622
- (2) Clean up the cluster(s) that are not cleaned up yet, which can happen
623
- when the task failed or cancelled. At most one cluster should be left
624
- when reaching here, as we currently only support chain DAGs, and only
625
- task is executed at a time.
626
- """
627
- # Cleanup the HA recovery script first as it is possible that some error
628
- # was raised when we construct the task object (e.g.,
629
- # sky.exceptions.ResourcesUnavailableError).
630
- managed_job_state.remove_ha_recovery_script(job_id)
631
- dag, _ = _get_dag_and_name(dag_yaml)
632
- for task in dag.tasks:
633
- assert task.name is not None, task
634
- if pool is None:
635
- cluster_name = managed_job_utils.generate_managed_job_cluster_name(
636
- task.name, job_id)
637
- managed_job_utils.terminate_cluster(cluster_name)
638
- else:
639
- cluster_name, job_id_on_pool_cluster = (
640
- managed_job_state.get_pool_submit_info(job_id))
641
- if cluster_name is not None:
642
- if job_id_on_pool_cluster is not None:
643
- core.cancel(cluster_name=cluster_name,
644
- job_ids=[job_id_on_pool_cluster],
645
- _try_cancel_if_cluster_is_init=True)
646
-
647
- # Clean up Storages with persistent=False.
648
- # TODO(zhwu): this assumes the specific backend.
649
- backend = cloud_vm_ray_backend.CloudVmRayBackend()
650
- # Need to re-construct storage object in the controller process
651
- # because when SkyPilot API server machine sends the yaml config to the
652
- # controller machine, only storage metadata is sent, not the storage
653
- # object itself.
654
- for storage in task.storage_mounts.values():
655
- storage.construct()
656
- backend.teardown_ephemeral_storage(task)
657
-
658
- # Clean up any files mounted from the local disk, such as two-hop file
659
- # mounts.
660
- for file_mount in (task.file_mounts or {}).values():
781
+ class Controller:
782
+ """Controller for managing jobs."""
783
+
784
+ def __init__(self):
785
+ # Global state for active jobs
786
+ self.job_tasks: Dict[int, asyncio.Task] = {}
787
+ self.starting: Set[int] = set()
788
+
789
+ # Lock for synchronizing access to global state dictionary
790
+ # Must always hold _job_tasks_lock when accessing the _starting_signal.
791
+ self._job_tasks_lock = asyncio.Lock()
792
+ # We signal whenever a job leaves the api server launching state. Feel
793
+ # free to signal as much as you want to be safe from leaks (if you
794
+ # do not signal enough there may be some jobs forever waiting to
795
+ # launch).
796
+ self._starting_signal = asyncio.Condition(lock=self._job_tasks_lock)
797
+
798
+ async def _cleanup(self,
799
+ job_id: int,
800
+ dag_yaml: str,
801
+ job_logger: logging.Logger,
802
+ pool: Optional[str] = None):
803
+ """Clean up the cluster(s) and storages.
804
+
805
+ (1) Clean up the succeeded task(s)' ephemeral storage. The storage has
806
+ to be cleaned up after the whole job is finished, as the tasks
807
+ may share the same storage.
808
+ (2) Clean up the cluster(s) that are not cleaned up yet, which can
809
+ happen when the task failed or cancelled. At most one cluster
810
+ should be left when reaching here, as we currently only support
811
+ chain DAGs, and only one task is executed at a time.
812
+ """
813
+ # Cleanup the HA recovery script first as it is possible that some error
814
+ # was raised when we construct the task object (e.g.,
815
+ # sky.exceptions.ResourcesUnavailableError).
816
+ await managed_job_state.remove_ha_recovery_script_async(job_id)
817
+
818
+ def task_cleanup(task: 'sky.Task', job_id: int):
819
+ assert task.name is not None, task
820
+ error = None
821
+
661
822
  try:
662
- # For consolidation mode, there is no two-hop file mounts
663
- # and the file path here represents the real user data.
664
- # We skip the cleanup for consolidation mode.
665
- if (not data_utils.is_cloud_store_url(file_mount) and
666
- not managed_job_utils.is_consolidation_mode()):
667
- path = os.path.expanduser(file_mount)
668
- if os.path.isdir(path):
669
- shutil.rmtree(path)
670
- else:
671
- os.remove(path)
823
+ if pool is None:
824
+ cluster_name = (
825
+ managed_job_utils.generate_managed_job_cluster_name(
826
+ task.name, job_id))
827
+ managed_job_utils.terminate_cluster(cluster_name,
828
+ _logger=job_logger)
829
+ status = core.status(cluster_names=[cluster_name],
830
+ all_users=True)
831
+ assert (len(status) == 0 or
832
+ status[0]['status'] == sky.ClusterStatus.STOPPED), (
833
+ f'{cluster_name} is not down: {status}')
834
+ job_logger.info(f'{cluster_name} is down')
835
+ else:
836
+ cluster_name, job_id_on_pool_cluster = (
837
+ managed_job_state.get_pool_submit_info(job_id))
838
+ if cluster_name is not None:
839
+ if job_id_on_pool_cluster is not None:
840
+ core.cancel(cluster_name=cluster_name,
841
+ job_ids=[job_id_on_pool_cluster],
842
+ _try_cancel_if_cluster_is_init=True)
672
843
  except Exception as e: # pylint: disable=broad-except
673
- logger.warning(
674
- f'Failed to clean up file mount {file_mount}: {e}')
844
+ error = e
845
+ job_logger.warning(
846
+ f'Failed to terminate cluster {cluster_name}: {e}')
847
+ # we continue to try cleaning up whatever else we can.
848
+ # Clean up Storages with persistent=False.
849
+ # TODO(zhwu): this assumes the specific backend.
850
+ backend = cloud_vm_ray_backend.CloudVmRayBackend()
851
+ # Need to re-construct storage object in the controller process
852
+ # because when SkyPilot API server machine sends the yaml config to
853
+ # the controller machine, only storage metadata is sent, not the
854
+ # storage object itself.
855
+ for storage in task.storage_mounts.values():
856
+ storage.construct()
857
+ try:
858
+ backend.teardown_ephemeral_storage(task)
859
+ except Exception as e: # pylint: disable=broad-except
860
+ error = e
861
+ job_logger.warning(f'Failed to teardown ephemeral storage: {e}')
862
+ # we continue to try cleaning up whatever else we can.
675
863
 
864
+ # Clean up any files mounted from the local disk, such as two-hop
865
+ # file mounts.
866
+ for file_mount in (task.file_mounts or {}).values():
867
+ try:
868
+ # For consolidation mode, there is no two-hop file mounts
869
+ # and the file path here represents the real user data.
870
+ # We skip the cleanup for consolidation mode.
871
+ if (not data_utils.is_cloud_store_url(file_mount) and
872
+ not managed_job_utils.is_consolidation_mode()):
873
+ path = os.path.expanduser(file_mount)
874
+ if os.path.isdir(path):
875
+ shutil.rmtree(path)
876
+ else:
877
+ os.remove(path)
878
+ except Exception as e: # pylint: disable=broad-except
879
+ job_logger.warning(
880
+ f'Failed to clean up file mount {file_mount}: {e}')
881
+
882
+ if error is not None:
883
+ raise error
676
884
 
677
- def start(job_id, dag_yaml, pool):
678
- """Start the controller."""
679
- controller_process = None
680
- cancelling = False
681
- task_id = None
682
- try:
683
- _handle_signal(job_id)
684
- # TODO(suquark): In theory, we should make controller process a
685
- # daemon process so it will be killed after this process exits,
686
- # however daemon process cannot launch subprocesses, explained here:
687
- # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.daemon # pylint: disable=line-too-long
688
- # So we can only enable daemon after we no longer need to
689
- # start daemon processes like Ray.
690
- controller_process = multiprocessing.Process(target=_run_controller,
691
- args=(job_id, dag_yaml,
692
- pool))
693
- controller_process.start()
694
- while controller_process.is_alive():
695
- _handle_signal(job_id)
696
- time.sleep(1)
697
- except exceptions.ManagedJobUserCancelledError:
698
885
  dag, _ = _get_dag_and_name(dag_yaml)
699
- task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
700
- assert task_id is not None, job_id
701
- logger.info(
702
- f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
703
- managed_job_state.set_cancelling(
704
- job_id=job_id,
705
- callback_func=managed_job_utils.event_callback_func(
706
- job_id=job_id, task_id=task_id, task=dag.tasks[task_id]))
707
- cancelling = True
708
- finally:
709
- if controller_process is not None:
710
- logger.info(f'Killing controller process {controller_process.pid}.')
711
- # NOTE: it is ok to kill or join a killed process.
712
- # Kill the controller process first; if its child process is
713
- # killed first, then the controller process will raise errors.
714
- # Kill any possible remaining children processes recursively.
715
- subprocess_utils.kill_children_processes(
716
- parent_pids=[controller_process.pid], force=True)
717
- controller_process.join()
718
- logger.info(f'Controller process {controller_process.pid} killed.')
719
-
720
- logger.info(f'Cleaning up any cluster for job {job_id}.')
721
- # NOTE: Originally, we send an interruption signal to the controller
722
- # process and the controller process handles cleanup. However, we
723
- # figure out the behavior differs from cloud to cloud
724
- # (e.g., GCP ignores 'SIGINT'). A possible explanation is
725
- # https://unix.stackexchange.com/questions/356408/strange-problem-with-trap-and-sigint
726
- # But anyway, a clean solution is killing the controller process
727
- # directly, and then cleanup the cluster job_state.
728
- _cleanup(job_id, dag_yaml=dag_yaml, pool=pool)
729
- logger.info(f'Cluster of managed job {job_id} has been cleaned up.')
730
-
731
- if cancelling:
732
- assert task_id is not None, job_id # Since it's set with cancelling
733
- managed_job_state.set_cancelled(
886
+ error = None
887
+ for task in dag.tasks:
888
+ # most things in this function are blocking
889
+ try:
890
+ await context_utils.to_thread(task_cleanup, task, job_id)
891
+ except Exception as e: # pylint: disable=broad-except
892
+ error = e
893
+
894
+ if error is not None:
895
+ # we only raise the last error that occurred, but its fine to lose
896
+ # some data here.
897
+ raise error
898
+
899
+ async def run_job_loop(self,
900
+ job_id: int,
901
+ dag_yaml: str,
902
+ job_logger: logging.Logger,
903
+ log_file: str,
904
+ env_file_path: Optional[str] = None,
905
+ pool: Optional[str] = None):
906
+ """Background task that runs the job loop."""
907
+ # Replace os.environ with ContextualEnviron to enable per-job
908
+ # environment isolation. This allows each job to have its own
909
+ # environment variables without affecting other jobs or the main
910
+ # process.
911
+ context.initialize()
912
+ ctx = context.get()
913
+ ctx.redirect_log(log_file) # type: ignore
914
+
915
+ # Load and apply environment variables from the job's environment file
916
+ if env_file_path and os.path.exists(env_file_path):
917
+ try:
918
+ # Load environment variables from the file
919
+ env_vars = dotenv.dotenv_values(env_file_path)
920
+ job_logger.info(f'Loading environment from {env_file_path}: '
921
+ f'{list(env_vars.keys())}')
922
+
923
+ # Apply environment variables to the job's context
924
+ ctx = context.get()
925
+ if ctx is not None:
926
+ for key, value in env_vars.items():
927
+ if value is not None:
928
+ ctx.override_envs({key: value})
929
+ job_logger.debug(
930
+ f'Set environment variable: {key}={value}')
931
+ else:
932
+ job_logger.error(
933
+ 'Context is None, cannot set environment variables')
934
+ except Exception as e: # pylint: disable=broad-except
935
+ job_logger.error(
936
+ f'Failed to load environment file {env_file_path}: {e}')
937
+ elif env_file_path:
938
+ job_logger.error(f'Environment file not found: {env_file_path}')
939
+
940
+ cancelling = False
941
+ try:
942
+ job_logger.info(f'Starting job loop for {job_id}')
943
+
944
+ controller = JobsController(job_id, dag_yaml, job_logger,
945
+ self.starting, self._job_tasks_lock,
946
+ self._starting_signal, pool)
947
+
948
+ async with self._job_tasks_lock:
949
+ if job_id in self.job_tasks:
950
+ job_logger.error(
951
+ f'Job {job_id} already exists in job_tasks')
952
+ raise ValueError(f'Job {job_id} already exists')
953
+
954
+ # Create the task and store it
955
+ # This function should return instantly and run the job loop in
956
+ # the background.
957
+ task = asyncio.create_task(controller.run())
958
+ self.job_tasks[job_id] = task
959
+ await task
960
+ except asyncio.CancelledError:
961
+ job_logger.info(f'Job {job_id} was cancelled')
962
+ dag, _ = _get_dag_and_name(dag_yaml)
963
+ task_id, _ = await (
964
+ managed_job_state.get_latest_task_id_status_async(job_id))
965
+ assert task_id is not None, job_id
966
+ job_logger.info(f'Cancelling managed job, job_id: {job_id}, '
967
+ f'task_id: {task_id}')
968
+ await managed_job_state.set_cancelling_async(
734
969
  job_id=job_id,
735
970
  callback_func=managed_job_utils.event_callback_func(
736
971
  job_id=job_id, task_id=task_id, task=dag.tasks[task_id]))
972
+ cancelling = True
973
+ raise
974
+ except Exception as e:
975
+ job_logger.error(f'Unexpected error in job loop for {job_id}: '
976
+ f'{common_utils.format_exception(e)}')
977
+ raise
978
+ finally:
979
+ try:
980
+ await self._cleanup(job_id,
981
+ dag_yaml=dag_yaml,
982
+ job_logger=job_logger,
983
+ pool=pool)
984
+ job_logger.info(
985
+ f'Cluster of managed job {job_id} has been cleaned up.')
986
+ except Exception as e: # pylint: disable=broad-except
987
+ await managed_job_state.set_failed_async(
988
+ job_id,
989
+ task_id=None,
990
+ failure_type=managed_job_state.ManagedJobStatus.
991
+ FAILED_CONTROLLER,
992
+ failure_reason=e,
993
+ override_terminal=True)
994
+
995
+ if cancelling:
996
+ # Since it's set with cancelling
997
+ assert task_id is not None, job_id
998
+ await managed_job_state.set_cancelled_async(
999
+ job_id=job_id,
1000
+ callback_func=managed_job_utils.event_callback_func(
1001
+ job_id=job_id, task_id=task_id,
1002
+ task=dag.tasks[task_id]))
1003
+
1004
+ # We should check job status after 'set_cancelled', otherwise
1005
+ # the job status is not terminal.
1006
+ job_status = await managed_job_state.get_status_async(job_id)
1007
+ assert job_status is not None
1008
+ # The job can be non-terminal if the controller exited abnormally,
1009
+ # e.g. failed to launch cluster after reaching the MAX_RETRY.
1010
+ if not job_status.is_terminal():
1011
+ job_logger.info(f'Previous job status: {job_status.value}')
1012
+ await managed_job_state.set_failed_async(
1013
+ job_id,
1014
+ task_id=None,
1015
+ failure_type=managed_job_state.ManagedJobStatus.
1016
+ FAILED_CONTROLLER,
1017
+ failure_reason=(
1018
+ 'Unexpected error occurred. For details, '
1019
+ f'run: sky jobs logs --controller {job_id}'))
1020
+
1021
+ await scheduler.job_done_async(job_id)
1022
+
1023
+ async with self._job_tasks_lock:
1024
+ try:
1025
+ # just in case we were cancelled or some other error
1026
+ # occurred during launch
1027
+ self.starting.remove(job_id)
1028
+ # its fine if we notify again, better to wake someone up
1029
+ # and have them go to sleep again, then have some stuck
1030
+ # sleeping.
1031
+ self._starting_signal.notify()
1032
+ except KeyError:
1033
+ pass
1034
+
1035
+ # Remove the job from the job_tasks dictionary.
1036
+ async with self._job_tasks_lock:
1037
+ if job_id in self.job_tasks:
1038
+ del self.job_tasks[job_id]
1039
+
1040
+ async def start_job(
1041
+ self,
1042
+ job_id: int,
1043
+ dag_yaml: str,
1044
+ env_file_path: Optional[str] = None,
1045
+ pool: Optional[str] = None,
1046
+ ):
1047
+ """Start a new job.
1048
+
1049
+ Args:
1050
+ job_id: The ID of the job to start.
1051
+ dag_yaml: Path to the YAML file containing the DAG definition.
1052
+ env_file_path: Optional path to environment file for the job.
1053
+ """
1054
+ # Create a job-specific logger
1055
+ log_dir = os.path.expanduser(jobs_constants.JOBS_CONTROLLER_LOGS_DIR)
1056
+ os.makedirs(log_dir, exist_ok=True)
1057
+ log_file = os.path.join(log_dir, f'{job_id}.log')
1058
+
1059
+ job_logger = logging.getLogger(f'sky.jobs.{job_id}')
1060
+ job_logger.setLevel(logging.DEBUG)
1061
+
1062
+ # Create file handler
1063
+ file_handler = logging.FileHandler(log_file)
1064
+ file_handler.setLevel(logging.DEBUG)
1065
+
1066
+ # Use Sky's standard formatter
1067
+ file_handler.setFormatter(sky_logging.FORMATTER)
737
1068
 
738
- # We should check job status after 'set_cancelled', otherwise
739
- # the job status is not terminal.
740
- job_status = managed_job_state.get_status(job_id)
741
- assert job_status is not None
742
- # The job can be non-terminal if the controller exited abnormally,
743
- # e.g. failed to launch cluster after reaching the MAX_RETRY.
744
- if not job_status.is_terminal():
745
- logger.info(f'Previous job status: {job_status.value}')
746
- managed_job_state.set_failed(
747
- job_id,
748
- task_id=None,
749
- failure_type=managed_job_state.ManagedJobStatus.
750
- FAILED_CONTROLLER,
751
- failure_reason=('Unexpected error occurred. For details, '
752
- f'run: sky jobs logs --controller {job_id}'))
753
-
754
- scheduler.job_done(job_id)
1069
+ # Add the handler to the logger
1070
+ job_logger.addHandler(file_handler)
1071
+
1072
+ # Prevent log propagation to avoid duplicate logs
1073
+ job_logger.propagate = False
1074
+
1075
+ job_logger.info(f'Starting job {job_id} with dag_yaml={dag_yaml}, '
1076
+ f'env_file_path={env_file_path}')
1077
+
1078
+ async with self._job_tasks_lock:
1079
+ self.starting.add(job_id)
1080
+ await create_background_task(
1081
+ self.run_job_loop(job_id, dag_yaml, job_logger, log_file,
1082
+ env_file_path, pool))
1083
+
1084
+ job_logger.info(f'Job {job_id} started successfully')
1085
+
1086
+ async def cancel_job(self):
1087
+ """Cancel an existing job."""
1088
+ while True:
1089
+ cancels = os.listdir(jobs_constants.CONSOLIDATED_SIGNAL_PATH)
1090
+ for cancel in cancels:
1091
+ async with self._job_tasks_lock:
1092
+ job_id = int(cancel)
1093
+ if job_id in self.job_tasks:
1094
+ logger.info(f'Cancelling job {job_id}')
1095
+
1096
+ task = self.job_tasks[job_id]
1097
+
1098
+ # Run the cancellation in the background, so we can
1099
+ # return immediately.
1100
+ task.cancel()
1101
+ logger.info(f'Job {job_id} cancelled successfully')
1102
+
1103
+ os.remove(f'{jobs_constants.CONSOLIDATED_SIGNAL_PATH}/'
1104
+ f'{job_id}')
1105
+ await asyncio.sleep(15)
1106
+
1107
+ async def monitor_loop(self):
1108
+ """Monitor the job loop."""
1109
+ logger.info(f'Starting monitor loop for pid {os.getpid()}...')
1110
+
1111
+ while True:
1112
+ async with self._job_tasks_lock:
1113
+ running_tasks = [
1114
+ task for task in self.job_tasks.values() if not task.done()
1115
+ ]
1116
+
1117
+ async with self._job_tasks_lock:
1118
+ starting_count = len(self.starting)
1119
+
1120
+ if starting_count >= scheduler.LAUNCHES_PER_WORKER:
1121
+ # launching a job takes around 1 minute, so lets wait half that
1122
+ # time
1123
+ await asyncio.sleep(30)
1124
+ continue
1125
+
1126
+ if len(running_tasks) >= scheduler.JOBS_PER_WORKER:
1127
+ await asyncio.sleep(60)
1128
+ continue
1129
+
1130
+ # Check if there are any jobs that are waiting to launch
1131
+ try:
1132
+ waiting_job = await managed_job_state.get_waiting_job_async(
1133
+ pid=-os.getpid())
1134
+ except Exception as e: # pylint: disable=broad-except
1135
+ logger.error(f'Failed to get waiting job: {e}')
1136
+ await asyncio.sleep(5)
1137
+ continue
1138
+
1139
+ if waiting_job is None:
1140
+ await asyncio.sleep(10)
1141
+ continue
1142
+
1143
+ job_id = waiting_job['job_id']
1144
+ dag_yaml_path = waiting_job['dag_yaml_path']
1145
+ env_file_path = waiting_job.get('env_file_path')
1146
+ pool = waiting_job.get('pool', None)
1147
+
1148
+ cancels = os.listdir(jobs_constants.CONSOLIDATED_SIGNAL_PATH)
1149
+ if str(job_id) in cancels:
1150
+ status = await managed_job_state.get_status_async(job_id)
1151
+ if status == managed_job_state.ManagedJobStatus.PENDING:
1152
+ logger.info(f'Job {job_id} cancelled')
1153
+ os.remove(f'{jobs_constants.CONSOLIDATED_SIGNAL_PATH}/'
1154
+ f'{job_id}')
1155
+ await managed_job_state.set_cancelling_async(
1156
+ job_id=job_id,
1157
+ callback_func=managed_job_utils.event_callback_func(
1158
+ job_id=job_id, task_id=None, task=None))
1159
+ await managed_job_state.set_cancelled_async(
1160
+ job_id=job_id,
1161
+ callback_func=managed_job_utils.event_callback_func(
1162
+ job_id=job_id, task_id=None, task=None))
1163
+ continue
1164
+
1165
+ await self.start_job(job_id, dag_yaml_path, env_file_path, pool)
1166
+
1167
+
1168
+ async def main():
1169
+ context_utils.hijack_sys_attrs()
1170
+
1171
+ controller = Controller()
1172
+
1173
+ # Will happen multiple times, who cares though
1174
+ os.makedirs(jobs_constants.CONSOLIDATED_SIGNAL_PATH, exist_ok=True)
1175
+
1176
+ # Increase number of files we can open
1177
+ soft = None
1178
+ try:
1179
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
1180
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
1181
+ except OSError as e:
1182
+ logger.warning(f'Failed to increase number of files we can open: {e}\n'
1183
+ f'Current soft limit: {soft}, hard limit: {hard}')
1184
+
1185
+ # Will loop forever, do it in the background
1186
+ cancel_job_task = asyncio.create_task(controller.cancel_job())
1187
+ monitor_loop_task = asyncio.create_task(controller.monitor_loop())
1188
+
1189
+ try:
1190
+ await asyncio.gather(cancel_job_task, monitor_loop_task)
1191
+ except Exception as e: # pylint: disable=broad-except
1192
+ logger.error(f'Controller server crashed: {e}')
1193
+ sys.exit(1)
755
1194
 
756
1195
 
757
1196
  if __name__ == '__main__':
758
- parser = argparse.ArgumentParser()
759
- parser.add_argument('--job-id',
760
- required=True,
761
- type=int,
762
- help='Job id for the controller job.')
763
- parser.add_argument('dag_yaml',
764
- type=str,
765
- help='The path to the user job yaml file.')
766
- parser.add_argument('--pool',
767
- required=False,
768
- default=None,
769
- type=str,
770
- help='The pool to use for the controller job.')
771
- args = parser.parse_args()
772
- # We start process with 'spawn', because 'fork' could result in weird
773
- # behaviors; 'spawn' is also cross-platform.
774
- multiprocessing.set_start_method('spawn', force=True)
775
- start(args.job_id, args.dag_yaml, args.pool)
1197
+ asyncio.run(main())