kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,377 @@
1
+ import os
2
+ import socket
3
+ import time
4
+ import warnings
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Literal, Optional, Union
8
+
9
+ import httpx
10
+ from kubernetes.client import ApiException, CoreV1Api, V1Pod
11
+ from kubernetes.utils import parse_quantity
12
+
13
+ from kubetorch import globals
14
+ from kubetorch.logger import get_logger
15
+ from kubetorch.servers.http.utils import is_running_in_kubernetes
16
+ from kubetorch.serving.constants import (
17
+ KUBETORCH_MONITORING_NAMESPACE,
18
+ LOKI_GATEWAY_SERVICE_NAME,
19
+ PROMETHEUS_SERVICE_NAME,
20
+ PROMETHEUS_URL,
21
+ )
22
+ from kubetorch.utils import load_kubeconfig
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class GPUConfig:
29
+ count: Optional[int] = None
30
+ memory: Optional[str] = None
31
+ sharing_type: Optional[Literal["memory", "fraction"]] = None
32
+ gpu_memory: Optional[str] = None
33
+ gpu_fraction: Optional[str] = None
34
+ gpu_type: Optional[str] = None
35
+
36
+ def __post_init__(self):
37
+ self.validate()
38
+
39
+ def validate(self) -> bool:
40
+ if self.count and not isinstance(self.count, int):
41
+ raise ValueError("GPU count must an int")
42
+
43
+ if self.sharing_type == "memory":
44
+ if not self.gpu_memory:
45
+ raise ValueError(
46
+ "GPU memory must be specified when using memory sharing"
47
+ )
48
+ elif self.sharing_type == "fraction":
49
+ if not self.gpu_fraction:
50
+ raise ValueError(
51
+ "GPU fraction must be specified when using fraction sharing"
52
+ )
53
+ try:
54
+ fraction = float(self.gpu_fraction)
55
+ if not 0 < fraction <= 1:
56
+ raise ValueError("GPU fraction must be between 0 and 1")
57
+ except ValueError:
58
+ raise ValueError("GPU fraction must be a valid float between 0 and 1")
59
+
60
+ return True
61
+
62
+ def to_dict(self) -> dict:
63
+ base_dict = {
64
+ "sharing_type": self.sharing_type,
65
+ "count": self.count,
66
+ }
67
+
68
+ if self.memory is not None:
69
+ base_dict["memory"] = self.memory
70
+
71
+ if self.sharing_type == "memory" and self.gpu_memory:
72
+ base_dict["gpu_memory"] = self.gpu_memory
73
+ if self.sharing_type == "fraction" and self.gpu_fraction:
74
+ # Convert to millicores format
75
+ fraction = float(self.gpu_fraction)
76
+ base_dict["gpu_fraction"] = f"{int(fraction * 1000)}m"
77
+ if self.gpu_type is not None:
78
+ base_dict["gpu_type"] = self.gpu_type
79
+
80
+ return base_dict
81
+
82
+
83
+ class RequestedPodResources:
84
+ """Resources requested in a Kubetorch cluster/compute object. Note these are the values we receive
85
+ from launcher the cluster via a Sky dryrun."""
86
+
87
+ # Default overhead percentages to account for filesystem overhead, OS files, logs, container runtime, etc.
88
+ MEMORY_OVERHEAD = 0.20
89
+ CPU_OVERHEAD = 0.10
90
+ DISK_OVERHEAD = 0.15
91
+ GPU_OVERHEAD = 0.0
92
+
93
+ MIN_MEMORY_GB = 0.1 # 100Mi minimum
94
+ MIN_CPU_CORES = 0.1 # 100m minimum
95
+
96
+ CPU_STEPS = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 192]
97
+ MEMORY_STEPS = [0.5, 1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768]
98
+
99
+ def __init__(
100
+ self,
101
+ memory: Optional[Union[str, float]] = None,
102
+ cpus: Optional[Union[int, float]] = None,
103
+ disk_size: Optional[int] = None,
104
+ num_gpus: Optional[Union[int, dict]] = None,
105
+ ):
106
+
107
+ self.memory = (
108
+ max(float(memory), self.MIN_MEMORY_GB) if memory is not None else None
109
+ )
110
+ self.cpus = (
111
+ max(self.normalize_cpu_value(cpus), self.MIN_CPU_CORES)
112
+ if cpus is not None
113
+ else None
114
+ )
115
+ self.disk_size = disk_size
116
+ self.num_gpus = num_gpus
117
+
118
+ def __str__(self):
119
+ # Example: RequestedPodResources(memory=16.0, cpus=4.0, disk=NoneGB, gpus={'A10G': 1})"
120
+ disk_str = f"{self.disk_size}GB" if self.disk_size is not None else "None"
121
+ memory = f"{self.memory}GB" if self.memory is not None else "None"
122
+
123
+ return (
124
+ f"RequestedPodResources(memory={memory}, cpus={self.cpus}, disk_size={disk_str}, "
125
+ f"num_gpus={self.num_gpus})"
126
+ )
127
+
128
+ def __repr__(self):
129
+ return (
130
+ f"RequestedPodResources(memory={self.memory}, cpus={self.cpus}, "
131
+ f"disk_size={self.disk_size}, num_gpus={self.num_gpus})"
132
+ )
133
+
134
+ @classmethod
135
+ def cpu_for_resource_request(cls, cpu_val: int = None):
136
+ if cpu_val is None:
137
+ return None
138
+
139
+ # Ensure minimum CPU value
140
+ cpu_val = max(float(cpu_val), cls.MIN_CPU_CORES)
141
+
142
+ # Convert to millicores (ex: '4.0' -> 4000m)
143
+ return f"{int(float(cpu_val) * 1000)}m"
144
+
145
+ @classmethod
146
+ def memory_for_resource_request(cls, memory_val: Union[str, float, int] = None):
147
+ if memory_val is None:
148
+ return None
149
+
150
+ # If it's a number, treat as GB
151
+ if isinstance(memory_val, (int, float)):
152
+ gb_val = max(float(memory_val), cls.MIN_MEMORY_GB)
153
+ memory_val = f"{gb_val}Gi"
154
+
155
+ # Validate the string - if invalid will throw a ValueError
156
+ parse_quantity(str(memory_val))
157
+
158
+ return str(memory_val)
159
+
160
+ @classmethod
161
+ def normalize_cpu_value(
162
+ cls, cpu_value: Optional[Union[int, str, float]]
163
+ ) -> Optional[float]:
164
+ """Convert CPU value to float, handling string values with '+' allowed by Sky and Kubetorch."""
165
+ if cpu_value is None:
166
+ return None
167
+
168
+ if isinstance(cpu_value, str):
169
+ # Strip the '+' if present and convert to float
170
+ return float(cpu_value.rstrip("+"))
171
+
172
+ return float(cpu_value)
173
+
174
+
175
+ def check_kubetorch_versions(response):
176
+ from kubetorch import __version__ as python_client_version, VersionMismatchError
177
+
178
+ try:
179
+ data = response.json()
180
+ except ValueError:
181
+ # older nginx proxy versions won't return a JSON
182
+ return
183
+
184
+ helm_installed_version = data.get("version")
185
+ if not helm_installed_version:
186
+ logger.debug("No 'version' found in health check response")
187
+ return
188
+
189
+ if python_client_version != helm_installed_version:
190
+ msg = (
191
+ f"client={python_client_version}, cluster={helm_installed_version}. "
192
+ "To suppress this error, set the environment variable "
193
+ "`KUBETORCH_IGNORE_VERSION_MISMATCH=1`."
194
+ )
195
+ if not os.getenv("KUBETORCH_IGNORE_VERSION_MISMATCH"):
196
+ raise VersionMismatchError(msg)
197
+
198
+ warnings.warn(f"Kubetorch version mismatch: {msg}")
199
+
200
+
201
+ def extract_config_from_nginx_health_check(response):
202
+ """Extract the config from the nginx health check response."""
203
+ try:
204
+ data = response.json()
205
+ except ValueError:
206
+ return
207
+ config = data.get("config", {})
208
+ return config
209
+
210
+
211
+ def wait_for_port_forward(
212
+ process,
213
+ local_port,
214
+ timeout=30,
215
+ health_endpoint: str = None,
216
+ validate_kubetorch_versions: bool = True,
217
+ ):
218
+ from kubetorch import VersionMismatchError
219
+
220
+ start_time = time.time()
221
+ while time.time() - start_time < timeout:
222
+ if process.poll() is not None:
223
+ stderr = process.stderr.read().decode()
224
+ raise Exception(f"Port forward failed: {stderr}")
225
+
226
+ try:
227
+ # Check if socket is open
228
+ with socket.create_connection(("localhost", local_port), timeout=1):
229
+ if not health_endpoint:
230
+ # If we are not checking HTTP (ex: rsync)
231
+ return True
232
+ except OSError:
233
+ time.sleep(0.2)
234
+ continue
235
+
236
+ if health_endpoint:
237
+ url = f"http://localhost:{local_port}" + health_endpoint
238
+ try:
239
+ # Check if HTTP endpoint is ready
240
+ resp = httpx.get(url, timeout=2)
241
+ if resp.status_code == 200:
242
+ if validate_kubetorch_versions:
243
+ check_kubetorch_versions(resp)
244
+ # Extract config to set outside of function scope
245
+ config = extract_config_from_nginx_health_check(resp)
246
+ return config
247
+ except VersionMismatchError as e:
248
+ raise e
249
+ except Exception as e:
250
+ logger.debug(f"Waiting for HTTP endpoint to be ready: {e}")
251
+
252
+ time.sleep(0.2)
253
+
254
+ raise TimeoutError("Timeout waiting for port forward to be ready")
255
+
256
+
257
+ def pod_is_running(pod: V1Pod):
258
+ return pod.status.phase == "Running" and pod.metadata.deletion_timestamp is None
259
+
260
+
261
+ def check_loki_enabled(core_api: CoreV1Api = None) -> bool:
262
+ """Check if loki is enabled"""
263
+ if core_api is None:
264
+ load_kubeconfig()
265
+ core_api = CoreV1Api()
266
+
267
+ kt_namespace = globals.config.install_namespace
268
+
269
+ try:
270
+ # Check if loki-gateway service exists in the namespace
271
+ core_api.read_namespaced_service(
272
+ name=LOKI_GATEWAY_SERVICE_NAME, namespace=kt_namespace
273
+ )
274
+ logger.debug(f"Loki gateway service found in namespace {kt_namespace}")
275
+ except ApiException as e:
276
+ if e.status == 404:
277
+ logger.debug(f"Loki gateway service not found in namespace {kt_namespace}")
278
+ return False
279
+
280
+ # Additional permission-proof check: try to ping the internal Loki gateway URL
281
+ # Needed if running in kubernetes without full kubeconfig permissions
282
+ if is_running_in_kubernetes():
283
+ try:
284
+ loki_url = f"http://loki-gateway.{kt_namespace}.svc.cluster.local/loki/api/v1/labels"
285
+ response = httpx.get(loki_url, timeout=2)
286
+ if response.status_code == 200:
287
+ logger.debug("Loki gateway is reachable")
288
+ else:
289
+ logger.debug(f"Loki gateway returned status {response.status_code}")
290
+ return False
291
+ except Exception as e:
292
+ logger.debug(f"Loki gateway is not reachable: {e}")
293
+ return False
294
+
295
+ return True
296
+
297
+
298
+ def check_prometheus_enabled(
299
+ prometheus_url: str, namespace: str, core_api: CoreV1Api = None
300
+ ) -> bool:
301
+ """Check if Prometheus is enabled and reachable."""
302
+ if prometheus_url and prometheus_url != PROMETHEUS_URL:
303
+ return True
304
+
305
+ if core_api is None:
306
+ load_kubeconfig()
307
+ core_api = CoreV1Api()
308
+
309
+ is_in_kubernetes = is_running_in_kubernetes()
310
+
311
+ # Check namespace exists
312
+ try:
313
+ core_api.read_namespace(name=namespace)
314
+ except ApiException as e:
315
+ if e.status == 404:
316
+ logger.debug(f"Prometheus namespace not found: {namespace}")
317
+ return False
318
+
319
+ # Check Prometheus service exists
320
+ try:
321
+ core_api.read_namespaced_service(
322
+ name=PROMETHEUS_SERVICE_NAME, namespace=namespace
323
+ )
324
+ logger.debug(f"Prometheus service found in namespace {namespace}")
325
+ except ApiException as e:
326
+ if e.status == 404:
327
+ logger.debug(f"Prometheus service not found: {PROMETHEUS_SERVICE_NAME}")
328
+ return False
329
+
330
+ # If running inside the cluster, try hitting the service directly
331
+ if is_in_kubernetes:
332
+ try:
333
+ response = httpx.get(PROMETHEUS_URL, timeout=2)
334
+ if response.status_code == 200:
335
+ logger.debug("Prometheus is reachable and healthy")
336
+ else:
337
+ logger.debug(f"Prometheus returned status {response.status_code}")
338
+ return False
339
+ except Exception as e:
340
+ logger.debug(f"Prometheus is not reachable: {e}")
341
+ return False
342
+
343
+ return True
344
+
345
+
346
+ def check_tempo_enabled(core_api: CoreV1Api = None) -> bool:
347
+ if core_api is None:
348
+ load_kubeconfig()
349
+ core_api = CoreV1Api()
350
+
351
+ try:
352
+ otel = core_api.read_namespaced_service(
353
+ name="kubetorch-otel-opentelemetry-collector",
354
+ namespace=KUBETORCH_MONITORING_NAMESPACE,
355
+ )
356
+ tempo = core_api.read_namespaced_service(
357
+ name="kubetorch-otel-tempo-distributor",
358
+ namespace=KUBETORCH_MONITORING_NAMESPACE,
359
+ )
360
+ return otel is not None and tempo is not None
361
+
362
+ except ApiException as e:
363
+ if e.status == 404:
364
+ return False
365
+ raise
366
+
367
+
368
+ def nested_override(original_dict, override_dict):
369
+ for key, value in override_dict.items():
370
+ if key in original_dict:
371
+ if isinstance(original_dict[key], dict) and isinstance(value, dict):
372
+ # Recursively merge nested dictionaries
373
+ nested_override(original_dict[key], value)
374
+ else:
375
+ original_dict[key] = value # Custom wins
376
+ else:
377
+ original_dict[key] = value
kubetorch/utils.py ADDED
@@ -0,0 +1,284 @@
1
+ import enum
2
+ import importlib.util
3
+ import json
4
+ import os
5
+ import re
6
+ import subprocess
7
+ import sys
8
+ from datetime import datetime, timezone
9
+ from io import StringIO
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ from kubernetes import client, config
14
+
15
+ from kubetorch.constants import DEFAULT_KUBECONFIG_PATH, MAX_USERNAME_LENGTH
16
+ from kubetorch.globals import config as kt_config
17
+ from kubetorch.logger import get_logger
18
+ from kubetorch.resources.callables.utils import get_local_install_path
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ def extract_host_port(url: str):
24
+ """Extract host and port when needed separately from a URL."""
25
+ p = urlparse(url)
26
+ return p.hostname, (p.port or (443 if p.scheme == "https" else 80))
27
+
28
+
29
+ def http_to_ws(url: str) -> str:
30
+ """Convert HTTP/HTTPS URLs to WebSocket URLs, or return as-is if already WS."""
31
+ if url.startswith("https://"):
32
+ return "wss://" + url[len("https://") :]
33
+ if url.startswith("http://"):
34
+ return "ws://" + url[len("http://") :]
35
+ if url.startswith(("ws://", "wss://")):
36
+ return url # already WebSocket URL
37
+ # Default to ws:// for bare hostnames
38
+ return "ws://" + url
39
+
40
+
41
+ def validate_username(username):
42
+ if username is None: # will be used in case we run kt config user username
43
+ return username
44
+ # Kubernetes requires service names to follow DNS-1035 label standards
45
+ original_username = username # if an exception is raised because the username is invalid, we want to print the original provided name
46
+ username = username.lower().replace("_", "-").replace("/", "-")
47
+ # Make sure the first character is a letter
48
+ if not re.match(r"^[a-z]", username):
49
+ # Strip out all the characters before the first letter with a regex
50
+ username = re.sub(r"^[^a-z]*", "", username)
51
+ username = username[:MAX_USERNAME_LENGTH]
52
+ # Make sure username doesn't end or start with a hyphen
53
+ if username.startswith("-") or username.endswith("-"):
54
+ username = username.strip("-")
55
+ reserved = ["kt", "kubetorch", "knative"]
56
+ if username in reserved:
57
+ raise ValueError(
58
+ f"{original_username} is one of the reserved names: {', '.join(reserved)}"
59
+ )
60
+ if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", username):
61
+ raise ValueError(f"{original_username} must be a valid k8s name")
62
+ return username
63
+
64
+
65
+ def load_kubeconfig():
66
+ try:
67
+ config.load_incluster_config()
68
+ except config.config_exception.ConfigException:
69
+ kubeconfig_path = os.getenv("KUBECONFIG") or DEFAULT_KUBECONFIG_PATH
70
+ abs_path = Path(kubeconfig_path).expanduser()
71
+ if not abs_path.exists():
72
+ raise FileNotFoundError(f"Kubeconfig file not found in path: {abs_path}")
73
+ config.load_kube_config(str(abs_path))
74
+
75
+
76
+ def current_git_branch():
77
+ try:
78
+ # For CI env
79
+ branch = (
80
+ os.environ.get("GITHUB_HEAD_REF") # PRs: source branch name
81
+ or os.environ.get("GITHUB_REF_NAME") # Pushes: branch name
82
+ or os.environ.get("CI_COMMIT_REF_NAME") # GitLab
83
+ or os.environ.get("CIRCLE_BRANCH") # CircleCI
84
+ )
85
+ if not branch:
86
+ branch = (
87
+ subprocess.check_output(
88
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
89
+ stderr=subprocess.DEVNULL,
90
+ )
91
+ .decode("utf-8")
92
+ .strip()
93
+ )
94
+ return branch
95
+ except Exception as e:
96
+ logger.debug(f"Failed to load current git branch: {e}")
97
+ return None
98
+
99
+
100
+ def iso_timestamp_to_nanoseconds(timestamp):
101
+ if timestamp is None:
102
+ dt = datetime.now(timezone.utc)
103
+ elif isinstance(timestamp, datetime):
104
+ dt = timestamp
105
+ if dt.tzinfo is None:
106
+ dt = dt.replace(tzinfo=timezone.utc)
107
+ elif isinstance(timestamp, str):
108
+ dt = datetime.fromisoformat(timestamp)
109
+ if dt.tzinfo is None:
110
+ dt = dt.replace(tzinfo=timezone.utc)
111
+ else:
112
+ raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
113
+ return int(dt.timestamp() * 1e9)
114
+
115
+
116
+ def get_kt_install_url(freeze: bool = False):
117
+ # Returns:
118
+ # str: kubetorch install url
119
+ # bool: whether to install in editable mode
120
+ if kt_config.install_url or freeze:
121
+ return kt_config.install_url, False
122
+ local_kt_path = get_local_install_path("kubetorch")
123
+ if local_kt_path and (Path(local_kt_path) / "pyproject.toml").exists():
124
+ return local_kt_path, True
125
+ else:
126
+ # If the user is using uv, sometimes the pip freeze command won't return the full "kubetorch @ ..." install
127
+ # URL. However, running `uv pip freeze` will return the full URL, even when the user isn't actually in a uv
128
+ # venv. So we default to using `uv pip freeze` if uv is installed, otherwise we use the regular `pip freeze`.
129
+ uv_installed = importlib.util.find_spec("uv") is not None
130
+ # check if installed from presigned url
131
+ freeze_cmd = (
132
+ "pip freeze | grep kubetorch"
133
+ if not uv_installed
134
+ else "uv pip freeze | grep kubetorch"
135
+ )
136
+ output = subprocess.run(
137
+ freeze_cmd,
138
+ shell=True,
139
+ capture_output=True,
140
+ text=True,
141
+ ).stdout
142
+ if not output.startswith("kubetorch @ "):
143
+ raise Exception(
144
+ "Could not find kubetorch version to install on the pod. You must either set the "
145
+ "``install_url`` in the config or ``KT_INSTALL_URL`` env var, have kubetorch installed "
146
+ "locally, or set ``compute.freeze`` to ``True``."
147
+ )
148
+ install_url = output[len("kubetorch @ ") :].strip()
149
+ return install_url, False
150
+
151
+
152
+ class LogVerbosity(str, enum.Enum):
153
+ DEBUG = "debug"
154
+ INFO = "info"
155
+ CRITICAL = "critical"
156
+
157
+
158
+ ####################################################################################################
159
+ # Logging redirection
160
+ ####################################################################################################
161
+ class StreamTee(object):
162
+ def __init__(self, instream, outstreams):
163
+ self.instream = instream
164
+ self.outstreams = outstreams
165
+
166
+ def write(self, message):
167
+ self.instream.write(message)
168
+ for stream in self.outstreams:
169
+ if message:
170
+ stream.write(message)
171
+ # We flush here to ensure that the logs are written to the file immediately
172
+ # see https://github.com/run-house/runhouse/pull/724
173
+ stream.flush()
174
+
175
+ def writelines(self, lines):
176
+ self.instream.writelines(lines)
177
+ for stream in self.outstreams:
178
+ stream.writelines(lines)
179
+ stream.flush()
180
+
181
+ def flush(self):
182
+ self.instream.flush()
183
+ for stream in self.outstreams:
184
+ stream.flush()
185
+
186
+ def __getattr__(self, item):
187
+ # Needed in case someone calls a method on instream, such as Ray calling sys.stdout.istty()
188
+ return getattr(self.instream, item)
189
+
190
+
191
+ class capture_stdout:
192
+ """Context manager for capturing stdout to a file, list, or stream, while still printing to stdout."""
193
+
194
+ def __init__(self, output=None):
195
+ self.output = output
196
+ self._stream = None
197
+
198
+ def __enter__(self):
199
+ if self.output is None:
200
+ self.output = StringIO()
201
+
202
+ if isinstance(self.output, str):
203
+ self._stream = open(self.output, "w")
204
+ else:
205
+ self._stream = self.output
206
+ sys.stdout = StreamTee(sys.stdout, [self])
207
+ sys.stderr = StreamTee(sys.stderr, [self])
208
+ return self
209
+
210
+ def write(self, message):
211
+ self._stream.write(message)
212
+
213
+ def flush(self):
214
+ self._stream.flush()
215
+
216
+ @property
217
+ def stream(self):
218
+ if isinstance(self.output, str):
219
+ return open(self.output, "r")
220
+ return self._stream
221
+
222
+ def list(self):
223
+ if isinstance(self.output, str):
224
+ return self.stream.readlines()
225
+ return (self.stream.getvalue() or "").splitlines()
226
+
227
+ def __str__(self):
228
+ return self.stream.getvalue()
229
+
230
+ def __exit__(self, exc_type, exc_val, exc_tb):
231
+ if hasattr(sys.stdout, "instream"):
232
+ sys.stdout = sys.stdout.instream
233
+ if hasattr(sys.stderr, "instream"):
234
+ sys.stderr = sys.stderr.instream
235
+ self._stream.close()
236
+ return False
237
+
238
+
239
+ ####################################################################################################
240
+ # Logging formatting
241
+ ####################################################################################################
242
+ class ColoredFormatter:
243
+ COLORS = {
244
+ "black": "\u001b[30m",
245
+ "red": "\u001b[31m",
246
+ "green": "\u001b[32m",
247
+ "yellow": "\u001b[33m",
248
+ "blue": "\u001b[34m",
249
+ "magenta": "\u001b[35m",
250
+ "cyan": "\u001b[36m",
251
+ "white": "\u001b[37m",
252
+ "bold": "\u001b[1m",
253
+ "italic": "\u001b[3m",
254
+ "reset": "\u001b[0m",
255
+ }
256
+
257
+ @classmethod
258
+ def get_color(cls, color: str):
259
+ return cls.COLORS.get(color, "")
260
+
261
+
262
+ class ServerLogsFormatter:
263
+ def __init__(self, name: str = None):
264
+ self.name = name
265
+ self.start_color = ColoredFormatter.get_color("cyan")
266
+ self.reset_color = ColoredFormatter.get_color("reset")
267
+
268
+
269
+ def initialize_k8s_clients():
270
+ """Initialize Kubernetes API clients."""
271
+ load_kubeconfig()
272
+ return (
273
+ client.CoreV1Api(),
274
+ client.CustomObjectsApi(),
275
+ client.AppsV1Api(),
276
+ )
277
+
278
+
279
+ def string_to_dict(value):
280
+ try:
281
+ result = json.loads(value or "{}")
282
+ return result if isinstance(result, dict) else {}
283
+ except (json.JSONDecodeError, TypeError):
284
+ return {}