kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,730 @@
1
+ import asyncio
2
+ import json
3
+ import threading
4
+ import time
5
+ import urllib.parse
6
+ from datetime import datetime, timezone
7
+ from typing import Optional, Union
8
+
9
+ import httpx
10
+ import websockets
11
+
12
+ from kubernetes import client
13
+ from kubernetes.client.rest import ApiException
14
+
15
+ from kubetorch.globals import config, service_url
16
+ from kubetorch.logger import get_logger
17
+
18
+ from kubetorch.servers.http.utils import (
19
+ _deserialize_response,
20
+ _serialize_body,
21
+ generate_unique_request_id,
22
+ PodTerminatedError,
23
+ request_id_ctx_var,
24
+ )
25
+
26
+ from kubetorch.serving.constants import (
27
+ DEFAULT_DEBUG_PORT,
28
+ DEFAULT_NGINX_PORT,
29
+ KT_TERMINATION_REASONS,
30
+ )
31
+ from kubetorch.utils import extract_host_port, ServerLogsFormatter
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class CustomResponse(httpx.Response):
37
+ def raise_for_status(self):
38
+ """Raises parsed server errors or HTTPError for other status codes"""
39
+ if not 400 <= self.status_code < 600:
40
+ return
41
+
42
+ if "application/json" in self.headers.get("Content-Type", ""):
43
+ try:
44
+ error_data = self.json()
45
+ if all(
46
+ k in error_data
47
+ for k in ["error_type", "message", "traceback", "pod_name"]
48
+ ):
49
+ error_type = error_data["error_type"]
50
+ message = error_data.get("message", "")
51
+ traceback = error_data["traceback"]
52
+ pod_name = error_data["pod_name"]
53
+ error_state = error_data.get(
54
+ "state", {}
55
+ ) # Optional serialized state
56
+
57
+ # Try to use the actual exception class if it exists
58
+ exc = None
59
+ error_class = None
60
+
61
+ # Import the exception registry
62
+ try:
63
+ from kubetorch import EXCEPTION_REGISTRY
64
+ except ImportError:
65
+ EXCEPTION_REGISTRY = {}
66
+
67
+ # First check if it's a Python builtin exception
68
+ import builtins
69
+
70
+ if hasattr(builtins, error_type):
71
+ error_class = getattr(builtins, error_type)
72
+ # Otherwise try to use from the kubetorch registry
73
+ elif error_type in EXCEPTION_REGISTRY:
74
+ error_class = EXCEPTION_REGISTRY[error_type]
75
+
76
+ if error_class:
77
+ try:
78
+ # First try to reconstruct from state if available
79
+ if error_state and hasattr(error_class, "from_dict"):
80
+ exc = error_class.from_dict(error_state)
81
+ # Otherwise try simple construction with message
82
+ else:
83
+ exc = error_class(message)
84
+ except Exception as e:
85
+ logger.debug(
86
+ f"Could not reconstruct {error_type}: {e}, will use dynamic type"
87
+ )
88
+ # Fall back to dynamic creation
89
+ pass
90
+
91
+ # If we couldn't create the actual exception, fall back to dynamic type creation
92
+ if not exc:
93
+
94
+ def create_str_method(remote_traceback):
95
+ def __str__(self):
96
+ cleaned_traceback = remote_traceback.encode().decode(
97
+ "unicode_escape"
98
+ )
99
+ return f"{self.args[0]}\n\n{cleaned_traceback}"
100
+
101
+ return __str__
102
+
103
+ # Create the exception class with the custom __str__
104
+ error_class = type(
105
+ error_type,
106
+ (Exception,),
107
+ {"__str__": create_str_method(traceback)},
108
+ )
109
+
110
+ exc = error_class(message)
111
+
112
+ # Always add remote_traceback and pod_name
113
+ exc.remote_traceback = traceback
114
+ exc.pod_name = pod_name
115
+
116
+ # Wrap the exception to display remote traceback
117
+ # Create a new class that inherits from the original exception
118
+ # and overrides __str__ to include the remote traceback
119
+ class RemoteException(exc.__class__):
120
+ def __str__(self):
121
+ # Get the original message
122
+ original_msg = super().__str__()
123
+ # Clean up the traceback
124
+ cleaned_traceback = self.remote_traceback.encode().decode(
125
+ "unicode_escape"
126
+ )
127
+ return f"{original_msg}\n\n{cleaned_traceback}"
128
+
129
+ # Create wrapped instance without calling __init__
130
+ wrapped_exc = RemoteException.__new__(RemoteException)
131
+ # Copy all attributes from the original exception
132
+ wrapped_exc.__dict__.update(exc.__dict__)
133
+ # Set the exception args for proper display
134
+ wrapped_exc.args = (str(exc),)
135
+ raise wrapped_exc
136
+
137
+ except Exception as e:
138
+ # Catchall for errors during exception handling above
139
+ if isinstance(e, RemoteException):
140
+ # If we caught a RemoteException, it was packaged properly
141
+ raise
142
+ import httpx
143
+
144
+ raise httpx.HTTPStatusError(
145
+ f"{self.status_code} {self.text}",
146
+ request=self.request,
147
+ response=self,
148
+ )
149
+ else:
150
+ logger.debug(f"Non-JSON error body: {self.text[:100]}")
151
+ super().raise_for_status()
152
+
153
+
154
+ class CustomSession(httpx.Client):
155
+ def __init__(self):
156
+ limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
157
+ super().__init__(timeout=None, limits=limits)
158
+
159
+ def __del__(self):
160
+ self.close()
161
+
162
+ def request(self, *args, **kwargs):
163
+ response = super().request(*args, **kwargs)
164
+ response.__class__ = CustomResponse
165
+ return response
166
+
167
+
168
+ class CustomAsyncClient(httpx.AsyncClient):
169
+ def __init__(self):
170
+ limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
171
+ super().__init__(timeout=None, limits=limits)
172
+
173
+ async def request(self, *args, **kwargs):
174
+ response = await super().request(*args, **kwargs)
175
+ response.__class__ = CustomResponse
176
+ return response
177
+
178
+
179
+ class HTTPClient:
180
+ """Client for making HTTP requests to a remote service. Port forwards are shared between client
181
+ instances. Each port forward instance is cleaned up when the last reference is closed."""
182
+
183
+ def __init__(self, base_url, compute, service_name):
184
+ self._core_api = None
185
+ self._objects_api = None
186
+ self._tracing_enabled = config.tracing_enabled
187
+
188
+ self.compute = compute
189
+ self.service_name = service_name
190
+ self.base_url = base_url.rstrip("/")
191
+ self.session = CustomSession()
192
+ self._async_client = None
193
+
194
+ def __del__(self):
195
+ self.close()
196
+
197
+ def close(self):
198
+ """Close the async HTTP client to prevent resource leaks."""
199
+ if self._async_client:
200
+ try:
201
+ # Close the async client if it's still open
202
+ if not self._async_client.is_closed:
203
+ # Use asyncio.run if we're in a sync context, otherwise schedule the close
204
+ try:
205
+ import asyncio
206
+
207
+ loop = asyncio.get_event_loop()
208
+ if loop.is_running():
209
+ # If we're in an async context, schedule the close
210
+ loop.create_task(self._async_client.aclose())
211
+ else:
212
+ # If we're in a sync context, run the close
213
+ asyncio.run(self._async_client.aclose())
214
+ except RuntimeError:
215
+ # No event loop available, try to create one
216
+ try:
217
+ asyncio.run(self._async_client.aclose())
218
+ except Exception:
219
+ pass
220
+ except Exception as e:
221
+ logger.debug(f"Error closing async client: {e}")
222
+ finally:
223
+ self._async_client = None
224
+
225
+ # Close the session as well
226
+ if self.session:
227
+ try:
228
+ self.session.close()
229
+ except Exception as e:
230
+ logger.debug(f"Error closing session: {e}")
231
+ finally:
232
+ self.session = None
233
+
234
+ @property
235
+ def core_api(self):
236
+ if self._core_api is None:
237
+ self._core_api = client.CoreV1Api()
238
+ return self._core_api
239
+
240
+ @property
241
+ def objects_api(self):
242
+ if self._objects_api is None:
243
+ self._objects_api = client.CustomObjectsApi()
244
+ return self._objects_api
245
+
246
+ @property
247
+ def local_port(self):
248
+ """Local port to open the port forward connection with the proxy service. This should match the client port used
249
+ to set the URL of the service in the Compute class."""
250
+ if self.compute:
251
+ return self.compute.client_port()
252
+ return DEFAULT_NGINX_PORT
253
+
254
+ @property
255
+ def async_session(self):
256
+ """Get or create async HTTP client."""
257
+ if self._async_client is None:
258
+ self._async_client = CustomAsyncClient()
259
+ return self._async_client
260
+
261
+ # ----------------- Error Handling ----------------- #
262
+ def _handle_response_errors(self, response):
263
+ """If we didn't get json back, it could be that the pod is already dead but Knative returns the 500 because
264
+ the service was mid-termination."""
265
+ status_code = response.status_code
266
+ if status_code >= 500 and self._tracing_enabled:
267
+ request_id = response.request.headers.get("X-Request-ID")
268
+ pod_name, start_time = self._load_pod_metadata_from_tempo(
269
+ request_id=request_id
270
+ )
271
+ if pod_name:
272
+ self._handle_500x_error(pod_name, status_code, start_time)
273
+ else:
274
+ logger.debug(f"No pod name found for request {request_id}")
275
+
276
+ def _handle_500x_error(
277
+ self, pod_name: str, status_code: int, start_time: float
278
+ ) -> None:
279
+ """Handle 500x errors by surfacing container status and kubernetes events."""
280
+ termination_reason = None
281
+ try:
282
+ pod = self.core_api.read_namespaced_pod(
283
+ name=pod_name, namespace=self.compute.namespace
284
+ )
285
+
286
+ # Check container termination states for better reason
287
+ for container_status in pod.status.container_statuses or []:
288
+ state = container_status.state
289
+ # Note: if pod was killed abruptly kubelet might not have time to set the container state
290
+ if state.terminated:
291
+ termination_reason = state.terminated.reason
292
+ error_code = state.terminated.exit_code
293
+
294
+ if (
295
+ termination_reason not in KT_TERMINATION_REASONS
296
+ and error_code == 137
297
+ ):
298
+ termination_reason = "OOMKilled"
299
+ logger.warning(
300
+ "OOM suspected: pod exited with code 137 but no termination reason "
301
+ "found in Kubernetes events"
302
+ )
303
+
304
+ logger.debug(
305
+ f"Pod {pod_name} terminated with reason: {termination_reason}"
306
+ )
307
+ break
308
+
309
+ if termination_reason is None:
310
+ termination_reason = pod.status.reason
311
+
312
+ except ApiException as e:
313
+ termination_reason = e.reason if e.reason == "Not Found" else None
314
+
315
+ # we are updating termination_reason only if the pod is indeed was not found / terminated.
316
+ if termination_reason in KT_TERMINATION_REASONS:
317
+ # Convert start_time float to ISO8601 format for comparison
318
+ start_dt = datetime.fromtimestamp(start_time, tz=timezone.utc)
319
+
320
+ # Fetch pod events since request started
321
+ events = self._get_pod_events_since(pod_name, start_time=start_dt)
322
+
323
+ raise PodTerminatedError(
324
+ pod_name=pod_name,
325
+ reason=termination_reason,
326
+ status_code=status_code,
327
+ events=events,
328
+ )
329
+
330
+ def _get_pod_events_since(self, pod_name: str, start_time: datetime) -> list[dict]:
331
+ """Fetch all events for a pod since the given start time."""
332
+ try:
333
+ events = self.core_api.list_namespaced_event(
334
+ namespace=self.compute.namespace,
335
+ field_selector=f"involvedObject.name={pod_name}",
336
+ )
337
+ filtered_events = []
338
+ for event in events.items:
339
+ if event.first_timestamp and event.first_timestamp > start_time:
340
+ filtered_events.append(
341
+ {
342
+ "timestamp": event.first_timestamp,
343
+ "reason": event.reason,
344
+ "message": event.message,
345
+ }
346
+ )
347
+ return filtered_events
348
+ except Exception as e:
349
+ logger.warning(f"Failed to fetch pod events for {pod_name}: {e}")
350
+ return []
351
+
352
+ def _query_tempo_internal(
353
+ self, tempo_url: str, request_id: str, retries=5, delay=2.0
354
+ ) -> Optional[tuple[str, float]]:
355
+ """
356
+ Query Tempo for the trace with the given request_id and return:
357
+ - the pod name (`service.instance.id`)
358
+ - the trace start time in epoch seconds
359
+
360
+ Note: retries are used to handle Tempo not being ready yet.
361
+ """
362
+ for attempt in range(retries):
363
+ try:
364
+ search_url = f"{tempo_url}/api/search"
365
+ params = {"tags": f"request_id={request_id}"}
366
+ response = httpx.get(search_url, params=params, timeout=2)
367
+ response.raise_for_status()
368
+ traces = response.json().get("traces", [])
369
+ if traces:
370
+ trace_info = traces[0]
371
+ trace_id = trace_info["traceID"]
372
+ start_time_ns = int(trace_info["startTimeUnixNano"])
373
+ start_time_sec = start_time_ns / 1_000_000_000
374
+
375
+ # Fetch pod name from full trace detail
376
+ detail_url = f"{tempo_url}/api/traces/{trace_id}"
377
+ detail_response = httpx.get(detail_url, timeout=2)
378
+ detail_response.raise_for_status()
379
+ trace_data = detail_response.json()
380
+
381
+ for batch in trace_data.get("batches", []):
382
+ for attr in batch.get("resource", {}).get("attributes", []):
383
+ if attr["key"] == "service.instance.id":
384
+ pod_name = attr["value"].get("stringValue")
385
+ return pod_name, start_time_sec
386
+
387
+ except Exception as e:
388
+ logger.debug(f"Attempt {attempt + 1} failed to query Tempo: {e}")
389
+
390
+ if attempt < retries - 1:
391
+ logger.debug(f"Retrying loading traces in {delay} seconds...")
392
+ time.sleep(delay)
393
+
394
+ logger.warning("Failed to load pod metadata from Tempo")
395
+ return None, None
396
+
397
+ def _load_pod_metadata_from_tempo(self, request_id: str):
398
+ """Query Tempo for the trace with the given request_id and return the pod name if available.
399
+ Note there are a few reasons why we may fail to load the data from Tempo:
400
+ (1) Flush failure: pod was killed (OOM, preempted, etc.) before the OTEL exporter flushed to Tempo
401
+ (2) Tempo ingestion errors
402
+ """
403
+ base_url = service_url()
404
+ tempo_url = f"{base_url}/tempo"
405
+ return self._query_tempo_internal(tempo_url, request_id)
406
+
407
+ def _prepare_request(
408
+ self,
409
+ endpoint: str,
410
+ stream_logs: Union[bool, None],
411
+ headers: dict,
412
+ pdb: Union[bool, int],
413
+ serialization: str,
414
+ ):
415
+ if stream_logs is None:
416
+ stream_logs = config.stream_logs or False
417
+
418
+ if pdb:
419
+ debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
420
+ endpoint += f"?debug_port={debug_port}"
421
+
422
+ request_id = request_id_ctx_var.get("-")
423
+ if request_id == "-":
424
+ timestamp = str(time.time())
425
+ request_id = generate_unique_request_id(
426
+ endpoint=endpoint, timestamp=timestamp
427
+ )
428
+
429
+ headers = headers or {}
430
+ headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
431
+
432
+ stop_event = threading.Event()
433
+ log_thread = None
434
+ if stream_logs:
435
+ log_thread = threading.Thread(
436
+ target=self.stream_logs, args=(request_id, stop_event)
437
+ )
438
+ log_thread.daemon = True
439
+ log_thread.start()
440
+
441
+ return endpoint, headers, stop_event, log_thread, request_id
442
+
443
+ def _prepare_request_async(
444
+ self,
445
+ endpoint: str,
446
+ stream_logs: Union[bool, None],
447
+ headers: dict,
448
+ pdb: Union[bool, int],
449
+ serialization: str,
450
+ ):
451
+ """Async version of _prepare_request that uses asyncio.Event and tasks instead of threads"""
452
+ if stream_logs is None:
453
+ stream_logs = config.stream_logs or False
454
+
455
+ if pdb:
456
+ debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
457
+ endpoint += f"?debug_port={debug_port}"
458
+
459
+ request_id = request_id_ctx_var.get("-")
460
+ if request_id == "-":
461
+ timestamp = str(time.time())
462
+ request_id = generate_unique_request_id(
463
+ endpoint=endpoint, timestamp=timestamp
464
+ )
465
+
466
+ headers = headers or {}
467
+ headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
468
+
469
+ stop_event = asyncio.Event()
470
+ log_task = None
471
+ if stream_logs:
472
+ log_task = asyncio.create_task(
473
+ self.stream_logs_async(request_id, stop_event)
474
+ )
475
+
476
+ return endpoint, headers, stop_event, log_task, request_id
477
+
478
+ def _make_request(self, method, endpoint, **kwargs):
479
+ response: httpx.Response = getattr(self.session, method)(endpoint, **kwargs)
480
+ self._handle_response_errors(response)
481
+ response.raise_for_status()
482
+ return response
483
+
484
+ async def _make_request_async(self, method, endpoint, **kwargs):
485
+ """Async version of _make_request."""
486
+ response = await getattr(self.async_session, method)(endpoint, **kwargs)
487
+ self._handle_response_errors(response)
488
+ response.raise_for_status()
489
+ return response
490
+
491
+ # ----------------- Stream Helpers ----------------- #
492
+ async def _stream_logs_websocket(
493
+ self,
494
+ request_id,
495
+ stop_event: Union[threading.Event, asyncio.Event],
496
+ port: int,
497
+ host: str = "localhost",
498
+ ):
499
+ """Stream logs using Loki's websocket tail endpoint"""
500
+ formatter = ServerLogsFormatter()
501
+ websocket = None
502
+ try:
503
+ query = (
504
+ f'{{k8s_container_name="kubetorch"}} | json | request_id="{request_id}"'
505
+ )
506
+ encoded_query = urllib.parse.quote_plus(query)
507
+ uri = f"ws://{host}:{port}/loki/api/v1/tail?query={encoded_query}"
508
+ # Track the last timestamp we've seen to avoid duplicates
509
+ last_timestamp = None
510
+ # Track when we should stop
511
+ stop_time = None
512
+
513
+ # Add timeout to prevent hanging connections
514
+ logger.debug(f"Streaming logs with tail query {uri}")
515
+ websocket = await websockets.connect(
516
+ uri,
517
+ close_timeout=10, # Max time to wait for close handshake
518
+ ping_interval=20, # Send ping every 20 seconds
519
+ ping_timeout=10, # Wait 10 seconds for pong
520
+ )
521
+ try:
522
+ while True:
523
+ # If stop event is set, start counting down
524
+ # Handle both threading.Event and asyncio.Event
525
+ is_stop_set = (
526
+ stop_event.is_set()
527
+ if hasattr(stop_event, "is_set")
528
+ else stop_event.is_set()
529
+ )
530
+ if is_stop_set and stop_time is None:
531
+ stop_time = time.time() + 2 # 2 seconds grace period
532
+
533
+ # If we're past the grace period, exit
534
+ if stop_time is not None and time.time() > stop_time:
535
+ break
536
+
537
+ try:
538
+ # Use shorter timeout during grace period
539
+ timeout = 0.1 if stop_time is not None else 1.0
540
+ message = await asyncio.wait_for(
541
+ websocket.recv(), timeout=timeout
542
+ )
543
+ data = json.loads(message)
544
+
545
+ if data.get("streams"):
546
+ for stream in data["streams"]:
547
+ labels = stream["stream"]
548
+ service_name = labels.get("kubetorch_com_service")
549
+
550
+ # Determine if this is a Knative service by checking for Knative-specific labels
551
+ is_knative = (
552
+ labels.get("serving_knative_dev_configuration")
553
+ is not None
554
+ )
555
+
556
+ for value in stream["values"]:
557
+ # Skip if we've already seen this timestamp
558
+ log_line = json.loads(value[1])
559
+ log_name = log_line.get("name")
560
+ log_message = log_line.get("message")
561
+ current_timestamp = value[0]
562
+ if (
563
+ last_timestamp is not None
564
+ and current_timestamp <= last_timestamp
565
+ ):
566
+ continue
567
+ last_timestamp = value[0]
568
+
569
+ # Choose the appropriate identifier for the log prefix
570
+ if is_knative:
571
+ log_prefix = service_name
572
+ else:
573
+ # For deployments, use the pod name from the structured log
574
+ log_prefix = log_line.get("pod", service_name)
575
+
576
+ if log_name == "print_redirect":
577
+ print(
578
+ f"{formatter.start_color}({log_prefix}) {log_message}{formatter.reset_color}"
579
+ )
580
+ elif log_name != "uvicorn.access":
581
+ formatted_log = f"({log_prefix}) {log_line.get('asctime')} | {log_line.get('levelname')} | {log_message}"
582
+ print(
583
+ f"{formatter.start_color}{formatted_log}{formatter.reset_color}"
584
+ )
585
+ except asyncio.TimeoutError:
586
+ # Timeout is expected, just continue the loop
587
+ continue
588
+ except websockets.exceptions.ConnectionClosed as e:
589
+ logger.debug(f"WebSocket connection closed: {str(e)}")
590
+ break
591
+ finally:
592
+ if websocket:
593
+ try:
594
+ # Use wait_for to prevent hanging on close
595
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
596
+ except (asyncio.TimeoutError, Exception):
597
+ pass
598
+ except Exception as e:
599
+ logger.error(f"Error in websocket stream: {e}")
600
+ finally:
601
+ # Ensure websocket is closed even if we didn't enter the context
602
+ if websocket:
603
+ try:
604
+ # Use wait_for to prevent hanging on close
605
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
606
+ except (asyncio.TimeoutError, Exception):
607
+ pass
608
+
609
+ def _run_log_stream(self, request_id, stop_event, host, port):
610
+ """Helper to run log streaming in an event loop"""
611
+ loop = asyncio.new_event_loop()
612
+ asyncio.set_event_loop(loop)
613
+ try:
614
+ loop.run_until_complete(
615
+ self._stream_logs_websocket(
616
+ request_id, stop_event, host=host, port=port
617
+ )
618
+ )
619
+ finally:
620
+ loop.close()
621
+
622
+ # ----------------- Core APIs ----------------- #
623
+ def stream_logs(self, request_id, stop_event):
624
+ """Start websocket log streaming in a separate thread"""
625
+ logger.debug(
626
+ f"Streaming logs for service {self.service_name} (request_id: {request_id})"
627
+ )
628
+
629
+ base_url = service_url()
630
+ base_host, base_port = extract_host_port(base_url)
631
+ self._run_log_stream(request_id, stop_event, base_host, base_port)
632
+
633
+ async def stream_logs_async(self, request_id, stop_event):
634
+ """Async version of stream_logs. Start websocket log streaming as an async task"""
635
+ logger.debug(
636
+ f"Streaming logs for service {self.service_name} (request_id: {request_id})"
637
+ )
638
+
639
+ base_url = service_url()
640
+ base_host, base_port = extract_host_port(base_url)
641
+ await self._stream_logs_websocket(
642
+ request_id, stop_event, host=base_host, port=base_port
643
+ )
644
+
645
+ def call_method(
646
+ self,
647
+ endpoint: str,
648
+ stream_logs: Union[bool, None] = None,
649
+ body: dict = None,
650
+ headers: dict = None,
651
+ pdb: Union[bool, int] = None,
652
+ serialization: str = "json",
653
+ ):
654
+ endpoint, headers, stop_event, log_thread, _ = self._prepare_request(
655
+ endpoint, stream_logs, headers, pdb, serialization
656
+ )
657
+ try:
658
+ json_data = _serialize_body(body, serialization)
659
+ response = self.post(endpoint=endpoint, json=json_data, headers=headers)
660
+ response.raise_for_status()
661
+ return _deserialize_response(response, serialization)
662
+ finally:
663
+ stop_event.set()
664
+
665
+ async def call_method_async(
666
+ self,
667
+ endpoint: str,
668
+ stream_logs: Union[bool, None] = None,
669
+ body: dict = None,
670
+ headers: dict = None,
671
+ pdb: Union[bool, int] = None,
672
+ serialization: str = "json",
673
+ ):
674
+ """Async version of call_method."""
675
+ endpoint, headers, stop_event, log_task, _ = self._prepare_request_async(
676
+ endpoint, stream_logs, headers, pdb, serialization
677
+ )
678
+ try:
679
+ json_data = _serialize_body(body, serialization)
680
+ response = await self.post_async(
681
+ endpoint=endpoint, json=json_data, headers=headers
682
+ )
683
+ response.raise_for_status()
684
+ result = _deserialize_response(response, serialization)
685
+
686
+ if stream_logs and log_task:
687
+ await asyncio.sleep(0.5)
688
+
689
+ return result
690
+ finally:
691
+ stop_event.set()
692
+ if log_task:
693
+ try:
694
+ await asyncio.wait_for(log_task, timeout=0.5)
695
+ except asyncio.TimeoutError:
696
+ log_task.cancel()
697
+ try:
698
+ await log_task
699
+ except asyncio.CancelledError:
700
+ pass
701
+
702
+ def post(self, endpoint, json=None, headers=None):
703
+ return self._make_request("post", endpoint, json=json, headers=headers)
704
+
705
+ def put(self, endpoint, json=None, headers=None):
706
+ return self._make_request("put", endpoint, json=json, headers=headers)
707
+
708
+ def delete(self, endpoint, json=None, headers=None):
709
+ return self._make_request("delete", endpoint, json=json, headers=headers)
710
+
711
+ def get(self, endpoint, headers=None):
712
+ return self._make_request("get", endpoint, headers=headers)
713
+
714
+ async def post_async(self, endpoint, json=None, headers=None):
715
+ return await self._make_request_async(
716
+ "post", endpoint, json=json, headers=headers
717
+ )
718
+
719
+ async def put_async(self, endpoint, json=None, headers=None):
720
+ return await self._make_request_async(
721
+ "put", endpoint, json=json, headers=headers
722
+ )
723
+
724
+ async def delete_async(self, endpoint, json=None, headers=None):
725
+ return await self._make_request_async(
726
+ "delete", endpoint, json=json, headers=headers
727
+ )
728
+
729
+ async def get_async(self, endpoint, headers=None):
730
+ return await self._make_request_async("get", endpoint, headers=headers)