dstack 0.19.20__py3-none-any.whl → 0.19.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (93) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -3
  2. dstack/_internal/cli/services/configurators/__init__.py +8 -0
  3. dstack/_internal/cli/services/configurators/fleet.py +1 -1
  4. dstack/_internal/cli/services/configurators/gateway.py +1 -1
  5. dstack/_internal/cli/services/configurators/run.py +11 -1
  6. dstack/_internal/cli/services/configurators/volume.py +1 -1
  7. dstack/_internal/cli/utils/common.py +48 -5
  8. dstack/_internal/cli/utils/fleet.py +5 -5
  9. dstack/_internal/cli/utils/run.py +32 -0
  10. dstack/_internal/core/backends/__init__.py +0 -65
  11. dstack/_internal/core/backends/configurators.py +9 -0
  12. dstack/_internal/core/backends/features.py +64 -0
  13. dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
  14. dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
  15. dstack/_internal/core/backends/hotaisle/backend.py +16 -0
  16. dstack/_internal/core/backends/hotaisle/compute.py +225 -0
  17. dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
  18. dstack/_internal/core/backends/hotaisle/models.py +45 -0
  19. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  20. dstack/_internal/core/backends/models.py +8 -0
  21. dstack/_internal/core/compatibility/fleets.py +2 -0
  22. dstack/_internal/core/compatibility/runs.py +12 -0
  23. dstack/_internal/core/models/backends/base.py +2 -0
  24. dstack/_internal/core/models/configurations.py +139 -1
  25. dstack/_internal/core/models/health.py +28 -0
  26. dstack/_internal/core/models/instances.py +2 -0
  27. dstack/_internal/core/models/logs.py +2 -1
  28. dstack/_internal/core/models/profiles.py +37 -0
  29. dstack/_internal/core/models/runs.py +21 -1
  30. dstack/_internal/core/services/ssh/tunnel.py +7 -0
  31. dstack/_internal/server/app.py +26 -10
  32. dstack/_internal/server/background/__init__.py +9 -6
  33. dstack/_internal/server/background/tasks/process_fleets.py +52 -38
  34. dstack/_internal/server/background/tasks/process_gateways.py +2 -2
  35. dstack/_internal/server/background/tasks/process_idle_volumes.py +5 -4
  36. dstack/_internal/server/background/tasks/process_instances.py +168 -103
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -2
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +2 -0
  39. dstack/_internal/server/background/tasks/process_probes.py +164 -0
  40. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +14 -2
  41. dstack/_internal/server/background/tasks/process_running_jobs.py +142 -124
  42. dstack/_internal/server/background/tasks/process_runs.py +84 -34
  43. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -10
  44. dstack/_internal/server/background/tasks/process_terminating_jobs.py +12 -4
  45. dstack/_internal/server/background/tasks/process_volumes.py +4 -1
  46. dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
  47. dstack/_internal/server/migrations/versions/50dd7ea98639_index_status_columns.py +55 -0
  48. dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
  49. dstack/_internal/server/migrations/versions/ec02a26a256c_add_runmodel_next_triggered_at.py +38 -0
  50. dstack/_internal/server/models.py +57 -16
  51. dstack/_internal/server/routers/instances.py +33 -5
  52. dstack/_internal/server/schemas/health/dcgm.py +56 -0
  53. dstack/_internal/server/schemas/instances.py +32 -0
  54. dstack/_internal/server/schemas/runner.py +5 -0
  55. dstack/_internal/server/services/fleets.py +19 -10
  56. dstack/_internal/server/services/gateways/__init__.py +17 -17
  57. dstack/_internal/server/services/instances.py +113 -15
  58. dstack/_internal/server/services/jobs/__init__.py +18 -13
  59. dstack/_internal/server/services/jobs/configurators/base.py +26 -0
  60. dstack/_internal/server/services/logging.py +4 -2
  61. dstack/_internal/server/services/logs/aws.py +13 -1
  62. dstack/_internal/server/services/logs/gcp.py +16 -1
  63. dstack/_internal/server/services/offers.py +3 -3
  64. dstack/_internal/server/services/probes.py +6 -0
  65. dstack/_internal/server/services/projects.py +51 -19
  66. dstack/_internal/server/services/prometheus/client_metrics.py +3 -0
  67. dstack/_internal/server/services/prometheus/custom_metrics.py +2 -3
  68. dstack/_internal/server/services/runner/client.py +52 -20
  69. dstack/_internal/server/services/runner/ssh.py +4 -4
  70. dstack/_internal/server/services/runs.py +115 -39
  71. dstack/_internal/server/services/services/__init__.py +4 -1
  72. dstack/_internal/server/services/ssh.py +66 -0
  73. dstack/_internal/server/services/users.py +2 -3
  74. dstack/_internal/server/services/volumes.py +11 -11
  75. dstack/_internal/server/settings.py +16 -0
  76. dstack/_internal/server/statics/index.html +1 -1
  77. dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
  78. dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
  79. dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
  80. dstack/_internal/server/testing/common.py +51 -0
  81. dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
  82. dstack/_internal/server/utils/sentry_utils.py +12 -0
  83. dstack/_internal/settings.py +3 -0
  84. dstack/_internal/utils/common.py +15 -0
  85. dstack/_internal/utils/cron.py +5 -0
  86. dstack/api/server/__init__.py +1 -1
  87. dstack/version.py +1 -1
  88. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/METADATA +13 -22
  89. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/RECORD +93 -75
  90. /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
  91. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/WHEEL +0 -0
  92. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/entry_points.txt +0 -0
  93. {dstack-0.19.20.dist-info → dstack-0.19.22.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,5 +1,7 @@
1
1
  import itertools
2
2
  import operator
3
+ import urllib
4
+ import urllib.parse
3
5
  from contextlib import contextmanager
4
6
  from datetime import datetime, timedelta, timezone
5
7
  from typing import Iterator, List, Optional, Set, Tuple, TypedDict
@@ -64,6 +66,7 @@ class CloudWatchLogStorage(LogStorage):
64
66
  self._client = session.client("logs")
65
67
  self._check_group_exists(group)
66
68
  self._group = group
69
+ self._region = self._client.meta.region_name
67
70
  # Stores names of already created streams.
68
71
  # XXX: This set acts as an unbound cache. If this becomes a problem (in case of _very_ long
69
72
  # running server and/or lots of jobs, consider replacing it with an LRU cache, e.g.,
@@ -103,7 +106,11 @@ class CloudWatchLogStorage(LogStorage):
103
106
  )
104
107
  for cw_event in cw_events
105
108
  ]
106
- return JobSubmissionLogs(logs=logs, next_token=next_token)
109
+ return JobSubmissionLogs(
110
+ logs=logs,
111
+ external_url=self._get_stream_external_url(stream),
112
+ next_token=next_token,
113
+ )
107
114
 
108
115
  def _get_log_events_with_retry(
109
116
  self, stream: str, request: PollLogsRequest
@@ -181,6 +188,11 @@ class CloudWatchLogStorage(LogStorage):
181
188
 
182
189
  return events, next_token
183
190
 
191
+ def _get_stream_external_url(self, stream: str) -> str:
192
+ quoted_group = urllib.parse.quote(self._group, safe="")
193
+ quoted_stream = urllib.parse.quote(stream, safe="")
194
+ return f"https://console.aws.amazon.com/cloudwatch/home?region={self._region}#logsV2:log-groups/log-group/{quoted_group}/log-events/{quoted_stream}"
195
+
184
196
  def write_logs(
185
197
  self,
186
198
  project: ProjectModel,
@@ -1,3 +1,4 @@
1
+ import urllib.parse
1
2
  from typing import List
2
3
  from uuid import UUID
3
4
 
@@ -48,6 +49,7 @@ class GCPLogStorage(LogStorage):
48
49
  # (https://cloud.google.com/logging/docs/analyze/custom-index).
49
50
 
50
51
  def __init__(self, project_id: str):
52
+ self.project_id = project_id
51
53
  try:
52
54
  self.client = logging_v2.Client(project=project_id)
53
55
  self.logger = self.client.logger(name=self.LOG_NAME)
@@ -106,7 +108,11 @@ class GCPLogStorage(LogStorage):
106
108
  "GCP Logging read request limit exceeded."
107
109
  " It's recommended to increase default entries.list request quota from 60 per minute."
108
110
  )
109
- return JobSubmissionLogs(logs=logs, next_token=next_token if len(logs) > 0 else None)
111
+ return JobSubmissionLogs(
112
+ logs=logs,
113
+ external_url=self._get_stream_extrnal_url(stream_name),
114
+ next_token=next_token if len(logs) > 0 else None,
115
+ )
110
116
 
111
117
  def write_logs(
112
118
  self,
@@ -162,3 +168,12 @@ class GCPLogStorage(LogStorage):
162
168
  self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer
163
169
  ) -> str:
164
170
  return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}"
171
+
172
+ def _get_stream_extrnal_url(self, stream_name: str) -> str:
173
+ log_name_resource_name = self._get_log_name_resource_name()
174
+ query = f'logName="{log_name_resource_name}" AND labels.stream="{stream_name}"'
175
+ quoted_query = urllib.parse.quote(query, safe="")
176
+ return f"https://console.cloud.google.com/logs/query;query={quoted_query}?project={self.project_id}"
177
+
178
+ def _get_log_name_resource_name(self) -> str:
179
+ return f"projects/{self.project_id}/logs/{self.LOG_NAME}"
@@ -2,13 +2,13 @@ from typing import List, Literal, Optional, Tuple, Union
2
2
 
3
3
  import gpuhunt
4
4
 
5
- from dstack._internal.core.backends import (
5
+ from dstack._internal.core.backends.base.backend import Backend
6
+ from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport
7
+ from dstack._internal.core.backends.features import (
6
8
  BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
7
9
  BACKENDS_WITH_MULTINODE_SUPPORT,
8
10
  BACKENDS_WITH_RESERVATION_SUPPORT,
9
11
  )
10
- from dstack._internal.core.backends.base.backend import Backend
11
- from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport
12
12
  from dstack._internal.core.models.backends.base import BackendType
13
13
  from dstack._internal.core.models.instances import (
14
14
  InstanceOfferWithAvailability,
@@ -0,0 +1,6 @@
1
+ from dstack._internal.core.models.runs import Probe
2
+ from dstack._internal.server.models import ProbeModel
3
+
4
+
5
+ def probe_model_to_probe(probe_model: ProbeModel) -> Probe:
6
+ return Probe(success_streak=probe_model.success_streak)
@@ -1,11 +1,10 @@
1
1
  import uuid
2
- from datetime import timezone
3
2
  from typing import Awaitable, Callable, List, Optional, Tuple
4
3
 
5
4
  from sqlalchemy import delete, select, update
6
5
  from sqlalchemy import func as safunc
7
6
  from sqlalchemy.ext.asyncio import AsyncSession
8
- from sqlalchemy.orm import joinedload
7
+ from sqlalchemy.orm import QueryableAttribute, joinedload, load_only
9
8
 
10
9
  from dstack._internal.core.backends.configurators import get_configurator
11
10
  from dstack._internal.core.backends.dstack.models import (
@@ -54,13 +53,12 @@ async def list_user_projects(
54
53
  user: UserModel,
55
54
  ) -> List[Project]:
56
55
  """
57
- Returns projects where the user is a member.
56
+ Returns projects where the user is a member or all projects for global admins.
58
57
  """
59
- if user.global_role == GlobalRole.ADMIN:
60
- projects = await list_project_models(session=session)
61
- else:
62
- projects = await list_user_project_models(session=session, user=user)
63
-
58
+ projects = await list_user_project_models(
59
+ session=session,
60
+ user=user,
61
+ )
64
62
  projects = sorted(projects, key=lambda p: p.created_at)
65
63
  return [
66
64
  project_model_to_project(p, include_backends=False, include_members=False)
@@ -80,7 +78,7 @@ async def list_user_accessible_projects(
80
78
  if user.global_role == GlobalRole.ADMIN:
81
79
  projects = await list_project_models(session=session)
82
80
  else:
83
- member_projects = await list_user_project_models(session=session, user=user)
81
+ member_projects = await list_member_project_models(session=session, user=user)
84
82
  public_projects = await list_public_non_member_project_models(session=session, user=user)
85
83
  projects = member_projects + public_projects
86
84
 
@@ -167,7 +165,7 @@ async def delete_projects(
167
165
  projects_names: List[str],
168
166
  ):
169
167
  if user.global_role != GlobalRole.ADMIN:
170
- user_projects = await list_user_project_models(
168
+ user_projects = await list_member_project_models(
171
169
  session=session, user=user, include_members=True
172
170
  )
173
171
  user_project_names = [p.name for p in user_projects]
@@ -199,6 +197,10 @@ async def set_project_members(
199
197
  project: ProjectModel,
200
198
  members: List[MemberSetting],
201
199
  ):
200
+ usernames = {m.username for m in members}
201
+ if len(usernames) != len(members):
202
+ raise ServerClientError("Cannot add same user multiple times")
203
+
202
204
  project = await get_project_model_by_name_or_error(
203
205
  session=session,
204
206
  project_name=project.name,
@@ -247,6 +249,10 @@ async def add_project_members(
247
249
  members: List[MemberSetting],
248
250
  ):
249
251
  """Add multiple members to a project."""
252
+ usernames = {m.username for m in members}
253
+ if len(usernames) != len(members):
254
+ raise ServerClientError("Cannot add same user multiple times")
255
+
250
256
  project = await get_project_model_by_name_or_error(
251
257
  session=session,
252
258
  project_name=project.name,
@@ -261,7 +267,10 @@ async def add_project_members(
261
267
  )
262
268
 
263
269
  if not is_self_join_to_public:
264
- if requesting_user_role not in [ProjectRole.ADMIN, ProjectRole.MANAGER]:
270
+ if user.global_role != GlobalRole.ADMIN and requesting_user_role not in [
271
+ ProjectRole.ADMIN,
272
+ ProjectRole.MANAGER,
273
+ ]:
265
274
  raise ForbiddenError("Access denied: insufficient permissions to add members")
266
275
 
267
276
  if user.global_role != GlobalRole.ADMIN and requesting_user_role == ProjectRole.MANAGER:
@@ -274,8 +283,6 @@ async def add_project_members(
274
283
  if members[0].project_role != ProjectRole.USER:
275
284
  raise ForbiddenError("Access denied: can only join public projects as user role")
276
285
 
277
- usernames = [member.username for member in members]
278
-
279
286
  res = await session.execute(
280
287
  select(UserModel).where((UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)))
281
288
  )
@@ -339,9 +346,25 @@ async def clear_project_members(
339
346
 
340
347
 
341
348
  async def list_user_project_models(
349
+ session: AsyncSession,
350
+ user: UserModel,
351
+ only_names: bool = False,
352
+ ) -> List[ProjectModel]:
353
+ load_only_attrs = []
354
+ if only_names:
355
+ load_only_attrs += [ProjectModel.id, ProjectModel.name]
356
+ if user.global_role == GlobalRole.ADMIN:
357
+ return await list_project_models(session=session, load_only_attrs=load_only_attrs)
358
+ return await list_member_project_models(
359
+ session=session, user=user, load_only_attrs=load_only_attrs
360
+ )
361
+
362
+
363
+ async def list_member_project_models(
342
364
  session: AsyncSession,
343
365
  user: UserModel,
344
366
  include_members: bool = False,
367
+ load_only_attrs: Optional[List[QueryableAttribute]] = None,
345
368
  ) -> List[ProjectModel]:
346
369
  """
347
370
  List project models for a user where they are a member.
@@ -349,6 +372,8 @@ async def list_user_project_models(
349
372
  options = []
350
373
  if include_members:
351
374
  options.append(joinedload(ProjectModel.members))
375
+ if load_only_attrs:
376
+ options.append(load_only(*load_only_attrs))
352
377
  res = await session.execute(
353
378
  select(ProjectModel)
354
379
  .where(
@@ -395,13 +420,20 @@ async def list_user_owned_project_models(
395
420
 
396
421
  async def list_project_models(
397
422
  session: AsyncSession,
423
+ load_only_attrs: Optional[List[QueryableAttribute]] = None,
398
424
  ) -> List[ProjectModel]:
425
+ options = []
426
+ if load_only_attrs:
427
+ options.append(load_only(*load_only_attrs))
399
428
  res = await session.execute(
400
- select(ProjectModel).where(ProjectModel.deleted == False),
429
+ select(ProjectModel).where(ProjectModel.deleted == False).options(*options)
401
430
  )
402
431
  return list(res.scalars().all())
403
432
 
404
433
 
434
+ # TODO: Do not load ProjectModel.backends and ProjectModel.members by default when getting project
435
+
436
+
405
437
  async def get_project_model_by_name(
406
438
  session: AsyncSession, project_name: str, ignore_case: bool = True
407
439
  ) -> Optional[ProjectModel]:
@@ -415,7 +447,6 @@ async def get_project_model_by_name(
415
447
  .where(*filters)
416
448
  .options(joinedload(ProjectModel.backends))
417
449
  .options(joinedload(ProjectModel.members))
418
- .options(joinedload(ProjectModel.default_gateway))
419
450
  )
420
451
  return res.unique().scalar()
421
452
 
@@ -432,7 +463,6 @@ async def get_project_model_by_name_or_error(
432
463
  )
433
464
  .options(joinedload(ProjectModel.backends))
434
465
  .options(joinedload(ProjectModel.members))
435
- .options(joinedload(ProjectModel.default_gateway))
436
466
  )
437
467
  return res.unique().scalar_one()
438
468
 
@@ -449,7 +479,6 @@ async def get_project_model_by_id_or_error(
449
479
  )
450
480
  .options(joinedload(ProjectModel.backends))
451
481
  .options(joinedload(ProjectModel.members))
452
- .options(joinedload(ProjectModel.default_gateway))
453
482
  )
454
483
  return res.unique().scalar_one()
455
484
 
@@ -537,7 +566,7 @@ def project_model_to_project(
537
566
  project_id=project_model.id,
538
567
  project_name=project_model.name,
539
568
  owner=users.user_model_to_user(project_model.owner),
540
- created_at=project_model.created_at.replace(tzinfo=timezone.utc),
569
+ created_at=project_model.created_at,
541
570
  backends=backends,
542
571
  members=members,
543
572
  is_public=project_model.is_public,
@@ -608,7 +637,10 @@ async def remove_project_members(
608
637
  )
609
638
 
610
639
  if not is_self_leave:
611
- if requesting_user_role not in [ProjectRole.ADMIN, ProjectRole.MANAGER]:
640
+ if user.global_role != GlobalRole.ADMIN and requesting_user_role not in [
641
+ ProjectRole.ADMIN,
642
+ ProjectRole.MANAGER,
643
+ ]:
612
644
  raise ForbiddenError("Access denied: insufficient permissions to remove members")
613
645
 
614
646
  res = await session.execute(
@@ -5,6 +5,9 @@ class RunMetrics:
5
5
  """Wrapper class for run-related Prometheus metrics."""
6
6
 
7
7
  def __init__(self):
8
+ # submit_to_provision_duration reflects real provisioning time
9
+ # but does not reflect how quickly provisioning processing works
10
+ # since it includes scheduling time, retrying, etc.
8
11
  self._submit_to_provision_duration = Histogram(
9
12
  "dstack_submit_to_provision_duration_seconds",
10
13
  "Time from when a run has been submitted and first job provisioning",
@@ -2,7 +2,6 @@ import itertools
2
2
  import json
3
3
  from collections import defaultdict
4
4
  from collections.abc import Generator, Iterable
5
- from datetime import timezone
6
5
  from typing import ClassVar
7
6
  from uuid import UUID
8
7
 
@@ -80,7 +79,7 @@ async def get_instance_metrics(session: AsyncSession) -> Iterable[Metric]:
80
79
  "dstack_backend": instance.backend.value if instance.backend is not None else "",
81
80
  "dstack_gpu": gpu,
82
81
  }
83
- duration = (now - instance.created_at.replace(tzinfo=timezone.utc)).total_seconds()
82
+ duration = (now - instance.created_at).total_seconds()
84
83
  metrics.add_sample(_INSTANCE_DURATION, labels, duration)
85
84
  metrics.add_sample(_INSTANCE_PRICE, labels, instance.price or 0.0)
86
85
  metrics.add_sample(_INSTANCE_GPU_COUNT, labels, gpu_count)
@@ -167,7 +166,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]:
167
166
  "dstack_backend": jpd.get_base_backend().value,
168
167
  "dstack_gpu": gpus[0].name if gpus else "",
169
168
  }
170
- duration = (now - job.submitted_at.replace(tzinfo=timezone.utc)).total_seconds()
169
+ duration = (now - job.submitted_at).total_seconds()
171
170
  metrics.add_sample(_JOB_DURATION, labels, duration)
172
171
  metrics.add_sample(_JOB_PRICE, labels, price)
173
172
  metrics.add_sample(_JOB_GPU_COUNT, labels, len(gpus))
@@ -1,7 +1,6 @@
1
1
  import uuid
2
- from dataclasses import dataclass
3
2
  from http import HTTPStatus
4
- from typing import BinaryIO, Dict, List, Optional, TypeVar, Union
3
+ from typing import BinaryIO, Dict, List, Literal, Optional, TypeVar, Union, overload
5
4
 
6
5
  import packaging.version
7
6
  import requests
@@ -14,9 +13,11 @@ from dstack._internal.core.models.repos.remote import RemoteRepoCreds
14
13
  from dstack._internal.core.models.resources import Memory
15
14
  from dstack._internal.core.models.runs import ClusterInfo, Job, Run
16
15
  from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
16
+ from dstack._internal.server.schemas.instances import InstanceCheck
17
17
  from dstack._internal.server.schemas.runner import (
18
18
  GPUDevice,
19
19
  HealthcheckResponse,
20
+ InstanceHealthResponse,
20
21
  LegacyPullResponse,
21
22
  LegacyStopBody,
22
23
  LegacySubmitBody,
@@ -37,15 +38,6 @@ UPLOAD_CODE_REQUEST_TIMEOUT = 60
37
38
  logger = get_logger(__name__)
38
39
 
39
40
 
40
- @dataclass
41
- class HealthStatus:
42
- healthy: bool
43
- reason: str
44
-
45
- def __str__(self) -> str:
46
- return self.reason
47
-
48
-
49
41
  class RunnerClient:
50
42
  def __init__(
51
43
  self,
@@ -193,6 +185,9 @@ class ShimClient:
193
185
  # API v1 (a.k.a. Legacy API) — `/api/{submit,pull,stop}`
194
186
  _API_V2_MIN_SHIM_VERSION = (0, 18, 34)
195
187
 
188
+ # `/api/instance/health`
189
+ _INSTANCE_HEALTH_MIN_SHIM_VERSION = (0, 19, 22)
190
+
196
191
  _shim_version: Optional["_Version"]
197
192
  _api_version: int
198
193
  _negotiated: bool = False
@@ -212,11 +207,25 @@ class ShimClient:
212
207
  self._negotiate()
213
208
  return self._api_version == 2
214
209
 
215
- def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckResponse]:
210
+ def is_instance_health_supported(self) -> bool:
211
+ if not self._negotiated:
212
+ self._negotiate()
213
+ return (
214
+ self._shim_version is None
215
+ or self._shim_version >= self._INSTANCE_HEALTH_MIN_SHIM_VERSION
216
+ )
217
+
218
+ @overload
219
+ def healthcheck(self) -> Optional[HealthcheckResponse]: ...
220
+
221
+ @overload
222
+ def healthcheck(self, unmask_exceptions: Literal[True]) -> HealthcheckResponse: ...
223
+
224
+ def healthcheck(self, unmask_exceptions: bool = False) -> Optional[HealthcheckResponse]:
216
225
  try:
217
226
  resp = self._request("GET", "/api/healthcheck", raise_for_status=True)
218
227
  except requests.exceptions.RequestException:
219
- if unmask_exeptions:
228
+ if unmask_exceptions:
220
229
  raise
221
230
  return None
222
231
  if not self._negotiated:
@@ -225,6 +234,17 @@ class ShimClient:
225
234
 
226
235
  # API v2 methods
227
236
 
237
+ def get_instance_health(self) -> Optional[InstanceHealthResponse]:
238
+ if not self.is_instance_health_supported():
239
+ logger.debug("instance health is not supported: %s", self._shim_version)
240
+ return None
241
+ resp = self._request("GET", "/api/instance/health")
242
+ if resp.status_code == HTTPStatus.NOT_FOUND:
243
+ logger.warning("instance health: %s", resp.text)
244
+ return None
245
+ self._raise_for_status(resp)
246
+ return self._response(InstanceHealthResponse, resp)
247
+
228
248
  def get_task(self, task_id: "_TaskID") -> TaskInfoResponse:
229
249
  if not self.is_api_v2_supported():
230
250
  raise ShimAPIVersionError()
@@ -418,14 +438,26 @@ class ShimClient:
418
438
  self._negotiated = True
419
439
 
420
440
 
421
- def health_response_to_health_status(data: HealthcheckResponse) -> HealthStatus:
422
- if data.service == "dstack-shim":
423
- return HealthStatus(healthy=True, reason="Service is OK")
424
- else:
425
- return HealthStatus(
426
- healthy=False,
427
- reason=f"Service name is {data.service}, service version: {data.version}",
441
+ def healthcheck_response_to_instance_check(
442
+ response: HealthcheckResponse,
443
+ instance_health_response: Optional[InstanceHealthResponse] = None,
444
+ ) -> InstanceCheck:
445
+ if response.service == "dstack-shim":
446
+ message: Optional[str] = None
447
+ if (
448
+ instance_health_response is not None
449
+ and instance_health_response.dcgm is not None
450
+ and instance_health_response.dcgm.incidents
451
+ ):
452
+ message = instance_health_response.dcgm.incidents[0].error_message
453
+ return InstanceCheck(
454
+ reachable=True, health_response=instance_health_response, message=message
428
455
  )
456
+ return InstanceCheck(
457
+ reachable=False,
458
+ message=f"unexpected service: {response.service} version: {response.version}",
459
+ health_response=instance_health_response,
460
+ )
429
461
 
430
462
 
431
463
  def _volume_to_shim_volume_info(volume: Volume, instance_id: str) -> ShimVolumeInfo:
@@ -2,7 +2,7 @@ import functools
2
2
  import socket
3
3
  import time
4
4
  from collections.abc import Iterable
5
- from typing import Callable, Dict, List, Optional, TypeVar, Union
5
+ from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union
6
6
 
7
7
  import requests
8
8
  from typing_extensions import Concatenate, ParamSpec
@@ -27,7 +27,7 @@ def runner_ssh_tunnel(
27
27
  [Callable[Concatenate[Dict[int, int], P], R]],
28
28
  Callable[
29
29
  Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
30
- Union[bool, R],
30
+ Union[Literal[False], R],
31
31
  ],
32
32
  ]:
33
33
  """
@@ -42,7 +42,7 @@ def runner_ssh_tunnel(
42
42
  func: Callable[Concatenate[Dict[int, int], P], R],
43
43
  ) -> Callable[
44
44
  Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
45
- Union[bool, R],
45
+ Union[Literal[False], R],
46
46
  ]:
47
47
  @functools.wraps(func)
48
48
  def wrapper(
@@ -51,7 +51,7 @@ def runner_ssh_tunnel(
51
51
  job_runtime_data: Optional[JobRuntimeData],
52
52
  *args: P.args,
53
53
  **kwargs: P.kwargs,
54
- ) -> Union[bool, R]:
54
+ ) -> Union[Literal[False], R]:
55
55
  """
56
56
  Returns:
57
57
  is successful