dstack 0.18.40rc1__py3-none-any.whl → 0.18.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +21 -9
  11. dstack/_internal/core/backends/aws/compute.py +92 -52
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +30 -23
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +20 -0
  28. dstack/_internal/core/models/volumes.py +3 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +72 -33
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
  42. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  43. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  44. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  45. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  46. dstack/_internal/server/models.py +10 -4
  47. dstack/_internal/server/routers/runs.py +1 -0
  48. dstack/_internal/server/schemas/runner.py +1 -0
  49. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  50. dstack/_internal/server/services/config.py +9 -0
  51. dstack/_internal/server/services/fleets.py +27 -2
  52. dstack/_internal/server/services/gateways/client.py +9 -1
  53. dstack/_internal/server/services/jobs/__init__.py +215 -43
  54. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  55. dstack/_internal/server/services/offers.py +91 -5
  56. dstack/_internal/server/services/pools.py +95 -11
  57. dstack/_internal/server/services/proxy/repo.py +17 -3
  58. dstack/_internal/server/services/runner/client.py +1 -1
  59. dstack/_internal/server/services/runner/ssh.py +33 -5
  60. dstack/_internal/server/services/runs.py +48 -179
  61. dstack/_internal/server/services/services/__init__.py +9 -1
  62. dstack/_internal/server/statics/index.html +1 -1
  63. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  64. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  65. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  66. dstack/_internal/server/testing/common.py +117 -52
  67. dstack/_internal/utils/common.py +22 -8
  68. dstack/_internal/utils/env.py +14 -0
  69. dstack/_internal/utils/ssh.py +1 -1
  70. dstack/api/server/_fleets.py +25 -1
  71. dstack/api/server/_runs.py +23 -2
  72. dstack/api/server/_volumes.py +12 -1
  73. dstack/version.py +1 -1
  74. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
  75. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
  76. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  77. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  78. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  79. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  80. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  81. tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
  82. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
  85. tests/_internal/server/routers/test_fleets.py +15 -2
  86. tests/_internal/server/routers/test_pools.py +6 -0
  87. tests/_internal/server/routers/test_runs.py +27 -0
  88. tests/_internal/server/services/jobs/__init__.py +0 -0
  89. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  90. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  91. tests/_internal/server/services/test_pools.py +4 -0
  92. tests/_internal/server/services/test_runs.py +5 -41
  93. tests/_internal/utils/test_common.py +21 -0
  94. tests/_internal/utils/test_env.py +38 -0
  95. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
  96. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
  97. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
  98. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
@@ -109,34 +109,21 @@ class AWSCompute(Compute):
109
109
  configurable_disk_size=CONFIGURABLE_DISK_SIZE,
110
110
  extra_filter=filter,
111
111
  )
112
- regions = set(i.region for i in offers)
113
-
114
- def get_quotas(client: botocore.client.BaseClient) -> Dict[str, int]:
115
- region_quotas = {}
116
- for page in client.get_paginator("list_service_quotas").paginate(ServiceCode="ec2"):
117
- for q in page["Quotas"]:
118
- if "On-Demand" in q["QuotaName"]:
119
- region_quotas[q["UsageMetric"]["MetricDimensions"]["Class"]] = q["Value"]
120
- return region_quotas
121
-
122
- quotas = {}
123
- with ThreadPoolExecutor(max_workers=8) as executor:
124
- future_to_region = {}
125
- for region in regions:
126
- future = executor.submit(
127
- get_quotas, self.session.client("service-quotas", region_name=region)
128
- )
129
- future_to_region[future] = region
130
- for future in as_completed(future_to_region):
131
- quotas[future_to_region[future]] = future.result()
112
+ regions = list(set(i.region for i in offers))
113
+ regions_to_quotas = _get_regions_to_quotas(self.session, regions)
114
+ regions_to_zones = _get_regions_to_zones(self.session, regions)
132
115
 
133
116
  availability_offers = []
134
117
  for offer in offers:
135
118
  availability = InstanceAvailability.UNKNOWN
136
- if not _has_quota(quotas[offer.region], offer.instance.name):
119
+ if not _has_quota(regions_to_quotas[offer.region], offer.instance.name):
137
120
  availability = InstanceAvailability.NO_QUOTA
138
121
  availability_offers.append(
139
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
122
+ InstanceOfferWithAvailability(
123
+ **offer.dict(),
124
+ availability=availability,
125
+ availability_zones=regions_to_zones[offer.region],
126
+ )
140
127
  )
141
128
  return availability_offers
142
129
 
@@ -161,9 +148,9 @@ class AWSCompute(Compute):
161
148
  ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
162
149
  ec2_client = self.session.client("ec2", region_name=instance_offer.region)
163
150
  allocate_public_ip = self.config.allocate_public_ips
164
- availability_zones = None
165
- if instance_config.availability_zone is not None:
166
- availability_zones = [instance_config.availability_zone]
151
+ zones = instance_offer.availability_zones
152
+ if zones is not None and len(zones) == 0:
153
+ raise NoCapacityError("No eligible availability zones")
167
154
 
168
155
  tags = {
169
156
  "Name": instance_config.instance_name,
@@ -174,7 +161,7 @@ class AWSCompute(Compute):
174
161
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
175
162
 
176
163
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
177
- max_efa_interfaces = get_maximum_efa_interfaces(
164
+ max_efa_interfaces = _get_maximum_efa_interfaces(
178
165
  ec2_client=ec2_client, instance_type=instance_offer.instance.name
179
166
  )
180
167
  enable_efa = max_efa_interfaces > 0
@@ -185,7 +172,7 @@ class AWSCompute(Compute):
185
172
  config=self.config,
186
173
  region=instance_offer.region,
187
174
  allocate_public_ip=allocate_public_ip,
188
- availability_zones=availability_zones,
175
+ availability_zones=zones,
189
176
  )
190
177
  subnet_id_to_az_map = aws_resources.get_subnets_availability_zones(
191
178
  ec2_client=ec2_client,
@@ -210,11 +197,11 @@ class AWSCompute(Compute):
210
197
  except botocore.exceptions.ClientError as e:
211
198
  logger.warning("Got botocore.exceptions.ClientError: %s", e)
212
199
  raise NoCapacityError()
213
- tried_availability_zones = set()
200
+ tried_zones = set()
214
201
  for subnet_id, az in subnet_id_to_az_map.items():
215
- if az in tried_availability_zones:
202
+ if az in tried_zones:
216
203
  continue
217
- tried_availability_zones.add(az)
204
+ tried_zones.add(az)
218
205
  try:
219
206
  logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
220
207
  image_id, username = aws_resources.get_image_id_and_username(
@@ -240,6 +227,7 @@ class AWSCompute(Compute):
240
227
  allocate_public_ip=allocate_public_ip,
241
228
  placement_group_name=instance_config.placement_group_name,
242
229
  enable_efa=enable_efa,
230
+ max_efa_interfaces=max_efa_interfaces,
243
231
  reservation_id=instance_config.reservation,
244
232
  is_capacity_block=is_capacity_block,
245
233
  )
@@ -283,6 +271,7 @@ class AWSCompute(Compute):
283
271
  project_ssh_private_key: str,
284
272
  volumes: List[Volume],
285
273
  ) -> JobProvisioningData:
274
+ # TODO: run_job is the same for vm-based backends, refactor
286
275
  instance_config = InstanceConfiguration(
287
276
  project_name=run.project_name,
288
277
  instance_name=get_instance_name(run, job), # TODO: generate name
@@ -290,15 +279,25 @@ class AWSCompute(Compute):
290
279
  SSHKey(public=project_ssh_public_key.strip()),
291
280
  ],
292
281
  user=run.user,
282
+ volumes=volumes,
293
283
  reservation=run.run_spec.configuration.reservation,
294
284
  )
285
+ instance_offer = instance_offer.copy()
295
286
  if len(volumes) > 0:
296
287
  volume = volumes[0]
297
288
  if (
298
289
  volume.provisioning_data is not None
299
290
  and volume.provisioning_data.availability_zone is not None
300
291
  ):
301
- instance_config.availability_zone = volume.provisioning_data.availability_zone
292
+ if instance_offer.availability_zones is None:
293
+ instance_offer.availability_zones = [
294
+ volume.provisioning_data.availability_zone
295
+ ]
296
+ instance_offer.availability_zones = [
297
+ z
298
+ for z in instance_offer.availability_zones
299
+ if z == volume.provisioning_data.availability_zone
300
+ ]
302
301
  return self.create_instance(instance_offer, instance_config)
303
302
 
304
303
  def create_placement_group(
@@ -544,14 +543,16 @@ class AWSCompute(Compute):
544
543
  }
545
544
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
546
545
 
547
- zone = aws_resources.get_availability_zone(
546
+ zones = aws_resources.get_availability_zones(
548
547
  ec2_client=ec2_client, region=volume.configuration.region
549
548
  )
550
- if zone is None:
549
+ if volume.configuration.availability_zone is not None:
550
+ zones = [z for z in zones if z == volume.configuration.availability_zone]
551
+ if len(zones) == 0:
551
552
  raise ComputeError(
552
553
  f"Failed to find availability zone in region {volume.configuration.region}"
553
554
  )
554
-
555
+ zone = zones[0]
555
556
  volume_type = "gp3"
556
557
 
557
558
  logger.debug("Creating EBS volume %s", volume.configuration.name)
@@ -570,7 +571,6 @@ class AWSCompute(Compute):
570
571
 
571
572
  size = response["Size"]
572
573
  iops = response["Iops"]
573
-
574
574
  return VolumeProvisioningData(
575
575
  backend=BackendType.AWS,
576
576
  volume_id=response["VolumeId"],
@@ -672,23 +672,6 @@ class AWSCompute(Compute):
672
672
  return True
673
673
 
674
674
 
675
- def get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
676
- try:
677
- response = ec2_client.describe_instance_types(
678
- InstanceTypes=[instance_type],
679
- Filters=[{"Name": "network-info.efa-supported", "Values": ["true"]}],
680
- )
681
- except botocore.exceptions.ClientError as e:
682
- if e.response.get("Error", {}).get("Code") == "InvalidInstanceType":
683
- # "The following supplied instance types do not exist: [<instance_type>]"
684
- return 0
685
- raise
686
- instance_types = response["InstanceTypes"]
687
- if not instance_types:
688
- return 0
689
- return instance_types[0]["NetworkInfo"]["EfaInfo"]["MaximumEfaInterfaces"]
690
-
691
-
692
675
  def get_vpc_id_subnet_id_or_error(
693
676
  ec2_client: botocore.client.BaseClient,
694
677
  config: AWSConfig,
@@ -770,6 +753,30 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
770
753
  )
771
754
 
772
755
 
756
+ def _get_regions_to_quotas(
757
+ session: boto3.Session, regions: List[str]
758
+ ) -> Dict[str, Dict[str, int]]:
759
+ def get_region_quotas(client: botocore.client.BaseClient) -> Dict[str, int]:
760
+ region_quotas = {}
761
+ for page in client.get_paginator("list_service_quotas").paginate(ServiceCode="ec2"):
762
+ for q in page["Quotas"]:
763
+ if "On-Demand" in q["QuotaName"]:
764
+ region_quotas[q["UsageMetric"]["MetricDimensions"]["Class"]] = q["Value"]
765
+ return region_quotas
766
+
767
+ regions_to_quotas = {}
768
+ with ThreadPoolExecutor(max_workers=8) as executor:
769
+ future_to_region = {}
770
+ for region in regions:
771
+ future = executor.submit(
772
+ get_region_quotas, session.client("service-quotas", region_name=region)
773
+ )
774
+ future_to_region[future] = region
775
+ for future in as_completed(future_to_region):
776
+ regions_to_quotas[future_to_region[future]] = future.result()
777
+ return regions_to_quotas
778
+
779
+
773
780
  def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool:
774
781
  if instance_name.startswith("p"):
775
782
  return quotas.get("P/OnDemand", 0) > 0
@@ -778,6 +785,22 @@ def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool:
778
785
  return quotas.get("Standard/OnDemand", 0) > 0
779
786
 
780
787
 
788
+ def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[str, List[str]]:
789
+ regions_to_zones = {}
790
+ with ThreadPoolExecutor(max_workers=8) as executor:
791
+ future_to_region = {}
792
+ for region in regions:
793
+ future = executor.submit(
794
+ aws_resources.get_availability_zones,
795
+ session.client("ec2", region_name=region),
796
+ region,
797
+ )
798
+ future_to_region[future] = region
799
+ for future in as_completed(future_to_region):
800
+ regions_to_zones[future_to_region[future]] = future.result()
801
+ return regions_to_zones
802
+
803
+
781
804
  def _supported_instances(offer: InstanceOffer) -> bool:
782
805
  for family in [
783
806
  "t2.small",
@@ -798,6 +821,23 @@ def _supported_instances(offer: InstanceOffer) -> bool:
798
821
  return False
799
822
 
800
823
 
824
+ def _get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
825
+ try:
826
+ response = ec2_client.describe_instance_types(
827
+ InstanceTypes=[instance_type],
828
+ Filters=[{"Name": "network-info.efa-supported", "Values": ["true"]}],
829
+ )
830
+ except botocore.exceptions.ClientError as e:
831
+ if e.response.get("Error", {}).get("Code") == "InvalidInstanceType":
832
+ # "The following supplied instance types do not exist: [<instance_type>]"
833
+ return 0
834
+ raise
835
+ instance_types = response["InstanceTypes"]
836
+ if not instance_types:
837
+ return 0
838
+ return instance_types[0]["NetworkInfo"]["EfaInfo"]["MaximumEfaInterfaces"]
839
+
840
+
801
841
  def _get_instance_ip(instance: Any, public_ip: bool) -> str:
802
842
  if public_ip:
803
843
  return instance.public_ip_address
@@ -140,6 +140,7 @@ def create_instances_struct(
140
140
  allocate_public_ip: bool = True,
141
141
  placement_group_name: Optional[str] = None,
142
142
  enable_efa: bool = False,
143
+ max_efa_interfaces: int = 0,
143
144
  reservation_id: Optional[str] = None,
144
145
  is_capacity_block: bool = False,
145
146
  ) -> Dict[str, Any]:
@@ -183,7 +184,7 @@ def create_instances_struct(
183
184
  # AWS allows specifying either NetworkInterfaces for specific subnet_id
184
185
  # or instance-level SecurityGroupIds in case of no specific subnet_id, not both.
185
186
  if subnet_id is not None:
186
- # Even if the instance type supports multiple cards, we always request only one interface
187
+ # If the instance type supports multiple cards, we request multiple interfaces only if not allocate_public_ip
187
188
  # due to the limitation: "AssociatePublicIpAddress [...] You cannot specify more than one
188
189
  # network interface in the request".
189
190
  # Error message: "(InvalidParameterCombination) when calling the RunInstances operation:
@@ -199,9 +200,28 @@ def create_instances_struct(
199
200
  "DeviceIndex": 0,
200
201
  "SubnetId": subnet_id,
201
202
  "Groups": [security_group_id],
202
- "InterfaceType": "efa" if enable_efa else "interface",
203
+ "InterfaceType": "efa" if max_efa_interfaces > 0 else "interface",
203
204
  },
204
205
  ]
206
+
207
+ if max_efa_interfaces > 1 and allocate_public_ip is False:
208
+ for i in range(1, max_efa_interfaces):
209
+ # Set to efa-only to use interfaces exclusively for GPU-to-GPU communication
210
+ interface_type = "efa-only"
211
+ if instance_type == "p5.48xlarge":
212
+ # EFA configuration for P5 instances:
213
+ # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-acc-inst-types.html#efa-for-p5
214
+ interface_type = "efa" if i % 4 == 0 else "efa-only"
215
+ struct["NetworkInterfaces"].append(
216
+ {
217
+ "AssociatePublicIpAddress": allocate_public_ip,
218
+ "NetworkCardIndex": i,
219
+ "DeviceIndex": 1,
220
+ "SubnetId": subnet_id,
221
+ "Groups": [security_group_id],
222
+ "InterfaceType": interface_type,
223
+ }
224
+ )
205
225
  else:
206
226
  struct["SecurityGroupIds"] = [security_group_id]
207
227
 
@@ -370,16 +390,6 @@ def get_subnets_ids_for_vpc(
370
390
  return subnets_ids
371
391
 
372
392
 
373
- def get_availability_zone(ec2_client: botocore.client.BaseClient, region: str) -> Optional[str]:
374
- zone_names = get_availability_zones(
375
- ec2_client=ec2_client,
376
- region=region,
377
- )
378
- if len(zone_names) == 0:
379
- return None
380
- return zone_names[0]
381
-
382
-
383
393
  def get_availability_zones(ec2_client: botocore.client.BaseClient, region: str) -> List[str]:
384
394
  response = ec2_client.describe_availability_zones(
385
395
  Filters=[
@@ -133,6 +133,8 @@ class AzureCompute(Compute):
133
133
  }
134
134
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
135
135
 
136
+ # TODO: Support custom availability_zones.
137
+ # Currently, VMs are regional, which means they don't have zone info.
136
138
  vm = _launch_instance(
137
139
  compute_client=self._compute_client,
138
140
  subscription_id=self.config.subscription_id,
@@ -371,7 +371,16 @@ def get_docker_commands(
371
371
  "rm -rf /run/sshd && mkdir -p /run/sshd && chown root:root /run/sshd",
372
372
  "rm -rf /var/empty && mkdir -p /var/empty && chown root:root /var/empty",
373
373
  # start sshd
374
- f"/usr/sbin/sshd -p {DSTACK_RUNNER_SSH_PORT} -o PidFile=none -o PasswordAuthentication=no -o AllowTcpForwarding=yes -o PermitUserEnvironment=yes",
374
+ (
375
+ "/usr/sbin/sshd"
376
+ f" -p {DSTACK_RUNNER_SSH_PORT}"
377
+ " -o PidFile=none"
378
+ " -o PasswordAuthentication=no"
379
+ " -o AllowTcpForwarding=yes"
380
+ " -o PermitUserEnvironment=yes"
381
+ " -o ClientAliveInterval=30"
382
+ " -o ClientAliveCountMax=4"
383
+ ),
375
384
  # restore ld.so variables
376
385
  'if [ -n "$_LD_LIBRARY_PATH" ]; then export LD_LIBRARY_PATH="$_LD_LIBRARY_PATH"; fi',
377
386
  'if [ -n "$_LD_PRELOAD" ]; then export LD_PRELOAD="$_LD_PRELOAD"; fi',
@@ -381,7 +390,16 @@ def get_docker_commands(
381
390
  commands += [
382
391
  f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}",
383
392
  f"chmod +x {DSTACK_RUNNER_BINARY_PATH}",
384
- f"{DSTACK_RUNNER_BINARY_PATH} --log-level 6 start --http-port {DSTACK_RUNNER_HTTP_PORT} --temp-dir /tmp/runner --home-dir /root --working-dir /workflow",
393
+ (
394
+ f"{DSTACK_RUNNER_BINARY_PATH}"
395
+ " --log-level 6"
396
+ " start"
397
+ f" --http-port {DSTACK_RUNNER_HTTP_PORT}"
398
+ f" --ssh-port {DSTACK_RUNNER_SSH_PORT}"
399
+ " --temp-dir /tmp/runner"
400
+ " --home-dir /root"
401
+ " --working-dir /workflow"
402
+ ),
385
403
  ]
386
404
 
387
405
  return commands
@@ -94,21 +94,25 @@ class GCPCompute(Compute):
94
94
  for quota in region.quotas:
95
95
  quotas[region.name][quota.metric] = quota.limit - quota.usage
96
96
 
97
- seen_region_offers = set()
97
+ offer_keys_to_offers = {}
98
98
  offers_with_availability = []
99
99
  for offer in offers:
100
100
  region = offer.region[:-2] # strip zone
101
101
  key = (_unique_instance_name(offer.instance), region)
102
- if key in seen_region_offers:
102
+ if key in offer_keys_to_offers:
103
+ offer_keys_to_offers[key].availability_zones.append(offer.region)
103
104
  continue
104
- seen_region_offers.add(key)
105
105
  availability = InstanceAvailability.NO_QUOTA
106
106
  if _has_gpu_quota(quotas[region], offer.instance.resources):
107
107
  availability = InstanceAvailability.UNKNOWN
108
108
  # todo quotas: cpu, memory, global gpu, tpu
109
- offers_with_availability.append(
110
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
109
+ offer_with_availability = InstanceOfferWithAvailability(
110
+ **offer.dict(),
111
+ availability=availability,
112
+ availability_zones=[offer.region],
111
113
  )
114
+ offer_keys_to_offers[key] = offer_with_availability
115
+ offers_with_availability.append(offer_with_availability)
112
116
  offers_with_availability[-1].region = region
113
117
 
114
118
  return offers_with_availability
@@ -156,10 +160,10 @@ class GCPCompute(Compute):
156
160
  )
157
161
  authorized_keys = instance_config.get_public_keys()
158
162
 
159
- zones = _get_instance_zones(instance_offer)
160
- if instance_config.availability_zone:
161
- zones = [z for z in zones if z == instance_config.availability_zone]
162
-
163
+ # get_offers always fills instance_offer.availability_zones
164
+ zones = get_or_error(instance_offer.availability_zones)
165
+ if len(zones) == 0:
166
+ raise NoCapacityError("No eligible availability zones")
163
167
  # If a shared VPC is not used, we can create firewall rules for user
164
168
  if self.config.vpc_project_id is None:
165
169
  gcp_resources.create_runner_firewall_rules(
@@ -371,6 +375,7 @@ class GCPCompute(Compute):
371
375
  project_ssh_private_key: str,
372
376
  volumes: List[Volume],
373
377
  ) -> JobProvisioningData:
378
+ # TODO: run_job is the same for vm-based backends, refactor
374
379
  instance_config = InstanceConfiguration(
375
380
  project_name=run.project_name,
376
381
  instance_name=get_instance_name(run, job), # TODO: generate name
@@ -379,14 +384,24 @@ class GCPCompute(Compute):
379
384
  ],
380
385
  user=run.user,
381
386
  volumes=volumes,
387
+ reservation=run.run_spec.configuration.reservation,
382
388
  )
389
+ instance_offer = instance_offer.copy()
383
390
  if len(volumes) > 0:
384
391
  volume = volumes[0]
385
392
  if (
386
393
  volume.provisioning_data is not None
387
394
  and volume.provisioning_data.availability_zone is not None
388
395
  ):
389
- instance_config.availability_zone = volume.provisioning_data.availability_zone
396
+ if instance_offer.availability_zones is None:
397
+ instance_offer.availability_zones = [
398
+ volume.provisioning_data.availability_zone
399
+ ]
400
+ instance_offer.availability_zones = [
401
+ z
402
+ for z in instance_offer.availability_zones
403
+ if z == volume.provisioning_data.availability_zone
404
+ ]
390
405
  return self.create_instance(instance_offer, instance_config)
391
406
 
392
407
  def create_gateway(
@@ -497,15 +512,18 @@ class GCPCompute(Compute):
497
512
  raise ComputeError(f"Persistent disk {volume.configuration.volume_id} not found")
498
513
 
499
514
  def create_volume(self, volume: Volume) -> VolumeProvisioningData:
500
- zone = gcp_resources.get_availability_zone(
515
+ zones = gcp_resources.get_availability_zones(
501
516
  regions_client=self.regions_client,
502
517
  project_id=self.config.project_id,
503
518
  region=volume.configuration.region,
504
519
  )
505
- if zone is None:
520
+ if volume.configuration.availability_zone is not None:
521
+ zones = [z for z in zones if z == volume.configuration.availability_zone]
522
+ if len(zones) == 0:
506
523
  raise ComputeError(
507
524
  f"Failed to find availability zone in region {volume.configuration.region}"
508
525
  )
526
+ zone = zones[0]
509
527
 
510
528
  labels = {
511
529
  "owner": "dstack",
@@ -759,17 +777,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
759
777
  return f"{name}-{gpu.name}-{gpu.memory_mib}"
760
778
 
761
779
 
762
- def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]:
763
- zones = []
764
- for offer in get_catalog_offers(backend=BackendType.GCP):
765
- if _unique_instance_name(instance_offer.instance) != _unique_instance_name(offer.instance):
766
- continue
767
- if offer.region[:-2] != instance_offer.region:
768
- continue
769
- zones.append(offer.region)
770
- return zones
771
-
772
-
773
780
  def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
774
781
  commands = get_shim_commands(
775
782
  authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU"
@@ -31,21 +31,6 @@ supported_accelerators = [
31
31
  ]
32
32
 
33
33
 
34
- def get_availability_zone(
35
- regions_client: compute_v1.RegionsClient,
36
- project_id: str,
37
- region: str,
38
- ) -> Optional[str]:
39
- zones = get_availability_zones(
40
- regions_client=regions_client,
41
- project_id=project_id,
42
- region=region,
43
- )
44
- if len(zones) == 0:
45
- return None
46
- return zones[0]
47
-
48
-
49
34
  def get_availability_zones(
50
35
  regions_client: compute_v1.RegionsClient,
51
36
  project_id: str,
@@ -76,7 +76,13 @@ class OCICompute(Compute):
76
76
  else:
77
77
  availability = InstanceAvailability.NO_QUOTA
78
78
  offers_with_availability.append(
79
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
79
+ InstanceOfferWithAvailability(
80
+ **offer.dict(),
81
+ availability=availability,
82
+ availability_zones=shapes_availability[offer.region].get(
83
+ offer.instance.name, []
84
+ ),
85
+ )
80
86
  )
81
87
 
82
88
  return offers_with_availability
@@ -111,11 +117,9 @@ class OCICompute(Compute):
111
117
  ) -> JobProvisioningData:
112
118
  region = self.regions[instance_offer.region]
113
119
 
114
- availability_domain = resources.choose_available_domain(
115
- instance_offer.instance.name, self.shapes_quota, region, self.config.compartment_id
116
- )
117
- if availability_domain is None:
120
+ if not instance_offer.availability_zones:
118
121
  raise NoCapacityError("Shape unavailable in all availability domains")
122
+ availability_domain = instance_offer.availability_zones[0]
119
123
 
120
124
  listing, package = resources.get_marketplace_listing_and_package(
121
125
  cuda=len(instance_offer.instance.resources.gpus) > 0,
@@ -170,6 +174,7 @@ class OCICompute(Compute):
170
174
  hostname=None,
171
175
  internal_ip=None,
172
176
  region=instance_offer.region,
177
+ availability_zone=availability_domain,
173
178
  price=instance_offer.price,
174
179
  username="ubuntu",
175
180
  ssh_port=22,
@@ -203,34 +203,29 @@ def check_availability_in_domain(
203
203
  return available
204
204
 
205
205
 
206
- def check_availability_in_region(
206
+ def check_availability_per_domain(
207
207
  shape_names: Iterable[str],
208
208
  shapes_quota: ShapesQuota,
209
209
  region: OCIRegionClient,
210
210
  compartment_id: str,
211
- ) -> Set[str]:
212
- """
213
- Returns a subset of `shape_names` with only the shapes available in at least
214
- one availability domain within `region`.
215
- """
216
-
211
+ ) -> Dict[str, Set[str]]:
217
212
  all_shapes = set(shape_names)
218
- available_shapes = set()
213
+ available_shapes_per_domain = {}
219
214
 
220
215
  for availability_domain in region.availability_domains:
221
216
  shapes_to_check = {
222
217
  shape
223
- for shape in all_shapes - available_shapes
218
+ for shape in all_shapes
224
219
  if shapes_quota.is_within_domain_quota(shape, availability_domain.name)
225
220
  }
226
- available_shapes |= check_availability_in_domain(
221
+ available_shapes_per_domain[availability_domain.name] = check_availability_in_domain(
227
222
  shape_names=shapes_to_check,
228
223
  availability_domain_name=availability_domain.name,
229
224
  client=region.compute_client,
230
225
  compartment_id=compartment_id,
231
226
  )
232
227
 
233
- return available_shapes
228
+ return available_shapes_per_domain
234
229
 
235
230
 
236
231
  def get_shapes_availability(
@@ -239,12 +234,11 @@ def get_shapes_availability(
239
234
  regions: Mapping[str, OCIRegionClient],
240
235
  compartment_id: str,
241
236
  executor: Executor,
242
- ) -> Dict[str, Set[str]]:
237
+ ) -> Dict[str, Dict[str, List[str]]]:
243
238
  """
244
- Returns a mapping of region names to sets of shape names available in these
245
- regions. Only shapes from `offers` are checked.
239
+ Returns availability domains where shapes are available as regions->shapes->availability_domains mapping.
240
+ Only shapes from `offers` are checked.
246
241
  """
247
-
248
242
  shape_names_per_region = {region: set() for region in regions}
249
243
  for offer in offers:
250
244
  if shapes_quota.is_within_region_quota(offer.instance.name, offer.region):
@@ -253,7 +247,7 @@ def get_shapes_availability(
253
247
  future_to_region_name = {}
254
248
  for region_name, shape_names in shape_names_per_region.items():
255
249
  future = executor.submit(
256
- check_availability_in_region,
250
+ check_availability_per_domain,
257
251
  shape_names,
258
252
  shapes_quota,
259
253
  regions[region_name],
@@ -263,29 +257,32 @@ def get_shapes_availability(
263
257
 
264
258
  result = {}
265
259
  for future in as_completed(future_to_region_name):
266
- region_name = future_to_region_name[future]
267
- result[region_name] = future.result()
260
+ domains_to_shape_names = future.result()
261
+ shape_names_to_domains = {}
262
+ for domain, shape_names in domains_to_shape_names.items():
263
+ for shape_name in shape_names:
264
+ shape_names_to_domains.setdefault(shape_name, []).append(domain)
265
+ result[future_to_region_name[future]] = shape_names_to_domains
268
266
 
269
267
  return result
270
268
 
271
269
 
272
- def choose_available_domain(
270
+ def get_available_domains(
273
271
  shape_name: str, shapes_quota: ShapesQuota, region: OCIRegionClient, compartment_id: str
274
- ) -> Optional[str]:
272
+ ) -> List[str]:
275
273
  """
276
- Returns the name of any availability domain within `region` in which
277
- `shape_name` is available. None if the shape is unavailable or not within
278
- `shapes_quota` in all domains.
274
+ Returns the names of all availability domains in `region` in which
275
+ `shape_name` is available and within `shapes_quota`.
279
276
  """
280
-
277
+ domains = []
281
278
  for domain in region.availability_domains:
282
279
  if shapes_quota.is_within_domain_quota(
283
280
  shape_name, domain.name
284
281
  ) and check_availability_in_domain(
285
282
  {shape_name}, domain.name, region.compute_client, compartment_id
286
283
  ):
287
- return domain.name
288
- return None
284
+ domains.append(domain.name)
285
+ return domains
289
286
 
290
287
 
291
288
  def get_instance_vnic(