kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,728 @@
1
+ import asyncio
2
+ import atexit
3
+ import base64
4
+ import enum
5
+ import hashlib
6
+ import json
7
+ import os
8
+ import pickle
9
+ import re
10
+ import socket
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ from contextvars import ContextVar
15
+ from typing import List
16
+
17
+ import httpx
18
+
19
+ import jinja2
20
+ import websockets
21
+ import yaml
22
+
23
+ from kubetorch.constants import LOCALHOST
24
+ from kubetorch.logger import get_logger
25
+ from kubetorch.serving.constants import DEFAULT_DEBUG_PORT
26
+ from kubetorch.utils import ServerLogsFormatter
27
+
28
+ logger = get_logger(__name__)
29
+
30
+ RSYNC_PORT = 873
31
+
32
+ DEFAULT_ALLOWED_SERIALIZATION = "json"
33
+
34
+ MAGIC_CALL_KWARGS = ["workers", "restart_procs"]
35
+
36
+ LOG_CONFIG = {
37
+ "version": 1,
38
+ "disable_existing_loggers": False,
39
+ "formatters": {},
40
+ "handlers": {},
41
+ "root": {"level": "INFO", "handlers": []},
42
+ "loggers": {
43
+ "uvicorn": {"level": "INFO", "handlers": [], "propagate": True},
44
+ "uvicorn.access": {"level": "INFO", "handlers": [], "propagate": True},
45
+ "kubetorch": {"level": "INFO", "handlers": [], "propagate": True},
46
+ },
47
+ }
48
+
49
+
50
+ def ensure_structured_logging():
51
+ """Add our structured JSON handler to all loggers without removing user's handlers. We do this both when we
52
+ set up the HTTP server and also after re-importing user code, as their modules might include logging setup
53
+ of their own."""
54
+ import logging
55
+ import logging.handlers
56
+ import os
57
+ import sys
58
+
59
+ from pythonjsonlogger import jsonlogger
60
+
61
+ # First ensure logging is initialized - this is crucial!
62
+ # If no handlers exist, we need to initialize the logging system
63
+ root_logger = logging.getLogger()
64
+
65
+ # Create our JSON formatter
66
+ json_formatter = jsonlogger.JsonFormatter(
67
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(request_id)s - %(pod)s",
68
+ datefmt="%Y-%m-%d %H:%M:%S",
69
+ )
70
+
71
+ # Create our structured handler (we keep using sys.stdout so user and kt logs
72
+ # both appear in pod logs; our stdout wrapper will mirror to the original stream)
73
+ structured_handler = logging.StreamHandler(sys.stdout)
74
+ structured_handler.setFormatter(json_formatter)
75
+ structured_handler.name = "kubetorch_structured" # Name it so we can identify it
76
+
77
+ # Set root logger level based on KT_LOG_LEVEL if it's set
78
+ kt_log_level = os.getenv("KT_LOG_LEVEL")
79
+ if kt_log_level:
80
+ kt_log_level = kt_log_level.upper()
81
+ root_logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
82
+
83
+ # Check if our handler is already there (to avoid adding duplicates)
84
+ existing_structured = None
85
+ for h in root_logger.handlers:
86
+ if getattr(h, "name", None) == "kubetorch_structured":
87
+ existing_structured = h
88
+ break
89
+
90
+ if not existing_structured:
91
+ # Add our structured handler alongside any user-installed handlers
92
+ # so both formats are emitted to pod logs
93
+ root_logger.addHandler(structured_handler)
94
+
95
+ # Ensure request context fields are attached to all records even if the user
96
+ # reconfigured logging and removed our filters. Do this idempotently.
97
+ class _ContextFieldsFilter(logging.Filter):
98
+ def filter(self, record):
99
+ if not hasattr(record, "request_id") or record.request_id in (None, "-"):
100
+ try:
101
+ record.request_id = request_id_ctx_var.get("-")
102
+ except Exception:
103
+ record.request_id = "-"
104
+ if not hasattr(record, "pod") or record.pod in (None, ""):
105
+ record.pod = os.getenv("POD_NAME", "unknown-pod")
106
+ return True
107
+
108
+ # Attach the filter to root and all of its handlers (idempotent: duplicate adds are ignored)
109
+ context_filter = _ContextFieldsFilter()
110
+ try:
111
+ root_logger.addFilter(context_filter)
112
+ except Exception:
113
+ pass
114
+ for h in root_logger.handlers:
115
+ try:
116
+ h.addFilter(context_filter)
117
+ except Exception:
118
+ pass
119
+
120
+ # Ensure print_redirect logger also has proper configuration
121
+ # This is important for the StreamToLogger output
122
+ print_logger = logging.getLogger("print_redirect")
123
+ print_logger.setLevel(logging.INFO)
124
+ # Ensure it propagates to root so the structured handler formats it
125
+ print_logger.propagate = True
126
+ try:
127
+ print_logger.addFilter(context_filter)
128
+ except Exception:
129
+ pass
130
+
131
+
132
+ request_id_ctx_var: ContextVar[str] = ContextVar("request_id", default="-")
133
+
134
+
135
+ class StartupError(Exception):
136
+ pass
137
+
138
+
139
+ class PodTerminatedError(Exception):
140
+ def __init__(
141
+ self,
142
+ pod_name: str = "unknown",
143
+ reason: str = "Unknown",
144
+ status_code: int = 503,
145
+ events: List[dict] = None,
146
+ ):
147
+ """
148
+ events: List of dicts with keys:
149
+ - timestamp: datetime
150
+ - reason: str
151
+ - message: str
152
+
153
+ sample event:
154
+ {
155
+ 'timestamp': datetime.datetime(2025, 7, 13, 16, 45, 46, tzinfo=tzutc()),
156
+ 'reason': 'Evicted',
157
+ 'message': 'The node was low on resource: memory. Threshold quantity: 100Mi, available: 3404Ki.'
158
+ }
159
+ """
160
+ self.pod_name = pod_name
161
+ self.reason = reason
162
+ self.status_code = status_code
163
+ self.events = events or []
164
+ super().__init__(str(self))
165
+
166
+ def __getstate__(self):
167
+ """Serialize the exception state for transmission over HTTP."""
168
+ # Convert datetime objects to ISO format strings for JSON serialization
169
+ serialized_events = []
170
+ for event in self.events:
171
+ serialized_event = event.copy()
172
+ if "timestamp" in serialized_event:
173
+ timestamp = serialized_event["timestamp"]
174
+ # Convert datetime to string if needed
175
+ if hasattr(timestamp, "isoformat"):
176
+ serialized_event["timestamp"] = timestamp.isoformat()
177
+ serialized_events.append(serialized_event)
178
+
179
+ return {
180
+ "pod_name": self.pod_name,
181
+ "reason": self.reason,
182
+ "status_code": self.status_code,
183
+ "events": serialized_events,
184
+ }
185
+
186
+ def __setstate__(self, state):
187
+ """Reconstruct the exception from serialized state."""
188
+ self.pod_name = state["pod_name"]
189
+ self.reason = state["reason"]
190
+ self.status_code = state["status_code"]
191
+ self.events = state["events"]
192
+
193
+ @classmethod
194
+ def from_dict(cls, state):
195
+ """Reconstruct the exception from a dictionary state."""
196
+ return cls(
197
+ pod_name=state.get("pod_name", "unknown"),
198
+ reason=state.get("reason", "Unknown"),
199
+ status_code=state.get("status_code", 503),
200
+ events=state.get("events", []),
201
+ )
202
+
203
+ @property
204
+ def evicted(self) -> bool:
205
+ """True if pod was evicted (ex: node pressure, preemption)."""
206
+ return self.reason == "Evicted" or any(
207
+ "Evicted" in event["reason"] for event in self.events
208
+ )
209
+
210
+ @property
211
+ def oom_killed(self) -> bool:
212
+ """True if pod was evicted due to OOM."""
213
+ return self.reason == "OOMKilled" or any(
214
+ "OOMKilled" in event["reason"] for event in self.events
215
+ )
216
+
217
+ def __str__(self):
218
+ events_str = "\n".join(
219
+ f"{e['timestamp']} {e['reason']}: {e['message']}" for e in self.events
220
+ )
221
+ base_exc = (
222
+ f"\nPod Name: {self.pod_name}\n"
223
+ f"Reason: {self.reason}\n"
224
+ f"Status Code: {self.status_code}\n"
225
+ )
226
+ if self.events:
227
+ base_exc += f"Recent Events:\n{events_str}"
228
+ return base_exc
229
+
230
+
231
+ class WorkerMembershipChanged(Exception):
232
+ """Raised when worker pods are added or removed during distributed execution."""
233
+
234
+ def __init__(
235
+ self,
236
+ added_ips: set = None,
237
+ removed_ips: set = None,
238
+ previous_ips: set = None,
239
+ current_ips: set = None,
240
+ message: str = None,
241
+ ):
242
+ # Support both explicit construction and reconstruction from message
243
+ if message and not (added_ips or removed_ips):
244
+ import ast
245
+
246
+ # Reconstruct from message
247
+ import re
248
+
249
+ self.added_ips = set()
250
+ self.removed_ips = set()
251
+ self.previous_ips = set()
252
+ self.current_ips = set()
253
+
254
+ if "removed during execution:" in message:
255
+ match = re.search(r"removed during execution: ({.*?})", message)
256
+ if match:
257
+ self.removed_ips = ast.literal_eval(match.group(1))
258
+ elif "added during execution:" in message:
259
+ match = re.search(r"added during execution: ({.*?})", message)
260
+ if match:
261
+ self.added_ips = ast.literal_eval(match.group(1))
262
+ else:
263
+ # Normal construction
264
+ self.added_ips = added_ips or set()
265
+ self.removed_ips = removed_ips or set()
266
+ self.previous_ips = previous_ips or set()
267
+ self.current_ips = current_ips or set()
268
+
269
+ if removed_ips:
270
+ message = f"Critical: {len(removed_ips)} worker(s) removed during execution: {removed_ips}"
271
+ elif added_ips:
272
+ message = f"Warning: {len(added_ips)} worker(s) added during execution: {added_ips}"
273
+ else:
274
+ message = "Worker membership changed"
275
+
276
+ super().__init__(message)
277
+
278
+ @property
279
+ def is_critical(self) -> bool:
280
+ """Returns True if workers were removed (critical for training)."""
281
+ return bool(self.removed_ips)
282
+
283
+ def __getstate__(self):
284
+ """Serialize the exception state."""
285
+ return {
286
+ "message": str(self),
287
+ "added_ips": list(self.added_ips),
288
+ "removed_ips": list(self.removed_ips),
289
+ "previous_ips": list(self.previous_ips),
290
+ "current_ips": list(self.current_ips),
291
+ }
292
+
293
+ @classmethod
294
+ def from_dict(cls, data):
295
+ """Reconstruct from serialized state."""
296
+ return cls(
297
+ added_ips=set(data.get("added_ips", [])),
298
+ removed_ips=set(data.get("removed_ips", [])),
299
+ previous_ips=set(data.get("previous_ips", [])),
300
+ current_ips=set(data.get("current_ips", [])),
301
+ )
302
+
303
+
304
+ class StreamType(str, enum.Enum):
305
+ CLI = "cli"
306
+ HTTP_CLIENT = "http_client"
307
+
308
+
309
+ def clean_and_validate_k8s_name(name: str, allow_full_length: bool = True) -> str:
310
+ """Clean and validate a name for K8s compatibility.
311
+
312
+ Args:
313
+ name: The name to clean and validate
314
+ allow_full_length: If True, allows and intelligently trims full pod names to 63 chars,
315
+ preserving k8s-generated portions.
316
+ If False, limits to 40 chars to leave room for k8s suffixes.
317
+ """
318
+ max_k8s_name_length = 63 # max length allowed by k8s
319
+ max_base_name_length = (
320
+ 40 # max module name length to account for added k8s suffixes
321
+ )
322
+ # Regex to comply with k8s service name requirements
323
+ cleaned_name = re.sub(r"[^a-z0-9-]|^[-]|[-]$", "", name.lower())
324
+ if not cleaned_name:
325
+ raise ValueError("Name must contain at least one alphanumeric character.")
326
+
327
+ max_length = max_k8s_name_length if allow_full_length else max_base_name_length
328
+
329
+ if len(cleaned_name) > max_length:
330
+ if not allow_full_length:
331
+ # For a user provided module name, raise an exception
332
+ error_msg = (
333
+ f"Name length {len(cleaned_name)} exceeds {max_length} characters. "
334
+ "Must leave room for Kubernetes-added suffixes."
335
+ )
336
+ raise ValueError(error_msg)
337
+
338
+ match = re.search(r"(-\d+)?-deployment-[a-z0-9]+-[a-z0-9]+", cleaned_name)
339
+ if match:
340
+ k8s_part = match.group(0)
341
+ k8s_start_idx = match.start()
342
+
343
+ prefix = cleaned_name[:k8s_start_idx]
344
+ suffix = cleaned_name[k8s_start_idx + len(k8s_part) :]
345
+
346
+ total_excess = len(cleaned_name) - max_length
347
+
348
+ # If we need to trim, handle each part
349
+ if total_excess > 0:
350
+ # Handle prefix trimming
351
+ if prefix:
352
+ segments = prefix.split("-")
353
+ while (
354
+ len("-".join(segments)) + len(k8s_part) + len(suffix)
355
+ > max_length
356
+ ):
357
+ if len(segments) > 1:
358
+ segments.pop()
359
+ else:
360
+ segments[0] = segments[0][:-1]
361
+ prefix = "-".join(segments)
362
+
363
+ # Handle suffix trimming if still needed
364
+ remaining_length = max_length - (len(prefix) + len(k8s_part))
365
+ if remaining_length > 0:
366
+ suffix_segments = suffix.split("-")
367
+ clean_segments = []
368
+ current_length = 0
369
+ for seg in suffix_segments:
370
+ # Only add segment if it's at least 2 chars so the name doesn't look cut off
371
+ if (
372
+ len(seg) >= 2
373
+ and current_length + len(seg) + 1 <= remaining_length
374
+ ):
375
+ clean_segments.append(seg)
376
+ current_length += len(seg) + 1
377
+ suffix = "-".join(clean_segments)
378
+ else:
379
+ suffix = ""
380
+
381
+ cleaned_name = (
382
+ (prefix + "-" if prefix else "")
383
+ + k8s_part
384
+ + ("-" + suffix if suffix else "")
385
+ )
386
+
387
+ return cleaned_name
388
+
389
+
390
+ def is_running_in_kubernetes():
391
+ """
392
+ Determines if the current Python process is running inside a Kubernetes pod.
393
+
394
+ Returns:
395
+ bool: True if running in Kubernetes, False otherwise
396
+ """
397
+ # Method 1: Check for Kubernetes service environment variables
398
+ if os.environ.get("KUBERNETES_SERVICE_HOST") is not None:
399
+ return True
400
+
401
+ # Method 2: Check for the existence of the Kubernetes service account token file
402
+ if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount/token"):
403
+ return True
404
+
405
+ return False
406
+
407
+
408
+ def _get_rendered_template(
409
+ template_file: str, template_dir: str, **template_vars
410
+ ) -> str:
411
+ """Helper function to set up and render a template."""
412
+ template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
413
+ template_env = jinja2.Environment(
414
+ loader=template_loader,
415
+ keep_trailing_newline=True,
416
+ trim_blocks=True,
417
+ lstrip_blocks=True,
418
+ enable_async=False,
419
+ autoescape=False,
420
+ )
421
+ template = template_env.get_template(template_file)
422
+ return template.render(**template_vars)
423
+
424
+
425
+ def load_template(template_file: str, template_dir: str, **template_vars) -> dict:
426
+ """Load and render a single YAML document template."""
427
+ rendered = _get_rendered_template(template_file, template_dir, **template_vars)
428
+ return yaml.safe_load(rendered)
429
+
430
+
431
+ def load_multi_yaml_template(
432
+ template_file: str, template_dir: str, **template_vars
433
+ ) -> dict:
434
+ """Load and render a multi-document YAML template."""
435
+ rendered = _get_rendered_template(template_file, template_dir, **template_vars)
436
+ return {"items": list(yaml.safe_load_all(rendered))}
437
+
438
+
439
+ def generate_unique_request_id(endpoint: str, timestamp: str) -> str:
440
+ """Generates a unique request id, based on the method/function endpoint and the call timestamp"""
441
+ raw = f"{endpoint}_{timestamp}"
442
+ unique_id = hashlib.sha256(raw.encode()).hexdigest()[:10]
443
+ return unique_id
444
+
445
+
446
+ def print_log_stream_client(message, last_timestamp, print_pod_name: bool = False):
447
+ formatter = ServerLogsFormatter()
448
+ if message.get("streams"):
449
+ for stream in message["streams"]:
450
+ pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
451
+ for value in stream["values"]:
452
+ # Skip if we've already seen this timestamp
453
+ if last_timestamp is not None and value[0] <= last_timestamp:
454
+ continue
455
+ last_timestamp = value[0]
456
+
457
+ log_line = json.loads(value[1])
458
+ log_name = log_line.get("name")
459
+ if log_name == "print_redirect":
460
+ message = log_line.get("message")
461
+ print(
462
+ f"{pod_name}{formatter.start_color}{message}{formatter.reset_color}"
463
+ )
464
+ elif log_name != "uvicorn.access":
465
+ formatted_log = f"{pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}"
466
+ print(
467
+ f"{formatter.start_color}{formatted_log}{formatter.reset_color}"
468
+ )
469
+ return last_timestamp
470
+
471
+
472
+ def print_log_stream_cli(message, last_timestamp, print_pod_name: bool = False):
473
+ if message.get("streams"):
474
+ for stream in message["streams"]:
475
+ pod_name = f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
476
+ for value in stream["values"]:
477
+ # Skip if we've already seen this timestamp
478
+ if last_timestamp is not None and value[0] <= last_timestamp:
479
+ continue
480
+ last_timestamp = value[0]
481
+ log_line = value[1]
482
+ try:
483
+ log_line = json.loads(log_line)
484
+ log_name = log_line.get("name")
485
+ if log_name == "print_redirect":
486
+ continue
487
+ # the print output will be printed in line 250. We need the "print_redirect"
488
+ # log type only for log streaming in the http client, so we could filter out
489
+ # the print outputs for a specific request ID. For the CLI --follow option, we
490
+ # print all logs, so at the moment we don't need to filter by request_id.
491
+ elif log_name != "uvicorn.access":
492
+ formatted_log = f"({pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | {log_line.get('message')}".strip()
493
+ print(formatted_log)
494
+ except json.JSONDecodeError:
495
+ print(log_line.strip())
496
+
497
+ return last_timestamp
498
+
499
+
500
+ async def stream_logs_websocket_helper(
501
+ uri,
502
+ stop_event,
503
+ stream_type: StreamType = StreamType.HTTP_CLIENT,
504
+ print_pod_name: bool = False,
505
+ ):
506
+ """Stream logs using Loki's websocket tail endpoint"""
507
+ websocket = None
508
+ try:
509
+ # Track the last timestamp we've seen to avoid duplicates
510
+ last_timestamp = None
511
+ # Track when we should stop
512
+ stop_time = None
513
+
514
+ # Add timeout to prevent hanging connections
515
+ websocket = await websockets.connect(
516
+ uri,
517
+ close_timeout=10, # Max time to wait for close handshake
518
+ ping_interval=20, # Send ping every 20 seconds
519
+ ping_timeout=10, # Wait 10 seconds for pong
520
+ )
521
+ try:
522
+ while True:
523
+ # If stop event is set, start counting down
524
+ if stop_event.is_set() and stop_time is None:
525
+ stop_time = time.time() + 2 # 2 seconds grace period
526
+
527
+ # If we're past the grace period, exit
528
+ if stop_time is not None and time.time() > stop_time:
529
+ break
530
+
531
+ try:
532
+ # Use shorter timeout during grace period
533
+ timeout = 0.1 if stop_time is not None else 1.0
534
+ message = await asyncio.wait_for(websocket.recv(), timeout=timeout)
535
+ try:
536
+ message = json.loads(message)
537
+ except json.JSONDecodeError:
538
+ message = message
539
+
540
+ if stream_type == StreamType.HTTP_CLIENT:
541
+ last_timestamp = print_log_stream_client(
542
+ message, last_timestamp, print_pod_name
543
+ )
544
+ elif stream_type == StreamType.CLI:
545
+ last_timestamp = print_log_stream_cli(
546
+ message, last_timestamp, print_pod_name
547
+ )
548
+ except asyncio.TimeoutError:
549
+ # Timeout is expected, just continue the loop
550
+ continue
551
+ except websockets.exceptions.ConnectionClosed as e:
552
+ logger.debug(f"WebSocket connection closed: {str(e)}")
553
+ break
554
+ finally:
555
+ if websocket:
556
+ try:
557
+ # Use wait_for to prevent hanging on close
558
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
559
+ except (asyncio.TimeoutError, Exception):
560
+ pass
561
+ except Exception as e:
562
+ logger.error(f"Error in websocket stream: {e}")
563
+ finally:
564
+ # Ensure websocket is closed even if we didn't enter the try block
565
+ if websocket:
566
+ try:
567
+ # Use wait_for to prevent hanging on close
568
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
569
+ except (asyncio.TimeoutError, Exception):
570
+ pass
571
+
572
+
573
+ def clear_debugging_sessions():
574
+ """Clear any existing debugging sessions when a module is redeployed or pod is terminated."""
575
+ try:
576
+ import web_pdb
577
+
578
+ if web_pdb.WebPdb.active_instance is not None:
579
+ logger.info("Clearing existing debugging session")
580
+ try:
581
+ web_pdb.WebPdb.active_instance.remove_trace()
582
+ except Exception as e:
583
+ logger.warning(f"Error removing trace: {e}")
584
+ web_pdb.WebPdb.active_instance = None
585
+
586
+ except ImportError:
587
+ # web_pdb not installed, nothing to clean up
588
+ pass
589
+ except Exception as e:
590
+ logger.warning(f"Error clearing debugging session: {e}")
591
+
592
+
593
+ # Register cleanup function to run at exit
594
+ atexit.register(clear_debugging_sessions)
595
+
596
+
597
+ def deep_breakpoint(debug_port: int = DEFAULT_DEBUG_PORT):
598
+ """
599
+ Similar to Python's built-in `breakpoint()`, but can be used deep inside distributed code. For SPMD-style
600
+ distributed code like PyTorch, be sure to only call this from one process (e.g. the rank 0 process) to avoid
601
+ blocking all processes in the distributed group.
602
+ """
603
+ # Check if madbg is installed, if not, install it
604
+ try:
605
+ import web_pdb
606
+ except ImportError:
607
+ install_cmd = "uv pip install --system web-pdb"
608
+ import subprocess
609
+
610
+ print("Pdb debugger not found, installing it...")
611
+ # Run the install command and propagate logs
612
+ subprocess.run(install_cmd, shell=True, check=True, text=True)
613
+ print("Pdb installed successfully.")
614
+
615
+ print(
616
+ "Distributed breakpoint activated. To attach a debugger, run the following command:"
617
+ )
618
+ print(
619
+ f"kt debug {os.environ['POD_NAME']} --port {debug_port} --namespace {os.environ['POD_NAMESPACE']}"
620
+ )
621
+
622
+ import web_pdb
623
+
624
+ pdb = web_pdb.WebPdb.active_instance
625
+ try:
626
+ if pdb is None:
627
+ pdb = web_pdb.WebPdb(host="", port=debug_port, patch_stdstreams=False)
628
+ else:
629
+ # If the debugger is still attached reset trace to a new location
630
+ pdb.remove_trace()
631
+
632
+ # Set the frame to the caller's frame
633
+ pdb.set_trace(sys._getframe(1)) # pylint: disable=protected-access
634
+ except Exception as e:
635
+ # Only clean up if there was an error setting up the debugger
636
+ if pdb:
637
+ pdb.remove_trace()
638
+ web_pdb.WebPdb.active_instance = None
639
+ raise e
640
+
641
+
642
+ def wait_for_app_start(
643
+ port, health_check: str, process: subprocess.Popen, timeout: int = 60
644
+ ):
645
+ """
646
+ Wait until the app is ready. If health_check if provided, will send HTTP requests to check, otherwise
647
+ will wait until something is listening on the port.
648
+ """
649
+ host = LOCALHOST
650
+ port = int(port)
651
+ logger.debug(f"Trying to connect to http://{host}:{port}{health_check or ''}")
652
+ start_time = time.time()
653
+
654
+ if health_check:
655
+ if not health_check.startswith("/"):
656
+ health_check = f"/{health_check}"
657
+ url = f"http://{LOCALHOST}:{port}{health_check}"
658
+ while time.time() - start_time < timeout:
659
+ if process.poll() is not None and process.poll() != 0:
660
+ raise RuntimeError(f"App exited with code {process.poll()}")
661
+ try:
662
+ response = httpx.get(url)
663
+ if response.status_code == 200:
664
+ return True
665
+ except httpx.ConnectError:
666
+ pass
667
+ time.sleep(0.5)
668
+ raise TimeoutError(
669
+ f"App did not become healthy on {url} within {timeout} seconds"
670
+ )
671
+ else:
672
+ # Fallback to socket check
673
+ while time.time() - start_time < timeout:
674
+ if process.poll() is not None and process.poll() != 0:
675
+ raise RuntimeError(f"App exited with code {process.poll()}")
676
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
677
+ sock.settimeout(1)
678
+ try:
679
+ sock.connect((host, port))
680
+ return True
681
+ except (ConnectionRefusedError, socket.timeout):
682
+ time.sleep(0.5)
683
+ raise TimeoutError(
684
+ f"Failed to detect open port {port} for app {url} within {timeout} seconds"
685
+ )
686
+
687
+
688
+ def _serialize_body(body: dict, serialization: str):
689
+ if body is None:
690
+ return {}
691
+
692
+ # We only serialize args and kwargs, other settings like "workers" and "restart_procs" are needed inside
693
+ # the http_server, outside the serialization boundary (e.g. the distributed processes)
694
+ # We break them out here as separate params
695
+ body = body or {}
696
+
697
+ for kwarg in MAGIC_CALL_KWARGS:
698
+ if kwarg in body.get("kwargs", {}):
699
+ body[kwarg] = body["kwargs"].pop(kwarg)
700
+
701
+ if serialization == "pickle":
702
+ args_data = {"args": body.pop("args"), "kwargs": body.pop("kwargs")}
703
+ pickled_args = pickle.dumps(args_data or {})
704
+ encoded_args = base64.b64encode(pickled_args).decode("utf-8")
705
+ body["data"] = encoded_args
706
+ return body
707
+ return body or {}
708
+
709
+
710
+ def _deserialize_response(response, serialization: str):
711
+ if serialization == "pickle":
712
+ response_data = response.json()
713
+ if isinstance(response_data, list):
714
+ # If this is a response from an spmd call, it's a list of serialized dicts
715
+ unpickled_results = []
716
+ for resp in response_data:
717
+ if "data" in resp:
718
+ encoded_result = resp["data"]
719
+ pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
720
+ resp = pickle.loads(pickled_result)
721
+ unpickled_results.append(resp)
722
+ return unpickled_results
723
+ if "data" in response_data:
724
+ encoded_result = response_data["data"]
725
+ pickled_result = base64.b64decode(encoded_result.encode("utf-8"))
726
+ return pickle.loads(pickled_result)
727
+ return response_data
728
+ return response.json()