dstack 0.19.6rc1__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (69) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/fleet.py +3 -2
  3. dstack/_internal/cli/services/configurators/run.py +50 -4
  4. dstack/_internal/cli/utils/fleet.py +3 -1
  5. dstack/_internal/cli/utils/run.py +25 -28
  6. dstack/_internal/core/backends/aws/compute.py +13 -1
  7. dstack/_internal/core/backends/azure/compute.py +42 -13
  8. dstack/_internal/core/backends/azure/configurator.py +21 -0
  9. dstack/_internal/core/backends/azure/models.py +9 -0
  10. dstack/_internal/core/backends/base/compute.py +101 -27
  11. dstack/_internal/core/backends/base/offers.py +13 -3
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/gcp/auth.py +1 -1
  15. dstack/_internal/core/backends/gcp/compute.py +51 -35
  16. dstack/_internal/core/backends/gcp/resources.py +6 -1
  17. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  18. dstack/_internal/core/backends/local/compute.py +2 -0
  19. dstack/_internal/core/backends/nebius/compute.py +95 -1
  20. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  21. dstack/_internal/core/backends/nebius/fabrics.py +47 -0
  22. dstack/_internal/core/backends/nebius/models.py +8 -0
  23. dstack/_internal/core/backends/nebius/resources.py +29 -0
  24. dstack/_internal/core/backends/oci/compute.py +2 -0
  25. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  26. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  27. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  28. dstack/_internal/core/backends/vastai/compute.py +2 -1
  29. dstack/_internal/core/backends/vultr/compute.py +5 -1
  30. dstack/_internal/core/errors.py +4 -0
  31. dstack/_internal/core/models/fleets.py +2 -0
  32. dstack/_internal/core/models/instances.py +4 -3
  33. dstack/_internal/core/models/resources.py +80 -3
  34. dstack/_internal/core/models/runs.py +10 -3
  35. dstack/_internal/core/models/volumes.py +1 -1
  36. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  37. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  39. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +1 -1
  41. dstack/_internal/server/models.py +1 -0
  42. dstack/_internal/server/routers/gateways.py +2 -1
  43. dstack/_internal/server/services/config.py +7 -2
  44. dstack/_internal/server/services/fleets.py +24 -26
  45. dstack/_internal/server/services/gateways/__init__.py +17 -2
  46. dstack/_internal/server/services/instances.py +0 -2
  47. dstack/_internal/server/services/offers.py +15 -0
  48. dstack/_internal/server/services/placement.py +27 -6
  49. dstack/_internal/server/services/plugins.py +77 -0
  50. dstack/_internal/server/services/resources.py +21 -0
  51. dstack/_internal/server/services/runs.py +41 -17
  52. dstack/_internal/server/services/volumes.py +10 -1
  53. dstack/_internal/server/testing/common.py +35 -26
  54. dstack/_internal/utils/common.py +22 -9
  55. dstack/_internal/utils/json_schema.py +6 -3
  56. dstack/api/__init__.py +1 -0
  57. dstack/api/server/__init__.py +8 -1
  58. dstack/api/server/_fleets.py +16 -0
  59. dstack/api/server/_runs.py +44 -3
  60. dstack/plugins/__init__.py +8 -0
  61. dstack/plugins/_base.py +72 -0
  62. dstack/plugins/_models.py +8 -0
  63. dstack/plugins/_utils.py +19 -0
  64. dstack/version.py +1 -1
  65. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/METADATA +14 -2
  66. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/RECORD +69 -62
  67. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
  68. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
  69. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -31,13 +31,19 @@ from dstack._internal.core.models.gateways import (
31
31
  Gateway,
32
32
  GatewayComputeConfiguration,
33
33
  GatewayConfiguration,
34
+ GatewaySpec,
34
35
  GatewayStatus,
35
36
  LetsEncryptGatewayCertificate,
36
37
  )
37
38
  from dstack._internal.core.services import validate_dstack_resource_name
38
39
  from dstack._internal.server import settings
39
40
  from dstack._internal.server.db import get_db
40
- from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel
41
+ from dstack._internal.server.models import (
42
+ GatewayComputeModel,
43
+ GatewayModel,
44
+ ProjectModel,
45
+ UserModel,
46
+ )
41
47
  from dstack._internal.server.services.backends import (
42
48
  check_backend_type_available,
43
49
  get_project_backend_by_type_or_error,
@@ -50,6 +56,7 @@ from dstack._internal.server.services.locking import (
50
56
  get_locker,
51
57
  string_to_lock_id,
52
58
  )
59
+ from dstack._internal.server.services.plugins import apply_plugin_policies
53
60
  from dstack._internal.server.utils.common import gather_map_async
54
61
  from dstack._internal.utils.common import get_current_datetime, run_async
55
62
  from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
@@ -129,9 +136,17 @@ async def create_gateway_compute(
129
136
 
130
137
  async def create_gateway(
131
138
  session: AsyncSession,
139
+ user: UserModel,
132
140
  project: ProjectModel,
133
141
  configuration: GatewayConfiguration,
134
142
  ) -> Gateway:
143
+ spec = apply_plugin_policies(
144
+ user=user.name,
145
+ project=project.name,
146
+ # Create pseudo spec until the gateway API is updated to accept spec
147
+ spec=GatewaySpec(configuration=configuration),
148
+ )
149
+ configuration = spec.configuration
135
150
  _validate_gateway_configuration(configuration)
136
151
 
137
152
  backend_model, _ = await get_project_backend_with_model_by_type_or_error(
@@ -140,7 +155,7 @@ async def create_gateway(
140
155
 
141
156
  lock_namespace = f"gateway_names_{project.name}"
142
157
  if get_db().dialect_name == "sqlite":
143
- # Start new transaction to see commited changes after lock
158
+ # Start new transaction to see committed changes after lock
144
159
  await session.commit()
145
160
  elif get_db().dialect_name == "postgresql":
146
161
  await session.execute(
@@ -408,7 +408,6 @@ async def create_instance_model(
408
408
  requirements: Requirements,
409
409
  instance_name: str,
410
410
  instance_num: int,
411
- placement_group_name: Optional[str],
412
411
  reservation: Optional[str],
413
412
  blocks: Union[Literal["auto"], int],
414
413
  tags: Optional[Dict[str, str]],
@@ -427,7 +426,6 @@ async def create_instance_model(
427
426
  user=user.name,
428
427
  ssh_keys=[project_ssh_key],
429
428
  instance_id=str(instance_id),
430
- placement_group_name=placement_group_name,
431
429
  reservation=reservation,
432
430
  tags=tags,
433
431
  )
@@ -8,12 +8,14 @@ from dstack._internal.core.backends import (
8
8
  BACKENDS_WITH_RESERVATION_SUPPORT,
9
9
  )
10
10
  from dstack._internal.core.backends.base.backend import Backend
11
+ from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport
11
12
  from dstack._internal.core.models.backends.base import BackendType
12
13
  from dstack._internal.core.models.instances import (
13
14
  InstanceOfferWithAvailability,
14
15
  InstanceType,
15
16
  Resources,
16
17
  )
18
+ from dstack._internal.core.models.placement import PlacementGroup
17
19
  from dstack._internal.core.models.profiles import Profile
18
20
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
19
21
  from dstack._internal.core.models.volumes import Volume
@@ -31,6 +33,7 @@ async def get_offers_by_requirements(
31
33
  volumes: Optional[List[List[Volume]]] = None,
32
34
  privileged: bool = False,
33
35
  instance_mounts: bool = False,
36
+ placement_group: Optional[PlacementGroup] = None,
34
37
  blocks: Union[int, Literal["auto"]] = 1,
35
38
  ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
36
39
  backends: List[Backend] = await backends_services.get_project_backends(project=project)
@@ -116,6 +119,18 @@ async def get_offers_by_requirements(
116
119
  new_offers.append((b, new_offer))
117
120
  offers = new_offers
118
121
 
122
+ if placement_group is not None:
123
+ new_offers = []
124
+ for b, o in offers:
125
+ for backend in backends:
126
+ compute = backend.compute()
127
+ if isinstance(
128
+ compute, ComputeWithPlacementGroupSupport
129
+ ) and compute.is_suitable_placement_group(placement_group, o):
130
+ new_offers.append((b, o))
131
+ break
132
+ offers = new_offers
133
+
119
134
  if profile.instance_types is not None:
120
135
  instance_types = [i.lower() for i in profile.instance_types]
121
136
  offers = [(b, o) for b, o in offers if o.instance.name.lower() in instance_types]
@@ -1,8 +1,9 @@
1
+ from collections.abc import Iterable
1
2
  from typing import Optional
2
3
  from uuid import UUID
3
4
 
4
5
  from git import List
5
- from sqlalchemy import select
6
+ from sqlalchemy import and_, select, update
6
7
  from sqlalchemy.ext.asyncio import AsyncSession
7
8
 
8
9
  from dstack._internal.core.models.placement import (
@@ -13,15 +14,35 @@ from dstack._internal.core.models.placement import (
13
14
  from dstack._internal.server.models import PlacementGroupModel
14
15
 
15
16
 
16
- async def get_fleet_placement_groups(
17
+ async def get_fleet_placement_group_models(
17
18
  session: AsyncSession,
18
19
  fleet_id: UUID,
19
- ) -> List[PlacementGroup]:
20
+ ) -> List[PlacementGroupModel]:
20
21
  res = await session.execute(
21
- select(PlacementGroupModel).where(PlacementGroupModel.fleet_id == fleet_id)
22
+ select(PlacementGroupModel).where(
23
+ and_(
24
+ PlacementGroupModel.fleet_id == fleet_id,
25
+ PlacementGroupModel.deleted == False,
26
+ PlacementGroupModel.fleet_deleted == False,
27
+ )
28
+ )
29
+ )
30
+ return list(res.scalars().all())
31
+
32
+
33
+ async def schedule_fleet_placement_groups_deletion(
34
+ session: AsyncSession, fleet_id: UUID, except_placement_group_ids: Iterable[UUID] = ()
35
+ ) -> None:
36
+ await session.execute(
37
+ update(PlacementGroupModel)
38
+ .where(
39
+ and_(
40
+ PlacementGroupModel.fleet_id == fleet_id,
41
+ PlacementGroupModel.id.not_in(except_placement_group_ids),
42
+ )
43
+ )
44
+ .values(fleet_deleted=True) # TODO: rename `fleet_deleted` -> `to_be_deleted`
22
45
  )
23
- placement_groups = res.scalars().all()
24
- return [placement_group_model_to_placement_group(pg) for pg in placement_groups]
25
46
 
26
47
 
27
48
  def placement_group_model_to_placement_group(
@@ -0,0 +1,77 @@
1
+ import itertools
2
+ from importlib import import_module
3
+
4
+ from backports.entry_points_selectable import entry_points # backport for Python 3.9
5
+
6
+ from dstack._internal.core.errors import ServerClientError
7
+ from dstack._internal.utils.logging import get_logger
8
+ from dstack.plugins import ApplyPolicy, ApplySpec, Plugin
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ _PLUGINS: list[Plugin] = []
14
+
15
+
16
+ def load_plugins(enabled_plugins: list[str]):
17
+ _PLUGINS.clear()
18
+ plugins_entrypoints = entry_points(group="dstack.plugins")
19
+ plugins_to_load = enabled_plugins.copy()
20
+ for entrypoint in plugins_entrypoints:
21
+ if entrypoint.name not in enabled_plugins:
22
+ logger.info(
23
+ ("Found not enabled plugin %s. Plugin will not be loaded."),
24
+ entrypoint.name,
25
+ )
26
+ continue
27
+ try:
28
+ module_path, _, class_name = entrypoint.value.partition(":")
29
+ module = import_module(module_path)
30
+ except ImportError:
31
+ logger.warning(
32
+ (
33
+ "Failed to load plugin %s when importing %s."
34
+ " Ensure the module is on the import path."
35
+ ),
36
+ entrypoint.name,
37
+ entrypoint.value,
38
+ )
39
+ continue
40
+ plugin_class = getattr(module, class_name, None)
41
+ if plugin_class is None:
42
+ logger.warning(
43
+ ("Failed to load plugin %s: plugin class %s not found in module %s."),
44
+ entrypoint.name,
45
+ class_name,
46
+ module_path,
47
+ )
48
+ continue
49
+ if not issubclass(plugin_class, Plugin):
50
+ logger.warning(
51
+ ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
52
+ entrypoint.name,
53
+ class_name,
54
+ )
55
+ continue
56
+ plugins_to_load.remove(entrypoint.name)
57
+ _PLUGINS.append(plugin_class())
58
+ logger.info("Loaded plugin %s", entrypoint.name)
59
+ if plugins_to_load:
60
+ logger.warning("Enabled plugins not found: %s", plugins_to_load)
61
+
62
+
63
+ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
64
+ policies = _get_apply_policies()
65
+ for policy in policies:
66
+ try:
67
+ spec = policy.on_apply(user=user, project=project, spec=spec)
68
+ except ValueError as e:
69
+ msg = None
70
+ if len(e.args) > 0:
71
+ msg = e.args[0]
72
+ raise ServerClientError(msg)
73
+ return spec
74
+
75
+
76
+ def _get_apply_policies() -> list[ApplyPolicy]:
77
+ return list(itertools.chain(*[p.get_apply_policies() for p in _PLUGINS]))
@@ -0,0 +1,21 @@
1
+ import gpuhunt
2
+ from pydantic import parse_obj_as
3
+
4
+ from dstack._internal.core.models.resources import CPUSpec, ResourcesSpec
5
+
6
+
7
+ def set_resources_defaults(resources: ResourcesSpec) -> None:
8
+ # TODO: Remove in 0.20. Use resources.cpu directly
9
+ cpu = parse_obj_as(CPUSpec, resources.cpu)
10
+ if cpu.arch is None:
11
+ gpu = resources.gpu
12
+ if (
13
+ gpu is not None
14
+ and gpu.vendor in [None, gpuhunt.AcceleratorVendor.NVIDIA]
15
+ and gpu.name
16
+ and any(map(gpuhunt.is_nvidia_superchip, gpu.name))
17
+ ):
18
+ cpu.arch = gpuhunt.CPUArchitecture.ARM
19
+ else:
20
+ cpu.arch = gpuhunt.CPUArchitecture.X86
21
+ resources.cpu = cpu
@@ -79,7 +79,9 @@ from dstack._internal.server.services.jobs import (
79
79
  from dstack._internal.server.services.locking import get_locker, string_to_lock_id
80
80
  from dstack._internal.server.services.logging import fmt
81
81
  from dstack._internal.server.services.offers import get_offers_by_requirements
82
+ from dstack._internal.server.services.plugins import apply_plugin_policies
82
83
  from dstack._internal.server.services.projects import list_project_models, list_user_project_models
84
+ from dstack._internal.server.services.resources import set_resources_defaults
83
85
  from dstack._internal.server.services.users import get_user_model_by_name
84
86
  from dstack._internal.utils.logging import get_logger
85
87
  from dstack._internal.utils.random_names import generate_name
@@ -279,7 +281,14 @@ async def get_plan(
279
281
  run_spec: RunSpec,
280
282
  max_offers: Optional[int],
281
283
  ) -> RunPlan:
284
+ # Spec must be copied by parsing to calculate merged_profile
282
285
  effective_run_spec = RunSpec.parse_obj(run_spec.dict())
286
+ effective_run_spec = apply_plugin_policies(
287
+ user=user.name,
288
+ project=project.name,
289
+ spec=effective_run_spec,
290
+ )
291
+ effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
283
292
  _validate_run_spec_and_set_defaults(effective_run_spec)
284
293
 
285
294
  profile = effective_run_spec.merged_profile
@@ -293,12 +302,14 @@ async def get_plan(
293
302
  project=project,
294
303
  run_name=effective_run_spec.run_name,
295
304
  )
296
- if (
297
- current_resource is not None
298
- and not current_resource.status.is_finished()
299
- and _can_update_run_spec(current_resource.run_spec, effective_run_spec)
300
- ):
301
- action = ApplyAction.UPDATE
305
+ if current_resource is not None:
306
+ # For backward compatibility (current_resource may has been submitted before
307
+ # some fields, e.g., CPUSpec.arch, were added)
308
+ set_resources_defaults(current_resource.run_spec.configuration.resources)
309
+ if not current_resource.status.is_finished() and _can_update_run_spec(
310
+ current_resource.run_spec, effective_run_spec
311
+ ):
312
+ action = ApplyAction.UPDATE
302
313
 
303
314
  jobs = await get_jobs_from_run_spec(effective_run_spec, replica_num=0)
304
315
 
@@ -370,34 +381,48 @@ async def apply_plan(
370
381
  plan: ApplyRunPlanInput,
371
382
  force: bool,
372
383
  ) -> Run:
373
- _validate_run_spec_and_set_defaults(plan.run_spec)
374
- if plan.run_spec.run_name is None:
384
+ run_spec = plan.run_spec
385
+ run_spec = apply_plugin_policies(
386
+ user=user.name,
387
+ project=project.name,
388
+ spec=run_spec,
389
+ )
390
+ # Spec must be copied by parsing to calculate merged_profile
391
+ run_spec = RunSpec.parse_obj(run_spec.dict())
392
+ _validate_run_spec_and_set_defaults(run_spec)
393
+ if run_spec.run_name is None:
375
394
  return await submit_run(
376
395
  session=session,
377
396
  user=user,
378
397
  project=project,
379
- run_spec=plan.run_spec,
398
+ run_spec=run_spec,
380
399
  )
381
400
  current_resource = await get_run_by_name(
382
401
  session=session,
383
402
  project=project,
384
- run_name=plan.run_spec.run_name,
403
+ run_name=run_spec.run_name,
385
404
  )
386
405
  if current_resource is None or current_resource.status.is_finished():
387
406
  return await submit_run(
388
407
  session=session,
389
408
  user=user,
390
409
  project=project,
391
- run_spec=plan.run_spec,
410
+ run_spec=run_spec,
392
411
  )
412
+
413
+ # For backward compatibility (current_resource may has been submitted before
414
+ # some fields, e.g., CPUSpec.arch, were added)
415
+ set_resources_defaults(current_resource.run_spec.configuration.resources)
393
416
  try:
394
- _check_can_update_run_spec(current_resource.run_spec, plan.run_spec)
417
+ _check_can_update_run_spec(current_resource.run_spec, run_spec)
395
418
  except ServerClientError:
396
419
  # The except is only needed to raise an appropriate error if run is active
397
420
  if not current_resource.status.is_finished():
398
421
  raise ServerClientError("Cannot override active run. Stop the run first.")
399
422
  raise
400
423
  if not force:
424
+ if plan.current_resource is not None:
425
+ set_resources_defaults(plan.current_resource.run_spec.configuration.resources)
401
426
  if (
402
427
  plan.current_resource is None
403
428
  or plan.current_resource.id != current_resource.id
@@ -409,14 +434,12 @@ async def apply_plan(
409
434
  # FIXME: potentially long write transaction
410
435
  # Avoid getting run_model after update
411
436
  await session.execute(
412
- update(RunModel)
413
- .where(RunModel.id == current_resource.id)
414
- .values(run_spec=plan.run_spec.json())
437
+ update(RunModel).where(RunModel.id == current_resource.id).values(run_spec=run_spec.json())
415
438
  )
416
439
  run = await get_run_by_name(
417
440
  session=session,
418
441
  project=project,
419
- run_name=plan.run_spec.run_name,
442
+ run_name=run_spec.run_name,
420
443
  )
421
444
  return common_utils.get_or_error(run)
422
445
 
@@ -436,7 +459,7 @@ async def submit_run(
436
459
 
437
460
  lock_namespace = f"run_names_{project.name}"
438
461
  if get_db().dialect_name == "sqlite":
439
- # Start new transaction to see commited changes after lock
462
+ # Start new transaction to see committed changes after lock
440
463
  await session.commit()
441
464
  elif get_db().dialect_name == "postgresql":
442
465
  await session.execute(
@@ -852,6 +875,7 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
852
875
  raise ServerClientError(
853
876
  f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_TTL_SECONDS}s"
854
877
  )
878
+ set_resources_defaults(run_spec.configuration.resources)
855
879
 
856
880
 
857
881
  _UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
@@ -21,6 +21,7 @@ from dstack._internal.core.models.volumes import (
21
21
  VolumeConfiguration,
22
22
  VolumeInstance,
23
23
  VolumeProvisioningData,
24
+ VolumeSpec,
24
25
  VolumeStatus,
25
26
  )
26
27
  from dstack._internal.core.services import validate_dstack_resource_name
@@ -38,6 +39,7 @@ from dstack._internal.server.services.locking import (
38
39
  get_locker,
39
40
  string_to_lock_id,
40
41
  )
42
+ from dstack._internal.server.services.plugins import apply_plugin_policies
41
43
  from dstack._internal.server.services.projects import list_project_models, list_user_project_models
42
44
  from dstack._internal.utils import common, random_names
43
45
  from dstack._internal.utils.logging import get_logger
@@ -203,11 +205,18 @@ async def create_volume(
203
205
  user: UserModel,
204
206
  configuration: VolumeConfiguration,
205
207
  ) -> Volume:
208
+ spec = apply_plugin_policies(
209
+ user=user.name,
210
+ project=project.name,
211
+ # Create pseudo spec until the volume API is updated to accept spec
212
+ spec=VolumeSpec(configuration=configuration),
213
+ )
214
+ configuration = spec.configuration
206
215
  _validate_volume_configuration(configuration)
207
216
 
208
217
  lock_namespace = f"volume_names_{project.name}"
209
218
  if get_db().dialect_name == "sqlite":
210
- # Start new transaction to see commited changes after lock
219
+ # Start new transaction to see committed changes after lock
211
220
  await session.commit()
212
221
  elif get_db().dialect_name == "postgresql":
213
222
  await session.execute(
@@ -2,7 +2,7 @@ import json
2
2
  import uuid
3
3
  from contextlib import contextmanager
4
4
  from datetime import datetime, timezone
5
- from typing import Dict, List, Optional, Union
5
+ from typing import Dict, List, Literal, Optional, Union
6
6
  from uuid import UUID
7
7
 
8
8
  import gpuhunt
@@ -25,7 +25,12 @@ from dstack._internal.core.models.configurations import (
25
25
  DevEnvironmentConfiguration,
26
26
  )
27
27
  from dstack._internal.core.models.envs import Env
28
- from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec, FleetStatus
28
+ from dstack._internal.core.models.fleets import (
29
+ FleetConfiguration,
30
+ FleetSpec,
31
+ FleetStatus,
32
+ InstanceGroupPlacement,
33
+ )
29
34
  from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus
30
35
  from dstack._internal.core.models.instances import (
31
36
  Disk,
@@ -51,7 +56,7 @@ from dstack._internal.core.models.profiles import (
51
56
  )
52
57
  from dstack._internal.core.models.repos.base import RepoType
53
58
  from dstack._internal.core.models.repos.local import LocalRunRepoData
54
- from dstack._internal.core.models.resources import Memory, Range, ResourcesSpec
59
+ from dstack._internal.core.models.resources import CPUSpec, Memory, Range, ResourcesSpec
55
60
  from dstack._internal.core.models.runs import (
56
61
  JobProvisioningData,
57
62
  JobRuntimeData,
@@ -497,10 +502,12 @@ def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec:
497
502
  def get_fleet_configuration(
498
503
  name: str = "test-fleet",
499
504
  nodes: Range[int] = Range(min=1, max=1),
505
+ placement: Optional[InstanceGroupPlacement] = None,
500
506
  ) -> FleetConfiguration:
501
507
  return FleetConfiguration(
502
508
  name=name,
503
509
  nodes=nodes,
510
+ placement=placement,
504
511
  )
505
512
 
506
513
 
@@ -519,13 +526,13 @@ async def create_instance(
519
526
  instance_id: Optional[UUID] = None,
520
527
  job: Optional[JobModel] = None,
521
528
  instance_num: int = 0,
522
- backend: BackendType = BackendType.DATACRUNCH,
529
+ backend: Optional[BackendType] = BackendType.DATACRUNCH,
523
530
  termination_policy: Optional[TerminationPolicy] = None,
524
531
  termination_idle_time: int = DEFAULT_FLEET_TERMINATION_IDLE_TIME,
525
- region: str = "eu-west",
532
+ region: Optional[str] = "eu-west",
526
533
  remote_connection_info: Optional[RemoteConnectionInfo] = None,
527
- offer: Optional[InstanceOfferWithAvailability] = None,
528
- job_provisioning_data: Optional[JobProvisioningData] = None,
534
+ offer: Optional[Union[InstanceOfferWithAvailability, Literal["auto"]]] = "auto",
535
+ job_provisioning_data: Optional[Union[JobProvisioningData, Literal["auto"]]] = "auto",
529
536
  total_blocks: Optional[int] = 1,
530
537
  busy_blocks: int = 0,
531
538
  name: str = "test_instance",
@@ -534,7 +541,7 @@ async def create_instance(
534
541
  ) -> InstanceModel:
535
542
  if instance_id is None:
536
543
  instance_id = uuid.uuid4()
537
- if job_provisioning_data is None:
544
+ if job_provisioning_data == "auto":
538
545
  job_provisioning_data = get_job_provisioning_data(
539
546
  dockerized=True,
540
547
  backend=backend,
@@ -543,13 +550,13 @@ async def create_instance(
543
550
  hostname="running_instance.ip",
544
551
  internal_ip=None,
545
552
  )
546
- if offer is None:
553
+ if offer == "auto":
547
554
  offer = get_instance_offer_with_availability(backend=backend, region=region, spot=spot)
548
555
  if profile is None:
549
556
  profile = Profile(name="test_name")
550
557
 
551
558
  if requirements is None:
552
- requirements = Requirements(resources=ResourcesSpec(cpu=1))
559
+ requirements = Requirements(resources=ResourcesSpec(cpu=CPUSpec.parse("1")))
553
560
 
554
561
  if instance_configuration is None:
555
562
  instance_configuration = get_instance_configuration()
@@ -571,8 +578,8 @@ async def create_instance(
571
578
  created_at=created_at,
572
579
  started_at=created_at,
573
580
  finished_at=finished_at,
574
- job_provisioning_data=job_provisioning_data.json(),
575
- offer=offer.json(),
581
+ job_provisioning_data=job_provisioning_data.json() if job_provisioning_data else None,
582
+ offer=offer.json() if offer else None,
576
583
  price=price,
577
584
  region=region,
578
585
  backend=backend,
@@ -659,20 +666,7 @@ def get_remote_connection_info(
659
666
  env: Optional[Union[Env, dict]] = None,
660
667
  ):
661
668
  if ssh_keys is None:
662
- ssh_keys = [
663
- SSHKey(
664
- public="ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO6mJxVbNtm0zXgMLvByrhXJCmJRveSrJxLB5/OzcyCk",
665
- private="""
666
- -----BEGIN OPENSSH PRIVATE KEY-----
667
- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
668
- QyNTUxOQAAACDupicVWzbZtM14DC7wcq4VyQpiUb3kqycSwefzs3MgpAAAAJCiWa5Volmu
669
- VQAAAAtzc2gtZWQyNTUxOQAAACDupicVWzbZtM14DC7wcq4VyQpiUb3kqycSwefzs3MgpA
670
- AAAEAncHi4AhS6XdMp5Gzd+IMse/4ekyQ54UngByf0Sp0uH+6mJxVbNtm0zXgMLvByrhXJ
671
- CmJRveSrJxLB5/OzcyCkAAAACWRlZkBkZWZwYwECAwQ=
672
- -----END OPENSSH PRIVATE KEY-----
673
- """,
674
- )
675
- ]
669
+ ssh_keys = [get_ssh_key()]
676
670
  if env is None:
677
671
  env = Env()
678
672
  elif isinstance(env, dict):
@@ -686,6 +680,21 @@ def get_remote_connection_info(
686
680
  )
687
681
 
688
682
 
683
+ def get_ssh_key() -> SSHKey:
684
+ return SSHKey(
685
+ public="ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIO6mJxVbNtm0zXgMLvByrhXJCmJRveSrJxLB5/OzcyCk",
686
+ private="""
687
+ -----BEGIN OPENSSH PRIVATE KEY-----
688
+ b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
689
+ QyNTUxOQAAACDupicVWzbZtM14DC7wcq4VyQpiUb3kqycSwefzs3MgpAAAAJCiWa5Volmu
690
+ VQAAAAtzc2gtZWQyNTUxOQAAACDupicVWzbZtM14DC7wcq4VyQpiUb3kqycSwefzs3MgpA
691
+ AAAEAncHi4AhS6XdMp5Gzd+IMse/4ekyQ54UngByf0Sp0uH+6mJxVbNtm0zXgMLvByrhXJ
692
+ CmJRveSrJxLB5/OzcyCkAAAACWRlZkBkZWZwYwECAwQ=
693
+ -----END OPENSSH PRIVATE KEY-----
694
+ """,
695
+ )
696
+
697
+
689
698
  async def create_volume(
690
699
  session: AsyncSession,
691
700
  project: ProjectModel,
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import enum
2
3
  import itertools
3
4
  import re
4
5
  import time
@@ -83,6 +84,8 @@ def pretty_date(time: datetime) -> str:
83
84
 
84
85
 
85
86
  def pretty_resources(
87
+ *,
88
+ cpu_arch: Optional[Any] = None,
86
89
  cpus: Optional[Any] = None,
87
90
  memory: Optional[Any] = None,
88
91
  gpu_count: Optional[Any] = None,
@@ -110,25 +113,35 @@ def pretty_resources(
110
113
  """
111
114
  parts = []
112
115
  if cpus is not None:
113
- parts.append(f"{cpus}xCPU")
116
+ cpu_arch_lower: Optional[str] = None
117
+ if isinstance(cpu_arch, enum.Enum):
118
+ cpu_arch_lower = str(cpu_arch.value).lower()
119
+ elif isinstance(cpu_arch, str):
120
+ cpu_arch_lower = cpu_arch.lower()
121
+ if cpu_arch_lower == "arm":
122
+ cpu_arch_prefix = "arm:"
123
+ else:
124
+ cpu_arch_prefix = ""
125
+ parts.append(f"cpu={cpu_arch_prefix}{cpus}")
114
126
  if memory is not None:
115
- parts.append(f"{memory}")
127
+ parts.append(f"mem={memory}")
128
+ if disk_size:
129
+ parts.append(f"disk={disk_size}")
116
130
  if gpu_count:
117
131
  gpu_parts = []
132
+ gpu_parts.append(f"{gpu_name or 'gpu'}")
118
133
  if gpu_memory is not None:
119
134
  gpu_parts.append(f"{gpu_memory}")
135
+ if gpu_count is not None:
136
+ gpu_parts.append(f"{gpu_count}")
120
137
  if total_gpu_memory is not None:
121
- gpu_parts.append(f"total {total_gpu_memory}")
138
+ gpu_parts.append(f"{total_gpu_memory}")
122
139
  if compute_capability is not None:
123
140
  gpu_parts.append(f"{compute_capability}")
124
141
 
125
- gpu = f"{gpu_count}x{gpu_name or 'GPU'}"
126
- if gpu_parts:
127
- gpu += f" ({', '.join(gpu_parts)})"
142
+ gpu = ":".join(gpu_parts)
128
143
  parts.append(gpu)
129
- if disk_size:
130
- parts.append(f"{disk_size} (disk)")
131
- return ", ".join(parts)
144
+ return " ".join(parts)
132
145
 
133
146
 
134
147
  def since(timestamp: str) -> datetime:
@@ -1,6 +1,9 @@
1
1
  def add_extra_schema_types(schema_property: dict, extra_types: list[dict]):
2
2
  if "allOf" in schema_property:
3
- ref = schema_property.pop("allOf")[0]
3
+ refs = [schema_property.pop("allOf")[0]]
4
+ elif "anyOf" in schema_property:
5
+ refs = schema_property.pop("anyOf")
4
6
  else:
5
- ref = {"type": schema_property.pop("type")}
6
- schema_property["anyOf"] = [ref, *extra_types]
7
+ refs = [{"type": schema_property.pop("type")}]
8
+ refs.extend(extra_types)
9
+ schema_property["anyOf"] = refs
dstack/api/__init__.py CHANGED
@@ -14,6 +14,7 @@ from dstack._internal.core.models.repos.local import LocalRepo
14
14
  from dstack._internal.core.models.repos.remote import RemoteRepo
15
15
  from dstack._internal.core.models.repos.virtual import VirtualRepo
16
16
  from dstack._internal.core.models.resources import ComputeCapability, Memory, Range
17
+ from dstack._internal.core.models.resources import CPUSpec as CPU
17
18
  from dstack._internal.core.models.resources import DiskSpec as Disk
18
19
  from dstack._internal.core.models.resources import GPUSpec as GPU
19
20
  from dstack._internal.core.models.resources import ResourcesSpec as Resources