dstack 0.19.20__py3-none-any.whl → 0.19.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (93) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -3
  2. dstack/_internal/cli/services/configurators/__init__.py +8 -0
  3. dstack/_internal/cli/services/configurators/fleet.py +1 -1
  4. dstack/_internal/cli/services/configurators/gateway.py +1 -1
  5. dstack/_internal/cli/services/configurators/run.py +11 -1
  6. dstack/_internal/cli/services/configurators/volume.py +1 -1
  7. dstack/_internal/cli/utils/common.py +48 -5
  8. dstack/_internal/cli/utils/fleet.py +5 -5
  9. dstack/_internal/cli/utils/run.py +32 -0
  10. dstack/_internal/core/backends/__init__.py +0 -65
  11. dstack/_internal/core/backends/configurators.py +9 -0
  12. dstack/_internal/core/backends/features.py +64 -0
  13. dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
  14. dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
  15. dstack/_internal/core/backends/hotaisle/backend.py +16 -0
  16. dstack/_internal/core/backends/hotaisle/compute.py +225 -0
  17. dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
  18. dstack/_internal/core/backends/hotaisle/models.py +45 -0
  19. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  20. dstack/_internal/core/backends/models.py +8 -0
  21. dstack/_internal/core/compatibility/fleets.py +2 -0
  22. dstack/_internal/core/compatibility/runs.py +12 -0
  23. dstack/_internal/core/models/backends/base.py +2 -0
  24. dstack/_internal/core/models/configurations.py +139 -1
  25. dstack/_internal/core/models/health.py +28 -0
  26. dstack/_internal/core/models/instances.py +2 -0
  27. dstack/_internal/core/models/logs.py +2 -1
  28. dstack/_internal/core/models/profiles.py +37 -0
  29. dstack/_internal/core/models/runs.py +21 -1
  30. dstack/_internal/core/services/ssh/tunnel.py +7 -0
  31. dstack/_internal/server/app.py +26 -10
  32. dstack/_internal/server/background/__init__.py +9 -6
  33. dstack/_internal/server/background/tasks/process_fleets.py +52 -38
  34. dstack/_internal/server/background/tasks/process_gateways.py +2 -2
  35. dstack/_internal/server/background/tasks/process_idle_volumes.py +5 -4
  36. dstack/_internal/server/background/tasks/process_instances.py +168 -103
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -2
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +2 -0
  39. dstack/_internal/server/background/tasks/process_probes.py +164 -0
  40. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +14 -2
  41. dstack/_internal/server/background/tasks/process_running_jobs.py +142 -124
  42. dstack/_internal/server/background/tasks/process_runs.py +84 -34
  43. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -10
  44. dstack/_internal/server/background/tasks/process_terminating_jobs.py +12 -4
  45. dstack/_internal/server/background/tasks/process_volumes.py +4 -1
  46. dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
  47. dstack/_internal/server/migrations/versions/50dd7ea98639_index_status_columns.py +55 -0
  48. dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
  49. dstack/_internal/server/migrations/versions/ec02a26a256c_add_runmodel_next_triggered_at.py +38 -0
  50. dstack/_internal/server/models.py +57 -16
  51. dstack/_internal/server/routers/instances.py +33 -5
  52. dstack/_internal/server/schemas/health/dcgm.py +56 -0
  53. dstack/_internal/server/schemas/instances.py +32 -0
  54. dstack/_internal/server/schemas/runner.py +5 -0
  55. dstack/_internal/server/services/fleets.py +19 -10
  56. dstack/_internal/server/services/gateways/__init__.py +17 -17
  57. dstack/_internal/server/services/instances.py +113 -15
  58. dstack/_internal/server/services/jobs/__init__.py +18 -13
  59. dstack/_internal/server/services/jobs/configurators/base.py +26 -0
  60. dstack/_internal/server/services/logging.py +4 -2
  61. dstack/_internal/server/services/logs/aws.py +13 -1
  62. dstack/_internal/server/services/logs/gcp.py +16 -1
  63. dstack/_internal/server/services/offers.py +3 -3
  64. dstack/_internal/server/services/probes.py +6 -0
  65. dstack/_internal/server/services/projects.py +51 -19
  66. dstack/_internal/server/services/prometheus/client_metrics.py +3 -0
  67. dstack/_internal/server/services/prometheus/custom_metrics.py +2 -3
  68. dstack/_internal/server/services/runner/client.py +52 -20
  69. dstack/_internal/server/services/runner/ssh.py +4 -4
  70. dstack/_internal/server/services/runs.py +115 -39
  71. dstack/_internal/server/services/services/__init__.py +4 -1
  72. dstack/_internal/server/services/ssh.py +66 -0
  73. dstack/_internal/server/services/users.py +2 -3
  74. dstack/_internal/server/services/volumes.py +11 -11
  75. dstack/_internal/server/settings.py +16 -0
  76. dstack/_internal/server/statics/index.html +1 -1
  77. dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
  78. dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
  79. dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
  80. dstack/_internal/server/testing/common.py +51 -0
  81. dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
  82. dstack/_internal/server/utils/sentry_utils.py +12 -0
  83. dstack/_internal/settings.py +3 -0
  84. dstack/_internal/utils/common.py +15 -0
  85. dstack/_internal/utils/cron.py +5 -0
  86. dstack/api/server/__init__.py +1 -1
  87. dstack/version.py +1 -1
  88. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/METADATA +13 -22
  89. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/RECORD +93 -75
  90. /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
  91. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/WHEEL +0 -0
  92. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/entry_points.txt +0 -0
  93. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/licenses/LICENSE.md +0 -0
@@ -9,11 +9,17 @@ from sqlalchemy.orm import joinedload
9
9
  from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
10
10
  from dstack._internal.core.models.runs import JobStatus
11
11
  from dstack._internal.server.db import get_session_ctx
12
- from dstack._internal.server.models import InstanceModel, JobModel, JobPrometheusMetrics
12
+ from dstack._internal.server.models import (
13
+ InstanceModel,
14
+ JobModel,
15
+ JobPrometheusMetrics,
16
+ ProjectModel,
17
+ )
13
18
  from dstack._internal.server.services.instances import get_instance_ssh_private_keys
14
19
  from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
15
20
  from dstack._internal.server.services.runner import client
16
21
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
22
+ from dstack._internal.server.utils import sentry_utils
17
23
  from dstack._internal.server.utils.common import gather_map_async
18
24
  from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
19
25
  from dstack._internal.utils.logging import get_logger
@@ -29,6 +35,7 @@ MIN_COLLECT_INTERVAL_SECONDS = 9
29
35
  METRICS_TTL_SECONDS = 600
30
36
 
31
37
 
38
+ @sentry_utils.instrument_background_task
32
39
  async def collect_prometheus_metrics():
33
40
  now = get_current_datetime()
34
41
  cutoff = now - timedelta(seconds=MIN_COLLECT_INTERVAL_SECONDS)
@@ -43,7 +50,11 @@ async def collect_prometheus_metrics():
43
50
  JobPrometheusMetrics.collected_at < cutoff,
44
51
  ),
45
52
  )
46
- .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
53
+ .options(
54
+ joinedload(JobModel.instance)
55
+ .joinedload(InstanceModel.project)
56
+ .load_only(ProjectModel.ssh_private_key)
57
+ )
47
58
  .order_by(JobModel.last_processed_at.asc())
48
59
  .limit(MAX_JOBS_FETCHED)
49
60
  )
@@ -52,6 +63,7 @@ async def collect_prometheus_metrics():
52
63
  await _collect_jobs_metrics(batch, now)
53
64
 
54
65
 
66
+ @sentry_utils.instrument_background_task
55
67
  async def delete_prometheus_metrics():
56
68
  now = get_current_datetime()
57
69
  cutoff = now - timedelta(seconds=METRICS_TTL_SECONDS)
@@ -2,12 +2,12 @@ import asyncio
2
2
  import re
3
3
  import uuid
4
4
  from collections.abc import Iterable
5
- from datetime import timedelta, timezone
5
+ from datetime import timedelta
6
6
  from typing import Dict, List, Optional
7
7
 
8
8
  from sqlalchemy import select
9
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
- from sqlalchemy.orm import joinedload
10
+ from sqlalchemy.orm import joinedload, load_only
11
11
 
12
12
  from dstack._internal import settings
13
13
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
@@ -42,6 +42,7 @@ from dstack._internal.server.db import get_db, get_session_ctx
42
42
  from dstack._internal.server.models import (
43
43
  InstanceModel,
44
44
  JobModel,
45
+ ProbeModel,
45
46
  ProjectModel,
46
47
  RepoModel,
47
48
  RunModel,
@@ -73,6 +74,7 @@ from dstack._internal.server.services.runs import (
73
74
  )
74
75
  from dstack._internal.server.services.secrets import get_project_secrets_mapping
75
76
  from dstack._internal.server.services.storage import get_default_storage
77
+ from dstack._internal.server.utils import sentry_utils
76
78
  from dstack._internal.utils import common as common_utils
77
79
  from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
78
80
  from dstack._internal.utils.logging import get_logger
@@ -94,6 +96,7 @@ async def process_running_jobs(batch_size: int = 1):
94
96
  await asyncio.gather(*tasks)
95
97
 
96
98
 
99
+ @sentry_utils.instrument_background_task
97
100
  async def _process_next_running_job():
98
101
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
99
102
  async with get_session_ctx() as session:
@@ -108,9 +111,9 @@ async def _process_next_running_job():
108
111
  RunModel.status.not_in([RunStatus.TERMINATING]),
109
112
  JobModel.id.not_in(lockset),
110
113
  JobModel.last_processed_at
111
- < common_utils.get_current_datetime().replace(tzinfo=None)
112
- - MIN_PROCESSING_INTERVAL,
114
+ < common_utils.get_current_datetime() - MIN_PROCESSING_INTERVAL,
113
115
  )
116
+ .options(load_only(JobModel.id))
114
117
  .order_by(JobModel.last_processed_at.asc())
115
118
  .limit(1)
116
119
  .with_for_update(
@@ -133,7 +136,6 @@ async def _process_next_running_job():
133
136
 
134
137
  async def _process_running_job(session: AsyncSession, job_model: JobModel):
135
138
  # Refetch to load related attributes.
136
- # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
137
139
  res = await session.execute(
138
140
  select(JobModel)
139
141
  .where(JobModel.id == job_model.id)
@@ -144,7 +146,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
144
146
  res = await session.execute(
145
147
  select(RunModel)
146
148
  .where(RunModel.id == job_model.run_id)
147
- .options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
149
+ .options(joinedload(RunModel.project))
148
150
  .options(joinedload(RunModel.user))
149
151
  .options(joinedload(RunModel.repo))
150
152
  .options(joinedload(RunModel.jobs))
@@ -160,143 +162,147 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
160
162
  job_model.status = JobStatus.TERMINATING
161
163
  job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
162
164
  job_model.last_processed_at = common_utils.get_current_datetime()
165
+ await session.commit()
163
166
  return
164
167
 
165
168
  job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
166
169
 
167
- # Wait until all other jobs in the replica are provisioned
168
- for other_job in run.jobs:
169
- if (
170
- other_job.job_spec.replica_num == job.job_spec.replica_num
171
- and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
172
- ):
170
+ initial_status = job_model.status
171
+ if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]:
172
+ # Wait until all other jobs in the replica are provisioned
173
+ for other_job in run.jobs:
174
+ if (
175
+ other_job.job_spec.replica_num == job.job_spec.replica_num
176
+ and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
177
+ ):
178
+ job_model.last_processed_at = common_utils.get_current_datetime()
179
+ await session.commit()
180
+ return
181
+
182
+ cluster_info = _get_cluster_info(
183
+ jobs=run.jobs,
184
+ replica_num=job.job_spec.replica_num,
185
+ job_provisioning_data=job_provisioning_data,
186
+ job_runtime_data=job_submission.job_runtime_data,
187
+ )
188
+
189
+ volumes = await get_job_attached_volumes(
190
+ session=session,
191
+ project=project,
192
+ run_spec=run.run_spec,
193
+ job_num=job.job_spec.job_num,
194
+ job_provisioning_data=job_provisioning_data,
195
+ )
196
+
197
+ repo_creds_model = await get_repo_creds(
198
+ session=session, repo=repo_model, user=run_model.user
199
+ )
200
+ repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
201
+
202
+ secrets = await get_project_secrets_mapping(session=session, project=project)
203
+ try:
204
+ _interpolate_secrets(secrets, job.job_spec)
205
+ except InterpolatorError as e:
206
+ logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
207
+ job_model.status = JobStatus.TERMINATING
208
+ job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
209
+ job_model.termination_reason_message = e.args[0]
173
210
  job_model.last_processed_at = common_utils.get_current_datetime()
174
211
  await session.commit()
175
212
  return
176
213
 
177
- cluster_info = _get_cluster_info(
178
- jobs=run.jobs,
179
- replica_num=job.job_spec.replica_num,
180
- job_provisioning_data=job_provisioning_data,
181
- job_runtime_data=job_submission.job_runtime_data,
182
- )
183
-
184
- volumes = await get_job_attached_volumes(
185
- session=session,
186
- project=project,
187
- run_spec=run.run_spec,
188
- job_num=job.job_spec.job_num,
189
- job_provisioning_data=job_provisioning_data,
190
- )
191
-
192
214
  server_ssh_private_keys = get_instance_ssh_private_keys(
193
215
  common_utils.get_or_error(job_model.instance)
194
216
  )
195
217
 
196
- secrets = await get_project_secrets_mapping(session=session, project=project)
197
-
198
- try:
199
- _interpolate_secrets(secrets, job.job_spec)
200
- except InterpolatorError as e:
201
- logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
202
- job_model.status = JobStatus.TERMINATING
203
- job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
204
- job_model.termination_reason_message = e.args[0]
205
- job_model.last_processed_at = common_utils.get_current_datetime()
206
- return
207
-
208
- repo_creds_model = await get_repo_creds(session=session, repo=repo_model, user=run_model.user)
209
- repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
210
-
211
- initial_status = job_model.status
212
218
  if initial_status == JobStatus.PROVISIONING:
213
219
  if job_provisioning_data.hostname is None:
214
220
  await _wait_for_instance_provisioning_data(job_model=job_model)
221
+ job_model.last_processed_at = common_utils.get_current_datetime()
222
+ await session.commit()
223
+ return
224
+ if _should_wait_for_other_nodes(run, job, job_model):
225
+ job_model.last_processed_at = common_utils.get_current_datetime()
226
+ await session.commit()
227
+ return
228
+
229
+ # fails are acceptable until timeout is exceeded
230
+ if job_provisioning_data.dockerized:
231
+ logger.debug(
232
+ "%s: process provisioning job with shim, age=%s",
233
+ fmt(job_model),
234
+ job_submission.age,
235
+ )
236
+ ssh_user = job_provisioning_data.username
237
+ user_ssh_key = run.run_spec.ssh_key_pub.strip()
238
+ public_keys = [project.ssh_public_key.strip(), user_ssh_key]
239
+ if job_provisioning_data.backend == BackendType.LOCAL:
240
+ # No need to update ~/.ssh/authorized_keys when running shim locally
241
+ user_ssh_key = ""
242
+ success = await common_utils.run_async(
243
+ _process_provisioning_with_shim,
244
+ server_ssh_private_keys,
245
+ job_provisioning_data,
246
+ None,
247
+ run,
248
+ job_model,
249
+ job_provisioning_data,
250
+ volumes,
251
+ job.job_spec.registry_auth,
252
+ public_keys,
253
+ ssh_user,
254
+ user_ssh_key,
255
+ )
215
256
  else:
216
- if _should_wait_for_other_nodes(run, job, job_model):
217
- job_model.last_processed_at = common_utils.get_current_datetime()
218
- await session.commit()
219
- return
257
+ logger.debug(
258
+ "%s: process provisioning job without shim, age=%s",
259
+ fmt(job_model),
260
+ job_submission.age,
261
+ )
262
+ # FIXME: downloading file archives and code here is a waste of time if
263
+ # the runner is not ready yet
264
+ file_archives = await _get_job_file_archives(
265
+ session=session,
266
+ archive_mappings=job.job_spec.file_archives,
267
+ user=run_model.user,
268
+ )
269
+ code = await _get_job_code(
270
+ session=session,
271
+ project=project,
272
+ repo=repo_model,
273
+ code_hash=_get_repo_code_hash(run, job),
274
+ )
220
275
 
221
- # fails are acceptable until timeout is exceeded
222
- if job_provisioning_data.dockerized:
223
- logger.debug(
224
- "%s: process provisioning job with shim, age=%s",
225
- fmt(job_model),
226
- job_submission.age,
227
- )
228
- ssh_user = job_provisioning_data.username
229
- user_ssh_key = run.run_spec.ssh_key_pub.strip()
230
- public_keys = [project.ssh_public_key.strip(), user_ssh_key]
231
- if job_provisioning_data.backend == BackendType.LOCAL:
232
- # No need to update ~/.ssh/authorized_keys when running shim locally
233
- user_ssh_key = ""
234
- success = await common_utils.run_async(
235
- _process_provisioning_with_shim,
236
- server_ssh_private_keys,
237
- job_provisioning_data,
238
- None,
239
- run,
240
- job_model,
241
- job_provisioning_data,
242
- volumes,
243
- job.job_spec.registry_auth,
244
- public_keys,
245
- ssh_user,
246
- user_ssh_key,
247
- )
248
- else:
249
- logger.debug(
250
- "%s: process provisioning job without shim, age=%s",
276
+ success = await common_utils.run_async(
277
+ _submit_job_to_runner,
278
+ server_ssh_private_keys,
279
+ job_provisioning_data,
280
+ None,
281
+ run,
282
+ job_model,
283
+ job,
284
+ cluster_info,
285
+ code,
286
+ file_archives,
287
+ secrets,
288
+ repo_creds,
289
+ success_if_not_available=False,
290
+ )
291
+
292
+ if not success:
293
+ # check timeout
294
+ if job_submission.age > get_provisioning_timeout(
295
+ backend_type=job_provisioning_data.get_base_backend(),
296
+ instance_type_name=job_provisioning_data.instance_type.name,
297
+ ):
298
+ logger.warning(
299
+ "%s: failed because runner has not become available in time, age=%s",
251
300
  fmt(job_model),
252
301
  job_submission.age,
253
302
  )
254
- # FIXME: downloading file archives and code here is a waste of time if
255
- # the runner is not ready yet
256
- file_archives = await _get_job_file_archives(
257
- session=session,
258
- archive_mappings=job.job_spec.file_archives,
259
- user=run_model.user,
260
- )
261
- code = await _get_job_code(
262
- session=session,
263
- project=project,
264
- repo=repo_model,
265
- code_hash=_get_repo_code_hash(run, job),
266
- )
267
-
268
- success = await common_utils.run_async(
269
- _submit_job_to_runner,
270
- server_ssh_private_keys,
271
- job_provisioning_data,
272
- None,
273
- run,
274
- job_model,
275
- job,
276
- cluster_info,
277
- code,
278
- file_archives,
279
- secrets,
280
- repo_creds,
281
- success_if_not_available=False,
282
- )
283
-
284
- if not success:
285
- # check timeout
286
- if job_submission.age > get_provisioning_timeout(
287
- backend_type=job_provisioning_data.get_base_backend(),
288
- instance_type_name=job_provisioning_data.instance_type.name,
289
- ):
290
- logger.warning(
291
- "%s: failed because runner has not become available in time, age=%s",
292
- fmt(job_model),
293
- job_submission.age,
294
- )
295
- job_model.status = JobStatus.TERMINATING
296
- job_model.termination_reason = (
297
- JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
298
- )
299
- # instance will be emptied by process_terminating_jobs
303
+ job_model.status = JobStatus.TERMINATING
304
+ job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
305
+ # instance will be emptied by process_terminating_jobs
300
306
 
301
307
  else: # fails are not acceptable
302
308
  if initial_status == JobStatus.PULLING:
@@ -409,6 +415,18 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
409
415
  )
410
416
  job_model.status = JobStatus.TERMINATING
411
417
  job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
418
+ else:
419
+ for probe_num in range(len(job.job_spec.probes)):
420
+ session.add(
421
+ ProbeModel(
422
+ name=f"{job_model.job_name}-{probe_num}",
423
+ job=job_model,
424
+ probe_num=probe_num,
425
+ due=common_utils.get_current_datetime(),
426
+ success_streak=0,
427
+ active=True,
428
+ )
429
+ )
412
430
 
413
431
  if job_model.status == JobStatus.RUNNING:
414
432
  await _check_gpu_utilization(session, job_model, job)
@@ -801,7 +819,7 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
801
819
  return False
802
820
  return (
803
821
  common_utils.get_current_datetime()
804
- > job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT
822
+ > job_model.disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT
805
823
  )
806
824
 
807
825
 
@@ -2,9 +2,9 @@ import asyncio
2
2
  import datetime
3
3
  from typing import List, Optional, Set, Tuple
4
4
 
5
- from sqlalchemy import select
5
+ from sqlalchemy import and_, or_, select
6
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
- from sqlalchemy.orm import joinedload, selectinload
7
+ from sqlalchemy.orm import joinedload, load_only, selectinload
8
8
 
9
9
  import dstack._internal.server.services.services.autoscalers as autoscalers
10
10
  from dstack._internal.core.errors import ServerError
@@ -20,7 +20,14 @@ from dstack._internal.core.models.runs import (
20
20
  RunTerminationReason,
21
21
  )
22
22
  from dstack._internal.server.db import get_db, get_session_ctx
23
- from dstack._internal.server.models import JobModel, ProjectModel, RunModel
23
+ from dstack._internal.server.models import (
24
+ InstanceModel,
25
+ JobModel,
26
+ ProbeModel,
27
+ ProjectModel,
28
+ RunModel,
29
+ UserModel,
30
+ )
24
31
  from dstack._internal.server.services.jobs import (
25
32
  find_job,
26
33
  get_job_specs_from_run_spec,
@@ -30,6 +37,7 @@ from dstack._internal.server.services.locking import get_locker
30
37
  from dstack._internal.server.services.prometheus.client_metrics import run_metrics
31
38
  from dstack._internal.server.services.runs import (
32
39
  fmt,
40
+ is_replica_ready,
33
41
  process_terminating_run,
34
42
  retry_run_replica_jobs,
35
43
  run_model_to_run,
@@ -37,6 +45,7 @@ from dstack._internal.server.services.runs import (
37
45
  )
38
46
  from dstack._internal.server.services.secrets import get_project_secrets_mapping
39
47
  from dstack._internal.server.services.services import update_service_desired_replica_count
48
+ from dstack._internal.server.utils import sentry_utils
40
49
  from dstack._internal.utils import common
41
50
  from dstack._internal.utils.logging import get_logger
42
51
 
@@ -53,22 +62,54 @@ async def process_runs(batch_size: int = 1):
53
62
  await asyncio.gather(*tasks)
54
63
 
55
64
 
65
+ @sentry_utils.instrument_background_task
56
66
  async def _process_next_run():
57
67
  run_lock, run_lockset = get_locker(get_db().dialect_name).get_lockset(RunModel.__tablename__)
58
68
  job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
69
+ now = common.get_current_datetime()
59
70
  async with get_session_ctx() as session:
60
71
  async with run_lock, job_lock:
61
72
  res = await session.execute(
62
73
  select(RunModel)
63
74
  .where(
64
- RunModel.status.not_in(RunStatus.finished_statuses()),
65
75
  RunModel.id.not_in(run_lockset),
66
- RunModel.last_processed_at
67
- < common.get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
76
+ RunModel.last_processed_at < now - MIN_PROCESSING_INTERVAL,
77
+ # Filter out runs that don't need to be processed.
78
+ # This is only to reduce unnecessary commits.
79
+ # Otherwise, we could fetch all active runs and filter them when processing.
80
+ or_(
81
+ # Active non-pending runs:
82
+ RunModel.status.not_in(
83
+ RunStatus.finished_statuses() + [RunStatus.PENDING]
84
+ ),
85
+ # Retrying runs:
86
+ and_(
87
+ RunModel.status == RunStatus.PENDING,
88
+ RunModel.resubmission_attempt > 0,
89
+ ),
90
+ # Scheduled ready runs:
91
+ and_(
92
+ RunModel.status == RunStatus.PENDING,
93
+ RunModel.resubmission_attempt == 0,
94
+ RunModel.next_triggered_at.is_not(None),
95
+ RunModel.next_triggered_at < now,
96
+ ),
97
+ # Scaled-to-zero runs:
98
+ # Such runs cannot be scheduled, thus we check next_triggered_at.
99
+ # If we allow scheduled services with downscaling to zero
100
+ # This check won't pass.
101
+ and_(
102
+ RunModel.status == RunStatus.PENDING,
103
+ RunModel.resubmission_attempt == 0,
104
+ RunModel.next_triggered_at.is_(None),
105
+ ),
106
+ ),
68
107
  )
108
+ .options(joinedload(RunModel.jobs).load_only(JobModel.id))
109
+ .options(load_only(RunModel.id))
69
110
  .order_by(RunModel.last_processed_at.asc())
70
111
  .limit(1)
71
- .with_for_update(skip_locked=True, key_share=True)
112
+ .with_for_update(skip_locked=True, key_share=True, of=RunModel)
72
113
  )
73
114
  run_model = res.scalar()
74
115
  if run_model is None:
@@ -98,20 +139,27 @@ async def _process_next_run():
98
139
 
99
140
 
100
141
  async def _process_run(session: AsyncSession, run_model: RunModel):
101
- logger.debug("%s: processing run", fmt(run_model))
102
142
  # Refetch to load related attributes.
103
- # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
104
143
  res = await session.execute(
105
144
  select(RunModel)
106
145
  .where(RunModel.id == run_model.id)
107
146
  .execution_options(populate_existing=True)
108
- .options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
109
- .options(joinedload(RunModel.user))
110
- .options(joinedload(RunModel.repo))
111
- .options(selectinload(RunModel.jobs).joinedload(JobModel.instance))
147
+ .options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
148
+ .options(joinedload(RunModel.user).load_only(UserModel.name))
149
+ .options(
150
+ selectinload(RunModel.jobs)
151
+ .joinedload(JobModel.instance)
152
+ .load_only(InstanceModel.fleet_id)
153
+ )
154
+ .options(
155
+ selectinload(RunModel.jobs)
156
+ .joinedload(JobModel.probes)
157
+ .load_only(ProbeModel.success_streak)
158
+ )
112
159
  .execution_options(populate_existing=True)
113
160
  )
114
161
  run_model = res.unique().scalar_one()
162
+ logger.debug("%s: processing run", fmt(run_model))
115
163
  try:
116
164
  if run_model.status == RunStatus.PENDING:
117
165
  await _process_pending_run(session, run_model)
@@ -135,8 +183,12 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
135
183
  async def _process_pending_run(session: AsyncSession, run_model: RunModel):
136
184
  """Jobs are not created yet"""
137
185
  run = run_model_to_run(run_model)
138
- if not _pending_run_ready_for_resubmission(run_model, run):
139
- logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model))
186
+
187
+ # TODO: Do not select such runs in the first place to avoid redundant processing
188
+ if run_model.resubmission_attempt > 0 and not _retrying_run_ready_for_resubmission(
189
+ run_model, run
190
+ ):
191
+ logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model))
140
192
  return
141
193
 
142
194
  run_model.desired_replica_count = 1
@@ -160,7 +212,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
160
212
  logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model))
161
213
 
162
214
 
163
- def _pending_run_ready_for_resubmission(run_model: RunModel, run: Run) -> bool:
215
+ def _retrying_run_ready_for_resubmission(run_model: RunModel, run: Run) -> bool:
164
216
  if run.latest_job_submission is None:
165
217
  # Should not be possible
166
218
  return True
@@ -197,7 +249,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
197
249
  We handle fails, scaling, and status changes.
198
250
  """
199
251
  run = run_model_to_run(run_model)
200
- run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
252
+ run_spec = run.run_spec
201
253
  retry_single_job = _can_retry_single_job(run_spec)
202
254
 
203
255
  run_statuses: Set[RunStatus] = set()
@@ -337,9 +389,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
337
389
  )
338
390
  if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING:
339
391
  current_time = common.get_current_datetime()
340
- submit_to_provision_duration = (
341
- current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc)
342
- ).total_seconds()
392
+ submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds()
343
393
  logger.info(
344
394
  "%s: run took %.2f seconds from submission to provisioning.",
345
395
  fmt(run_model),
@@ -429,22 +479,22 @@ async def _handle_run_replicas(
429
479
  )
430
480
 
431
481
  replicas_to_stop_count = 0
432
- # stop any out-of-date replicas that are not running
433
- replicas_to_stop_count += len(
434
- {
435
- j.replica_num
436
- for j in run_model.jobs
437
- if j.status
438
- not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
439
- and j.deployment_num < run_model.deployment_num
440
- }
482
+ # stop any out-of-date replicas that are not ready
483
+ replicas_to_stop_count += sum(
484
+ any(j.deployment_num < run_model.deployment_num for j in jobs)
485
+ and any(
486
+ j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses()
487
+ for j in jobs
488
+ )
489
+ and not is_replica_ready(jobs)
490
+ for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
441
491
  )
442
- running_replica_count = len(
443
- {j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
492
+ ready_replica_count = sum(
493
+ is_replica_ready(jobs) for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
444
494
  )
445
- if running_replica_count > run_model.desired_replica_count:
446
- # stop excessive running out-of-date replicas
447
- replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
495
+ if ready_replica_count > run_model.desired_replica_count:
496
+ # stop excessive ready out-of-date replicas
497
+ replicas_to_stop_count += ready_replica_count - run_model.desired_replica_count
448
498
  if replicas_to_stop_count:
449
499
  await scale_run_replicas(
450
500
  session,