kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,2596 @@
1
+ import os
2
+ import re
3
+ import shlex
4
+ import subprocess
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Union
8
+ from urllib.parse import urlparse
9
+
10
+ import yaml
11
+
12
+ from kubernetes import client, config
13
+ from kubernetes.client import V1ResourceRequirements
14
+
15
+ import kubetorch.constants as constants
16
+ import kubetorch.serving.constants as serving_constants
17
+
18
+ from kubetorch import globals
19
+
20
+ from kubetorch.logger import get_logger
21
+ from kubetorch.resources.callables.utils import find_locally_installed_version
22
+ from kubetorch.resources.compute.utils import (
23
+ _get_rsync_exclude_options,
24
+ _get_sync_package_paths,
25
+ _run_bash,
26
+ find_available_port,
27
+ RsyncError,
28
+ )
29
+ from kubetorch.resources.compute.websocket import WebSocketRsyncTunnel
30
+ from kubetorch.resources.images.image import Image, ImageSetupStepType
31
+ from kubetorch.resources.secrets.kubernetes_secrets_client import (
32
+ KubernetesSecretsClient,
33
+ )
34
+ from kubetorch.resources.volumes.volume import Volume
35
+ from kubetorch.servers.http.utils import is_running_in_kubernetes, load_template
36
+ from kubetorch.serving.autoscaling import AutoscalingConfig
37
+ from kubetorch.serving.service_manager import (
38
+ DeploymentServiceManager,
39
+ KnativeServiceManager,
40
+ RayClusterServiceManager,
41
+ )
42
+ from kubetorch.serving.utils import GPUConfig, pod_is_running, RequestedPodResources
43
+
44
+ from kubetorch.utils import extract_host_port, http_to_ws, load_kubeconfig
45
+
46
+ logger = get_logger(__name__)
47
+
48
+
49
+ class Compute:
50
+ def __init__(
51
+ self,
52
+ cpus: Union[str, int] = None,
53
+ memory: str = None,
54
+ disk_size: str = None,
55
+ gpus: Union[str, int] = None,
56
+ queue: str = None,
57
+ priority_class_name: str = None,
58
+ gpu_type: str = None,
59
+ gpu_memory: str = None,
60
+ namespace: str = None,
61
+ image: "Image" = None,
62
+ labels: Dict = None,
63
+ annotations: Dict = None,
64
+ volumes: List[Union[str, Volume]] = None,
65
+ node_selector: Dict = None,
66
+ service_template: Dict = None,
67
+ tolerations: List[Dict] = None,
68
+ env_vars: Dict = None,
69
+ secrets: List[Union[str, "Secret"]] = None,
70
+ freeze: bool = False,
71
+ kubeconfig_path: str = None,
72
+ service_account_name: str = None,
73
+ image_pull_policy: str = None,
74
+ inactivity_ttl: str = None,
75
+ gpu_anti_affinity: bool = None,
76
+ launch_timeout: int = None,
77
+ working_dir: str = None,
78
+ shared_memory_limit: str = None,
79
+ allowed_serialization: Optional[List[str]] = None,
80
+ replicas: int = 1,
81
+ _skip_template_init: bool = False,
82
+ ):
83
+ """Initialize the compute requirements for a Kubetorch service.
84
+
85
+ Args:
86
+ cpus (str, int, optional): CPU resource request. Can be specified in cores ("1.0") or millicores ("1000m").
87
+ memory (str, optional): Memory resource request. Can use binary (Ki, Mi, Gi) or decimal (K, M, G) units.
88
+ disk_size (str, optional): Ephemeral storage request. Uses same format as memory.
89
+ gpus (str or int, optional): Number of GPUs to request. Fractional GPUs not currently supported.
90
+ gpu_type (str, optional): GPU type to request. Corresponds to the "nvidia.com/gpu.product" label on the
91
+ node (if GPU feature discovery is enabled), or a full string like "nvidia.com/gpu.product: L4" can be
92
+ passed, which will be used to set a `nodeSelector` on the service. More info below.
93
+ gpu_memory (str, optional): GPU memory request (e.g., "4Gi"). Will still request whole GPU but limit
94
+ memory usage.
95
+ queue (str, optional): Name of the Kubernetes queue that will be responsible for placing the service's
96
+ pods onto nodes. Controls how cluster resources are allocated and prioritized for this service.
97
+ Pods will be scheduled according to the quota, priority, and limits configured for the queue.
98
+ priority_class_name (str, optional): Name of the Kubernetes priority class to use for the service. If
99
+ not specified, the default priority class will be used.
100
+ namespace (str, optional): Kubernetes namespace. Defaults to global config default, or "default".
101
+ image (Image, optional): Kubetorch image configuration. See :class:`Image` for more details.
102
+ labels (Dict, optional): Kubernetes labels to apply to the service.
103
+ annotations (Dict, optional): Kubernetes annotations to apply to the service.
104
+ volumes (List[Union[str or Volume]], optional): Volumes to attach to the service. Can be specified as a
105
+ list of volume names (strings) or Volume objects. If using strings, they must be the names of existing
106
+ PersistentVolumeClaims (PVCs) in the specified namespace.
107
+ node_selector (Dict, optional): Kubernetes node selector to constrain pods to specific nodes. Should be a
108
+ dictionary of key-value pairs, e.g. `{"node.kubernetes.io/instance-type": "g4dn.xlarge"}`.
109
+ service_template (Dict, optional): Nested dictionary of service template arguments to apply to the service. E.g.
110
+ ``{"spec": {"template": {"spec": {"nodeSelector": {"node.kubernetes.io/instance-type": "g4dn.xlarge"}}}}}``
111
+ tolerations (List[Dict], optional): Kubernetes tolerations to apply to the service. Each toleration should
112
+ be a dictionary with keys like "key", "operator", "value", and "effect". More info
113
+ `here <https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/>`__.
114
+ env_vars (Dict, optional): Environment variables to set in containers.
115
+ secrets (List[Union[str, Secret]], optional): Secrets to mount or expose.
116
+ freeze (bool, optional): Whether to freeze the compute configuration (e.g. for production).
117
+ kubeconfig_path (str, optional): Path to local kubeconfig file used for cluster authentication.
118
+ service_account_name (str, optional): Kubernetes service account to use.
119
+ image_pull_policy (str, optional): Container image pull policy.
120
+ More info `here <https://kubernetes.io/docs/concepts/containers/images/#image-pull-policy>`__.
121
+ inactivity_ttl (str, optional): Time-to-live after inactivity. Once hit, the service will be destroyed.
122
+ Values below 1m may cause premature deletion.
123
+ gpu_anti_affinity (bool, optional): Whether to prevent scheduling the service on a GPU, should no GPUs be requested.
124
+ Can also control globally by setting the `KT_GPU_ANTI_AFFINITY` environment variable. (Default: ``False``)
125
+ launch_timeout (int, optional): Determines how long to wait for the service to ready before giving up.
126
+ If not specified, will wait {serving_constants.KT_LAUNCH_TIMEOUT} seconds.
127
+ Note: you can also control this timeout globally by setting the `KT_LAUNCH_TIMEOUT` environment variable.
128
+ replicas (int, optional): Number of replicas to create for deployment-based services. Can also be set via
129
+ the `.distribute(workers=N)` method for distributed training. (Default: 1)
130
+ working_dir (str, optional): Working directory to use inside the remote images. Must be an absolute path (e.g. `/kt`)
131
+ shared_memory_limit (str, optional): Maximum size of the shared memory filesystem (/dev/shm) available to
132
+ each pod created by the service. Value should be a Kubernetes quantity string, for example: "512Mi",
133
+ "2Gi", "1G", "1024Mi", "100M". If not provided, /dev/shm will default to the pod's memory limit (if set)
134
+ or up to half the node's RAM.
135
+
136
+ Note:
137
+ **Resource Specification Formats:**
138
+
139
+ CPUs:
140
+ - Decimal core count: "0.5", "1.0", "2.0"
141
+ - Millicores: "500m", "1000m", "2000m"
142
+
143
+ Memory:
144
+ - Bytes: "1000000"
145
+ - Binary units: "1Ki", "1Mi", "1Gi", "1Ti"
146
+ - Decimal units: "1K", "1M", "1G", "1T"
147
+
148
+ GPU Specifications:
149
+ 1. ``gpus`` for whole GPUs: "1", "2"
150
+ 2. ``gpu_memory``: "$Gi", "16Gi"
151
+
152
+ Disk Size:
153
+ - Same format as memory
154
+
155
+ Note:
156
+ - Memory/disk values are case sensitive (Mi != mi)
157
+ - When using ``gpu_memory``, a whole GPU is still requested but memory is limited
158
+
159
+ Examples:
160
+
161
+ .. code-block:: python
162
+
163
+ import kubetorch as kt
164
+
165
+ # Basic CPU/Memory request
166
+ compute = kt.Compute(cpus="0.5", memory="2Gi")
167
+
168
+ # GPU request with memory limit
169
+ compute = kt.Compute(gpu_memory="4Gi", cpus="1.0")
170
+
171
+ # Multiple whole GPUs
172
+ compute = kt.Compute(gpus="2", memory="16Gi")
173
+ """
174
+ self.default_config = {}
175
+
176
+ self._endpoint = None
177
+ self._service_manager = None
178
+ self._autoscaling_config = None
179
+ self._kubeconfig_path = kubeconfig_path
180
+ self._namespace = namespace or globals.config.namespace
181
+
182
+ self._objects_api = None
183
+ self._core_api = None
184
+ self._apps_v1_api = None
185
+ self._node_v1_api = None
186
+
187
+ self._image = image
188
+ self._service_name = None
189
+ self._secrets = secrets
190
+ self._secrets_client = None
191
+ self._volumes = volumes
192
+ self._queue = queue
193
+
194
+ # service template args to store
195
+ self.replicas = replicas
196
+ self.labels = labels or {}
197
+ self.annotations = annotations or {}
198
+ self.service_template = service_template or {}
199
+ self._gpu_annotations = {} # Will be populated during init or from_template
200
+
201
+ # Skip template initialization if loading from existing service
202
+ if _skip_template_init:
203
+ return
204
+
205
+ # determine pod template vars
206
+ server_port = serving_constants.DEFAULT_KT_SERVER_PORT
207
+ service_account_name = (
208
+ service_account_name or serving_constants.DEFAULT_SERVICE_ACCOUNT_NAME
209
+ )
210
+ otel_enabled = (
211
+ globals.config.cluster_config.get("otel_enabled", False)
212
+ if globals.config.cluster_config
213
+ else False
214
+ )
215
+ server_image = self._get_server_image(image, otel_enabled, inactivity_ttl)
216
+ gpus = None if gpus in (0, None) else gpus
217
+ gpu_config = self._load_gpu_config(gpus, gpu_memory, gpu_type)
218
+ self._gpu_annotations = self._get_gpu_annotations(gpu_config)
219
+ requested_resources = self._get_requested_resources(
220
+ cpus, memory, disk_size, gpu_config
221
+ )
222
+ secret_env_vars, secret_volumes = self._extract_secrets(secrets)
223
+ volume_mounts, volume_specs = self._volumes_for_pod_template(volumes)
224
+ scheduler_name = self._get_scheduler_name(queue)
225
+ node_selector = self._get_node_selector(
226
+ node_selector.copy() if node_selector else {}, gpu_type
227
+ )
228
+
229
+ env_vars = env_vars or {}
230
+ if os.getenv("KT_LOG_LEVEL") and not env_vars.get("KT_LOG_LEVEL"):
231
+ # If KT_LOG_LEVEL is set, add it to env vars so the log level is set on the server
232
+ env_vars["KT_LOG_LEVEL"] = os.getenv("KT_LOG_LEVEL")
233
+
234
+ template_vars = {
235
+ "server_image": server_image,
236
+ "server_port": server_port,
237
+ "env_vars": env_vars,
238
+ "resources": requested_resources,
239
+ "node_selector": node_selector,
240
+ "secret_env_vars": secret_env_vars or [],
241
+ "secret_volumes": secret_volumes or [],
242
+ "volume_mounts": volume_mounts,
243
+ "volume_specs": volume_specs,
244
+ "service_account_name": service_account_name,
245
+ "config_env_vars": self._get_config_env_vars(
246
+ allowed_serialization or ["json"]
247
+ ),
248
+ "image_pull_policy": image_pull_policy,
249
+ "namespace": self._namespace,
250
+ "freeze": freeze,
251
+ "gpu_anti_affinity": gpu_anti_affinity,
252
+ "working_dir": working_dir,
253
+ "tolerations": tolerations,
254
+ "shm_size_limit": shared_memory_limit,
255
+ "priority_class_name": priority_class_name,
256
+ "launch_timeout": self._get_launch_timeout(launch_timeout),
257
+ "queue_name": self.queue_name(),
258
+ "scheduler_name": scheduler_name,
259
+ "inactivity_ttl": inactivity_ttl,
260
+ "otel_enabled": otel_enabled,
261
+ # launch time arguments
262
+ "raycluster": False,
263
+ "setup_script": "",
264
+ }
265
+
266
+ self.pod_template = load_template(
267
+ template_file=serving_constants.POD_TEMPLATE_FILE,
268
+ template_dir=os.path.join(
269
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
270
+ "serving",
271
+ "templates",
272
+ ),
273
+ **template_vars,
274
+ )
275
+
276
+ @classmethod
277
+ def from_template(cls, service_info: dict):
278
+ """Create a Compute object from a deployed Kubernetes resource."""
279
+ if "resource" not in service_info:
280
+ raise ValueError("service_info missing required key: resource")
281
+
282
+ resource = service_info["resource"]
283
+ kind = resource.get("kind", "Unknown")
284
+
285
+ if kind == "RayCluster":
286
+ template_path = resource["spec"]["headGroupSpec"]["template"]
287
+ elif kind in ["Deployment", "Service"]: # Deployment or Knative Service
288
+ template_path = resource["spec"]["template"]
289
+ else:
290
+ raise ValueError(
291
+ f"Unsupported resource kind: '{kind}'. "
292
+ f"Supported kinds are: Deployment, Service (Knative), RayCluster"
293
+ )
294
+
295
+ template_metadata = template_path["metadata"]
296
+ pod_spec = template_path["spec"]
297
+
298
+ annotations = template_metadata.get("annotations", {})
299
+
300
+ compute = cls(_skip_template_init=True)
301
+ compute.pod_template = pod_spec
302
+
303
+ # Set properties from manifest
304
+ compute._namespace = service_info["namespace"]
305
+ compute.replicas = resource["spec"].get("replicas")
306
+ compute.labels = template_metadata.get("labels", {})
307
+ compute.annotations = annotations
308
+ compute._autoscaling_config = annotations.get(
309
+ "autoscaling.knative.dev/config", {}
310
+ )
311
+ compute._queue = template_metadata.get("labels", {}).get("kai.scheduler/queue")
312
+ compute._kubeconfig_path = annotations.get(
313
+ serving_constants.KUBECONFIG_PATH_ANNOTATION
314
+ )
315
+
316
+ # Extract GPU annotations directly from template annotations
317
+ gpu_annotation_keys = ["gpu-memory", "gpu-fraction"]
318
+ compute._gpu_annotations = {
319
+ k: v for k, v in annotations.items() if k in gpu_annotation_keys
320
+ }
321
+
322
+ return compute
323
+
324
+ # ----------------- Properties ----------------- #
325
+ @property
326
+ def objects_api(self):
327
+ if self._objects_api is None:
328
+ self._objects_api = client.CustomObjectsApi()
329
+ return self._objects_api
330
+
331
+ @property
332
+ def core_api(self):
333
+ if self._core_api is None:
334
+ load_kubeconfig()
335
+ self._core_api = client.CoreV1Api()
336
+ return self._core_api
337
+
338
+ @property
339
+ def apps_v1_api(self):
340
+ if self._apps_v1_api is None:
341
+ self._apps_v1_api = client.AppsV1Api()
342
+ return self._apps_v1_api
343
+
344
+ @property
345
+ def node_v1_api(self):
346
+ if self._node_v1_api is None:
347
+ self._node_v1_api = client.NodeV1Api()
348
+ return self._node_v1_api
349
+
350
+ @property
351
+ def kubeconfig_path(self):
352
+ if self._kubeconfig_path is None:
353
+ self._kubeconfig_path = (
354
+ os.getenv("KUBECONFIG") or constants.DEFAULT_KUBECONFIG_PATH
355
+ )
356
+ return str(Path(self._kubeconfig_path).expanduser())
357
+
358
+ @property
359
+ def service_manager(self):
360
+ if self._service_manager is None:
361
+ self._load_kube_config()
362
+ # Select appropriate service manager based on configuration
363
+ if self.deployment_mode == "knative":
364
+ # Use KnativeServiceManager for autoscaling services
365
+ self._service_manager = KnativeServiceManager(
366
+ objects_api=self.objects_api,
367
+ core_api=self.core_api,
368
+ apps_v1_api=self.apps_v1_api,
369
+ namespace=self.namespace,
370
+ )
371
+ elif self.deployment_mode == "raycluster":
372
+ # Use RayClusterServiceManager for Ray distributed workloads
373
+ self._service_manager = RayClusterServiceManager(
374
+ objects_api=self.objects_api,
375
+ core_api=self.core_api,
376
+ apps_v1_api=self.apps_v1_api,
377
+ namespace=self.namespace,
378
+ )
379
+ else:
380
+ # Use DeploymentServiceManager for regular deployments
381
+ self._service_manager = DeploymentServiceManager(
382
+ objects_api=self.objects_api,
383
+ core_api=self.core_api,
384
+ apps_v1_api=self.apps_v1_api,
385
+ namespace=self.namespace,
386
+ )
387
+ return self._service_manager
388
+
389
+ @property
390
+ def secrets_client(self):
391
+ if not self.secrets:
392
+ # Skip creating secrets client if no secrets are provided
393
+ return None
394
+
395
+ if self._secrets_client is None:
396
+ self._secrets_client = KubernetesSecretsClient(
397
+ namespace=self.namespace, kubeconfig_path=self.kubeconfig_path
398
+ )
399
+ return self._secrets_client
400
+
401
+ @property
402
+ def image(self):
403
+ return self._image
404
+
405
+ @image.setter
406
+ def image(self, value: "Image"):
407
+ self._image = value
408
+
409
+ @property
410
+ def endpoint(self):
411
+ if self._endpoint is None and self.service_name:
412
+ self._endpoint = self.service_manager.get_endpoint(self.service_name)
413
+ return self._endpoint
414
+
415
+ @endpoint.setter
416
+ def endpoint(self, endpoint: str):
417
+ self._endpoint = endpoint
418
+
419
+ def _container(self):
420
+ """Get the container from the pod template."""
421
+ if "containers" not in self.pod_template:
422
+ raise ValueError("pod_template missing 'containers' field.")
423
+ return self.pod_template["containers"][0]
424
+
425
+ def _container_env(self):
426
+ container = self._container()
427
+ if "env" not in container:
428
+ return []
429
+ return container["env"]
430
+
431
+ def _set_container_resource(self, resource_name: str, value: str):
432
+ container = self._container()
433
+
434
+ # Ensure resources dict exists
435
+ if "resources" not in container:
436
+ container["resources"] = {}
437
+
438
+ # Ensure requests dict exists
439
+ if "requests" not in container["resources"]:
440
+ container["resources"]["requests"] = {}
441
+
442
+ # Ensure limits dict exists
443
+ if "limits" not in container["resources"]:
444
+ container["resources"]["limits"] = {}
445
+
446
+ # Set both requests and limits to the same value
447
+ container["resources"]["requests"][resource_name] = value
448
+ container["resources"]["limits"][resource_name] = value
449
+
450
+ def _get_container_resource(self, resource_name: str) -> Optional[str]:
451
+ resources = self._container().get("resources", {})
452
+ requests = resources.get("requests", {})
453
+ return requests.get(resource_name)
454
+
455
+ # -------------- Properties From Template -------------- #
456
+ @property
457
+ def server_image(self):
458
+ return self._container().get("image")
459
+
460
+ @server_image.setter
461
+ def server_image(self, value: str):
462
+ """Set the server image in the pod template."""
463
+ self._container()["image"] = value
464
+
465
+ @property
466
+ def server_port(self):
467
+ return self._container()["ports"][0].get("containerPort")
468
+
469
+ @server_port.setter
470
+ def server_port(self, value: int):
471
+ """Set the server port in the pod template."""
472
+ self._container()["ports"][0]["containerPort"] = value
473
+
474
+ @property
475
+ def env_vars(self):
476
+ # extract user-defined environment variables from rendered pod template
477
+ kt_env_vars = [
478
+ "POD_NAME",
479
+ "POD_NAMESPACE",
480
+ "POD_IP",
481
+ "POD_UUID",
482
+ "MODULE_NAME",
483
+ "KUBETORCH_VERSION",
484
+ "UV_LINK_MODE",
485
+ "OTEL_SERVICE_NAME",
486
+ "OTEL_EXPORTER_OTLP_ENDPOINT",
487
+ "OTEL_EXPORTER_OTLP_PROTOCOL",
488
+ "OTEL_TRACES_EXPORTER",
489
+ "OTEL_PROPAGATORS",
490
+ "KT_SERVER_PORT",
491
+ "KT_FREEZE",
492
+ "KT_INACTIVITY_TTL",
493
+ "KT_ALLOWED_SERIALIZATION",
494
+ "KT_FILE_PATH",
495
+ "KT_MODULE_NAME",
496
+ "KT_CLS_OR_FN_NAME",
497
+ "KT_CALLABLE_TYPE",
498
+ "KT_LAUNCH_ID",
499
+ "KT_SERVICE_NAME",
500
+ "KT_SERVICE_DNS",
501
+ ]
502
+ user_env_vars = {}
503
+ for env_var in self._container_env():
504
+ # skip if it was set by kubetorch
505
+ if env_var["name"] not in kt_env_vars and "value" in env_var:
506
+ user_env_vars[env_var["name"]] = env_var["value"]
507
+ return user_env_vars
508
+
509
+ @property
510
+ def resources(self):
511
+ return self._container().get("resources")
512
+
513
+ @property
514
+ def cpus(self):
515
+ return self._get_container_resource("cpu")
516
+
517
+ @cpus.setter
518
+ def cpus(self, value: str):
519
+ """
520
+ Args:
521
+ value: CPU value (e.g., "2", "1000m", "0.5")
522
+ """
523
+ self._set_container_resource("cpu", value)
524
+
525
+ @property
526
+ def memory(self):
527
+ return self._get_container_resource("memory")
528
+
529
+ @memory.setter
530
+ def memory(self, value: str):
531
+ """
532
+ Args:
533
+ value: Memory value (e.g., "4Gi", "2048Mi")
534
+ """
535
+ self._set_container_resource("memory", value)
536
+
537
+ @property
538
+ def disk_size(self):
539
+ return self._get_container_resource("ephemeral-storage")
540
+
541
+ @disk_size.setter
542
+ def disk_size(self, value: str):
543
+ """
544
+ Args:
545
+ value: Disk size (e.g., "10Gi", "5000Mi")
546
+ """
547
+ self._set_container_resource("ephemeral-storage", value)
548
+
549
+ @property
550
+ def gpus(self):
551
+ return self._get_container_resource("nvidia.com/gpu")
552
+
553
+ @gpus.setter
554
+ def gpus(self, value: Union[str, int]):
555
+ """
556
+ Args:
557
+ value: Number of GPUs (e.g., 1, "2")
558
+ """
559
+ self._set_container_resource("nvidia.com/gpu", str(value))
560
+
561
+ @property
562
+ def gpu_type(self):
563
+ node_selector = self.pod_template.get("nodeSelector")
564
+ if node_selector and "nvidia.com/gpu.product" in node_selector:
565
+ return node_selector["nvidia.com/gpu.product"]
566
+ return None
567
+
568
+ @gpu_type.setter
569
+ def gpu_type(self, value: str):
570
+ """
571
+ Args:
572
+ value: GPU product name (e.g., "L4", "V100", "A100", "T4")
573
+ """
574
+ if "nodeSelector" not in self.pod_template:
575
+ self.pod_template["nodeSelector"] = {}
576
+ self.pod_template["nodeSelector"]["nvidia.com/gpu.product"] = value
577
+
578
+ @property
579
+ def gpu_memory(self):
580
+ annotations = self.pod_template.get("annotations", {})
581
+ if "gpu-memory" in annotations:
582
+ return annotations["gpu-memory"]
583
+ return None
584
+
585
+ @gpu_memory.setter
586
+ def gpu_memory(self, value: str):
587
+ """
588
+ Args:
589
+ value: GPU memory in MiB (e.g., "4096", "8192", "16384")
590
+ """
591
+ if "annotations" not in self.pod_template:
592
+ self.pod_template["annotations"] = {}
593
+ self.pod_template["annotations"]["gpu-memory"] = value
594
+
595
+ @property
596
+ def volumes(self):
597
+ if not self._volumes:
598
+ volumes = []
599
+ if "volumes" in self.pod_template:
600
+ for volume in self.pod_template["volumes"]:
601
+ # Skip the shared memory volume
602
+ if volume["name"] == "dshm":
603
+ continue
604
+ # Skip secret volumes
605
+ if "secret" in volume:
606
+ continue
607
+ # Only include PVC volumes
608
+ if "persistentVolumeClaim" in volume:
609
+ volumes.append(volume["name"])
610
+ self._volumes = volumes
611
+ return self._volumes
612
+
613
+ @property
614
+ def shared_memory_limit(self):
615
+ if "volumes" not in self.pod_template:
616
+ return None
617
+
618
+ for volume in self.pod_template["volumes"]:
619
+ if volume.get("name") == "dshm" and "emptyDir" in volume:
620
+ empty_dir = volume["emptyDir"]
621
+ return empty_dir.get("sizeLimit")
622
+
623
+ return None
624
+
625
+ @shared_memory_limit.setter
626
+ def shared_memory_limit(self, value: str):
627
+ """
628
+ Args:
629
+ value: Size limit (e.g., "512Mi", "1Gi", "2G")
630
+ """
631
+ if "volumes" not in self.pod_template:
632
+ self.pod_template["volumes"] = []
633
+
634
+ # Find existing dshm volume and update it
635
+ for volume in self.pod_template["volumes"]:
636
+ if volume.get("name") == "dshm" and "emptyDir" in volume:
637
+ volume["emptyDir"]["sizeLimit"] = value
638
+ return
639
+
640
+ # Add new dshm volume if not found
641
+ self.pod_template["volumes"].append(
642
+ {"name": "dshm", "emptyDir": {"medium": "Memory", "sizeLimit": value}}
643
+ )
644
+
645
+ # Alias for backward compatibility (deprecated)
646
+ @property
647
+ def shm_size_limit(self):
648
+ """Deprecated: Use shared_memory_limit instead."""
649
+ return self.shared_memory_limit
650
+
651
+ @shm_size_limit.setter
652
+ def shm_size_limit(self, value: str):
653
+ """Deprecated: Use shared_memory_limit instead."""
654
+ self.shared_memory_limit = value
655
+
656
+ @property
657
+ def node_selector(self):
658
+ return self.pod_template.get("nodeSelector")
659
+
660
+ @node_selector.setter
661
+ def node_selector(self, value: dict):
662
+ """
663
+ Args:
664
+ value: Label key-value pairs (e.g., {"node-type": "gpu"})
665
+ """
666
+ self.pod_template["nodeSelector"] = value
667
+
668
+ @property
669
+ def secret_env_vars(self):
670
+ secret_env_vars = []
671
+ container = self._container()
672
+ if "env" in container:
673
+ for env_var in container["env"]:
674
+ if "valueFrom" in env_var and "secretKeyRef" in env_var["valueFrom"]:
675
+ secret_ref = env_var["valueFrom"]["secretKeyRef"]
676
+ # Find existing secret or create new entry
677
+ secret_name = secret_ref["name"]
678
+
679
+ # Check if we already have this secret
680
+ existing_secret = None
681
+ for secret in secret_env_vars:
682
+ if secret.get("secret_name") == secret_name:
683
+ existing_secret = secret
684
+ break
685
+
686
+ if existing_secret:
687
+ if "env_vars" not in existing_secret:
688
+ existing_secret["env_vars"] = []
689
+ if env_var["name"] not in existing_secret["env_vars"]:
690
+ existing_secret["env_vars"].append(env_var["name"])
691
+ else:
692
+ secret_env_vars.append(
693
+ {"secret_name": secret_name, "env_vars": [env_var["name"]]}
694
+ )
695
+ return secret_env_vars
696
+
697
+ @property
698
+ def secret_volumes(self):
699
+ secret_volumes = []
700
+ if "volumes" in self.pod_template:
701
+ for volume in self.pod_template["volumes"]:
702
+ if "secret" in volume:
703
+ secret_name = volume["secret"]["secretName"]
704
+ # Find corresponding volume mount
705
+ mount_path = None
706
+ container = self._container()
707
+ if "volumeMounts" in container:
708
+ for mount in container["volumeMounts"]:
709
+ if mount["name"] == volume["name"]:
710
+ mount_path = mount["mountPath"]
711
+ break
712
+
713
+ secret_volumes.append(
714
+ {
715
+ "name": volume["name"],
716
+ "secret_name": secret_name,
717
+ "path": mount_path or f"/secrets/{volume['name']}",
718
+ }
719
+ )
720
+ return secret_volumes
721
+
722
+ @property
723
+ def volume_mounts(self):
724
+ volume_mounts = []
725
+ container = self._container()
726
+ if "volumeMounts" in container:
727
+ for mount in container["volumeMounts"]:
728
+ # Skip the default dshm mount
729
+ if mount["name"] != "dshm":
730
+ volume_mounts.append(
731
+ {"name": mount["name"], "mountPath": mount["mountPath"]}
732
+ )
733
+ return volume_mounts
734
+
735
+ @property
736
+ def service_account_name(self):
737
+ return self.pod_template.get("serviceAccountName")
738
+
739
+ @service_account_name.setter
740
+ def service_account_name(self, value: str):
741
+ """Set service account name in the pod template."""
742
+ self.pod_template["serviceAccountName"] = value
743
+
744
+ @property
745
+ def config_env_vars(self):
746
+ from kubetorch.config import ENV_MAPPINGS
747
+
748
+ config_env_vars = {}
749
+ container = self._container()
750
+ if "env" in container:
751
+ for env_var in container["env"]:
752
+ # Filter for config-related env vars (those that start with KT_ or are known config vars)
753
+ if env_var["name"] in ENV_MAPPINGS.keys():
754
+ if "value" in env_var and env_var["value"]:
755
+ config_env_vars[env_var["name"]] = env_var["value"]
756
+ return config_env_vars
757
+
758
+ @property
759
+ def image_pull_policy(self):
760
+ return self._container().get("imagePullPolicy")
761
+
762
+ @image_pull_policy.setter
763
+ def image_pull_policy(self, value: str):
764
+ """Set image pull policy in the pod template."""
765
+ self._container()["imagePullPolicy"] = value
766
+
767
+ @property
768
+ def namespace(self):
769
+ return self._namespace
770
+
771
+ @namespace.setter
772
+ def namespace(self, value: str):
773
+ self._namespace = value
774
+
775
+ @property
776
+ def python_path(self):
777
+ if self.image and self.image.python_path:
778
+ return self.image.python_path
779
+
780
+ container = self._container()
781
+ if "env" in container:
782
+ for env_var in container["env"]:
783
+ if env_var["name"] == "KT_PYTHON_PATH" and "value" in env_var:
784
+ return env_var["value"]
785
+ return None
786
+
787
+ @property
788
+ def freeze(self):
789
+ container = self._container()
790
+ if "env" in container:
791
+ for env_var in container["env"]:
792
+ if env_var["name"] == "KT_FREEZE" and "value" in env_var:
793
+ return env_var["value"].lower() == "true"
794
+ return False
795
+
796
+ @property
797
+ def secrets(self):
798
+ if not self._secrets:
799
+ secrets = []
800
+
801
+ # Extract secrets from environment variables
802
+ container = self._container()
803
+ if "env" in container:
804
+ for env_var in container["env"]:
805
+ if (
806
+ "valueFrom" in env_var
807
+ and "secretKeyRef" in env_var["valueFrom"]
808
+ ):
809
+ secret_ref = env_var["valueFrom"]["secretKeyRef"]
810
+ if secret_ref["name"] not in secrets:
811
+ secrets.append(secret_ref["name"])
812
+
813
+ # Extract secrets from volumes
814
+ if "volumes" in self.pod_template:
815
+ for volume in self.pod_template["volumes"]:
816
+ if "secret" in volume:
817
+ secret_name = volume["secret"]["secretName"]
818
+ if secret_name not in secrets:
819
+ secrets.append(secret_name)
820
+
821
+ self._secrets = secrets
822
+
823
+ return self._secrets
824
+
825
+ @property
826
+ def gpu_anti_affinity(self):
827
+ if (
828
+ "affinity" in self.pod_template
829
+ and "nodeAffinity" in self.pod_template["affinity"]
830
+ ):
831
+ node_affinity = self.pod_template["affinity"]["nodeAffinity"]
832
+ if "requiredDuringSchedulingIgnoredDuringExecution" in node_affinity:
833
+ required = node_affinity[
834
+ "requiredDuringSchedulingIgnoredDuringExecution"
835
+ ]
836
+ if "nodeSelectorTerms" in required:
837
+ for term in required["nodeSelectorTerms"]:
838
+ if "matchExpressions" in term:
839
+ for expr in term["matchExpressions"]:
840
+ if (
841
+ expr.get("key") == "nvidia.com/gpu"
842
+ and expr.get("operator") == "DoesNotExist"
843
+ ):
844
+ return True
845
+ return False
846
+
847
+ @property
848
+ def concurrency(self):
849
+ return self.pod_template.get("containerConcurrency")
850
+
851
+ @concurrency.setter
852
+ def concurrency(self, value: int):
853
+ self.pod_template["containerConcurrency"] = value
854
+
855
+ @property
856
+ def working_dir(self):
857
+ return self._container().get("workingDir")
858
+
859
+ @working_dir.setter
860
+ def working_dir(self, value: str):
861
+ """Set working directory in the pod template."""
862
+ self._container()["workingDir"] = value
863
+
864
+ @property
865
+ def priority_class_name(self):
866
+ return self.pod_template.get("priorityClassName")
867
+
868
+ @priority_class_name.setter
869
+ def priority_class_name(self, value: str):
870
+ """Set priority class name in the pod template."""
871
+ self.pod_template["priorityClassName"] = value
872
+
873
+ @property
874
+ def tracing_enabled(self):
875
+ """Whether tracing is enabled based on global config."""
876
+ return globals.config.tracing_enabled
877
+
878
+ @property
879
+ def otel_enabled(self):
880
+ container = self._container()
881
+ if "env" in container:
882
+ for env_var in container["env"]:
883
+ if env_var["name"] == "KT_OTEL_ENABLED" and "value" in env_var:
884
+ return env_var["value"].lower() == "true"
885
+ return False
886
+
887
+ @property
888
+ def launch_timeout(self):
889
+ container = self._container()
890
+ if "startupProbe" in container:
891
+ startup_probe = container["startupProbe"]
892
+ if "failureThreshold" in startup_probe:
893
+ # Convert back from failure threshold (launch_timeout // 5)
894
+ return startup_probe["failureThreshold"] * 5
895
+ return None
896
+
897
+ @launch_timeout.setter
898
+ def launch_timeout(self, value: int):
899
+ container = self._container()
900
+ if "startupProbe" not in container:
901
+ container["startupProbe"] = {}
902
+ # Convert timeout to failure threshold (launch_timeout // 5)
903
+ container["startupProbe"]["failureThreshold"] = value // 5
904
+
905
+ def queue_name(self):
906
+ if self.queue is not None:
907
+ return self.queue
908
+
909
+ default_queue = globals.config.queue
910
+ if default_queue:
911
+ return default_queue
912
+
913
+ @property
914
+ def queue(self):
915
+ return self._queue
916
+
917
+ @queue.setter
918
+ def queue(self, value: str):
919
+ self._queue = value
920
+
921
+ @property
922
+ def scheduler_name(self):
923
+ return self._get_scheduler_name(self.queue_name())
924
+
925
+ @property
926
+ def inactivity_ttl(self):
927
+ container = self._container()
928
+ if "env" in container:
929
+ for env_var in container["env"]:
930
+ if env_var["name"] == "KT_INACTIVITY_TTL" and "value" in env_var:
931
+ return env_var["value"] if not env_var["value"] == "None" else None
932
+ return None
933
+
934
+ @inactivity_ttl.setter
935
+ def inactivity_ttl(self, value: str):
936
+ if value and (
937
+ not isinstance(value, str) or not re.match(r"^\d+[smhd]$", value)
938
+ ):
939
+ raise ValueError("Inactivity TTL must be a string, e.g. '5m', '1h', '1d'")
940
+ if value and not self.otel_enabled:
941
+ logger.warning(
942
+ "Inactivity TTL is only supported when OTEL is enabled, please update your Kubetorch Helm chart and restart the nginx proxy"
943
+ )
944
+
945
+ container = self._container()
946
+ if "env" not in container:
947
+ container["env"] = []
948
+
949
+ # Find existing KT_INACTIVITY_TTL env var and update it
950
+ for env_var in container["env"]:
951
+ if env_var["name"] == "KT_INACTIVITY_TTL":
952
+ env_var["value"] = value if value is not None else "None"
953
+ return
954
+
955
+ # Add new env var if not found
956
+ container["env"].append(
957
+ {
958
+ "name": "KT_INACTIVITY_TTL",
959
+ "value": value if value is not None else "None",
960
+ }
961
+ )
962
+
963
+ @property
964
+ def name(self):
965
+ container = self._container()
966
+ if "env" in container:
967
+ for env_var in container["env"]:
968
+ if env_var["name"] == "KT_SERVICE_NAME" and "value" in env_var:
969
+ return env_var["value"] if not env_var["value"] == "None" else None
970
+ return None
971
+
972
+ @property
973
+ def raycluster(self):
974
+ container = self._container()
975
+ if "ports" in container:
976
+ for port in container["ports"]:
977
+ if port.get("name") == "ray-gcs":
978
+ return True
979
+ return False
980
+
981
+ @property
982
+ def autoscaling_config(self):
983
+ return self._autoscaling_config
984
+
985
+ @property
986
+ def distributed_config(self):
987
+ # First try to get from pod template
988
+ template_config = None
989
+ container = self._container()
990
+ if "env" in container:
991
+ for env_var in container["env"]:
992
+ if (
993
+ env_var["name"] == "KT_DISTRIBUTED_CONFIG"
994
+ and "value" in env_var
995
+ and env_var["value"]
996
+ ):
997
+ import json
998
+
999
+ try:
1000
+ template_config = json.loads(env_var["value"])
1001
+ except (json.JSONDecodeError, TypeError):
1002
+ template_config = env_var["value"]
1003
+ break
1004
+
1005
+ # Return template config if available, otherwise return stored config
1006
+ return template_config
1007
+
1008
+ @distributed_config.setter
1009
+ def distributed_config(self, config: dict):
1010
+ # Update pod template with distributed config
1011
+ container = self._container()
1012
+ if "env" not in container:
1013
+ container["env"] = []
1014
+
1015
+ # Update or add KT_SERVICE_DNS, KT_DISTRIBUTED_CONFIG env vars
1016
+ import json
1017
+
1018
+ service_dns = None
1019
+ if config and config.get("distribution_type") == "ray":
1020
+ service_dns = "ray-head-svc"
1021
+ elif config and config.get("distribution_type") == "pytorch":
1022
+ service_dns = "rank0"
1023
+
1024
+ # Serialize the config to JSON, ensuring it's always a string
1025
+ # Check for non-serializable values and raise an error with details
1026
+ non_serializable_keys = []
1027
+ for key, value in config.items():
1028
+ try:
1029
+ json.dumps(value)
1030
+ except (TypeError, ValueError) as e:
1031
+ non_serializable_keys.append(
1032
+ f"'{key}': {type(value).__name__} - {str(e)}"
1033
+ )
1034
+
1035
+ if non_serializable_keys:
1036
+ raise ValueError(
1037
+ f"Distributed config contains non-serializable values: {', '.join(non_serializable_keys)}. "
1038
+ f"All values must be JSON serializable (strings, numbers, booleans, lists, dicts)."
1039
+ )
1040
+
1041
+ service_dns_found, distributed_config_found = False, False
1042
+ for env_var in self._container_env():
1043
+ if env_var["name"] == "KT_SERVICE_DNS" and service_dns:
1044
+ env_var["value"] = service_dns
1045
+ service_dns_found = True
1046
+ elif env_var["name"] == "KT_DISTRIBUTED_CONFIG":
1047
+ env_var["value"] = json.dumps(config)
1048
+ distributed_config_found = True
1049
+
1050
+ # Add any missing env vars
1051
+ if service_dns and not service_dns_found:
1052
+ container["env"].append({"name": "KT_SERVICE_DNS", "value": service_dns})
1053
+ if not distributed_config_found:
1054
+ container["env"].append(
1055
+ {"name": "KT_DISTRIBUTED_CONFIG", "value": json.dumps(config)}
1056
+ )
1057
+
1058
+ @property
1059
+ def deployment_mode(self):
1060
+ # Determine deployment mode based on distributed config and autoscaling.
1061
+ # For distributed workloads, always use the appropriate deployment mode
1062
+ if self.distributed_config:
1063
+ distribution_type = self.distributed_config.get("distribution_type")
1064
+ if distribution_type == "pytorch":
1065
+ return "deployment"
1066
+ elif distribution_type == "ray":
1067
+ return "raycluster"
1068
+
1069
+ # Use Knative for autoscaling services
1070
+ if self.autoscaling_config:
1071
+ return "knative"
1072
+
1073
+ # Default to deployment mode for simple workloads
1074
+ return "deployment"
1075
+
1076
+ # ----------------- Service Level Properties ----------------- #
1077
+
1078
+ @property
1079
+ def service_name(self):
1080
+ # Get service name from pod template if available, otherwise return stored service name
1081
+ if not self._service_name:
1082
+ for env_var in self._container_env():
1083
+ if env_var["name"] == "KT_SERVICE_NAME" and "value" in env_var:
1084
+ self._service_name = (
1085
+ env_var["value"] if not env_var["value"] == "None" else None
1086
+ )
1087
+ break
1088
+ return self._service_name
1089
+
1090
+ @service_name.setter
1091
+ def service_name(self, value: str):
1092
+ """Set the service name."""
1093
+ if self._service_name and not self._service_name == value:
1094
+ raise ValueError("Service name cannot be changed after it has been set")
1095
+ self._service_name = value
1096
+
1097
+ # ----------------- GPU Properties ----------------- #
1098
+
1099
+ @property
1100
+ def tolerations(self):
1101
+ # First try to get tolerations from pod template
1102
+ template_tolerations = self.pod_template.get("tolerations", [])
1103
+
1104
+ # Note: if nodes are not tainted for GPUs, these tolerations are just ignored
1105
+ required_gpu_tolerations = [
1106
+ {
1107
+ "key": "nvidia.com/gpu",
1108
+ "operator": "Exists",
1109
+ "effect": "NoSchedule",
1110
+ },
1111
+ {
1112
+ "key": "dedicated",
1113
+ "operator": "Equal",
1114
+ "value": "gpu",
1115
+ "effect": "NoSchedule",
1116
+ },
1117
+ ]
1118
+
1119
+ # Use template tolerations if available, otherwise use stored tolerations
1120
+ user_tolerations = template_tolerations if template_tolerations else []
1121
+
1122
+ # Only add required GPU tolerations if this is a GPU workload
1123
+ if (
1124
+ not self._load_gpu_config(self.gpus, self.gpu_memory, self.gpu_type)
1125
+ and not self.gpu_type
1126
+ ):
1127
+ return user_tolerations if user_tolerations else None
1128
+
1129
+ all_tolerations = user_tolerations.copy()
1130
+ for req_tol in required_gpu_tolerations:
1131
+ if not any(
1132
+ t["key"] == req_tol["key"]
1133
+ and t.get("value") == req_tol["value"]
1134
+ and t.get("effect") == req_tol["effect"]
1135
+ for t in all_tolerations
1136
+ ):
1137
+ all_tolerations.append(req_tol)
1138
+
1139
+ return all_tolerations
1140
+
1141
+ @property
1142
+ def gpu_annotations(self):
1143
+ # GPU annotations for KAI scheduler
1144
+ return self._gpu_annotations
1145
+
1146
+ # ----------------- Init Template Setup Helpers ----------------- #
1147
+ def _get_server_image(self, image, otel_enabled, inactivity_ttl):
1148
+ """Return base server image"""
1149
+ image = self.image.image_id if self.image and self.image.image_id else None
1150
+
1151
+ if not image or image == serving_constants.KUBETORCH_IMAGE_TRAPDOOR:
1152
+ # No custom image or Trapdoor → pick OTEL or default
1153
+ if self._server_should_enable_otel(otel_enabled, inactivity_ttl):
1154
+ return serving_constants.SERVER_IMAGE_WITH_OTEL
1155
+ return serving_constants.SERVER_IMAGE_MINIMAL
1156
+
1157
+ return image
1158
+
1159
+ def _get_requested_resources(self, cpus, memory, disk_size, gpu_config):
1160
+ """Return requested resources."""
1161
+ requests = {}
1162
+ limits = {}
1163
+
1164
+ # Add CPU if specified
1165
+ if cpus:
1166
+ requests["cpu"] = RequestedPodResources.cpu_for_resource_request(cpus)
1167
+ limits["cpu"] = requests["cpu"]
1168
+
1169
+ # Add Memory if specified
1170
+ if memory:
1171
+ requests["memory"] = RequestedPodResources.memory_for_resource_request(
1172
+ memory
1173
+ )
1174
+ limits["memory"] = requests["memory"]
1175
+
1176
+ # Add Storage if specified
1177
+ if disk_size:
1178
+ requests["ephemeral-storage"] = disk_size
1179
+ limits["ephemeral-storage"] = disk_size
1180
+
1181
+ # Add GPU if specified
1182
+ gpu_config: dict = gpu_config
1183
+ gpu_count = gpu_config.get("count", 1)
1184
+ if gpu_config:
1185
+ if gpu_config.get("sharing_type") == "memory":
1186
+ # TODO: not currently supported
1187
+ # For memory-sharing GPUs, we don't need to request any additional resources - the KAI scheduler
1188
+ # will handle it thru annotations
1189
+ return V1ResourceRequirements()
1190
+ elif gpu_config.get("sharing_type") == "fraction":
1191
+ # For fractional GPUs, we still need to request the base GPU resource
1192
+ requests["nvidia.com/gpu"] = "1"
1193
+ limits["nvidia.com/gpu"] = "1"
1194
+ elif not gpu_config.get("sharing_type"):
1195
+ # Whole GPUs
1196
+ requests["nvidia.com/gpu"] = str(gpu_count)
1197
+ limits["nvidia.com/gpu"] = str(gpu_count)
1198
+
1199
+ # Only include non-empty dicts
1200
+ resources = {}
1201
+ if requests:
1202
+ resources["requests"] = requests
1203
+ if limits:
1204
+ resources["limits"] = limits
1205
+
1206
+ return V1ResourceRequirements(**resources).to_dict()
1207
+
1208
+ def _get_launch_timeout(self, launch_timeout):
1209
+ if launch_timeout:
1210
+ return int(launch_timeout)
1211
+ default_launch_timeout = (
1212
+ self.default_config["launch_timeout"]
1213
+ if "launch_timeout" in self.default_config
1214
+ else serving_constants.KT_LAUNCH_TIMEOUT
1215
+ )
1216
+ return int(os.getenv("KT_LAUNCH_TIMEOUT", default_launch_timeout))
1217
+
1218
+ def _get_scheduler_name(self, queue_name):
1219
+ return serving_constants.KAI_SCHEDULER_NAME if queue_name else None
1220
+
1221
+ def _get_config_env_vars(self, allowed_serialization):
1222
+ config_env_vars = globals.config._get_config_env_vars()
1223
+ if allowed_serialization:
1224
+ config_env_vars["KT_ALLOWED_SERIALIZATION"] = ",".join(
1225
+ allowed_serialization
1226
+ )
1227
+
1228
+ return config_env_vars
1229
+
1230
+ def _server_should_enable_otel(self, otel_enabled, inactivity_ttl):
1231
+ return self.tracing_enabled or (otel_enabled and inactivity_ttl)
1232
+
1233
+ def _should_install_otel_dependencies(
1234
+ self, server_image, otel_enabled, inactivity_ttl
1235
+ ):
1236
+ return (
1237
+ self._server_should_enable_otel(otel_enabled, inactivity_ttl)
1238
+ and server_image != serving_constants.SERVER_IMAGE_WITH_OTEL
1239
+ )
1240
+
1241
+ @property
1242
+ def image_install_cmd(self):
1243
+ return self.image.install_cmd if self.image and self.image.install_cmd else None
1244
+
1245
+ def client_port(self) -> int:
1246
+ base_url = globals.service_url()
1247
+ _, port = extract_host_port(base_url)
1248
+ return port
1249
+
1250
+ # ----------------- GPU Init Template Setup Helpers ----------------- #
1251
+
1252
+ def _get_gpu_annotations(self, gpu_config):
1253
+ # https://blog.devops.dev/struggling-with-gpu-waste-on-kubernetes-how-kai-schedulers-sharing-unlocks-efficiency-1029e9bd334b
1254
+ if gpu_config is None:
1255
+ return {}
1256
+
1257
+ if gpu_config.get("sharing_type") == "memory":
1258
+ return {
1259
+ "gpu-memory": str(gpu_config["gpu_memory"]),
1260
+ }
1261
+ elif gpu_config.get("sharing_type") == "fraction":
1262
+ return {
1263
+ "gpu-fraction": str(gpu_config["gpu_fraction"]),
1264
+ }
1265
+ else:
1266
+ return {}
1267
+
1268
+ def _load_gpu_config(self, gpus, gpu_memory, gpu_type) -> dict:
1269
+ if all(x is None for x in [gpus, gpu_memory, gpu_type]):
1270
+ return {}
1271
+
1272
+ if gpus is not None:
1273
+ if isinstance(gpus, (int, float)):
1274
+ if gpus <= 0:
1275
+ raise ValueError("GPU count must be greater than 0")
1276
+ if gpus < 1:
1277
+ raise ValueError(
1278
+ "Fractional GPUs are not currently supported. Please use whole GPUs."
1279
+ )
1280
+ if not str(gpus).isdigit():
1281
+ raise ValueError(
1282
+ "Unexpected format for GPUs, expecting a numeric count"
1283
+ )
1284
+
1285
+ gpu_config = {
1286
+ "count": int(gpus) if gpus else 1,
1287
+ "sharing_type": None,
1288
+ "gpu_memory": None,
1289
+ "gpu_type": None,
1290
+ }
1291
+
1292
+ # Handle memory specification
1293
+ if gpu_memory is not None:
1294
+ if not isinstance(gpu_memory, str):
1295
+ raise ValueError(
1296
+ "GPU memory must be a string with suffix Mi, Gi, or Ti"
1297
+ )
1298
+
1299
+ units = {"mi": 1, "gi": 1024, "ti": 1024 * 1024}
1300
+ val = gpu_memory.lower()
1301
+
1302
+ for suffix, factor in units.items():
1303
+ if val.endswith(suffix):
1304
+ try:
1305
+ num = float(val[: -len(suffix)])
1306
+ mi_value = int(num * factor)
1307
+ gpu_config["sharing_type"] = "memory"
1308
+ gpu_config["gpu_memory"] = str(mi_value)
1309
+ break # Successfully parsed, exit the loop
1310
+ except ValueError:
1311
+ raise ValueError("Invalid numeric value in GPU memory spec")
1312
+ else:
1313
+ # Only raise error if no suffix matched
1314
+ raise ValueError("GPU memory must end with Mi, Gi, or Ti")
1315
+
1316
+ if gpu_type is not None:
1317
+ gpu_config["gpu_type"] = gpu_type
1318
+
1319
+ return GPUConfig(**gpu_config).to_dict()
1320
+
1321
+ # ----------------- Generic Helpers ----------------- #
1322
+ def _load_kube_config(self):
1323
+ try:
1324
+ config.load_incluster_config()
1325
+ except config.config_exception.ConfigException:
1326
+ # Fall back to a local kubeconfig file
1327
+ if not Path(self.kubeconfig_path).exists():
1328
+ raise FileNotFoundError(
1329
+ f"Kubeconfig file not found: {self.kubeconfig_path}"
1330
+ )
1331
+ config.load_kube_config(config_file=self.kubeconfig_path)
1332
+
1333
+ # Reset the cached API clients so they'll be reinitialized with the loaded config
1334
+ self._objects_api = None
1335
+ self._core_api = None
1336
+ self._apps_v1_api = None
1337
+
1338
+ def _load_kubetorch_global_config(self):
1339
+ global_config = {}
1340
+ kubetorch_config = self.service_manager.fetch_kubetorch_config()
1341
+ if kubetorch_config:
1342
+ defaults_yaml = kubetorch_config.get("COMPUTE_DEFAULTS", "")
1343
+ if defaults_yaml:
1344
+ try:
1345
+ validated_config = {}
1346
+ config_dict = yaml.safe_load(defaults_yaml)
1347
+ for key, value in config_dict.items():
1348
+ # Check for values as dictionaries with keys 'key' and 'value'
1349
+ if (
1350
+ isinstance(value, list)
1351
+ and len(value) > 0
1352
+ and isinstance(value[0], dict)
1353
+ and "key" in value[0]
1354
+ and "value" in value[0]
1355
+ ):
1356
+ validated_config[key] = {
1357
+ item["key"]: item["value"] for item in value
1358
+ }
1359
+ elif value is not None:
1360
+ validated_config[key] = value
1361
+ global_config = validated_config
1362
+ except yaml.YAMLError as e:
1363
+ logger.error(f"Failed to load kubetorch global config: {str(e)}")
1364
+
1365
+ if global_config:
1366
+ for key in ["inactivity_ttl"]:
1367
+ # Set values from global config where the value is not already set
1368
+ if key in global_config and self.__getattribute__(key) is None:
1369
+ self.__setattr__(key, global_config[key])
1370
+ for key in ["labels", "annotations", "env_vars"]:
1371
+ # Merge global config with existing config for dictionary values
1372
+ if key in global_config and isinstance(global_config[key], dict):
1373
+ self.__setattr__(
1374
+ key, {**global_config[key], **self.__getattribute__(key)}
1375
+ )
1376
+ if "image_id" in global_config:
1377
+ if self.image is None:
1378
+ self.image = Image(image_id=global_config["image_id"])
1379
+ elif self.image.image_id is None:
1380
+ self.image.image_id = global_config["image_id"]
1381
+
1382
+ return global_config
1383
+
1384
+ # ----------------- Launching a new service (Knative or StatefulSet) ----------------- #
1385
+ def _launch(
1386
+ self,
1387
+ service_name: str,
1388
+ install_url: str,
1389
+ pointer_env_vars: Dict,
1390
+ metadata_env_vars: Dict,
1391
+ startup_rsync_command: Optional[str],
1392
+ launch_id: Optional[str],
1393
+ dryrun: bool = False,
1394
+ ):
1395
+ """Creates a new service on the compute for the provided service. If the service already exists,
1396
+ it will update the service with the latest copy of the code."""
1397
+ # Finalize pod template with launch time env vars
1398
+ self._update_launch_env_vars(
1399
+ service_name, pointer_env_vars, metadata_env_vars, launch_id
1400
+ )
1401
+ self._upload_secrets_list()
1402
+
1403
+ setup_script = self._get_setup_script(install_url, startup_rsync_command)
1404
+ self._container()["args"][0] = setup_script
1405
+
1406
+ # Handle service template creation
1407
+ # Use the replicas property for deployment scaling
1408
+ replicas = self.replicas
1409
+
1410
+ # Prepare annotations for service creation, including kubeconfig path if provided
1411
+ if self._kubeconfig_path is not None:
1412
+ self.annotations[
1413
+ serving_constants.KUBECONFIG_PATH_ANNOTATION
1414
+ ] = self._kubeconfig_path
1415
+
1416
+ # Create service using the appropriate service manager
1417
+ # KnativeServiceManager will handle autoscaling config, inactivity_ttl, etc.
1418
+ # ServiceManager will handle replicas for deployments and rayclusters
1419
+ created_service = self.service_manager.create_or_update_service(
1420
+ service_name=service_name,
1421
+ module_name=pointer_env_vars["KT_MODULE_NAME"],
1422
+ pod_template=self.pod_template,
1423
+ replicas=replicas,
1424
+ autoscaling_config=self.autoscaling_config,
1425
+ gpu_annotations=self.gpu_annotations,
1426
+ inactivity_ttl=self.inactivity_ttl,
1427
+ custom_labels=self.labels,
1428
+ custom_annotations=self.annotations,
1429
+ custom_template=self.service_template,
1430
+ deployment_mode=self.deployment_mode,
1431
+ dryrun=dryrun,
1432
+ scheduler_name=self.scheduler_name,
1433
+ queue_name=self.queue_name(),
1434
+ )
1435
+
1436
+ # Handle service creation result based on resource type
1437
+ if isinstance(created_service, dict):
1438
+ # For custom resources (RayCluster, Knative), created_service is a dictionary
1439
+ service_name = created_service.get("metadata", {}).get("name")
1440
+ kind = created_service.get("kind", "")
1441
+
1442
+ if kind == "RayCluster":
1443
+ # RayCluster has headGroupSpec instead of template
1444
+ service_template = {
1445
+ "metadata": {
1446
+ "name": service_name,
1447
+ "namespace": created_service.get("metadata", {}).get(
1448
+ "namespace"
1449
+ ),
1450
+ },
1451
+ "spec": {
1452
+ "template": created_service["spec"]["headGroupSpec"]["template"]
1453
+ },
1454
+ }
1455
+ else:
1456
+ # For Knative services and other dict-based resources
1457
+ service_template = created_service["spec"]["template"]
1458
+ else:
1459
+ # For Deployments, created_service is a V1Deployment object
1460
+ service_name = created_service.metadata.name
1461
+ # Return dict format for compatibility with tests and reload logic
1462
+ service_template = {
1463
+ "metadata": {
1464
+ "name": created_service.metadata.name,
1465
+ "namespace": created_service.metadata.namespace,
1466
+ },
1467
+ "spec": {"template": created_service.spec.template},
1468
+ }
1469
+
1470
+ logger.debug(
1471
+ f"Successfully deployed {self.deployment_mode} service {service_name}"
1472
+ )
1473
+
1474
+ return service_template
1475
+
1476
+ async def _launch_async(
1477
+ self,
1478
+ service_name: str,
1479
+ install_url: str,
1480
+ pointer_env_vars: Dict,
1481
+ metadata_env_vars: Dict,
1482
+ startup_rsync_command: Optional[str],
1483
+ launch_id: Optional[str],
1484
+ dryrun: bool = False,
1485
+ ):
1486
+ """Async version of _launch. Creates a new service on the compute for the provided service.
1487
+ If the service already exists, it will update the service with the latest copy of the code."""
1488
+
1489
+ import asyncio
1490
+
1491
+ loop = asyncio.get_event_loop()
1492
+
1493
+ service_template = await loop.run_in_executor(
1494
+ None,
1495
+ self._launch,
1496
+ service_name,
1497
+ install_url,
1498
+ pointer_env_vars,
1499
+ metadata_env_vars,
1500
+ startup_rsync_command,
1501
+ launch_id,
1502
+ dryrun,
1503
+ )
1504
+
1505
+ return service_template
1506
+
1507
+ def _update_launch_env_vars(
1508
+ self, service_name, pointer_env_vars, metadata_env_vars, launch_id
1509
+ ):
1510
+ kt_env_vars = {
1511
+ **pointer_env_vars,
1512
+ **metadata_env_vars,
1513
+ "KT_LAUNCH_ID": launch_id,
1514
+ "KT_SERVICE_NAME": service_name,
1515
+ "KT_SERVICE_DNS": (
1516
+ f"{service_name}-headless.{self.namespace}.svc.cluster.local"
1517
+ if self.distributed_config
1518
+ else f"{service_name}.{self.namespace}.svc.cluster.local"
1519
+ ),
1520
+ "KT_DEPLOYMENT_MODE": self.deployment_mode,
1521
+ }
1522
+ if "OTEL_SERVICE_NAME" not in self.config_env_vars.keys():
1523
+ kt_env_vars["OTEL_SERVICE_NAME"] = service_name
1524
+
1525
+ # Ensure cluster config env vars are set
1526
+ if globals.config.cluster_config:
1527
+ if globals.config.cluster_config.get("otel_enabled"):
1528
+ kt_env_vars["KT_OTEL_ENABLED"] = True
1529
+
1530
+ # Ensure all environment variable values are strings for Kubernetes compatibility
1531
+ kt_env_vars = self._serialize_env_vars(kt_env_vars)
1532
+
1533
+ updated_env_vars = set()
1534
+ for env_var in self._container_env():
1535
+ if env_var["name"] in kt_env_vars:
1536
+ env_var["value"] = kt_env_vars[env_var["name"]]
1537
+ updated_env_vars.add(env_var["name"])
1538
+ for key, val in kt_env_vars.items():
1539
+ if key not in updated_env_vars:
1540
+ self._container_env().append({"name": key, "value": val})
1541
+
1542
+ def _serialize_env_vars(self, env_vars: Dict) -> Dict:
1543
+ import json
1544
+
1545
+ serialized_vars = {}
1546
+ for key, value in env_vars.items():
1547
+ if value is None:
1548
+ serialized_vars[key] = "null"
1549
+ elif isinstance(value, (dict, list)):
1550
+ try:
1551
+ serialized_vars[key] = json.dumps(value)
1552
+ except (TypeError, ValueError):
1553
+ serialized_vars[key] = str(value)
1554
+ elif isinstance(value, (bool, int, float)):
1555
+ serialized_vars[key] = str(value)
1556
+ else:
1557
+ serialized_vars[key] = value
1558
+ return serialized_vars
1559
+
1560
+ def _extract_secrets(self, secrets):
1561
+ if is_running_in_kubernetes():
1562
+ return [], []
1563
+
1564
+ secret_env_vars = []
1565
+ secret_volumes = []
1566
+ if secrets:
1567
+ secrets_client = KubernetesSecretsClient(
1568
+ namespace=self.namespace, kubeconfig_path=self.kubeconfig_path
1569
+ )
1570
+ secret_objects = secrets_client.convert_to_secret_objects(secrets=secrets)
1571
+ (
1572
+ secret_env_vars,
1573
+ secret_volumes,
1574
+ ) = secrets_client.extract_envs_and_volumes_from_secrets(secret_objects)
1575
+
1576
+ return secret_env_vars, secret_volumes
1577
+
1578
+ def _get_setup_script(self, install_url, startup_rsync_command):
1579
+ # Load the setup script template
1580
+ from kubetorch.servers.http.utils import _get_rendered_template
1581
+
1582
+ setup_script = _get_rendered_template(
1583
+ serving_constants.KT_SETUP_TEMPLATE_FILE,
1584
+ template_dir=os.path.join(
1585
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
1586
+ "serving",
1587
+ "templates",
1588
+ ),
1589
+ python_path=self.python_path,
1590
+ freeze=self.freeze,
1591
+ install_url=install_url or globals.config.install_url,
1592
+ install_cmd=self.image_install_cmd,
1593
+ install_otel=self._should_install_otel_dependencies(
1594
+ self.server_image, self.otel_enabled, self.inactivity_ttl
1595
+ ),
1596
+ tracing_enabled=self.tracing_enabled,
1597
+ server_image=self.server_image,
1598
+ rsync_kt_editable_cmd=startup_rsync_command,
1599
+ server_port=self.server_port,
1600
+ )
1601
+ return setup_script
1602
+
1603
+ def _upload_secrets_list(self):
1604
+ """Upload secrets to Kubernetes. Called during launch time, not during init."""
1605
+ if is_running_in_kubernetes():
1606
+ return
1607
+
1608
+ if self.secrets:
1609
+ logger.debug("Uploading secrets to Kubernetes")
1610
+ self.secrets_client.upload_secrets_list(secrets=self.secrets)
1611
+
1612
+ def _get_node_selector(self, node_selector, gpu_type):
1613
+ if gpu_type:
1614
+ if ":" in gpu_type:
1615
+ # Parse "key: value" format
1616
+ key, value = gpu_type.split(":", 1)
1617
+ node_selector[key.strip()] = value.strip()
1618
+ else:
1619
+ # Default to nvidia.com/gpu.product
1620
+ node_selector["nvidia.com/gpu.product"] = gpu_type
1621
+ return node_selector
1622
+
1623
+ def pod_names(self):
1624
+ """Returns a list of pod names."""
1625
+ pods = self.pods()
1626
+ return [pod.metadata.name for pod in pods if pod_is_running(pod)]
1627
+
1628
+ def pods(self):
1629
+ return self.service_manager.get_pods_for_service(self.service_name)
1630
+
1631
+ # ------------------------------- Volumes ------------------------------ #
1632
+ def _process_volumes(self, volumes) -> Optional[List[Volume]]:
1633
+ """Process volumes input into standardized format"""
1634
+ if volumes is None:
1635
+ volumes = globals.config.volumes
1636
+
1637
+ if volumes is None:
1638
+ return None
1639
+
1640
+ if isinstance(volumes, list):
1641
+ processed_volumes = []
1642
+ for vol in volumes:
1643
+ if isinstance(vol, str):
1644
+ # list of volume names (assume they exist)
1645
+ volume = Volume.from_name(
1646
+ vol, create_if_missing=False, core_v1=self.core_api
1647
+ )
1648
+ processed_volumes.append(volume)
1649
+
1650
+ elif isinstance(vol, Volume):
1651
+ # list of Volume objects (create them if they don't already exist)
1652
+ # Default the volume namespace to compute namespace if not provided
1653
+ if vol.namespace is None:
1654
+ vol.namespace = self.namespace
1655
+ vol.create()
1656
+ processed_volumes.append(vol)
1657
+
1658
+ else:
1659
+ raise ValueError(
1660
+ f"Volume list items must be strings or Volume objects, got {type(vol)}"
1661
+ )
1662
+
1663
+ return processed_volumes
1664
+
1665
+ else:
1666
+ raise ValueError(f"Volumes must be a list, got {type(volumes)}")
1667
+
1668
+ def _volumes_for_pod_template(self, volumes):
1669
+ """Convert processed volumes to template format"""
1670
+ volume_mounts = []
1671
+ volume_specs = []
1672
+
1673
+ if volumes:
1674
+ for volume in volumes:
1675
+ # Add volume mount
1676
+ volume_mounts.append(
1677
+ {"name": volume.name, "mountPath": volume.mount_path}
1678
+ )
1679
+
1680
+ # Add volume spec
1681
+ volume_specs.append(volume.pod_template_spec())
1682
+
1683
+ return volume_mounts, volume_specs
1684
+
1685
+ # ----------------- Functions using K8s implementation ----------------- #
1686
+ def _wait_for_endpoint(self):
1687
+ retries = 20
1688
+ for i in range(retries):
1689
+ endpoint = self.endpoint
1690
+ if endpoint:
1691
+ return endpoint
1692
+ else:
1693
+ logger.info(f"Endpoint not available (attempt {i + 1}/{retries})")
1694
+ time.sleep(2)
1695
+
1696
+ logger.error(f"Endpoint not available for {self.service_name}")
1697
+ return None
1698
+
1699
+ def _status_condition_ready(self, status) -> bool:
1700
+ """
1701
+ Checks if the Knative Service status conditions include a 'Ready' condition with status 'True'.
1702
+ This indicates that the service is ready to receive traffic.
1703
+
1704
+ Notes:
1705
+ - This does not check pod status or readiness, only the Knative Service's own readiness condition.
1706
+ - A service can be 'Ready' even if no pods are currently running (e.g., after scaling to zero).
1707
+ """
1708
+ for condition in status.get("conditions", []):
1709
+ if condition.get("type") == "Ready" and condition.get("status") == "True":
1710
+ logger.debug(f"Service {self.service_name} is ready")
1711
+ return True
1712
+ return False
1713
+
1714
+ def _check_service_ready(self):
1715
+ """Checks if the service is ready to start serving requests.
1716
+
1717
+ Delegates to the appropriate service manager's check_service_ready method.
1718
+ """
1719
+ return self.service_manager.check_service_ready(
1720
+ service_name=self.service_name,
1721
+ launch_timeout=self.launch_timeout,
1722
+ objects_api=self.objects_api,
1723
+ core_api=self.core_api,
1724
+ queue_name=self.queue_name(),
1725
+ scheduler_name=self.scheduler_name,
1726
+ )
1727
+
1728
+ async def _check_service_ready_async(self):
1729
+ """Async version of _check_service_ready. Checks if the service is ready to start serving requests.
1730
+
1731
+ Delegates to the appropriate service manager's check_service_ready method.
1732
+ """
1733
+ import asyncio
1734
+
1735
+ loop = asyncio.get_event_loop()
1736
+
1737
+ return await loop.run_in_executor(
1738
+ None,
1739
+ self._check_service_ready,
1740
+ )
1741
+
1742
+ def is_up(self):
1743
+ """Whether the pods are running."""
1744
+ try:
1745
+ pods = self.pods()
1746
+ if not pods:
1747
+ return False
1748
+ for pod in pods:
1749
+ if pod.status.phase != "Running":
1750
+ logger.info(
1751
+ f"Pod {pod.metadata.name} is not running. Status: {pod.status.phase}"
1752
+ )
1753
+ return False
1754
+ except client.exceptions.ApiException:
1755
+ return False
1756
+ return True
1757
+
1758
+ def _base_rsync_url(self, local_port: int):
1759
+ return (
1760
+ f"rsync://localhost:{local_port}/data/{self.namespace}/{self.service_name}"
1761
+ )
1762
+
1763
+ def _rsync_svc_url(self):
1764
+ return f"rsync://kubetorch-rsync.{self.namespace}.svc.cluster.local:{serving_constants.REMOTE_RSYNC_PORT}/data/{self.namespace}/{self.service_name}/"
1765
+
1766
+ def ssh(self, pod_name: str = None):
1767
+ pod_name = pod_name or self.pod_names()[0]
1768
+ ssh_cmd = f"kubectl exec -it {pod_name} -n {self.namespace} -- /bin/bash"
1769
+ subprocess.run(shlex.split(ssh_cmd), check=True)
1770
+
1771
+ def get_env_vars(self, keys: Union[List[str], str] = None):
1772
+ keys = [keys] if isinstance(keys, str) else keys
1773
+ env_vars = {}
1774
+ for env_var in self._container_env():
1775
+ if not keys or (env_var["name"] in keys and "value" in env_var):
1776
+ env_vars[env_var["name"]] = env_var["value"]
1777
+ return env_vars
1778
+
1779
+ # def update_env_vars(self, env_vars: Dict):
1780
+ # updated_env_vars = set()
1781
+ # for env_var in self._container_env():
1782
+ # if env_var["name"] in env_vars and env_var["name"]:
1783
+ # env_var["value"] = env_vars[env_var["name"]]
1784
+ # updated_env_vars.add(env_var["name"])
1785
+ # for key, val in env_vars.items():
1786
+ # if key not in updated_env_vars:
1787
+ # self._container_env().append({"name": key, "value": val})
1788
+
1789
+ # ----------------- Image Related Functionality ----------------- #
1790
+
1791
+ def pip_install(
1792
+ self,
1793
+ reqs: Union[List[str], str],
1794
+ node: Optional[str] = None,
1795
+ override_remote_version: bool = False,
1796
+ ):
1797
+ """Pip install reqs onto compute pod(s)."""
1798
+ reqs = [reqs] if isinstance(reqs, str) else reqs
1799
+ python_path = self.image.python_path or "python3"
1800
+ pip_install_cmd = f"{python_path} -m pip install"
1801
+ try:
1802
+ result = self.run_bash(
1803
+ "cat .kt/kt_pip_install_cmd 2>/dev/null || echo ''", node=node
1804
+ )
1805
+ if result and result[0][0] == 0 and result[0][1].strip():
1806
+ pip_install_cmd = result[0][1].strip()
1807
+ except Exception:
1808
+ pass
1809
+
1810
+ for req in reqs:
1811
+ base = self.working_dir or "."
1812
+ remote_editable = (
1813
+ self.run_bash(f"[ -d {base}/{req} ]", node=node)[0][0] == 0
1814
+ )
1815
+ if remote_editable:
1816
+ req = f"{base}/{req}"
1817
+ else:
1818
+ local_version = find_locally_installed_version(req)
1819
+ if local_version is not None:
1820
+ if not override_remote_version:
1821
+ installed_remotely = (
1822
+ self.run_bash(
1823
+ f"{python_path} -c \"import importlib.util; exit(0) if importlib.util.find_spec('{req}') else exit(1)\"",
1824
+ node=node,
1825
+ )[0][0]
1826
+ == 0
1827
+ )
1828
+ if installed_remotely:
1829
+ logger.info(f"{req} already installed. Skipping.")
1830
+ else:
1831
+ req = f"{req}=={local_version}"
1832
+
1833
+ logger.info(f"Pip installing {req} with: {pip_install_cmd} {req}")
1834
+ self.run_bash(f"{pip_install_cmd} {req}", node=node)
1835
+
1836
+ def sync_package(
1837
+ self,
1838
+ package: str,
1839
+ node: Optional[str] = None,
1840
+ ):
1841
+ """Sync package (locally installed, or path to package) to compute pod(s)."""
1842
+ full_path, dest_dir = _get_sync_package_paths(package)
1843
+ logger.info(f"Syncing over package at {full_path} to {dest_dir}")
1844
+ self.rsync(source=full_path, dest=dest_dir)
1845
+
1846
+ def run_bash(
1847
+ self,
1848
+ commands,
1849
+ node: Union[str, List[str]] = None,
1850
+ container: Optional[str] = None,
1851
+ ):
1852
+ """Run bash commands on the pod(s)."""
1853
+ self._load_kube_config()
1854
+
1855
+ pod_names = (
1856
+ self.pod_names()
1857
+ if node in ["all", None]
1858
+ else [node]
1859
+ if isinstance(node, str)
1860
+ else node
1861
+ )
1862
+
1863
+ return _run_bash(
1864
+ commands=commands,
1865
+ core_api=self.core_api,
1866
+ pod_names=pod_names,
1867
+ namespace=self.namespace,
1868
+ container=container,
1869
+ )
1870
+
1871
+ def _create_rsync_target_dir(self):
1872
+ """Create the subdirectory for this particular service in the rsync pod."""
1873
+ subdir = f"/data/{self.namespace}/{self.service_name}"
1874
+
1875
+ label_selector = f"app={serving_constants.RSYNC_SERVICE_NAME}"
1876
+ pod_name = (
1877
+ self.core_api.list_namespaced_pod(
1878
+ namespace=self.namespace, label_selector=label_selector
1879
+ )
1880
+ .items[0]
1881
+ .metadata.name
1882
+ )
1883
+ subdir_cmd = f"kubectl exec {pod_name} -n {self.namespace} -- mkdir -p {subdir}"
1884
+ logger.info(f"Creating directory on rsync pod with cmd: {subdir_cmd}")
1885
+ subprocess.run(subdir_cmd, shell=True, check=True)
1886
+
1887
+ def _run_rsync_command(self, rsync_cmd, create_target_dir: bool = True):
1888
+ backup_rsync_cmd = rsync_cmd
1889
+ if "--mkpath" not in rsync_cmd and create_target_dir:
1890
+ # Warning: --mkpath requires rsync 3.2.0+
1891
+ # Note: --mkpath allows the rsync daemon to create all intermediate directories that may not exist
1892
+ # https://download.samba.org/pub/rsync/rsync.1#opt--mkpath
1893
+ rsync_cmd = rsync_cmd.replace("rsync ", "rsync --mkpath ", 1)
1894
+ logger.debug(f"Rsync command: {rsync_cmd}")
1895
+
1896
+ resp = subprocess.run(
1897
+ rsync_cmd,
1898
+ shell=True,
1899
+ capture_output=True,
1900
+ text=True,
1901
+ )
1902
+ if resp.returncode != 0:
1903
+ if (
1904
+ create_target_dir
1905
+ and (
1906
+ "rsync: --mkpath" in resp.stderr
1907
+ or "rsync: unrecognized option" in resp.stderr
1908
+ )
1909
+ and not is_running_in_kubernetes()
1910
+ ):
1911
+ logger.warning(
1912
+ "Rsync failed: --mkpath is not supported, falling back to creating target dir. "
1913
+ "Please upgrade rsync to 3.2.0+ to improve performance."
1914
+ )
1915
+ self._create_rsync_target_dir()
1916
+ return self._run_rsync_command(
1917
+ backup_rsync_cmd, create_target_dir=False
1918
+ )
1919
+
1920
+ raise RsyncError(rsync_cmd, resp.returncode, resp.stdout, resp.stderr)
1921
+ else:
1922
+ import fcntl
1923
+ import pty
1924
+ import select
1925
+
1926
+ logger.debug(f"Rsync command: {rsync_cmd}")
1927
+
1928
+ leader, follower = pty.openpty()
1929
+ proc = subprocess.Popen(
1930
+ shlex.split(rsync_cmd),
1931
+ stdout=follower,
1932
+ stderr=follower,
1933
+ text=True,
1934
+ close_fds=True,
1935
+ )
1936
+ os.close(follower)
1937
+
1938
+ # Set to non-blocking mode
1939
+ flags = fcntl.fcntl(leader, fcntl.F_GETFL)
1940
+ fcntl.fcntl(leader, fcntl.F_SETFL, flags | os.O_NONBLOCK)
1941
+ buffer = b""
1942
+ transfer_completed = False
1943
+ error_patterns = [
1944
+ r"rsync\(\d+\): error:",
1945
+ r"rsync error:",
1946
+ r"@ERROR:",
1947
+ ]
1948
+ error_regexes = [
1949
+ re.compile(pattern, re.IGNORECASE) for pattern in error_patterns
1950
+ ]
1951
+
1952
+ try:
1953
+ with os.fdopen(leader, "rb", buffering=0) as stdout:
1954
+ while True:
1955
+ rlist, _, _ = select.select(
1956
+ [stdout], [], [], 0.1
1957
+ ) # 0.1 sec timeout for responsiveness
1958
+ if stdout in rlist:
1959
+ try:
1960
+ chunk = os.read(stdout.fileno(), 1024)
1961
+ except BlockingIOError:
1962
+ continue # no data available, try again
1963
+
1964
+ if not chunk: # EOF
1965
+ break
1966
+
1967
+ buffer += chunk
1968
+ while b"\n" in buffer:
1969
+ line, buffer = buffer.split(b"\n", 1)
1970
+ decoded_line = line.decode(errors="replace").strip()
1971
+ logger.debug(f"{decoded_line}")
1972
+
1973
+ for error_regex in error_regexes:
1974
+ if error_regex.search(decoded_line):
1975
+ raise RsyncError(
1976
+ rsync_cmd, 1, decoded_line, decoded_line
1977
+ )
1978
+
1979
+ if (
1980
+ "total size is" in decoded_line
1981
+ and "speedup is" in decoded_line
1982
+ ):
1983
+ transfer_completed = True
1984
+
1985
+ if transfer_completed:
1986
+ break
1987
+
1988
+ exit_code = proc.poll()
1989
+ if exit_code is not None:
1990
+ if exit_code != 0:
1991
+ raise RsyncError(
1992
+ rsync_cmd,
1993
+ exit_code,
1994
+ output=decoded_line,
1995
+ stderr=decoded_line,
1996
+ )
1997
+ break
1998
+
1999
+ proc.terminate()
2000
+ except Exception as e:
2001
+ proc.terminate()
2002
+ raise e
2003
+
2004
+ if not transfer_completed:
2005
+ logger.error("Rsync process ended without completion message")
2006
+ proc.terminate()
2007
+ raise subprocess.CalledProcessError(
2008
+ 1,
2009
+ rsync_cmd,
2010
+ output="",
2011
+ stderr="Rsync completed without success indication",
2012
+ )
2013
+
2014
+ logger.info("Rsync operation completed successfully")
2015
+
2016
+ async def _run_rsync_command_async(
2017
+ self, rsync_cmd: str, create_target_dir: bool = True
2018
+ ):
2019
+ """Async version of _run_rsync_command using asyncio.subprocess."""
2020
+ import asyncio
2021
+
2022
+ if "--mkpath" not in rsync_cmd and create_target_dir:
2023
+ # Warning: --mkpath requires rsync 3.2.0+
2024
+ # Note: --mkpath allows the rsync daemon to create all intermediate directories that may not exist
2025
+ # https://download.samba.org/pub/rsync/rsync.1#opt--mkpath
2026
+ rsync_cmd = rsync_cmd.replace("rsync ", "rsync --mkpath ", 1)
2027
+ logger.debug(f"Rsync command: {rsync_cmd}")
2028
+
2029
+ # Use asyncio.create_subprocess_shell for shell commands
2030
+ proc = await asyncio.create_subprocess_shell(
2031
+ rsync_cmd,
2032
+ stdout=asyncio.subprocess.PIPE,
2033
+ stderr=asyncio.subprocess.PIPE,
2034
+ )
2035
+
2036
+ stdout_bytes, stderr_bytes = await proc.communicate()
2037
+ stdout = (
2038
+ stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
2039
+ )
2040
+ stderr = (
2041
+ stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
2042
+ )
2043
+
2044
+ if proc.returncode != 0:
2045
+ if proc.returncode is None:
2046
+ proc.terminate()
2047
+ if (
2048
+ "rsync: --mkpath" in stderr
2049
+ or "rsync: unrecognized option" in stderr
2050
+ ):
2051
+ error_msg = (
2052
+ "Rsync failed: --mkpath is not supported, please upgrade your rsync version to 3.2.0+ to "
2053
+ "improve performance (e.g. `brew install rsync`)"
2054
+ )
2055
+ raise RsyncError(rsync_cmd, proc.returncode, stdout, error_msg)
2056
+
2057
+ raise RsyncError(rsync_cmd, proc.returncode, stdout, stderr)
2058
+
2059
+ def _get_rsync_cmd(
2060
+ self,
2061
+ source: Union[str, List[str]],
2062
+ dest: str,
2063
+ rsync_local_port: int,
2064
+ contents: bool = False,
2065
+ filter_options: str = None,
2066
+ force: bool = False,
2067
+ ):
2068
+ if dest:
2069
+ # Handle tilde prefix - treat as relative to home/working directory
2070
+ if dest.startswith("~/"):
2071
+ # Strip ~/ prefix to make it relative
2072
+ dest = dest[2:]
2073
+
2074
+ # Handle absolute vs relative paths
2075
+ if dest.startswith("/"):
2076
+ # For absolute paths, store under special __absolute__ subdirectory in the rsync pod
2077
+ # to preserve the path structure
2078
+ dest_for_rsync = f"__absolute__{dest}"
2079
+ else:
2080
+ # Relative paths work as before
2081
+ dest_for_rsync = dest
2082
+ remote_dest = f"{self._base_rsync_url(rsync_local_port)}/{dest_for_rsync}"
2083
+ else:
2084
+ remote_dest = self._base_rsync_url(rsync_local_port)
2085
+
2086
+ source = [source] if isinstance(source, str) else source
2087
+
2088
+ for src in source:
2089
+ if not Path(src).expanduser().exists():
2090
+ raise ValueError(f"Could not locate path to sync up: {src}")
2091
+
2092
+ exclude_options = _get_rsync_exclude_options()
2093
+
2094
+ expanded_sources = []
2095
+ for s in source:
2096
+ path = Path(s).expanduser().absolute()
2097
+ if not path.exists():
2098
+ raise ValueError(f"Could not locate path to sync up: {s}")
2099
+
2100
+ path_str = str(path)
2101
+ if contents and path.is_dir() and not str(s).endswith("/"):
2102
+ path_str += "/"
2103
+ expanded_sources.append(path_str)
2104
+
2105
+ source_str = " ".join(expanded_sources)
2106
+
2107
+ rsync_cmd = f"rsync -avL {exclude_options}"
2108
+
2109
+ if filter_options:
2110
+ rsync_cmd += f" {filter_options}"
2111
+
2112
+ if force:
2113
+ rsync_cmd += " --ignore-times"
2114
+
2115
+ rsync_cmd += f" {source_str} {remote_dest}"
2116
+ return rsync_cmd
2117
+
2118
+ def _get_rsync_in_cluster_cmd(
2119
+ self,
2120
+ source: Union[str, List[str]],
2121
+ dest: str,
2122
+ contents: bool = False,
2123
+ filter_options: str = None,
2124
+ force: bool = False,
2125
+ ):
2126
+ """Generate rsync command for in-cluster execution."""
2127
+ # Handle tilde prefix in dest - treat as relative to home/working directory
2128
+ if dest and dest.startswith("~/"):
2129
+ dest = dest[2:] # Strip ~/ prefix to make it relative
2130
+
2131
+ source = [source] if isinstance(source, str) else source
2132
+ if self.working_dir:
2133
+ source = [src.replace(self.working_dir, "") for src in source]
2134
+
2135
+ if contents:
2136
+ if self.working_dir:
2137
+ source = [
2138
+ s
2139
+ if s.endswith("/") or not Path(self.working_dir, s).is_dir()
2140
+ else s + "/"
2141
+ for s in source
2142
+ ]
2143
+ else:
2144
+ source = [
2145
+ s if s.endswith("/") or not Path(s).is_dir() else s + "/"
2146
+ for s in source
2147
+ ]
2148
+
2149
+ source_str = " ".join(source)
2150
+
2151
+ exclude_options = _get_rsync_exclude_options()
2152
+
2153
+ base_remote = self._rsync_svc_url()
2154
+
2155
+ if dest is None:
2156
+ # no dest specified -> use base
2157
+ remote = base_remote
2158
+ elif dest.startswith("rsync://"):
2159
+ # if full rsync:// URL -> use as-is
2160
+ remote = dest
2161
+ else:
2162
+ # if relative subdir specified -> append to base
2163
+ remote = base_remote + dest.lstrip("/")
2164
+
2165
+ # rsync wants the remote last; ensure it ends with '/' so we copy *into* the dir
2166
+ if not remote.endswith("/"):
2167
+ remote += "/"
2168
+
2169
+ rsync_command = f"rsync -av {exclude_options}"
2170
+ if filter_options:
2171
+ rsync_command += f" {filter_options}"
2172
+ if force:
2173
+ rsync_command += " --ignore-times"
2174
+
2175
+ rsync_command += f" {source_str} {remote}"
2176
+ return rsync_command
2177
+
2178
+ def _rsync(
2179
+ self,
2180
+ source: Union[str, List[str]],
2181
+ dest: str,
2182
+ rsync_local_port: int,
2183
+ contents: bool = False,
2184
+ filter_options: str = None,
2185
+ force: bool = False,
2186
+ ):
2187
+ rsync_cmd = self._get_rsync_cmd(
2188
+ source, dest, rsync_local_port, contents, filter_options, force
2189
+ )
2190
+ self._run_rsync_command(rsync_cmd)
2191
+
2192
+ async def _rsync_async(
2193
+ self,
2194
+ source: Union[str, List[str]],
2195
+ dest: str,
2196
+ rsync_local_port: int,
2197
+ contents: bool = False,
2198
+ filter_options: str = None,
2199
+ force: bool = False,
2200
+ ):
2201
+ """Async version of _rsync_helper."""
2202
+ rsync_cmd = self._get_rsync_cmd(
2203
+ source, dest, rsync_local_port, contents, filter_options, force
2204
+ )
2205
+ await self._run_rsync_command_async(rsync_cmd)
2206
+
2207
+ def _get_websocket_info(self, local_port: int):
2208
+ rsync_local_port = local_port or serving_constants.LOCAL_NGINX_PORT
2209
+ base_url = globals.service_url()
2210
+
2211
+ # api_url = globals.config.api_url
2212
+
2213
+ # # Determine if we need port forwarding to reach nginx proxy
2214
+ # should_port_forward = api_url is None
2215
+
2216
+ # if should_port_forward:
2217
+ # base_url = globals.service_url()
2218
+ # else:
2219
+ # # Direct access to nginx proxy via ingress
2220
+ # base_url = api_url # e.g. "https://your.ingress.domain"
2221
+
2222
+ ws_url = f"{http_to_ws(base_url)}/rsync/{self.namespace}/"
2223
+ parsed_url = urlparse(base_url)
2224
+
2225
+ # choose a local ephemeral port for the tunnel
2226
+ start_from = (parsed_url.port or rsync_local_port) + 1
2227
+ websocket_port = find_available_port(start_from, max_tries=10)
2228
+ return websocket_port, ws_url
2229
+
2230
+ def rsync(
2231
+ self,
2232
+ source: Union[str, List[str]],
2233
+ dest: str = None,
2234
+ local_port: int = None,
2235
+ contents: bool = False,
2236
+ filter_options: str = None,
2237
+ force: bool = False,
2238
+ ):
2239
+ """Rsync from local to the rsync pod."""
2240
+ # Note: use the nginx port by default since the rsync pod sits behind the nginx proxy
2241
+ websocket_port, ws_url = self._get_websocket_info(local_port)
2242
+
2243
+ logger.debug(f"Opening WebSocket tunnel on port {websocket_port} to {ws_url}")
2244
+ with WebSocketRsyncTunnel(websocket_port, ws_url) as tunnel:
2245
+ self._rsync(
2246
+ source, dest, tunnel.local_port, contents, filter_options, force
2247
+ )
2248
+
2249
+ async def rsync_async(
2250
+ self,
2251
+ source: Union[str, List[str]],
2252
+ dest: str = None,
2253
+ local_port: int = None,
2254
+ contents: bool = False,
2255
+ filter_options: str = None,
2256
+ force: bool = False,
2257
+ ):
2258
+ """Async version of rsync. Rsync from local to the rsync pod."""
2259
+ websocket_port, ws_url = self._get_websocket_info(local_port)
2260
+
2261
+ logger.debug(f"Opening WebSocket tunnel on port {websocket_port} to {ws_url}")
2262
+ with WebSocketRsyncTunnel(websocket_port, ws_url) as tunnel:
2263
+ await self._rsync_async(
2264
+ source, dest, tunnel.local_port, contents, filter_options, force
2265
+ )
2266
+
2267
+ def rsync_in_cluster(
2268
+ self,
2269
+ source: Union[str, List[str]],
2270
+ dest: str = None,
2271
+ contents: bool = False,
2272
+ filter_options: str = None,
2273
+ force: bool = False,
2274
+ ):
2275
+ """Rsync from inside the cluster to the rsync pod."""
2276
+ rsync_command = self._get_rsync_in_cluster_cmd(
2277
+ source, dest, contents, filter_options, force
2278
+ )
2279
+ self._run_rsync_command(rsync_command)
2280
+
2281
+ async def rsync_in_cluster_async(
2282
+ self,
2283
+ source: Union[str, List[str]],
2284
+ dest: str = None,
2285
+ contents: bool = False,
2286
+ filter_options: str = None,
2287
+ force: bool = False,
2288
+ ):
2289
+ """Async version of rsync_in_cluster. Rsync from inside the cluster to the rsync pod."""
2290
+ rsync_command = self._get_rsync_in_cluster_cmd(
2291
+ source, dest, contents, filter_options, force
2292
+ )
2293
+ await self._run_rsync_command_async(rsync_command)
2294
+
2295
+ def _image_setup_and_instructions(self, rsync: bool = True):
2296
+ """
2297
+ Return image instructions in Dockerfile format, and optionally rsync over content to the rsync pod.
2298
+ """
2299
+ instructions = ""
2300
+
2301
+ if not self.image:
2302
+ return instructions
2303
+
2304
+ logger.debug("Writing out image instructions.")
2305
+
2306
+ if self.image.image_id:
2307
+ instructions += f"FROM {self.server_image}\n"
2308
+ if self.image.python_path:
2309
+ instructions += f"ENV KT_PYTHON_PATH {self.image.python_path}\n"
2310
+
2311
+ # image_id is ignored, used directly in server_image
2312
+ for step in self.image.setup_steps:
2313
+ if step.step_type == ImageSetupStepType.CMD_RUN:
2314
+ commands = step.kwargs.get("command")
2315
+ commands = [commands] if isinstance(commands, str) else commands
2316
+ for i in range(len(commands)):
2317
+ if i != 0:
2318
+ instructions += "\n"
2319
+ instructions += f"RUN {commands[i]}"
2320
+ elif step.step_type == ImageSetupStepType.PIP_INSTALL:
2321
+ reqs = step.kwargs.get("reqs")
2322
+ reqs = [reqs] if isinstance(reqs, str) else reqs
2323
+ for i in range(len(reqs)):
2324
+ if i != 0:
2325
+ instructions += "\n"
2326
+ if self.image_install_cmd:
2327
+ install_cmd = self.image_install_cmd
2328
+ else:
2329
+ install_cmd = "$KT_PIP_INSTALL_CMD"
2330
+
2331
+ # Pass through the requirement string directly without quoting
2332
+ # This allows users to pass any pip arguments they want
2333
+ # e.g., "--pre torchmonarch==0.1.0rc7" or "numpy>=1.20"
2334
+ instructions += f"RUN {install_cmd} {reqs[i]}"
2335
+
2336
+ if step.kwargs.get("force"):
2337
+ instructions += " # force"
2338
+ elif step.step_type == ImageSetupStepType.SYNC_PACKAGE:
2339
+ # using package name instead of paths, since the folder path in the rsync pod will just be the package name
2340
+ full_path, dest_dir = _get_sync_package_paths(
2341
+ step.kwargs.get("package")
2342
+ )
2343
+ if rsync:
2344
+ self.rsync(full_path, dest=dest_dir)
2345
+ instructions += f"COPY {full_path} {dest_dir}"
2346
+ elif step.step_type == ImageSetupStepType.RSYNC:
2347
+ source_path = step.kwargs.get("source")
2348
+ dest_dir = step.kwargs.get("dest")
2349
+ contents = step.kwargs.get("contents")
2350
+ filter_options = step.kwargs.get("filter_options")
2351
+ force = step.kwargs.get("force")
2352
+
2353
+ if rsync:
2354
+ if is_running_in_kubernetes():
2355
+ self.rsync_in_cluster(
2356
+ source_path,
2357
+ dest=dest_dir,
2358
+ contents=contents,
2359
+ filter_options=filter_options,
2360
+ force=force,
2361
+ )
2362
+ else:
2363
+ self.rsync(
2364
+ source_path,
2365
+ dest=dest_dir,
2366
+ contents=contents,
2367
+ filter_options=filter_options,
2368
+ force=force,
2369
+ )
2370
+ # Generate COPY instruction with explicit destination
2371
+ if dest_dir:
2372
+ instructions += f"COPY {source_path} {dest_dir}"
2373
+ else:
2374
+ # No dest specified - use basename of source as destination
2375
+ dest_name = Path(source_path).name
2376
+ instructions += f"COPY {source_path} {dest_name}"
2377
+ elif step.step_type == ImageSetupStepType.SET_ENV_VARS:
2378
+ for key, val in step.kwargs.get("env_vars").items():
2379
+ # single env var per line in the dockerfile
2380
+ instructions += f"ENV {key} {val}\n"
2381
+ if (
2382
+ step.kwargs.get("force")
2383
+ and step.step_type != ImageSetupStepType.PIP_INSTALL
2384
+ ):
2385
+ instructions += " # force"
2386
+ instructions += "\n"
2387
+
2388
+ return instructions
2389
+
2390
+ # ----------------- Copying over for now... TBD ----------------- #
2391
+ def __getstate__(self):
2392
+ """Remove local stateful values before pickle serialization."""
2393
+ state = self.__dict__.copy()
2394
+ # Remove local stateful values that shouldn't be serialized
2395
+ state["_endpoint"] = None
2396
+ state["_service_manager"] = None
2397
+ state["_objects_api"] = None
2398
+ state["_core_api"] = None
2399
+ state["_apps_v1_api"] = None
2400
+ state["_node_v1_api"] = None
2401
+ state["_secrets_client"] = None
2402
+ return state
2403
+
2404
+ def __setstate__(self, state):
2405
+ """Restore state after pickle deserialization."""
2406
+ self.__dict__.update(state)
2407
+ # Reset local stateful values to None to ensure clean initialization
2408
+ self._endpoint = None
2409
+ self._service_manager = None
2410
+ self._objects_api = None
2411
+ self._core_api = None
2412
+ self._apps_v1_api = None
2413
+ self._node_v1_api = None
2414
+ self._secrets_client = None
2415
+
2416
+ # ------------ Distributed / Autoscaling Helpers -------- #
2417
+ def distribute(
2418
+ self,
2419
+ distribution_type: str = None,
2420
+ workers: int = None,
2421
+ quorum_timeout: int = None,
2422
+ quorum_workers: int = None,
2423
+ monitor_members: bool = None,
2424
+ **kwargs,
2425
+ ):
2426
+ """Configure the distributed worker compute needed by each service replica.
2427
+
2428
+ Args:
2429
+ distribution_type (str): The type of distributed supervisor to create.
2430
+ Options: ``spmd`` (default, if empty), ``"pytorch"``, ``"ray"``, ``"monarch"``, ``"jax"``, or ``"tensorflow"``.
2431
+ workers (int): Int representing the number of workers to create, with identical compute resources to
2432
+ the service compute. Or List of ``<int, Compute>`` pairs specifying the number of workers and the compute
2433
+ resources for each worker StatefulSet.
2434
+ quorum_timeout (int, optional): Timeout in seconds for workers to become ready and join the cluster.
2435
+ Defaults to `launch_timeout` if not provided, for both SPMD frameworks and for Ray.
2436
+ Increase this if workers need more time to start (e.g., during node autoscaling or loading down data
2437
+ during initialization).
2438
+ **kwargs: Additional framework-specific parameters (e.g., num_proc, port).
2439
+
2440
+ Note:
2441
+ List of ``<int, Compute>`` pairs is not yet supported for workers.
2442
+
2443
+ Examples:
2444
+
2445
+ .. code-block:: python
2446
+
2447
+ import kubetorch as kt
2448
+
2449
+ remote_fn = kt.fn(simple_summer, service_name).to(
2450
+ kt.Compute(
2451
+ cpus="2",
2452
+ memory="4Gi",
2453
+ image=kt.Image(image_id="rayproject/ray"),
2454
+ launch_timeout=300,
2455
+ ).distribute("ray", workers=2)
2456
+ )
2457
+
2458
+ gpus = kt.Compute(
2459
+ gpus=1,
2460
+ image=kt.Image(image_id="nvcr.io/nvidia/pytorch:23.10-py3"),
2461
+ launch_timeout=600,
2462
+ inactivity_ttl="4h",
2463
+ ).distribute("pytorch", workers=4)
2464
+ """
2465
+ # Check for conflicting configuration
2466
+ if self.autoscaling_config:
2467
+ raise ValueError(
2468
+ "Cannot use both .distribute() and .autoscale() on the same compute instance. "
2469
+ "Use .distribute() for fixed replicas with distributed training, or .autoscale() for auto-scaling services."
2470
+ )
2471
+
2472
+ # Configure distributed settings
2473
+ # Note: We default to simple SPMD distribution ("spmd") if nothing specified and compute.workers > 1
2474
+
2475
+ # User can override quorum if they want to set a lower threshold
2476
+ quorum_workers = quorum_workers or workers
2477
+ distributed_config = {
2478
+ "distribution_type": distribution_type or "spmd",
2479
+ "quorum_timeout": quorum_timeout or self.launch_timeout,
2480
+ "quorum_workers": quorum_workers,
2481
+ }
2482
+ if monitor_members is not None:
2483
+ # Note: Ray manages its own membership, so it's disabled by default in the supervisor
2484
+ # It's enabled by default for SPMD.
2485
+ distributed_config["monitor_members"] = monitor_members
2486
+ distributed_config.update(kwargs)
2487
+
2488
+ if workers:
2489
+ if not isinstance(workers, int):
2490
+ raise ValueError(
2491
+ "Workers must be an integer. List of <integer, Compute> pairs is not yet supported"
2492
+ )
2493
+ # Set replicas property instead of storing in distributed_config
2494
+ self.replicas = workers
2495
+
2496
+ if distributed_config:
2497
+ self.distributed_config = distributed_config
2498
+ # Invalidate cached service manager so it gets recreated with the right type
2499
+ self._service_manager = None
2500
+
2501
+ return self
2502
+
2503
+ def autoscale(self, **kwargs):
2504
+ """Configure the service with the provided autoscaling parameters.
2505
+
2506
+ You can pass any of the following keyword arguments:
2507
+
2508
+ Args:
2509
+ target (int): The concurrency/RPS/CPU/memory target per pod.
2510
+ window (str): Time window for scaling decisions, e.g. "60s".
2511
+ metric (str): Metric to scale on: "concurrency", "rps", "cpu", "memory" or custom.
2512
+ Note: "cpu" and "memory" require autoscaler_class="hpa.autoscaling.knative.dev".
2513
+ target_utilization (int): Utilization % to trigger scaling (1-100).
2514
+ min_scale (int): Minimum number of replicas. 0 allows scale to zero.
2515
+ max_scale (int): Maximum number of replicas.
2516
+ initial_scale (int): Initial number of pods.
2517
+ concurrency (int): Maximum concurrent requests per pod (containerConcurrency).
2518
+ If not set, pods accept unlimited concurrent requests.
2519
+ scale_to_zero_pod_retention_period (str): Time to keep last pod before scaling
2520
+ to zero, e.g. "30s", "1m5s".
2521
+ scale_down_delay (str): Delay before scaling down, e.g. "15m". Only for KPA.
2522
+ autoscaler_class (str): Autoscaler implementation:
2523
+ - "kpa.autoscaling.knative.dev" (default, supports concurrency/rps)
2524
+ - "hpa.autoscaling.knative.dev" (supports cpu/memory/custom metrics)
2525
+ progress_deadline (str): Time to wait for deployment to be ready, e.g. "10m".
2526
+ Must be longer than startup probe timeout.
2527
+ **extra_annotations: Additional Knative autoscaling annotations.
2528
+
2529
+ Note:
2530
+ The service will be deployed as a Knative service.
2531
+
2532
+ Timing-related defaults are applied if not explicitly set (for ML workloads):
2533
+ - scale_down_delay="1m" (avoid rapid scaling cycles)
2534
+ - scale_to_zero_pod_retention_period="10m" (keep last pod longer before scale to zero)
2535
+ - progress_deadline="10m" or greater (ensures enough time for initialization, automatically adjusted based on launch_timeout)
2536
+
2537
+ Examples:
2538
+
2539
+ .. code-block:: python
2540
+
2541
+ import kubetorch as kt
2542
+
2543
+ remote_fn = kt.fn(my_fn_obj).to(
2544
+ kt.Compute(
2545
+ cpus=".1",
2546
+ ).autoscale(min_replicas=1)
2547
+ )
2548
+
2549
+ remote_fn = kt.fn(summer).to(
2550
+ compute=kt.Compute(
2551
+ cpus=".01",
2552
+ ).autoscale(min_scale=3, scale_to_zero_grace_period=50),
2553
+ )
2554
+ """
2555
+ # Check for conflicting configuration
2556
+ if self.distributed_config:
2557
+ raise ValueError(
2558
+ "Cannot use both .distribute() and .autoscale() on the same compute instance. "
2559
+ "Use .distribute() for fixed replicas with distributed training, or .autoscale() for auto-scaling services."
2560
+ )
2561
+
2562
+ # Apply timing-related defaults for ML workloads to account for initialization overhead
2563
+ # (heavy dependencies, model loading, etc. affect both CPU and GPU workloads)
2564
+ if "scale_down_delay" not in kwargs:
2565
+ kwargs["scale_down_delay"] = "1m"
2566
+ logger.debug("Setting scale_down_delay=1m to avoid thrashing")
2567
+
2568
+ if "scale_to_zero_pod_retention_period" not in kwargs:
2569
+ kwargs["scale_to_zero_pod_retention_period"] = "10m"
2570
+ logger.debug(
2571
+ "Setting scale_to_zero_pod_retention_period=10m to avoid thrashing"
2572
+ )
2573
+
2574
+ if "progress_deadline" not in kwargs:
2575
+ # Ensure progress_deadline is at least as long as launch_timeout
2576
+ default_deadline = "10m" # 600 seconds
2577
+ if self.launch_timeout:
2578
+ # Convert launch_timeout (seconds) to a duration string
2579
+ # Add some buffer (20% or at least 60 seconds)
2580
+ timeout_with_buffer = max(
2581
+ self.launch_timeout + 60, int(self.launch_timeout * 1.2)
2582
+ )
2583
+ if timeout_with_buffer > 600: # If larger than default
2584
+ default_deadline = f"{timeout_with_buffer}s"
2585
+ kwargs["progress_deadline"] = default_deadline
2586
+ logger.debug(
2587
+ f"Setting progress_deadline={default_deadline} to allow time for initialization"
2588
+ )
2589
+
2590
+ autoscaling_config = AutoscalingConfig(**kwargs)
2591
+ if autoscaling_config:
2592
+ self._autoscaling_config = autoscaling_config
2593
+ # Invalidate cached service manager so it gets recreated with KnativeServiceManager
2594
+ self._service_manager = None
2595
+
2596
+ return self