dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -109,34 +109,21 @@ class AWSCompute(Compute):
109
109
  configurable_disk_size=CONFIGURABLE_DISK_SIZE,
110
110
  extra_filter=filter,
111
111
  )
112
- regions = set(i.region for i in offers)
113
-
114
- def get_quotas(client: botocore.client.BaseClient) -> Dict[str, int]:
115
- region_quotas = {}
116
- for page in client.get_paginator("list_service_quotas").paginate(ServiceCode="ec2"):
117
- for q in page["Quotas"]:
118
- if "On-Demand" in q["QuotaName"]:
119
- region_quotas[q["UsageMetric"]["MetricDimensions"]["Class"]] = q["Value"]
120
- return region_quotas
121
-
122
- quotas = {}
123
- with ThreadPoolExecutor(max_workers=8) as executor:
124
- future_to_region = {}
125
- for region in regions:
126
- future = executor.submit(
127
- get_quotas, self.session.client("service-quotas", region_name=region)
128
- )
129
- future_to_region[future] = region
130
- for future in as_completed(future_to_region):
131
- quotas[future_to_region[future]] = future.result()
112
+ regions = list(set(i.region for i in offers))
113
+ regions_to_quotas = _get_regions_to_quotas(self.session, regions)
114
+ regions_to_zones = _get_regions_to_zones(self.session, regions)
132
115
 
133
116
  availability_offers = []
134
117
  for offer in offers:
135
118
  availability = InstanceAvailability.UNKNOWN
136
- if not _has_quota(quotas[offer.region], offer.instance.name):
119
+ if not _has_quota(regions_to_quotas[offer.region], offer.instance.name):
137
120
  availability = InstanceAvailability.NO_QUOTA
138
121
  availability_offers.append(
139
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
122
+ InstanceOfferWithAvailability(
123
+ **offer.dict(),
124
+ availability=availability,
125
+ availability_zones=regions_to_zones[offer.region],
126
+ )
140
127
  )
141
128
  return availability_offers
142
129
 
@@ -161,9 +148,9 @@ class AWSCompute(Compute):
161
148
  ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
162
149
  ec2_client = self.session.client("ec2", region_name=instance_offer.region)
163
150
  allocate_public_ip = self.config.allocate_public_ips
164
- availability_zones = None
165
- if instance_config.availability_zone is not None:
166
- availability_zones = [instance_config.availability_zone]
151
+ zones = instance_offer.availability_zones
152
+ if zones is not None and len(zones) == 0:
153
+ raise NoCapacityError("No eligible availability zones")
167
154
 
168
155
  tags = {
169
156
  "Name": instance_config.instance_name,
@@ -174,7 +161,7 @@ class AWSCompute(Compute):
174
161
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
175
162
 
176
163
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
177
- max_efa_interfaces = get_maximum_efa_interfaces(
164
+ max_efa_interfaces = _get_maximum_efa_interfaces(
178
165
  ec2_client=ec2_client, instance_type=instance_offer.instance.name
179
166
  )
180
167
  enable_efa = max_efa_interfaces > 0
@@ -185,7 +172,7 @@ class AWSCompute(Compute):
185
172
  config=self.config,
186
173
  region=instance_offer.region,
187
174
  allocate_public_ip=allocate_public_ip,
188
- availability_zones=availability_zones,
175
+ availability_zones=zones,
189
176
  )
190
177
  subnet_id_to_az_map = aws_resources.get_subnets_availability_zones(
191
178
  ec2_client=ec2_client,
@@ -210,11 +197,11 @@ class AWSCompute(Compute):
210
197
  except botocore.exceptions.ClientError as e:
211
198
  logger.warning("Got botocore.exceptions.ClientError: %s", e)
212
199
  raise NoCapacityError()
213
- tried_availability_zones = set()
200
+ tried_zones = set()
214
201
  for subnet_id, az in subnet_id_to_az_map.items():
215
- if az in tried_availability_zones:
202
+ if az in tried_zones:
216
203
  continue
217
- tried_availability_zones.add(az)
204
+ tried_zones.add(az)
218
205
  try:
219
206
  logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
220
207
  image_id, username = aws_resources.get_image_id_and_username(
@@ -240,6 +227,7 @@ class AWSCompute(Compute):
240
227
  allocate_public_ip=allocate_public_ip,
241
228
  placement_group_name=instance_config.placement_group_name,
242
229
  enable_efa=enable_efa,
230
+ max_efa_interfaces=max_efa_interfaces,
243
231
  reservation_id=instance_config.reservation,
244
232
  is_capacity_block=is_capacity_block,
245
233
  )
@@ -283,6 +271,7 @@ class AWSCompute(Compute):
283
271
  project_ssh_private_key: str,
284
272
  volumes: List[Volume],
285
273
  ) -> JobProvisioningData:
274
+ # TODO: run_job is the same for vm-based backends, refactor
286
275
  instance_config = InstanceConfiguration(
287
276
  project_name=run.project_name,
288
277
  instance_name=get_instance_name(run, job), # TODO: generate name
@@ -290,15 +279,25 @@ class AWSCompute(Compute):
290
279
  SSHKey(public=project_ssh_public_key.strip()),
291
280
  ],
292
281
  user=run.user,
282
+ volumes=volumes,
293
283
  reservation=run.run_spec.configuration.reservation,
294
284
  )
285
+ instance_offer = instance_offer.copy()
295
286
  if len(volumes) > 0:
296
287
  volume = volumes[0]
297
288
  if (
298
289
  volume.provisioning_data is not None
299
290
  and volume.provisioning_data.availability_zone is not None
300
291
  ):
301
- instance_config.availability_zone = volume.provisioning_data.availability_zone
292
+ if instance_offer.availability_zones is None:
293
+ instance_offer.availability_zones = [
294
+ volume.provisioning_data.availability_zone
295
+ ]
296
+ instance_offer.availability_zones = [
297
+ z
298
+ for z in instance_offer.availability_zones
299
+ if z == volume.provisioning_data.availability_zone
300
+ ]
302
301
  return self.create_instance(instance_offer, instance_config)
303
302
 
304
303
  def create_placement_group(
@@ -544,14 +543,16 @@ class AWSCompute(Compute):
544
543
  }
545
544
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
546
545
 
547
- zone = aws_resources.get_availability_zone(
546
+ zones = aws_resources.get_availability_zones(
548
547
  ec2_client=ec2_client, region=volume.configuration.region
549
548
  )
550
- if zone is None:
549
+ if volume.configuration.availability_zone is not None:
550
+ zones = [z for z in zones if z == volume.configuration.availability_zone]
551
+ if len(zones) == 0:
551
552
  raise ComputeError(
552
553
  f"Failed to find availability zone in region {volume.configuration.region}"
553
554
  )
554
-
555
+ zone = zones[0]
555
556
  volume_type = "gp3"
556
557
 
557
558
  logger.debug("Creating EBS volume %s", volume.configuration.name)
@@ -570,7 +571,6 @@ class AWSCompute(Compute):
570
571
 
571
572
  size = response["Size"]
572
573
  iops = response["Iops"]
573
-
574
574
  return VolumeProvisioningData(
575
575
  backend=BackendType.AWS,
576
576
  volume_id=response["VolumeId"],
@@ -635,11 +635,12 @@ class AWSCompute(Compute):
635
635
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
636
636
 
637
637
  logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id)
638
+ attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
638
639
  try:
639
640
  ec2_client.detach_volume(
640
641
  VolumeId=volume.volume_id,
641
642
  InstanceId=instance_id,
642
- Device=get_or_error(volume.attachment_data).device_name,
643
+ Device=attachment_data.device_name,
643
644
  Force=force,
644
645
  )
645
646
  except botocore.exceptions.ClientError as e:
@@ -672,23 +673,6 @@ class AWSCompute(Compute):
672
673
  return True
673
674
 
674
675
 
675
- def get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
676
- try:
677
- response = ec2_client.describe_instance_types(
678
- InstanceTypes=[instance_type],
679
- Filters=[{"Name": "network-info.efa-supported", "Values": ["true"]}],
680
- )
681
- except botocore.exceptions.ClientError as e:
682
- if e.response.get("Error", {}).get("Code") == "InvalidInstanceType":
683
- # "The following supplied instance types do not exist: [<instance_type>]"
684
- return 0
685
- raise
686
- instance_types = response["InstanceTypes"]
687
- if not instance_types:
688
- return 0
689
- return instance_types[0]["NetworkInfo"]["EfaInfo"]["MaximumEfaInterfaces"]
690
-
691
-
692
676
  def get_vpc_id_subnet_id_or_error(
693
677
  ec2_client: botocore.client.BaseClient,
694
678
  config: AWSConfig,
@@ -770,6 +754,30 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
770
754
  )
771
755
 
772
756
 
757
+ def _get_regions_to_quotas(
758
+ session: boto3.Session, regions: List[str]
759
+ ) -> Dict[str, Dict[str, int]]:
760
+ def get_region_quotas(client: botocore.client.BaseClient) -> Dict[str, int]:
761
+ region_quotas = {}
762
+ for page in client.get_paginator("list_service_quotas").paginate(ServiceCode="ec2"):
763
+ for q in page["Quotas"]:
764
+ if "On-Demand" in q["QuotaName"]:
765
+ region_quotas[q["UsageMetric"]["MetricDimensions"]["Class"]] = q["Value"]
766
+ return region_quotas
767
+
768
+ regions_to_quotas = {}
769
+ with ThreadPoolExecutor(max_workers=8) as executor:
770
+ future_to_region = {}
771
+ for region in regions:
772
+ future = executor.submit(
773
+ get_region_quotas, session.client("service-quotas", region_name=region)
774
+ )
775
+ future_to_region[future] = region
776
+ for future in as_completed(future_to_region):
777
+ regions_to_quotas[future_to_region[future]] = future.result()
778
+ return regions_to_quotas
779
+
780
+
773
781
  def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool:
774
782
  if instance_name.startswith("p"):
775
783
  return quotas.get("P/OnDemand", 0) > 0
@@ -778,6 +786,22 @@ def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool:
778
786
  return quotas.get("Standard/OnDemand", 0) > 0
779
787
 
780
788
 
789
+ def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[str, List[str]]:
790
+ regions_to_zones = {}
791
+ with ThreadPoolExecutor(max_workers=8) as executor:
792
+ future_to_region = {}
793
+ for region in regions:
794
+ future = executor.submit(
795
+ aws_resources.get_availability_zones,
796
+ session.client("ec2", region_name=region),
797
+ region,
798
+ )
799
+ future_to_region[future] = region
800
+ for future in as_completed(future_to_region):
801
+ regions_to_zones[future_to_region[future]] = future.result()
802
+ return regions_to_zones
803
+
804
+
781
805
  def _supported_instances(offer: InstanceOffer) -> bool:
782
806
  for family in [
783
807
  "t2.small",
@@ -798,6 +822,23 @@ def _supported_instances(offer: InstanceOffer) -> bool:
798
822
  return False
799
823
 
800
824
 
825
+ def _get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
826
+ try:
827
+ response = ec2_client.describe_instance_types(
828
+ InstanceTypes=[instance_type],
829
+ Filters=[{"Name": "network-info.efa-supported", "Values": ["true"]}],
830
+ )
831
+ except botocore.exceptions.ClientError as e:
832
+ if e.response.get("Error", {}).get("Code") == "InvalidInstanceType":
833
+ # "The following supplied instance types do not exist: [<instance_type>]"
834
+ return 0
835
+ raise
836
+ instance_types = response["InstanceTypes"]
837
+ if not instance_types:
838
+ return 0
839
+ return instance_types[0]["NetworkInfo"]["EfaInfo"]["MaximumEfaInterfaces"]
840
+
841
+
801
842
  def _get_instance_ip(instance: Any, public_ip: bool) -> str:
802
843
  if public_ip:
803
844
  return instance.public_ip_address
@@ -140,6 +140,7 @@ def create_instances_struct(
140
140
  allocate_public_ip: bool = True,
141
141
  placement_group_name: Optional[str] = None,
142
142
  enable_efa: bool = False,
143
+ max_efa_interfaces: int = 0,
143
144
  reservation_id: Optional[str] = None,
144
145
  is_capacity_block: bool = False,
145
146
  ) -> Dict[str, Any]:
@@ -183,7 +184,7 @@ def create_instances_struct(
183
184
  # AWS allows specifying either NetworkInterfaces for specific subnet_id
184
185
  # or instance-level SecurityGroupIds in case of no specific subnet_id, not both.
185
186
  if subnet_id is not None:
186
- # Even if the instance type supports multiple cards, we always request only one interface
187
+ # If the instance type supports multiple cards, we request multiple interfaces only if not allocate_public_ip
187
188
  # due to the limitation: "AssociatePublicIpAddress [...] You cannot specify more than one
188
189
  # network interface in the request".
189
190
  # Error message: "(InvalidParameterCombination) when calling the RunInstances operation:
@@ -199,9 +200,28 @@ def create_instances_struct(
199
200
  "DeviceIndex": 0,
200
201
  "SubnetId": subnet_id,
201
202
  "Groups": [security_group_id],
202
- "InterfaceType": "efa" if enable_efa else "interface",
203
+ "InterfaceType": "efa" if max_efa_interfaces > 0 else "interface",
203
204
  },
204
205
  ]
206
+
207
+ if max_efa_interfaces > 1 and allocate_public_ip is False:
208
+ for i in range(1, max_efa_interfaces):
209
+ # Set to efa-only to use interfaces exclusively for GPU-to-GPU communication
210
+ interface_type = "efa-only"
211
+ if instance_type == "p5.48xlarge":
212
+ # EFA configuration for P5 instances:
213
+ # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-acc-inst-types.html#efa-for-p5
214
+ interface_type = "efa" if i % 4 == 0 else "efa-only"
215
+ struct["NetworkInterfaces"].append(
216
+ {
217
+ "AssociatePublicIpAddress": allocate_public_ip,
218
+ "NetworkCardIndex": i,
219
+ "DeviceIndex": 1,
220
+ "SubnetId": subnet_id,
221
+ "Groups": [security_group_id],
222
+ "InterfaceType": interface_type,
223
+ }
224
+ )
205
225
  else:
206
226
  struct["SecurityGroupIds"] = [security_group_id]
207
227
 
@@ -370,16 +390,6 @@ def get_subnets_ids_for_vpc(
370
390
  return subnets_ids
371
391
 
372
392
 
373
- def get_availability_zone(ec2_client: botocore.client.BaseClient, region: str) -> Optional[str]:
374
- zone_names = get_availability_zones(
375
- ec2_client=ec2_client,
376
- region=region,
377
- )
378
- if len(zone_names) == 0:
379
- return None
380
- return zone_names[0]
381
-
382
-
383
393
  def get_availability_zones(ec2_client: botocore.client.BaseClient, region: str) -> List[str]:
384
394
  response = ec2_client.describe_availability_zones(
385
395
  Filters=[
@@ -133,6 +133,8 @@ class AzureCompute(Compute):
133
133
  }
134
134
  tags = merge_tags(tags=tags, backend_tags=self.config.tags)
135
135
 
136
+ # TODO: Support custom availability_zones.
137
+ # Currently, VMs are regional, which means they don't have zone info.
136
138
  vm = _launch_instance(
137
139
  compute_client=self._compute_client,
138
140
  subscription_id=self.config.subscription_id,
@@ -371,7 +371,16 @@ def get_docker_commands(
371
371
  "rm -rf /run/sshd && mkdir -p /run/sshd && chown root:root /run/sshd",
372
372
  "rm -rf /var/empty && mkdir -p /var/empty && chown root:root /var/empty",
373
373
  # start sshd
374
- f"/usr/sbin/sshd -p {DSTACK_RUNNER_SSH_PORT} -o PidFile=none -o PasswordAuthentication=no -o AllowTcpForwarding=yes -o PermitUserEnvironment=yes",
374
+ (
375
+ "/usr/sbin/sshd"
376
+ f" -p {DSTACK_RUNNER_SSH_PORT}"
377
+ " -o PidFile=none"
378
+ " -o PasswordAuthentication=no"
379
+ " -o AllowTcpForwarding=yes"
380
+ " -o PermitUserEnvironment=yes"
381
+ " -o ClientAliveInterval=30"
382
+ " -o ClientAliveCountMax=4"
383
+ ),
375
384
  # restore ld.so variables
376
385
  'if [ -n "$_LD_LIBRARY_PATH" ]; then export LD_LIBRARY_PATH="$_LD_LIBRARY_PATH"; fi',
377
386
  'if [ -n "$_LD_PRELOAD" ]; then export LD_PRELOAD="$_LD_PRELOAD"; fi',
@@ -381,7 +390,16 @@ def get_docker_commands(
381
390
  commands += [
382
391
  f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}",
383
392
  f"chmod +x {DSTACK_RUNNER_BINARY_PATH}",
384
- f"{DSTACK_RUNNER_BINARY_PATH} --log-level 6 start --http-port {DSTACK_RUNNER_HTTP_PORT} --temp-dir /tmp/runner --home-dir /root --working-dir /workflow",
393
+ (
394
+ f"{DSTACK_RUNNER_BINARY_PATH}"
395
+ " --log-level 6"
396
+ " start"
397
+ f" --http-port {DSTACK_RUNNER_HTTP_PORT}"
398
+ f" --ssh-port {DSTACK_RUNNER_SSH_PORT}"
399
+ " --temp-dir /tmp/runner"
400
+ " --home-dir /root"
401
+ " --working-dir /workflow"
402
+ ),
385
403
  ]
386
404
 
387
405
  return commands
@@ -94,21 +94,25 @@ class GCPCompute(Compute):
94
94
  for quota in region.quotas:
95
95
  quotas[region.name][quota.metric] = quota.limit - quota.usage
96
96
 
97
- seen_region_offers = set()
97
+ offer_keys_to_offers = {}
98
98
  offers_with_availability = []
99
99
  for offer in offers:
100
100
  region = offer.region[:-2] # strip zone
101
101
  key = (_unique_instance_name(offer.instance), region)
102
- if key in seen_region_offers:
102
+ if key in offer_keys_to_offers:
103
+ offer_keys_to_offers[key].availability_zones.append(offer.region)
103
104
  continue
104
- seen_region_offers.add(key)
105
105
  availability = InstanceAvailability.NO_QUOTA
106
106
  if _has_gpu_quota(quotas[region], offer.instance.resources):
107
107
  availability = InstanceAvailability.UNKNOWN
108
108
  # todo quotas: cpu, memory, global gpu, tpu
109
- offers_with_availability.append(
110
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
109
+ offer_with_availability = InstanceOfferWithAvailability(
110
+ **offer.dict(),
111
+ availability=availability,
112
+ availability_zones=[offer.region],
111
113
  )
114
+ offer_keys_to_offers[key] = offer_with_availability
115
+ offers_with_availability.append(offer_with_availability)
112
116
  offers_with_availability[-1].region = region
113
117
 
114
118
  return offers_with_availability
@@ -156,10 +160,10 @@ class GCPCompute(Compute):
156
160
  )
157
161
  authorized_keys = instance_config.get_public_keys()
158
162
 
159
- zones = _get_instance_zones(instance_offer)
160
- if instance_config.availability_zone:
161
- zones = [z for z in zones if z == instance_config.availability_zone]
162
-
163
+ # get_offers always fills instance_offer.availability_zones
164
+ zones = get_or_error(instance_offer.availability_zones)
165
+ if len(zones) == 0:
166
+ raise NoCapacityError("No eligible availability zones")
163
167
  # If a shared VPC is not used, we can create firewall rules for user
164
168
  if self.config.vpc_project_id is None:
165
169
  gcp_resources.create_runner_firewall_rules(
@@ -371,6 +375,7 @@ class GCPCompute(Compute):
371
375
  project_ssh_private_key: str,
372
376
  volumes: List[Volume],
373
377
  ) -> JobProvisioningData:
378
+ # TODO: run_job is the same for vm-based backends, refactor
374
379
  instance_config = InstanceConfiguration(
375
380
  project_name=run.project_name,
376
381
  instance_name=get_instance_name(run, job), # TODO: generate name
@@ -379,14 +384,24 @@ class GCPCompute(Compute):
379
384
  ],
380
385
  user=run.user,
381
386
  volumes=volumes,
387
+ reservation=run.run_spec.configuration.reservation,
382
388
  )
389
+ instance_offer = instance_offer.copy()
383
390
  if len(volumes) > 0:
384
391
  volume = volumes[0]
385
392
  if (
386
393
  volume.provisioning_data is not None
387
394
  and volume.provisioning_data.availability_zone is not None
388
395
  ):
389
- instance_config.availability_zone = volume.provisioning_data.availability_zone
396
+ if instance_offer.availability_zones is None:
397
+ instance_offer.availability_zones = [
398
+ volume.provisioning_data.availability_zone
399
+ ]
400
+ instance_offer.availability_zones = [
401
+ z
402
+ for z in instance_offer.availability_zones
403
+ if z == volume.provisioning_data.availability_zone
404
+ ]
390
405
  return self.create_instance(instance_offer, instance_config)
391
406
 
392
407
  def create_gateway(
@@ -497,15 +512,18 @@ class GCPCompute(Compute):
497
512
  raise ComputeError(f"Persistent disk {volume.configuration.volume_id} not found")
498
513
 
499
514
  def create_volume(self, volume: Volume) -> VolumeProvisioningData:
500
- zone = gcp_resources.get_availability_zone(
515
+ zones = gcp_resources.get_availability_zones(
501
516
  regions_client=self.regions_client,
502
517
  project_id=self.config.project_id,
503
518
  region=volume.configuration.region,
504
519
  )
505
- if zone is None:
520
+ if volume.configuration.availability_zone is not None:
521
+ zones = [z for z in zones if z == volume.configuration.availability_zone]
522
+ if len(zones) == 0:
506
523
  raise ComputeError(
507
524
  f"Failed to find availability zone in region {volume.configuration.region}"
508
525
  )
526
+ zone = zones[0]
509
527
 
510
528
  labels = {
511
529
  "owner": "dstack",
@@ -648,6 +666,7 @@ class GCPCompute(Compute):
648
666
  instance_id,
649
667
  )
650
668
  zone = get_or_error(volume.provisioning_data).availability_zone
669
+ attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
651
670
  # This method has no information if the instance is a TPU or a VM,
652
671
  # so we first try to see if there is a TPU with such name
653
672
  try:
@@ -676,7 +695,7 @@ class GCPCompute(Compute):
676
695
  project=self.config.project_id,
677
696
  zone=get_or_error(volume.provisioning_data).availability_zone,
678
697
  instance=instance_id,
679
- device_name=get_or_error(volume.attachment_data).device_name,
698
+ device_name=attachment_data.device_name,
680
699
  )
681
700
  gcp_resources.wait_for_extended_operation(operation, "persistent disk detachment")
682
701
  logger.debug(
@@ -759,17 +778,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
759
778
  return f"{name}-{gpu.name}-{gpu.memory_mib}"
760
779
 
761
780
 
762
- def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]:
763
- zones = []
764
- for offer in get_catalog_offers(backend=BackendType.GCP):
765
- if _unique_instance_name(instance_offer.instance) != _unique_instance_name(offer.instance):
766
- continue
767
- if offer.region[:-2] != instance_offer.region:
768
- continue
769
- zones.append(offer.region)
770
- return zones
771
-
772
-
773
781
  def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
774
782
  commands = get_shim_commands(
775
783
  authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU"
@@ -31,21 +31,6 @@ supported_accelerators = [
31
31
  ]
32
32
 
33
33
 
34
- def get_availability_zone(
35
- regions_client: compute_v1.RegionsClient,
36
- project_id: str,
37
- region: str,
38
- ) -> Optional[str]:
39
- zones = get_availability_zones(
40
- regions_client=regions_client,
41
- project_id=project_id,
42
- region=region,
43
- )
44
- if len(zones) == 0:
45
- return None
46
- return zones[0]
47
-
48
-
49
34
  def get_availability_zones(
50
35
  regions_client: compute_v1.RegionsClient,
51
36
  project_id: str,
@@ -76,7 +76,13 @@ class OCICompute(Compute):
76
76
  else:
77
77
  availability = InstanceAvailability.NO_QUOTA
78
78
  offers_with_availability.append(
79
- InstanceOfferWithAvailability(**offer.dict(), availability=availability)
79
+ InstanceOfferWithAvailability(
80
+ **offer.dict(),
81
+ availability=availability,
82
+ availability_zones=shapes_availability[offer.region].get(
83
+ offer.instance.name, []
84
+ ),
85
+ )
80
86
  )
81
87
 
82
88
  return offers_with_availability
@@ -111,11 +117,9 @@ class OCICompute(Compute):
111
117
  ) -> JobProvisioningData:
112
118
  region = self.regions[instance_offer.region]
113
119
 
114
- availability_domain = resources.choose_available_domain(
115
- instance_offer.instance.name, self.shapes_quota, region, self.config.compartment_id
116
- )
117
- if availability_domain is None:
120
+ if not instance_offer.availability_zones:
118
121
  raise NoCapacityError("Shape unavailable in all availability domains")
122
+ availability_domain = instance_offer.availability_zones[0]
119
123
 
120
124
  listing, package = resources.get_marketplace_listing_and_package(
121
125
  cuda=len(instance_offer.instance.resources.gpus) > 0,
@@ -170,6 +174,7 @@ class OCICompute(Compute):
170
174
  hostname=None,
171
175
  internal_ip=None,
172
176
  region=instance_offer.region,
177
+ availability_zone=availability_domain,
173
178
  price=instance_offer.price,
174
179
  username="ubuntu",
175
180
  ssh_port=22,