dstack 0.19.19__py3-none-any.whl → 0.19.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (54) hide show
  1. dstack/_internal/core/backends/__init__.py +0 -65
  2. dstack/_internal/core/backends/cloudrift/api_client.py +13 -1
  3. dstack/_internal/core/backends/features.py +64 -0
  4. dstack/_internal/core/backends/oci/resources.py +5 -5
  5. dstack/_internal/core/compatibility/fleets.py +2 -0
  6. dstack/_internal/core/compatibility/runs.py +4 -0
  7. dstack/_internal/core/models/profiles.py +37 -0
  8. dstack/_internal/server/app.py +22 -10
  9. dstack/_internal/server/background/__init__.py +5 -6
  10. dstack/_internal/server/background/tasks/process_fleets.py +52 -38
  11. dstack/_internal/server/background/tasks/process_gateways.py +2 -2
  12. dstack/_internal/server/background/tasks/process_idle_volumes.py +5 -4
  13. dstack/_internal/server/background/tasks/process_instances.py +62 -48
  14. dstack/_internal/server/background/tasks/process_metrics.py +9 -2
  15. dstack/_internal/server/background/tasks/process_placement_groups.py +2 -0
  16. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +14 -2
  17. dstack/_internal/server/background/tasks/process_running_jobs.py +129 -124
  18. dstack/_internal/server/background/tasks/process_runs.py +63 -20
  19. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -10
  20. dstack/_internal/server/background/tasks/process_terminating_jobs.py +12 -4
  21. dstack/_internal/server/background/tasks/process_volumes.py +4 -1
  22. dstack/_internal/server/migrations/versions/50dd7ea98639_index_status_columns.py +55 -0
  23. dstack/_internal/server/migrations/versions/ec02a26a256c_add_runmodel_next_triggered_at.py +38 -0
  24. dstack/_internal/server/models.py +16 -16
  25. dstack/_internal/server/schemas/logs.py +1 -9
  26. dstack/_internal/server/services/fleets.py +19 -10
  27. dstack/_internal/server/services/gateways/__init__.py +17 -17
  28. dstack/_internal/server/services/instances.py +10 -14
  29. dstack/_internal/server/services/jobs/__init__.py +10 -12
  30. dstack/_internal/server/services/logs/aws.py +45 -3
  31. dstack/_internal/server/services/logs/filelog.py +121 -11
  32. dstack/_internal/server/services/offers.py +3 -3
  33. dstack/_internal/server/services/projects.py +35 -15
  34. dstack/_internal/server/services/prometheus/client_metrics.py +3 -0
  35. dstack/_internal/server/services/prometheus/custom_metrics.py +22 -3
  36. dstack/_internal/server/services/runs.py +74 -34
  37. dstack/_internal/server/services/services/__init__.py +4 -1
  38. dstack/_internal/server/services/users.py +2 -3
  39. dstack/_internal/server/services/volumes.py +11 -11
  40. dstack/_internal/server/settings.py +3 -0
  41. dstack/_internal/server/statics/index.html +1 -1
  42. dstack/_internal/server/statics/{main-64f8273740c4b52c18f5.js → main-39a767528976f8078166.js} +7 -26
  43. dstack/_internal/server/statics/{main-64f8273740c4b52c18f5.js.map → main-39a767528976f8078166.js.map} +1 -1
  44. dstack/_internal/server/statics/{main-d58fc0460cb0eae7cb5c.css → main-8f9ee218d3eb45989682.css} +2 -2
  45. dstack/_internal/server/testing/common.py +7 -0
  46. dstack/_internal/server/utils/sentry_utils.py +12 -0
  47. dstack/_internal/utils/common.py +10 -21
  48. dstack/_internal/utils/cron.py +5 -0
  49. dstack/version.py +1 -1
  50. {dstack-0.19.19.dist-info → dstack-0.19.21.dist-info}/METADATA +2 -11
  51. {dstack-0.19.19.dist-info → dstack-0.19.21.dist-info}/RECORD +54 -49
  52. {dstack-0.19.19.dist-info → dstack-0.19.21.dist-info}/WHEEL +0 -0
  53. {dstack-0.19.19.dist-info → dstack-0.19.21.dist-info}/entry_points.txt +0 -0
  54. {dstack-0.19.19.dist-info → dstack-0.19.21.dist-info}/licenses/LICENSE.md +0 -0
@@ -9,13 +9,9 @@ from paramiko.ssh_exception import PasswordRequiredException
9
9
  from pydantic import ValidationError
10
10
  from sqlalchemy import select
11
11
  from sqlalchemy.ext.asyncio import AsyncSession
12
- from sqlalchemy.orm import joinedload, lazyload
12
+ from sqlalchemy.orm import joinedload
13
13
 
14
14
  from dstack._internal import settings
15
- from dstack._internal.core.backends import (
16
- BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
17
- BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
18
- )
19
15
  from dstack._internal.core.backends.base.compute import (
20
16
  ComputeWithCreateInstanceSupport,
21
17
  ComputeWithPlacementGroupSupport,
@@ -27,6 +23,10 @@ from dstack._internal.core.backends.base.compute import (
27
23
  get_shim_env,
28
24
  get_shim_pre_start_commands,
29
25
  )
26
+ from dstack._internal.core.backends.features import (
27
+ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
28
+ BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
29
+ )
30
30
  from dstack._internal.core.backends.remote.provisioning import (
31
31
  detect_cpu_arch,
32
32
  get_host_info,
@@ -78,6 +78,7 @@ from dstack._internal.server.db import get_db, get_session_ctx
78
78
  from dstack._internal.server.models import (
79
79
  FleetModel,
80
80
  InstanceModel,
81
+ JobModel,
81
82
  PlacementGroupModel,
82
83
  ProjectModel,
83
84
  )
@@ -104,7 +105,12 @@ from dstack._internal.server.services.placement import (
104
105
  from dstack._internal.server.services.runner import client as runner_client
105
106
  from dstack._internal.server.services.runner.client import HealthStatus
106
107
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
107
- from dstack._internal.utils.common import get_current_datetime, run_async
108
+ from dstack._internal.server.utils import sentry_utils
109
+ from dstack._internal.utils.common import (
110
+ get_current_datetime,
111
+ get_or_error,
112
+ run_async,
113
+ )
108
114
  from dstack._internal.utils.logging import get_logger
109
115
  from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
110
116
  from dstack._internal.utils.ssh import (
@@ -131,6 +137,7 @@ async def process_instances(batch_size: int = 1):
131
137
  await asyncio.gather(*tasks)
132
138
 
133
139
 
140
+ @sentry_utils.instrument_background_task
134
141
  async def _process_next_instance():
135
142
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__)
136
143
  async with get_session_ctx() as session:
@@ -149,12 +156,13 @@ async def _process_next_instance():
149
156
  ),
150
157
  InstanceModel.id.not_in(lockset),
151
158
  InstanceModel.last_processed_at
152
- < get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
159
+ < get_current_datetime() - MIN_PROCESSING_INTERVAL,
153
160
  )
154
- .options(lazyload(InstanceModel.jobs))
161
+ .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
162
+ .options(joinedload(InstanceModel.project).load_only(ProjectModel.ssh_private_key))
155
163
  .order_by(InstanceModel.last_processed_at.asc())
156
164
  .limit(1)
157
- .with_for_update(skip_locked=True, key_share=True)
165
+ .with_for_update(skip_locked=True, key_share=True, of=InstanceModel)
158
166
  )
159
167
  instance = res.scalar()
160
168
  if instance is None:
@@ -168,23 +176,22 @@ async def _process_next_instance():
168
176
 
169
177
 
170
178
  async def _process_instance(session: AsyncSession, instance: InstanceModel):
171
- # Refetch to load related attributes.
172
- # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
173
- res = await session.execute(
174
- select(InstanceModel)
175
- .where(InstanceModel.id == instance.id)
176
- .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
177
- .options(joinedload(InstanceModel.jobs))
178
- .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
179
- .execution_options(populate_existing=True)
180
- )
181
- instance = res.unique().scalar_one()
182
- if (
183
- instance.status == InstanceStatus.IDLE
184
- and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
185
- and not instance.jobs
179
+ if instance.status in (
180
+ InstanceStatus.PENDING,
181
+ InstanceStatus.TERMINATING,
186
182
  ):
187
- await _mark_terminating_if_idle_duration_expired(instance)
183
+ # Refetch to load related attributes.
184
+ # Load related attributes only for statuses that always need them.
185
+ res = await session.execute(
186
+ select(InstanceModel)
187
+ .where(InstanceModel.id == instance.id)
188
+ .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
189
+ .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
190
+ .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
191
+ .execution_options(populate_existing=True)
192
+ )
193
+ instance = res.unique().scalar_one()
194
+
188
195
  if instance.status == InstanceStatus.PENDING:
189
196
  if instance.remote_connection_info is not None:
190
197
  await _add_remote(instance)
@@ -198,7 +205,9 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
198
205
  InstanceStatus.IDLE,
199
206
  InstanceStatus.BUSY,
200
207
  ):
201
- await _check_instance(instance)
208
+ idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(instance)
209
+ if not idle_duration_expired:
210
+ await _check_instance(session, instance)
202
211
  elif instance.status == InstanceStatus.TERMINATING:
203
212
  await _terminate(instance)
204
213
 
@@ -206,7 +215,13 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
206
215
  await session.commit()
207
216
 
208
217
 
209
- async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
218
+ def _check_and_mark_terminating_if_idle_duration_expired(instance: InstanceModel):
219
+ if not (
220
+ instance.status == InstanceStatus.IDLE
221
+ and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
222
+ and not instance.jobs
223
+ ):
224
+ return False
210
225
  idle_duration = _get_instance_idle_duration(instance)
211
226
  idle_seconds = instance.termination_idle_time
212
227
  delta = datetime.timedelta(seconds=idle_seconds)
@@ -222,6 +237,8 @@ async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
222
237
  "instance_status": instance.status.value,
223
238
  },
224
239
  )
240
+ return True
241
+ return False
225
242
 
226
243
 
227
244
  async def _add_remote(instance: InstanceModel) -> None:
@@ -461,7 +478,7 @@ def _deploy_instance(
461
478
 
462
479
  async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
463
480
  if instance.last_retry_at is not None:
464
- last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc)
481
+ last_retry = instance.last_retry_at
465
482
  if get_current_datetime() < last_retry + timedelta(minutes=1):
466
483
  return
467
484
 
@@ -700,7 +717,7 @@ def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
700
717
  )
701
718
 
702
719
 
703
- async def _check_instance(instance: InstanceModel) -> None:
720
+ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> None:
704
721
  if (
705
722
  instance.status == InstanceStatus.BUSY
706
723
  and instance.jobs
@@ -719,12 +736,16 @@ async def _check_instance(instance: InstanceModel) -> None:
719
736
  )
720
737
  return
721
738
 
722
- job_provisioning_data = JobProvisioningData.__response__.parse_raw(
723
- instance.job_provisioning_data
724
- )
739
+ job_provisioning_data = get_or_error(get_instance_provisioning_data(instance))
725
740
  if job_provisioning_data.hostname is None:
741
+ res = await session.execute(
742
+ select(ProjectModel)
743
+ .where(ProjectModel.id == instance.project_id)
744
+ .options(joinedload(ProjectModel.backends))
745
+ )
746
+ project = res.unique().scalar_one()
726
747
  await _wait_for_instance_provisioning_data(
727
- project=instance.project,
748
+ project=project,
728
749
  instance=instance,
729
750
  job_provisioning_data=job_provisioning_data,
730
751
  )
@@ -801,7 +822,7 @@ async def _check_instance(instance: InstanceModel) -> None:
801
822
  instance.name,
802
823
  extra={"instance_name": instance.name},
803
824
  )
804
- deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
825
+ deadline = instance.termination_deadline
805
826
  if get_current_datetime() > deadline:
806
827
  instance.status = InstanceStatus.TERMINATING
807
828
  instance.termination_reason = "Termination deadline"
@@ -956,18 +977,12 @@ async def _terminate(instance: InstanceModel) -> None:
956
977
 
957
978
  def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
958
979
  assert instance.last_termination_retry_at is not None
959
- return (
960
- instance.last_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
961
- + TERMINATION_RETRY_TIMEOUT
962
- )
980
+ return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT
963
981
 
964
982
 
965
983
  def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
966
984
  assert instance.first_termination_retry_at is not None
967
- return (
968
- instance.first_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
969
- + TERMINATION_RETRY_MAX_DURATION
970
- )
985
+ return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION
971
986
 
972
987
 
973
988
  def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
@@ -1102,27 +1117,26 @@ async def _create_placement_group(
1102
1117
 
1103
1118
 
1104
1119
  def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
1105
- last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)
1120
+ last_time = instance.created_at
1106
1121
  if instance.last_job_processed_at is not None:
1107
- last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc)
1122
+ last_time = instance.last_job_processed_at
1108
1123
  return get_current_datetime() - last_time
1109
1124
 
1110
1125
 
1111
1126
  def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime:
1112
- return instance.created_at.replace(tzinfo=datetime.timezone.utc) + timedelta(
1113
- seconds=retry.duration
1114
- )
1127
+ return instance.created_at + timedelta(seconds=retry.duration)
1115
1128
 
1116
1129
 
1117
1130
  def _get_provisioning_deadline(
1118
1131
  instance: InstanceModel,
1119
1132
  job_provisioning_data: JobProvisioningData,
1120
1133
  ) -> datetime.datetime:
1134
+ assert instance.started_at is not None
1121
1135
  timeout_interval = get_provisioning_timeout(
1122
1136
  backend_type=job_provisioning_data.get_base_backend(),
1123
1137
  instance_type_name=job_provisioning_data.instance_type.name,
1124
1138
  )
1125
- return instance.started_at.replace(tzinfo=datetime.timezone.utc) + timeout_interval
1139
+ return instance.started_at + timeout_interval
1126
1140
 
1127
1141
 
1128
1142
  def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
@@ -9,12 +9,13 @@ from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
9
9
  from dstack._internal.core.models.runs import JobStatus
10
10
  from dstack._internal.server import settings
11
11
  from dstack._internal.server.db import get_session_ctx
12
- from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel
12
+ from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel, ProjectModel
13
13
  from dstack._internal.server.schemas.runner import MetricsResponse
14
14
  from dstack._internal.server.services.instances import get_instance_ssh_private_keys
15
15
  from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
16
16
  from dstack._internal.server.services.runner import client
17
17
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
18
+ from dstack._internal.server.utils import sentry_utils
18
19
  from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
19
20
  from dstack._internal.utils.logging import get_logger
20
21
 
@@ -26,12 +27,17 @@ BATCH_SIZE = 10
26
27
  MIN_COLLECT_INTERVAL_SECONDS = 9
27
28
 
28
29
 
30
+ @sentry_utils.instrument_background_task
29
31
  async def collect_metrics():
30
32
  async with get_session_ctx() as session:
31
33
  res = await session.execute(
32
34
  select(JobModel)
33
35
  .where(JobModel.status.in_([JobStatus.RUNNING]))
34
- .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
36
+ .options(
37
+ joinedload(JobModel.instance)
38
+ .joinedload(InstanceModel.project)
39
+ .load_only(ProjectModel.ssh_private_key)
40
+ )
35
41
  .order_by(JobModel.last_processed_at.asc())
36
42
  .limit(MAX_JOBS_FETCHED)
37
43
  )
@@ -41,6 +47,7 @@ async def collect_metrics():
41
47
  await _collect_jobs_metrics(batch)
42
48
 
43
49
 
50
+ @sentry_utils.instrument_background_task
44
51
  async def delete_metrics():
45
52
  now_timestamp_micro = int(get_current_datetime().timestamp() * 1_000_000)
46
53
  running_timestamp_micro_cutoff = (
@@ -12,12 +12,14 @@ from dstack._internal.server.models import PlacementGroupModel, ProjectModel
12
12
  from dstack._internal.server.services import backends as backends_services
13
13
  from dstack._internal.server.services.locking import get_locker
14
14
  from dstack._internal.server.services.placement import placement_group_model_to_placement_group
15
+ from dstack._internal.server.utils import sentry_utils
15
16
  from dstack._internal.utils.common import get_current_datetime, run_async
16
17
  from dstack._internal.utils.logging import get_logger
17
18
 
18
19
  logger = get_logger(__name__)
19
20
 
20
21
 
22
+ @sentry_utils.instrument_background_task
21
23
  async def process_placement_groups():
22
24
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(
23
25
  PlacementGroupModel.__tablename__
@@ -9,11 +9,17 @@ from sqlalchemy.orm import joinedload
9
9
  from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
10
10
  from dstack._internal.core.models.runs import JobStatus
11
11
  from dstack._internal.server.db import get_session_ctx
12
- from dstack._internal.server.models import InstanceModel, JobModel, JobPrometheusMetrics
12
+ from dstack._internal.server.models import (
13
+ InstanceModel,
14
+ JobModel,
15
+ JobPrometheusMetrics,
16
+ ProjectModel,
17
+ )
13
18
  from dstack._internal.server.services.instances import get_instance_ssh_private_keys
14
19
  from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
15
20
  from dstack._internal.server.services.runner import client
16
21
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
22
+ from dstack._internal.server.utils import sentry_utils
17
23
  from dstack._internal.server.utils.common import gather_map_async
18
24
  from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
19
25
  from dstack._internal.utils.logging import get_logger
@@ -29,6 +35,7 @@ MIN_COLLECT_INTERVAL_SECONDS = 9
29
35
  METRICS_TTL_SECONDS = 600
30
36
 
31
37
 
38
+ @sentry_utils.instrument_background_task
32
39
  async def collect_prometheus_metrics():
33
40
  now = get_current_datetime()
34
41
  cutoff = now - timedelta(seconds=MIN_COLLECT_INTERVAL_SECONDS)
@@ -43,7 +50,11 @@ async def collect_prometheus_metrics():
43
50
  JobPrometheusMetrics.collected_at < cutoff,
44
51
  ),
45
52
  )
46
- .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
53
+ .options(
54
+ joinedload(JobModel.instance)
55
+ .joinedload(InstanceModel.project)
56
+ .load_only(ProjectModel.ssh_private_key)
57
+ )
47
58
  .order_by(JobModel.last_processed_at.asc())
48
59
  .limit(MAX_JOBS_FETCHED)
49
60
  )
@@ -52,6 +63,7 @@ async def collect_prometheus_metrics():
52
63
  await _collect_jobs_metrics(batch, now)
53
64
 
54
65
 
66
+ @sentry_utils.instrument_background_task
55
67
  async def delete_prometheus_metrics():
56
68
  now = get_current_datetime()
57
69
  cutoff = now - timedelta(seconds=METRICS_TTL_SECONDS)
@@ -2,12 +2,12 @@ import asyncio
2
2
  import re
3
3
  import uuid
4
4
  from collections.abc import Iterable
5
- from datetime import timedelta, timezone
5
+ from datetime import timedelta
6
6
  from typing import Dict, List, Optional
7
7
 
8
8
  from sqlalchemy import select
9
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
- from sqlalchemy.orm import joinedload
10
+ from sqlalchemy.orm import joinedload, load_only
11
11
 
12
12
  from dstack._internal import settings
13
13
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
@@ -73,6 +73,7 @@ from dstack._internal.server.services.runs import (
73
73
  )
74
74
  from dstack._internal.server.services.secrets import get_project_secrets_mapping
75
75
  from dstack._internal.server.services.storage import get_default_storage
76
+ from dstack._internal.server.utils import sentry_utils
76
77
  from dstack._internal.utils import common as common_utils
77
78
  from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
78
79
  from dstack._internal.utils.logging import get_logger
@@ -94,6 +95,7 @@ async def process_running_jobs(batch_size: int = 1):
94
95
  await asyncio.gather(*tasks)
95
96
 
96
97
 
98
+ @sentry_utils.instrument_background_task
97
99
  async def _process_next_running_job():
98
100
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
99
101
  async with get_session_ctx() as session:
@@ -108,9 +110,9 @@ async def _process_next_running_job():
108
110
  RunModel.status.not_in([RunStatus.TERMINATING]),
109
111
  JobModel.id.not_in(lockset),
110
112
  JobModel.last_processed_at
111
- < common_utils.get_current_datetime().replace(tzinfo=None)
112
- - MIN_PROCESSING_INTERVAL,
113
+ < common_utils.get_current_datetime() - MIN_PROCESSING_INTERVAL,
113
114
  )
115
+ .options(load_only(JobModel.id))
114
116
  .order_by(JobModel.last_processed_at.asc())
115
117
  .limit(1)
116
118
  .with_for_update(
@@ -133,7 +135,6 @@ async def _process_next_running_job():
133
135
 
134
136
  async def _process_running_job(session: AsyncSession, job_model: JobModel):
135
137
  # Refetch to load related attributes.
136
- # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
137
138
  res = await session.execute(
138
139
  select(JobModel)
139
140
  .where(JobModel.id == job_model.id)
@@ -144,7 +145,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
144
145
  res = await session.execute(
145
146
  select(RunModel)
146
147
  .where(RunModel.id == job_model.run_id)
147
- .options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
148
+ .options(joinedload(RunModel.project))
148
149
  .options(joinedload(RunModel.user))
149
150
  .options(joinedload(RunModel.repo))
150
151
  .options(joinedload(RunModel.jobs))
@@ -160,143 +161,147 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
160
161
  job_model.status = JobStatus.TERMINATING
161
162
  job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
162
163
  job_model.last_processed_at = common_utils.get_current_datetime()
164
+ await session.commit()
163
165
  return
164
166
 
165
167
  job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
166
168
 
167
- # Wait until all other jobs in the replica are provisioned
168
- for other_job in run.jobs:
169
- if (
170
- other_job.job_spec.replica_num == job.job_spec.replica_num
171
- and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
172
- ):
169
+ initial_status = job_model.status
170
+ if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]:
171
+ # Wait until all other jobs in the replica are provisioned
172
+ for other_job in run.jobs:
173
+ if (
174
+ other_job.job_spec.replica_num == job.job_spec.replica_num
175
+ and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
176
+ ):
177
+ job_model.last_processed_at = common_utils.get_current_datetime()
178
+ await session.commit()
179
+ return
180
+
181
+ cluster_info = _get_cluster_info(
182
+ jobs=run.jobs,
183
+ replica_num=job.job_spec.replica_num,
184
+ job_provisioning_data=job_provisioning_data,
185
+ job_runtime_data=job_submission.job_runtime_data,
186
+ )
187
+
188
+ volumes = await get_job_attached_volumes(
189
+ session=session,
190
+ project=project,
191
+ run_spec=run.run_spec,
192
+ job_num=job.job_spec.job_num,
193
+ job_provisioning_data=job_provisioning_data,
194
+ )
195
+
196
+ repo_creds_model = await get_repo_creds(
197
+ session=session, repo=repo_model, user=run_model.user
198
+ )
199
+ repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
200
+
201
+ secrets = await get_project_secrets_mapping(session=session, project=project)
202
+ try:
203
+ _interpolate_secrets(secrets, job.job_spec)
204
+ except InterpolatorError as e:
205
+ logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
206
+ job_model.status = JobStatus.TERMINATING
207
+ job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
208
+ job_model.termination_reason_message = e.args[0]
173
209
  job_model.last_processed_at = common_utils.get_current_datetime()
174
210
  await session.commit()
175
211
  return
176
212
 
177
- cluster_info = _get_cluster_info(
178
- jobs=run.jobs,
179
- replica_num=job.job_spec.replica_num,
180
- job_provisioning_data=job_provisioning_data,
181
- job_runtime_data=job_submission.job_runtime_data,
182
- )
183
-
184
- volumes = await get_job_attached_volumes(
185
- session=session,
186
- project=project,
187
- run_spec=run.run_spec,
188
- job_num=job.job_spec.job_num,
189
- job_provisioning_data=job_provisioning_data,
190
- )
191
-
192
213
  server_ssh_private_keys = get_instance_ssh_private_keys(
193
214
  common_utils.get_or_error(job_model.instance)
194
215
  )
195
216
 
196
- secrets = await get_project_secrets_mapping(session=session, project=project)
197
-
198
- try:
199
- _interpolate_secrets(secrets, job.job_spec)
200
- except InterpolatorError as e:
201
- logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
202
- job_model.status = JobStatus.TERMINATING
203
- job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
204
- job_model.termination_reason_message = e.args[0]
205
- job_model.last_processed_at = common_utils.get_current_datetime()
206
- return
207
-
208
- repo_creds_model = await get_repo_creds(session=session, repo=repo_model, user=run_model.user)
209
- repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
210
-
211
- initial_status = job_model.status
212
217
  if initial_status == JobStatus.PROVISIONING:
213
218
  if job_provisioning_data.hostname is None:
214
219
  await _wait_for_instance_provisioning_data(job_model=job_model)
220
+ job_model.last_processed_at = common_utils.get_current_datetime()
221
+ await session.commit()
222
+ return
223
+ if _should_wait_for_other_nodes(run, job, job_model):
224
+ job_model.last_processed_at = common_utils.get_current_datetime()
225
+ await session.commit()
226
+ return
227
+
228
+ # fails are acceptable until timeout is exceeded
229
+ if job_provisioning_data.dockerized:
230
+ logger.debug(
231
+ "%s: process provisioning job with shim, age=%s",
232
+ fmt(job_model),
233
+ job_submission.age,
234
+ )
235
+ ssh_user = job_provisioning_data.username
236
+ user_ssh_key = run.run_spec.ssh_key_pub.strip()
237
+ public_keys = [project.ssh_public_key.strip(), user_ssh_key]
238
+ if job_provisioning_data.backend == BackendType.LOCAL:
239
+ # No need to update ~/.ssh/authorized_keys when running shim locally
240
+ user_ssh_key = ""
241
+ success = await common_utils.run_async(
242
+ _process_provisioning_with_shim,
243
+ server_ssh_private_keys,
244
+ job_provisioning_data,
245
+ None,
246
+ run,
247
+ job_model,
248
+ job_provisioning_data,
249
+ volumes,
250
+ job.job_spec.registry_auth,
251
+ public_keys,
252
+ ssh_user,
253
+ user_ssh_key,
254
+ )
215
255
  else:
216
- if _should_wait_for_other_nodes(run, job, job_model):
217
- job_model.last_processed_at = common_utils.get_current_datetime()
218
- await session.commit()
219
- return
256
+ logger.debug(
257
+ "%s: process provisioning job without shim, age=%s",
258
+ fmt(job_model),
259
+ job_submission.age,
260
+ )
261
+ # FIXME: downloading file archives and code here is a waste of time if
262
+ # the runner is not ready yet
263
+ file_archives = await _get_job_file_archives(
264
+ session=session,
265
+ archive_mappings=job.job_spec.file_archives,
266
+ user=run_model.user,
267
+ )
268
+ code = await _get_job_code(
269
+ session=session,
270
+ project=project,
271
+ repo=repo_model,
272
+ code_hash=_get_repo_code_hash(run, job),
273
+ )
220
274
 
221
- # fails are acceptable until timeout is exceeded
222
- if job_provisioning_data.dockerized:
223
- logger.debug(
224
- "%s: process provisioning job with shim, age=%s",
225
- fmt(job_model),
226
- job_submission.age,
227
- )
228
- ssh_user = job_provisioning_data.username
229
- user_ssh_key = run.run_spec.ssh_key_pub.strip()
230
- public_keys = [project.ssh_public_key.strip(), user_ssh_key]
231
- if job_provisioning_data.backend == BackendType.LOCAL:
232
- # No need to update ~/.ssh/authorized_keys when running shim locally
233
- user_ssh_key = ""
234
- success = await common_utils.run_async(
235
- _process_provisioning_with_shim,
236
- server_ssh_private_keys,
237
- job_provisioning_data,
238
- None,
239
- run,
240
- job_model,
241
- job_provisioning_data,
242
- volumes,
243
- job.job_spec.registry_auth,
244
- public_keys,
245
- ssh_user,
246
- user_ssh_key,
247
- )
248
- else:
249
- logger.debug(
250
- "%s: process provisioning job without shim, age=%s",
275
+ success = await common_utils.run_async(
276
+ _submit_job_to_runner,
277
+ server_ssh_private_keys,
278
+ job_provisioning_data,
279
+ None,
280
+ run,
281
+ job_model,
282
+ job,
283
+ cluster_info,
284
+ code,
285
+ file_archives,
286
+ secrets,
287
+ repo_creds,
288
+ success_if_not_available=False,
289
+ )
290
+
291
+ if not success:
292
+ # check timeout
293
+ if job_submission.age > get_provisioning_timeout(
294
+ backend_type=job_provisioning_data.get_base_backend(),
295
+ instance_type_name=job_provisioning_data.instance_type.name,
296
+ ):
297
+ logger.warning(
298
+ "%s: failed because runner has not become available in time, age=%s",
251
299
  fmt(job_model),
252
300
  job_submission.age,
253
301
  )
254
- # FIXME: downloading file archives and code here is a waste of time if
255
- # the runner is not ready yet
256
- file_archives = await _get_job_file_archives(
257
- session=session,
258
- archive_mappings=job.job_spec.file_archives,
259
- user=run_model.user,
260
- )
261
- code = await _get_job_code(
262
- session=session,
263
- project=project,
264
- repo=repo_model,
265
- code_hash=_get_repo_code_hash(run, job),
266
- )
267
-
268
- success = await common_utils.run_async(
269
- _submit_job_to_runner,
270
- server_ssh_private_keys,
271
- job_provisioning_data,
272
- None,
273
- run,
274
- job_model,
275
- job,
276
- cluster_info,
277
- code,
278
- file_archives,
279
- secrets,
280
- repo_creds,
281
- success_if_not_available=False,
282
- )
283
-
284
- if not success:
285
- # check timeout
286
- if job_submission.age > get_provisioning_timeout(
287
- backend_type=job_provisioning_data.get_base_backend(),
288
- instance_type_name=job_provisioning_data.instance_type.name,
289
- ):
290
- logger.warning(
291
- "%s: failed because runner has not become available in time, age=%s",
292
- fmt(job_model),
293
- job_submission.age,
294
- )
295
- job_model.status = JobStatus.TERMINATING
296
- job_model.termination_reason = (
297
- JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
298
- )
299
- # instance will be emptied by process_terminating_jobs
302
+ job_model.status = JobStatus.TERMINATING
303
+ job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
304
+ # instance will be emptied by process_terminating_jobs
300
305
 
301
306
  else: # fails are not acceptable
302
307
  if initial_status == JobStatus.PULLING:
@@ -801,7 +806,7 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
801
806
  return False
802
807
  return (
803
808
  common_utils.get_current_datetime()
804
- > job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT
809
+ > job_model.disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT
805
810
  )
806
811
 
807
812