dstack 0.19.20__py3-none-any.whl → 0.19.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (93) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -3
  2. dstack/_internal/cli/services/configurators/__init__.py +8 -0
  3. dstack/_internal/cli/services/configurators/fleet.py +1 -1
  4. dstack/_internal/cli/services/configurators/gateway.py +1 -1
  5. dstack/_internal/cli/services/configurators/run.py +11 -1
  6. dstack/_internal/cli/services/configurators/volume.py +1 -1
  7. dstack/_internal/cli/utils/common.py +48 -5
  8. dstack/_internal/cli/utils/fleet.py +5 -5
  9. dstack/_internal/cli/utils/run.py +32 -0
  10. dstack/_internal/core/backends/__init__.py +0 -65
  11. dstack/_internal/core/backends/configurators.py +9 -0
  12. dstack/_internal/core/backends/features.py +64 -0
  13. dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
  14. dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
  15. dstack/_internal/core/backends/hotaisle/backend.py +16 -0
  16. dstack/_internal/core/backends/hotaisle/compute.py +225 -0
  17. dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
  18. dstack/_internal/core/backends/hotaisle/models.py +45 -0
  19. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  20. dstack/_internal/core/backends/models.py +8 -0
  21. dstack/_internal/core/compatibility/fleets.py +2 -0
  22. dstack/_internal/core/compatibility/runs.py +12 -0
  23. dstack/_internal/core/models/backends/base.py +2 -0
  24. dstack/_internal/core/models/configurations.py +139 -1
  25. dstack/_internal/core/models/health.py +28 -0
  26. dstack/_internal/core/models/instances.py +2 -0
  27. dstack/_internal/core/models/logs.py +2 -1
  28. dstack/_internal/core/models/profiles.py +37 -0
  29. dstack/_internal/core/models/runs.py +21 -1
  30. dstack/_internal/core/services/ssh/tunnel.py +7 -0
  31. dstack/_internal/server/app.py +26 -10
  32. dstack/_internal/server/background/__init__.py +9 -6
  33. dstack/_internal/server/background/tasks/process_fleets.py +52 -38
  34. dstack/_internal/server/background/tasks/process_gateways.py +2 -2
  35. dstack/_internal/server/background/tasks/process_idle_volumes.py +5 -4
  36. dstack/_internal/server/background/tasks/process_instances.py +168 -103
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -2
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +2 -0
  39. dstack/_internal/server/background/tasks/process_probes.py +164 -0
  40. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +14 -2
  41. dstack/_internal/server/background/tasks/process_running_jobs.py +142 -124
  42. dstack/_internal/server/background/tasks/process_runs.py +84 -34
  43. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -10
  44. dstack/_internal/server/background/tasks/process_terminating_jobs.py +12 -4
  45. dstack/_internal/server/background/tasks/process_volumes.py +4 -1
  46. dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
  47. dstack/_internal/server/migrations/versions/50dd7ea98639_index_status_columns.py +55 -0
  48. dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
  49. dstack/_internal/server/migrations/versions/ec02a26a256c_add_runmodel_next_triggered_at.py +38 -0
  50. dstack/_internal/server/models.py +57 -16
  51. dstack/_internal/server/routers/instances.py +33 -5
  52. dstack/_internal/server/schemas/health/dcgm.py +56 -0
  53. dstack/_internal/server/schemas/instances.py +32 -0
  54. dstack/_internal/server/schemas/runner.py +5 -0
  55. dstack/_internal/server/services/fleets.py +19 -10
  56. dstack/_internal/server/services/gateways/__init__.py +17 -17
  57. dstack/_internal/server/services/instances.py +113 -15
  58. dstack/_internal/server/services/jobs/__init__.py +18 -13
  59. dstack/_internal/server/services/jobs/configurators/base.py +26 -0
  60. dstack/_internal/server/services/logging.py +4 -2
  61. dstack/_internal/server/services/logs/aws.py +13 -1
  62. dstack/_internal/server/services/logs/gcp.py +16 -1
  63. dstack/_internal/server/services/offers.py +3 -3
  64. dstack/_internal/server/services/probes.py +6 -0
  65. dstack/_internal/server/services/projects.py +51 -19
  66. dstack/_internal/server/services/prometheus/client_metrics.py +3 -0
  67. dstack/_internal/server/services/prometheus/custom_metrics.py +2 -3
  68. dstack/_internal/server/services/runner/client.py +52 -20
  69. dstack/_internal/server/services/runner/ssh.py +4 -4
  70. dstack/_internal/server/services/runs.py +115 -39
  71. dstack/_internal/server/services/services/__init__.py +4 -1
  72. dstack/_internal/server/services/ssh.py +66 -0
  73. dstack/_internal/server/services/users.py +2 -3
  74. dstack/_internal/server/services/volumes.py +11 -11
  75. dstack/_internal/server/settings.py +16 -0
  76. dstack/_internal/server/statics/index.html +1 -1
  77. dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
  78. dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
  79. dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
  80. dstack/_internal/server/testing/common.py +51 -0
  81. dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
  82. dstack/_internal/server/utils/sentry_utils.py +12 -0
  83. dstack/_internal/settings.py +3 -0
  84. dstack/_internal/utils/common.py +15 -0
  85. dstack/_internal/utils/cron.py +5 -0
  86. dstack/api/server/__init__.py +1 -1
  87. dstack/version.py +1 -1
  88. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/METADATA +13 -22
  89. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/RECORD +93 -75
  90. /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
  91. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/WHEEL +0 -0
  92. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/entry_points.txt +0 -0
  93. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,21 +1,18 @@
1
1
  import asyncio
2
2
  import datetime
3
+ import logging
3
4
  from datetime import timedelta
4
- from typing import Any, Dict, List, Optional, Tuple, cast
5
+ from typing import Any, Dict, List, Optional, cast
5
6
 
6
7
  import requests
7
8
  from paramiko.pkey import PKey
8
9
  from paramiko.ssh_exception import PasswordRequiredException
9
10
  from pydantic import ValidationError
10
- from sqlalchemy import select
11
+ from sqlalchemy import delete, func, select
11
12
  from sqlalchemy.ext.asyncio import AsyncSession
12
- from sqlalchemy.orm import joinedload, lazyload
13
+ from sqlalchemy.orm import joinedload
13
14
 
14
15
  from dstack._internal import settings
15
- from dstack._internal.core.backends import (
16
- BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
17
- BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
18
- )
19
16
  from dstack._internal.core.backends.base.compute import (
20
17
  ComputeWithCreateInstanceSupport,
21
18
  ComputeWithPlacementGroupSupport,
@@ -27,17 +24,9 @@ from dstack._internal.core.backends.base.compute import (
27
24
  get_shim_env,
28
25
  get_shim_pre_start_commands,
29
26
  )
30
- from dstack._internal.core.backends.remote.provisioning import (
31
- detect_cpu_arch,
32
- get_host_info,
33
- get_paramiko_connection,
34
- get_shim_healthcheck,
35
- host_info_to_instance_type,
36
- remove_dstack_runner_if_exists,
37
- remove_host_info_if_exists,
38
- run_pre_start_commands,
39
- run_shim_as_systemd_service,
40
- upload_envs,
27
+ from dstack._internal.core.backends.features import (
28
+ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
29
+ BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
41
30
  )
42
31
  from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
43
32
 
@@ -77,11 +66,14 @@ from dstack._internal.server.background.tasks.common import get_provisioning_tim
77
66
  from dstack._internal.server.db import get_db, get_session_ctx
78
67
  from dstack._internal.server.models import (
79
68
  FleetModel,
69
+ InstanceHealthCheckModel,
80
70
  InstanceModel,
71
+ JobModel,
81
72
  PlacementGroupModel,
82
73
  ProjectModel,
83
74
  )
84
- from dstack._internal.server.schemas.runner import HealthcheckResponse
75
+ from dstack._internal.server.schemas.instances import InstanceCheck
76
+ from dstack._internal.server.schemas.runner import HealthcheckResponse, InstanceHealthResponse
85
77
  from dstack._internal.server.services import backends as backends_services
86
78
  from dstack._internal.server.services.fleets import (
87
79
  fleet_model_to_fleet,
@@ -102,9 +94,25 @@ from dstack._internal.server.services.placement import (
102
94
  schedule_fleet_placement_groups_deletion,
103
95
  )
104
96
  from dstack._internal.server.services.runner import client as runner_client
105
- from dstack._internal.server.services.runner.client import HealthStatus
106
97
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
107
- from dstack._internal.utils.common import get_current_datetime, run_async
98
+ from dstack._internal.server.utils import sentry_utils
99
+ from dstack._internal.server.utils.provisioning import (
100
+ detect_cpu_arch,
101
+ get_host_info,
102
+ get_paramiko_connection,
103
+ get_shim_healthcheck,
104
+ host_info_to_instance_type,
105
+ remove_dstack_runner_if_exists,
106
+ remove_host_info_if_exists,
107
+ run_pre_start_commands,
108
+ run_shim_as_systemd_service,
109
+ upload_envs,
110
+ )
111
+ from dstack._internal.utils.common import (
112
+ get_current_datetime,
113
+ get_or_error,
114
+ run_async,
115
+ )
108
116
  from dstack._internal.utils.logging import get_logger
109
117
  from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
110
118
  from dstack._internal.utils.ssh import (
@@ -131,6 +139,18 @@ async def process_instances(batch_size: int = 1):
131
139
  await asyncio.gather(*tasks)
132
140
 
133
141
 
142
+ @sentry_utils.instrument_background_task
143
+ async def delete_instance_health_checks():
144
+ now = get_current_datetime()
145
+ cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS)
146
+ async with get_session_ctx() as session:
147
+ await session.execute(
148
+ delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff)
149
+ )
150
+ await session.commit()
151
+
152
+
153
+ @sentry_utils.instrument_background_task
134
154
  async def _process_next_instance():
135
155
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__)
136
156
  async with get_session_ctx() as session:
@@ -149,12 +169,13 @@ async def _process_next_instance():
149
169
  ),
150
170
  InstanceModel.id.not_in(lockset),
151
171
  InstanceModel.last_processed_at
152
- < get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
172
+ < get_current_datetime() - MIN_PROCESSING_INTERVAL,
153
173
  )
154
- .options(lazyload(InstanceModel.jobs))
174
+ .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
175
+ .options(joinedload(InstanceModel.project).load_only(ProjectModel.ssh_private_key))
155
176
  .order_by(InstanceModel.last_processed_at.asc())
156
177
  .limit(1)
157
- .with_for_update(skip_locked=True, key_share=True)
178
+ .with_for_update(skip_locked=True, key_share=True, of=InstanceModel)
158
179
  )
159
180
  instance = res.scalar()
160
181
  if instance is None:
@@ -168,23 +189,22 @@ async def _process_next_instance():
168
189
 
169
190
 
170
191
  async def _process_instance(session: AsyncSession, instance: InstanceModel):
171
- # Refetch to load related attributes.
172
- # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
173
- res = await session.execute(
174
- select(InstanceModel)
175
- .where(InstanceModel.id == instance.id)
176
- .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
177
- .options(joinedload(InstanceModel.jobs))
178
- .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
179
- .execution_options(populate_existing=True)
180
- )
181
- instance = res.unique().scalar_one()
182
- if (
183
- instance.status == InstanceStatus.IDLE
184
- and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
185
- and not instance.jobs
192
+ if instance.status in (
193
+ InstanceStatus.PENDING,
194
+ InstanceStatus.TERMINATING,
186
195
  ):
187
- await _mark_terminating_if_idle_duration_expired(instance)
196
+ # Refetch to load related attributes.
197
+ # Load related attributes only for statuses that always need them.
198
+ res = await session.execute(
199
+ select(InstanceModel)
200
+ .where(InstanceModel.id == instance.id)
201
+ .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
202
+ .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
203
+ .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
204
+ .execution_options(populate_existing=True)
205
+ )
206
+ instance = res.unique().scalar_one()
207
+
188
208
  if instance.status == InstanceStatus.PENDING:
189
209
  if instance.remote_connection_info is not None:
190
210
  await _add_remote(instance)
@@ -198,7 +218,9 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
198
218
  InstanceStatus.IDLE,
199
219
  InstanceStatus.BUSY,
200
220
  ):
201
- await _check_instance(instance)
221
+ idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(instance)
222
+ if not idle_duration_expired:
223
+ await _check_instance(session, instance)
202
224
  elif instance.status == InstanceStatus.TERMINATING:
203
225
  await _terminate(instance)
204
226
 
@@ -206,7 +228,13 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
206
228
  await session.commit()
207
229
 
208
230
 
209
- async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
231
+ def _check_and_mark_terminating_if_idle_duration_expired(instance: InstanceModel):
232
+ if not (
233
+ instance.status == InstanceStatus.IDLE
234
+ and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
235
+ and not instance.jobs
236
+ ):
237
+ return False
210
238
  idle_duration = _get_instance_idle_duration(instance)
211
239
  idle_seconds = instance.termination_idle_time
212
240
  delta = datetime.timedelta(seconds=idle_seconds)
@@ -222,6 +250,8 @@ async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
222
250
  "instance_status": instance.status.value,
223
251
  },
224
252
  )
253
+ return True
254
+ return False
225
255
 
226
256
 
227
257
  async def _add_remote(instance: InstanceModel) -> None:
@@ -398,10 +428,10 @@ async def _add_remote(instance: InstanceModel) -> None:
398
428
 
399
429
  def _deploy_instance(
400
430
  remote_details: RemoteConnectionInfo,
401
- pkeys: List[PKey],
431
+ pkeys: list[PKey],
402
432
  ssh_proxy_pkeys: Optional[list[PKey]],
403
- authorized_keys: List[str],
404
- ) -> Tuple[HealthStatus, Dict[str, Any], GoArchType]:
433
+ authorized_keys: list[str],
434
+ ) -> tuple[InstanceCheck, dict[str, Any], GoArchType]:
405
435
  with get_paramiko_connection(
406
436
  remote_details.ssh_user,
407
437
  remote_details.host,
@@ -449,19 +479,19 @@ def _deploy_instance(
449
479
  host_info = get_host_info(client, dstack_working_dir)
450
480
  logger.debug("Received a host_info %s", host_info)
451
481
 
452
- raw_health = get_shim_healthcheck(client)
482
+ healthcheck_out = get_shim_healthcheck(client)
453
483
  try:
454
- health_response = HealthcheckResponse.__response__.parse_raw(raw_health)
484
+ healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out)
455
485
  except ValueError as e:
456
- raise ProvisioningError("Cannot read HealthcheckResponse") from e
457
- health = runner_client.health_response_to_health_status(health_response)
486
+ raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e
487
+ instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck)
458
488
 
459
- return health, host_info, arch
489
+ return instance_check, host_info, arch
460
490
 
461
491
 
462
492
  async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
463
493
  if instance.last_retry_at is not None:
464
- last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc)
494
+ last_retry = instance.last_retry_at
465
495
  if get_current_datetime() < last_retry + timedelta(minutes=1):
466
496
  return
467
497
 
@@ -700,7 +730,7 @@ def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
700
730
  )
701
731
 
702
732
 
703
- async def _check_instance(instance: InstanceModel) -> None:
733
+ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> None:
704
734
  if (
705
735
  instance.status == InstanceStatus.BUSY
706
736
  and instance.jobs
@@ -719,12 +749,16 @@ async def _check_instance(instance: InstanceModel) -> None:
719
749
  )
720
750
  return
721
751
 
722
- job_provisioning_data = JobProvisioningData.__response__.parse_raw(
723
- instance.job_provisioning_data
724
- )
752
+ job_provisioning_data = get_or_error(get_instance_provisioning_data(instance))
725
753
  if job_provisioning_data.hostname is None:
754
+ res = await session.execute(
755
+ select(ProjectModel)
756
+ .where(ProjectModel.id == instance.project_id)
757
+ .options(joinedload(ProjectModel.backends))
758
+ )
759
+ project = res.unique().scalar_one()
726
760
  await _wait_for_instance_provisioning_data(
727
- project=instance.project,
761
+ project=project,
728
762
  instance=instance,
729
763
  job_provisioning_data=job_provisioning_data,
730
764
  )
@@ -737,29 +771,65 @@ async def _check_instance(instance: InstanceModel) -> None:
737
771
 
738
772
  ssh_private_keys = get_instance_ssh_private_keys(instance)
739
773
 
774
+ health_check_cutoff = get_current_datetime() - timedelta(
775
+ seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS
776
+ )
777
+ res = await session.execute(
778
+ select(func.count(1)).where(
779
+ InstanceHealthCheckModel.instance_id == instance.id,
780
+ InstanceHealthCheckModel.collected_at > health_check_cutoff,
781
+ )
782
+ )
783
+ check_instance_health = res.scalar_one() == 0
784
+
740
785
  # May return False if fails to establish ssh connection
741
- health_status_response = await run_async(
742
- _instance_healthcheck,
786
+ instance_check = await run_async(
787
+ _check_instance_inner,
743
788
  ssh_private_keys,
744
789
  job_provisioning_data,
745
790
  None,
791
+ check_instance_health=check_instance_health,
746
792
  )
747
- if isinstance(health_status_response, bool) or health_status_response is None:
748
- health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
749
- else:
750
- health_status = health_status_response
793
+ if instance_check is False:
794
+ instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error")
751
795
 
752
- logger.debug(
753
- "Check instance %s status. shim health: %s",
796
+ if instance_check.reachable and check_instance_health:
797
+ health_status = instance_check.get_health_status()
798
+ else:
799
+ # Keep previous health status
800
+ health_status = instance.health
801
+
802
+ loglevel = logging.DEBUG
803
+ if not instance_check.reachable and instance.status.is_available():
804
+ loglevel = logging.WARNING
805
+ elif check_instance_health and not health_status.is_healthy():
806
+ loglevel = logging.WARNING
807
+ logger.log(
808
+ loglevel,
809
+ "Instance %s check: reachable=%s health_status=%s message=%r",
754
810
  instance.name,
755
- health_status,
756
- extra={"instance_name": instance.name, "shim_health": health_status},
811
+ instance_check.reachable,
812
+ health_status.name,
813
+ instance_check.message,
814
+ extra={"instance_name": instance.name, "health_status": health_status},
757
815
  )
758
816
 
759
- if health_status.healthy:
817
+ if instance_check.has_health_checks():
818
+ # ensured by has_health_checks()
819
+ assert instance_check.health_response is not None
820
+ health_check_model = InstanceHealthCheckModel(
821
+ instance_id=instance.id,
822
+ collected_at=get_current_datetime(),
823
+ status=health_status,
824
+ response=instance_check.health_response.json(),
825
+ )
826
+ session.add(health_check_model)
827
+
828
+ instance.health = health_status
829
+ instance.unreachable = not instance_check.reachable
830
+
831
+ if instance_check.reachable:
760
832
  instance.termination_deadline = None
761
- instance.health_status = None
762
- instance.unreachable = False
763
833
 
764
834
  if instance.status == InstanceStatus.PROVISIONING:
765
835
  instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
@@ -777,9 +847,6 @@ async def _check_instance(instance: InstanceModel) -> None:
777
847
  if instance.termination_deadline is None:
778
848
  instance.termination_deadline = get_current_datetime() + TERMINATION_DEADLINE_OFFSET
779
849
 
780
- instance.health_status = health_status.reason
781
- instance.unreachable = True
782
-
783
850
  if instance.status == InstanceStatus.PROVISIONING and instance.started_at is not None:
784
851
  provisioning_deadline = _get_provisioning_deadline(
785
852
  instance=instance,
@@ -795,13 +862,8 @@ async def _check_instance(instance: InstanceModel) -> None:
795
862
  "instance_status": InstanceStatus.TERMINATING.value,
796
863
  },
797
864
  )
798
- elif instance.status in (InstanceStatus.IDLE, InstanceStatus.BUSY):
799
- logger.warning(
800
- "Instance %s shim is not available",
801
- instance.name,
802
- extra={"instance_name": instance.name},
803
- )
804
- deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
865
+ elif instance.status.is_available():
866
+ deadline = instance.termination_deadline
805
867
  if get_current_datetime() > deadline:
806
868
  instance.status = InstanceStatus.TERMINATING
807
869
  instance.termination_reason = "Termination deadline"
@@ -871,20 +933,30 @@ async def _wait_for_instance_provisioning_data(
871
933
 
872
934
 
873
935
  @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
874
- def _instance_healthcheck(ports: Dict[int, int]) -> HealthStatus:
936
+ def _check_instance_inner(
937
+ ports: Dict[int, int], *, check_instance_health: bool = False
938
+ ) -> InstanceCheck:
939
+ instance_health_response: Optional[InstanceHealthResponse] = None
875
940
  shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
941
+ method = shim_client.healthcheck
876
942
  try:
877
- resp = shim_client.healthcheck(unmask_exeptions=True)
878
- if resp is None:
879
- return HealthStatus(healthy=False, reason="Unknown reason")
880
- return runner_client.health_response_to_health_status(resp)
943
+ healthcheck_response = method(unmask_exceptions=True)
944
+ if check_instance_health:
945
+ method = shim_client.get_instance_health
946
+ instance_health_response = method()
881
947
  except requests.RequestException as e:
882
- return HealthStatus(healthy=False, reason=f"Can't request shim: {e}")
948
+ template = "shim.%s(): request error: %s"
949
+ args = (method.__func__.__name__, e)
950
+ logger.debug(template, *args)
951
+ return InstanceCheck(reachable=False, message=template % args)
883
952
  except Exception as e:
884
- logger.exception("Unknown exception from shim.healthcheck: %s", e)
885
- return HealthStatus(
886
- healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}"
887
- )
953
+ template = "shim.%s(): unexpected exception %s: %s"
954
+ args = (method.__func__.__name__, e.__class__.__name__, e)
955
+ logger.exception(template, *args)
956
+ return InstanceCheck(reachable=False, message=template % args)
957
+ return runner_client.healthcheck_response_to_instance_check(
958
+ healthcheck_response, instance_health_response
959
+ )
888
960
 
889
961
 
890
962
  async def _terminate(instance: InstanceModel) -> None:
@@ -956,18 +1028,12 @@ async def _terminate(instance: InstanceModel) -> None:
956
1028
 
957
1029
  def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
958
1030
  assert instance.last_termination_retry_at is not None
959
- return (
960
- instance.last_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
961
- + TERMINATION_RETRY_TIMEOUT
962
- )
1031
+ return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT
963
1032
 
964
1033
 
965
1034
  def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
966
1035
  assert instance.first_termination_retry_at is not None
967
- return (
968
- instance.first_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
969
- + TERMINATION_RETRY_MAX_DURATION
970
- )
1036
+ return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION
971
1037
 
972
1038
 
973
1039
  def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
@@ -1102,27 +1168,26 @@ async def _create_placement_group(
1102
1168
 
1103
1169
 
1104
1170
  def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
1105
- last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)
1171
+ last_time = instance.created_at
1106
1172
  if instance.last_job_processed_at is not None:
1107
- last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc)
1173
+ last_time = instance.last_job_processed_at
1108
1174
  return get_current_datetime() - last_time
1109
1175
 
1110
1176
 
1111
1177
  def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime:
1112
- return instance.created_at.replace(tzinfo=datetime.timezone.utc) + timedelta(
1113
- seconds=retry.duration
1114
- )
1178
+ return instance.created_at + timedelta(seconds=retry.duration)
1115
1179
 
1116
1180
 
1117
1181
  def _get_provisioning_deadline(
1118
1182
  instance: InstanceModel,
1119
1183
  job_provisioning_data: JobProvisioningData,
1120
1184
  ) -> datetime.datetime:
1185
+ assert instance.started_at is not None
1121
1186
  timeout_interval = get_provisioning_timeout(
1122
1187
  backend_type=job_provisioning_data.get_base_backend(),
1123
1188
  instance_type_name=job_provisioning_data.instance_type.name,
1124
1189
  )
1125
- return instance.started_at.replace(tzinfo=datetime.timezone.utc) + timeout_interval
1190
+ return instance.started_at + timeout_interval
1126
1191
 
1127
1192
 
1128
1193
  def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
@@ -9,12 +9,13 @@ from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
9
9
  from dstack._internal.core.models.runs import JobStatus
10
10
  from dstack._internal.server import settings
11
11
  from dstack._internal.server.db import get_session_ctx
12
- from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel
12
+ from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel, ProjectModel
13
13
  from dstack._internal.server.schemas.runner import MetricsResponse
14
14
  from dstack._internal.server.services.instances import get_instance_ssh_private_keys
15
15
  from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
16
16
  from dstack._internal.server.services.runner import client
17
17
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
18
+ from dstack._internal.server.utils import sentry_utils
18
19
  from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
19
20
  from dstack._internal.utils.logging import get_logger
20
21
 
@@ -26,12 +27,17 @@ BATCH_SIZE = 10
26
27
  MIN_COLLECT_INTERVAL_SECONDS = 9
27
28
 
28
29
 
30
+ @sentry_utils.instrument_background_task
29
31
  async def collect_metrics():
30
32
  async with get_session_ctx() as session:
31
33
  res = await session.execute(
32
34
  select(JobModel)
33
35
  .where(JobModel.status.in_([JobStatus.RUNNING]))
34
- .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
36
+ .options(
37
+ joinedload(JobModel.instance)
38
+ .joinedload(InstanceModel.project)
39
+ .load_only(ProjectModel.ssh_private_key)
40
+ )
35
41
  .order_by(JobModel.last_processed_at.asc())
36
42
  .limit(MAX_JOBS_FETCHED)
37
43
  )
@@ -41,6 +47,7 @@ async def collect_metrics():
41
47
  await _collect_jobs_metrics(batch)
42
48
 
43
49
 
50
+ @sentry_utils.instrument_background_task
44
51
  async def delete_metrics():
45
52
  now_timestamp_micro = int(get_current_datetime().timestamp() * 1_000_000)
46
53
  running_timestamp_micro_cutoff = (
@@ -12,12 +12,14 @@ from dstack._internal.server.models import PlacementGroupModel, ProjectModel
12
12
  from dstack._internal.server.services import backends as backends_services
13
13
  from dstack._internal.server.services.locking import get_locker
14
14
  from dstack._internal.server.services.placement import placement_group_model_to_placement_group
15
+ from dstack._internal.server.utils import sentry_utils
15
16
  from dstack._internal.utils.common import get_current_datetime, run_async
16
17
  from dstack._internal.utils.logging import get_logger
17
18
 
18
19
  logger = get_logger(__name__)
19
20
 
20
21
 
22
+ @sentry_utils.instrument_background_task
21
23
  async def process_placement_groups():
22
24
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(
23
25
  PlacementGroupModel.__tablename__
@@ -0,0 +1,164 @@
1
+ from collections.abc import AsyncGenerator
2
+ from contextlib import asynccontextmanager
3
+ from datetime import timedelta
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+
8
+ import httpx
9
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
10
+ from httpx import AsyncClient, AsyncHTTPTransport
11
+ from sqlalchemy import select, update
12
+ from sqlalchemy.orm import joinedload
13
+
14
+ from dstack._internal.core.errors import SSHError
15
+ from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec
16
+ from dstack._internal.core.services.ssh.tunnel import (
17
+ SSH_DEFAULT_OPTIONS,
18
+ IPSocket,
19
+ SocketPair,
20
+ UnixSocket,
21
+ )
22
+ from dstack._internal.server.db import get_db, get_session_ctx
23
+ from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
24
+ from dstack._internal.server.services.locking import get_locker
25
+ from dstack._internal.server.services.logging import fmt
26
+ from dstack._internal.server.services.ssh import container_ssh_tunnel
27
+ from dstack._internal.utils.common import get_current_datetime, get_or_error
28
+ from dstack._internal.utils.logging import get_logger
29
+
30
+ logger = get_logger(__name__)
31
+ BATCH_SIZE = 100
32
+ SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
33
+ PROCESSING_OVERHEAD_TIMEOUT = timedelta(minutes=1)
34
+ PROBES_SCHEDULER = AsyncIOScheduler()
35
+
36
+
37
+ async def process_probes():
38
+ probe_lock, probe_lockset = get_locker(get_db().dialect_name).get_lockset(
39
+ ProbeModel.__tablename__
40
+ )
41
+ async with get_session_ctx() as session:
42
+ async with probe_lock:
43
+ res = await session.execute(
44
+ select(ProbeModel.id)
45
+ .where(ProbeModel.id.not_in(probe_lockset))
46
+ .where(ProbeModel.active == True)
47
+ .where(ProbeModel.due <= get_current_datetime())
48
+ .order_by(ProbeModel.due.asc())
49
+ .limit(BATCH_SIZE)
50
+ .with_for_update(skip_locked=True, key_share=True)
51
+ )
52
+ probe_ids = res.unique().scalars().all()
53
+ probe_lockset.update(probe_ids)
54
+
55
+ try:
56
+ # Refetch to load all attributes.
57
+ # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
58
+ res = await session.execute(
59
+ select(ProbeModel)
60
+ .where(ProbeModel.id.in_(probe_ids))
61
+ .options(
62
+ joinedload(ProbeModel.job)
63
+ .joinedload(JobModel.instance)
64
+ .joinedload(InstanceModel.project)
65
+ )
66
+ .options(joinedload(ProbeModel.job))
67
+ .execution_options(populate_existing=True)
68
+ )
69
+ probes = res.unique().scalars().all()
70
+ for probe in probes:
71
+ if probe.job.status != JobStatus.RUNNING:
72
+ probe.active = False
73
+ else:
74
+ job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data)
75
+ probe_spec = job_spec.probes[probe.probe_num]
76
+ # Schedule the next probe execution in case this execution is interrupted
77
+ probe.due = get_current_datetime() + _get_probe_async_processing_timeout(
78
+ probe_spec
79
+ )
80
+ # Execute the probe asynchronously outside of the DB session
81
+ PROBES_SCHEDULER.add_job(partial(_process_probe_async, probe, probe_spec))
82
+ await session.commit()
83
+ finally:
84
+ probe_lockset.difference_update(probe_ids)
85
+
86
+
87
+ async def _process_probe_async(probe: ProbeModel, probe_spec: ProbeSpec) -> None:
88
+ start = get_current_datetime()
89
+ logger.debug("%s: processing probe", fmt(probe))
90
+ success = await _execute_probe(probe, probe_spec)
91
+
92
+ async with get_session_ctx() as session:
93
+ async with get_locker(get_db().dialect_name).lock_ctx(
94
+ ProbeModel.__tablename__, [probe.id]
95
+ ):
96
+ await session.execute(
97
+ update(ProbeModel)
98
+ .where(ProbeModel.id == probe.id)
99
+ .values(
100
+ success_streak=0 if not success else ProbeModel.success_streak + 1,
101
+ due=get_current_datetime() + timedelta(seconds=probe_spec.interval),
102
+ )
103
+ )
104
+ logger.debug(
105
+ "%s: probe processing took %ss",
106
+ fmt(probe),
107
+ (get_current_datetime() - start).total_seconds(),
108
+ )
109
+
110
+
111
+ async def _execute_probe(probe: ProbeModel, probe_spec: ProbeSpec) -> bool:
112
+ """
113
+ Returns:
114
+ Whether probe execution was successful.
115
+ """
116
+
117
+ try:
118
+ async with _get_service_replica_client(probe.job) as client:
119
+ resp = await client.request(
120
+ method=probe_spec.method,
121
+ url="http://dstack" + probe_spec.url,
122
+ headers=[(h.name, h.value) for h in probe_spec.headers],
123
+ data=probe_spec.body,
124
+ timeout=probe_spec.timeout,
125
+ follow_redirects=False,
126
+ )
127
+ logger.debug("%s: probe status code: %s", fmt(probe), resp.status_code)
128
+ return resp.is_success
129
+ except (SSHError, httpx.RequestError) as e:
130
+ logger.debug("%s: probe failed: %r", fmt(probe), e)
131
+ return False
132
+
133
+
134
+ def _get_probe_async_processing_timeout(probe_spec: ProbeSpec) -> timedelta:
135
+ return (
136
+ timedelta(seconds=probe_spec.timeout)
137
+ + SSH_CONNECT_TIMEOUT
138
+ + PROCESSING_OVERHEAD_TIMEOUT # slow db queries and other unforeseen conditions
139
+ )
140
+
141
+
142
+ @asynccontextmanager
143
+ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClient, None]:
144
+ options = {
145
+ **SSH_DEFAULT_OPTIONS,
146
+ "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
147
+ }
148
+ job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
149
+ with TemporaryDirectory() as temp_dir:
150
+ app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
151
+ async with container_ssh_tunnel(
152
+ job=job,
153
+ forwarded_sockets=[
154
+ SocketPair(
155
+ remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
156
+ local=UnixSocket(app_socket_path),
157
+ ),
158
+ ],
159
+ options=options,
160
+ ):
161
+ async with AsyncClient(
162
+ transport=AsyncHTTPTransport(uds=str(app_socket_path))
163
+ ) as client:
164
+ yield client