dstack 0.19.34__py3-none-any.whl → 0.19.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (41) hide show
  1. dstack/_internal/cli/services/configurators/run.py +1 -1
  2. dstack/_internal/core/backends/base/compute.py +20 -1
  3. dstack/_internal/core/backends/base/models.py +10 -0
  4. dstack/_internal/core/backends/base/offers.py +1 -0
  5. dstack/_internal/core/backends/features.py +5 -0
  6. dstack/_internal/core/backends/nebius/compute.py +28 -16
  7. dstack/_internal/core/backends/nebius/configurator.py +1 -1
  8. dstack/_internal/core/backends/nebius/models.py +4 -0
  9. dstack/_internal/core/backends/nebius/resources.py +41 -20
  10. dstack/_internal/core/backends/runpod/api_client.py +245 -59
  11. dstack/_internal/core/backends/runpod/compute.py +157 -13
  12. dstack/_internal/core/models/compute_groups.py +39 -0
  13. dstack/_internal/core/models/fleets.py +6 -1
  14. dstack/_internal/core/models/profiles.py +3 -1
  15. dstack/_internal/core/models/runs.py +3 -0
  16. dstack/_internal/server/app.py +14 -2
  17. dstack/_internal/server/background/__init__.py +7 -0
  18. dstack/_internal/server/background/tasks/process_compute_groups.py +164 -0
  19. dstack/_internal/server/background/tasks/process_instances.py +81 -49
  20. dstack/_internal/server/background/tasks/process_submitted_jobs.py +179 -84
  21. dstack/_internal/server/migrations/env.py +20 -2
  22. dstack/_internal/server/migrations/versions/7d1ec2b920ac_add_computegroupmodel.py +93 -0
  23. dstack/_internal/server/models.py +39 -0
  24. dstack/_internal/server/routers/runs.py +15 -6
  25. dstack/_internal/server/services/compute_groups.py +22 -0
  26. dstack/_internal/server/services/fleets.py +1 -0
  27. dstack/_internal/server/services/jobs/__init__.py +13 -0
  28. dstack/_internal/server/services/jobs/configurators/base.py +3 -2
  29. dstack/_internal/server/services/requirements/combine.py +1 -0
  30. dstack/_internal/server/services/runs.py +17 -3
  31. dstack/_internal/server/testing/common.py +51 -0
  32. dstack/_internal/server/utils/routers.py +18 -20
  33. dstack/_internal/settings.py +4 -1
  34. dstack/_internal/utils/version.py +22 -0
  35. dstack/version.py +1 -1
  36. {dstack-0.19.34.dist-info → dstack-0.19.35.dist-info}/METADATA +3 -3
  37. {dstack-0.19.34.dist-info → dstack-0.19.35.dist-info}/RECORD +40 -36
  38. dstack/_internal/core/backends/nebius/fabrics.py +0 -49
  39. {dstack-0.19.34.dist-info → dstack-0.19.35.dist-info}/WHEEL +0 -0
  40. {dstack-0.19.34.dist-info → dstack-0.19.35.dist-info}/entry_points.txt +0 -0
  41. {dstack-0.19.34.dist-info → dstack-0.19.35.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,22 @@
1
+ from dstack._internal.core.models.compute_groups import ComputeGroup, ComputeGroupProvisioningData
2
+ from dstack._internal.server.models import ComputeGroupModel
3
+
4
+
5
+ def compute_group_model_to_compute_group(compute_group_model: ComputeGroupModel) -> ComputeGroup:
6
+ provisioning_data = get_compute_group_provisioning_data(compute_group_model)
7
+ return ComputeGroup(
8
+ id=compute_group_model.id,
9
+ project_name=compute_group_model.project.name,
10
+ status=compute_group_model.status,
11
+ name=provisioning_data.compute_group_name,
12
+ created_at=compute_group_model.created_at,
13
+ provisioning_data=provisioning_data,
14
+ )
15
+
16
+
17
+ def get_compute_group_provisioning_data(
18
+ compute_group_model: ComputeGroupModel,
19
+ ) -> ComputeGroupProvisioningData:
20
+ return ComputeGroupProvisioningData.__response__.parse_raw(
21
+ compute_group_model.provisioning_data
22
+ )
@@ -650,6 +650,7 @@ def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
650
650
  max_price=profile.max_price,
651
651
  spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND),
652
652
  reservation=fleet_spec.configuration.reservation,
653
+ multinode=fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER,
653
654
  )
654
655
  return requirements
655
656
 
@@ -96,6 +96,19 @@ def find_job(jobs: List[Job], replica_num: int, job_num: int) -> Job:
96
96
  )
97
97
 
98
98
 
99
+ def find_jobs(
100
+ jobs: List[Job],
101
+ replica_num: Optional[int] = None,
102
+ job_num: Optional[int] = None,
103
+ ) -> list[Job]:
104
+ res = jobs
105
+ if replica_num is not None:
106
+ res = [j for j in res if j.job_spec.replica_num == replica_num]
107
+ if job_num is not None:
108
+ res = [j for j in res if j.job_spec.job_num == job_num]
109
+ return res
110
+
111
+
99
112
  async def get_run_job_model(
100
113
  session: AsyncSession,
101
114
  project: ProjectModel,
@@ -161,7 +161,7 @@ class JobConfigurator(ABC):
161
161
  stop_duration=self._stop_duration(),
162
162
  utilization_policy=self._utilization_policy(),
163
163
  registry_auth=self._registry_auth(),
164
- requirements=self._requirements(),
164
+ requirements=self._requirements(jobs_per_replica),
165
165
  retry=self._retry(),
166
166
  working_dir=self._working_dir(),
167
167
  volumes=self._volumes(job_num),
@@ -295,13 +295,14 @@ class JobConfigurator(ABC):
295
295
  def _registry_auth(self) -> Optional[RegistryAuth]:
296
296
  return self.run_spec.configuration.registry_auth
297
297
 
298
- def _requirements(self) -> Requirements:
298
+ def _requirements(self, jobs_per_replica: int) -> Requirements:
299
299
  spot_policy = self._spot_policy()
300
300
  return Requirements(
301
301
  resources=self.run_spec.configuration.resources,
302
302
  max_price=self.run_spec.merged_profile.max_price,
303
303
  spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT),
304
304
  reservation=self.run_spec.merged_profile.reservation,
305
+ multinode=jobs_per_replica > 1,
305
306
  )
306
307
 
307
308
  def _retry(self) -> Optional[Retry]:
@@ -63,6 +63,7 @@ def combine_fleet_and_run_requirements(
63
63
  reservation=_get_single_value_optional(
64
64
  fleet_requirements.reservation, run_requirements.reservation
65
65
  ),
66
+ multinode=fleet_requirements.multinode or run_requirements.multinode,
66
67
  )
67
68
  except CombineError:
68
69
  return None
@@ -34,6 +34,7 @@ from dstack._internal.core.models.profiles import (
34
34
  )
35
35
  from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData
36
36
  from dstack._internal.core.models.runs import (
37
+ LEGACY_REPO_DIR,
37
38
  ApplyRunPlanInput,
38
39
  Job,
39
40
  JobPlan,
@@ -308,6 +309,7 @@ async def get_plan(
308
309
  user: UserModel,
309
310
  run_spec: RunSpec,
310
311
  max_offers: Optional[int],
312
+ legacy_default_working_dir: bool = False,
311
313
  ) -> RunPlan:
312
314
  # Spec must be copied by parsing to calculate merged_profile
313
315
  effective_run_spec = RunSpec.parse_obj(run_spec.dict())
@@ -317,7 +319,11 @@ async def get_plan(
317
319
  spec=effective_run_spec,
318
320
  )
319
321
  effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
320
- _validate_run_spec_and_set_defaults(user, effective_run_spec)
322
+ _validate_run_spec_and_set_defaults(
323
+ user=user,
324
+ run_spec=effective_run_spec,
325
+ legacy_default_working_dir=legacy_default_working_dir,
326
+ )
321
327
 
322
328
  profile = effective_run_spec.merged_profile
323
329
  creation_policy = profile.creation_policy
@@ -413,6 +419,7 @@ async def apply_plan(
413
419
  project: ProjectModel,
414
420
  plan: ApplyRunPlanInput,
415
421
  force: bool,
422
+ legacy_default_working_dir: bool = False,
416
423
  ) -> Run:
417
424
  run_spec = plan.run_spec
418
425
  run_spec = await apply_plugin_policies(
@@ -422,7 +429,9 @@ async def apply_plan(
422
429
  )
423
430
  # Spec must be copied by parsing to calculate merged_profile
424
431
  run_spec = RunSpec.parse_obj(run_spec.dict())
425
- _validate_run_spec_and_set_defaults(user, run_spec)
432
+ _validate_run_spec_and_set_defaults(
433
+ user=user, run_spec=run_spec, legacy_default_working_dir=legacy_default_working_dir
434
+ )
426
435
  if run_spec.run_name is None:
427
436
  return await submit_run(
428
437
  session=session,
@@ -600,6 +609,7 @@ def create_job_model_for_new_submission(
600
609
  job_spec_data=job.job_spec.json(),
601
610
  job_provisioning_data=None,
602
611
  probes=[],
612
+ waiting_master_job=job.job_spec.job_num != 0,
603
613
  )
604
614
 
605
615
 
@@ -985,7 +995,9 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float:
985
995
  return job_submission.job_provisioning_data.price * duration_hours
986
996
 
987
997
 
988
- def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
998
+ def _validate_run_spec_and_set_defaults(
999
+ user: UserModel, run_spec: RunSpec, legacy_default_working_dir: bool = False
1000
+ ):
989
1001
  # This function may set defaults for null run_spec values,
990
1002
  # although most defaults are resolved when building job_spec
991
1003
  # so that we can keep both the original user-supplied value (null in run_spec)
@@ -1040,6 +1052,8 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
1040
1052
  run_spec.ssh_key_pub = user.ssh_public_key
1041
1053
  else:
1042
1054
  raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key")
1055
+ if run_spec.configuration.working_dir is None and legacy_default_working_dir:
1056
+ run_spec.configuration.working_dir = LEGACY_REPO_DIR
1043
1057
 
1044
1058
 
1045
1059
  _UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]
@@ -13,6 +13,7 @@ from dstack._internal.core.backends.base.compute import (
13
13
  Compute,
14
14
  ComputeWithCreateInstanceSupport,
15
15
  ComputeWithGatewaySupport,
16
+ ComputeWithGroupProvisioningSupport,
16
17
  ComputeWithMultinodeSupport,
17
18
  ComputeWithPlacementGroupSupport,
18
19
  ComputeWithPrivateGatewaySupport,
@@ -22,6 +23,10 @@ from dstack._internal.core.backends.base.compute import (
22
23
  )
23
24
  from dstack._internal.core.models.backends.base import BackendType
24
25
  from dstack._internal.core.models.common import NetworkMode
26
+ from dstack._internal.core.models.compute_groups import (
27
+ ComputeGroupProvisioningData,
28
+ ComputeGroupStatus,
29
+ )
25
30
  from dstack._internal.core.models.configurations import (
26
31
  AnyRunConfiguration,
27
32
  DevEnvironmentConfiguration,
@@ -83,6 +88,7 @@ from dstack._internal.core.models.volumes import (
83
88
  )
84
89
  from dstack._internal.server.models import (
85
90
  BackendModel,
91
+ ComputeGroupModel,
86
92
  DecryptedString,
87
93
  FileArchiveModel,
88
94
  FleetModel,
@@ -353,6 +359,7 @@ async def create_job(
353
359
  instance_assigned: bool = False,
354
360
  disconnected_at: Optional[datetime] = None,
355
361
  registered: bool = False,
362
+ waiting_master_job: Optional[bool] = None,
356
363
  ) -> JobModel:
357
364
  if deployment_num is None:
358
365
  deployment_num = run.deployment_num
@@ -384,6 +391,7 @@ async def create_job(
384
391
  disconnected_at=disconnected_at,
385
392
  probes=[],
386
393
  registered=registered,
394
+ waiting_master_job=waiting_master_job,
387
395
  )
388
396
  session.add(job)
389
397
  await session.commit()
@@ -455,6 +463,48 @@ def get_job_runtime_data(
455
463
  )
456
464
 
457
465
 
466
+ def get_compute_group_provisioning_data(
467
+ compute_group_id: str = "test_compute_group",
468
+ compute_group_name: str = "test_compute_group",
469
+ backend: BackendType = BackendType.RUNPOD,
470
+ region: str = "US",
471
+ job_provisioning_datas: Optional[list[JobProvisioningData]] = None,
472
+ backend_data: Optional[str] = None,
473
+ ) -> ComputeGroupProvisioningData:
474
+ if job_provisioning_datas is None:
475
+ job_provisioning_datas = []
476
+ return ComputeGroupProvisioningData(
477
+ compute_group_id=compute_group_id,
478
+ compute_group_name=compute_group_name,
479
+ backend=backend,
480
+ region=region,
481
+ job_provisioning_datas=job_provisioning_datas,
482
+ backend_data=backend_data,
483
+ )
484
+
485
+
486
+ async def create_compute_group(
487
+ session: AsyncSession,
488
+ project: ProjectModel,
489
+ fleet: FleetModel,
490
+ status: ComputeGroupStatus = ComputeGroupStatus.RUNNING,
491
+ provisioning_data: Optional[ComputeGroupProvisioningData] = None,
492
+ last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
493
+ ):
494
+ if provisioning_data is None:
495
+ provisioning_data = get_compute_group_provisioning_data()
496
+ compute_group = ComputeGroupModel(
497
+ project=project,
498
+ fleet=fleet,
499
+ status=status,
500
+ provisioning_data=provisioning_data.json(),
501
+ last_processed_at=last_processed_at,
502
+ )
503
+ session.add(compute_group)
504
+ await session.commit()
505
+ return compute_group
506
+
507
+
458
508
  async def create_probe(
459
509
  session: AsyncSession,
460
510
  job: JobModel,
@@ -1136,6 +1186,7 @@ class AsyncContextManager:
1136
1186
  class ComputeMockSpec(
1137
1187
  Compute,
1138
1188
  ComputeWithCreateInstanceSupport,
1189
+ ComputeWithGroupProvisioningSupport,
1139
1190
  ComputeWithPrivilegedSupport,
1140
1191
  ComputeWithMultinodeSupport,
1141
1192
  ComputeWithReservationSupport,
@@ -1,12 +1,13 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
3
  import orjson
4
+ import packaging.version
4
5
  from fastapi import HTTPException, Request, Response, status
5
- from packaging import version
6
6
 
7
7
  from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode
8
8
  from dstack._internal.core.models.common import CoreModel
9
9
  from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default
10
+ from dstack._internal.utils.version import parse_version
10
11
 
11
12
 
12
13
  class CustomORJSONResponse(Response):
@@ -122,8 +123,15 @@ def get_request_size(request: Request) -> int:
122
123
  return int(request.headers["content-length"])
123
124
 
124
125
 
126
+ def get_client_version(request: Request) -> Optional[packaging.version.Version]:
127
+ version = request.headers.get("x-api-version")
128
+ if version is None:
129
+ return None
130
+ return parse_version(version)
131
+
132
+
125
133
  def check_client_server_compatibility(
126
- client_version: Optional[str],
134
+ client_version: Optional[packaging.version.Version],
127
135
  server_version: Optional[str],
128
136
  ) -> Optional[CustomORJSONResponse]:
129
137
  """
@@ -132,28 +140,18 @@ def check_client_server_compatibility(
132
140
  """
133
141
  if client_version is None or server_version is None:
134
142
  return None
135
- parsed_server_version = version.parse(server_version)
136
- # latest allows client to bypass compatibility check (e.g. frontend)
137
- if client_version == "latest":
143
+ parsed_server_version = parse_version(server_version)
144
+ if parsed_server_version is None:
138
145
  return None
139
- try:
140
- parsed_client_version = version.parse(client_version)
141
- except version.InvalidVersion:
142
- return CustomORJSONResponse(
143
- status_code=status.HTTP_400_BAD_REQUEST,
144
- content={
145
- "detail": get_server_client_error_details(
146
- ServerClientError("Bad API version specified")
147
- )
148
- },
149
- )
150
146
  # We preserve full client backward compatibility across patch releases.
151
147
  # Server is always partially backward-compatible (so no check).
152
- if parsed_client_version > parsed_server_version and (
153
- parsed_client_version.major > parsed_server_version.major
154
- or parsed_client_version.minor > parsed_server_version.minor
148
+ if client_version > parsed_server_version and (
149
+ client_version.major > parsed_server_version.major
150
+ or client_version.minor > parsed_server_version.minor
155
151
  ):
156
- return error_incompatible_versions(client_version, server_version, ask_cli_update=False)
152
+ return error_incompatible_versions(
153
+ str(client_version), server_version, ask_cli_update=False
154
+ )
157
155
  return None
158
156
 
159
157
 
@@ -1,9 +1,10 @@
1
1
  import os
2
2
 
3
3
  from dstack import version
4
+ from dstack._internal.utils.version import parse_version
4
5
 
5
6
  DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__)
6
- if DSTACK_VERSION == "0.0.0":
7
+ if parse_version(DSTACK_VERSION) is None:
7
8
  # The build backend (hatching) requires not None for versions,
8
9
  # but the code currently treats None as dev version.
9
10
  # TODO: update the code to treat 0.0.0 as dev version.
@@ -33,3 +34,5 @@ class FeatureFlags:
33
34
  large features. This class may be empty if there are no such features in
34
35
  development. Feature flags are environment variables of the form DSTACK_FF_*
35
36
  """
37
+
38
+ pass
@@ -0,0 +1,22 @@
1
+ from typing import Optional
2
+
3
+ import packaging.version
4
+
5
+
6
+ def parse_version(version_string: str) -> Optional[packaging.version.Version]:
7
+ """
8
+ Returns a `packaging.version.Version` instance or `None` if the version is dev/latest.
9
+
10
+ Values parsed as the dev/latest version:
11
+ * the "latest" literal
12
+ * any "0.0.0" release, e.g., "0.0.0", "0.0.0a1", "0.0.0.dev0"
13
+ """
14
+ if version_string == "latest":
15
+ return None
16
+ try:
17
+ version = packaging.version.parse(version_string)
18
+ except packaging.version.InvalidVersion as e:
19
+ raise ValueError(f"Invalid version: {version_string}") from e
20
+ if version.release == (0, 0, 0):
21
+ return None
22
+ return version
dstack/version.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.19.34"
1
+ __version__ = "0.19.35"
2
2
  __is_release__ = True
3
3
  base_image = "0.11"
4
4
  base_image_ubuntu_version = "22.04"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dstack
3
- Version: 0.19.34
3
+ Version: 0.19.35
4
4
  Summary: dstack is an open-source orchestration engine for running AI workloads on any cloud or on-premises.
5
5
  Project-URL: Homepage, https://dstack.ai
6
6
  Project-URL: Source, https://github.com/dstackai/dstack
@@ -73,7 +73,7 @@ Requires-Dist: grpcio>=1.50; extra == 'all'
73
73
  Requires-Dist: httpx; extra == 'all'
74
74
  Requires-Dist: jinja2; extra == 'all'
75
75
  Requires-Dist: kubernetes; extra == 'all'
76
- Requires-Dist: nebius<=0.2.72,>=0.2.40; (python_version >= '3.10') and extra == 'all'
76
+ Requires-Dist: nebius<0.4,>=0.3.4; (python_version >= '3.10') and extra == 'all'
77
77
  Requires-Dist: oci>=2.150.0; extra == 'all'
78
78
  Requires-Dist: prometheus-client; extra == 'all'
79
79
  Requires-Dist: pyopenssl>=23.2.0; extra == 'all'
@@ -259,7 +259,7 @@ Requires-Dist: fastapi; extra == 'nebius'
259
259
  Requires-Dist: grpcio>=1.50; extra == 'nebius'
260
260
  Requires-Dist: httpx; extra == 'nebius'
261
261
  Requires-Dist: jinja2; extra == 'nebius'
262
- Requires-Dist: nebius<=0.2.72,>=0.2.40; (python_version >= '3.10') and extra == 'nebius'
262
+ Requires-Dist: nebius<0.4,>=0.3.4; (python_version >= '3.10') and extra == 'nebius'
263
263
  Requires-Dist: prometheus-client; extra == 'nebius'
264
264
  Requires-Dist: python-dxf==12.1.0; extra == 'nebius'
265
265
  Requires-Dist: python-json-logger>=3.1.0; extra == 'nebius'