dstack 0.19.7__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (52) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/run.py +38 -2
  3. dstack/_internal/cli/utils/run.py +3 -3
  4. dstack/_internal/core/backends/aws/compute.py +13 -1
  5. dstack/_internal/core/backends/azure/compute.py +42 -13
  6. dstack/_internal/core/backends/azure/configurator.py +21 -0
  7. dstack/_internal/core/backends/azure/models.py +9 -0
  8. dstack/_internal/core/backends/base/compute.py +101 -27
  9. dstack/_internal/core/backends/base/offers.py +13 -3
  10. dstack/_internal/core/backends/cudo/compute.py +2 -0
  11. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  12. dstack/_internal/core/backends/gcp/auth.py +1 -1
  13. dstack/_internal/core/backends/gcp/compute.py +51 -35
  14. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  15. dstack/_internal/core/backends/local/compute.py +2 -0
  16. dstack/_internal/core/backends/nebius/compute.py +95 -1
  17. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  18. dstack/_internal/core/backends/nebius/fabrics.py +47 -0
  19. dstack/_internal/core/backends/nebius/models.py +8 -0
  20. dstack/_internal/core/backends/nebius/resources.py +29 -0
  21. dstack/_internal/core/backends/oci/compute.py +2 -0
  22. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  23. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  24. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  25. dstack/_internal/core/backends/vultr/compute.py +5 -1
  26. dstack/_internal/core/models/instances.py +2 -1
  27. dstack/_internal/core/models/resources.py +78 -3
  28. dstack/_internal/core/models/runs.py +7 -2
  29. dstack/_internal/core/models/volumes.py +1 -1
  30. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  31. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  32. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  33. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  34. dstack/_internal/server/models.py +1 -0
  35. dstack/_internal/server/services/fleets.py +9 -26
  36. dstack/_internal/server/services/instances.py +0 -2
  37. dstack/_internal/server/services/offers.py +15 -0
  38. dstack/_internal/server/services/placement.py +27 -6
  39. dstack/_internal/server/services/resources.py +21 -0
  40. dstack/_internal/server/services/runs.py +16 -6
  41. dstack/_internal/server/testing/common.py +35 -26
  42. dstack/_internal/utils/common.py +13 -1
  43. dstack/_internal/utils/json_schema.py +6 -3
  44. dstack/api/__init__.py +1 -0
  45. dstack/api/server/_fleets.py +16 -0
  46. dstack/api/server/_runs.py +44 -3
  47. dstack/version.py +1 -1
  48. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/METADATA +3 -1
  49. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/RECORD +52 -50
  50. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
  51. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
  52. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,8 +1,9 @@
1
1
  import math
2
+ from collections.abc import Mapping
2
3
  from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
3
4
 
4
5
  import gpuhunt
5
- from pydantic import Field, root_validator, validator
6
+ from pydantic import Field, parse_obj_as, root_validator, validator
6
7
  from pydantic.generics import GenericModel
7
8
  from typing_extensions import Annotated
8
9
 
@@ -128,6 +129,67 @@ DEFAULT_MEMORY_SIZE = Range[Memory](min=Memory.parse("8GB"))
128
129
  DEFAULT_GPU_COUNT = Range[int](min=1, max=1)
129
130
 
130
131
 
132
+ class CPUSpec(CoreModel):
133
+ class Config:
134
+ @staticmethod
135
+ def schema_extra(schema: Dict[str, Any]):
136
+ add_extra_schema_types(
137
+ schema["properties"]["count"],
138
+ extra_types=[{"type": "integer"}, {"type": "string"}],
139
+ )
140
+
141
+ arch: Annotated[
142
+ Optional[gpuhunt.CPUArchitecture],
143
+ Field(description="The CPU architecture, one of: `x86`, `arm`"),
144
+ ] = None
145
+ count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT
146
+
147
+ @classmethod
148
+ def __get_validators__(cls):
149
+ yield cls.parse
150
+ yield cls.validate
151
+
152
+ @classmethod
153
+ def parse(cls, v: Any) -> Any:
154
+ if isinstance(v, int):
155
+ v = str(v)
156
+ if isinstance(v, str):
157
+ tokens = v.replace(" ", "").split(":")
158
+ spec = {}
159
+ for token in tokens:
160
+ if not token:
161
+ raise ValueError(f"CPU spec contains empty token: {v}")
162
+ if ".." in token or token.isdigit():
163
+ if "count" in spec:
164
+ raise ValueError(f"CPU spec count conflict: {v}")
165
+ spec["count"] = token
166
+ else:
167
+ try:
168
+ arch = gpuhunt.CPUArchitecture.cast(token)
169
+ except ValueError:
170
+ raise ValueError(f"Invalid CPU architecture: {v}")
171
+ if "arch" in spec:
172
+ raise ValueError(f"CPU spec arch conflict: {v}")
173
+ spec["arch"] = arch
174
+ return spec
175
+ # Range and min/max dict - for backward compatibility
176
+ if isinstance(v, Range):
177
+ return {"arch": None, "count": v}
178
+ if isinstance(v, Mapping) and v.keys() == {"min", "max"}:
179
+ return {"arch": None, "count": v}
180
+ return v
181
+
182
+ @validator("arch", pre=True)
183
+ def _validate_arch(cls, v: Any) -> Any:
184
+ if v is None:
185
+ return None
186
+ if isinstance(v, gpuhunt.CPUArchitecture):
187
+ return v
188
+ if isinstance(v, str):
189
+ return gpuhunt.CPUArchitecture.cast(v)
190
+ return v
191
+
192
+
131
193
  class GPUSpec(CoreModel):
132
194
  class Config:
133
195
  @staticmethod
@@ -302,7 +364,10 @@ class ResourcesSpec(CoreModel):
302
364
  extra_types=[{"type": "integer"}, {"type": "string"}],
303
365
  )
304
366
 
305
- cpu: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT
367
+ # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only.
368
+ cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = (
369
+ CPUSpec()
370
+ )
306
371
  memory: Annotated[Range[Memory], Field(description="The RAM size (e.g., `8GB`)")] = (
307
372
  DEFAULT_MEMORY_SIZE
308
373
  )
@@ -317,8 +382,18 @@ class ResourcesSpec(CoreModel):
317
382
  gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
318
383
  disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
319
384
 
385
+ # TODO: Remove in 0.20. Added for backward compatibility.
386
+ @root_validator
387
+ def _post_validate(cls, values):
388
+ cpu = values.get("cpu")
389
+ if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]:
390
+ values["cpu"] = cpu.count
391
+ return values
392
+
320
393
  def pretty_format(self) -> str:
321
- resources: Dict[str, Any] = dict(cpus=self.cpu, memory=self.memory)
394
+ # TODO: Remove in 0.20. Use self.cpu directly
395
+ cpu = parse_obj_as(CPUSpec, self.cpu)
396
+ resources: Dict[str, Any] = dict(cpu_arch=cpu.arch, cpus=cpu.count, memory=self.memory)
322
397
  if self.gpu:
323
398
  gpu = self.gpu
324
399
  resources.update(
@@ -439,9 +439,14 @@ class Run(CoreModel):
439
439
 
440
440
  @root_validator
441
441
  def _error(cls, values) -> Dict:
442
+ try:
443
+ termination_reason = values["termination_reason"]
444
+ jobs = values["jobs"]
445
+ except KeyError:
446
+ return values
442
447
  values["error"] = _get_run_error(
443
- run_termination_reason=values["termination_reason"],
444
- run_jobs=values["jobs"],
448
+ run_termination_reason=termination_reason,
449
+ run_jobs=jobs,
445
450
  )
446
451
  return values
447
452
 
@@ -159,7 +159,7 @@ class VolumeMountPoint(CoreModel):
159
159
  description=(
160
160
  "The network volume name or the list of network volume names to mount."
161
161
  " If a list is specified, one of the volumes in the list will be mounted."
162
- " Specify volumes from different backends/regions to increase availability."
162
+ " Specify volumes from different backends/regions to increase availability"
163
163
  )
164
164
  ),
165
165
  ]
@@ -1,15 +1,16 @@
1
- from sqlalchemy import select, update
1
+ from sqlalchemy import select
2
2
  from sqlalchemy.ext.asyncio import AsyncSession
3
3
  from sqlalchemy.orm import joinedload
4
4
 
5
5
  from dstack._internal.core.models.fleets import FleetStatus
6
6
  from dstack._internal.server.db import get_session_ctx
7
- from dstack._internal.server.models import FleetModel, PlacementGroupModel
7
+ from dstack._internal.server.models import FleetModel
8
8
  from dstack._internal.server.services.fleets import (
9
9
  is_fleet_empty,
10
10
  is_fleet_in_use,
11
11
  )
12
12
  from dstack._internal.server.services.locking import get_locker
13
+ from dstack._internal.server.services.placement import schedule_fleet_placement_groups_deletion
13
14
  from dstack._internal.utils.common import get_current_datetime
14
15
  from dstack._internal.utils.logging import get_logger
15
16
 
@@ -68,16 +69,6 @@ async def _autodelete_fleet(session: AsyncSession, fleet_model: FleetModel):
68
69
  fleet_model.status = FleetStatus.TERMINATED
69
70
  fleet_model.deleted = True
70
71
  fleet_model.last_processed_at = get_current_datetime()
71
- await _mark_placement_groups_as_ready_for_deletion(session=session, fleet_model=fleet_model)
72
+ await schedule_fleet_placement_groups_deletion(session=session, fleet_id=fleet_model.id)
72
73
  await session.commit()
73
74
  logger.info("Fleet %s deleted", fleet_model.name)
74
-
75
-
76
- async def _mark_placement_groups_as_ready_for_deletion(
77
- session: AsyncSession, fleet_model: FleetModel
78
- ):
79
- await session.execute(
80
- update(PlacementGroupModel)
81
- .where(PlacementGroupModel.fleet_id == fleet_model.id)
82
- .values(fleet_deleted=True)
83
- )
@@ -19,6 +19,8 @@ from dstack._internal.core.backends import (
19
19
  from dstack._internal.core.backends.base.compute import (
20
20
  ComputeWithCreateInstanceSupport,
21
21
  ComputeWithPlacementGroupSupport,
22
+ GoArchType,
23
+ generate_unique_placement_group_name,
22
24
  get_dstack_runner_binary_path,
23
25
  get_dstack_shim_binary_path,
24
26
  get_dstack_working_dir,
@@ -26,6 +28,7 @@ from dstack._internal.core.backends.base.compute import (
26
28
  get_shim_pre_start_commands,
27
29
  )
28
30
  from dstack._internal.core.backends.remote.provisioning import (
31
+ detect_cpu_arch,
29
32
  get_host_info,
30
33
  get_paramiko_connection,
31
34
  get_shim_healthcheck,
@@ -39,11 +42,16 @@ from dstack._internal.core.backends.remote.provisioning import (
39
42
  from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
40
43
 
41
44
  # FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute
42
- from dstack._internal.core.errors import BackendError, NotYetTerminated, ProvisioningError
45
+ from dstack._internal.core.errors import (
46
+ BackendError,
47
+ NotYetTerminated,
48
+ ProvisioningError,
49
+ )
43
50
  from dstack._internal.core.models.backends.base import BackendType
44
51
  from dstack._internal.core.models.fleets import InstanceGroupPlacement
45
52
  from dstack._internal.core.models.instances import (
46
53
  InstanceAvailability,
54
+ InstanceOffer,
47
55
  InstanceOfferWithAvailability,
48
56
  InstanceRuntime,
49
57
  InstanceStatus,
@@ -51,7 +59,6 @@ from dstack._internal.core.models.instances import (
51
59
  SSHKey,
52
60
  )
53
61
  from dstack._internal.core.models.placement import (
54
- PlacementGroup,
55
62
  PlacementGroupConfiguration,
56
63
  PlacementStrategy,
57
64
  )
@@ -89,8 +96,9 @@ from dstack._internal.server.services.instances import (
89
96
  from dstack._internal.server.services.locking import get_locker
90
97
  from dstack._internal.server.services.offers import is_divisible_into_blocks
91
98
  from dstack._internal.server.services.placement import (
92
- get_fleet_placement_groups,
99
+ get_fleet_placement_group_models,
93
100
  placement_group_model_to_placement_group,
101
+ schedule_fleet_placement_groups_deletion,
94
102
  )
95
103
  from dstack._internal.server.services.runner import client as runner_client
96
104
  from dstack._internal.server.services.runner.client import HealthStatus
@@ -264,7 +272,7 @@ async def _add_remote(instance: InstanceModel) -> None:
264
272
  )
265
273
  deploy_timeout = 20 * 60 # 20 minutes
266
274
  result = await asyncio.wait_for(future, timeout=deploy_timeout)
267
- health, host_info = result
275
+ health, host_info, cpu_arch = result
268
276
  except (asyncio.TimeoutError, TimeoutError) as e:
269
277
  raise ProvisioningError(f"Deploy timeout: {e}") from e
270
278
  except Exception as e:
@@ -285,7 +293,7 @@ async def _add_remote(instance: InstanceModel) -> None:
285
293
  instance.last_retry_at = get_current_datetime()
286
294
  return
287
295
 
288
- instance_type = host_info_to_instance_type(host_info)
296
+ instance_type = host_info_to_instance_type(host_info, cpu_arch)
289
297
  instance_network = None
290
298
  internal_ip = None
291
299
  try:
@@ -388,7 +396,7 @@ def _deploy_instance(
388
396
  pkeys: List[PKey],
389
397
  ssh_proxy_pkeys: Optional[list[PKey]],
390
398
  authorized_keys: List[str],
391
- ) -> Tuple[HealthStatus, Dict[str, Any]]:
399
+ ) -> Tuple[HealthStatus, Dict[str, Any], GoArchType]:
392
400
  with get_paramiko_connection(
393
401
  remote_details.ssh_user,
394
402
  remote_details.host,
@@ -399,13 +407,16 @@ def _deploy_instance(
399
407
  ) as client:
400
408
  logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")
401
409
 
410
+ arch = detect_cpu_arch(client)
411
+ logger.info("%s: CPU arch is %s", remote_details.host, arch)
412
+
402
413
  # Execute pre start commands
403
- shim_pre_start_commands = get_shim_pre_start_commands()
414
+ shim_pre_start_commands = get_shim_pre_start_commands(arch=arch)
404
415
  run_pre_start_commands(client, shim_pre_start_commands, authorized_keys)
405
416
  logger.debug("The script for installing dstack has been executed")
406
417
 
407
418
  # Upload envs
408
- shim_envs = get_shim_env(authorized_keys)
419
+ shim_envs = get_shim_env(authorized_keys, arch=arch)
409
420
  try:
410
421
  fleet_configuration_envs = remote_details.env.as_dict()
411
422
  except ValueError as e:
@@ -440,7 +451,7 @@ def _deploy_instance(
440
451
  raise ProvisioningError("Cannot read HealthcheckResponse") from e
441
452
  health = runner_client.health_response_to_health_status(health_response)
442
453
 
443
- return health, host_info
454
+ return health, host_info, arch
444
455
 
445
456
 
446
457
  async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
@@ -509,11 +520,39 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
509
520
  )
510
521
  return
511
522
 
523
+ placement_group_models = []
524
+ placement_group_model = None
525
+ if instance.fleet_id:
526
+ placement_group_models = await get_fleet_placement_group_models(
527
+ session=session,
528
+ fleet_id=instance.fleet_id,
529
+ )
530
+ # The placement group is determined when provisioning the master instance
531
+ # and used for all other instances in the fleet.
532
+ if not _is_fleet_master_instance(instance):
533
+ if placement_group_models:
534
+ placement_group_model = placement_group_models[0]
535
+ if len(placement_group_models) > 1:
536
+ logger.error(
537
+ (
538
+ "Expected 0 or 1 placement groups associated with fleet %s, found %s."
539
+ " An incorrect placement group might have been selected for instance %s"
540
+ ),
541
+ instance.fleet_id,
542
+ len(placement_group_models),
543
+ instance.name,
544
+ )
545
+
512
546
  offers = await get_create_instance_offers(
513
547
  project=instance.project,
514
548
  profile=profile,
515
549
  requirements=requirements,
516
550
  fleet_model=instance.fleet,
551
+ placement_group=(
552
+ placement_group_model_to_placement_group(placement_group_model)
553
+ if placement_group_model
554
+ else None
555
+ ),
517
556
  blocks="auto" if instance.total_blocks is None else instance.total_blocks,
518
557
  exclude_not_available=True,
519
558
  )
@@ -527,12 +566,6 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
527
566
  )
528
567
  return
529
568
 
530
- placement_groups = []
531
- if instance.fleet_id:
532
- placement_groups = await get_fleet_placement_groups(
533
- session=session, fleet_id=instance.fleet_id
534
- )
535
-
536
569
  # Limit number of offers tried to prevent long-running processing
537
570
  # in case all offers fail.
538
571
  for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]:
@@ -542,25 +575,28 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
542
575
  assert isinstance(compute, ComputeWithCreateInstanceSupport)
543
576
  instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
544
577
  if (
545
- instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
578
+ _is_fleet_master_instance(instance)
579
+ and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
546
580
  and instance.fleet
547
- and instance_configuration.placement_group_name
581
+ and _is_cloud_cluster(instance.fleet)
548
582
  ):
549
583
  assert isinstance(compute, ComputeWithPlacementGroupSupport)
550
- placement_group_model = _create_placement_group_if_does_not_exist(
551
- session=session,
552
- fleet_model=instance.fleet,
553
- placement_groups=placement_groups,
554
- name=instance_configuration.placement_group_name,
555
- backend=instance_offer.backend,
556
- region=instance_offer.region,
584
+ placement_group_model = _find_suitable_placement_group(
585
+ placement_groups=placement_group_models,
586
+ instance_offer=instance_offer,
587
+ compute=compute,
557
588
  )
558
- if placement_group_model is not None:
559
- placement_group = placement_group_model_to_placement_group(placement_group_model)
560
- pgpd = await run_async(compute.create_placement_group, placement_group)
561
- placement_group_model.provisioning_data = pgpd.json()
589
+ if placement_group_model is None:
590
+ placement_group_model = await _create_placement_group(
591
+ fleet_model=instance.fleet,
592
+ master_instance_offer=instance_offer,
593
+ compute=compute,
594
+ )
595
+ if placement_group_model is None: # error occurred
596
+ continue
562
597
  session.add(placement_group_model)
563
- placement_groups.append(placement_group)
598
+ await session.flush()
599
+ placement_group_models.append(placement_group_model)
564
600
  logger.debug(
565
601
  "Trying %s in %s/%s for $%0.4f per hour",
566
602
  instance_offer.instance.name,
@@ -573,6 +609,11 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
573
609
  compute.create_instance,
574
610
  instance_offer,
575
611
  instance_configuration,
612
+ (
613
+ placement_group_model_to_placement_group(placement_group_model)
614
+ if placement_group_model
615
+ else None
616
+ ),
576
617
  )
577
618
  except BackendError as e:
578
619
  logger.warning(
@@ -612,22 +653,46 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
612
653
  "instance_status": InstanceStatus.PROVISIONING.value,
613
654
  },
614
655
  )
656
+ if instance.fleet_id and _is_fleet_master_instance(instance):
657
+ # Clean up placement groups that did not end up being used
658
+ await schedule_fleet_placement_groups_deletion(
659
+ session=session,
660
+ fleet_id=instance.fleet_id,
661
+ except_placement_group_ids=(
662
+ [placement_group_model.id] if placement_group_model is not None else []
663
+ ),
664
+ )
615
665
  return
616
666
 
617
667
  instance.last_retry_at = get_current_datetime()
618
668
 
619
669
  if not should_retry:
620
- instance.status = InstanceStatus.TERMINATED
621
- instance.termination_reason = "All offers failed" if offers else "No offers found"
622
- logger.info(
623
- "Terminated instance %s: %s",
624
- instance.name,
625
- instance.termination_reason,
626
- extra={
627
- "instance_name": instance.name,
628
- "instance_status": InstanceStatus.TERMINATED.value,
629
- },
630
- )
670
+ _mark_terminated(instance, "All offers failed" if offers else "No offers found")
671
+ if (
672
+ instance.fleet
673
+ and _is_fleet_master_instance(instance)
674
+ and _is_cloud_cluster(instance.fleet)
675
+ ):
676
+ # Do not attempt to deploy other instances, as they won't determine the correct cluster
677
+ # backend, region, and placement group without a successfully deployed master instance
678
+ for sibling_instance in instance.fleet.instances:
679
+ if sibling_instance.id == instance.id:
680
+ continue
681
+ _mark_terminated(sibling_instance, "Master instance failed to start")
682
+
683
+
684
+ def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
685
+ instance.status = InstanceStatus.TERMINATED
686
+ instance.termination_reason = termination_reason
687
+ logger.info(
688
+ "Terminated instance %s: %s",
689
+ instance.name,
690
+ instance.termination_reason,
691
+ extra={
692
+ "instance_name": instance.name,
693
+ "instance_status": InstanceStatus.TERMINATED.value,
694
+ },
695
+ )
631
696
 
632
697
 
633
698
  async def _check_instance(instance: InstanceModel) -> None:
@@ -906,12 +971,20 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
906
971
  if instance.fleet is None:
907
972
  return False
908
973
  if (
909
- instance.id == instance.fleet.instances[0].id
974
+ _is_fleet_master_instance(instance)
910
975
  or instance.fleet.instances[0].job_provisioning_data is not None
911
976
  or instance.fleet.instances[0].status == InstanceStatus.TERMINATED
912
977
  ):
913
978
  return False
914
- fleet = fleet_model_to_fleet(instance.fleet)
979
+ return _is_cloud_cluster(instance.fleet)
980
+
981
+
982
+ def _is_fleet_master_instance(instance: InstanceModel) -> bool:
983
+ return instance.fleet is not None and instance.id == instance.fleet.instances[0].id
984
+
985
+
986
+ def _is_cloud_cluster(fleet_model: FleetModel) -> bool:
987
+ fleet = fleet_model_to_fleet(fleet_model)
915
988
  return (
916
989
  fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
917
990
  and fleet.spec.configuration.ssh_config is None
@@ -944,28 +1017,76 @@ def _get_instance_offer_for_instance(
944
1017
  return instance_offer
945
1018
 
946
1019
 
947
- def _create_placement_group_if_does_not_exist(
948
- session: AsyncSession,
949
- fleet_model: FleetModel,
950
- placement_groups: List[PlacementGroup],
951
- name: str,
952
- backend: BackendType,
953
- region: str,
1020
+ def _find_suitable_placement_group(
1021
+ placement_groups: List[PlacementGroupModel],
1022
+ instance_offer: InstanceOffer,
1023
+ compute: ComputeWithPlacementGroupSupport,
954
1024
  ) -> Optional[PlacementGroupModel]:
955
1025
  for pg in placement_groups:
956
- if pg.configuration.backend == backend and pg.configuration.region == region:
957
- return None
1026
+ if compute.is_suitable_placement_group(
1027
+ placement_group_model_to_placement_group(pg), instance_offer
1028
+ ):
1029
+ return pg
1030
+ return None
1031
+
1032
+
1033
+ async def _create_placement_group(
1034
+ fleet_model: FleetModel,
1035
+ master_instance_offer: InstanceOffer,
1036
+ compute: ComputeWithPlacementGroupSupport,
1037
+ ) -> Optional[PlacementGroupModel]:
958
1038
  placement_group_model = PlacementGroupModel(
959
- name=name,
1039
+ # TODO: generate the name in Compute.create_placement_group to allow
1040
+ # backend-specific name length limits
1041
+ name=generate_unique_placement_group_name(
1042
+ project_name=fleet_model.project.name,
1043
+ fleet_name=fleet_model.name,
1044
+ ),
960
1045
  project=fleet_model.project,
961
1046
  fleet=fleet_model,
962
1047
  configuration=PlacementGroupConfiguration(
963
- backend=backend,
964
- region=region,
1048
+ backend=master_instance_offer.backend,
1049
+ region=master_instance_offer.region,
965
1050
  placement_strategy=PlacementStrategy.CLUSTER,
966
1051
  ).json(),
967
1052
  )
968
- session.add(placement_group_model)
1053
+ placement_group = placement_group_model_to_placement_group(placement_group_model)
1054
+ logger.debug(
1055
+ "Creating placement group %s in %s/%s",
1056
+ placement_group.name,
1057
+ placement_group.configuration.backend.value,
1058
+ placement_group.configuration.region,
1059
+ )
1060
+ try:
1061
+ pgpd = await run_async(
1062
+ compute.create_placement_group,
1063
+ placement_group_model_to_placement_group(placement_group_model),
1064
+ master_instance_offer,
1065
+ )
1066
+ except BackendError as e:
1067
+ logger.warning(
1068
+ "Failed to create placement group %s in %s/%s: %r",
1069
+ placement_group.name,
1070
+ placement_group.configuration.backend.value,
1071
+ placement_group.configuration.region,
1072
+ e,
1073
+ )
1074
+ return None
1075
+ except Exception:
1076
+ logger.exception(
1077
+ "Got exception when creating placement group %s in %s/%s",
1078
+ placement_group.name,
1079
+ placement_group.configuration.backend.value,
1080
+ placement_group.configuration.region,
1081
+ )
1082
+ return None
1083
+ logger.info(
1084
+ "Created placement group %s in %s/%s",
1085
+ placement_group.name,
1086
+ placement_group.configuration.backend.value,
1087
+ placement_group.configuration.region,
1088
+ )
1089
+ placement_group_model.provisioning_data = pgpd.json()
969
1090
  return placement_group_model
970
1091
 
971
1092
 
@@ -66,7 +66,7 @@ async def _delete_placement_groups(
66
66
 
67
67
 
68
68
  async def _delete_placement_group(placement_group_model: PlacementGroupModel):
69
- logger.info("Deleting placement group %s", placement_group_model.name)
69
+ logger.debug("Deleting placement group %s", placement_group_model.name)
70
70
  placement_group = placement_group_model_to_placement_group(placement_group_model)
71
71
  if placement_group.provisioning_data is None:
72
72
  logger.error(
@@ -99,11 +99,14 @@ async def _collect_jobs_metrics(job_models: list[JobModel], collected_at: dateti
99
99
 
100
100
 
101
101
  async def _collect_job_metrics(job_model: JobModel) -> Optional[str]:
102
- ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
103
102
  jpd = get_job_provisioning_data(job_model)
104
- jrd = get_job_runtime_data(job_model)
105
103
  if jpd is None:
106
104
  return None
105
+ if not jpd.dockerized:
106
+ # Container-based backend, no shim
107
+ return None
108
+ ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
109
+ jrd = get_job_runtime_data(job_model)
107
110
  try:
108
111
  res = await run_async(
109
112
  _pull_job_metrics,
@@ -659,6 +659,7 @@ class PlacementGroupModel(BaseModel):
659
659
 
660
660
  fleet_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("fleets.id"))
661
661
  fleet: Mapped["FleetModel"] = relationship(foreign_keys=[fleet_id])
662
+ # TODO: rename `fleet_deleted` -> `to_be_deleted`
662
663
  fleet_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
663
664
 
664
665
  created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)