kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,198 @@
1
+ {% if service_account_name is not none %}
2
+ serviceAccountName: {{ service_account_name }}
3
+ {% endif %}
4
+
5
+ {% if priority_class_name is not none %}
6
+ priorityClassName: {{ priority_class_name }}
7
+ {% endif %}
8
+
9
+ {% if queue_name is not none %}
10
+ schedulerName: {{ scheduler_name }}
11
+ {% endif %}
12
+
13
+ {% if gpu_anti_affinity is sameas true %}
14
+ affinity:
15
+ nodeAffinity:
16
+ requiredDuringSchedulingIgnoredDuringExecution:
17
+ nodeSelectorTerms:
18
+ - matchExpressions:
19
+ - key: nvidia.com/gpu
20
+ operator: DoesNotExist
21
+ - key: eks.amazonaws.com/instance-gpu-count
22
+ operator: DoesNotExist
23
+ - key: cloud.google.com/gke-accelerator
24
+ operator: DoesNotExist
25
+ {% endif %}
26
+
27
+ {% if node_selector is not none %}
28
+ nodeSelector:
29
+ {% for key, value in node_selector.items() %}
30
+ {{ key }}: {{ value }}
31
+ {% endfor %}
32
+ {% endif %}
33
+
34
+ {% if tolerations is not none and tolerations|length > 0 %}
35
+ tolerations:
36
+ {% for tol in tolerations %}
37
+ - key: "{{ tol.key }}"
38
+ operator: "{{ tol.operator }}"
39
+ value: "{{ tol.value }}"
40
+ effect: "{{ tol.effect }}"
41
+ {% endfor %}
42
+ {% endif %}
43
+
44
+ timeoutSeconds: {{ launch_timeout }}
45
+ containers:
46
+ - name: kubetorch
47
+ image: {{ server_image }}
48
+ {% if image_pull_policy is not none %}
49
+ imagePullPolicy: {{ image_pull_policy }}
50
+ {% endif %}
51
+ {% if working_dir is not none %}
52
+ workingDir: {{ working_dir }}
53
+ {% endif %}
54
+ ports:
55
+ - name: http1
56
+ containerPort: {{ server_port }}
57
+ command: ["/bin/bash", "-c"]
58
+ {% if not freeze %}
59
+ securityContext:
60
+ capabilities:
61
+ add:
62
+ - "SYS_PTRACE"
63
+ {% endif %}
64
+ args:
65
+ - |
66
+ {{ setup_script | indent(8, true) }}
67
+
68
+ env:
69
+ # Pod metadata available via the Kubernetes Downward API
70
+ - name: POD_NAME
71
+ valueFrom:
72
+ fieldRef:
73
+ fieldPath: metadata.name
74
+ - name: POD_NAMESPACE
75
+ valueFrom:
76
+ fieldRef:
77
+ fieldPath: metadata.namespace
78
+ - name: POD_IP
79
+ valueFrom:
80
+ fieldRef:
81
+ fieldPath: status.podIP
82
+ - name: POD_UUID
83
+ valueFrom:
84
+ fieldRef:
85
+ fieldPath: metadata.uid
86
+ - name: MODULE_NAME
87
+ valueFrom:
88
+ fieldRef:
89
+ fieldPath: metadata.labels['kubetorch.com/module']
90
+ - name: KUBETORCH_VERSION
91
+ valueFrom:
92
+ fieldRef:
93
+ fieldPath: metadata.labels['kubetorch.com/version']
94
+ - name: SERVICE_ACCOUNT_NAME
95
+ valueFrom:
96
+ fieldRef:
97
+ fieldPath: spec.serviceAccountName
98
+ - name: UV_LINK_MODE
99
+ value: "copy" # Suppress the hardlink warning
100
+ - name: OTEL_EXPORTER_OTLP_ENDPOINT
101
+ value: "kubetorch-otel-opentelemetry-collector.kubetorch-monitoring.svc.cluster.local:4317"
102
+ - name: OTEL_EXPORTER_OTLP_PROTOCOL
103
+ value: "grpc"
104
+ - name: OTEL_TRACES_EXPORTER
105
+ value: "otlp"
106
+ - name: OTEL_PROPAGATORS
107
+ value: "tracecontext,baggage"
108
+ - name: KT_OTEL_ENABLED
109
+ value: "{{ otel_enabled }}"
110
+ - name: KT_SERVER_PORT
111
+ value: "{{ server_port }}"
112
+ - name: KT_FREEZE
113
+ value: "{{ freeze }}"
114
+ {% if inactivity_ttl is not none %}
115
+ - name: KT_INACTIVITY_TTL
116
+ value: "{{ inactivity_ttl }}"
117
+ {% endif %}
118
+ {% for key, value in config_env_vars.items() %}
119
+ - name: {{ key }}
120
+ value: "{{ value }}"
121
+ {% endfor %}
122
+ {% if env_vars is not none and env_vars|length > 0 %}
123
+ {% for key, value in env_vars.items() %}
124
+ - name: {{ key }}
125
+ value: "{{ value }}"
126
+ {% endfor %}
127
+ {% endif %}
128
+ {% for secret in secret_env_vars %}
129
+ {% for key in secret.env_vars %}
130
+ - name: {{ key }}
131
+ valueFrom:
132
+ secretKeyRef:
133
+ name: {{ secret.secret_name }}
134
+ key: {{ key }}
135
+ {% endfor %}
136
+ {% endfor %}
137
+ volumeMounts:
138
+ - mountPath: /dev/shm
139
+ name: dshm
140
+ {% for secret in secret_volumes %}
141
+ - name: {{ secret.name }}
142
+ mountPath: {{ secret.path }}
143
+ readOnly: true
144
+ {% endfor %}
145
+ {% if volume_mounts is not none and volume_mounts|length > 0 %}
146
+ {% for mount in volume_mounts %}
147
+ - name: {{ mount.name }}
148
+ mountPath: {{ mount.mountPath }}
149
+ {% endfor %}
150
+ {% endif %}
151
+ resources:
152
+ {{ resources | tojson }}
153
+ # TODO: do we want these health checks?
154
+ # Note: Knative won't consider the service ready to receive traffic until the probe succeeds at least once
155
+ # Initial readiness check
156
+ startupProbe:
157
+ httpGet:
158
+ path: /health
159
+ port: {{ server_port }}
160
+ initialDelaySeconds: 0
161
+ periodSeconds: 5
162
+ timeoutSeconds: 2
163
+ failureThreshold: {{ launch_timeout // 5 }}
164
+ readinessProbe:
165
+ httpGet:
166
+ path: /health
167
+ port: {{ server_port }}
168
+ periodSeconds: 3
169
+ successThreshold: 1
170
+ failureThreshold: 5
171
+ # Ongoing health monitoring with less frequent checks
172
+ livenessProbe:
173
+ httpGet:
174
+ path: /health
175
+ port: {{ server_port }}
176
+ periodSeconds: 30
177
+ timeoutSeconds: 1
178
+ failureThreshold: 3
179
+
180
+ volumes:
181
+ - name: dshm
182
+ emptyDir:
183
+ medium: Memory
184
+ {% if shm_size_limit is not none %}
185
+ sizeLimit: {{ shm_size_limit }}
186
+ {% endif %}
187
+ {% for secret in secret_volumes %}
188
+ - name: {{ secret.name }}
189
+ secret:
190
+ secretName: {{ secret.secret_name }}
191
+ {% endfor %}
192
+ {% if volume_specs is not none and volume_specs|length > 0 %}
193
+ {% for spec in volume_specs %}
194
+ - name: {{ spec.name }}
195
+ persistentVolumeClaim:
196
+ claimName: {{ spec.persistentVolumeClaim.claimName }}
197
+ {% endfor %}
198
+ {% endif %}
@@ -0,0 +1,42 @@
1
+ apiVersion: v1
2
+ kind: Service
3
+ metadata:
4
+ name: {{ name }}
5
+ namespace: {{ namespace }}
6
+ annotations: {{ annotations | tojson }}
7
+ labels: {{ labels | tojson }}
8
+ spec:
9
+ {% if distributed %}
10
+ clusterIP: None # Headless service for Ray pod discovery
11
+ {% else %}
12
+ sessionAffinity: ClientIP # Ensure requests from same client go to same pod
13
+ {% endif %}
14
+ selector:
15
+ kubetorch.com/service: {{ deployment_name }}
16
+ kubetorch.com/module: {{ module_name }}
17
+ ray.io/node-type: head # Only select head node pods
18
+ ports:
19
+ - name: http
20
+ port: 80
21
+ targetPort: {{ server_port }}
22
+ protocol: TCP
23
+ - name: ray-gcs
24
+ port: 6379
25
+ targetPort: 6379
26
+ protocol: TCP
27
+ - name: ray-object-mgr
28
+ port: 8076
29
+ targetPort: 8076
30
+ protocol: TCP
31
+ - name: ray-node-mgr
32
+ port: 8077
33
+ targetPort: 8077
34
+ protocol: TCP
35
+ - name: ray-dashboard
36
+ port: 8265
37
+ targetPort: 8265
38
+ protocol: TCP
39
+ - name: ray-metrics
40
+ port: 8080
41
+ targetPort: 8080
42
+ protocol: TCP
@@ -0,0 +1,35 @@
1
+ apiVersion: ray.io/v1
2
+ kind: RayCluster
3
+ metadata:
4
+ name: {{ name }}
5
+ namespace: {{ namespace }}
6
+ annotations: {{ annotations | tojson }}
7
+ labels: {{ labels | tojson }}
8
+ spec:
9
+ rayVersion: "2.8.0"
10
+ enableInTreeAutoscaling: false
11
+ headGroupSpec:
12
+ rayStartParams:
13
+ dashboard-host: "0.0.0.0"
14
+ port: "6379"
15
+ object-manager-port: "8076"
16
+ node-manager-port: "8077"
17
+ dashboard-port: "8265"
18
+ metrics-export-port: "8080"
19
+ replicas: 1
20
+ template:
21
+ metadata:
22
+ annotations: {{ template_annotations | tojson }}
23
+ labels: {{ head_template_labels | tojson }}
24
+ spec: {{ pod_template | tojson }}
25
+ workerGroupSpecs:
26
+ - groupName: worker-group
27
+ rayStartParams: {}
28
+ minReplicas: 0
29
+ maxReplicas: {{ worker_replicas }}
30
+ replicas: {{ worker_replicas }}
31
+ template:
32
+ metadata:
33
+ annotations: {{ template_annotations | tojson }}
34
+ labels: {{ worker_template_labels | tojson }}
35
+ spec: {{ pod_template | tojson }}
@@ -0,0 +1,21 @@
1
+ apiVersion: v1
2
+ kind: Service
3
+ metadata:
4
+ name: {{ name }}
5
+ namespace: {{ namespace }}
6
+ annotations: {{ annotations | tojson }}
7
+ labels: {{ labels | tojson }}
8
+ spec:
9
+ {% if distributed %}
10
+ clusterIP: None # Headless service for distributed pod discovery
11
+ {% else %}
12
+ sessionAffinity: ClientIP # Ensure requests from same client go to same pod
13
+ {% endif %}
14
+ selector:
15
+ kubetorch.com/service: {{ deployment_name }}
16
+ kubetorch.com/module: {{ module_name }} # Only deployment pods have this set, so allows us to exclude the jump pod
17
+ ports:
18
+ - name: http
19
+ port: 80
20
+ targetPort: {{ server_port }}
21
+ protocol: TCP
@@ -0,0 +1,36 @@
1
+ # --- Headless Service selecting all pods with 'app=my-app-name' ---
2
+ apiVersion: v1
3
+ kind: Service
4
+ metadata:
5
+ name: {{ workerset_name }}
6
+ namespace: {{ namespace }}
7
+ annotations: {{ annotations | tojson }}
8
+ labels: {{ labels | tojson }}
9
+ spec:
10
+ clusterIP: None # Make it headless
11
+ selector:
12
+ app: {{ workerset_name_app }} # Selects pods from any associated StatefulSets
13
+
14
+ ---
15
+ # --- StatefulSet ---
16
+ apiVersion: apps/v1
17
+ kind: StatefulSet
18
+ metadata:
19
+ name: {{ worker_group_name }}
20
+ namespace: {{ namespace }}
21
+ ownerReferences:
22
+ - apiVersion: v1
23
+ kind: Pod
24
+ name: {{ service_pod_name }}
25
+ uid: {{ service_pod_uid }}
26
+ spec:
27
+ serviceName: {{ workerset_name }} # Must match the headless service name
28
+ replicas: {{ replicas }}
29
+ selector:
30
+ matchLabels:
31
+ app: {{ workerset_name_app }} # Matches the service selector
32
+ template:
33
+ metadata:
34
+ labels:
35
+ app: {{ workerset_name_app }} # Label for the service
36
+ spec: {{ pod_template | tojson }}
@@ -0,0 +1,344 @@
1
+ import os
2
+ import socket
3
+ import time
4
+ import warnings
5
+
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Literal, Optional, Union
9
+
10
+ import httpx
11
+ from kubernetes.client import ApiException, CoreV1Api, V1Pod
12
+ from kubernetes.utils import parse_quantity
13
+
14
+ from kubetorch import globals
15
+ from kubetorch.logger import get_logger
16
+ from kubetorch.servers.http.utils import is_running_in_kubernetes
17
+ from kubetorch.serving.constants import LOKI_GATEWAY_SERVICE_NAME, PROMETHEUS_SERVICE_NAME
18
+ from kubetorch.utils import load_kubeconfig
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class GPUConfig:
25
+ count: Optional[int] = None
26
+ memory: Optional[str] = None
27
+ sharing_type: Optional[Literal["memory", "fraction"]] = None
28
+ gpu_memory: Optional[str] = None
29
+ gpu_fraction: Optional[str] = None
30
+ gpu_type: Optional[str] = None
31
+
32
+ def __post_init__(self):
33
+ self.validate()
34
+
35
+ def validate(self) -> bool:
36
+ if self.count and not isinstance(self.count, int):
37
+ raise ValueError("GPU count must an int")
38
+
39
+ if self.sharing_type == "memory":
40
+ if not self.gpu_memory:
41
+ raise ValueError("GPU memory must be specified when using memory sharing")
42
+ elif self.sharing_type == "fraction":
43
+ if not self.gpu_fraction:
44
+ raise ValueError("GPU fraction must be specified when using fraction sharing")
45
+ try:
46
+ fraction = float(self.gpu_fraction)
47
+ if not 0 < fraction <= 1:
48
+ raise ValueError("GPU fraction must be between 0 and 1")
49
+ except ValueError:
50
+ raise ValueError("GPU fraction must be a valid float between 0 and 1")
51
+
52
+ return True
53
+
54
+ def to_dict(self) -> dict:
55
+ base_dict = {
56
+ "sharing_type": self.sharing_type,
57
+ "count": self.count,
58
+ }
59
+
60
+ if self.memory is not None:
61
+ base_dict["memory"] = self.memory
62
+
63
+ if self.sharing_type == "memory" and self.gpu_memory:
64
+ base_dict["gpu_memory"] = self.gpu_memory
65
+ if self.sharing_type == "fraction" and self.gpu_fraction:
66
+ # Convert to millicores format
67
+ fraction = float(self.gpu_fraction)
68
+ base_dict["gpu_fraction"] = f"{int(fraction * 1000)}m"
69
+ if self.gpu_type is not None:
70
+ base_dict["gpu_type"] = self.gpu_type
71
+
72
+ return base_dict
73
+
74
+
75
+ class RequestedPodResources:
76
+ """Resources requested in a Kubetorch cluster/compute object. Note these are the values we receive
77
+ from launcher the cluster via a Sky dryrun."""
78
+
79
+ # Default overhead percentages to account for filesystem overhead, OS files, logs, container runtime, etc.
80
+ MEMORY_OVERHEAD = 0.20
81
+ CPU_OVERHEAD = 0.10
82
+ DISK_OVERHEAD = 0.15
83
+ GPU_OVERHEAD = 0.0
84
+
85
+ MIN_MEMORY_GB = 0.1 # 100Mi minimum
86
+ MIN_CPU_CORES = 0.1 # 100m minimum
87
+
88
+ CPU_STEPS = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 192]
89
+ MEMORY_STEPS = [0.5, 1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768]
90
+
91
+ def __init__(
92
+ self,
93
+ memory: Optional[Union[str, float]] = None,
94
+ cpus: Optional[Union[int, float]] = None,
95
+ disk_size: Optional[int] = None,
96
+ num_gpus: Optional[Union[int, dict]] = None,
97
+ ):
98
+
99
+ self.memory = max(float(memory), self.MIN_MEMORY_GB) if memory is not None else None
100
+ self.cpus = max(self.normalize_cpu_value(cpus), self.MIN_CPU_CORES) if cpus is not None else None
101
+ self.disk_size = disk_size
102
+ self.num_gpus = num_gpus
103
+
104
+ def __str__(self):
105
+ # Example: RequestedPodResources(memory=16.0, cpus=4.0, disk=NoneGB, gpus={'A10G': 1})"
106
+ disk_str = f"{self.disk_size}GB" if self.disk_size is not None else "None"
107
+ memory = f"{self.memory}GB" if self.memory is not None else "None"
108
+
109
+ return (
110
+ f"RequestedPodResources(memory={memory}, cpus={self.cpus}, disk_size={disk_str}, "
111
+ f"num_gpus={self.num_gpus})"
112
+ )
113
+
114
+ def __repr__(self):
115
+ return (
116
+ f"RequestedPodResources(memory={self.memory}, cpus={self.cpus}, "
117
+ f"disk_size={self.disk_size}, num_gpus={self.num_gpus})"
118
+ )
119
+
120
+ @classmethod
121
+ def cpu_for_resource_request(cls, cpu_val: int = None):
122
+ if cpu_val is None:
123
+ return None
124
+
125
+ # Ensure minimum CPU value
126
+ cpu_val = max(float(cpu_val), cls.MIN_CPU_CORES)
127
+
128
+ # Convert to millicores (ex: '4.0' -> 4000m)
129
+ return f"{int(float(cpu_val) * 1000)}m"
130
+
131
+ @classmethod
132
+ def memory_for_resource_request(cls, memory_val: Union[str, float, int] = None):
133
+ if memory_val is None:
134
+ return None
135
+
136
+ # If it's a number, treat as GB
137
+ if isinstance(memory_val, (int, float)):
138
+ gb_val = max(float(memory_val), cls.MIN_MEMORY_GB)
139
+ memory_val = f"{gb_val}Gi"
140
+
141
+ # Validate the string - if invalid will throw a ValueError
142
+ parse_quantity(str(memory_val))
143
+
144
+ return str(memory_val)
145
+
146
+ @classmethod
147
+ def normalize_cpu_value(cls, cpu_value: Optional[Union[int, str, float]]) -> Optional[float]:
148
+ """Convert CPU value to float, handling string values with '+' allowed by Sky and Kubetorch."""
149
+ if cpu_value is None:
150
+ return None
151
+
152
+ if isinstance(cpu_value, str):
153
+ # Strip the '+' if present and convert to float
154
+ return float(cpu_value.rstrip("+"))
155
+
156
+ return float(cpu_value)
157
+
158
+
159
+ class KubernetesCredentialsError(Exception):
160
+ pass
161
+
162
+
163
+ def has_k8s_credentials():
164
+ """
165
+ Fast check for K8s credentials - works both in-cluster and external.
166
+ No network calls, no imports needed.
167
+ """
168
+ # Check 1: In-cluster service account
169
+ if (
170
+ Path("/var/run/secrets/kubernetes.io/serviceaccount/token").exists()
171
+ and Path("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt").exists()
172
+ ):
173
+ return True
174
+
175
+ # Check 2: Kubeconfig file
176
+ kubeconfig_path = os.environ.get("KUBECONFIG", os.path.expanduser("~/.kube/config"))
177
+ return Path(kubeconfig_path).exists()
178
+
179
+
180
+ def check_kubetorch_versions(response):
181
+ from kubetorch import __version__ as python_client_version, VersionMismatchError
182
+
183
+ try:
184
+ data = response.json()
185
+ except ValueError:
186
+ # older nginx proxy versions won't return a JSON
187
+ return
188
+
189
+ helm_installed_version = data.get("version")
190
+ if not helm_installed_version:
191
+ logger.debug("No 'version' found in health check response")
192
+ return
193
+
194
+ if python_client_version != helm_installed_version:
195
+ msg = (
196
+ f"client={python_client_version}, cluster={helm_installed_version}. "
197
+ "To suppress this error, set the environment variable "
198
+ "`KUBETORCH_IGNORE_VERSION_MISMATCH=1`."
199
+ )
200
+ if not os.getenv("KUBETORCH_IGNORE_VERSION_MISMATCH"):
201
+ raise VersionMismatchError(msg)
202
+
203
+ warnings.warn(f"Kubetorch version mismatch: {msg}")
204
+
205
+
206
+ def extract_config_from_nginx_health_check(response):
207
+ """Extract the config from the nginx health check response."""
208
+ try:
209
+ data = response.json()
210
+ except ValueError:
211
+ return
212
+ config = data.get("config", {})
213
+ return config
214
+
215
+
216
+ def wait_for_port_forward(
217
+ process,
218
+ local_port,
219
+ timeout=30,
220
+ health_endpoint: str = None,
221
+ validate_kubetorch_versions: bool = True,
222
+ ):
223
+ from kubetorch import VersionMismatchError
224
+
225
+ start_time = time.time()
226
+ while time.time() - start_time < timeout:
227
+ if process.poll() is not None:
228
+ stderr = process.stderr.read().decode()
229
+ raise Exception(f"Port forward failed: {stderr}")
230
+
231
+ try:
232
+ # Check if socket is open
233
+ with socket.create_connection(("localhost", local_port), timeout=1):
234
+ if not health_endpoint:
235
+ # If we are not checking HTTP (ex: rsync)
236
+ return True
237
+ except OSError:
238
+ time.sleep(0.2)
239
+ continue
240
+
241
+ if health_endpoint:
242
+ url = f"http://localhost:{local_port}" + health_endpoint
243
+ try:
244
+ # Check if HTTP endpoint is ready
245
+ resp = httpx.get(url, timeout=2)
246
+ if resp.status_code == 200:
247
+ if validate_kubetorch_versions:
248
+ check_kubetorch_versions(resp)
249
+ # Extract config to set outside of function scope
250
+ config = extract_config_from_nginx_health_check(resp)
251
+ return config
252
+ except VersionMismatchError as e:
253
+ raise e
254
+ except Exception as e:
255
+ logger.debug(f"Waiting for HTTP endpoint to be ready: {e}")
256
+
257
+ time.sleep(0.2)
258
+
259
+ raise TimeoutError("Timeout waiting for port forward to be ready")
260
+
261
+
262
+ def pod_is_running(pod: V1Pod):
263
+ return pod.status.phase == "Running" and pod.metadata.deletion_timestamp is None
264
+
265
+
266
+ def check_loki_enabled(core_api: CoreV1Api = None) -> bool:
267
+ """Check if loki is enabled"""
268
+ if core_api is None:
269
+ load_kubeconfig()
270
+ core_api = CoreV1Api()
271
+
272
+ kt_namespace = globals.config.install_namespace
273
+
274
+ try:
275
+ # Check if loki-gateway service exists in the namespace
276
+ core_api.read_namespaced_service(name=LOKI_GATEWAY_SERVICE_NAME, namespace=kt_namespace)
277
+ logger.debug(f"Loki gateway service found in namespace {kt_namespace}")
278
+ except ApiException as e:
279
+ if e.status == 404:
280
+ logger.debug(f"Loki gateway service not found in namespace {kt_namespace}")
281
+ return False
282
+
283
+ # Additional permission-proof check: try to ping the internal Loki gateway URL
284
+ # Needed if running in kubernetes without full kubeconfig permissions
285
+ if is_running_in_kubernetes():
286
+ try:
287
+ loki_url = f"http://loki-gateway.{kt_namespace}.svc.cluster.local/loki/api/v1/labels"
288
+ response = httpx.get(loki_url, timeout=2)
289
+ if response.status_code == 200:
290
+ logger.debug("Loki gateway is reachable")
291
+ else:
292
+ logger.debug(f"Loki gateway returned status {response.status_code}")
293
+ return False
294
+ except Exception as e:
295
+ logger.debug(f"Loki gateway is not reachable: {e}")
296
+ return False
297
+
298
+ return True
299
+
300
+
301
+ def check_prometheus_enabled(core_api: CoreV1Api = None) -> bool:
302
+ """Check if prometheus is enabled"""
303
+ if core_api is None:
304
+ load_kubeconfig()
305
+ core_api = CoreV1Api()
306
+
307
+ kt_namespace = globals.config.install_namespace
308
+
309
+ try:
310
+ # Check if prometheus service exists in the namespace
311
+ core_api.read_namespaced_service(name=PROMETHEUS_SERVICE_NAME, namespace=kt_namespace)
312
+ logger.debug(f"Metrics service found in namespace {kt_namespace}")
313
+ except ApiException as e:
314
+ if e.status == 404:
315
+ logger.debug(f"Metrics service not found in namespace {kt_namespace}")
316
+ return False
317
+
318
+ # If running inside the cluster, try hitting the service directly
319
+ if is_running_in_kubernetes():
320
+ try:
321
+ prom_url = f"http://{PROMETHEUS_SERVICE_NAME}.{kt_namespace}.svc.cluster.local/api/v1/labels"
322
+ response = httpx.get(prom_url, timeout=2)
323
+ if response.status_code == 200:
324
+ logger.debug("Metrics service is reachable")
325
+ else:
326
+ logger.debug(f"Metrics service returned status {response.status_code}")
327
+ return False
328
+ except Exception as e:
329
+ logger.debug(f"Metrics service is not reachable: {e}")
330
+ return False
331
+
332
+ return True
333
+
334
+
335
+ def nested_override(original_dict, override_dict):
336
+ for key, value in override_dict.items():
337
+ if key in original_dict:
338
+ if isinstance(original_dict[key], dict) and isinstance(value, dict):
339
+ # Recursively merge nested dictionaries
340
+ nested_override(original_dict[key], value)
341
+ else:
342
+ original_dict[key] = value # Custom wins
343
+ else:
344
+ original_dict[key] = value