kubetorch 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kubetorch might be problematic. Click here for more details.
- kubetorch/__init__.py +60 -0
- kubetorch/cli.py +1985 -0
- kubetorch/cli_utils.py +1025 -0
- kubetorch/config.py +453 -0
- kubetorch/constants.py +18 -0
- kubetorch/docs/Makefile +18 -0
- kubetorch/docs/__init__.py +0 -0
- kubetorch/docs/_ext/json_globaltoc.py +42 -0
- kubetorch/docs/api/cli.rst +10 -0
- kubetorch/docs/api/python/app.rst +21 -0
- kubetorch/docs/api/python/cls.rst +19 -0
- kubetorch/docs/api/python/compute.rst +25 -0
- kubetorch/docs/api/python/config.rst +11 -0
- kubetorch/docs/api/python/fn.rst +19 -0
- kubetorch/docs/api/python/image.rst +14 -0
- kubetorch/docs/api/python/secret.rst +18 -0
- kubetorch/docs/api/python/volumes.rst +13 -0
- kubetorch/docs/api/python.rst +101 -0
- kubetorch/docs/conf.py +69 -0
- kubetorch/docs/index.rst +20 -0
- kubetorch/docs/requirements.txt +5 -0
- kubetorch/globals.py +285 -0
- kubetorch/logger.py +59 -0
- kubetorch/resources/__init__.py +0 -0
- kubetorch/resources/callables/__init__.py +0 -0
- kubetorch/resources/callables/cls/__init__.py +0 -0
- kubetorch/resources/callables/cls/cls.py +157 -0
- kubetorch/resources/callables/fn/__init__.py +0 -0
- kubetorch/resources/callables/fn/fn.py +133 -0
- kubetorch/resources/callables/module.py +1416 -0
- kubetorch/resources/callables/utils.py +174 -0
- kubetorch/resources/compute/__init__.py +0 -0
- kubetorch/resources/compute/app.py +261 -0
- kubetorch/resources/compute/compute.py +2596 -0
- kubetorch/resources/compute/decorators.py +139 -0
- kubetorch/resources/compute/rbac.py +74 -0
- kubetorch/resources/compute/utils.py +1114 -0
- kubetorch/resources/compute/websocket.py +137 -0
- kubetorch/resources/images/__init__.py +1 -0
- kubetorch/resources/images/image.py +414 -0
- kubetorch/resources/images/images.py +74 -0
- kubetorch/resources/secrets/__init__.py +2 -0
- kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
- kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
- kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
- kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
- kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
- kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
- kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
- kubetorch/resources/secrets/secret.py +238 -0
- kubetorch/resources/secrets/secret_factory.py +70 -0
- kubetorch/resources/secrets/utils.py +209 -0
- kubetorch/resources/volumes/__init__.py +0 -0
- kubetorch/resources/volumes/volume.py +365 -0
- kubetorch/servers/__init__.py +0 -0
- kubetorch/servers/http/__init__.py +0 -0
- kubetorch/servers/http/distributed_utils.py +3223 -0
- kubetorch/servers/http/http_client.py +730 -0
- kubetorch/servers/http/http_server.py +1788 -0
- kubetorch/servers/http/server_metrics.py +278 -0
- kubetorch/servers/http/utils.py +728 -0
- kubetorch/serving/__init__.py +0 -0
- kubetorch/serving/autoscaling.py +173 -0
- kubetorch/serving/base_service_manager.py +363 -0
- kubetorch/serving/constants.py +83 -0
- kubetorch/serving/deployment_service_manager.py +478 -0
- kubetorch/serving/knative_service_manager.py +519 -0
- kubetorch/serving/raycluster_service_manager.py +582 -0
- kubetorch/serving/service_manager.py +18 -0
- kubetorch/serving/templates/deployment_template.yaml +17 -0
- kubetorch/serving/templates/knative_service_template.yaml +19 -0
- kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
- kubetorch/serving/templates/pod_template.yaml +194 -0
- kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
- kubetorch/serving/templates/raycluster_template.yaml +35 -0
- kubetorch/serving/templates/service_template.yaml +21 -0
- kubetorch/serving/templates/workerset_template.yaml +36 -0
- kubetorch/serving/utils.py +377 -0
- kubetorch/utils.py +284 -0
- kubetorch-0.2.0.dist-info/METADATA +121 -0
- kubetorch-0.2.0.dist-info/RECORD +93 -0
- kubetorch-0.2.0.dist-info/WHEEL +4 -0
- kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import base64
|
|
4
|
+
import enum
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import pickle
|
|
9
|
+
import re
|
|
10
|
+
import socket
|
|
11
|
+
import subprocess
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
from contextvars import ContextVar
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
import httpx
|
|
18
|
+
|
|
19
|
+
import jinja2
|
|
20
|
+
import websockets
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
from kubetorch.constants import LOCALHOST
|
|
24
|
+
from kubetorch.logger import get_logger
|
|
25
|
+
from kubetorch.serving.constants import DEFAULT_DEBUG_PORT
|
|
26
|
+
from kubetorch.utils import ServerLogsFormatter
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
RSYNC_PORT = 873
|
|
31
|
+
|
|
32
|
+
DEFAULT_ALLOWED_SERIALIZATION = "json"
|
|
33
|
+
|
|
34
|
+
MAGIC_CALL_KWARGS = ["workers", "restart_procs"]
|
|
35
|
+
|
|
36
|
+
LOG_CONFIG = {
|
|
37
|
+
"version": 1,
|
|
38
|
+
"disable_existing_loggers": False,
|
|
39
|
+
"formatters": {},
|
|
40
|
+
"handlers": {},
|
|
41
|
+
"root": {"level": "INFO", "handlers": []},
|
|
42
|
+
"loggers": {
|
|
43
|
+
"uvicorn": {"level": "INFO", "handlers": [], "propagate": True},
|
|
44
|
+
"uvicorn.access": {"level": "INFO", "handlers": [], "propagate": True},
|
|
45
|
+
"kubetorch": {"level": "INFO", "handlers": [], "propagate": True},
|
|
46
|
+
},
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def ensure_structured_logging():
|
|
51
|
+
"""Add our structured JSON handler to all loggers without removing user's handlers. We do this both when we
|
|
52
|
+
set up the HTTP server and also after re-importing user code, as their modules might include logging setup
|
|
53
|
+
of their own."""
|
|
54
|
+
import logging
|
|
55
|
+
import logging.handlers
|
|
56
|
+
import os
|
|
57
|
+
import sys
|
|
58
|
+
|
|
59
|
+
from pythonjsonlogger import jsonlogger
|
|
60
|
+
|
|
61
|
+
# First ensure logging is initialized - this is crucial!
|
|
62
|
+
# If no handlers exist, we need to initialize the logging system
|
|
63
|
+
root_logger = logging.getLogger()
|
|
64
|
+
|
|
65
|
+
# Create our JSON formatter
|
|
66
|
+
json_formatter = jsonlogger.JsonFormatter(
|
|
67
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(request_id)s - %(pod)s",
|
|
68
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Create our structured handler (we keep using sys.stdout so user and kt logs
|
|
72
|
+
# both appear in pod logs; our stdout wrapper will mirror to the original stream)
|
|
73
|
+
structured_handler = logging.StreamHandler(sys.stdout)
|
|
74
|
+
structured_handler.setFormatter(json_formatter)
|
|
75
|
+
structured_handler.name = "kubetorch_structured" # Name it so we can identify it
|
|
76
|
+
|
|
77
|
+
# Set root logger level based on KT_LOG_LEVEL if it's set
|
|
78
|
+
kt_log_level = os.getenv("KT_LOG_LEVEL")
|
|
79
|
+
if kt_log_level:
|
|
80
|
+
kt_log_level = kt_log_level.upper()
|
|
81
|
+
root_logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
|
|
82
|
+
|
|
83
|
+
# Check if our handler is already there (to avoid adding duplicates)
|
|
84
|
+
existing_structured = None
|
|
85
|
+
for h in root_logger.handlers:
|
|
86
|
+
if getattr(h, "name", None) == "kubetorch_structured":
|
|
87
|
+
existing_structured = h
|
|
88
|
+
break
|
|
89
|
+
|
|
90
|
+
if not existing_structured:
|
|
91
|
+
# Add our structured handler alongside any user-installed handlers
|
|
92
|
+
# so both formats are emitted to pod logs
|
|
93
|
+
root_logger.addHandler(structured_handler)
|
|
94
|
+
|
|
95
|
+
# Ensure request context fields are attached to all records even if the user
|
|
96
|
+
# reconfigured logging and removed our filters. Do this idempotently.
|
|
97
|
+
class _ContextFieldsFilter(logging.Filter):
|
|
98
|
+
def filter(self, record):
|
|
99
|
+
if not hasattr(record, "request_id") or record.request_id in (None, "-"):
|
|
100
|
+
try:
|
|
101
|
+
record.request_id = request_id_ctx_var.get("-")
|
|
102
|
+
except Exception:
|
|
103
|
+
record.request_id = "-"
|
|
104
|
+
if not hasattr(record, "pod") or record.pod in (None, ""):
|
|
105
|
+
record.pod = os.getenv("POD_NAME", "unknown-pod")
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
# Attach the filter to root and all of its handlers (idempotent: duplicate adds are ignored)
|
|
109
|
+
context_filter = _ContextFieldsFilter()
|
|
110
|
+
try:
|
|
111
|
+
root_logger.addFilter(context_filter)
|
|
112
|
+
except Exception:
|
|
113
|
+
pass
|
|
114
|
+
for h in root_logger.handlers:
|
|
115
|
+
try:
|
|
116
|
+
h.addFilter(context_filter)
|
|
117
|
+
except Exception:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
# Ensure print_redirect logger also has proper configuration
|
|
121
|
+
# This is important for the StreamToLogger output
|
|
122
|
+
print_logger = logging.getLogger("print_redirect")
|
|
123
|
+
print_logger.setLevel(logging.INFO)
|
|
124
|
+
# Ensure it propagates to root so the structured handler formats it
|
|
125
|
+
print_logger.propagate = True
|
|
126
|
+
try:
|
|
127
|
+
print_logger.addFilter(context_filter)
|
|
128
|
+
except Exception:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
request_id_ctx_var: ContextVar[str] = ContextVar("request_id", default="-")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class StartupError(Exception):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class PodTerminatedError(Exception):
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
pod_name: str = "unknown",
|
|
143
|
+
reason: str = "Unknown",
|
|
144
|
+
status_code: int = 503,
|
|
145
|
+
events: List[dict] = None,
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
events: List of dicts with keys:
|
|
149
|
+
- timestamp: datetime
|
|
150
|
+
- reason: str
|
|
151
|
+
- message: str
|
|
152
|
+
|
|
153
|
+
sample event:
|
|
154
|
+
{
|
|
155
|
+
'timestamp': datetime.datetime(2025, 7, 13, 16, 45, 46, tzinfo=tzutc()),
|
|
156
|
+
'reason': 'Evicted',
|
|
157
|
+
'message': 'The node was low on resource: memory. Threshold quantity: 100Mi, available: 3404Ki.'
|
|
158
|
+
}
|
|
159
|
+
"""
|
|
160
|
+
self.pod_name = pod_name
|
|
161
|
+
self.reason = reason
|
|
162
|
+
self.status_code = status_code
|
|
163
|
+
self.events = events or []
|
|
164
|
+
super().__init__(str(self))
|
|
165
|
+
|
|
166
|
+
def __getstate__(self):
|
|
167
|
+
"""Serialize the exception state for transmission over HTTP."""
|
|
168
|
+
# Convert datetime objects to ISO format strings for JSON serialization
|
|
169
|
+
serialized_events = []
|
|
170
|
+
for event in self.events:
|
|
171
|
+
serialized_event = event.copy()
|
|
172
|
+
if "timestamp" in serialized_event:
|
|
173
|
+
timestamp = serialized_event["timestamp"]
|
|
174
|
+
# Convert datetime to string if needed
|
|
175
|
+
if hasattr(timestamp, "isoformat"):
|
|
176
|
+
serialized_event["timestamp"] = timestamp.isoformat()
|
|
177
|
+
serialized_events.append(serialized_event)
|
|
178
|
+
|
|
179
|
+
return {
|
|
180
|
+
"pod_name": self.pod_name,
|
|
181
|
+
"reason": self.reason,
|
|
182
|
+
"status_code": self.status_code,
|
|
183
|
+
"events": serialized_events,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
def __setstate__(self, state):
|
|
187
|
+
"""Reconstruct the exception from serialized state."""
|
|
188
|
+
self.pod_name = state["pod_name"]
|
|
189
|
+
self.reason = state["reason"]
|
|
190
|
+
self.status_code = state["status_code"]
|
|
191
|
+
self.events = state["events"]
|
|
192
|
+
|
|
193
|
+
@classmethod
|
|
194
|
+
def from_dict(cls, state):
|
|
195
|
+
"""Reconstruct the exception from a dictionary state."""
|
|
196
|
+
return cls(
|
|
197
|
+
pod_name=state.get("pod_name", "unknown"),
|
|
198
|
+
reason=state.get("reason", "Unknown"),
|
|
199
|
+
status_code=state.get("status_code", 503),
|
|
200
|
+
events=state.get("events", []),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def evicted(self) -> bool:
|
|
205
|
+
"""True if pod was evicted (ex: node pressure, preemption)."""
|
|
206
|
+
return self.reason == "Evicted" or any(
|
|
207
|
+
"Evicted" in event["reason"] for event in self.events
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def oom_killed(self) -> bool:
|
|
212
|
+
"""True if pod was evicted due to OOM."""
|
|
213
|
+
return self.reason == "OOMKilled" or any(
|
|
214
|
+
"OOMKilled" in event["reason"] for event in self.events
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def __str__(self):
|
|
218
|
+
events_str = "\n".join(
|
|
219
|
+
f"{e['timestamp']} {e['reason']}: {e['message']}" for e in self.events
|
|
220
|
+
)
|
|
221
|
+
base_exc = (
|
|
222
|
+
f"\nPod Name: {self.pod_name}\n"
|
|
223
|
+
f"Reason: {self.reason}\n"
|
|
224
|
+
f"Status Code: {self.status_code}\n"
|
|
225
|
+
)
|
|
226
|
+
if self.events:
|
|
227
|
+
base_exc += f"Recent Events:\n{events_str}"
|
|
228
|
+
return base_exc
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class WorkerMembershipChanged(Exception):
|
|
232
|
+
"""Raised when worker pods are added or removed during distributed execution."""
|
|
233
|
+
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
added_ips: set = None,
|
|
237
|
+
removed_ips: set = None,
|
|
238
|
+
previous_ips: set = None,
|
|
239
|
+
current_ips: set = None,
|
|
240
|
+
message: str = None,
|
|
241
|
+
):
|
|
242
|
+
# Support both explicit construction and reconstruction from message
|
|
243
|
+
if message and not (added_ips or removed_ips):
|
|
244
|
+
import ast
|
|
245
|
+
|
|
246
|
+
# Reconstruct from message
|
|
247
|
+
import re
|
|
248
|
+
|
|
249
|
+
self.added_ips = set()
|
|
250
|
+
self.removed_ips = set()
|
|
251
|
+
self.previous_ips = set()
|
|
252
|
+
self.current_ips = set()
|
|
253
|
+
|
|
254
|
+
if "removed during execution:" in message:
|
|
255
|
+
match = re.search(r"removed during execution: ({.*?})", message)
|
|
256
|
+
if match:
|
|
257
|
+
self.removed_ips = ast.literal_eval(match.group(1))
|
|
258
|
+
elif "added during execution:" in message:
|
|
259
|
+
match = re.search(r"added during execution: ({.*?})", message)
|
|
260
|
+
if match:
|
|
261
|
+
self.added_ips = ast.literal_eval(match.group(1))
|
|
262
|
+
else:
|
|
263
|
+
# Normal construction
|
|
264
|
+
self.added_ips = added_ips or set()
|
|
265
|
+
self.removed_ips = removed_ips or set()
|
|
266
|
+
self.previous_ips = previous_ips or set()
|
|
267
|
+
self.current_ips = current_ips or set()
|
|
268
|
+
|
|
269
|
+
if removed_ips:
|
|
270
|
+
message = f"Critical: {len(removed_ips)} worker(s) removed during execution: {removed_ips}"
|
|
271
|
+
elif added_ips:
|
|
272
|
+
message = f"Warning: {len(added_ips)} worker(s) added during execution: {added_ips}"
|
|
273
|
+
else:
|
|
274
|
+
message = "Worker membership changed"
|
|
275
|
+
|
|
276
|
+
super().__init__(message)
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def is_critical(self) -> bool:
|
|
280
|
+
"""Returns True if workers were removed (critical for training)."""
|
|
281
|
+
return bool(self.removed_ips)
|
|
282
|
+
|
|
283
|
+
def __getstate__(self):
|
|
284
|
+
"""Serialize the exception state."""
|
|
285
|
+
return {
|
|
286
|
+
"message": str(self),
|
|
287
|
+
"added_ips": list(self.added_ips),
|
|
288
|
+
"removed_ips": list(self.removed_ips),
|
|
289
|
+
"previous_ips": list(self.previous_ips),
|
|
290
|
+
"current_ips": list(self.current_ips),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def from_dict(cls, data):
|
|
295
|
+
"""Reconstruct from serialized state."""
|
|
296
|
+
return cls(
|
|
297
|
+
added_ips=set(data.get("added_ips", [])),
|
|
298
|
+
removed_ips=set(data.get("removed_ips", [])),
|
|
299
|
+
previous_ips=set(data.get("previous_ips", [])),
|
|
300
|
+
current_ips=set(data.get("current_ips", [])),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class StreamType(str, enum.Enum):
|
|
305
|
+
CLI = "cli"
|
|
306
|
+
HTTP_CLIENT = "http_client"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def clean_and_validate_k8s_name(name: str, allow_full_length: bool = True) -> str:
|
|
310
|
+
"""Clean and validate a name for K8s compatibility.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
name: The name to clean and validate
|
|
314
|
+
allow_full_length: If True, allows and intelligently trims full pod names to 63 chars,
|
|
315
|
+
preserving k8s-generated portions.
|
|
316
|
+
If False, limits to 40 chars to leave room for k8s suffixes.
|
|
317
|
+
"""
|
|
318
|
+
max_k8s_name_length = 63 # max length allowed by k8s
|
|
319
|
+
max_base_name_length = (
|
|
320
|
+
40 # max module name length to account for added k8s suffixes
|
|
321
|
+
)
|
|
322
|
+
# Regex to comply with k8s service name requirements
|
|
323
|
+
cleaned_name = re.sub(r"[^a-z0-9-]|^[-]|[-]$", "", name.lower())
|
|
324
|
+
if not cleaned_name:
|
|
325
|
+
raise ValueError("Name must contain at least one alphanumeric character.")
|
|
326
|
+
|
|
327
|
+
max_length = max_k8s_name_length if allow_full_length else max_base_name_length
|
|
328
|
+
|
|
329
|
+
if len(cleaned_name) > max_length:
|
|
330
|
+
if not allow_full_length:
|
|
331
|
+
# For a user provided module name, raise an exception
|
|
332
|
+
error_msg = (
|
|
333
|
+
f"Name length {len(cleaned_name)} exceeds {max_length} characters. "
|
|
334
|
+
"Must leave room for Kubernetes-added suffixes."
|
|
335
|
+
)
|
|
336
|
+
raise ValueError(error_msg)
|
|
337
|
+
|
|
338
|
+
match = re.search(r"(-\d+)?-deployment-[a-z0-9]+-[a-z0-9]+", cleaned_name)
|
|
339
|
+
if match:
|
|
340
|
+
k8s_part = match.group(0)
|
|
341
|
+
k8s_start_idx = match.start()
|
|
342
|
+
|
|
343
|
+
prefix = cleaned_name[:k8s_start_idx]
|
|
344
|
+
suffix = cleaned_name[k8s_start_idx + len(k8s_part) :]
|
|
345
|
+
|
|
346
|
+
total_excess = len(cleaned_name) - max_length
|
|
347
|
+
|
|
348
|
+
# If we need to trim, handle each part
|
|
349
|
+
if total_excess > 0:
|
|
350
|
+
# Handle prefix trimming
|
|
351
|
+
if prefix:
|
|
352
|
+
segments = prefix.split("-")
|
|
353
|
+
while (
|
|
354
|
+
len("-".join(segments)) + len(k8s_part) + len(suffix)
|
|
355
|
+
> max_length
|
|
356
|
+
):
|
|
357
|
+
if len(segments) > 1:
|
|
358
|
+
segments.pop()
|
|
359
|
+
else:
|
|
360
|
+
segments[0] = segments[0][:-1]
|
|
361
|
+
prefix = "-".join(segments)
|
|
362
|
+
|
|
363
|
+
# Handle suffix trimming if still needed
|
|
364
|
+
remaining_length = max_length - (len(prefix) + len(k8s_part))
|
|
365
|
+
if remaining_length > 0:
|
|
366
|
+
suffix_segments = suffix.split("-")
|
|
367
|
+
clean_segments = []
|
|
368
|
+
current_length = 0
|
|
369
|
+
for seg in suffix_segments:
|
|
370
|
+
# Only add segment if it's at least 2 chars so the name doesn't look cut off
|
|
371
|
+
if (
|
|
372
|
+
len(seg) >= 2
|
|
373
|
+
and current_length + len(seg) + 1 <= remaining_length
|
|
374
|
+
):
|
|
375
|
+
clean_segments.append(seg)
|
|
376
|
+
current_length += len(seg) + 1
|
|
377
|
+
suffix = "-".join(clean_segments)
|
|
378
|
+
else:
|
|
379
|
+
suffix = ""
|
|
380
|
+
|
|
381
|
+
cleaned_name = (
|
|
382
|
+
(prefix + "-" if prefix else "")
|
|
383
|
+
+ k8s_part
|
|
384
|
+
+ ("-" + suffix if suffix else "")
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return cleaned_name
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def is_running_in_kubernetes():
|
|
391
|
+
"""
|
|
392
|
+
Determines if the current Python process is running inside a Kubernetes pod.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
bool: True if running in Kubernetes, False otherwise
|
|
396
|
+
"""
|
|
397
|
+
# Method 1: Check for Kubernetes service environment variables
|
|
398
|
+
if os.environ.get("KUBERNETES_SERVICE_HOST") is not None:
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
# Method 2: Check for the existence of the Kubernetes service account token file
|
|
402
|
+
if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount/token"):
|
|
403
|
+
return True
|
|
404
|
+
|
|
405
|
+
return False
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _get_rendered_template(
|
|
409
|
+
template_file: str, template_dir: str, **template_vars
|
|
410
|
+
) -> str:
|
|
411
|
+
"""Helper function to set up and render a template."""
|
|
412
|
+
template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
|
|
413
|
+
template_env = jinja2.Environment(
|
|
414
|
+
loader=template_loader,
|
|
415
|
+
keep_trailing_newline=True,
|
|
416
|
+
trim_blocks=True,
|
|
417
|
+
lstrip_blocks=True,
|
|
418
|
+
enable_async=False,
|
|
419
|
+
autoescape=False,
|
|
420
|
+
)
|
|
421
|
+
template = template_env.get_template(template_file)
|
|
422
|
+
return template.render(**template_vars)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def load_template(template_file: str, template_dir: str, **template_vars) -> dict:
|
|
426
|
+
"""Load and render a single YAML document template."""
|
|
427
|
+
rendered = _get_rendered_template(template_file, template_dir, **template_vars)
|
|
428
|
+
return yaml.safe_load(rendered)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def load_multi_yaml_template(
|
|
432
|
+
template_file: str, template_dir: str, **template_vars
|
|
433
|
+
) -> dict:
|
|
434
|
+
"""Load and render a multi-document YAML template."""
|
|
435
|
+
rendered = _get_rendered_template(template_file, template_dir, **template_vars)
|
|
436
|
+
return {"items": list(yaml.safe_load_all(rendered))}
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def generate_unique_request_id(endpoint: str, timestamp: str) -> str:
|
|
440
|
+
"""Generates a unique request id, based on the method/function endpoint and the call timestamp"""
|
|
441
|
+
raw = f"{endpoint}_{timestamp}"
|
|
442
|
+
unique_id = hashlib.sha256(raw.encode()).hexdigest()[:10]
|
|
443
|
+
return unique_id
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def print_log_stream_client(message, last_timestamp, print_pod_name: bool = False):
|
|
447
|
+
formatter = ServerLogsFormatter()
|
|
448
|
+
if message.get("streams"):
|
|
449
|
+
for stream in message["streams"]:
|
|
450
|
+
pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
|
|
451
|
+
for value in stream["values"]:
|
|
452
|
+
# Skip if we've already seen this timestamp
|
|
453
|
+
if last_timestamp is not None and value[0] <= last_timestamp:
|
|
454
|
+
continue
|
|
455
|
+
last_timestamp = value[0]
|
|
456
|
+
|
|
457
|
+
log_line = json.loads(value[1])
|
|
458
|
+
log_name = log_line.get("name")
|
|
459
|
+
if log_name == "print_redirect":
|
|
460
|
+
message = log_line.get("message")
|
|
461
|
+
print(
|
|
462
|
+
f"{pod_name}{formatter.start_color}{message}{formatter.reset_color}"
|
|
463
|
+
)
|
|
464
|
+
elif log_name != "uvicorn.access":
|
|
465
|
+
formatted_log = f"{pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}"
|
|
466
|
+
print(
|
|
467
|
+
f"{formatter.start_color}{formatted_log}{formatter.reset_color}"
|
|
468
|
+
)
|
|
469
|
+
return last_timestamp
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def print_log_stream_cli(message, last_timestamp, print_pod_name: bool = False):
|
|
473
|
+
if message.get("streams"):
|
|
474
|
+
for stream in message["streams"]:
|
|
475
|
+
pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
|
|
476
|
+
for value in stream["values"]:
|
|
477
|
+
# Skip if we've already seen this timestamp
|
|
478
|
+
if last_timestamp is not None and value[0] <= last_timestamp:
|
|
479
|
+
continue
|
|
480
|
+
last_timestamp = value[0]
|
|
481
|
+
log_line = value[1]
|
|
482
|
+
try:
|
|
483
|
+
log_line = json.loads(log_line)
|
|
484
|
+
log_name = log_line.get("name")
|
|
485
|
+
if log_name == "print_redirect":
|
|
486
|
+
continue
|
|
487
|
+
# the print output will be printed in line 250. We need the "print_redirect"
|
|
488
|
+
# log type only for log streaming in the http client, so we could filter out
|
|
489
|
+
# the print outputs for a specific request ID. For the CLI --follow option, we
|
|
490
|
+
# print all logs, so at the moment we don't need to filter by request_id.
|
|
491
|
+
elif log_name != "uvicorn.access":
|
|
492
|
+
formatted_log = f"({pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}".strip()
|
|
493
|
+
print(formatted_log)
|
|
494
|
+
except json.JSONDecodeError:
|
|
495
|
+
print(log_line.strip())
|
|
496
|
+
|
|
497
|
+
return last_timestamp
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
async def stream_logs_websocket_helper(
|
|
501
|
+
uri,
|
|
502
|
+
stop_event,
|
|
503
|
+
stream_type: StreamType = StreamType.HTTP_CLIENT,
|
|
504
|
+
print_pod_name: bool = False,
|
|
505
|
+
):
|
|
506
|
+
"""Stream logs using Loki's websocket tail endpoint"""
|
|
507
|
+
websocket = None
|
|
508
|
+
try:
|
|
509
|
+
# Track the last timestamp we've seen to avoid duplicates
|
|
510
|
+
last_timestamp = None
|
|
511
|
+
# Track when we should stop
|
|
512
|
+
stop_time = None
|
|
513
|
+
|
|
514
|
+
# Add timeout to prevent hanging connections
|
|
515
|
+
websocket = await websockets.connect(
|
|
516
|
+
uri,
|
|
517
|
+
close_timeout=10, # Max time to wait for close handshake
|
|
518
|
+
ping_interval=20, # Send ping every 20 seconds
|
|
519
|
+
ping_timeout=10, # Wait 10 seconds for pong
|
|
520
|
+
)
|
|
521
|
+
try:
|
|
522
|
+
while True:
|
|
523
|
+
# If stop event is set, start counting down
|
|
524
|
+
if stop_event.is_set() and stop_time is None:
|
|
525
|
+
stop_time = time.time() + 2 # 2 seconds grace period
|
|
526
|
+
|
|
527
|
+
# If we're past the grace period, exit
|
|
528
|
+
if stop_time is not None and time.time() > stop_time:
|
|
529
|
+
break
|
|
530
|
+
|
|
531
|
+
try:
|
|
532
|
+
# Use shorter timeout during grace period
|
|
533
|
+
timeout = 0.1 if stop_time is not None else 1.0
|
|
534
|
+
message = await asyncio.wait_for(websocket.recv(), timeout=timeout)
|
|
535
|
+
try:
|
|
536
|
+
message = json.loads(message)
|
|
537
|
+
except json.JSONDecodeError:
|
|
538
|
+
message = message
|
|
539
|
+
|
|
540
|
+
if stream_type == StreamType.HTTP_CLIENT:
|
|
541
|
+
last_timestamp = print_log_stream_client(
|
|
542
|
+
message, last_timestamp, print_pod_name
|
|
543
|
+
)
|
|
544
|
+
elif stream_type == StreamType.CLI:
|
|
545
|
+
last_timestamp = print_log_stream_cli(
|
|
546
|
+
message, last_timestamp, print_pod_name
|
|
547
|
+
)
|
|
548
|
+
except asyncio.TimeoutError:
|
|
549
|
+
# Timeout is expected, just continue the loop
|
|
550
|
+
continue
|
|
551
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
552
|
+
logger.debug(f"WebSocket connection closed: {str(e)}")
|
|
553
|
+
break
|
|
554
|
+
finally:
|
|
555
|
+
if websocket:
|
|
556
|
+
try:
|
|
557
|
+
# Use wait_for to prevent hanging on close
|
|
558
|
+
await asyncio.wait_for(websocket.close(), timeout=1.0)
|
|
559
|
+
except (asyncio.TimeoutError, Exception):
|
|
560
|
+
pass
|
|
561
|
+
except Exception as e:
|
|
562
|
+
logger.error(f"Error in websocket stream: {e}")
|
|
563
|
+
finally:
|
|
564
|
+
# Ensure websocket is closed even if we didn't enter the try block
|
|
565
|
+
if websocket:
|
|
566
|
+
try:
|
|
567
|
+
# Use wait_for to prevent hanging on close
|
|
568
|
+
await asyncio.wait_for(websocket.close(), timeout=1.0)
|
|
569
|
+
except (asyncio.TimeoutError, Exception):
|
|
570
|
+
pass
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def clear_debugging_sessions():
|
|
574
|
+
"""Clear any existing debugging sessions when a module is redeployed or pod is terminated."""
|
|
575
|
+
try:
|
|
576
|
+
import web_pdb
|
|
577
|
+
|
|
578
|
+
if web_pdb.WebPdb.active_instance is not None:
|
|
579
|
+
logger.info("Clearing existing debugging session")
|
|
580
|
+
try:
|
|
581
|
+
web_pdb.WebPdb.active_instance.remove_trace()
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.warning(f"Error removing trace: {e}")
|
|
584
|
+
web_pdb.WebPdb.active_instance = None
|
|
585
|
+
|
|
586
|
+
except ImportError:
|
|
587
|
+
# web_pdb not installed, nothing to clean up
|
|
588
|
+
pass
|
|
589
|
+
except Exception as e:
|
|
590
|
+
logger.warning(f"Error clearing debugging session: {e}")
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
# Register cleanup function to run at exit
|
|
594
|
+
atexit.register(clear_debugging_sessions)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def deep_breakpoint(debug_port: int = DEFAULT_DEBUG_PORT):
|
|
598
|
+
"""
|
|
599
|
+
Similar to Python's built-in `breakpoint()`, but can be used deep inside distributed code. For SPMD-style
|
|
600
|
+
distributed code like PyTorch, be sure to only call this from one process (e.g. the rank 0 process) to avoid
|
|
601
|
+
blocking all processes in the distributed group.
|
|
602
|
+
"""
|
|
603
|
+
# Check if madbg is installed, if not, install it
|
|
604
|
+
try:
|
|
605
|
+
import web_pdb
|
|
606
|
+
except ImportError:
|
|
607
|
+
install_cmd = "uv pip install --system web-pdb"
|
|
608
|
+
import subprocess
|
|
609
|
+
|
|
610
|
+
print("Pdb debugger not found, installing it...")
|
|
611
|
+
# Run the install command and propagate logs
|
|
612
|
+
subprocess.run(install_cmd, shell=True, check=True, text=True)
|
|
613
|
+
print("Pdb installed successfully.")
|
|
614
|
+
|
|
615
|
+
print(
|
|
616
|
+
"Distributed breakpoint activated. To attach a debugger, run the following command:"
|
|
617
|
+
)
|
|
618
|
+
print(
|
|
619
|
+
f"kt debug {os.environ['POD_NAME']} --port {debug_port} --namespace {os.environ['POD_NAMESPACE']}"
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
import web_pdb
|
|
623
|
+
|
|
624
|
+
pdb = web_pdb.WebPdb.active_instance
|
|
625
|
+
try:
|
|
626
|
+
if pdb is None:
|
|
627
|
+
pdb = web_pdb.WebPdb(host="", port=debug_port, patch_stdstreams=False)
|
|
628
|
+
else:
|
|
629
|
+
# If the debugger is still attached reset trace to a new location
|
|
630
|
+
pdb.remove_trace()
|
|
631
|
+
|
|
632
|
+
# Set the frame to the caller's frame
|
|
633
|
+
pdb.set_trace(sys._getframe(1)) # pylint: disable=protected-access
|
|
634
|
+
except Exception as e:
|
|
635
|
+
# Only clean up if there was an error setting up the debugger
|
|
636
|
+
if pdb:
|
|
637
|
+
pdb.remove_trace()
|
|
638
|
+
web_pdb.WebPdb.active_instance = None
|
|
639
|
+
raise e
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def wait_for_app_start(
|
|
643
|
+
port, health_check: str, process: subprocess.Popen, timeout: int = 60
|
|
644
|
+
):
|
|
645
|
+
"""
|
|
646
|
+
Wait until the app is ready. If health_check if provided, will send HTTP requests to check, otherwise
|
|
647
|
+
will wait until something is listening on the port.
|
|
648
|
+
"""
|
|
649
|
+
host = LOCALHOST
|
|
650
|
+
port = int(port)
|
|
651
|
+
logger.debug(f"Trying to connect to http://{host}:{port}{health_check or ''}")
|
|
652
|
+
start_time = time.time()
|
|
653
|
+
|
|
654
|
+
if health_check:
|
|
655
|
+
if not health_check.startswith("/"):
|
|
656
|
+
health_check = f"/{health_check}"
|
|
657
|
+
url = f"http://{LOCALHOST}:{port}{health_check}"
|
|
658
|
+
while time.time() - start_time < timeout:
|
|
659
|
+
if process.poll() is not None and process.poll() != 0:
|
|
660
|
+
raise RuntimeError(f"App exited with code {process.poll()}")
|
|
661
|
+
try:
|
|
662
|
+
response = httpx.get(url)
|
|
663
|
+
if response.status_code == 200:
|
|
664
|
+
return True
|
|
665
|
+
except httpx.ConnectError:
|
|
666
|
+
pass
|
|
667
|
+
time.sleep(0.5)
|
|
668
|
+
raise TimeoutError(
|
|
669
|
+
f"App did not become healthy on {url} within {timeout} seconds"
|
|
670
|
+
)
|
|
671
|
+
else:
|
|
672
|
+
# Fallback to socket check
|
|
673
|
+
while time.time() - start_time < timeout:
|
|
674
|
+
if process.poll() is not None and process.poll() != 0:
|
|
675
|
+
raise RuntimeError(f"App exited with code {process.poll()}")
|
|
676
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
677
|
+
sock.settimeout(1)
|
|
678
|
+
try:
|
|
679
|
+
sock.connect((host, port))
|
|
680
|
+
return True
|
|
681
|
+
except (ConnectionRefusedError, socket.timeout):
|
|
682
|
+
time.sleep(0.5)
|
|
683
|
+
raise TimeoutError(
|
|
684
|
+
f"Failed to detect open port {port} for app {url} within {timeout} seconds"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _serialize_body(body: dict, serialization: str):
|
|
689
|
+
if body is None:
|
|
690
|
+
return {}
|
|
691
|
+
|
|
692
|
+
# We only serialize args and kwargs, other settings like "workers" and "restart_procs" are needed inside
|
|
693
|
+
# the http_server, outside the serialization boundary (e.g. the distributed processes)
|
|
694
|
+
# We break them out here as separate params
|
|
695
|
+
body = body or {}
|
|
696
|
+
|
|
697
|
+
for kwarg in MAGIC_CALL_KWARGS:
|
|
698
|
+
if kwarg in body.get("kwargs", {}):
|
|
699
|
+
body[kwarg] = body["kwargs"].pop(kwarg)
|
|
700
|
+
|
|
701
|
+
if serialization == "pickle":
|
|
702
|
+
args_data = {"args": body.pop("args"), "kwargs": body.pop("kwargs")}
|
|
703
|
+
pickled_args = pickle.dumps(args_data or {})
|
|
704
|
+
encoded_args = base64.b64encode(pickled_args).decode("utf-8")
|
|
705
|
+
body["data"] = encoded_args
|
|
706
|
+
return body
|
|
707
|
+
return body or {}
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def _deserialize_response(response, serialization: str):
|
|
711
|
+
if serialization == "pickle":
|
|
712
|
+
response_data = response.json()
|
|
713
|
+
if isinstance(response_data, list):
|
|
714
|
+
# If this is a response from an spmd call, it's a list of serialized dicts
|
|
715
|
+
unpickled_results = []
|
|
716
|
+
for resp in response_data:
|
|
717
|
+
if "data" in resp:
|
|
718
|
+
encoded_result = resp["data"]
|
|
719
|
+
pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
|
|
720
|
+
resp = pickle.loads(pickled_result)
|
|
721
|
+
unpickled_results.append(resp)
|
|
722
|
+
return unpickled_results
|
|
723
|
+
if "data" in response_data:
|
|
724
|
+
encoded_result = response_data["data"]
|
|
725
|
+
pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
|
|
726
|
+
return pickle.loads(pickled_result)
|
|
727
|
+
return response_data
|
|
728
|
+
return response.json()
|