kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,722 @@
1
+ import asyncio
2
+ import atexit
3
+ import base64
4
+ import enum
5
+ import hashlib
6
+ import json
7
+ import os
8
+ import pickle
9
+ import re
10
+ import socket
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ from contextvars import ContextVar
15
+ from typing import List
16
+
17
+ import httpx
18
+
19
+ import jinja2
20
+ import websockets
21
+ import yaml
22
+
23
+ from kubetorch.constants import LOCALHOST
24
+ from kubetorch.logger import get_logger
25
+ from kubetorch.serving.constants import DEFAULT_DEBUG_PORT
26
+ from kubetorch.utils import ServerLogsFormatter
27
+
28
+ logger = get_logger(__name__)
29
+
30
+ RSYNC_PORT = 873
31
+
32
+ DEFAULT_ALLOWED_SERIALIZATION = "json"
33
+
34
+ MAGIC_CALL_KWARGS = ["workers", "restart_procs"]
35
+
36
+ LOG_CONFIG = {
37
+ "version": 1,
38
+ "disable_existing_loggers": False,
39
+ "formatters": {},
40
+ "handlers": {},
41
+ "root": {"level": "INFO", "handlers": []},
42
+ "loggers": {
43
+ "uvicorn": {"level": "INFO", "handlers": [], "propagate": True},
44
+ "uvicorn.access": {"level": "INFO", "handlers": [], "propagate": True},
45
+ "kubetorch": {"level": "INFO", "handlers": [], "propagate": True},
46
+ },
47
+ }
48
+
49
+
50
+ def ensure_structured_logging():
51
+ """Add our structured JSON handler to all loggers without removing user's handlers. We do this both when we
52
+ set up the HTTP server and also after re-importing user code, as their modules might include logging setup
53
+ of their own."""
54
+ import logging
55
+ import logging.handlers
56
+ import os
57
+ import sys
58
+
59
+ from pythonjsonlogger import jsonlogger
60
+
61
+ # First ensure logging is initialized - this is crucial!
62
+ # If no handlers exist, we need to initialize the logging system
63
+ root_logger = logging.getLogger()
64
+
65
+ # Create our JSON formatter
66
+ json_formatter = jsonlogger.JsonFormatter(
67
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(request_id)s - %(pod)s",
68
+ datefmt="%Y-%m-%d %H:%M:%S",
69
+ )
70
+
71
+ # Create our structured handler (we keep using sys.stdout so user and kt logs
72
+ # both appear in pod logs; our stdout wrapper will mirror to the original stream)
73
+ structured_handler = logging.StreamHandler(sys.stdout)
74
+ structured_handler.setFormatter(json_formatter)
75
+ structured_handler.name = "kubetorch_structured" # Name it so we can identify it
76
+
77
+ # Set root logger level based on KT_LOG_LEVEL if it's set
78
+ kt_log_level = os.getenv("KT_LOG_LEVEL")
79
+ if kt_log_level:
80
+ kt_log_level = kt_log_level.upper()
81
+ root_logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
82
+
83
+ # Check if our handler is already there (to avoid adding duplicates)
84
+ existing_structured = None
85
+ for h in root_logger.handlers:
86
+ if getattr(h, "name", None) == "kubetorch_structured":
87
+ existing_structured = h
88
+ break
89
+
90
+ if not existing_structured:
91
+ # Add our structured handler alongside any user-installed handlers
92
+ # so both formats are emitted to pod logs
93
+ root_logger.addHandler(structured_handler)
94
+
95
+ # Ensure request context fields are attached to all records even if the user
96
+ # reconfigured logging and removed our filters. Do this idempotently.
97
+ class _ContextFieldsFilter(logging.Filter):
98
+ def filter(self, record):
99
+ if not hasattr(record, "request_id") or record.request_id in (None, "-"):
100
+ try:
101
+ record.request_id = request_id_ctx_var.get("-")
102
+ except Exception:
103
+ record.request_id = "-"
104
+ if not hasattr(record, "pod") or record.pod in (None, ""):
105
+ record.pod = os.getenv("POD_NAME", "unknown-pod")
106
+ return True
107
+
108
+ # Attach the filter to root and all of its handlers (idempotent: duplicate adds are ignored)
109
+ context_filter = _ContextFieldsFilter()
110
+ try:
111
+ root_logger.addFilter(context_filter)
112
+ except Exception:
113
+ pass
114
+ for h in root_logger.handlers:
115
+ try:
116
+ h.addFilter(context_filter)
117
+ except Exception:
118
+ pass
119
+
120
+ # Ensure print_redirect logger also has proper configuration
121
+ # This is important for the StreamToLogger output
122
+ print_logger = logging.getLogger("print_redirect")
123
+ print_logger.setLevel(logging.INFO)
124
+ # Ensure it propagates to root so the structured handler formats it
125
+ print_logger.propagate = True
126
+ try:
127
+ print_logger.addFilter(context_filter)
128
+ except Exception:
129
+ pass
130
+
131
+
132
+ request_id_ctx_var: ContextVar[str] = ContextVar("request_id", default="-")
133
+
134
+
135
+ def collect_reload_modules(kt_home_dir_str: str) -> list:
136
+ """
137
+ Collect user modules from sys.modules that should be reloaded. Returns sorted list of modules (children to parents)
138
+ with file under kt_home_dir and excludes kubetorch.servers.
139
+ """
140
+ modules_to_reload = []
141
+
142
+ for mod_name, mod in sys.modules.items():
143
+ # Exclude kubetorch.servers (framework module that shouldn't be reloaded)
144
+ if mod_name == "kubetorch.servers" or mod_name.startswith("kubetorch.servers."):
145
+ continue
146
+
147
+ if not hasattr(mod, "__file__") or not mod.__file__:
148
+ continue
149
+
150
+ try:
151
+ mod_file = os.path.abspath(mod.__file__)
152
+
153
+ # Only include modules under the kt home directory (synced files)
154
+ if mod_file.startswith(kt_home_dir_str):
155
+ # Exclude modules from site-packages (Docker image files)
156
+ if sys.prefix and mod_file.startswith(os.path.abspath(sys.prefix)):
157
+ continue
158
+ if sys.base_prefix and sys.base_prefix != sys.prefix:
159
+ if mod_file.startswith(os.path.abspath(sys.base_prefix)):
160
+ continue
161
+
162
+ modules_to_reload.append(mod_name)
163
+ except (OSError, ValueError):
164
+ continue
165
+
166
+ # Sort from children to parents to ensure child modules are reloaded before their parent packages
167
+ modules_to_reload_sorted = sorted(modules_to_reload, key=lambda x: (x.count("."), len(x)), reverse=True)
168
+ return modules_to_reload_sorted
169
+
170
+
171
+ class StartupError(Exception):
172
+ pass
173
+
174
+
175
+ class PodTerminatedError(Exception):
176
+ def __init__(
177
+ self,
178
+ pod_name: str = "unknown",
179
+ reason: str = "Unknown",
180
+ status_code: int = 503,
181
+ events: List[dict] = None,
182
+ ):
183
+ """
184
+ events: List of dicts with keys:
185
+ - timestamp: datetime
186
+ - reason: str
187
+ - message: str
188
+
189
+ sample event:
190
+ {
191
+ 'timestamp': datetime.datetime(2025, 7, 13, 16, 45, 46, tzinfo=tzutc()),
192
+ 'reason': 'Evicted',
193
+ 'message': 'The node was low on resource: memory. Threshold quantity: 100Mi, available: 3404Ki.'
194
+ }
195
+ """
196
+ self.pod_name = pod_name
197
+ self.reason = reason
198
+ self.status_code = status_code
199
+ self.events = events or []
200
+ super().__init__(str(self))
201
+
202
+ def __getstate__(self):
203
+ """Serialize the exception state for transmission over HTTP."""
204
+ # Convert datetime objects to ISO format strings for JSON serialization
205
+ serialized_events = []
206
+ for event in self.events:
207
+ serialized_event = event.copy()
208
+ if "timestamp" in serialized_event:
209
+ timestamp = serialized_event["timestamp"]
210
+ # Convert datetime to string if needed
211
+ if hasattr(timestamp, "isoformat"):
212
+ serialized_event["timestamp"] = timestamp.isoformat()
213
+ serialized_events.append(serialized_event)
214
+
215
+ return {
216
+ "pod_name": self.pod_name,
217
+ "reason": self.reason,
218
+ "status_code": self.status_code,
219
+ "events": serialized_events,
220
+ }
221
+
222
+ def __setstate__(self, state):
223
+ """Reconstruct the exception from serialized state."""
224
+ self.pod_name = state["pod_name"]
225
+ self.reason = state["reason"]
226
+ self.status_code = state["status_code"]
227
+ self.events = state["events"]
228
+
229
+ @classmethod
230
+ def from_dict(cls, state):
231
+ """Reconstruct the exception from a dictionary state."""
232
+ return cls(
233
+ pod_name=state.get("pod_name", "unknown"),
234
+ reason=state.get("reason", "Unknown"),
235
+ status_code=state.get("status_code", 503),
236
+ events=state.get("events", []),
237
+ )
238
+
239
+ @property
240
+ def evicted(self) -> bool:
241
+ """True if pod was evicted (ex: node pressure, preemption)."""
242
+ return self.reason == "Evicted" or any("Evicted" in event["reason"] for event in self.events)
243
+
244
+ @property
245
+ def oom_killed(self) -> bool:
246
+ """True if pod was evicted due to OOM."""
247
+ return self.reason == "OOMKilled" or any("OOMKilled" in event["reason"] for event in self.events)
248
+
249
+ def __str__(self):
250
+ events_str = "\n".join(f"{e['timestamp']} {e['reason']}: {e['message']}" for e in self.events)
251
+ base_exc = f"\nPod Name: {self.pod_name}\n" f"Reason: {self.reason}\n" f"Status Code: {self.status_code}\n"
252
+ if self.events:
253
+ base_exc += f"Recent Events:\n{events_str}"
254
+ return base_exc
255
+
256
+
257
+ class WorkerMembershipChanged(Exception):
258
+ """Raised when worker pods are added or removed during distributed execution."""
259
+
260
+ def __init__(
261
+ self,
262
+ added_ips: set = None,
263
+ removed_ips: set = None,
264
+ previous_ips: set = None,
265
+ current_ips: set = None,
266
+ message: str = None,
267
+ ):
268
+ # Support both explicit construction and reconstruction from message
269
+ if message and not (added_ips or removed_ips):
270
+ import ast
271
+
272
+ # Reconstruct from message
273
+ import re
274
+
275
+ self.added_ips = set()
276
+ self.removed_ips = set()
277
+ self.previous_ips = set()
278
+ self.current_ips = set()
279
+
280
+ if "removed during execution:" in message:
281
+ match = re.search(r"removed during execution: ({.*?})", message)
282
+ if match:
283
+ self.removed_ips = ast.literal_eval(match.group(1))
284
+ elif "added during execution:" in message:
285
+ match = re.search(r"added during execution: ({.*?})", message)
286
+ if match:
287
+ self.added_ips = ast.literal_eval(match.group(1))
288
+ else:
289
+ # Normal construction
290
+ self.added_ips = added_ips or set()
291
+ self.removed_ips = removed_ips or set()
292
+ self.previous_ips = previous_ips or set()
293
+ self.current_ips = current_ips or set()
294
+
295
+ if removed_ips:
296
+ message = f"Critical: {len(removed_ips)} worker(s) removed during execution: {removed_ips}"
297
+ elif added_ips:
298
+ message = f"Warning: {len(added_ips)} worker(s) added during execution: {added_ips}"
299
+ else:
300
+ message = "Worker membership changed"
301
+
302
+ super().__init__(message)
303
+
304
+ @property
305
+ def is_critical(self) -> bool:
306
+ """Returns True if workers were removed (critical for training)."""
307
+ return bool(self.removed_ips)
308
+
309
+ def __getstate__(self):
310
+ """Serialize the exception state."""
311
+ return {
312
+ "message": str(self),
313
+ "added_ips": list(self.added_ips),
314
+ "removed_ips": list(self.removed_ips),
315
+ "previous_ips": list(self.previous_ips),
316
+ "current_ips": list(self.current_ips),
317
+ }
318
+
319
+ @classmethod
320
+ def from_dict(cls, data):
321
+ """Reconstruct from serialized state."""
322
+ return cls(
323
+ added_ips=set(data.get("added_ips", [])),
324
+ removed_ips=set(data.get("removed_ips", [])),
325
+ previous_ips=set(data.get("previous_ips", [])),
326
+ current_ips=set(data.get("current_ips", [])),
327
+ )
328
+
329
+
330
+ class StreamType(str, enum.Enum):
331
+ CLI = "cli"
332
+ HTTP_CLIENT = "http_client"
333
+
334
+
335
+ def clean_and_validate_k8s_name(name: str, allow_full_length: bool = True) -> str:
336
+ """Clean and validate a name for K8s compatibility.
337
+
338
+ Args:
339
+ name: The name to clean and validate
340
+ allow_full_length: If True, allows and intelligently trims full pod names to 63 chars,
341
+ preserving k8s-generated portions.
342
+ If False, limits to 40 chars to leave room for k8s suffixes.
343
+ """
344
+ max_k8s_name_length = 63 # max length allowed by k8s
345
+ max_base_name_length = 40 # max module name length to account for added k8s suffixes
346
+ # Regex to comply with k8s service name requirements
347
+ cleaned_name = re.sub(r"[^a-z0-9-]|^[-]|[-]$", "", name.lower())
348
+ if not cleaned_name:
349
+ raise ValueError("Name must contain at least one alphanumeric character.")
350
+
351
+ max_length = max_k8s_name_length if allow_full_length else max_base_name_length
352
+
353
+ if len(cleaned_name) > max_length:
354
+ if not allow_full_length:
355
+ # For a user provided module name, raise an exception
356
+ error_msg = (
357
+ f"Name length {len(cleaned_name)} exceeds {max_length} characters. "
358
+ "Must leave room for Kubernetes-added suffixes."
359
+ )
360
+ raise ValueError(error_msg)
361
+
362
+ match = re.search(r"(-\d+)?-deployment-[a-z0-9]+-[a-z0-9]+", cleaned_name)
363
+ if match:
364
+ k8s_part = match.group(0)
365
+ k8s_start_idx = match.start()
366
+
367
+ prefix = cleaned_name[:k8s_start_idx]
368
+ suffix = cleaned_name[k8s_start_idx + len(k8s_part) :]
369
+
370
+ total_excess = len(cleaned_name) - max_length
371
+
372
+ # If we need to trim, handle each part
373
+ if total_excess > 0:
374
+ # Handle prefix trimming
375
+ if prefix:
376
+ segments = prefix.split("-")
377
+ while len("-".join(segments)) + len(k8s_part) + len(suffix) > max_length:
378
+ if len(segments) > 1:
379
+ segments.pop()
380
+ else:
381
+ segments[0] = segments[0][:-1]
382
+ prefix = "-".join(segments)
383
+
384
+ # Handle suffix trimming if still needed
385
+ remaining_length = max_length - (len(prefix) + len(k8s_part))
386
+ if remaining_length > 0:
387
+ suffix_segments = suffix.split("-")
388
+ clean_segments = []
389
+ current_length = 0
390
+ for seg in suffix_segments:
391
+ # Only add segment if it's at least 2 chars so the name doesn't look cut off
392
+ if len(seg) >= 2 and current_length + len(seg) + 1 <= remaining_length:
393
+ clean_segments.append(seg)
394
+ current_length += len(seg) + 1
395
+ suffix = "-".join(clean_segments)
396
+ else:
397
+ suffix = ""
398
+
399
+ cleaned_name = (prefix + "-" if prefix else "") + k8s_part + ("-" + suffix if suffix else "")
400
+
401
+ return cleaned_name
402
+
403
+
404
+ def is_running_in_kubernetes():
405
+ """
406
+ Determines if the current Python process is running inside a Kubernetes pod.
407
+
408
+ Returns:
409
+ bool: True if running in Kubernetes, False otherwise
410
+ """
411
+ # Method 1: Check for Kubernetes service environment variables
412
+ if os.environ.get("KUBERNETES_SERVICE_HOST") is not None:
413
+ return True
414
+
415
+ # Method 2: Check for the existence of the Kubernetes service account token file
416
+ if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount/token"):
417
+ return True
418
+
419
+ return False
420
+
421
+
422
+ def _get_rendered_template(template_file: str, template_dir: str, **template_vars) -> str:
423
+ """Helper function to set up and render a template."""
424
+ template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
425
+ template_env = jinja2.Environment(
426
+ loader=template_loader,
427
+ keep_trailing_newline=True,
428
+ trim_blocks=True,
429
+ lstrip_blocks=True,
430
+ enable_async=False,
431
+ autoescape=False,
432
+ )
433
+ template = template_env.get_template(template_file)
434
+ return template.render(**template_vars)
435
+
436
+
437
+ def load_template(template_file: str, template_dir: str, **template_vars) -> dict:
438
+ """Load and render a single YAML document template."""
439
+ rendered = _get_rendered_template(template_file, template_dir, **template_vars)
440
+ return yaml.safe_load(rendered)
441
+
442
+
443
+ def load_multi_yaml_template(template_file: str, template_dir: str, **template_vars) -> dict:
444
+ """Load and render a multi-document YAML template."""
445
+ rendered = _get_rendered_template(template_file, template_dir, **template_vars)
446
+ return {"items": list(yaml.safe_load_all(rendered))}
447
+
448
+
449
+ def generate_unique_request_id(endpoint: str, timestamp: str) -> str:
450
+ """Generates a unique request id, based on the method/function endpoint and the call timestamp"""
451
+ raw = f"{endpoint}_{timestamp}"
452
+ unique_id = hashlib.sha256(raw.encode()).hexdigest()[:10]
453
+ return unique_id
454
+
455
+
456
+ def print_log_stream_client(message, last_timestamp, print_pod_name: bool = False):
457
+ formatter = ServerLogsFormatter()
458
+ if message.get("streams"):
459
+ for stream in message["streams"]:
460
+ pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
461
+ for value in stream["values"]:
462
+ # Skip if we've already seen this timestamp
463
+ if last_timestamp is not None and value[0] <= last_timestamp:
464
+ continue
465
+ last_timestamp = value[0]
466
+
467
+ log_line = json.loads(value[1])
468
+ log_name = log_line.get("name")
469
+ if log_name == "print_redirect":
470
+ message = log_line.get("message")
471
+ print(f"{pod_name}{formatter.start_color}{message}{formatter.reset_color}")
472
+ elif log_name != "uvicorn.access":
473
+ formatted_log = (
474
+ f"{pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}"
475
+ )
476
+ print(f"{formatter.start_color}{formatted_log}{formatter.reset_color}")
477
+ return last_timestamp
478
+
479
+
480
+ def print_log_stream_cli(message, last_timestamp, print_pod_name: bool = False):
481
+ if message.get("streams"):
482
+ for stream in message["streams"]:
483
+ pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
484
+ for value in stream["values"]:
485
+ # Skip if we've already seen this timestamp
486
+ if last_timestamp is not None and value[0] <= last_timestamp:
487
+ continue
488
+ last_timestamp = value[0]
489
+ log_line = value[1]
490
+ try:
491
+ log_line = json.loads(log_line)
492
+ log_name = log_line.get("name")
493
+ if log_name == "print_redirect":
494
+ continue
495
+ # the print output will be printed in line 250. We need the "print_redirect"
496
+ # log type only for log streaming in the http client, so we could filter out
497
+ # the print outputs for a specific request ID. For the CLI --follow option, we
498
+ # print all logs, so at the moment we don't need to filter by request_id.
499
+ elif log_name != "uvicorn.access":
500
+ formatted_log = f"({pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}".strip()
501
+ print(formatted_log)
502
+ except json.JSONDecodeError:
503
+ print(log_line.strip())
504
+
505
+ return last_timestamp
506
+
507
+
508
+ async def stream_logs_websocket_helper(
509
+ uri,
510
+ stop_event,
511
+ stream_type: StreamType = StreamType.HTTP_CLIENT,
512
+ print_pod_name: bool = False,
513
+ ):
514
+ """Stream logs using Loki's websocket tail endpoint"""
515
+ websocket = None
516
+ try:
517
+ # Track the last timestamp we've seen to avoid duplicates
518
+ last_timestamp = None
519
+ # Track when we should stop
520
+ stop_time = None
521
+
522
+ # Add timeout to prevent hanging connections
523
+ websocket = await websockets.connect(
524
+ uri,
525
+ close_timeout=10, # Max time to wait for close handshake
526
+ ping_interval=20, # Send ping every 20 seconds
527
+ ping_timeout=10, # Wait 10 seconds for pong
528
+ )
529
+ try:
530
+ while True:
531
+ # If stop event is set, start counting down
532
+ if stop_event.is_set() and stop_time is None:
533
+ stop_time = time.time() + 2 # 2 seconds grace period
534
+
535
+ # If we're past the grace period, exit
536
+ if stop_time is not None and time.time() > stop_time:
537
+ break
538
+
539
+ try:
540
+ # Use shorter timeout during grace period
541
+ timeout = 0.1 if stop_time is not None else 1.0
542
+ message = await asyncio.wait_for(websocket.recv(), timeout=timeout)
543
+ try:
544
+ message = json.loads(message)
545
+ except json.JSONDecodeError:
546
+ message = message
547
+
548
+ if stream_type == StreamType.HTTP_CLIENT:
549
+ last_timestamp = print_log_stream_client(message, last_timestamp, print_pod_name)
550
+ elif stream_type == StreamType.CLI:
551
+ last_timestamp = print_log_stream_cli(message, last_timestamp, print_pod_name)
552
+ except asyncio.TimeoutError:
553
+ # Timeout is expected, just continue the loop
554
+ continue
555
+ except websockets.exceptions.ConnectionClosed as e:
556
+ logger.debug(f"WebSocket connection closed: {str(e)}")
557
+ break
558
+ finally:
559
+ if websocket:
560
+ try:
561
+ # Use wait_for to prevent hanging on close
562
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
563
+ except (asyncio.TimeoutError, Exception):
564
+ pass
565
+ except Exception as e:
566
+ logger.error(f"Error in websocket stream: {e}")
567
+ finally:
568
+ # Ensure websocket is closed even if we didn't enter the try block
569
+ if websocket:
570
+ try:
571
+ # Use wait_for to prevent hanging on close
572
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
573
+ except (asyncio.TimeoutError, Exception):
574
+ pass
575
+
576
+
577
+ def clear_debugging_sessions():
578
+ """Clear any existing debugging sessions when a module is redeployed or pod is terminated."""
579
+ try:
580
+ import web_pdb
581
+
582
+ if web_pdb.WebPdb.active_instance is not None:
583
+ logger.info("Clearing existing debugging session")
584
+ try:
585
+ web_pdb.WebPdb.active_instance.remove_trace()
586
+ except Exception as e:
587
+ logger.warning(f"Error removing trace: {e}")
588
+ web_pdb.WebPdb.active_instance = None
589
+
590
+ except ImportError:
591
+ # web_pdb not installed, nothing to clean up
592
+ pass
593
+ except Exception as e:
594
+ logger.warning(f"Error clearing debugging session: {e}")
595
+
596
+
597
+ # Register cleanup function to run at exit
598
+ atexit.register(clear_debugging_sessions)
599
+
600
+
601
+ def deep_breakpoint(debug_port: int = DEFAULT_DEBUG_PORT):
602
+ """
603
+ Similar to Python's built-in `breakpoint()`, but can be used deep inside distributed code. For SPMD-style
604
+ distributed code like PyTorch, be sure to only call this from one process (e.g. the rank 0 process) to avoid
605
+ blocking all processes in the distributed group.
606
+ """
607
+ # Check if madbg is installed, if not, install it
608
+ try:
609
+ import web_pdb
610
+ except ImportError:
611
+ install_cmd = "uv pip install --system web-pdb"
612
+ import subprocess
613
+
614
+ print("Pdb debugger not found, installing it...")
615
+ # Run the install command and propagate logs
616
+ subprocess.run(install_cmd, shell=True, check=True, text=True)
617
+ print("Pdb installed successfully.")
618
+
619
+ print("Distributed breakpoint activated. To attach a debugger, run the following command:")
620
+ print(f"kt debug {os.environ['POD_NAME']} --port {debug_port} --namespace {os.environ['POD_NAMESPACE']}")
621
+
622
+ import web_pdb
623
+
624
+ pdb = web_pdb.WebPdb.active_instance
625
+ try:
626
+ if pdb is None:
627
+ pdb = web_pdb.WebPdb(host="", port=debug_port, patch_stdstreams=False)
628
+ else:
629
+ # If the debugger is still attached reset trace to a new location
630
+ pdb.remove_trace()
631
+
632
+ # Set the frame to the caller's frame
633
+ pdb.set_trace(sys._getframe(1)) # pylint: disable=protected-access
634
+ except Exception as e:
635
+ # Only clean up if there was an error setting up the debugger
636
+ if pdb:
637
+ pdb.remove_trace()
638
+ web_pdb.WebPdb.active_instance = None
639
+ raise e
640
+
641
+
642
+ def wait_for_app_start(port, health_check: str, process: subprocess.Popen, timeout: int = 60):
643
+ """
644
+ Wait until the app is ready. If health_check if provided, will send HTTP requests to check, otherwise
645
+ will wait until something is listening on the port.
646
+ """
647
+ host = LOCALHOST
648
+ port = int(port)
649
+ logger.debug(f"Trying to connect to http://{host}:{port}{health_check or ''}")
650
+ start_time = time.time()
651
+
652
+ if health_check:
653
+ if not health_check.startswith("/"):
654
+ health_check = f"/{health_check}"
655
+ url = f"http://{LOCALHOST}:{port}{health_check}"
656
+ while time.time() - start_time < timeout:
657
+ if process.poll() is not None and process.poll() != 0:
658
+ raise RuntimeError(f"App exited with code {process.poll()}")
659
+ try:
660
+ response = httpx.get(url)
661
+ if response.status_code == 200:
662
+ return True
663
+ except httpx.ConnectError:
664
+ pass
665
+ time.sleep(0.5)
666
+ raise TimeoutError(f"App did not become healthy on {url} within {timeout} seconds")
667
+ else:
668
+ # Fallback to socket check
669
+ while time.time() - start_time < timeout:
670
+ if process.poll() is not None and process.poll() != 0:
671
+ raise RuntimeError(f"App exited with code {process.poll()}")
672
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
673
+ sock.settimeout(1)
674
+ try:
675
+ sock.connect((host, port))
676
+ return True
677
+ except (ConnectionRefusedError, socket.timeout):
678
+ time.sleep(0.5)
679
+ raise TimeoutError(f"Failed to detect open port {port} for app {url} within {timeout} seconds")
680
+
681
+
682
+ def _serialize_body(body: dict, serialization: str):
683
+ if body is None:
684
+ return {}
685
+
686
+ # We only serialize args and kwargs, other settings like "workers" and "restart_procs" are needed inside
687
+ # the http_server, outside the serialization boundary (e.g. the distributed processes)
688
+ # We break them out here as separate params
689
+ body = body or {}
690
+
691
+ for kwarg in MAGIC_CALL_KWARGS:
692
+ if kwarg in body.get("kwargs", {}):
693
+ body[kwarg] = body["kwargs"].pop(kwarg)
694
+
695
+ if serialization == "pickle":
696
+ args_data = {"args": body.pop("args"), "kwargs": body.pop("kwargs")}
697
+ pickled_args = pickle.dumps(args_data or {})
698
+ encoded_args = base64.b64encode(pickled_args).decode("utf-8")
699
+ body["data"] = encoded_args
700
+ return body
701
+ return body or {}
702
+
703
+
704
+ def _deserialize_response(response, serialization: str):
705
+ if serialization == "pickle":
706
+ response_data = response.json()
707
+ if isinstance(response_data, list):
708
+ # If this is a response from an spmd call, it's a list of serialized dicts
709
+ unpickled_results = []
710
+ for resp in response_data:
711
+ if "data" in resp:
712
+ encoded_result = resp["data"]
713
+ pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
714
+ resp = pickle.loads(pickled_result)
715
+ unpickled_results.append(resp)
716
+ return unpickled_results
717
+ if "data" in response_data:
718
+ encoded_result = response_data["data"]
719
+ pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
720
+ return pickle.loads(pickled_result)
721
+ return response_data
722
+ return response.json()