kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,2414 @@
1
+ import os
2
+ import re
3
+ import shlex
4
+ import subprocess
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Union
8
+ from urllib.parse import urlparse
9
+
10
+ import yaml
11
+
12
+ from kubernetes import client, config
13
+ from kubernetes.client import V1ResourceRequirements
14
+
15
+ import kubetorch.constants as constants
16
+ import kubetorch.serving.constants as serving_constants
17
+
18
+ from kubetorch import globals
19
+
20
+ from kubetorch.logger import get_logger
21
+ from kubetorch.resources.callables.utils import find_locally_installed_version
22
+ from kubetorch.resources.compute.utils import (
23
+ _get_rsync_exclude_options,
24
+ _get_sync_package_paths,
25
+ _run_bash,
26
+ find_available_port,
27
+ RsyncError,
28
+ )
29
+ from kubetorch.resources.compute.websocket import WebSocketRsyncTunnel
30
+ from kubetorch.resources.images.image import Image, ImageSetupStepType
31
+ from kubetorch.resources.secrets.kubernetes_secrets_client import KubernetesSecretsClient
32
+ from kubetorch.resources.volumes.volume import Volume
33
+ from kubetorch.servers.http.utils import is_running_in_kubernetes, load_template
34
+ from kubetorch.serving.autoscaling import AutoscalingConfig
35
+ from kubetorch.serving.service_manager import DeploymentServiceManager, KnativeServiceManager, RayClusterServiceManager
36
+ from kubetorch.serving.utils import GPUConfig, pod_is_running, RequestedPodResources
37
+
38
+ from kubetorch.utils import extract_host_port, http_to_ws, load_kubeconfig
39
+
40
+ logger = get_logger(__name__)
41
+
42
+
43
+ class Compute:
44
+ def __init__(
45
+ self,
46
+ cpus: Union[str, int] = None,
47
+ memory: str = None,
48
+ disk_size: str = None,
49
+ gpus: Union[str, int] = None,
50
+ queue: str = None,
51
+ priority_class_name: str = None,
52
+ gpu_type: str = None,
53
+ gpu_memory: str = None,
54
+ namespace: str = None,
55
+ image: "Image" = None,
56
+ labels: Dict = None,
57
+ annotations: Dict = None,
58
+ volumes: List[Union[str, Volume]] = None,
59
+ node_selector: Dict = None,
60
+ service_template: Dict = None,
61
+ tolerations: List[Dict] = None,
62
+ env_vars: Dict = None,
63
+ secrets: List[Union[str, "Secret"]] = None,
64
+ freeze: bool = False,
65
+ kubeconfig_path: str = None,
66
+ service_account_name: str = None,
67
+ image_pull_policy: str = None,
68
+ inactivity_ttl: str = None,
69
+ gpu_anti_affinity: bool = None,
70
+ launch_timeout: int = None,
71
+ working_dir: str = None,
72
+ shared_memory_limit: str = None,
73
+ allowed_serialization: Optional[List[str]] = None,
74
+ replicas: int = 1,
75
+ _skip_template_init: bool = False,
76
+ ):
77
+ """Initialize the compute requirements for a Kubetorch service.
78
+
79
+ Args:
80
+ cpus (str, int, optional): CPU resource request. Can be specified in cores ("1.0") or millicores ("1000m").
81
+ memory (str, optional): Memory resource request. Can use binary (Ki, Mi, Gi) or decimal (K, M, G) units.
82
+ disk_size (str, optional): Ephemeral storage request. Uses same format as memory.
83
+ gpus (str or int, optional): Number of GPUs to request. Fractional GPUs not currently supported.
84
+ gpu_type (str, optional): GPU type to request. Corresponds to the "nvidia.com/gpu.product" label on the
85
+ node (if GPU feature discovery is enabled), or a full string like "nvidia.com/gpu.product: L4" can be
86
+ passed, which will be used to set a `nodeSelector` on the service. More info below.
87
+ gpu_memory (str, optional): GPU memory request (e.g., "4Gi"). Will still request whole GPU but limit
88
+ memory usage.
89
+ queue (str, optional): Name of the Kubernetes queue that will be responsible for placing the service's
90
+ pods onto nodes. Controls how cluster resources are allocated and prioritized for this service.
91
+ Pods will be scheduled according to the quota, priority, and limits configured for the queue.
92
+ priority_class_name (str, optional): Name of the Kubernetes priority class to use for the service. If
93
+ not specified, the default priority class will be used.
94
+ namespace (str, optional): Kubernetes namespace. Defaults to global config default, or "default".
95
+ image (Image, optional): Kubetorch image configuration. See :class:`Image` for more details.
96
+ labels (Dict, optional): Kubernetes labels to apply to the service.
97
+ annotations (Dict, optional): Kubernetes annotations to apply to the service.
98
+ volumes (List[Union[str or Volume]], optional): Volumes to attach to the service. Can be specified as a
99
+ list of volume names (strings) or Volume objects. If using strings, they must be the names of existing
100
+ PersistentVolumeClaims (PVCs) in the specified namespace.
101
+ node_selector (Dict, optional): Kubernetes node selector to constrain pods to specific nodes. Should be a
102
+ dictionary of key-value pairs, e.g. `{"node.kubernetes.io/instance-type": "g4dn.xlarge"}`.
103
+ service_template (Dict, optional): Nested dictionary of service template arguments to apply to the service. E.g.
104
+ ``{"spec": {"template": {"spec": {"nodeSelector": {"node.kubernetes.io/instance-type": "g4dn.xlarge"}}}}}``
105
+ tolerations (List[Dict], optional): Kubernetes tolerations to apply to the service. Each toleration should
106
+ be a dictionary with keys like "key", "operator", "value", and "effect". More info
107
+ `here <https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/>`__.
108
+ env_vars (Dict, optional): Environment variables to set in containers.
109
+ secrets (List[Union[str, Secret]], optional): Secrets to mount or expose.
110
+ freeze (bool, optional): Whether to freeze the compute configuration (e.g. for production).
111
+ kubeconfig_path (str, optional): Path to local kubeconfig file used for cluster authentication.
112
+ service_account_name (str, optional): Kubernetes service account to use.
113
+ image_pull_policy (str, optional): Container image pull policy.
114
+ More info `here <https://kubernetes.io/docs/concepts/containers/images/#image-pull-policy>`__.
115
+ inactivity_ttl (str, optional): Time-to-live after inactivity. Once hit, the service will be destroyed.
116
+ Values below 1m may cause premature deletion.
117
+ gpu_anti_affinity (bool, optional): Whether to prevent scheduling the service on a GPU, should no GPUs be requested.
118
+ Can also control globally by setting the `KT_GPU_ANTI_AFFINITY` environment variable. (Default: ``False``)
119
+ launch_timeout (int, optional): Determines how long to wait for the service to ready before giving up.
120
+ If not specified, will wait {serving_constants.KT_LAUNCH_TIMEOUT} seconds.
121
+ Note: you can also control this timeout globally by setting the `KT_LAUNCH_TIMEOUT` environment variable.
122
+ replicas (int, optional): Number of replicas to create for deployment-based services. Can also be set via
123
+ the `.distribute(workers=N)` method for distributed training. (Default: 1)
124
+ working_dir (str, optional): Working directory to use inside the remote images. Must be an absolute path (e.g. `/kt`)
125
+ shared_memory_limit (str, optional): Maximum size of the shared memory filesystem (/dev/shm) available to
126
+ each pod created by the service. Value should be a Kubernetes quantity string, for example: "512Mi",
127
+ "2Gi", "1G", "1024Mi", "100M". If not provided, /dev/shm will default to the pod's memory limit (if set)
128
+ or up to half the node's RAM.
129
+
130
+ Note:
131
+ **Resource Specification Formats:**
132
+
133
+ CPUs:
134
+ - Decimal core count: "0.5", "1.0", "2.0"
135
+ - Millicores: "500m", "1000m", "2000m"
136
+
137
+ Memory:
138
+ - Bytes: "1000000"
139
+ - Binary units: "1Ki", "1Mi", "1Gi", "1Ti"
140
+ - Decimal units: "1K", "1M", "1G", "1T"
141
+
142
+ GPU Specifications:
143
+ 1. ``gpus`` for whole GPUs: "1", "2"
144
+ 2. ``gpu_memory``: "$Gi", "16Gi"
145
+
146
+ Disk Size:
147
+ - Same format as memory
148
+
149
+ Note:
150
+ - Memory/disk values are case sensitive (Mi != mi)
151
+ - When using ``gpu_memory``, a whole GPU is still requested but memory is limited
152
+
153
+ Examples:
154
+
155
+ .. code-block:: python
156
+
157
+ import kubetorch as kt
158
+
159
+ # Basic CPU/Memory request
160
+ compute = kt.Compute(cpus="0.5", memory="2Gi")
161
+
162
+ # GPU request with memory limit
163
+ compute = kt.Compute(gpu_memory="4Gi", cpus="1.0")
164
+
165
+ # Multiple whole GPUs
166
+ compute = kt.Compute(gpus="2", memory="16Gi")
167
+ """
168
+ self.default_config = {}
169
+
170
+ self._endpoint = None
171
+ self._service_manager = None
172
+ self._autoscaling_config = None
173
+ self._kubeconfig_path = kubeconfig_path
174
+ self._namespace = namespace or globals.config.namespace
175
+
176
+ self._objects_api = None
177
+ self._core_api = None
178
+ self._apps_v1_api = None
179
+ self._node_v1_api = None
180
+
181
+ self._image = image
182
+ self._service_name = None
183
+ self._secrets = secrets
184
+ self._secrets_client = None
185
+ self._volumes = volumes
186
+ self._queue = queue
187
+
188
+ # service template args to store
189
+ self.replicas = replicas
190
+ self.labels = labels or {}
191
+ self.annotations = annotations or {}
192
+ self.service_template = service_template or {}
193
+ self._gpu_annotations = {} # Will be populated during init or from_template
194
+
195
+ # Skip template initialization if loading from existing service
196
+ if _skip_template_init:
197
+ return
198
+
199
+ # determine pod template vars
200
+ server_port = serving_constants.DEFAULT_KT_SERVER_PORT
201
+ service_account_name = service_account_name or serving_constants.DEFAULT_SERVICE_ACCOUNT_NAME
202
+ otel_enabled = (
203
+ globals.config.cluster_config.get("otel_enabled", False) if globals.config.cluster_config else False
204
+ )
205
+ server_image = self._get_server_image(image, otel_enabled, inactivity_ttl)
206
+ gpus = None if gpus in (0, None) else gpus
207
+ gpu_config = self._load_gpu_config(gpus, gpu_memory, gpu_type)
208
+ self._gpu_annotations = self._get_gpu_annotations(gpu_config)
209
+ requested_resources = self._get_requested_resources(cpus, memory, disk_size, gpu_config)
210
+ secret_env_vars, secret_volumes = self._extract_secrets(secrets)
211
+ volume_mounts, volume_specs = self._volumes_for_pod_template(volumes)
212
+ scheduler_name = self._get_scheduler_name(queue)
213
+ node_selector = self._get_node_selector(node_selector.copy() if node_selector else {}, gpu_type)
214
+ all_tolerations = self._get_tolerations(gpus, tolerations)
215
+
216
+ env_vars = env_vars or {}
217
+ if os.getenv("KT_LOG_LEVEL") and not env_vars.get("KT_LOG_LEVEL"):
218
+ # If KT_LOG_LEVEL is set, add it to env vars so the log level is set on the server
219
+ env_vars["KT_LOG_LEVEL"] = os.getenv("KT_LOG_LEVEL")
220
+
221
+ template_vars = {
222
+ "server_image": server_image,
223
+ "server_port": server_port,
224
+ "env_vars": env_vars,
225
+ "resources": requested_resources,
226
+ "node_selector": node_selector,
227
+ "secret_env_vars": secret_env_vars or [],
228
+ "secret_volumes": secret_volumes or [],
229
+ "volume_mounts": volume_mounts,
230
+ "volume_specs": volume_specs,
231
+ "service_account_name": service_account_name,
232
+ "config_env_vars": self._get_config_env_vars(allowed_serialization or ["json"]),
233
+ "image_pull_policy": image_pull_policy,
234
+ "namespace": self._namespace,
235
+ "freeze": freeze,
236
+ "gpu_anti_affinity": gpu_anti_affinity,
237
+ "working_dir": working_dir,
238
+ "tolerations": all_tolerations,
239
+ "shm_size_limit": shared_memory_limit,
240
+ "priority_class_name": priority_class_name,
241
+ "launch_timeout": self._get_launch_timeout(launch_timeout),
242
+ "queue_name": self.queue_name(),
243
+ "scheduler_name": scheduler_name,
244
+ "inactivity_ttl": inactivity_ttl,
245
+ "otel_enabled": otel_enabled,
246
+ # launch time arguments
247
+ "raycluster": False,
248
+ "setup_script": "",
249
+ }
250
+
251
+ self.pod_template = load_template(
252
+ template_file=serving_constants.POD_TEMPLATE_FILE,
253
+ template_dir=os.path.join(
254
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
255
+ "serving",
256
+ "templates",
257
+ ),
258
+ **template_vars,
259
+ )
260
+
261
+ @classmethod
262
+ def from_template(cls, service_info: dict):
263
+ """Create a Compute object from a deployed Kubernetes resource."""
264
+ if "resource" not in service_info:
265
+ raise ValueError("service_info missing required key: resource")
266
+
267
+ resource = service_info["resource"]
268
+ kind = resource.get("kind", "Unknown")
269
+
270
+ if kind == "RayCluster":
271
+ template_path = resource["spec"]["headGroupSpec"]["template"]
272
+ elif kind in ["Deployment", "Service"]: # Deployment or Knative Service
273
+ template_path = resource["spec"]["template"]
274
+ else:
275
+ raise ValueError(
276
+ f"Unsupported resource kind: '{kind}'. "
277
+ f"Supported kinds are: Deployment, Service (Knative), RayCluster"
278
+ )
279
+
280
+ template_metadata = template_path["metadata"]
281
+ pod_spec = template_path["spec"]
282
+
283
+ annotations = template_metadata.get("annotations", {})
284
+
285
+ compute = cls(_skip_template_init=True)
286
+ compute.pod_template = pod_spec
287
+
288
+ # Set properties from manifest
289
+ compute._namespace = service_info["namespace"]
290
+ compute.replicas = resource["spec"].get("replicas")
291
+ compute.labels = template_metadata.get("labels", {})
292
+ compute.annotations = annotations
293
+ compute._autoscaling_config = annotations.get("autoscaling.knative.dev/config", {})
294
+ compute._queue = template_metadata.get("labels", {}).get("kai.scheduler/queue")
295
+ compute._kubeconfig_path = annotations.get(serving_constants.KUBECONFIG_PATH_ANNOTATION)
296
+
297
+ # Extract GPU annotations directly from template annotations
298
+ gpu_annotation_keys = ["gpu-memory", "gpu-fraction"]
299
+ compute._gpu_annotations = {k: v for k, v in annotations.items() if k in gpu_annotation_keys}
300
+
301
+ return compute
302
+
303
+ # ----------------- Properties ----------------- #
304
+ @property
305
+ def objects_api(self):
306
+ if self._objects_api is None:
307
+ self._objects_api = client.CustomObjectsApi()
308
+ return self._objects_api
309
+
310
+ @property
311
+ def core_api(self):
312
+ if self._core_api is None:
313
+ load_kubeconfig()
314
+ self._core_api = client.CoreV1Api()
315
+ return self._core_api
316
+
317
+ @property
318
+ def apps_v1_api(self):
319
+ if self._apps_v1_api is None:
320
+ self._apps_v1_api = client.AppsV1Api()
321
+ return self._apps_v1_api
322
+
323
+ @property
324
+ def node_v1_api(self):
325
+ if self._node_v1_api is None:
326
+ self._node_v1_api = client.NodeV1Api()
327
+ return self._node_v1_api
328
+
329
+ @property
330
+ def kubeconfig_path(self):
331
+ if self._kubeconfig_path is None:
332
+ self._kubeconfig_path = os.getenv("KUBECONFIG") or constants.DEFAULT_KUBECONFIG_PATH
333
+ return str(Path(self._kubeconfig_path).expanduser())
334
+
335
+ @property
336
+ def service_manager(self):
337
+ if self._service_manager is None:
338
+ self._load_kube_config()
339
+ # Select appropriate service manager based on configuration
340
+ if self.deployment_mode == "knative":
341
+ # Use KnativeServiceManager for autoscaling services
342
+ self._service_manager = KnativeServiceManager(
343
+ objects_api=self.objects_api,
344
+ core_api=self.core_api,
345
+ apps_v1_api=self.apps_v1_api,
346
+ namespace=self.namespace,
347
+ )
348
+ elif self.deployment_mode == "raycluster":
349
+ # Use RayClusterServiceManager for Ray distributed workloads
350
+ self._service_manager = RayClusterServiceManager(
351
+ objects_api=self.objects_api,
352
+ core_api=self.core_api,
353
+ apps_v1_api=self.apps_v1_api,
354
+ namespace=self.namespace,
355
+ )
356
+ else:
357
+ # Use DeploymentServiceManager for regular deployments
358
+ self._service_manager = DeploymentServiceManager(
359
+ objects_api=self.objects_api,
360
+ core_api=self.core_api,
361
+ apps_v1_api=self.apps_v1_api,
362
+ namespace=self.namespace,
363
+ )
364
+ return self._service_manager
365
+
366
+ @property
367
+ def secrets_client(self):
368
+ if not self.secrets:
369
+ # Skip creating secrets client if no secrets are provided
370
+ return None
371
+
372
+ if self._secrets_client is None:
373
+ self._secrets_client = KubernetesSecretsClient(
374
+ namespace=self.namespace, kubeconfig_path=self.kubeconfig_path
375
+ )
376
+ return self._secrets_client
377
+
378
+ @property
379
+ def image(self):
380
+ return self._image
381
+
382
+ @image.setter
383
+ def image(self, value: "Image"):
384
+ self._image = value
385
+
386
+ @property
387
+ def endpoint(self):
388
+ if self._endpoint is None and self.service_name:
389
+ self._endpoint = self.service_manager.get_endpoint(self.service_name)
390
+ return self._endpoint
391
+
392
+ @endpoint.setter
393
+ def endpoint(self, endpoint: str):
394
+ self._endpoint = endpoint
395
+
396
+ def _container(self):
397
+ """Get the container from the pod template."""
398
+ if "containers" not in self.pod_template:
399
+ raise ValueError("pod_template missing 'containers' field.")
400
+ return self.pod_template["containers"][0]
401
+
402
+ def _container_env(self):
403
+ container = self._container()
404
+ if "env" not in container:
405
+ return []
406
+ return container["env"]
407
+
408
+ def _set_container_resource(self, resource_name: str, value: str):
409
+ container = self._container()
410
+
411
+ # Ensure resources dict exists
412
+ if "resources" not in container:
413
+ container["resources"] = {}
414
+
415
+ # Ensure requests dict exists
416
+ if "requests" not in container["resources"]:
417
+ container["resources"]["requests"] = {}
418
+
419
+ # Ensure limits dict exists
420
+ if "limits" not in container["resources"]:
421
+ container["resources"]["limits"] = {}
422
+
423
+ # Set both requests and limits to the same value
424
+ container["resources"]["requests"][resource_name] = value
425
+ container["resources"]["limits"][resource_name] = value
426
+
427
+ def _get_container_resource(self, resource_name: str) -> Optional[str]:
428
+ resources = self._container().get("resources", {})
429
+ requests = resources.get("requests", {})
430
+ return requests.get(resource_name)
431
+
432
+ # -------------- Properties From Template -------------- #
433
+ @property
434
+ def server_image(self):
435
+ return self._container().get("image")
436
+
437
+ @server_image.setter
438
+ def server_image(self, value: str):
439
+ """Set the server image in the pod template."""
440
+ self._container()["image"] = value
441
+
442
+ @property
443
+ def server_port(self):
444
+ return self._container()["ports"][0].get("containerPort")
445
+
446
+ @server_port.setter
447
+ def server_port(self, value: int):
448
+ """Set the server port in the pod template."""
449
+ self._container()["ports"][0]["containerPort"] = value
450
+
451
+ @property
452
+ def env_vars(self):
453
+ # extract user-defined environment variables from rendered pod template
454
+ kt_env_vars = [
455
+ "POD_NAME",
456
+ "POD_NAMESPACE",
457
+ "POD_IP",
458
+ "POD_UUID",
459
+ "MODULE_NAME",
460
+ "KUBETORCH_VERSION",
461
+ "UV_LINK_MODE",
462
+ "OTEL_SERVICE_NAME",
463
+ "OTEL_EXPORTER_OTLP_ENDPOINT",
464
+ "OTEL_EXPORTER_OTLP_PROTOCOL",
465
+ "OTEL_TRACES_EXPORTER",
466
+ "OTEL_PROPAGATORS",
467
+ "KT_SERVER_PORT",
468
+ "KT_FREEZE",
469
+ "KT_INACTIVITY_TTL",
470
+ "KT_ALLOWED_SERIALIZATION",
471
+ "KT_FILE_PATH",
472
+ "KT_MODULE_NAME",
473
+ "KT_CLS_OR_FN_NAME",
474
+ "KT_CALLABLE_TYPE",
475
+ "KT_LAUNCH_ID",
476
+ "KT_SERVICE_NAME",
477
+ "KT_SERVICE_DNS",
478
+ ]
479
+ user_env_vars = {}
480
+ for env_var in self._container_env():
481
+ # skip if it was set by kubetorch
482
+ if env_var["name"] not in kt_env_vars and "value" in env_var:
483
+ user_env_vars[env_var["name"]] = env_var["value"]
484
+ return user_env_vars
485
+
486
+ @property
487
+ def resources(self):
488
+ return self._container().get("resources")
489
+
490
+ @property
491
+ def cpus(self):
492
+ return self._get_container_resource("cpu")
493
+
494
+ @cpus.setter
495
+ def cpus(self, value: str):
496
+ """
497
+ Args:
498
+ value: CPU value (e.g., "2", "1000m", "0.5")
499
+ """
500
+ self._set_container_resource("cpu", value)
501
+
502
+ @property
503
+ def memory(self):
504
+ return self._get_container_resource("memory")
505
+
506
+ @memory.setter
507
+ def memory(self, value: str):
508
+ """
509
+ Args:
510
+ value: Memory value (e.g., "4Gi", "2048Mi")
511
+ """
512
+ self._set_container_resource("memory", value)
513
+
514
+ @property
515
+ def disk_size(self):
516
+ return self._get_container_resource("ephemeral-storage")
517
+
518
+ @disk_size.setter
519
+ def disk_size(self, value: str):
520
+ """
521
+ Args:
522
+ value: Disk size (e.g., "10Gi", "5000Mi")
523
+ """
524
+ self._set_container_resource("ephemeral-storage", value)
525
+
526
+ @property
527
+ def gpus(self):
528
+ return self._get_container_resource("nvidia.com/gpu")
529
+
530
+ @gpus.setter
531
+ def gpus(self, value: Union[str, int]):
532
+ """
533
+ Args:
534
+ value: Number of GPUs (e.g., 1, "2")
535
+ """
536
+ self._set_container_resource("nvidia.com/gpu", str(value))
537
+
538
+ @property
539
+ def gpu_type(self):
540
+ node_selector = self.pod_template.get("nodeSelector")
541
+ if node_selector and "nvidia.com/gpu.product" in node_selector:
542
+ return node_selector["nvidia.com/gpu.product"]
543
+ return None
544
+
545
+ @gpu_type.setter
546
+ def gpu_type(self, value: str):
547
+ """
548
+ Args:
549
+ value: GPU product name (e.g., "L4", "V100", "A100", "T4")
550
+ """
551
+ if "nodeSelector" not in self.pod_template:
552
+ self.pod_template["nodeSelector"] = {}
553
+ self.pod_template["nodeSelector"]["nvidia.com/gpu.product"] = value
554
+
555
+ @property
556
+ def gpu_memory(self):
557
+ annotations = self.pod_template.get("annotations", {})
558
+ if "gpu-memory" in annotations:
559
+ return annotations["gpu-memory"]
560
+ return None
561
+
562
+ @gpu_memory.setter
563
+ def gpu_memory(self, value: str):
564
+ """
565
+ Args:
566
+ value: GPU memory in MiB (e.g., "4096", "8192", "16384")
567
+ """
568
+ if "annotations" not in self.pod_template:
569
+ self.pod_template["annotations"] = {}
570
+ self.pod_template["annotations"]["gpu-memory"] = value
571
+
572
+ @property
573
+ def volumes(self):
574
+ if not self._volumes:
575
+ volumes = []
576
+ if "volumes" in self.pod_template:
577
+ for volume in self.pod_template["volumes"]:
578
+ # Skip the shared memory volume
579
+ if volume["name"] == "dshm":
580
+ continue
581
+ # Skip secret volumes
582
+ if "secret" in volume:
583
+ continue
584
+ # Only include PVC volumes
585
+ if "persistentVolumeClaim" in volume:
586
+ volumes.append(volume["name"])
587
+ self._volumes = volumes
588
+ return self._volumes
589
+
590
+ @property
591
+ def shared_memory_limit(self):
592
+ if "volumes" not in self.pod_template:
593
+ return None
594
+
595
+ for volume in self.pod_template["volumes"]:
596
+ if volume.get("name") == "dshm" and "emptyDir" in volume:
597
+ empty_dir = volume["emptyDir"]
598
+ return empty_dir.get("sizeLimit")
599
+
600
+ return None
601
+
602
+ @shared_memory_limit.setter
603
+ def shared_memory_limit(self, value: str):
604
+ """
605
+ Args:
606
+ value: Size limit (e.g., "512Mi", "1Gi", "2G")
607
+ """
608
+ if "volumes" not in self.pod_template:
609
+ self.pod_template["volumes"] = []
610
+
611
+ # Find existing dshm volume and update it
612
+ for volume in self.pod_template["volumes"]:
613
+ if volume.get("name") == "dshm" and "emptyDir" in volume:
614
+ volume["emptyDir"]["sizeLimit"] = value
615
+ return
616
+
617
+ # Add new dshm volume if not found
618
+ self.pod_template["volumes"].append({"name": "dshm", "emptyDir": {"medium": "Memory", "sizeLimit": value}})
619
+
620
+ # Alias for backward compatibility (deprecated)
621
+ @property
622
+ def shm_size_limit(self):
623
+ """Deprecated: Use shared_memory_limit instead."""
624
+ return self.shared_memory_limit
625
+
626
+ @shm_size_limit.setter
627
+ def shm_size_limit(self, value: str):
628
+ """Deprecated: Use shared_memory_limit instead."""
629
+ self.shared_memory_limit = value
630
+
631
+ @property
632
+ def node_selector(self):
633
+ return self.pod_template.get("nodeSelector")
634
+
635
+ @node_selector.setter
636
+ def node_selector(self, value: dict):
637
+ """
638
+ Args:
639
+ value: Label key-value pairs (e.g., {"node-type": "gpu"})
640
+ """
641
+ self.pod_template["nodeSelector"] = value
642
+
643
+ @property
644
+ def secret_env_vars(self):
645
+ secret_env_vars = []
646
+ container = self._container()
647
+ if "env" in container:
648
+ for env_var in container["env"]:
649
+ if "valueFrom" in env_var and "secretKeyRef" in env_var["valueFrom"]:
650
+ secret_ref = env_var["valueFrom"]["secretKeyRef"]
651
+ # Find existing secret or create new entry
652
+ secret_name = secret_ref["name"]
653
+
654
+ # Check if we already have this secret
655
+ existing_secret = None
656
+ for secret in secret_env_vars:
657
+ if secret.get("secret_name") == secret_name:
658
+ existing_secret = secret
659
+ break
660
+
661
+ if existing_secret:
662
+ if "env_vars" not in existing_secret:
663
+ existing_secret["env_vars"] = []
664
+ if env_var["name"] not in existing_secret["env_vars"]:
665
+ existing_secret["env_vars"].append(env_var["name"])
666
+ else:
667
+ secret_env_vars.append({"secret_name": secret_name, "env_vars": [env_var["name"]]})
668
+ return secret_env_vars
669
+
670
+ @property
671
+ def secret_volumes(self):
672
+ secret_volumes = []
673
+ if "volumes" in self.pod_template:
674
+ for volume in self.pod_template["volumes"]:
675
+ if "secret" in volume:
676
+ secret_name = volume["secret"]["secretName"]
677
+ # Find corresponding volume mount
678
+ mount_path = None
679
+ container = self._container()
680
+ if "volumeMounts" in container:
681
+ for mount in container["volumeMounts"]:
682
+ if mount["name"] == volume["name"]:
683
+ mount_path = mount["mountPath"]
684
+ break
685
+
686
+ secret_volumes.append(
687
+ {
688
+ "name": volume["name"],
689
+ "secret_name": secret_name,
690
+ "path": mount_path or f"/secrets/{volume['name']}",
691
+ }
692
+ )
693
+ return secret_volumes
694
+
695
+ @property
696
+ def volume_mounts(self):
697
+ volume_mounts = []
698
+ container = self._container()
699
+ if "volumeMounts" in container:
700
+ for mount in container["volumeMounts"]:
701
+ # Skip the default dshm mount
702
+ if mount["name"] != "dshm":
703
+ volume_mounts.append({"name": mount["name"], "mountPath": mount["mountPath"]})
704
+ return volume_mounts
705
+
706
+ @property
707
+ def service_account_name(self):
708
+ return self.pod_template.get("serviceAccountName")
709
+
710
+ @service_account_name.setter
711
+ def service_account_name(self, value: str):
712
+ """Set service account name in the pod template."""
713
+ self.pod_template["serviceAccountName"] = value
714
+
715
+ @property
716
+ def config_env_vars(self):
717
+ from kubetorch.config import ENV_MAPPINGS
718
+
719
+ config_env_vars = {}
720
+ container = self._container()
721
+ if "env" in container:
722
+ for env_var in container["env"]:
723
+ # Filter for config-related env vars (those that start with KT_ or are known config vars)
724
+ if env_var["name"] in ENV_MAPPINGS.keys():
725
+ if "value" in env_var and env_var["value"]:
726
+ config_env_vars[env_var["name"]] = env_var["value"]
727
+ return config_env_vars
728
+
729
+ @property
730
+ def image_pull_policy(self):
731
+ return self._container().get("imagePullPolicy")
732
+
733
+ @image_pull_policy.setter
734
+ def image_pull_policy(self, value: str):
735
+ """Set image pull policy in the pod template."""
736
+ self._container()["imagePullPolicy"] = value
737
+
738
+ @property
739
+ def namespace(self):
740
+ return self._namespace
741
+
742
+ @namespace.setter
743
+ def namespace(self, value: str):
744
+ self._namespace = value
745
+
746
+ @property
747
+ def python_path(self):
748
+ if self.image and self.image.python_path:
749
+ return self.image.python_path
750
+
751
+ container = self._container()
752
+ if "env" in container:
753
+ for env_var in container["env"]:
754
+ if env_var["name"] == "KT_PYTHON_PATH" and "value" in env_var:
755
+ return env_var["value"]
756
+ return None
757
+
758
+ @property
759
+ def freeze(self):
760
+ container = self._container()
761
+ if "env" in container:
762
+ for env_var in container["env"]:
763
+ if env_var["name"] == "KT_FREEZE" and "value" in env_var:
764
+ return env_var["value"].lower() == "true"
765
+ return False
766
+
767
+ @property
768
+ def secrets(self):
769
+ if not self._secrets:
770
+ secrets = []
771
+
772
+ # Extract secrets from environment variables
773
+ container = self._container()
774
+ if "env" in container:
775
+ for env_var in container["env"]:
776
+ if "valueFrom" in env_var and "secretKeyRef" in env_var["valueFrom"]:
777
+ secret_ref = env_var["valueFrom"]["secretKeyRef"]
778
+ if secret_ref["name"] not in secrets:
779
+ secrets.append(secret_ref["name"])
780
+
781
+ # Extract secrets from volumes
782
+ if "volumes" in self.pod_template:
783
+ for volume in self.pod_template["volumes"]:
784
+ if "secret" in volume:
785
+ secret_name = volume["secret"]["secretName"]
786
+ if secret_name not in secrets:
787
+ secrets.append(secret_name)
788
+
789
+ self._secrets = secrets
790
+
791
+ return self._secrets
792
+
793
+ @property
794
+ def gpu_anti_affinity(self):
795
+ if "affinity" in self.pod_template and "nodeAffinity" in self.pod_template["affinity"]:
796
+ node_affinity = self.pod_template["affinity"]["nodeAffinity"]
797
+ if "requiredDuringSchedulingIgnoredDuringExecution" in node_affinity:
798
+ required = node_affinity["requiredDuringSchedulingIgnoredDuringExecution"]
799
+ if "nodeSelectorTerms" in required:
800
+ for term in required["nodeSelectorTerms"]:
801
+ if "matchExpressions" in term:
802
+ for expr in term["matchExpressions"]:
803
+ if expr.get("key") == "nvidia.com/gpu" and expr.get("operator") == "DoesNotExist":
804
+ return True
805
+ return False
806
+
807
+ @property
808
+ def concurrency(self):
809
+ return self.pod_template.get("containerConcurrency")
810
+
811
+ @concurrency.setter
812
+ def concurrency(self, value: int):
813
+ self.pod_template["containerConcurrency"] = value
814
+
815
+ @property
816
+ def working_dir(self):
817
+ return self._container().get("workingDir")
818
+
819
+ @working_dir.setter
820
+ def working_dir(self, value: str):
821
+ """Set working directory in the pod template."""
822
+ self._container()["workingDir"] = value
823
+
824
+ @property
825
+ def priority_class_name(self):
826
+ return self.pod_template.get("priorityClassName")
827
+
828
+ @priority_class_name.setter
829
+ def priority_class_name(self, value: str):
830
+ """Set priority class name in the pod template."""
831
+ self.pod_template["priorityClassName"] = value
832
+
833
+ @property
834
+ def otel_enabled(self):
835
+ container = self._container()
836
+ if "env" in container:
837
+ for env_var in container["env"]:
838
+ if env_var["name"] == "KT_OTEL_ENABLED" and "value" in env_var:
839
+ return env_var["value"].lower() == "true"
840
+ return False
841
+
842
+ @property
843
+ def launch_timeout(self):
844
+ container = self._container()
845
+ if "startupProbe" in container:
846
+ startup_probe = container["startupProbe"]
847
+ if "failureThreshold" in startup_probe:
848
+ # Convert back from failure threshold (launch_timeout // 5)
849
+ return startup_probe["failureThreshold"] * 5
850
+ return None
851
+
852
+ @launch_timeout.setter
853
+ def launch_timeout(self, value: int):
854
+ container = self._container()
855
+ if "startupProbe" not in container:
856
+ container["startupProbe"] = {}
857
+ # Convert timeout to failure threshold (launch_timeout // 5)
858
+ container["startupProbe"]["failureThreshold"] = value // 5
859
+
860
+ def queue_name(self):
861
+ if self.queue is not None:
862
+ return self.queue
863
+
864
+ default_queue = globals.config.queue
865
+ if default_queue:
866
+ return default_queue
867
+
868
+ @property
869
+ def queue(self):
870
+ return self._queue
871
+
872
+ @queue.setter
873
+ def queue(self, value: str):
874
+ self._queue = value
875
+
876
+ @property
877
+ def scheduler_name(self):
878
+ return self._get_scheduler_name(self.queue_name())
879
+
880
+ @property
881
+ def inactivity_ttl(self):
882
+ container = self._container()
883
+ if "env" in container:
884
+ for env_var in container["env"]:
885
+ if env_var["name"] == "KT_INACTIVITY_TTL" and "value" in env_var:
886
+ return env_var["value"] if not env_var["value"] == "None" else None
887
+ return None
888
+
889
+ @inactivity_ttl.setter
890
+ def inactivity_ttl(self, value: str):
891
+ if value and (not isinstance(value, str) or not re.match(r"^\d+[smhd]$", value)):
892
+ raise ValueError("Inactivity TTL must be a string, e.g. '5m', '1h', '1d'")
893
+ if value and not self.otel_enabled:
894
+ logger.warning(
895
+ "Inactivity TTL is only supported when OTEL is enabled, please update your Kubetorch Helm chart and restart the nginx proxy"
896
+ )
897
+
898
+ container = self._container()
899
+ if "env" not in container:
900
+ container["env"] = []
901
+
902
+ # Find existing KT_INACTIVITY_TTL env var and update it
903
+ for env_var in container["env"]:
904
+ if env_var["name"] == "KT_INACTIVITY_TTL":
905
+ env_var["value"] = value if value is not None else "None"
906
+ return
907
+
908
+ # Add new env var if not found
909
+ container["env"].append(
910
+ {
911
+ "name": "KT_INACTIVITY_TTL",
912
+ "value": value if value is not None else "None",
913
+ }
914
+ )
915
+
916
+ @property
917
+ def name(self):
918
+ container = self._container()
919
+ if "env" in container:
920
+ for env_var in container["env"]:
921
+ if env_var["name"] == "KT_SERVICE_NAME" and "value" in env_var:
922
+ return env_var["value"] if not env_var["value"] == "None" else None
923
+ return None
924
+
925
+ @property
926
+ def raycluster(self):
927
+ container = self._container()
928
+ if "ports" in container:
929
+ for port in container["ports"]:
930
+ if port.get("name") == "ray-gcs":
931
+ return True
932
+ return False
933
+
934
+ @property
935
+ def autoscaling_config(self):
936
+ return self._autoscaling_config
937
+
938
+ @property
939
+ def distributed_config(self):
940
+ # First try to get from pod template
941
+ template_config = None
942
+ container = self._container()
943
+ if "env" in container:
944
+ for env_var in container["env"]:
945
+ if env_var["name"] == "KT_DISTRIBUTED_CONFIG" and "value" in env_var and env_var["value"]:
946
+ import json
947
+
948
+ try:
949
+ template_config = json.loads(env_var["value"])
950
+ except (json.JSONDecodeError, TypeError):
951
+ template_config = env_var["value"]
952
+ break
953
+
954
+ # Return template config if available, otherwise return stored config
955
+ return template_config
956
+
957
+ @distributed_config.setter
958
+ def distributed_config(self, config: dict):
959
+ # Update pod template with distributed config
960
+ container = self._container()
961
+ if "env" not in container:
962
+ container["env"] = []
963
+
964
+ # Update or add KT_SERVICE_DNS, KT_DISTRIBUTED_CONFIG env vars
965
+ import json
966
+
967
+ service_dns = None
968
+ if config and config.get("distribution_type") == "ray":
969
+ service_dns = "ray-head-svc"
970
+ elif config and config.get("distribution_type") == "pytorch":
971
+ service_dns = "rank0"
972
+
973
+ # Serialize the config to JSON, ensuring it's always a string
974
+ # Check for non-serializable values and raise an error with details
975
+ non_serializable_keys = []
976
+ for key, value in config.items():
977
+ try:
978
+ json.dumps(value)
979
+ except (TypeError, ValueError) as e:
980
+ non_serializable_keys.append(f"'{key}': {type(value).__name__} - {str(e)}")
981
+
982
+ if non_serializable_keys:
983
+ raise ValueError(
984
+ f"Distributed config contains non-serializable values: {', '.join(non_serializable_keys)}. "
985
+ f"All values must be JSON serializable (strings, numbers, booleans, lists, dicts)."
986
+ )
987
+
988
+ service_dns_found, distributed_config_found = False, False
989
+ for env_var in self._container_env():
990
+ if env_var["name"] == "KT_SERVICE_DNS" and service_dns:
991
+ env_var["value"] = service_dns
992
+ service_dns_found = True
993
+ elif env_var["name"] == "KT_DISTRIBUTED_CONFIG":
994
+ env_var["value"] = json.dumps(config)
995
+ distributed_config_found = True
996
+
997
+ # Add any missing env vars
998
+ if service_dns and not service_dns_found:
999
+ container["env"].append({"name": "KT_SERVICE_DNS", "value": service_dns})
1000
+ if not distributed_config_found:
1001
+ container["env"].append({"name": "KT_DISTRIBUTED_CONFIG", "value": json.dumps(config)})
1002
+
1003
+ @property
1004
+ def deployment_mode(self):
1005
+ # Determine deployment mode based on distributed config and autoscaling.
1006
+ # For distributed workloads, always use the appropriate deployment mode
1007
+ if self.distributed_config:
1008
+ distribution_type = self.distributed_config.get("distribution_type")
1009
+ if distribution_type == "pytorch":
1010
+ return "deployment"
1011
+ elif distribution_type == "ray":
1012
+ return "raycluster"
1013
+
1014
+ # Use Knative for autoscaling services
1015
+ if self.autoscaling_config:
1016
+ return "knative"
1017
+
1018
+ # Default to deployment mode for simple workloads
1019
+ return "deployment"
1020
+
1021
+ # ----------------- Service Level Properties ----------------- #
1022
+
1023
+ @property
1024
+ def service_name(self):
1025
+ # Get service name from pod template if available, otherwise return stored service name
1026
+ if not self._service_name:
1027
+ for env_var in self._container_env():
1028
+ if env_var["name"] == "KT_SERVICE_NAME" and "value" in env_var:
1029
+ self._service_name = env_var["value"] if not env_var["value"] == "None" else None
1030
+ break
1031
+ return self._service_name
1032
+
1033
+ @service_name.setter
1034
+ def service_name(self, value: str):
1035
+ """Set the service name."""
1036
+ if self._service_name and not self._service_name == value:
1037
+ raise ValueError("Service name cannot be changed after it has been set")
1038
+ self._service_name = value
1039
+
1040
+ # ----------------- GPU Properties ----------------- #
1041
+
1042
+ @property
1043
+ def tolerations(self):
1044
+ return self.pod_template.get("tolerations", [])
1045
+
1046
+ @property
1047
+ def gpu_annotations(self):
1048
+ # GPU annotations for KAI scheduler
1049
+ return self._gpu_annotations
1050
+
1051
+ # ----------------- Init Template Setup Helpers ----------------- #
1052
+ def _get_server_image(self, image, otel_enabled, inactivity_ttl):
1053
+ """Return base server image"""
1054
+ image = self.image.image_id if self.image and self.image.image_id else None
1055
+
1056
+ if not image or image == serving_constants.KUBETORCH_IMAGE_TRAPDOOR:
1057
+ # No custom image or Trapdoor → pick OTEL or default
1058
+ if self._server_should_enable_otel(otel_enabled, inactivity_ttl):
1059
+ return serving_constants.SERVER_IMAGE_WITH_OTEL
1060
+ return serving_constants.SERVER_IMAGE_MINIMAL
1061
+
1062
+ return image
1063
+
1064
+ def _get_requested_resources(self, cpus, memory, disk_size, gpu_config):
1065
+ """Return requested resources."""
1066
+ requests = {}
1067
+ limits = {}
1068
+
1069
+ # Add CPU if specified
1070
+ if cpus:
1071
+ requests["cpu"] = RequestedPodResources.cpu_for_resource_request(cpus)
1072
+ limits["cpu"] = requests["cpu"]
1073
+
1074
+ # Add Memory if specified
1075
+ if memory:
1076
+ requests["memory"] = RequestedPodResources.memory_for_resource_request(memory)
1077
+ limits["memory"] = requests["memory"]
1078
+
1079
+ # Add Storage if specified
1080
+ if disk_size:
1081
+ requests["ephemeral-storage"] = disk_size
1082
+ limits["ephemeral-storage"] = disk_size
1083
+
1084
+ # Add GPU if specified
1085
+ gpu_config: dict = gpu_config
1086
+ gpu_count = gpu_config.get("count", 1)
1087
+ if gpu_config:
1088
+ if gpu_config.get("sharing_type") == "memory":
1089
+ # TODO: not currently supported
1090
+ # For memory-sharing GPUs, we don't need to request any additional resources - the KAI scheduler
1091
+ # will handle it thru annotations
1092
+ return V1ResourceRequirements()
1093
+ elif gpu_config.get("sharing_type") == "fraction":
1094
+ # For fractional GPUs, we still need to request the base GPU resource
1095
+ requests["nvidia.com/gpu"] = "1"
1096
+ limits["nvidia.com/gpu"] = "1"
1097
+ elif not gpu_config.get("sharing_type"):
1098
+ # Whole GPUs
1099
+ requests["nvidia.com/gpu"] = str(gpu_count)
1100
+ limits["nvidia.com/gpu"] = str(gpu_count)
1101
+
1102
+ # Only include non-empty dicts
1103
+ resources = {}
1104
+ if requests:
1105
+ resources["requests"] = requests
1106
+ if limits:
1107
+ resources["limits"] = limits
1108
+
1109
+ return V1ResourceRequirements(**resources).to_dict()
1110
+
1111
+ def _get_launch_timeout(self, launch_timeout):
1112
+ if launch_timeout:
1113
+ return int(launch_timeout)
1114
+ default_launch_timeout = (
1115
+ self.default_config["launch_timeout"]
1116
+ if "launch_timeout" in self.default_config
1117
+ else serving_constants.KT_LAUNCH_TIMEOUT
1118
+ )
1119
+ return int(os.getenv("KT_LAUNCH_TIMEOUT", default_launch_timeout))
1120
+
1121
+ def _get_scheduler_name(self, queue_name):
1122
+ return serving_constants.KAI_SCHEDULER_NAME if queue_name else None
1123
+
1124
+ def _get_config_env_vars(self, allowed_serialization):
1125
+ config_env_vars = globals.config._get_config_env_vars()
1126
+ if allowed_serialization:
1127
+ config_env_vars["KT_ALLOWED_SERIALIZATION"] = ",".join(allowed_serialization)
1128
+
1129
+ return config_env_vars
1130
+
1131
+ def _server_should_enable_otel(self, otel_enabled, inactivity_ttl):
1132
+ return otel_enabled and inactivity_ttl
1133
+
1134
+ def _should_install_otel_dependencies(self, server_image, otel_enabled, inactivity_ttl):
1135
+ return (
1136
+ self._server_should_enable_otel(otel_enabled, inactivity_ttl)
1137
+ and server_image != serving_constants.SERVER_IMAGE_WITH_OTEL
1138
+ )
1139
+
1140
+ @property
1141
+ def image_install_cmd(self):
1142
+ return self.image.install_cmd if self.image and self.image.install_cmd else None
1143
+
1144
+ def client_port(self) -> int:
1145
+ base_url = globals.service_url()
1146
+ _, port = extract_host_port(base_url)
1147
+ return port
1148
+
1149
+ # ----------------- GPU Init Template Setup Helpers ----------------- #
1150
+
1151
+ def _get_tolerations(self, gpus, tolerations):
1152
+ user_tolerations = tolerations if tolerations else []
1153
+
1154
+ # add required GPU tolerations for GPU workloads
1155
+ if gpus:
1156
+ required_gpu_tolerations = [
1157
+ {
1158
+ "key": "nvidia.com/gpu",
1159
+ "operator": "Exists",
1160
+ "effect": "NoSchedule",
1161
+ },
1162
+ {
1163
+ "key": "dedicated",
1164
+ "operator": "Equal",
1165
+ "value": "gpu",
1166
+ "effect": "NoSchedule",
1167
+ },
1168
+ ]
1169
+
1170
+ all_tolerations = user_tolerations.copy()
1171
+ for req_tol in required_gpu_tolerations:
1172
+ if not any(
1173
+ t["key"] == req_tol["key"]
1174
+ and t.get("operator") == req_tol.get("operator")
1175
+ and t.get("effect") == req_tol["effect"]
1176
+ and (req_tol.get("value") is None or t.get("value") == req_tol.get("value"))
1177
+ for t in all_tolerations
1178
+ ):
1179
+ all_tolerations.append(req_tol)
1180
+ return all_tolerations
1181
+
1182
+ return user_tolerations if user_tolerations else None
1183
+
1184
+ def _get_gpu_annotations(self, gpu_config):
1185
+ # https://blog.devops.dev/struggling-with-gpu-waste-on-kubernetes-how-kai-schedulers-sharing-unlocks-efficiency-1029e9bd334b
1186
+ if gpu_config is None:
1187
+ return {}
1188
+
1189
+ if gpu_config.get("sharing_type") == "memory":
1190
+ return {
1191
+ "gpu-memory": str(gpu_config["gpu_memory"]),
1192
+ }
1193
+ elif gpu_config.get("sharing_type") == "fraction":
1194
+ return {
1195
+ "gpu-fraction": str(gpu_config["gpu_fraction"]),
1196
+ }
1197
+ else:
1198
+ return {}
1199
+
1200
+ def _load_gpu_config(self, gpus, gpu_memory, gpu_type) -> dict:
1201
+ if all(x is None for x in [gpus, gpu_memory, gpu_type]):
1202
+ return {}
1203
+
1204
+ if gpus is not None:
1205
+ if isinstance(gpus, (int, float)):
1206
+ if gpus <= 0:
1207
+ raise ValueError("GPU count must be greater than 0")
1208
+ if gpus < 1:
1209
+ raise ValueError("Fractional GPUs are not currently supported. Please use whole GPUs.")
1210
+ if not str(gpus).isdigit():
1211
+ raise ValueError("Unexpected format for GPUs, expecting a numeric count")
1212
+
1213
+ gpu_config = {
1214
+ "count": int(gpus) if gpus else 1,
1215
+ "sharing_type": None,
1216
+ "gpu_memory": None,
1217
+ "gpu_type": None,
1218
+ }
1219
+
1220
+ # Handle memory specification
1221
+ if gpu_memory is not None:
1222
+ if not isinstance(gpu_memory, str):
1223
+ raise ValueError("GPU memory must be a string with suffix Mi, Gi, or Ti")
1224
+
1225
+ units = {"mi": 1, "gi": 1024, "ti": 1024 * 1024}
1226
+ val = gpu_memory.lower()
1227
+
1228
+ for suffix, factor in units.items():
1229
+ if val.endswith(suffix):
1230
+ try:
1231
+ num = float(val[: -len(suffix)])
1232
+ mi_value = int(num * factor)
1233
+ gpu_config["sharing_type"] = "memory"
1234
+ gpu_config["gpu_memory"] = str(mi_value)
1235
+ break # Successfully parsed, exit the loop
1236
+ except ValueError:
1237
+ raise ValueError("Invalid numeric value in GPU memory spec")
1238
+ else:
1239
+ # Only raise error if no suffix matched
1240
+ raise ValueError("GPU memory must end with Mi, Gi, or Ti")
1241
+
1242
+ if gpu_type is not None:
1243
+ gpu_config["gpu_type"] = gpu_type
1244
+
1245
+ return GPUConfig(**gpu_config).to_dict()
1246
+
1247
+ # ----------------- Generic Helpers ----------------- #
1248
+ def _load_kube_config(self):
1249
+ try:
1250
+ config.load_incluster_config()
1251
+ except config.config_exception.ConfigException:
1252
+ # Fall back to a local kubeconfig file
1253
+ if not Path(self.kubeconfig_path).exists():
1254
+ raise FileNotFoundError(f"Kubeconfig file not found: {self.kubeconfig_path}")
1255
+ config.load_kube_config(config_file=self.kubeconfig_path)
1256
+
1257
+ # Reset the cached API clients so they'll be reinitialized with the loaded config
1258
+ self._objects_api = None
1259
+ self._core_api = None
1260
+ self._apps_v1_api = None
1261
+
1262
+ def _load_kubetorch_global_config(self):
1263
+ global_config = {}
1264
+ kubetorch_config = self.service_manager.fetch_kubetorch_config()
1265
+ if kubetorch_config:
1266
+ defaults_yaml = kubetorch_config.get("COMPUTE_DEFAULTS", "")
1267
+ if defaults_yaml:
1268
+ try:
1269
+ validated_config = {}
1270
+ config_dict = yaml.safe_load(defaults_yaml)
1271
+ for key, value in config_dict.items():
1272
+ # Check for values as dictionaries with keys 'key' and 'value'
1273
+ if (
1274
+ isinstance(value, list)
1275
+ and len(value) > 0
1276
+ and isinstance(value[0], dict)
1277
+ and "key" in value[0]
1278
+ and "value" in value[0]
1279
+ ):
1280
+ validated_config[key] = {item["key"]: item["value"] for item in value}
1281
+ elif value is not None:
1282
+ validated_config[key] = value
1283
+ global_config = validated_config
1284
+ except yaml.YAMLError as e:
1285
+ logger.error(f"Failed to load kubetorch global config: {str(e)}")
1286
+
1287
+ if global_config:
1288
+ for key in ["inactivity_ttl"]:
1289
+ # Set values from global config where the value is not already set
1290
+ if key in global_config and self.__getattribute__(key) is None:
1291
+ self.__setattr__(key, global_config[key])
1292
+ for key in ["labels", "annotations", "env_vars"]:
1293
+ # Merge global config with existing config for dictionary values
1294
+ if key in global_config and isinstance(global_config[key], dict):
1295
+ self.__setattr__(key, {**global_config[key], **self.__getattribute__(key)})
1296
+ if "image_id" in global_config:
1297
+ if self.image is None:
1298
+ self.image = Image(image_id=global_config["image_id"])
1299
+ elif self.image.image_id is None:
1300
+ self.image.image_id = global_config["image_id"]
1301
+
1302
+ return global_config
1303
+
1304
+ # ----------------- Launching a new service (Knative or StatefulSet) ----------------- #
1305
+ def _launch(
1306
+ self,
1307
+ service_name: str,
1308
+ install_url: str,
1309
+ pointer_env_vars: Dict,
1310
+ metadata_env_vars: Dict,
1311
+ startup_rsync_command: Optional[str],
1312
+ launch_id: Optional[str],
1313
+ dryrun: bool = False,
1314
+ ):
1315
+ """Creates a new service on the compute for the provided service. If the service already exists,
1316
+ it will update the service with the latest copy of the code."""
1317
+ # Finalize pod template with launch time env vars
1318
+ self._update_launch_env_vars(service_name, pointer_env_vars, metadata_env_vars, launch_id)
1319
+ self._upload_secrets_list()
1320
+
1321
+ setup_script = self._get_setup_script(install_url, startup_rsync_command)
1322
+ self._container()["args"][0] = setup_script
1323
+
1324
+ # Handle service template creation
1325
+ # Use the replicas property for deployment scaling
1326
+ replicas = self.replicas
1327
+
1328
+ # Prepare annotations for service creation, including kubeconfig path if provided
1329
+ if self._kubeconfig_path is not None:
1330
+ self.annotations[serving_constants.KUBECONFIG_PATH_ANNOTATION] = self._kubeconfig_path
1331
+
1332
+ # Create service using the appropriate service manager
1333
+ # KnativeServiceManager will handle autoscaling config, inactivity_ttl, etc.
1334
+ # ServiceManager will handle replicas for deployments and rayclusters
1335
+ created_service = self.service_manager.create_or_update_service(
1336
+ service_name=service_name,
1337
+ module_name=pointer_env_vars["KT_MODULE_NAME"],
1338
+ pod_template=self.pod_template,
1339
+ replicas=replicas,
1340
+ autoscaling_config=self.autoscaling_config,
1341
+ gpu_annotations=self.gpu_annotations,
1342
+ inactivity_ttl=self.inactivity_ttl,
1343
+ custom_labels=self.labels,
1344
+ custom_annotations=self.annotations,
1345
+ custom_template=self.service_template,
1346
+ deployment_mode=self.deployment_mode,
1347
+ dryrun=dryrun,
1348
+ scheduler_name=self.scheduler_name,
1349
+ queue_name=self.queue_name(),
1350
+ )
1351
+
1352
+ # Handle service creation result based on resource type
1353
+ if isinstance(created_service, dict):
1354
+ # For custom resources (RayCluster, Knative), created_service is a dictionary
1355
+ service_name = created_service.get("metadata", {}).get("name")
1356
+ kind = created_service.get("kind", "")
1357
+
1358
+ if kind == "RayCluster":
1359
+ # RayCluster has headGroupSpec instead of template
1360
+ service_template = {
1361
+ "metadata": {
1362
+ "name": service_name,
1363
+ "namespace": created_service.get("metadata", {}).get("namespace"),
1364
+ },
1365
+ "spec": {"template": created_service["spec"]["headGroupSpec"]["template"]},
1366
+ }
1367
+ else:
1368
+ # For Knative services and other dict-based resources
1369
+ service_template = created_service["spec"]["template"]
1370
+ else:
1371
+ # For Deployments, created_service is a V1Deployment object
1372
+ service_name = created_service.metadata.name
1373
+ # Return dict format for compatibility with tests and reload logic
1374
+ service_template = {
1375
+ "metadata": {
1376
+ "name": created_service.metadata.name,
1377
+ "namespace": created_service.metadata.namespace,
1378
+ },
1379
+ "spec": {"template": created_service.spec.template},
1380
+ }
1381
+
1382
+ logger.debug(f"Successfully deployed {self.deployment_mode} service {service_name}")
1383
+
1384
+ return service_template
1385
+
1386
+ async def _launch_async(
1387
+ self,
1388
+ service_name: str,
1389
+ install_url: str,
1390
+ pointer_env_vars: Dict,
1391
+ metadata_env_vars: Dict,
1392
+ startup_rsync_command: Optional[str],
1393
+ launch_id: Optional[str],
1394
+ dryrun: bool = False,
1395
+ ):
1396
+ """Async version of _launch. Creates a new service on the compute for the provided service.
1397
+ If the service already exists, it will update the service with the latest copy of the code."""
1398
+
1399
+ import asyncio
1400
+
1401
+ loop = asyncio.get_event_loop()
1402
+
1403
+ service_template = await loop.run_in_executor(
1404
+ None,
1405
+ self._launch,
1406
+ service_name,
1407
+ install_url,
1408
+ pointer_env_vars,
1409
+ metadata_env_vars,
1410
+ startup_rsync_command,
1411
+ launch_id,
1412
+ dryrun,
1413
+ )
1414
+
1415
+ return service_template
1416
+
1417
+ def _update_launch_env_vars(self, service_name, pointer_env_vars, metadata_env_vars, launch_id):
1418
+ kt_env_vars = {
1419
+ **pointer_env_vars,
1420
+ **metadata_env_vars,
1421
+ "KT_LAUNCH_ID": launch_id,
1422
+ "KT_SERVICE_NAME": service_name,
1423
+ "KT_SERVICE_DNS": (
1424
+ f"{service_name}-headless.{self.namespace}.svc.cluster.local"
1425
+ if self.distributed_config
1426
+ else f"{service_name}.{self.namespace}.svc.cluster.local"
1427
+ ),
1428
+ "KT_DEPLOYMENT_MODE": self.deployment_mode,
1429
+ }
1430
+ if "OTEL_SERVICE_NAME" not in self.config_env_vars.keys():
1431
+ kt_env_vars["OTEL_SERVICE_NAME"] = service_name
1432
+
1433
+ # Ensure cluster config env vars are set
1434
+ if globals.config.cluster_config:
1435
+ if globals.config.cluster_config.get("otel_enabled"):
1436
+ kt_env_vars["KT_OTEL_ENABLED"] = True
1437
+
1438
+ # Ensure all environment variable values are strings for Kubernetes compatibility
1439
+ kt_env_vars = self._serialize_env_vars(kt_env_vars)
1440
+
1441
+ updated_env_vars = set()
1442
+ for env_var in self._container_env():
1443
+ if env_var["name"] in kt_env_vars:
1444
+ env_var["value"] = kt_env_vars[env_var["name"]]
1445
+ updated_env_vars.add(env_var["name"])
1446
+ for key, val in kt_env_vars.items():
1447
+ if key not in updated_env_vars:
1448
+ self._container_env().append({"name": key, "value": val})
1449
+
1450
+ def _serialize_env_vars(self, env_vars: Dict) -> Dict:
1451
+ import json
1452
+
1453
+ serialized_vars = {}
1454
+ for key, value in env_vars.items():
1455
+ if value is None:
1456
+ serialized_vars[key] = "null"
1457
+ elif isinstance(value, (dict, list)):
1458
+ try:
1459
+ serialized_vars[key] = json.dumps(value)
1460
+ except (TypeError, ValueError):
1461
+ serialized_vars[key] = str(value)
1462
+ elif isinstance(value, (bool, int, float)):
1463
+ serialized_vars[key] = str(value)
1464
+ else:
1465
+ serialized_vars[key] = value
1466
+ return serialized_vars
1467
+
1468
+ def _extract_secrets(self, secrets):
1469
+ if is_running_in_kubernetes():
1470
+ return [], []
1471
+
1472
+ secret_env_vars = []
1473
+ secret_volumes = []
1474
+ if secrets:
1475
+ secrets_client = KubernetesSecretsClient(namespace=self.namespace, kubeconfig_path=self.kubeconfig_path)
1476
+ secret_objects = secrets_client.convert_to_secret_objects(secrets=secrets)
1477
+ (
1478
+ secret_env_vars,
1479
+ secret_volumes,
1480
+ ) = secrets_client.extract_envs_and_volumes_from_secrets(secret_objects)
1481
+
1482
+ return secret_env_vars, secret_volumes
1483
+
1484
+ def _get_setup_script(self, install_url, startup_rsync_command):
1485
+ # Load the setup script template
1486
+ from kubetorch.servers.http.utils import _get_rendered_template
1487
+
1488
+ setup_script = _get_rendered_template(
1489
+ serving_constants.KT_SETUP_TEMPLATE_FILE,
1490
+ template_dir=os.path.join(
1491
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
1492
+ "serving",
1493
+ "templates",
1494
+ ),
1495
+ python_path=self.python_path,
1496
+ freeze=self.freeze,
1497
+ install_url=install_url or globals.config.install_url,
1498
+ install_cmd=self.image_install_cmd,
1499
+ install_otel=self._should_install_otel_dependencies(
1500
+ self.server_image, self.otel_enabled, self.inactivity_ttl
1501
+ ),
1502
+ server_image=self.server_image,
1503
+ rsync_kt_local_cmd=startup_rsync_command,
1504
+ server_port=self.server_port,
1505
+ )
1506
+ return setup_script
1507
+
1508
+ def _upload_secrets_list(self):
1509
+ """Upload secrets to Kubernetes. Called during launch time, not during init."""
1510
+ if is_running_in_kubernetes():
1511
+ return
1512
+
1513
+ if self.secrets:
1514
+ logger.debug("Uploading secrets to Kubernetes")
1515
+ self.secrets_client.upload_secrets_list(secrets=self.secrets)
1516
+
1517
+ def _get_node_selector(self, node_selector, gpu_type):
1518
+ if gpu_type:
1519
+ if ":" in gpu_type:
1520
+ # Parse "key: value" format
1521
+ key, value = gpu_type.split(":", 1)
1522
+ node_selector[key.strip()] = value.strip()
1523
+ else:
1524
+ # Default to nvidia.com/gpu.product
1525
+ node_selector["nvidia.com/gpu.product"] = gpu_type
1526
+ return node_selector
1527
+
1528
+ def pod_names(self):
1529
+ """Returns a list of pod names."""
1530
+ pods = self.pods()
1531
+ return [pod.metadata.name for pod in pods if pod_is_running(pod)]
1532
+
1533
+ def pods(self):
1534
+ return self.service_manager.get_pods_for_service(self.service_name)
1535
+
1536
+ # ------------------------------- Volumes ------------------------------ #
1537
+ def _process_volumes(self, volumes) -> Optional[List[Volume]]:
1538
+ """Process volumes input into standardized format"""
1539
+ if volumes is None:
1540
+ volumes = globals.config.volumes
1541
+
1542
+ if volumes is None:
1543
+ return None
1544
+
1545
+ if isinstance(volumes, list):
1546
+ processed_volumes = []
1547
+ for vol in volumes:
1548
+ if isinstance(vol, str):
1549
+ # list of volume names (assume they exist)
1550
+ volume = Volume.from_name(vol, create_if_missing=False, core_v1=self.core_api)
1551
+ processed_volumes.append(volume)
1552
+
1553
+ elif isinstance(vol, Volume):
1554
+ # list of Volume objects (create them if they don't already exist)
1555
+ # Default the volume namespace to compute namespace if not provided
1556
+ if vol.namespace is None:
1557
+ vol.namespace = self.namespace
1558
+ vol.create()
1559
+ processed_volumes.append(vol)
1560
+
1561
+ else:
1562
+ raise ValueError(f"Volume list items must be strings or Volume objects, got {type(vol)}")
1563
+
1564
+ return processed_volumes
1565
+
1566
+ else:
1567
+ raise ValueError(f"Volumes must be a list, got {type(volumes)}")
1568
+
1569
+ def _volumes_for_pod_template(self, volumes):
1570
+ """Convert processed volumes to template format"""
1571
+ volume_mounts = []
1572
+ volume_specs = []
1573
+
1574
+ if volumes:
1575
+ for volume in volumes:
1576
+ # Add volume mount
1577
+ volume_mounts.append({"name": volume.name, "mountPath": volume.mount_path})
1578
+
1579
+ # Add volume spec
1580
+ volume_specs.append(volume.pod_template_spec())
1581
+
1582
+ return volume_mounts, volume_specs
1583
+
1584
+ # ----------------- Functions using K8s implementation ----------------- #
1585
+ def _wait_for_endpoint(self):
1586
+ retries = 20
1587
+ for i in range(retries):
1588
+ endpoint = self.endpoint
1589
+ if endpoint:
1590
+ return endpoint
1591
+ else:
1592
+ logger.info(f"Endpoint not available (attempt {i + 1}/{retries})")
1593
+ time.sleep(2)
1594
+
1595
+ logger.error(f"Endpoint not available for {self.service_name}")
1596
+ return None
1597
+
1598
+ def _status_condition_ready(self, status) -> bool:
1599
+ """
1600
+ Checks if the Knative Service status conditions include a 'Ready' condition with status 'True'.
1601
+ This indicates that the service is ready to receive traffic.
1602
+
1603
+ Notes:
1604
+ - This does not check pod status or readiness, only the Knative Service's own readiness condition.
1605
+ - A service can be 'Ready' even if no pods are currently running (e.g., after scaling to zero).
1606
+ """
1607
+ for condition in status.get("conditions", []):
1608
+ if condition.get("type") == "Ready" and condition.get("status") == "True":
1609
+ logger.debug(f"Service {self.service_name} is ready")
1610
+ return True
1611
+ return False
1612
+
1613
+ def _check_service_ready(self):
1614
+ """Checks if the service is ready to start serving requests.
1615
+
1616
+ Delegates to the appropriate service manager's check_service_ready method.
1617
+ """
1618
+ return self.service_manager.check_service_ready(
1619
+ service_name=self.service_name,
1620
+ launch_timeout=self.launch_timeout,
1621
+ objects_api=self.objects_api,
1622
+ core_api=self.core_api,
1623
+ queue_name=self.queue_name(),
1624
+ scheduler_name=self.scheduler_name,
1625
+ )
1626
+
1627
+ async def _check_service_ready_async(self):
1628
+ """Async version of _check_service_ready. Checks if the service is ready to start serving requests.
1629
+
1630
+ Delegates to the appropriate service manager's check_service_ready method.
1631
+ """
1632
+ import asyncio
1633
+
1634
+ loop = asyncio.get_event_loop()
1635
+
1636
+ return await loop.run_in_executor(
1637
+ None,
1638
+ self._check_service_ready,
1639
+ )
1640
+
1641
+ def is_up(self):
1642
+ """Whether the pods are running."""
1643
+ try:
1644
+ pods = self.pods()
1645
+ if not pods:
1646
+ return False
1647
+ for pod in pods:
1648
+ if pod.status.phase != "Running":
1649
+ logger.info(f"Pod {pod.metadata.name} is not running. Status: {pod.status.phase}")
1650
+ return False
1651
+ except client.exceptions.ApiException:
1652
+ return False
1653
+ return True
1654
+
1655
+ def _base_rsync_url(self, local_port: int):
1656
+ return f"rsync://localhost:{local_port}/data/{self.namespace}/{self.service_name}"
1657
+
1658
+ def _rsync_svc_url(self):
1659
+ return f"rsync://kubetorch-rsync.{self.namespace}.svc.cluster.local:{serving_constants.REMOTE_RSYNC_PORT}/data/{self.namespace}/{self.service_name}/"
1660
+
1661
+ def ssh(self, pod_name: str = None):
1662
+ pod_name = pod_name or self.pod_names()[0]
1663
+ ssh_cmd = f"kubectl exec -it {pod_name} -n {self.namespace} -- /bin/bash"
1664
+ subprocess.run(shlex.split(ssh_cmd), check=True)
1665
+
1666
+ def get_env_vars(self, keys: Union[List[str], str] = None):
1667
+ keys = [keys] if isinstance(keys, str) else keys
1668
+ env_vars = {}
1669
+ for env_var in self._container_env():
1670
+ if not keys or (env_var["name"] in keys and "value" in env_var):
1671
+ env_vars[env_var["name"]] = env_var["value"]
1672
+ return env_vars
1673
+
1674
+ # ----------------- Image Related Functionality ----------------- #
1675
+
1676
+ def pip_install(
1677
+ self,
1678
+ reqs: Union[List[str], str],
1679
+ node: Optional[str] = None,
1680
+ override_remote_version: bool = False,
1681
+ ):
1682
+ """Pip install reqs onto compute pod(s)."""
1683
+ reqs = [reqs] if isinstance(reqs, str) else reqs
1684
+ python_path = self.image.python_path if self.image else "python3"
1685
+ pip_install_cmd = f"{python_path} -m pip install"
1686
+ try:
1687
+ result = self.run_bash("cat .kt/kt_pip_install_cmd 2>/dev/null || echo ''", node=node)
1688
+ if result and result[0][0] == 0 and result[0][1].strip():
1689
+ pip_install_cmd = result[0][1].strip()
1690
+ except Exception:
1691
+ pass
1692
+
1693
+ for req in reqs:
1694
+ base = self.working_dir or "."
1695
+ remote_editable = self.run_bash(f"[ -d {base}/{req} ]", node=node)[0][0] == 0
1696
+ if remote_editable:
1697
+ req = f"{base}/{req}"
1698
+ else:
1699
+ local_version = find_locally_installed_version(req)
1700
+ if local_version is not None:
1701
+ if not override_remote_version:
1702
+ installed_remotely = (
1703
+ self.run_bash(
1704
+ f"{python_path} -c \"import importlib.util; exit(0) if importlib.util.find_spec('{req}') else exit(1)\"",
1705
+ node=node,
1706
+ )[0][0]
1707
+ == 0
1708
+ )
1709
+ if installed_remotely:
1710
+ logger.info(f"{req} already installed. Skipping.")
1711
+ continue
1712
+ else:
1713
+ req = f"{req}=={local_version}"
1714
+
1715
+ logger.info(f"Pip installing {req} with: {pip_install_cmd} {req}")
1716
+ self.run_bash(f"{pip_install_cmd} {req}", node=node)
1717
+
1718
+ def sync_package(
1719
+ self,
1720
+ package: str,
1721
+ node: Optional[str] = None,
1722
+ ):
1723
+ """Sync package (locally installed, or path to package) to compute pod(s)."""
1724
+ full_path, dest_dir = _get_sync_package_paths(package)
1725
+ logger.info(f"Syncing over package at {full_path} to {dest_dir}")
1726
+ self.rsync(source=full_path, dest=dest_dir)
1727
+
1728
+ def run_bash(
1729
+ self,
1730
+ commands,
1731
+ node: Union[str, List[str]] = None,
1732
+ container: Optional[str] = None,
1733
+ ):
1734
+ """Run bash commands on the pod(s)."""
1735
+ self._load_kube_config()
1736
+
1737
+ pod_names = self.pod_names() if node in ["all", None] else [node] if isinstance(node, str) else node
1738
+
1739
+ return _run_bash(
1740
+ commands=commands,
1741
+ core_api=self.core_api,
1742
+ pod_names=pod_names,
1743
+ namespace=self.namespace,
1744
+ container=container,
1745
+ )
1746
+
1747
+ def _create_rsync_target_dir(self):
1748
+ """Create the subdirectory for this particular service in the rsync pod."""
1749
+ subdir = f"/data/{self.namespace}/{self.service_name}"
1750
+
1751
+ label_selector = f"app={serving_constants.RSYNC_SERVICE_NAME}"
1752
+ pod_name = (
1753
+ self.core_api.list_namespaced_pod(namespace=self.namespace, label_selector=label_selector)
1754
+ .items[0]
1755
+ .metadata.name
1756
+ )
1757
+ subdir_cmd = f"kubectl exec {pod_name} -n {self.namespace} -- mkdir -p {subdir}"
1758
+ logger.info(f"Creating directory on rsync pod with cmd: {subdir_cmd}")
1759
+ subprocess.run(subdir_cmd, shell=True, check=True)
1760
+
1761
+ def _run_rsync_command(self, rsync_cmd, create_target_dir: bool = True):
1762
+ backup_rsync_cmd = rsync_cmd
1763
+ if "--mkpath" not in rsync_cmd and create_target_dir:
1764
+ # Warning: --mkpath requires rsync 3.2.0+
1765
+ # Note: --mkpath allows the rsync daemon to create all intermediate directories that may not exist
1766
+ # https://download.samba.org/pub/rsync/rsync.1#opt--mkpath
1767
+ rsync_cmd = rsync_cmd.replace("rsync ", "rsync --mkpath ", 1)
1768
+ logger.debug(f"Rsync command: {rsync_cmd}")
1769
+
1770
+ resp = subprocess.run(
1771
+ rsync_cmd,
1772
+ shell=True,
1773
+ capture_output=True,
1774
+ text=True,
1775
+ )
1776
+ if resp.returncode != 0:
1777
+ if (
1778
+ create_target_dir
1779
+ and ("rsync: --mkpath" in resp.stderr or "rsync: unrecognized option" in resp.stderr)
1780
+ and not is_running_in_kubernetes()
1781
+ ):
1782
+ logger.warning(
1783
+ "Rsync failed: --mkpath is not supported, falling back to creating target dir. "
1784
+ "Please upgrade rsync to 3.2.0+ to improve performance."
1785
+ )
1786
+ self._create_rsync_target_dir()
1787
+ return self._run_rsync_command(backup_rsync_cmd, create_target_dir=False)
1788
+
1789
+ raise RsyncError(rsync_cmd, resp.returncode, resp.stdout, resp.stderr)
1790
+ else:
1791
+ import fcntl
1792
+ import pty
1793
+ import select
1794
+
1795
+ logger.debug(f"Rsync command: {rsync_cmd}")
1796
+
1797
+ leader, follower = pty.openpty()
1798
+ proc = subprocess.Popen(
1799
+ shlex.split(rsync_cmd),
1800
+ stdout=follower,
1801
+ stderr=follower,
1802
+ text=True,
1803
+ close_fds=True,
1804
+ )
1805
+ os.close(follower)
1806
+
1807
+ # Set to non-blocking mode
1808
+ flags = fcntl.fcntl(leader, fcntl.F_GETFL)
1809
+ fcntl.fcntl(leader, fcntl.F_SETFL, flags | os.O_NONBLOCK)
1810
+ buffer = b""
1811
+ transfer_completed = False
1812
+ error_patterns = [
1813
+ r"rsync\(\d+\): error:",
1814
+ r"rsync error:",
1815
+ r"@ERROR:",
1816
+ ]
1817
+ error_regexes = [re.compile(pattern, re.IGNORECASE) for pattern in error_patterns]
1818
+
1819
+ try:
1820
+ with os.fdopen(leader, "rb", buffering=0) as stdout:
1821
+ while True:
1822
+ rlist, _, _ = select.select([stdout], [], [], 0.1) # 0.1 sec timeout for responsiveness
1823
+ if stdout in rlist:
1824
+ try:
1825
+ chunk = os.read(stdout.fileno(), 1024)
1826
+ except BlockingIOError:
1827
+ continue # no data available, try again
1828
+
1829
+ if not chunk: # EOF
1830
+ break
1831
+
1832
+ buffer += chunk
1833
+ while b"\n" in buffer:
1834
+ line, buffer = buffer.split(b"\n", 1)
1835
+ decoded_line = line.decode(errors="replace").strip()
1836
+ logger.debug(f"{decoded_line}")
1837
+
1838
+ for error_regex in error_regexes:
1839
+ if error_regex.search(decoded_line):
1840
+ raise RsyncError(rsync_cmd, 1, decoded_line, decoded_line)
1841
+
1842
+ if "total size is" in decoded_line and "speedup is" in decoded_line:
1843
+ transfer_completed = True
1844
+
1845
+ if transfer_completed:
1846
+ break
1847
+
1848
+ exit_code = proc.poll()
1849
+ if exit_code is not None:
1850
+ if exit_code != 0:
1851
+ raise RsyncError(
1852
+ rsync_cmd,
1853
+ exit_code,
1854
+ output=decoded_line,
1855
+ stderr=decoded_line,
1856
+ )
1857
+ break
1858
+
1859
+ proc.terminate()
1860
+ except Exception as e:
1861
+ proc.terminate()
1862
+ raise e
1863
+
1864
+ if not transfer_completed:
1865
+ logger.error("Rsync process ended without completion message")
1866
+ proc.terminate()
1867
+ raise subprocess.CalledProcessError(
1868
+ 1,
1869
+ rsync_cmd,
1870
+ output="",
1871
+ stderr="Rsync completed without success indication",
1872
+ )
1873
+
1874
+ logger.info("Rsync operation completed successfully")
1875
+
1876
+ async def _run_rsync_command_async(self, rsync_cmd: str, create_target_dir: bool = True):
1877
+ """Async version of _run_rsync_command using asyncio.subprocess."""
1878
+ import asyncio
1879
+
1880
+ if "--mkpath" not in rsync_cmd and create_target_dir:
1881
+ # Warning: --mkpath requires rsync 3.2.0+
1882
+ # Note: --mkpath allows the rsync daemon to create all intermediate directories that may not exist
1883
+ # https://download.samba.org/pub/rsync/rsync.1#opt--mkpath
1884
+ rsync_cmd = rsync_cmd.replace("rsync ", "rsync --mkpath ", 1)
1885
+ logger.debug(f"Rsync command: {rsync_cmd}")
1886
+
1887
+ # Use asyncio.create_subprocess_shell for shell commands
1888
+ proc = await asyncio.create_subprocess_shell(
1889
+ rsync_cmd,
1890
+ stdout=asyncio.subprocess.PIPE,
1891
+ stderr=asyncio.subprocess.PIPE,
1892
+ )
1893
+
1894
+ stdout_bytes, stderr_bytes = await proc.communicate()
1895
+ stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
1896
+ stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
1897
+
1898
+ if proc.returncode != 0:
1899
+ if proc.returncode is None:
1900
+ proc.terminate()
1901
+ if "rsync: --mkpath" in stderr or "rsync: unrecognized option" in stderr:
1902
+ error_msg = (
1903
+ "Rsync failed: --mkpath is not supported, please upgrade your rsync version to 3.2.0+ to "
1904
+ "improve performance (e.g. `brew install rsync`)"
1905
+ )
1906
+ raise RsyncError(rsync_cmd, proc.returncode, stdout, error_msg)
1907
+
1908
+ raise RsyncError(rsync_cmd, proc.returncode, stdout, stderr)
1909
+
1910
+ def _get_rsync_cmd(
1911
+ self,
1912
+ source: Union[str, List[str]],
1913
+ dest: str,
1914
+ rsync_local_port: int,
1915
+ contents: bool = False,
1916
+ filter_options: str = None,
1917
+ force: bool = False,
1918
+ ):
1919
+ if dest:
1920
+ # Handle tilde prefix - treat as relative to home/working directory
1921
+ if dest.startswith("~/"):
1922
+ # Strip ~/ prefix to make it relative
1923
+ dest = dest[2:]
1924
+
1925
+ # Handle absolute vs relative paths
1926
+ if dest.startswith("/"):
1927
+ # For absolute paths, store under special __absolute__ subdirectory in the rsync pod
1928
+ # to preserve the path structure
1929
+ dest_for_rsync = f"__absolute__{dest}"
1930
+ else:
1931
+ # Relative paths work as before
1932
+ dest_for_rsync = dest
1933
+ remote_dest = f"{self._base_rsync_url(rsync_local_port)}/{dest_for_rsync}"
1934
+ else:
1935
+ remote_dest = self._base_rsync_url(rsync_local_port)
1936
+
1937
+ source = [source] if isinstance(source, str) else source
1938
+
1939
+ for src in source:
1940
+ if not Path(src).expanduser().exists():
1941
+ raise ValueError(f"Could not locate path to sync up: {src}")
1942
+
1943
+ exclude_options = _get_rsync_exclude_options()
1944
+
1945
+ expanded_sources = []
1946
+ for s in source:
1947
+ path = Path(s).expanduser().absolute()
1948
+ if not path.exists():
1949
+ raise ValueError(f"Could not locate path to sync up: {s}")
1950
+
1951
+ path_str = str(path)
1952
+ if contents and path.is_dir() and not str(s).endswith("/"):
1953
+ path_str += "/"
1954
+ expanded_sources.append(path_str)
1955
+
1956
+ source_str = " ".join(expanded_sources)
1957
+
1958
+ rsync_cmd = f"rsync -avL {exclude_options}"
1959
+
1960
+ if filter_options:
1961
+ rsync_cmd += f" {filter_options}"
1962
+
1963
+ if force:
1964
+ rsync_cmd += " --ignore-times"
1965
+
1966
+ rsync_cmd += f" {source_str} {remote_dest}"
1967
+ return rsync_cmd
1968
+
1969
+ def _get_rsync_in_cluster_cmd(
1970
+ self,
1971
+ source: Union[str, List[str]],
1972
+ dest: str,
1973
+ contents: bool = False,
1974
+ filter_options: str = None,
1975
+ force: bool = False,
1976
+ ):
1977
+ """Generate rsync command for in-cluster execution."""
1978
+ # Handle tilde prefix in dest - treat as relative to home/working directory
1979
+ if dest and dest.startswith("~/"):
1980
+ dest = dest[2:] # Strip ~/ prefix to make it relative
1981
+
1982
+ source = [source] if isinstance(source, str) else source
1983
+ if self.working_dir:
1984
+ source = [src.replace(self.working_dir, "") for src in source]
1985
+
1986
+ if contents:
1987
+ if self.working_dir:
1988
+ source = [s if s.endswith("/") or not Path(self.working_dir, s).is_dir() else s + "/" for s in source]
1989
+ else:
1990
+ source = [s if s.endswith("/") or not Path(s).is_dir() else s + "/" for s in source]
1991
+
1992
+ source_str = " ".join(source)
1993
+
1994
+ exclude_options = _get_rsync_exclude_options()
1995
+
1996
+ base_remote = self._rsync_svc_url()
1997
+
1998
+ if dest is None:
1999
+ # no dest specified -> use base
2000
+ remote = base_remote
2001
+ elif dest.startswith("rsync://"):
2002
+ # if full rsync:// URL -> use as-is
2003
+ remote = dest
2004
+ else:
2005
+ # if relative subdir specified -> append to base
2006
+ remote = base_remote + dest.lstrip("/")
2007
+
2008
+ # rsync wants the remote last; ensure it ends with '/' so we copy *into* the dir
2009
+ if not remote.endswith("/"):
2010
+ remote += "/"
2011
+
2012
+ rsync_command = f"rsync -av {exclude_options}"
2013
+ if filter_options:
2014
+ rsync_command += f" {filter_options}"
2015
+ if force:
2016
+ rsync_command += " --ignore-times"
2017
+
2018
+ rsync_command += f" {source_str} {remote}"
2019
+ return rsync_command
2020
+
2021
+ def _rsync(
2022
+ self,
2023
+ source: Union[str, List[str]],
2024
+ dest: str,
2025
+ rsync_local_port: int,
2026
+ contents: bool = False,
2027
+ filter_options: str = None,
2028
+ force: bool = False,
2029
+ ):
2030
+ rsync_cmd = self._get_rsync_cmd(source, dest, rsync_local_port, contents, filter_options, force)
2031
+ self._run_rsync_command(rsync_cmd)
2032
+
2033
+ async def _rsync_async(
2034
+ self,
2035
+ source: Union[str, List[str]],
2036
+ dest: str,
2037
+ rsync_local_port: int,
2038
+ contents: bool = False,
2039
+ filter_options: str = None,
2040
+ force: bool = False,
2041
+ ):
2042
+ """Async version of _rsync_helper."""
2043
+ rsync_cmd = self._get_rsync_cmd(source, dest, rsync_local_port, contents, filter_options, force)
2044
+ await self._run_rsync_command_async(rsync_cmd)
2045
+
2046
+ def _get_websocket_info(self, local_port: int):
2047
+ rsync_local_port = local_port or serving_constants.LOCAL_NGINX_PORT
2048
+ base_url = globals.service_url()
2049
+
2050
+ # api_url = globals.config.api_url
2051
+
2052
+ # # Determine if we need port forwarding to reach nginx proxy
2053
+ # should_port_forward = api_url is None
2054
+
2055
+ # if should_port_forward:
2056
+ # base_url = globals.service_url()
2057
+ # else:
2058
+ # # Direct access to nginx proxy via ingress
2059
+ # base_url = api_url # e.g. "https://your.ingress.domain"
2060
+
2061
+ ws_url = f"{http_to_ws(base_url)}/rsync/{self.namespace}/"
2062
+ parsed_url = urlparse(base_url)
2063
+
2064
+ # choose a local ephemeral port for the tunnel
2065
+ start_from = (parsed_url.port or rsync_local_port) + 1
2066
+ websocket_port = find_available_port(start_from, max_tries=10)
2067
+ return websocket_port, ws_url
2068
+
2069
+ def rsync(
2070
+ self,
2071
+ source: Union[str, List[str]],
2072
+ dest: str = None,
2073
+ local_port: int = None,
2074
+ contents: bool = False,
2075
+ filter_options: str = None,
2076
+ force: bool = False,
2077
+ ):
2078
+ """Rsync from local to the rsync pod."""
2079
+ # Note: use the nginx port by default since the rsync pod sits behind the nginx proxy
2080
+ websocket_port, ws_url = self._get_websocket_info(local_port)
2081
+
2082
+ logger.debug(f"Opening WebSocket tunnel on port {websocket_port} to {ws_url}")
2083
+ with WebSocketRsyncTunnel(websocket_port, ws_url) as tunnel:
2084
+ self._rsync(source, dest, tunnel.local_port, contents, filter_options, force)
2085
+
2086
+ async def rsync_async(
2087
+ self,
2088
+ source: Union[str, List[str]],
2089
+ dest: str = None,
2090
+ local_port: int = None,
2091
+ contents: bool = False,
2092
+ filter_options: str = None,
2093
+ force: bool = False,
2094
+ ):
2095
+ """Async version of rsync. Rsync from local to the rsync pod."""
2096
+ websocket_port, ws_url = self._get_websocket_info(local_port)
2097
+
2098
+ logger.debug(f"Opening WebSocket tunnel on port {websocket_port} to {ws_url}")
2099
+ with WebSocketRsyncTunnel(websocket_port, ws_url) as tunnel:
2100
+ await self._rsync_async(source, dest, tunnel.local_port, contents, filter_options, force)
2101
+
2102
+ def rsync_in_cluster(
2103
+ self,
2104
+ source: Union[str, List[str]],
2105
+ dest: str = None,
2106
+ contents: bool = False,
2107
+ filter_options: str = None,
2108
+ force: bool = False,
2109
+ ):
2110
+ """Rsync from inside the cluster to the rsync pod."""
2111
+ rsync_command = self._get_rsync_in_cluster_cmd(source, dest, contents, filter_options, force)
2112
+ self._run_rsync_command(rsync_command)
2113
+
2114
+ async def rsync_in_cluster_async(
2115
+ self,
2116
+ source: Union[str, List[str]],
2117
+ dest: str = None,
2118
+ contents: bool = False,
2119
+ filter_options: str = None,
2120
+ force: bool = False,
2121
+ ):
2122
+ """Async version of rsync_in_cluster. Rsync from inside the cluster to the rsync pod."""
2123
+ rsync_command = self._get_rsync_in_cluster_cmd(source, dest, contents, filter_options, force)
2124
+ await self._run_rsync_command_async(rsync_command)
2125
+
2126
+ def _image_setup_and_instructions(self, rsync: bool = True):
2127
+ """
2128
+ Return image instructions in Dockerfile format, and optionally rsync over content to the rsync pod.
2129
+ """
2130
+ instructions = ""
2131
+
2132
+ if not self.image:
2133
+ return instructions
2134
+
2135
+ logger.debug("Writing out image instructions.")
2136
+
2137
+ if self.image.image_id:
2138
+ instructions += f"FROM {self.server_image}\n"
2139
+ if self.image.python_path:
2140
+ instructions += f"ENV KT_PYTHON_PATH {self.image.python_path}\n"
2141
+
2142
+ # image_id is ignored, used directly in server_image
2143
+ for step in self.image.setup_steps:
2144
+ if step.step_type == ImageSetupStepType.CMD_RUN:
2145
+ commands = step.kwargs.get("command")
2146
+ commands = [commands] if isinstance(commands, str) else commands
2147
+ for i in range(len(commands)):
2148
+ if i != 0:
2149
+ instructions += "\n"
2150
+ instructions += f"RUN {commands[i]}"
2151
+ elif step.step_type == ImageSetupStepType.PIP_INSTALL:
2152
+ reqs = step.kwargs.get("reqs")
2153
+ reqs = [reqs] if isinstance(reqs, str) else reqs
2154
+ for i in range(len(reqs)):
2155
+ if i != 0:
2156
+ instructions += "\n"
2157
+ if self.image_install_cmd:
2158
+ install_cmd = self.image_install_cmd
2159
+ else:
2160
+ install_cmd = "$KT_PIP_INSTALL_CMD"
2161
+
2162
+ # Pass through the requirement string directly without quoting
2163
+ # This allows users to pass any pip arguments they want
2164
+ # e.g., "--pre torchmonarch==0.1.0rc7" or "numpy>=1.20"
2165
+ instructions += f"RUN {install_cmd} {reqs[i]}"
2166
+
2167
+ if step.kwargs.get("force"):
2168
+ instructions += " # force"
2169
+ elif step.step_type == ImageSetupStepType.SYNC_PACKAGE:
2170
+ # using package name instead of paths, since the folder path in the rsync pod will just be the package name
2171
+ full_path, dest_dir = _get_sync_package_paths(step.kwargs.get("package"))
2172
+ if rsync:
2173
+ self.rsync(full_path, dest=dest_dir)
2174
+ instructions += f"COPY {full_path} {dest_dir}"
2175
+ elif step.step_type == ImageSetupStepType.RSYNC:
2176
+ source_path = step.kwargs.get("source")
2177
+ dest_dir = step.kwargs.get("dest")
2178
+ contents = step.kwargs.get("contents")
2179
+ filter_options = step.kwargs.get("filter_options")
2180
+ force = step.kwargs.get("force")
2181
+
2182
+ if rsync:
2183
+ if is_running_in_kubernetes():
2184
+ self.rsync_in_cluster(
2185
+ source_path,
2186
+ dest=dest_dir,
2187
+ contents=contents,
2188
+ filter_options=filter_options,
2189
+ force=force,
2190
+ )
2191
+ else:
2192
+ self.rsync(
2193
+ source_path,
2194
+ dest=dest_dir,
2195
+ contents=contents,
2196
+ filter_options=filter_options,
2197
+ force=force,
2198
+ )
2199
+ # Generate COPY instruction with explicit destination
2200
+ if dest_dir:
2201
+ instructions += f"COPY {source_path} {dest_dir}"
2202
+ else:
2203
+ # No dest specified - use basename of source as destination
2204
+ dest_name = Path(source_path).name
2205
+ instructions += f"COPY {source_path} {dest_name}"
2206
+ elif step.step_type == ImageSetupStepType.SET_ENV_VARS:
2207
+ for key, val in step.kwargs.get("env_vars").items():
2208
+ # single env var per line in the dockerfile
2209
+ instructions += f"ENV {key} {val}\n"
2210
+ if step.kwargs.get("force") and step.step_type != ImageSetupStepType.PIP_INSTALL:
2211
+ instructions += " # force"
2212
+ instructions += "\n"
2213
+
2214
+ return instructions
2215
+
2216
+ # ----------------- Copying over for now... TBD ----------------- #
2217
+ def __getstate__(self):
2218
+ """Remove local stateful values before pickle serialization."""
2219
+ state = self.__dict__.copy()
2220
+ # Remove local stateful values that shouldn't be serialized
2221
+ state["_endpoint"] = None
2222
+ state["_service_manager"] = None
2223
+ state["_objects_api"] = None
2224
+ state["_core_api"] = None
2225
+ state["_apps_v1_api"] = None
2226
+ state["_node_v1_api"] = None
2227
+ state["_secrets_client"] = None
2228
+ return state
2229
+
2230
+ def __setstate__(self, state):
2231
+ """Restore state after pickle deserialization."""
2232
+ self.__dict__.update(state)
2233
+ # Reset local stateful values to None to ensure clean initialization
2234
+ self._endpoint = None
2235
+ self._service_manager = None
2236
+ self._objects_api = None
2237
+ self._core_api = None
2238
+ self._apps_v1_api = None
2239
+ self._node_v1_api = None
2240
+ self._secrets_client = None
2241
+
2242
+ # ------------ Distributed / Autoscaling Helpers -------- #
2243
+ def distribute(
2244
+ self,
2245
+ distribution_type: str = None,
2246
+ workers: int = None,
2247
+ quorum_timeout: int = None,
2248
+ quorum_workers: int = None,
2249
+ monitor_members: bool = None,
2250
+ **kwargs,
2251
+ ):
2252
+ """Configure the distributed worker compute needed by each service replica.
2253
+
2254
+ Args:
2255
+ distribution_type (str): The type of distributed supervisor to create.
2256
+ Options: ``spmd`` (default, if empty), ``"pytorch"``, ``"ray"``, ``"monarch"``, ``"jax"``, or ``"tensorflow"``.
2257
+ workers (int): Int representing the number of workers to create, with identical compute resources to
2258
+ the service compute. Or List of ``<int, Compute>`` pairs specifying the number of workers and the compute
2259
+ resources for each worker StatefulSet.
2260
+ quorum_timeout (int, optional): Timeout in seconds for workers to become ready and join the cluster.
2261
+ Defaults to `launch_timeout` if not provided, for both SPMD frameworks and for Ray.
2262
+ Increase this if workers need more time to start (e.g., during node autoscaling or loading down data
2263
+ during initialization).
2264
+ **kwargs: Additional framework-specific parameters (e.g., num_proc, port).
2265
+
2266
+ Note:
2267
+ List of ``<int, Compute>`` pairs is not yet supported for workers.
2268
+
2269
+ Examples:
2270
+
2271
+ .. code-block:: python
2272
+
2273
+ import kubetorch as kt
2274
+
2275
+ remote_fn = kt.fn(simple_summer, service_name).to(
2276
+ kt.Compute(
2277
+ cpus="2",
2278
+ memory="4Gi",
2279
+ image=kt.Image(image_id="rayproject/ray"),
2280
+ launch_timeout=300,
2281
+ ).distribute("ray", workers=2)
2282
+ )
2283
+
2284
+ gpus = kt.Compute(
2285
+ gpus=1,
2286
+ image=kt.Image(image_id="nvcr.io/nvidia/pytorch:23.10-py3"),
2287
+ launch_timeout=600,
2288
+ inactivity_ttl="4h",
2289
+ ).distribute("pytorch", workers=4)
2290
+ """
2291
+ # Check for conflicting configuration
2292
+ if self.autoscaling_config:
2293
+ raise ValueError(
2294
+ "Cannot use both .distribute() and .autoscale() on the same compute instance. "
2295
+ "Use .distribute() for fixed replicas with distributed training, or .autoscale() for auto-scaling services."
2296
+ )
2297
+
2298
+ # Configure distributed settings
2299
+ # Note: We default to simple SPMD distribution ("spmd") if nothing specified and compute.workers > 1
2300
+
2301
+ # User can override quorum if they want to set a lower threshold
2302
+ quorum_workers = quorum_workers or workers
2303
+ distributed_config = {
2304
+ "distribution_type": distribution_type or "spmd",
2305
+ "quorum_timeout": quorum_timeout or self.launch_timeout,
2306
+ "quorum_workers": quorum_workers,
2307
+ }
2308
+ if monitor_members is not None:
2309
+ # Note: Ray manages its own membership, so it's disabled by default in the supervisor
2310
+ # It's enabled by default for SPMD.
2311
+ distributed_config["monitor_members"] = monitor_members
2312
+ distributed_config.update(kwargs)
2313
+
2314
+ if workers:
2315
+ if not isinstance(workers, int):
2316
+ raise ValueError("Workers must be an integer. List of <integer, Compute> pairs is not yet supported")
2317
+ # Set replicas property instead of storing in distributed_config
2318
+ self.replicas = workers
2319
+
2320
+ if distributed_config:
2321
+ self.distributed_config = distributed_config
2322
+ # Invalidate cached service manager so it gets recreated with the right type
2323
+ self._service_manager = None
2324
+
2325
+ return self
2326
+
2327
+ def autoscale(self, **kwargs):
2328
+ """Configure the service with the provided autoscaling parameters.
2329
+
2330
+ You can pass any of the following keyword arguments:
2331
+
2332
+ Args:
2333
+ target (int): The concurrency/RPS/CPU/memory target per pod.
2334
+ window (str): Time window for scaling decisions, e.g. "60s".
2335
+ metric (str): Metric to scale on: "concurrency", "rps", "cpu", "memory" or custom.
2336
+ Note: "cpu" and "memory" require autoscaler_class="hpa.autoscaling.knative.dev".
2337
+ target_utilization (int): Utilization % to trigger scaling (1-100).
2338
+ min_scale (int): Minimum number of replicas. 0 allows scale to zero.
2339
+ max_scale (int): Maximum number of replicas.
2340
+ initial_scale (int): Initial number of pods.
2341
+ concurrency (int): Maximum concurrent requests per pod (containerConcurrency).
2342
+ If not set, pods accept unlimited concurrent requests.
2343
+ scale_to_zero_pod_retention_period (str): Time to keep last pod before scaling
2344
+ to zero, e.g. "30s", "1m5s".
2345
+ scale_down_delay (str): Delay before scaling down, e.g. "15m". Only for KPA.
2346
+ autoscaler_class (str): Autoscaler implementation:
2347
+ - "kpa.autoscaling.knative.dev" (default, supports concurrency/rps)
2348
+ - "hpa.autoscaling.knative.dev" (supports cpu/memory/custom metrics)
2349
+ progress_deadline (str): Time to wait for deployment to be ready, e.g. "10m".
2350
+ Must be longer than startup probe timeout.
2351
+ **extra_annotations: Additional Knative autoscaling annotations.
2352
+
2353
+ Note:
2354
+ The service will be deployed as a Knative service.
2355
+
2356
+ Timing-related defaults are applied if not explicitly set (for ML workloads):
2357
+ - scale_down_delay="1m" (avoid rapid scaling cycles)
2358
+ - scale_to_zero_pod_retention_period="10m" (keep last pod longer before scale to zero)
2359
+ - progress_deadline="10m" or greater (ensures enough time for initialization, automatically adjusted based on launch_timeout)
2360
+
2361
+ Examples:
2362
+
2363
+ .. code-block:: python
2364
+
2365
+ import kubetorch as kt
2366
+
2367
+ remote_fn = kt.fn(my_fn_obj).to(
2368
+ kt.Compute(
2369
+ cpus=".1",
2370
+ ).autoscale(min_replicas=1)
2371
+ )
2372
+
2373
+ remote_fn = kt.fn(summer).to(
2374
+ compute=kt.Compute(
2375
+ cpus=".01",
2376
+ ).autoscale(min_scale=3, scale_to_zero_grace_period=50),
2377
+ )
2378
+ """
2379
+ # Check for conflicting configuration
2380
+ if self.distributed_config:
2381
+ raise ValueError(
2382
+ "Cannot use both .distribute() and .autoscale() on the same compute instance. "
2383
+ "Use .distribute() for fixed replicas with distributed training, or .autoscale() for auto-scaling services."
2384
+ )
2385
+
2386
+ # Apply timing-related defaults for ML workloads to account for initialization overhead
2387
+ # (heavy dependencies, model loading, etc. affect both CPU and GPU workloads)
2388
+ if "scale_down_delay" not in kwargs:
2389
+ kwargs["scale_down_delay"] = "1m"
2390
+ logger.debug("Setting scale_down_delay=1m to avoid thrashing")
2391
+
2392
+ if "scale_to_zero_pod_retention_period" not in kwargs:
2393
+ kwargs["scale_to_zero_pod_retention_period"] = "10m"
2394
+ logger.debug("Setting scale_to_zero_pod_retention_period=10m to avoid thrashing")
2395
+
2396
+ if "progress_deadline" not in kwargs:
2397
+ # Ensure progress_deadline is at least as long as launch_timeout
2398
+ default_deadline = "10m" # 600 seconds
2399
+ if self.launch_timeout:
2400
+ # Convert launch_timeout (seconds) to a duration string
2401
+ # Add some buffer (20% or at least 60 seconds)
2402
+ timeout_with_buffer = max(self.launch_timeout + 60, int(self.launch_timeout * 1.2))
2403
+ if timeout_with_buffer > 600: # If larger than default
2404
+ default_deadline = f"{timeout_with_buffer}s"
2405
+ kwargs["progress_deadline"] = default_deadline
2406
+ logger.debug(f"Setting progress_deadline={default_deadline} to allow time for initialization")
2407
+
2408
+ autoscaling_config = AutoscalingConfig(**kwargs)
2409
+ if autoscaling_config:
2410
+ self._autoscaling_config = autoscaling_config
2411
+ # Invalidate cached service manager so it gets recreated with KnativeServiceManager
2412
+ self._service_manager = None
2413
+
2414
+ return self