dstack 0.19.21__py3-none-any.whl → 0.19.23rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (71) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -3
  2. dstack/_internal/cli/services/configurators/__init__.py +8 -0
  3. dstack/_internal/cli/services/configurators/fleet.py +1 -1
  4. dstack/_internal/cli/services/configurators/gateway.py +1 -1
  5. dstack/_internal/cli/services/configurators/run.py +11 -1
  6. dstack/_internal/cli/services/configurators/volume.py +1 -1
  7. dstack/_internal/cli/utils/common.py +48 -5
  8. dstack/_internal/cli/utils/fleet.py +5 -5
  9. dstack/_internal/cli/utils/run.py +32 -0
  10. dstack/_internal/core/backends/configurators.py +9 -0
  11. dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
  12. dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
  13. dstack/_internal/core/backends/hotaisle/backend.py +16 -0
  14. dstack/_internal/core/backends/hotaisle/compute.py +225 -0
  15. dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
  16. dstack/_internal/core/backends/hotaisle/models.py +45 -0
  17. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  18. dstack/_internal/core/backends/models.py +8 -0
  19. dstack/_internal/core/backends/nebius/compute.py +8 -2
  20. dstack/_internal/core/backends/nebius/fabrics.py +1 -0
  21. dstack/_internal/core/backends/nebius/resources.py +9 -0
  22. dstack/_internal/core/compatibility/runs.py +8 -0
  23. dstack/_internal/core/models/backends/base.py +2 -0
  24. dstack/_internal/core/models/configurations.py +139 -1
  25. dstack/_internal/core/models/health.py +28 -0
  26. dstack/_internal/core/models/instances.py +2 -0
  27. dstack/_internal/core/models/logs.py +2 -1
  28. dstack/_internal/core/models/runs.py +21 -1
  29. dstack/_internal/core/services/ssh/tunnel.py +7 -0
  30. dstack/_internal/server/app.py +4 -0
  31. dstack/_internal/server/background/__init__.py +4 -0
  32. dstack/_internal/server/background/tasks/process_instances.py +107 -56
  33. dstack/_internal/server/background/tasks/process_probes.py +164 -0
  34. dstack/_internal/server/background/tasks/process_running_jobs.py +13 -0
  35. dstack/_internal/server/background/tasks/process_runs.py +21 -14
  36. dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
  37. dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
  38. dstack/_internal/server/models.py +41 -0
  39. dstack/_internal/server/routers/instances.py +33 -5
  40. dstack/_internal/server/schemas/health/dcgm.py +56 -0
  41. dstack/_internal/server/schemas/instances.py +32 -0
  42. dstack/_internal/server/schemas/runner.py +5 -0
  43. dstack/_internal/server/services/instances.py +103 -1
  44. dstack/_internal/server/services/jobs/__init__.py +8 -1
  45. dstack/_internal/server/services/jobs/configurators/base.py +26 -0
  46. dstack/_internal/server/services/logging.py +4 -2
  47. dstack/_internal/server/services/logs/aws.py +13 -1
  48. dstack/_internal/server/services/logs/gcp.py +16 -1
  49. dstack/_internal/server/services/probes.py +6 -0
  50. dstack/_internal/server/services/projects.py +16 -4
  51. dstack/_internal/server/services/runner/client.py +52 -20
  52. dstack/_internal/server/services/runner/ssh.py +4 -4
  53. dstack/_internal/server/services/runs.py +49 -13
  54. dstack/_internal/server/services/ssh.py +66 -0
  55. dstack/_internal/server/settings.py +13 -0
  56. dstack/_internal/server/statics/index.html +1 -1
  57. dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
  58. dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
  59. dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
  60. dstack/_internal/server/testing/common.py +44 -0
  61. dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
  62. dstack/_internal/settings.py +3 -0
  63. dstack/_internal/utils/common.py +15 -0
  64. dstack/api/server/__init__.py +1 -1
  65. dstack/version.py +1 -1
  66. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/METADATA +14 -14
  67. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/RECORD +71 -58
  68. /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
  69. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/WHEEL +0 -0
  70. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/entry_points.txt +0 -0
  71. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,13 +1,14 @@
1
1
  import asyncio
2
2
  import datetime
3
+ import logging
3
4
  from datetime import timedelta
4
- from typing import Any, Dict, List, Optional, Tuple, cast
5
+ from typing import Any, Dict, List, Optional, cast
5
6
 
6
7
  import requests
7
8
  from paramiko.pkey import PKey
8
9
  from paramiko.ssh_exception import PasswordRequiredException
9
10
  from pydantic import ValidationError
10
- from sqlalchemy import select
11
+ from sqlalchemy import delete, func, select
11
12
  from sqlalchemy.ext.asyncio import AsyncSession
12
13
  from sqlalchemy.orm import joinedload
13
14
 
@@ -27,18 +28,6 @@ from dstack._internal.core.backends.features import (
27
28
  BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
28
29
  BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
29
30
  )
30
- from dstack._internal.core.backends.remote.provisioning import (
31
- detect_cpu_arch,
32
- get_host_info,
33
- get_paramiko_connection,
34
- get_shim_healthcheck,
35
- host_info_to_instance_type,
36
- remove_dstack_runner_if_exists,
37
- remove_host_info_if_exists,
38
- run_pre_start_commands,
39
- run_shim_as_systemd_service,
40
- upload_envs,
41
- )
42
31
  from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
43
32
 
44
33
  # FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute
@@ -77,12 +66,14 @@ from dstack._internal.server.background.tasks.common import get_provisioning_tim
77
66
  from dstack._internal.server.db import get_db, get_session_ctx
78
67
  from dstack._internal.server.models import (
79
68
  FleetModel,
69
+ InstanceHealthCheckModel,
80
70
  InstanceModel,
81
71
  JobModel,
82
72
  PlacementGroupModel,
83
73
  ProjectModel,
84
74
  )
85
- from dstack._internal.server.schemas.runner import HealthcheckResponse
75
+ from dstack._internal.server.schemas.instances import InstanceCheck
76
+ from dstack._internal.server.schemas.runner import HealthcheckResponse, InstanceHealthResponse
86
77
  from dstack._internal.server.services import backends as backends_services
87
78
  from dstack._internal.server.services.fleets import (
88
79
  fleet_model_to_fleet,
@@ -103,9 +94,20 @@ from dstack._internal.server.services.placement import (
103
94
  schedule_fleet_placement_groups_deletion,
104
95
  )
105
96
  from dstack._internal.server.services.runner import client as runner_client
106
- from dstack._internal.server.services.runner.client import HealthStatus
107
97
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
108
98
  from dstack._internal.server.utils import sentry_utils
99
+ from dstack._internal.server.utils.provisioning import (
100
+ detect_cpu_arch,
101
+ get_host_info,
102
+ get_paramiko_connection,
103
+ get_shim_healthcheck,
104
+ host_info_to_instance_type,
105
+ remove_dstack_runner_if_exists,
106
+ remove_host_info_if_exists,
107
+ run_pre_start_commands,
108
+ run_shim_as_systemd_service,
109
+ upload_envs,
110
+ )
109
111
  from dstack._internal.utils.common import (
110
112
  get_current_datetime,
111
113
  get_or_error,
@@ -137,6 +139,17 @@ async def process_instances(batch_size: int = 1):
137
139
  await asyncio.gather(*tasks)
138
140
 
139
141
 
142
+ @sentry_utils.instrument_background_task
143
+ async def delete_instance_health_checks():
144
+ now = get_current_datetime()
145
+ cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS)
146
+ async with get_session_ctx() as session:
147
+ await session.execute(
148
+ delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff)
149
+ )
150
+ await session.commit()
151
+
152
+
140
153
  @sentry_utils.instrument_background_task
141
154
  async def _process_next_instance():
142
155
  lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__)
@@ -415,10 +428,10 @@ async def _add_remote(instance: InstanceModel) -> None:
415
428
 
416
429
  def _deploy_instance(
417
430
  remote_details: RemoteConnectionInfo,
418
- pkeys: List[PKey],
431
+ pkeys: list[PKey],
419
432
  ssh_proxy_pkeys: Optional[list[PKey]],
420
- authorized_keys: List[str],
421
- ) -> Tuple[HealthStatus, Dict[str, Any], GoArchType]:
433
+ authorized_keys: list[str],
434
+ ) -> tuple[InstanceCheck, dict[str, Any], GoArchType]:
422
435
  with get_paramiko_connection(
423
436
  remote_details.ssh_user,
424
437
  remote_details.host,
@@ -466,14 +479,14 @@ def _deploy_instance(
466
479
  host_info = get_host_info(client, dstack_working_dir)
467
480
  logger.debug("Received a host_info %s", host_info)
468
481
 
469
- raw_health = get_shim_healthcheck(client)
482
+ healthcheck_out = get_shim_healthcheck(client)
470
483
  try:
471
- health_response = HealthcheckResponse.__response__.parse_raw(raw_health)
484
+ healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out)
472
485
  except ValueError as e:
473
- raise ProvisioningError("Cannot read HealthcheckResponse") from e
474
- health = runner_client.health_response_to_health_status(health_response)
486
+ raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e
487
+ instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck)
475
488
 
476
- return health, host_info, arch
489
+ return instance_check, host_info, arch
477
490
 
478
491
 
479
492
  async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
@@ -758,29 +771,65 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
758
771
 
759
772
  ssh_private_keys = get_instance_ssh_private_keys(instance)
760
773
 
774
+ health_check_cutoff = get_current_datetime() - timedelta(
775
+ seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS
776
+ )
777
+ res = await session.execute(
778
+ select(func.count(1)).where(
779
+ InstanceHealthCheckModel.instance_id == instance.id,
780
+ InstanceHealthCheckModel.collected_at > health_check_cutoff,
781
+ )
782
+ )
783
+ check_instance_health = res.scalar_one() == 0
784
+
761
785
  # May return False if fails to establish ssh connection
762
- health_status_response = await run_async(
763
- _instance_healthcheck,
786
+ instance_check = await run_async(
787
+ _check_instance_inner,
764
788
  ssh_private_keys,
765
789
  job_provisioning_data,
766
790
  None,
791
+ check_instance_health=check_instance_health,
767
792
  )
768
- if isinstance(health_status_response, bool) or health_status_response is None:
769
- health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
770
- else:
771
- health_status = health_status_response
793
+ if instance_check is False:
794
+ instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error")
772
795
 
773
- logger.debug(
774
- "Check instance %s status. shim health: %s",
796
+ if instance_check.reachable and check_instance_health:
797
+ health_status = instance_check.get_health_status()
798
+ else:
799
+ # Keep previous health status
800
+ health_status = instance.health
801
+
802
+ loglevel = logging.DEBUG
803
+ if not instance_check.reachable and instance.status.is_available():
804
+ loglevel = logging.WARNING
805
+ elif check_instance_health and not health_status.is_healthy():
806
+ loglevel = logging.WARNING
807
+ logger.log(
808
+ loglevel,
809
+ "Instance %s check: reachable=%s health_status=%s message=%r",
775
810
  instance.name,
776
- health_status,
777
- extra={"instance_name": instance.name, "shim_health": health_status},
811
+ instance_check.reachable,
812
+ health_status.name,
813
+ instance_check.message,
814
+ extra={"instance_name": instance.name, "health_status": health_status},
778
815
  )
779
816
 
780
- if health_status.healthy:
817
+ if instance_check.has_health_checks():
818
+ # ensured by has_health_checks()
819
+ assert instance_check.health_response is not None
820
+ health_check_model = InstanceHealthCheckModel(
821
+ instance_id=instance.id,
822
+ collected_at=get_current_datetime(),
823
+ status=health_status,
824
+ response=instance_check.health_response.json(),
825
+ )
826
+ session.add(health_check_model)
827
+
828
+ instance.health = health_status
829
+ instance.unreachable = not instance_check.reachable
830
+
831
+ if instance_check.reachable:
781
832
  instance.termination_deadline = None
782
- instance.health_status = None
783
- instance.unreachable = False
784
833
 
785
834
  if instance.status == InstanceStatus.PROVISIONING:
786
835
  instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
@@ -798,9 +847,6 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
798
847
  if instance.termination_deadline is None:
799
848
  instance.termination_deadline = get_current_datetime() + TERMINATION_DEADLINE_OFFSET
800
849
 
801
- instance.health_status = health_status.reason
802
- instance.unreachable = True
803
-
804
850
  if instance.status == InstanceStatus.PROVISIONING and instance.started_at is not None:
805
851
  provisioning_deadline = _get_provisioning_deadline(
806
852
  instance=instance,
@@ -816,12 +862,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
816
862
  "instance_status": InstanceStatus.TERMINATING.value,
817
863
  },
818
864
  )
819
- elif instance.status in (InstanceStatus.IDLE, InstanceStatus.BUSY):
820
- logger.warning(
821
- "Instance %s shim is not available",
822
- instance.name,
823
- extra={"instance_name": instance.name},
824
- )
865
+ elif instance.status.is_available():
825
866
  deadline = instance.termination_deadline
826
867
  if get_current_datetime() > deadline:
827
868
  instance.status = InstanceStatus.TERMINATING
@@ -892,20 +933,30 @@ async def _wait_for_instance_provisioning_data(
892
933
 
893
934
 
894
935
  @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
895
- def _instance_healthcheck(ports: Dict[int, int]) -> HealthStatus:
936
+ def _check_instance_inner(
937
+ ports: Dict[int, int], *, check_instance_health: bool = False
938
+ ) -> InstanceCheck:
939
+ instance_health_response: Optional[InstanceHealthResponse] = None
896
940
  shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
941
+ method = shim_client.healthcheck
897
942
  try:
898
- resp = shim_client.healthcheck(unmask_exeptions=True)
899
- if resp is None:
900
- return HealthStatus(healthy=False, reason="Unknown reason")
901
- return runner_client.health_response_to_health_status(resp)
943
+ healthcheck_response = method(unmask_exceptions=True)
944
+ if check_instance_health:
945
+ method = shim_client.get_instance_health
946
+ instance_health_response = method()
902
947
  except requests.RequestException as e:
903
- return HealthStatus(healthy=False, reason=f"Can't request shim: {e}")
948
+ template = "shim.%s(): request error: %s"
949
+ args = (method.__func__.__name__, e)
950
+ logger.debug(template, *args)
951
+ return InstanceCheck(reachable=False, message=template % args)
904
952
  except Exception as e:
905
- logger.exception("Unknown exception from shim.healthcheck: %s", e)
906
- return HealthStatus(
907
- healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}"
908
- )
953
+ template = "shim.%s(): unexpected exception %s: %s"
954
+ args = (method.__func__.__name__, e.__class__.__name__, e)
955
+ logger.exception(template, *args)
956
+ return InstanceCheck(reachable=False, message=template % args)
957
+ return runner_client.healthcheck_response_to_instance_check(
958
+ healthcheck_response, instance_health_response
959
+ )
909
960
 
910
961
 
911
962
  async def _terminate(instance: InstanceModel) -> None:
@@ -0,0 +1,164 @@
1
+ from collections.abc import AsyncGenerator
2
+ from contextlib import asynccontextmanager
3
+ from datetime import timedelta
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+
8
+ import httpx
9
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
10
+ from httpx import AsyncClient, AsyncHTTPTransport
11
+ from sqlalchemy import select, update
12
+ from sqlalchemy.orm import joinedload
13
+
14
+ from dstack._internal.core.errors import SSHError
15
+ from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec
16
+ from dstack._internal.core.services.ssh.tunnel import (
17
+ SSH_DEFAULT_OPTIONS,
18
+ IPSocket,
19
+ SocketPair,
20
+ UnixSocket,
21
+ )
22
+ from dstack._internal.server.db import get_db, get_session_ctx
23
+ from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
24
+ from dstack._internal.server.services.locking import get_locker
25
+ from dstack._internal.server.services.logging import fmt
26
+ from dstack._internal.server.services.ssh import container_ssh_tunnel
27
+ from dstack._internal.utils.common import get_current_datetime, get_or_error
28
+ from dstack._internal.utils.logging import get_logger
29
+
30
+ logger = get_logger(__name__)
31
+ BATCH_SIZE = 100
32
+ SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
33
+ PROCESSING_OVERHEAD_TIMEOUT = timedelta(minutes=1)
34
+ PROBES_SCHEDULER = AsyncIOScheduler()
35
+
36
+
37
+ async def process_probes():
38
+ probe_lock, probe_lockset = get_locker(get_db().dialect_name).get_lockset(
39
+ ProbeModel.__tablename__
40
+ )
41
+ async with get_session_ctx() as session:
42
+ async with probe_lock:
43
+ res = await session.execute(
44
+ select(ProbeModel.id)
45
+ .where(ProbeModel.id.not_in(probe_lockset))
46
+ .where(ProbeModel.active == True)
47
+ .where(ProbeModel.due <= get_current_datetime())
48
+ .order_by(ProbeModel.due.asc())
49
+ .limit(BATCH_SIZE)
50
+ .with_for_update(skip_locked=True, key_share=True)
51
+ )
52
+ probe_ids = res.unique().scalars().all()
53
+ probe_lockset.update(probe_ids)
54
+
55
+ try:
56
+ # Refetch to load all attributes.
57
+ # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
58
+ res = await session.execute(
59
+ select(ProbeModel)
60
+ .where(ProbeModel.id.in_(probe_ids))
61
+ .options(
62
+ joinedload(ProbeModel.job)
63
+ .joinedload(JobModel.instance)
64
+ .joinedload(InstanceModel.project)
65
+ )
66
+ .options(joinedload(ProbeModel.job))
67
+ .execution_options(populate_existing=True)
68
+ )
69
+ probes = res.unique().scalars().all()
70
+ for probe in probes:
71
+ if probe.job.status != JobStatus.RUNNING:
72
+ probe.active = False
73
+ else:
74
+ job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data)
75
+ probe_spec = job_spec.probes[probe.probe_num]
76
+ # Schedule the next probe execution in case this execution is interrupted
77
+ probe.due = get_current_datetime() + _get_probe_async_processing_timeout(
78
+ probe_spec
79
+ )
80
+ # Execute the probe asynchronously outside of the DB session
81
+ PROBES_SCHEDULER.add_job(partial(_process_probe_async, probe, probe_spec))
82
+ await session.commit()
83
+ finally:
84
+ probe_lockset.difference_update(probe_ids)
85
+
86
+
87
+ async def _process_probe_async(probe: ProbeModel, probe_spec: ProbeSpec) -> None:
88
+ start = get_current_datetime()
89
+ logger.debug("%s: processing probe", fmt(probe))
90
+ success = await _execute_probe(probe, probe_spec)
91
+
92
+ async with get_session_ctx() as session:
93
+ async with get_locker(get_db().dialect_name).lock_ctx(
94
+ ProbeModel.__tablename__, [probe.id]
95
+ ):
96
+ await session.execute(
97
+ update(ProbeModel)
98
+ .where(ProbeModel.id == probe.id)
99
+ .values(
100
+ success_streak=0 if not success else ProbeModel.success_streak + 1,
101
+ due=get_current_datetime() + timedelta(seconds=probe_spec.interval),
102
+ )
103
+ )
104
+ logger.debug(
105
+ "%s: probe processing took %ss",
106
+ fmt(probe),
107
+ (get_current_datetime() - start).total_seconds(),
108
+ )
109
+
110
+
111
+ async def _execute_probe(probe: ProbeModel, probe_spec: ProbeSpec) -> bool:
112
+ """
113
+ Returns:
114
+ Whether probe execution was successful.
115
+ """
116
+
117
+ try:
118
+ async with _get_service_replica_client(probe.job) as client:
119
+ resp = await client.request(
120
+ method=probe_spec.method,
121
+ url="http://dstack" + probe_spec.url,
122
+ headers=[(h.name, h.value) for h in probe_spec.headers],
123
+ data=probe_spec.body,
124
+ timeout=probe_spec.timeout,
125
+ follow_redirects=False,
126
+ )
127
+ logger.debug("%s: probe status code: %s", fmt(probe), resp.status_code)
128
+ return resp.is_success
129
+ except (SSHError, httpx.RequestError) as e:
130
+ logger.debug("%s: probe failed: %r", fmt(probe), e)
131
+ return False
132
+
133
+
134
+ def _get_probe_async_processing_timeout(probe_spec: ProbeSpec) -> timedelta:
135
+ return (
136
+ timedelta(seconds=probe_spec.timeout)
137
+ + SSH_CONNECT_TIMEOUT
138
+ + PROCESSING_OVERHEAD_TIMEOUT # slow db queries and other unforeseen conditions
139
+ )
140
+
141
+
142
+ @asynccontextmanager
143
+ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClient, None]:
144
+ options = {
145
+ **SSH_DEFAULT_OPTIONS,
146
+ "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
147
+ }
148
+ job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
149
+ with TemporaryDirectory() as temp_dir:
150
+ app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
151
+ async with container_ssh_tunnel(
152
+ job=job,
153
+ forwarded_sockets=[
154
+ SocketPair(
155
+ remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
156
+ local=UnixSocket(app_socket_path),
157
+ ),
158
+ ],
159
+ options=options,
160
+ ):
161
+ async with AsyncClient(
162
+ transport=AsyncHTTPTransport(uds=str(app_socket_path))
163
+ ) as client:
164
+ yield client
@@ -42,6 +42,7 @@ from dstack._internal.server.db import get_db, get_session_ctx
42
42
  from dstack._internal.server.models import (
43
43
  InstanceModel,
44
44
  JobModel,
45
+ ProbeModel,
45
46
  ProjectModel,
46
47
  RepoModel,
47
48
  RunModel,
@@ -414,6 +415,18 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
414
415
  )
415
416
  job_model.status = JobStatus.TERMINATING
416
417
  job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
418
+ else:
419
+ for probe_num in range(len(job.job_spec.probes)):
420
+ session.add(
421
+ ProbeModel(
422
+ name=f"{job_model.job_name}-{probe_num}",
423
+ job=job_model,
424
+ probe_num=probe_num,
425
+ due=common_utils.get_current_datetime(),
426
+ success_streak=0,
427
+ active=True,
428
+ )
429
+ )
417
430
 
418
431
  if job_model.status == JobStatus.RUNNING:
419
432
  await _check_gpu_utilization(session, job_model, job)
@@ -23,6 +23,7 @@ from dstack._internal.server.db import get_db, get_session_ctx
23
23
  from dstack._internal.server.models import (
24
24
  InstanceModel,
25
25
  JobModel,
26
+ ProbeModel,
26
27
  ProjectModel,
27
28
  RunModel,
28
29
  UserModel,
@@ -36,6 +37,7 @@ from dstack._internal.server.services.locking import get_locker
36
37
  from dstack._internal.server.services.prometheus.client_metrics import run_metrics
37
38
  from dstack._internal.server.services.runs import (
38
39
  fmt,
40
+ is_replica_ready,
39
41
  process_terminating_run,
40
42
  retry_run_replica_jobs,
41
43
  run_model_to_run,
@@ -149,6 +151,11 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
149
151
  .joinedload(JobModel.instance)
150
152
  .load_only(InstanceModel.fleet_id)
151
153
  )
154
+ .options(
155
+ selectinload(RunModel.jobs)
156
+ .joinedload(JobModel.probes)
157
+ .load_only(ProbeModel.success_streak)
158
+ )
152
159
  .execution_options(populate_existing=True)
153
160
  )
154
161
  run_model = res.unique().scalar_one()
@@ -472,22 +479,22 @@ async def _handle_run_replicas(
472
479
  )
473
480
 
474
481
  replicas_to_stop_count = 0
475
- # stop any out-of-date replicas that are not running
476
- replicas_to_stop_count += len(
477
- {
478
- j.replica_num
479
- for j in run_model.jobs
480
- if j.status
481
- not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
482
- and j.deployment_num < run_model.deployment_num
483
- }
482
+ # stop any out-of-date replicas that are not ready
483
+ replicas_to_stop_count += sum(
484
+ any(j.deployment_num < run_model.deployment_num for j in jobs)
485
+ and any(
486
+ j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses()
487
+ for j in jobs
488
+ )
489
+ and not is_replica_ready(jobs)
490
+ for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
484
491
  )
485
- running_replica_count = len(
486
- {j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
492
+ ready_replica_count = sum(
493
+ is_replica_ready(jobs) for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
487
494
  )
488
- if running_replica_count > run_model.desired_replica_count:
489
- # stop excessive running out-of-date replicas
490
- replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
495
+ if ready_replica_count > run_model.desired_replica_count:
496
+ # stop excessive ready out-of-date replicas
497
+ replicas_to_stop_count += ready_replica_count - run_model.desired_replica_count
491
498
  if replicas_to_stop_count:
492
499
  await scale_run_replicas(
493
500
  session,
@@ -0,0 +1,43 @@
1
+ """Add probes
2
+
3
+ Revision ID: 25479f540245
4
+ Revises: 50dd7ea98639
5
+ Create Date: 2025-08-03 19:51:07.722217
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlalchemy_utils
11
+ from alembic import op
12
+
13
+ import dstack._internal.server.models
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision = "25479f540245"
17
+ down_revision = "50dd7ea98639"
18
+ branch_labels = None
19
+ depends_on = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ # ### commands auto generated by Alembic - please adjust! ###
24
+ op.create_table(
25
+ "probes",
26
+ sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
27
+ sa.Column("name", sa.String(length=100), nullable=False),
28
+ sa.Column("job_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
29
+ sa.Column("probe_num", sa.Integer(), nullable=False),
30
+ sa.Column("due", dstack._internal.server.models.NaiveDateTime(), nullable=False),
31
+ sa.Column("success_streak", sa.BigInteger(), nullable=False),
32
+ sa.Column("active", sa.Boolean(), nullable=False),
33
+ sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], name=op.f("fk_probes_job_id_jobs")),
34
+ sa.PrimaryKeyConstraint("id", "job_id", name=op.f("pk_probes")),
35
+ sa.UniqueConstraint("job_id", "probe_num", name="uq_probes_job_id_probe_num"),
36
+ )
37
+ # ### end Alembic commands ###
38
+
39
+
40
+ def downgrade() -> None:
41
+ # ### commands auto generated by Alembic - please adjust! ###
42
+ op.drop_table("probes")
43
+ # ### end Alembic commands ###
@@ -0,0 +1,50 @@
1
+ """Add instance health
2
+
3
+ Revision ID: 728b1488b1b4
4
+ Revises: 25479f540245
5
+ Create Date: 2025-08-01 14:56:20.466990
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlalchemy_utils
11
+ from alembic import op
12
+
13
+ import dstack._internal.server.models
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision = "728b1488b1b4"
17
+ down_revision = "25479f540245"
18
+ branch_labels = None
19
+ depends_on = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ op.create_table(
24
+ "instance_health_checks",
25
+ sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
26
+ sa.Column(
27
+ "instance_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False
28
+ ),
29
+ sa.Column("collected_at", dstack._internal.server.models.NaiveDateTime(), nullable=False),
30
+ sa.Column("status", sa.VARCHAR(length=100), nullable=False),
31
+ sa.Column("response", sa.Text(), nullable=False),
32
+ sa.ForeignKeyConstraint(
33
+ ["instance_id"],
34
+ ["instances.id"],
35
+ name=op.f("fk_instance_health_checks_instance_id_instances"),
36
+ ),
37
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_instance_health_checks")),
38
+ )
39
+ with op.batch_alter_table("instances", schema=None) as batch_op:
40
+ batch_op.add_column(sa.Column("health", sa.VARCHAR(length=100), nullable=True))
41
+ op.execute("UPDATE instances SET health = 'HEALTHY'")
42
+ with op.batch_alter_table("instances", schema=None) as batch_op:
43
+ batch_op.alter_column("health", existing_type=sa.VARCHAR(length=100), nullable=False)
44
+
45
+
46
+ def downgrade() -> None:
47
+ with op.batch_alter_table("instances", schema=None) as batch_op:
48
+ batch_op.drop_column("health")
49
+
50
+ op.drop_table("instance_health_checks")