dstack 0.19.30rc1__py3-none-any.whl → 0.19.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (54) hide show
  1. dstack/_internal/cli/commands/__init__.py +8 -0
  2. dstack/_internal/cli/commands/project.py +27 -20
  3. dstack/_internal/cli/commands/server.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +20 -6
  5. dstack/_internal/cli/utils/gpu.py +2 -2
  6. dstack/_internal/core/backends/aws/compute.py +13 -5
  7. dstack/_internal/core/backends/aws/resources.py +11 -6
  8. dstack/_internal/core/backends/azure/compute.py +17 -6
  9. dstack/_internal/core/backends/base/compute.py +57 -9
  10. dstack/_internal/core/backends/base/offers.py +1 -0
  11. dstack/_internal/core/backends/cloudrift/compute.py +2 -0
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
  15. dstack/_internal/core/backends/features.py +5 -0
  16. dstack/_internal/core/backends/gcp/compute.py +87 -38
  17. dstack/_internal/core/backends/gcp/configurator.py +1 -1
  18. dstack/_internal/core/backends/gcp/models.py +14 -1
  19. dstack/_internal/core/backends/gcp/resources.py +35 -12
  20. dstack/_internal/core/backends/hotaisle/compute.py +22 -0
  21. dstack/_internal/core/backends/kubernetes/compute.py +531 -215
  22. dstack/_internal/core/backends/kubernetes/models.py +13 -16
  23. dstack/_internal/core/backends/kubernetes/utils.py +145 -8
  24. dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
  25. dstack/_internal/core/backends/local/compute.py +2 -0
  26. dstack/_internal/core/backends/nebius/compute.py +17 -0
  27. dstack/_internal/core/backends/nebius/configurator.py +15 -0
  28. dstack/_internal/core/backends/nebius/models.py +57 -5
  29. dstack/_internal/core/backends/nebius/resources.py +45 -2
  30. dstack/_internal/core/backends/oci/compute.py +7 -1
  31. dstack/_internal/core/backends/oci/resources.py +8 -3
  32. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  33. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  34. dstack/_internal/core/backends/vultr/compute.py +2 -0
  35. dstack/_internal/core/compatibility/runs.py +8 -0
  36. dstack/_internal/core/consts.py +2 -0
  37. dstack/_internal/core/models/profiles.py +11 -4
  38. dstack/_internal/core/services/repos.py +101 -11
  39. dstack/_internal/server/background/tasks/common.py +2 -0
  40. dstack/_internal/server/background/tasks/process_fleets.py +75 -17
  41. dstack/_internal/server/background/tasks/process_instances.py +3 -5
  42. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
  43. dstack/_internal/server/background/tasks/process_runs.py +27 -23
  44. dstack/_internal/server/background/tasks/process_submitted_jobs.py +107 -54
  45. dstack/_internal/server/services/offers.py +7 -1
  46. dstack/_internal/server/testing/common.py +2 -0
  47. dstack/_internal/server/utils/provisioning.py +3 -10
  48. dstack/_internal/utils/ssh.py +22 -2
  49. dstack/version.py +2 -2
  50. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/METADATA +20 -18
  51. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/RECORD +54 -54
  52. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/WHEEL +0 -0
  53. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/entry_points.txt +0 -0
  54. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/licenses/LICENSE.md +0 -0
@@ -2,7 +2,8 @@ import subprocess
2
2
  import tempfile
3
3
  import threading
4
4
  import time
5
- from typing import Dict, List, Optional, Tuple
5
+ from enum import Enum
6
+ from typing import List, Optional, Tuple
6
7
 
7
8
  from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor
8
9
  from kubernetes import client
@@ -11,19 +12,24 @@ from dstack._internal.core.backends.base.compute import (
11
12
  Compute,
12
13
  ComputeWithFilteredOffersCached,
13
14
  ComputeWithGatewaySupport,
15
+ ComputeWithMultinodeSupport,
16
+ ComputeWithPrivilegedSupport,
14
17
  generate_unique_gateway_instance_name,
15
18
  generate_unique_instance_name_for_job,
16
19
  get_docker_commands,
17
20
  get_dstack_gateway_commands,
21
+ normalize_arch,
18
22
  )
19
23
  from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
20
24
  from dstack._internal.core.backends.kubernetes.models import (
21
25
  KubernetesConfig,
22
- KubernetesNetworkingConfig,
26
+ KubernetesProxyJumpConfig,
23
27
  )
24
28
  from dstack._internal.core.backends.kubernetes.utils import (
29
+ call_api_method,
25
30
  get_api_from_config_data,
26
31
  get_cluster_public_ip,
32
+ get_value,
27
33
  )
28
34
  from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
29
35
  from dstack._internal.core.errors import ComputeError
@@ -44,6 +50,7 @@ from dstack._internal.core.models.instances import (
44
50
  Resources,
45
51
  SSHConnectionParams,
46
52
  )
53
+ from dstack._internal.core.models.resources import CPUSpec, Memory
47
54
  from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
48
55
  from dstack._internal.core.models.volumes import Volume
49
56
  from dstack._internal.utils.common import parse_memory
@@ -52,52 +59,92 @@ from dstack._internal.utils.logging import get_logger
52
59
  logger = get_logger(__name__)
53
60
 
54
61
  JUMP_POD_SSH_PORT = 22
55
- DEFAULT_NAMESPACE = "default"
56
62
 
57
63
  NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS}
58
64
  NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys()
59
65
 
66
+ NVIDIA_GPU_RESOURCE = "nvidia.com/gpu"
67
+ NVIDIA_GPU_COUNT_LABEL = f"{NVIDIA_GPU_RESOURCE}.count"
68
+ NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product"
69
+ NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE
70
+
71
+ # Taints we know and tolerate when creating our objects, e.g., the jump pod.
72
+ TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT,)
73
+
74
+ DUMMY_REGION = "-"
75
+
76
+
77
+ class Operator(str, Enum):
78
+ EXISTS = "Exists"
79
+ IN = "In"
80
+
81
+
82
+ class TaintEffect(str, Enum):
83
+ NO_EXECUTE = "NoExecute"
84
+ NO_SCHEDULE = "NoSchedule"
85
+ PREFER_NO_SCHEDULE = "PreferNoSchedule"
86
+
60
87
 
61
88
  class KubernetesCompute(
62
89
  ComputeWithFilteredOffersCached,
90
+ ComputeWithPrivilegedSupport,
63
91
  ComputeWithGatewaySupport,
92
+ ComputeWithMultinodeSupport,
64
93
  Compute,
65
94
  ):
66
95
  def __init__(self, config: KubernetesConfig):
67
96
  super().__init__()
68
97
  self.config = config.copy()
69
- networking_config = self.config.networking
70
- if networking_config is None:
71
- networking_config = KubernetesNetworkingConfig()
72
- self.networking_config = networking_config
98
+ proxy_jump = self.config.proxy_jump
99
+ if proxy_jump is None:
100
+ proxy_jump = KubernetesProxyJumpConfig()
101
+ self.proxy_jump = proxy_jump
73
102
  self.api = get_api_from_config_data(config.kubeconfig.data)
74
103
 
75
104
  def get_offers_by_requirements(
76
105
  self, requirements: Requirements
77
106
  ) -> List[InstanceOfferWithAvailability]:
78
- nodes = self.api.list_node()
79
- instance_offers = []
80
- for node in nodes.items:
107
+ instance_offers: list[InstanceOfferWithAvailability] = []
108
+ node_list = call_api_method(
109
+ self.api.list_node,
110
+ client.V1NodeList,
111
+ )
112
+ nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
113
+ for node in nodes:
114
+ try:
115
+ labels = get_value(node, ".metadata.labels", dict[str, str]) or {}
116
+ name = get_value(node, ".metadata.name", str, required=True)
117
+ cpus = _parse_cpu(
118
+ get_value(node, ".status.allocatable['cpu']", str, required=True)
119
+ )
120
+ cpu_arch = normalize_arch(
121
+ get_value(node, ".status.node_info.architecture", str)
122
+ ).to_cpu_architecture()
123
+ memory_mib = _parse_memory(
124
+ get_value(node, ".status.allocatable['memory']", str, required=True)
125
+ )
126
+ gpus, _ = _get_gpus_from_node_labels(labels)
127
+ disk_size_mib = _parse_memory(
128
+ get_value(node, ".status.allocatable['ephemeral-storage']", str, required=True)
129
+ )
130
+ except (AttributeError, KeyError, ValueError) as e:
131
+ logger.exception("Failed to process node: %s: %s", type(e).__name__, e)
132
+ continue
81
133
  instance_offer = InstanceOfferWithAvailability(
82
134
  backend=BackendType.KUBERNETES,
83
135
  instance=InstanceType(
84
- name=node.metadata.name,
136
+ name=name,
85
137
  resources=Resources(
86
- cpus=node.status.capacity["cpu"],
87
- memory_mib=int(parse_memory(node.status.capacity["memory"], as_untis="M")),
88
- gpus=_get_gpus_from_node_labels(node.metadata.labels),
138
+ cpus=cpus,
139
+ cpu_arch=cpu_arch,
140
+ memory_mib=memory_mib,
141
+ gpus=gpus,
89
142
  spot=False,
90
- disk=Disk(
91
- size_mib=int(
92
- parse_memory(
93
- node.status.capacity["ephemeral-storage"], as_untis="M"
94
- )
95
- )
96
- ),
143
+ disk=Disk(size_mib=disk_size_mib),
97
144
  ),
98
145
  ),
99
146
  price=0,
100
- region="-",
147
+ region=DUMMY_REGION,
101
148
  availability=InstanceAvailability.AVAILABLE,
102
149
  instance_runtime=InstanceRuntime.RUNNER,
103
150
  )
@@ -122,7 +169,7 @@ class KubernetesCompute(
122
169
  # as an ssh proxy jump to connect to all other services in Kubernetes.
123
170
  # Setup jump pod in a separate thread to avoid long-running run_job.
124
171
  # In case the thread fails, the job will be failed and resubmitted.
125
- jump_pod_hostname = self.networking_config.ssh_host
172
+ jump_pod_hostname = self.proxy_jump.hostname
126
173
  if jump_pod_hostname is None:
127
174
  jump_pod_hostname = get_cluster_public_ip(self.api)
128
175
  if jump_pod_hostname is None:
@@ -132,15 +179,17 @@ class KubernetesCompute(
132
179
  )
133
180
  jump_pod_port, created = _create_jump_pod_service_if_not_exists(
134
181
  api=self.api,
182
+ namespace=self.config.namespace,
135
183
  project_name=run.project_name,
136
184
  ssh_public_keys=[project_ssh_public_key.strip(), run.run_spec.ssh_key_pub.strip()],
137
- jump_pod_port=self.networking_config.ssh_port,
185
+ jump_pod_port=self.proxy_jump.port,
138
186
  )
139
187
  if not created:
140
188
  threading.Thread(
141
189
  target=_continue_setup_jump_pod,
142
190
  kwargs={
143
191
  "api": self.api,
192
+ "namespace": self.config.namespace,
144
193
  "project_name": run.project_name,
145
194
  "project_ssh_private_key": project_ssh_private_key.strip(),
146
195
  "user_ssh_public_key": run.run_spec.ssh_key_pub.strip(),
@@ -148,41 +197,155 @@ class KubernetesCompute(
148
197
  "jump_pod_port": jump_pod_port,
149
198
  },
150
199
  ).start()
151
- self.api.create_namespaced_pod(
152
- namespace=DEFAULT_NAMESPACE,
153
- body=client.V1Pod(
154
- metadata=client.V1ObjectMeta(
155
- name=instance_name,
156
- labels={"app.kubernetes.io/name": instance_name},
157
- ),
158
- spec=client.V1PodSpec(
159
- containers=[
160
- client.V1Container(
161
- name=f"{instance_name}-container",
162
- image=job.job_spec.image_name,
163
- command=["/bin/sh"],
164
- args=["-c", " && ".join(commands)],
165
- ports=[
166
- client.V1ContainerPort(
167
- container_port=DSTACK_RUNNER_SSH_PORT,
168
- )
200
+
201
+ resources_requests: dict[str, str] = {}
202
+ resources_limits: dict[str, str] = {}
203
+ node_affinity: Optional[client.V1NodeAffinity] = None
204
+ tolerations: list[client.V1Toleration] = []
205
+ volumes_: list[client.V1Volume] = []
206
+ volume_mounts: list[client.V1VolumeMount] = []
207
+
208
+ resources_spec = job.job_spec.requirements.resources
209
+ assert isinstance(resources_spec.cpu, CPUSpec)
210
+ if (cpu_min := resources_spec.cpu.count.min) is not None:
211
+ resources_requests["cpu"] = str(cpu_min)
212
+ if (gpu_spec := resources_spec.gpu) is not None:
213
+ gpu_min = gpu_spec.count.min
214
+ if gpu_min is not None and gpu_min > 0:
215
+ if not (offer_gpus := instance_offer.instance.resources.gpus):
216
+ raise ComputeError(
217
+ "GPU is requested but the offer has no GPUs:"
218
+ f" {gpu_spec=} {instance_offer=}",
219
+ )
220
+ offer_gpu = offer_gpus[0]
221
+ matching_gpu_label_values: set[str] = set()
222
+ # We cannot generate an expected GPU label value from the Gpu model instance
223
+ # as the actual values may have additional components (socket, memory type, etc.)
224
+ # that we don't preserve in the Gpu model, e.g., "NVIDIA-H100-80GB-HBM3".
225
+ # Moreover, a single Gpu may match multiple label values.
226
+ # As a workaround, we iterate and process all node labels once again (we already
227
+ # processed them in `get_offers_by_requirements()`).
228
+ node_list = call_api_method(
229
+ self.api.list_node,
230
+ client.V1NodeList,
231
+ )
232
+ nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
233
+ for node in nodes:
234
+ labels = get_value(node, ".metadata.labels", dict[str, str])
235
+ if not labels:
236
+ continue
237
+ gpus, gpu_label_value = _get_gpus_from_node_labels(labels)
238
+ if not gpus or gpu_label_value is None:
239
+ continue
240
+ if gpus[0] == offer_gpu:
241
+ matching_gpu_label_values.add(gpu_label_value)
242
+ if not matching_gpu_label_values:
243
+ raise ComputeError(
244
+ f"GPU is requested but no matching GPU labels found: {gpu_spec=}"
245
+ )
246
+ logger.debug(
247
+ "Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values
248
+ )
249
+ # TODO: support other GPU vendors
250
+ resources_requests[NVIDIA_GPU_RESOURCE] = str(gpu_min)
251
+ resources_limits[NVIDIA_GPU_RESOURCE] = str(gpu_min)
252
+ node_affinity = client.V1NodeAffinity(
253
+ required_during_scheduling_ignored_during_execution=[
254
+ client.V1NodeSelectorTerm(
255
+ match_expressions=[
256
+ client.V1NodeSelectorRequirement(
257
+ key=NVIDIA_GPU_PRODUCT_LABEL,
258
+ operator=Operator.IN,
259
+ values=list(matching_gpu_label_values),
260
+ ),
169
261
  ],
170
- security_context=client.V1SecurityContext(
171
- # TODO(#1535): support non-root images properly
172
- run_as_user=0,
173
- run_as_group=0,
174
- ),
175
- # TODO: Pass cpu, memory, gpu as requests.
176
- # Beware that node capacity != allocatable, so
177
- # if the node has 2xCPU – then cpu=2 request will probably fail.
178
- resources=client.V1ResourceRequirements(requests={}),
262
+ ),
263
+ ],
264
+ )
265
+ # It should be NoSchedule, but we also add NoExecute toleration just in case.
266
+ for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]:
267
+ tolerations.append(
268
+ client.V1Toleration(
269
+ key=NVIDIA_GPU_NODE_TAINT, operator=Operator.EXISTS, effect=effect
179
270
  )
180
- ]
181
- ),
271
+ )
272
+
273
+ if (memory_min := resources_spec.memory.min) is not None:
274
+ resources_requests["memory"] = _render_memory(memory_min)
275
+ if (
276
+ resources_spec.disk is not None
277
+ and (disk_min := resources_spec.disk.size.min) is not None
278
+ ):
279
+ resources_requests["ephemeral-storage"] = _render_memory(disk_min)
280
+ if (shm_size := resources_spec.shm_size) is not None:
281
+ shm_volume_name = "dev-shm"
282
+ volumes_.append(
283
+ client.V1Volume(
284
+ name=shm_volume_name,
285
+ empty_dir=client.V1EmptyDirVolumeSource(
286
+ medium="Memory",
287
+ size_limit=_render_memory(shm_size),
288
+ ),
289
+ )
290
+ )
291
+ volume_mounts.append(
292
+ client.V1VolumeMount(
293
+ name=shm_volume_name,
294
+ mount_path="/dev/shm",
295
+ )
296
+ )
297
+
298
+ pod = client.V1Pod(
299
+ metadata=client.V1ObjectMeta(
300
+ name=instance_name,
301
+ labels={"app.kubernetes.io/name": instance_name},
302
+ ),
303
+ spec=client.V1PodSpec(
304
+ containers=[
305
+ client.V1Container(
306
+ name=f"{instance_name}-container",
307
+ image=job.job_spec.image_name,
308
+ command=["/bin/sh"],
309
+ args=["-c", " && ".join(commands)],
310
+ ports=[
311
+ client.V1ContainerPort(
312
+ container_port=DSTACK_RUNNER_SSH_PORT,
313
+ )
314
+ ],
315
+ security_context=client.V1SecurityContext(
316
+ # TODO(#1535): support non-root images properly
317
+ run_as_user=0,
318
+ run_as_group=0,
319
+ privileged=job.job_spec.privileged,
320
+ capabilities=client.V1Capabilities(
321
+ add=[
322
+ # Allow to increase hard resource limits, see getrlimit(2)
323
+ "SYS_RESOURCE",
324
+ ],
325
+ ),
326
+ ),
327
+ resources=client.V1ResourceRequirements(
328
+ requests=resources_requests,
329
+ limits=resources_limits,
330
+ ),
331
+ volume_mounts=volume_mounts,
332
+ )
333
+ ],
334
+ affinity=node_affinity,
335
+ tolerations=tolerations,
336
+ volumes=volumes_,
182
337
  ),
183
338
  )
184
- service_response = self.api.create_namespaced_service(
185
- namespace=DEFAULT_NAMESPACE,
339
+ call_api_method(
340
+ self.api.create_namespaced_pod,
341
+ client.V1Pod,
342
+ namespace=self.config.namespace,
343
+ body=pod,
344
+ )
345
+ call_api_method(
346
+ self.api.create_namespaced_service,
347
+ client.V1Service,
348
+ namespace=self.config.namespace,
186
349
  body=client.V1Service(
187
350
  metadata=client.V1ObjectMeta(name=_get_pod_service_name(instance_name)),
188
351
  spec=client.V1ServiceSpec(
@@ -192,14 +355,16 @@ class KubernetesCompute(
192
355
  ),
193
356
  ),
194
357
  )
195
- service_ip = service_response.spec.cluster_ip
196
358
  return JobProvisioningData(
197
359
  backend=instance_offer.backend,
198
360
  instance_type=instance_offer.instance,
199
361
  instance_id=instance_name,
200
- hostname=service_ip,
362
+ # Although we can already get Service's ClusterIP from the `V1Service` object returned
363
+ # by the `create_namespaced_service` method, we still need PodIP for multinode runs.
364
+ # We'll update both hostname and internal_ip once the pod is assigned to the node.
365
+ hostname=None,
201
366
  internal_ip=None,
202
- region="local",
367
+ region=instance_offer.region,
203
368
  price=instance_offer.price,
204
369
  username="root",
205
370
  ssh_port=DSTACK_RUNNER_SSH_PORT,
@@ -212,25 +377,49 @@ class KubernetesCompute(
212
377
  backend_data=None,
213
378
  )
214
379
 
380
+ def update_provisioning_data(
381
+ self,
382
+ provisioning_data: JobProvisioningData,
383
+ project_ssh_public_key: str,
384
+ project_ssh_private_key: str,
385
+ ):
386
+ pod = call_api_method(
387
+ self.api.read_namespaced_pod,
388
+ client.V1Pod,
389
+ name=provisioning_data.instance_id,
390
+ namespace=self.config.namespace,
391
+ )
392
+ pod_ip = get_value(pod, ".status.pod_ip", str)
393
+ if not pod_ip:
394
+ return
395
+ provisioning_data.internal_ip = pod_ip
396
+ service = call_api_method(
397
+ self.api.read_namespaced_service,
398
+ client.V1Service,
399
+ name=_get_pod_service_name(provisioning_data.instance_id),
400
+ namespace=self.config.namespace,
401
+ )
402
+ provisioning_data.hostname = get_value(service, ".spec.cluster_ip", str, required=True)
403
+
215
404
  def terminate_instance(
216
405
  self, instance_id: str, region: str, backend_data: Optional[str] = None
217
406
  ):
218
- try:
219
- self.api.delete_namespaced_service(
220
- name=_get_pod_service_name(instance_id),
221
- namespace=DEFAULT_NAMESPACE,
222
- body=client.V1DeleteOptions(),
223
- )
224
- except client.ApiException as e:
225
- if e.status != 404:
226
- raise
227
- try:
228
- self.api.delete_namespaced_pod(
229
- name=instance_id, namespace=DEFAULT_NAMESPACE, body=client.V1DeleteOptions()
230
- )
231
- except client.ApiException as e:
232
- if e.status != 404:
233
- raise
407
+ call_api_method(
408
+ self.api.delete_namespaced_service,
409
+ client.V1Service,
410
+ expected=404,
411
+ name=_get_pod_service_name(instance_id),
412
+ namespace=self.config.namespace,
413
+ body=client.V1DeleteOptions(),
414
+ )
415
+ call_api_method(
416
+ self.api.delete_namespaced_pod,
417
+ client.V1Pod,
418
+ expected=404,
419
+ name=instance_id,
420
+ namespace=self.config.namespace,
421
+ body=client.V1DeleteOptions(),
422
+ )
234
423
 
235
424
  def create_gateway(
236
425
  self,
@@ -247,70 +436,79 @@ class KubernetesCompute(
247
436
  # https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html
248
437
  instance_name = generate_unique_gateway_instance_name(configuration)
249
438
  commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub])
250
- self.api.create_namespaced_pod(
251
- namespace=DEFAULT_NAMESPACE,
252
- body=client.V1Pod(
253
- metadata=client.V1ObjectMeta(
254
- name=instance_name,
255
- labels={"app.kubernetes.io/name": instance_name},
256
- ),
257
- spec=client.V1PodSpec(
258
- containers=[
259
- client.V1Container(
260
- name=f"{instance_name}-container",
261
- image="ubuntu:22.04",
262
- command=["/bin/sh"],
263
- args=["-c", " && ".join(commands)],
264
- ports=[
265
- client.V1ContainerPort(
266
- container_port=22,
267
- ),
268
- client.V1ContainerPort(
269
- container_port=80,
270
- ),
271
- client.V1ContainerPort(
272
- container_port=443,
273
- ),
274
- ],
275
- )
276
- ]
277
- ),
439
+ pod = client.V1Pod(
440
+ metadata=client.V1ObjectMeta(
441
+ name=instance_name,
442
+ labels={"app.kubernetes.io/name": instance_name},
443
+ ),
444
+ spec=client.V1PodSpec(
445
+ containers=[
446
+ client.V1Container(
447
+ name=f"{instance_name}-container",
448
+ image="ubuntu:22.04",
449
+ command=["/bin/sh"],
450
+ args=["-c", " && ".join(commands)],
451
+ ports=[
452
+ client.V1ContainerPort(
453
+ container_port=22,
454
+ ),
455
+ client.V1ContainerPort(
456
+ container_port=80,
457
+ ),
458
+ client.V1ContainerPort(
459
+ container_port=443,
460
+ ),
461
+ ],
462
+ )
463
+ ]
278
464
  ),
279
465
  )
280
- self.api.create_namespaced_service(
281
- namespace=DEFAULT_NAMESPACE,
282
- body=client.V1Service(
283
- metadata=client.V1ObjectMeta(
284
- name=_get_pod_service_name(instance_name),
285
- ),
286
- spec=client.V1ServiceSpec(
287
- type="LoadBalancer",
288
- selector={"app.kubernetes.io/name": instance_name},
289
- ports=[
290
- client.V1ServicePort(
291
- name="ssh",
292
- port=22,
293
- target_port=22,
294
- ),
295
- client.V1ServicePort(
296
- name="http",
297
- port=80,
298
- target_port=80,
299
- ),
300
- client.V1ServicePort(
301
- name="https",
302
- port=443,
303
- target_port=443,
304
- ),
305
- ],
306
- ),
466
+ call_api_method(
467
+ self.api.create_namespaced_pod,
468
+ client.V1Pod,
469
+ namespace=self.config.namespace,
470
+ body=pod,
471
+ )
472
+ service = client.V1Service(
473
+ metadata=client.V1ObjectMeta(
474
+ name=_get_pod_service_name(instance_name),
307
475
  ),
476
+ spec=client.V1ServiceSpec(
477
+ type="LoadBalancer",
478
+ selector={"app.kubernetes.io/name": instance_name},
479
+ ports=[
480
+ client.V1ServicePort(
481
+ name="ssh",
482
+ port=22,
483
+ target_port=22,
484
+ ),
485
+ client.V1ServicePort(
486
+ name="http",
487
+ port=80,
488
+ target_port=80,
489
+ ),
490
+ client.V1ServicePort(
491
+ name="https",
492
+ port=443,
493
+ target_port=443,
494
+ ),
495
+ ],
496
+ ),
497
+ )
498
+ call_api_method(
499
+ self.api.create_namespaced_service,
500
+ client.V1Service,
501
+ namespace=self.config.namespace,
502
+ body=service,
308
503
  )
309
504
  hostname = _wait_for_load_balancer_hostname(
310
- api=self.api, service_name=_get_pod_service_name(instance_name)
505
+ api=self.api,
506
+ namespace=self.config.namespace,
507
+ service_name=_get_pod_service_name(instance_name),
311
508
  )
509
+ region = DUMMY_REGION
312
510
  if hostname is None:
313
- self.terminate_instance(instance_name, region="-")
511
+ self.terminate_instance(instance_name, region=region)
314
512
  raise ComputeError(
315
513
  "Failed to get gateway hostname. "
316
514
  "Ensure the Kubernetes cluster supports Load Balancer services."
@@ -318,7 +516,7 @@ class KubernetesCompute(
318
516
  return GatewayProvisioningData(
319
517
  instance_id=instance_name,
320
518
  ip_address=hostname,
321
- region="-",
519
+ region=region,
322
520
  )
323
521
 
324
522
  def terminate_gateway(
@@ -334,15 +532,34 @@ class KubernetesCompute(
334
532
  )
335
533
 
336
534
 
337
- def _get_gpus_from_node_labels(labels: Dict) -> List[Gpu]:
338
- # We rely on https://github.com/NVIDIA/gpu-feature-discovery to detect gpus.
339
- # Note that "nvidia.com/gpu.product" is not a short gpu name like "T4" or "A100" but a product name
340
- # from nvidia-smi like "Tesla-T4" or "A100-SXM4-40GB".
535
+ def _parse_cpu(cpu: str) -> int:
536
+ if cpu.endswith("m"):
537
+ # "m" means millicpu (1/1000 CPU), e.g., 7900m -> 7.9 -> 7
538
+ return int(float(cpu[:-1]) / 1000)
539
+ return int(cpu)
540
+
541
+
542
+ def _parse_memory(memory: str) -> int:
543
+ if memory.isdigit():
544
+ # no suffix means that the value is in bytes
545
+ return int(memory) // 2**20
546
+ return int(parse_memory(memory, as_untis="M"))
547
+
548
+
549
+ def _render_memory(memory: Memory) -> str:
550
+ return f"{float(memory)}Gi"
551
+
552
+
553
+ def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optional[str]]:
554
+ # We rely on https://github.com/NVIDIA/k8s-device-plugin/tree/main/docs/gpu-feature-discovery
555
+ # to detect gpus. Note that "nvidia.com/gpu.product" is not a short gpu name like "T4" or
556
+ # "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB".
341
557
  # Thus, we convert the product name to a known gpu name.
342
- gpu_count = labels.get("nvidia.com/gpu.count")
343
- gpu_product = labels.get("nvidia.com/gpu.product")
558
+ # TODO: support other GPU vendors
559
+ gpu_count = labels.get(NVIDIA_GPU_COUNT_LABEL)
560
+ gpu_product = labels.get(NVIDIA_GPU_PRODUCT_LABEL)
344
561
  if gpu_count is None or gpu_product is None:
345
- return []
562
+ return [], None
346
563
  gpu_count = int(gpu_count)
347
564
  gpu_name = None
348
565
  for known_gpu_name in NVIDIA_GPU_NAMES:
@@ -350,20 +567,22 @@ def _get_gpus_from_node_labels(labels: Dict) -> List[Gpu]:
350
567
  gpu_name = known_gpu_name
351
568
  break
352
569
  if gpu_name is None:
353
- return []
570
+ return [], None
354
571
  gpu_info = NVIDIA_GPU_NAME_TO_GPU_INFO[gpu_name]
355
572
  gpu_memory = gpu_info.memory * 1024
356
573
  # A100 may come in two variants
357
574
  if "40GB" in gpu_product:
358
575
  gpu_memory = 40 * 1024
359
- return [
576
+ gpus = [
360
577
  Gpu(vendor=AcceleratorVendor.NVIDIA, name=gpu_name, memory_mib=gpu_memory)
361
578
  for _ in range(gpu_count)
362
579
  ]
580
+ return gpus, gpu_product
363
581
 
364
582
 
365
583
  def _continue_setup_jump_pod(
366
584
  api: client.CoreV1Api,
585
+ namespace: str,
367
586
  project_name: str,
368
587
  project_ssh_private_key: str,
369
588
  user_ssh_public_key: str,
@@ -372,6 +591,7 @@ def _continue_setup_jump_pod(
372
591
  ):
373
592
  _wait_for_pod_ready(
374
593
  api=api,
594
+ namespace=namespace,
375
595
  pod_name=_get_jump_pod_name(project_name),
376
596
  )
377
597
  _add_authorized_key_to_jump_pod(
@@ -384,82 +604,169 @@ def _continue_setup_jump_pod(
384
604
 
385
605
  def _create_jump_pod_service_if_not_exists(
386
606
  api: client.CoreV1Api,
607
+ namespace: str,
387
608
  project_name: str,
388
609
  ssh_public_keys: List[str],
389
610
  jump_pod_port: Optional[int],
390
611
  ) -> Tuple[int, bool]:
391
612
  created = False
392
- try:
393
- service = api.read_namespaced_service(
613
+ service: Optional[client.V1Service] = None
614
+ pod: Optional[client.V1Pod] = None
615
+ _namespace = call_api_method(
616
+ api.read_namespace,
617
+ client.V1Namespace,
618
+ expected=404,
619
+ name=namespace,
620
+ )
621
+ if _namespace is None:
622
+ _namespace = client.V1Namespace(
623
+ metadata=client.V1ObjectMeta(
624
+ name=namespace,
625
+ labels={"app.kubernetes.io/name": namespace},
626
+ ),
627
+ )
628
+ call_api_method(
629
+ api.create_namespace,
630
+ client.V1Namespace,
631
+ body=_namespace,
632
+ )
633
+ else:
634
+ service = call_api_method(
635
+ api.read_namespaced_service,
636
+ client.V1Service,
637
+ expected=404,
394
638
  name=_get_jump_pod_service_name(project_name),
395
- namespace=DEFAULT_NAMESPACE,
639
+ namespace=namespace,
396
640
  )
397
- except client.ApiException as e:
398
- if e.status == 404:
399
- service = _create_jump_pod_service(
400
- api=api,
401
- project_name=project_name,
402
- ssh_public_keys=ssh_public_keys,
403
- jump_pod_port=jump_pod_port,
404
- )
405
- created = True
406
- else:
407
- raise
408
- return service.spec.ports[0].node_port, created
641
+ pod = call_api_method(
642
+ api.read_namespaced_pod,
643
+ client.V1Pod,
644
+ expected=404,
645
+ name=_get_jump_pod_name(project_name),
646
+ namespace=namespace,
647
+ )
648
+ # The service may exist without the pod if the node on which the jump pod was running
649
+ # has been deleted.
650
+ if service is None or pod is None:
651
+ service = _create_jump_pod_service(
652
+ api=api,
653
+ namespace=namespace,
654
+ project_name=project_name,
655
+ ssh_public_keys=ssh_public_keys,
656
+ jump_pod_port=jump_pod_port,
657
+ )
658
+ created = True
659
+ port = get_value(service, ".spec.ports[0].node_port", int, required=True)
660
+ return port, created
409
661
 
410
662
 
411
663
  def _create_jump_pod_service(
412
664
  api: client.CoreV1Api,
665
+ namespace: str,
413
666
  project_name: str,
414
667
  ssh_public_keys: List[str],
415
668
  jump_pod_port: Optional[int],
416
669
  ) -> client.V1Service:
417
670
  # TODO use restricted ssh-forwarding-only user for jump pod instead of root.
418
- commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys)
419
671
  pod_name = _get_jump_pod_name(project_name)
420
- api.create_namespaced_pod(
421
- namespace=DEFAULT_NAMESPACE,
422
- body=client.V1Pod(
423
- metadata=client.V1ObjectMeta(
424
- name=pod_name,
425
- labels={"app.kubernetes.io/name": pod_name},
426
- ),
427
- spec=client.V1PodSpec(
428
- containers=[
429
- client.V1Container(
430
- name=f"{pod_name}-container",
431
- # TODO: Choose appropriate image for jump pod
432
- image="dstackai/base:py3.11-0.4rc4",
433
- command=["/bin/sh"],
434
- args=["-c", " && ".join(commands)],
435
- ports=[
436
- client.V1ContainerPort(
437
- container_port=JUMP_POD_SSH_PORT,
438
- )
439
- ],
440
- )
441
- ]
442
- ),
672
+ call_api_method(
673
+ api.delete_namespaced_pod,
674
+ client.V1Pod,
675
+ expected=404,
676
+ namespace=namespace,
677
+ name=pod_name,
678
+ )
679
+
680
+ node_list = call_api_method(api.list_node, client.V1NodeList)
681
+ nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
682
+ # False if we found at least one node without any "hard" taint, that is, if we don't need to
683
+ # specify the toleration.
684
+ toleration_required = True
685
+ # (key, effect) pairs.
686
+ tolerated_taints: set[tuple[str, str]] = set()
687
+ for node in nodes:
688
+ # True if the node has at least one NoExecute or NoSchedule taint.
689
+ has_hard_taint = False
690
+ taints = get_value(node, ".spec.taints", list[client.V1Taint]) or []
691
+ for taint in taints:
692
+ effect = get_value(taint, ".effect", str, required=True)
693
+ # A "soft" taint, ignore.
694
+ if effect == TaintEffect.PREFER_NO_SCHEDULE:
695
+ continue
696
+ has_hard_taint = True
697
+ key = get_value(taint, ".key", str, required=True)
698
+ if key in TOLERATED_NODE_TAINTS:
699
+ tolerated_taints.add((key, effect))
700
+ if not has_hard_taint:
701
+ toleration_required = False
702
+ break
703
+ tolerations: list[client.V1Toleration] = []
704
+ if toleration_required:
705
+ for key, effect in tolerated_taints:
706
+ tolerations.append(
707
+ client.V1Toleration(key=key, operator=Operator.EXISTS, effect=effect)
708
+ )
709
+ if not tolerations:
710
+ logger.warning("No appropriate node found, the jump pod may never be scheduled")
711
+
712
+ commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys)
713
+ pod = client.V1Pod(
714
+ metadata=client.V1ObjectMeta(
715
+ name=pod_name,
716
+ labels={"app.kubernetes.io/name": pod_name},
717
+ ),
718
+ spec=client.V1PodSpec(
719
+ containers=[
720
+ client.V1Container(
721
+ name=f"{pod_name}-container",
722
+ # TODO: Choose appropriate image for jump pod
723
+ image="dstackai/base:py3.11-0.4rc4",
724
+ command=["/bin/sh"],
725
+ args=["-c", " && ".join(commands)],
726
+ ports=[
727
+ client.V1ContainerPort(
728
+ container_port=JUMP_POD_SSH_PORT,
729
+ )
730
+ ],
731
+ )
732
+ ],
733
+ tolerations=tolerations,
443
734
  ),
444
735
  )
445
- service_response = api.create_namespaced_service(
446
- namespace=DEFAULT_NAMESPACE,
447
- body=client.V1Service(
448
- metadata=client.V1ObjectMeta(name=_get_jump_pod_service_name(project_name)),
449
- spec=client.V1ServiceSpec(
450
- type="NodePort",
451
- selector={"app.kubernetes.io/name": pod_name},
452
- ports=[
453
- client.V1ServicePort(
454
- port=JUMP_POD_SSH_PORT,
455
- target_port=JUMP_POD_SSH_PORT,
456
- node_port=jump_pod_port,
457
- )
458
- ],
459
- ),
736
+ call_api_method(
737
+ api.create_namespaced_pod,
738
+ client.V1Pod,
739
+ namespace=namespace,
740
+ body=pod,
741
+ )
742
+ service_name = _get_jump_pod_service_name(project_name)
743
+ call_api_method(
744
+ api.delete_namespaced_service,
745
+ client.V1Service,
746
+ expected=404,
747
+ namespace=namespace,
748
+ name=service_name,
749
+ )
750
+ service = client.V1Service(
751
+ metadata=client.V1ObjectMeta(name=service_name),
752
+ spec=client.V1ServiceSpec(
753
+ type="NodePort",
754
+ selector={"app.kubernetes.io/name": pod_name},
755
+ ports=[
756
+ client.V1ServicePort(
757
+ port=JUMP_POD_SSH_PORT,
758
+ target_port=JUMP_POD_SSH_PORT,
759
+ node_port=jump_pod_port,
760
+ )
761
+ ],
460
762
  ),
461
763
  )
462
- return service_response
764
+ return call_api_method(
765
+ api.create_namespaced_service,
766
+ client.V1Service,
767
+ namespace=namespace,
768
+ body=service,
769
+ )
463
770
 
464
771
 
465
772
  def _get_jump_pod_commands(authorized_keys: List[str]) -> List[str]:
@@ -484,20 +791,25 @@ def _get_jump_pod_commands(authorized_keys: List[str]) -> List[str]:
484
791
 
485
792
  def _wait_for_pod_ready(
486
793
  api: client.CoreV1Api,
794
+ namespace: str,
487
795
  pod_name: str,
488
796
  timeout_seconds: int = 300,
489
797
  ):
490
798
  start_time = time.time()
491
799
  while True:
492
- try:
493
- pod = api.read_namespaced_pod(name=pod_name, namespace=DEFAULT_NAMESPACE)
494
- except client.ApiException as e:
495
- if e.status != 404:
496
- raise
497
- else:
498
- if pod.status.phase == "Running" and all(
499
- container_status.ready for container_status in pod.status.container_statuses
500
- ):
800
+ pod = call_api_method(
801
+ api.read_namespaced_pod,
802
+ client.V1Pod,
803
+ expected=404,
804
+ name=pod_name,
805
+ namespace=namespace,
806
+ )
807
+ if pod is not None:
808
+ phase = get_value(pod, ".status.phase", str, required=True)
809
+ container_statuses = get_value(
810
+ pod, ".status.container_statuses", list[client.V1ContainerStatus], required=True
811
+ )
812
+ if phase == "Running" and all(status.ready for status in container_statuses):
501
813
  return True
502
814
  elapsed_time = time.time() - start_time
503
815
  if elapsed_time >= timeout_seconds:
@@ -508,19 +820,23 @@ def _wait_for_pod_ready(
508
820
 
509
821
  def _wait_for_load_balancer_hostname(
510
822
  api: client.CoreV1Api,
823
+ namespace: str,
511
824
  service_name: str,
512
825
  timeout_seconds: int = 120,
513
826
  ) -> Optional[str]:
514
827
  start_time = time.time()
515
828
  while True:
516
- try:
517
- service = api.read_namespaced_service(name=service_name, namespace=DEFAULT_NAMESPACE)
518
- except client.ApiException as e:
519
- if e.status != 404:
520
- raise
521
- else:
522
- if service.status.load_balancer.ingress is not None:
523
- return service.status.load_balancer.ingress[0].hostname
829
+ service = call_api_method(
830
+ api.read_namespaced_service,
831
+ client.V1Service,
832
+ expected=404,
833
+ name=service_name,
834
+ namespace=namespace,
835
+ )
836
+ if service is not None:
837
+ hostname = get_value(service, ".status.load_balancer.ingress[0].hostname", str)
838
+ if hostname is not None:
839
+ return hostname
524
840
  elapsed_time = time.time() - start_time
525
841
  if elapsed_time >= timeout_seconds:
526
842
  logger.warning("Timeout waiting for load balancer %s to get ip", service_name)
@@ -607,11 +923,11 @@ def _run_ssh_command(hostname: str, port: int, ssh_private_key: str, command: st
607
923
 
608
924
 
609
925
  def _get_jump_pod_name(project_name: str) -> str:
610
- return f"{project_name}-ssh-jump-pod"
926
+ return f"dstack-{project_name}-ssh-jump-pod"
611
927
 
612
928
 
613
929
  def _get_jump_pod_service_name(project_name: str) -> str:
614
- return f"{project_name}-ssh-jump-pod-service"
930
+ return f"dstack-{project_name}-ssh-jump-pod-service"
615
931
 
616
932
 
617
933
  def _get_pod_service_name(pod_name: str) -> str: