kubetorch 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kubetorch might be problematic. Click here for more details.
- kubetorch/__init__.py +60 -0
- kubetorch/cli.py +1985 -0
- kubetorch/cli_utils.py +1025 -0
- kubetorch/config.py +453 -0
- kubetorch/constants.py +18 -0
- kubetorch/docs/Makefile +18 -0
- kubetorch/docs/__init__.py +0 -0
- kubetorch/docs/_ext/json_globaltoc.py +42 -0
- kubetorch/docs/api/cli.rst +10 -0
- kubetorch/docs/api/python/app.rst +21 -0
- kubetorch/docs/api/python/cls.rst +19 -0
- kubetorch/docs/api/python/compute.rst +25 -0
- kubetorch/docs/api/python/config.rst +11 -0
- kubetorch/docs/api/python/fn.rst +19 -0
- kubetorch/docs/api/python/image.rst +14 -0
- kubetorch/docs/api/python/secret.rst +18 -0
- kubetorch/docs/api/python/volumes.rst +13 -0
- kubetorch/docs/api/python.rst +101 -0
- kubetorch/docs/conf.py +69 -0
- kubetorch/docs/index.rst +20 -0
- kubetorch/docs/requirements.txt +5 -0
- kubetorch/globals.py +285 -0
- kubetorch/logger.py +59 -0
- kubetorch/resources/__init__.py +0 -0
- kubetorch/resources/callables/__init__.py +0 -0
- kubetorch/resources/callables/cls/__init__.py +0 -0
- kubetorch/resources/callables/cls/cls.py +157 -0
- kubetorch/resources/callables/fn/__init__.py +0 -0
- kubetorch/resources/callables/fn/fn.py +133 -0
- kubetorch/resources/callables/module.py +1416 -0
- kubetorch/resources/callables/utils.py +174 -0
- kubetorch/resources/compute/__init__.py +0 -0
- kubetorch/resources/compute/app.py +261 -0
- kubetorch/resources/compute/compute.py +2596 -0
- kubetorch/resources/compute/decorators.py +139 -0
- kubetorch/resources/compute/rbac.py +74 -0
- kubetorch/resources/compute/utils.py +1114 -0
- kubetorch/resources/compute/websocket.py +137 -0
- kubetorch/resources/images/__init__.py +1 -0
- kubetorch/resources/images/image.py +414 -0
- kubetorch/resources/images/images.py +74 -0
- kubetorch/resources/secrets/__init__.py +2 -0
- kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
- kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
- kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
- kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
- kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
- kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
- kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
- kubetorch/resources/secrets/secret.py +238 -0
- kubetorch/resources/secrets/secret_factory.py +70 -0
- kubetorch/resources/secrets/utils.py +209 -0
- kubetorch/resources/volumes/__init__.py +0 -0
- kubetorch/resources/volumes/volume.py +365 -0
- kubetorch/servers/__init__.py +0 -0
- kubetorch/servers/http/__init__.py +0 -0
- kubetorch/servers/http/distributed_utils.py +3223 -0
- kubetorch/servers/http/http_client.py +730 -0
- kubetorch/servers/http/http_server.py +1788 -0
- kubetorch/servers/http/server_metrics.py +278 -0
- kubetorch/servers/http/utils.py +728 -0
- kubetorch/serving/__init__.py +0 -0
- kubetorch/serving/autoscaling.py +173 -0
- kubetorch/serving/base_service_manager.py +363 -0
- kubetorch/serving/constants.py +83 -0
- kubetorch/serving/deployment_service_manager.py +478 -0
- kubetorch/serving/knative_service_manager.py +519 -0
- kubetorch/serving/raycluster_service_manager.py +582 -0
- kubetorch/serving/service_manager.py +18 -0
- kubetorch/serving/templates/deployment_template.yaml +17 -0
- kubetorch/serving/templates/knative_service_template.yaml +19 -0
- kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
- kubetorch/serving/templates/pod_template.yaml +194 -0
- kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
- kubetorch/serving/templates/raycluster_template.yaml +35 -0
- kubetorch/serving/templates/service_template.yaml +21 -0
- kubetorch/serving/templates/workerset_template.yaml +36 -0
- kubetorch/serving/utils.py +377 -0
- kubetorch/utils.py +284 -0
- kubetorch-0.2.0.dist-info/METADATA +121 -0
- kubetorch-0.2.0.dist-info/RECORD +93 -0
- kubetorch-0.2.0.dist-info/WHEEL +4 -0
- kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,730 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
import urllib.parse
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
import websockets
|
|
11
|
+
|
|
12
|
+
from kubernetes import client
|
|
13
|
+
from kubernetes.client.rest import ApiException
|
|
14
|
+
|
|
15
|
+
from kubetorch.globals import config, service_url
|
|
16
|
+
from kubetorch.logger import get_logger
|
|
17
|
+
|
|
18
|
+
from kubetorch.servers.http.utils import (
|
|
19
|
+
_deserialize_response,
|
|
20
|
+
_serialize_body,
|
|
21
|
+
generate_unique_request_id,
|
|
22
|
+
PodTerminatedError,
|
|
23
|
+
request_id_ctx_var,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from kubetorch.serving.constants import (
|
|
27
|
+
DEFAULT_DEBUG_PORT,
|
|
28
|
+
DEFAULT_NGINX_PORT,
|
|
29
|
+
KT_TERMINATION_REASONS,
|
|
30
|
+
)
|
|
31
|
+
from kubetorch.utils import extract_host_port, ServerLogsFormatter
|
|
32
|
+
|
|
33
|
+
logger = get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CustomResponse(httpx.Response):
|
|
37
|
+
def raise_for_status(self):
|
|
38
|
+
"""Raises parsed server errors or HTTPError for other status codes"""
|
|
39
|
+
if not 400 <= self.status_code < 600:
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
if "application/json" in self.headers.get("Content-Type", ""):
|
|
43
|
+
try:
|
|
44
|
+
error_data = self.json()
|
|
45
|
+
if all(
|
|
46
|
+
k in error_data
|
|
47
|
+
for k in ["error_type", "message", "traceback", "pod_name"]
|
|
48
|
+
):
|
|
49
|
+
error_type = error_data["error_type"]
|
|
50
|
+
message = error_data.get("message", "")
|
|
51
|
+
traceback = error_data["traceback"]
|
|
52
|
+
pod_name = error_data["pod_name"]
|
|
53
|
+
error_state = error_data.get(
|
|
54
|
+
"state", {}
|
|
55
|
+
) # Optional serialized state
|
|
56
|
+
|
|
57
|
+
# Try to use the actual exception class if it exists
|
|
58
|
+
exc = None
|
|
59
|
+
error_class = None
|
|
60
|
+
|
|
61
|
+
# Import the exception registry
|
|
62
|
+
try:
|
|
63
|
+
from kubetorch import EXCEPTION_REGISTRY
|
|
64
|
+
except ImportError:
|
|
65
|
+
EXCEPTION_REGISTRY = {}
|
|
66
|
+
|
|
67
|
+
# First check if it's a Python builtin exception
|
|
68
|
+
import builtins
|
|
69
|
+
|
|
70
|
+
if hasattr(builtins, error_type):
|
|
71
|
+
error_class = getattr(builtins, error_type)
|
|
72
|
+
# Otherwise try to use from the kubetorch registry
|
|
73
|
+
elif error_type in EXCEPTION_REGISTRY:
|
|
74
|
+
error_class = EXCEPTION_REGISTRY[error_type]
|
|
75
|
+
|
|
76
|
+
if error_class:
|
|
77
|
+
try:
|
|
78
|
+
# First try to reconstruct from state if available
|
|
79
|
+
if error_state and hasattr(error_class, "from_dict"):
|
|
80
|
+
exc = error_class.from_dict(error_state)
|
|
81
|
+
# Otherwise try simple construction with message
|
|
82
|
+
else:
|
|
83
|
+
exc = error_class(message)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.debug(
|
|
86
|
+
f"Could not reconstruct {error_type}: {e}, will use dynamic type"
|
|
87
|
+
)
|
|
88
|
+
# Fall back to dynamic creation
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
# If we couldn't create the actual exception, fall back to dynamic type creation
|
|
92
|
+
if not exc:
|
|
93
|
+
|
|
94
|
+
def create_str_method(remote_traceback):
|
|
95
|
+
def __str__(self):
|
|
96
|
+
cleaned_traceback = remote_traceback.encode().decode(
|
|
97
|
+
"unicode_escape"
|
|
98
|
+
)
|
|
99
|
+
return f"{self.args[0]}\n\n{cleaned_traceback}"
|
|
100
|
+
|
|
101
|
+
return __str__
|
|
102
|
+
|
|
103
|
+
# Create the exception class with the custom __str__
|
|
104
|
+
error_class = type(
|
|
105
|
+
error_type,
|
|
106
|
+
(Exception,),
|
|
107
|
+
{"__str__": create_str_method(traceback)},
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
exc = error_class(message)
|
|
111
|
+
|
|
112
|
+
# Always add remote_traceback and pod_name
|
|
113
|
+
exc.remote_traceback = traceback
|
|
114
|
+
exc.pod_name = pod_name
|
|
115
|
+
|
|
116
|
+
# Wrap the exception to display remote traceback
|
|
117
|
+
# Create a new class that inherits from the original exception
|
|
118
|
+
# and overrides __str__ to include the remote traceback
|
|
119
|
+
class RemoteException(exc.__class__):
|
|
120
|
+
def __str__(self):
|
|
121
|
+
# Get the original message
|
|
122
|
+
original_msg = super().__str__()
|
|
123
|
+
# Clean up the traceback
|
|
124
|
+
cleaned_traceback = self.remote_traceback.encode().decode(
|
|
125
|
+
"unicode_escape"
|
|
126
|
+
)
|
|
127
|
+
return f"{original_msg}\n\n{cleaned_traceback}"
|
|
128
|
+
|
|
129
|
+
# Create wrapped instance without calling __init__
|
|
130
|
+
wrapped_exc = RemoteException.__new__(RemoteException)
|
|
131
|
+
# Copy all attributes from the original exception
|
|
132
|
+
wrapped_exc.__dict__.update(exc.__dict__)
|
|
133
|
+
# Set the exception args for proper display
|
|
134
|
+
wrapped_exc.args = (str(exc),)
|
|
135
|
+
raise wrapped_exc
|
|
136
|
+
|
|
137
|
+
except Exception as e:
|
|
138
|
+
# Catchall for errors during exception handling above
|
|
139
|
+
if isinstance(e, RemoteException):
|
|
140
|
+
# If we caught a RemoteException, it was packaged properly
|
|
141
|
+
raise
|
|
142
|
+
import httpx
|
|
143
|
+
|
|
144
|
+
raise httpx.HTTPStatusError(
|
|
145
|
+
f"{self.status_code} {self.text}",
|
|
146
|
+
request=self.request,
|
|
147
|
+
response=self,
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
logger.debug(f"Non-JSON error body: {self.text[:100]}")
|
|
151
|
+
super().raise_for_status()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CustomSession(httpx.Client):
|
|
155
|
+
def __init__(self):
|
|
156
|
+
limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
|
|
157
|
+
super().__init__(timeout=None, limits=limits)
|
|
158
|
+
|
|
159
|
+
def __del__(self):
|
|
160
|
+
self.close()
|
|
161
|
+
|
|
162
|
+
def request(self, *args, **kwargs):
|
|
163
|
+
response = super().request(*args, **kwargs)
|
|
164
|
+
response.__class__ = CustomResponse
|
|
165
|
+
return response
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class CustomAsyncClient(httpx.AsyncClient):
|
|
169
|
+
def __init__(self):
|
|
170
|
+
limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
|
|
171
|
+
super().__init__(timeout=None, limits=limits)
|
|
172
|
+
|
|
173
|
+
async def request(self, *args, **kwargs):
|
|
174
|
+
response = await super().request(*args, **kwargs)
|
|
175
|
+
response.__class__ = CustomResponse
|
|
176
|
+
return response
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class HTTPClient:
|
|
180
|
+
"""Client for making HTTP requests to a remote service. Port forwards are shared between client
|
|
181
|
+
instances. Each port forward instance is cleaned up when the last reference is closed."""
|
|
182
|
+
|
|
183
|
+
def __init__(self, base_url, compute, service_name):
|
|
184
|
+
self._core_api = None
|
|
185
|
+
self._objects_api = None
|
|
186
|
+
self._tracing_enabled = config.tracing_enabled
|
|
187
|
+
|
|
188
|
+
self.compute = compute
|
|
189
|
+
self.service_name = service_name
|
|
190
|
+
self.base_url = base_url.rstrip("/")
|
|
191
|
+
self.session = CustomSession()
|
|
192
|
+
self._async_client = None
|
|
193
|
+
|
|
194
|
+
def __del__(self):
|
|
195
|
+
self.close()
|
|
196
|
+
|
|
197
|
+
def close(self):
|
|
198
|
+
"""Close the async HTTP client to prevent resource leaks."""
|
|
199
|
+
if self._async_client:
|
|
200
|
+
try:
|
|
201
|
+
# Close the async client if it's still open
|
|
202
|
+
if not self._async_client.is_closed:
|
|
203
|
+
# Use asyncio.run if we're in a sync context, otherwise schedule the close
|
|
204
|
+
try:
|
|
205
|
+
import asyncio
|
|
206
|
+
|
|
207
|
+
loop = asyncio.get_event_loop()
|
|
208
|
+
if loop.is_running():
|
|
209
|
+
# If we're in an async context, schedule the close
|
|
210
|
+
loop.create_task(self._async_client.aclose())
|
|
211
|
+
else:
|
|
212
|
+
# If we're in a sync context, run the close
|
|
213
|
+
asyncio.run(self._async_client.aclose())
|
|
214
|
+
except RuntimeError:
|
|
215
|
+
# No event loop available, try to create one
|
|
216
|
+
try:
|
|
217
|
+
asyncio.run(self._async_client.aclose())
|
|
218
|
+
except Exception:
|
|
219
|
+
pass
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.debug(f"Error closing async client: {e}")
|
|
222
|
+
finally:
|
|
223
|
+
self._async_client = None
|
|
224
|
+
|
|
225
|
+
# Close the session as well
|
|
226
|
+
if self.session:
|
|
227
|
+
try:
|
|
228
|
+
self.session.close()
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.debug(f"Error closing session: {e}")
|
|
231
|
+
finally:
|
|
232
|
+
self.session = None
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def core_api(self):
|
|
236
|
+
if self._core_api is None:
|
|
237
|
+
self._core_api = client.CoreV1Api()
|
|
238
|
+
return self._core_api
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def objects_api(self):
|
|
242
|
+
if self._objects_api is None:
|
|
243
|
+
self._objects_api = client.CustomObjectsApi()
|
|
244
|
+
return self._objects_api
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def local_port(self):
|
|
248
|
+
"""Local port to open the port forward connection with the proxy service. This should match the client port used
|
|
249
|
+
to set the URL of the service in the Compute class."""
|
|
250
|
+
if self.compute:
|
|
251
|
+
return self.compute.client_port()
|
|
252
|
+
return DEFAULT_NGINX_PORT
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def async_session(self):
|
|
256
|
+
"""Get or create async HTTP client."""
|
|
257
|
+
if self._async_client is None:
|
|
258
|
+
self._async_client = CustomAsyncClient()
|
|
259
|
+
return self._async_client
|
|
260
|
+
|
|
261
|
+
# ----------------- Error Handling ----------------- #
|
|
262
|
+
def _handle_response_errors(self, response):
|
|
263
|
+
"""If we didn't get json back, it could be that the pod is already dead but Knative returns the 500 because
|
|
264
|
+
the service was mid-termination."""
|
|
265
|
+
status_code = response.status_code
|
|
266
|
+
if status_code >= 500 and self._tracing_enabled:
|
|
267
|
+
request_id = response.request.headers.get("X-Request-ID")
|
|
268
|
+
pod_name, start_time = self._load_pod_metadata_from_tempo(
|
|
269
|
+
request_id=request_id
|
|
270
|
+
)
|
|
271
|
+
if pod_name:
|
|
272
|
+
self._handle_500x_error(pod_name, status_code, start_time)
|
|
273
|
+
else:
|
|
274
|
+
logger.debug(f"No pod name found for request {request_id}")
|
|
275
|
+
|
|
276
|
+
def _handle_500x_error(
|
|
277
|
+
self, pod_name: str, status_code: int, start_time: float
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Handle 500x errors by surfacing container status and kubernetes events."""
|
|
280
|
+
termination_reason = None
|
|
281
|
+
try:
|
|
282
|
+
pod = self.core_api.read_namespaced_pod(
|
|
283
|
+
name=pod_name, namespace=self.compute.namespace
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Check container termination states for better reason
|
|
287
|
+
for container_status in pod.status.container_statuses or []:
|
|
288
|
+
state = container_status.state
|
|
289
|
+
# Note: if pod was killed abruptly kubelet might not have time to set the container state
|
|
290
|
+
if state.terminated:
|
|
291
|
+
termination_reason = state.terminated.reason
|
|
292
|
+
error_code = state.terminated.exit_code
|
|
293
|
+
|
|
294
|
+
if (
|
|
295
|
+
termination_reason not in KT_TERMINATION_REASONS
|
|
296
|
+
and error_code == 137
|
|
297
|
+
):
|
|
298
|
+
termination_reason = "OOMKilled"
|
|
299
|
+
logger.warning(
|
|
300
|
+
"OOM suspected: pod exited with code 137 but no termination reason "
|
|
301
|
+
"found in Kubernetes events"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
logger.debug(
|
|
305
|
+
f"Pod {pod_name} terminated with reason: {termination_reason}"
|
|
306
|
+
)
|
|
307
|
+
break
|
|
308
|
+
|
|
309
|
+
if termination_reason is None:
|
|
310
|
+
termination_reason = pod.status.reason
|
|
311
|
+
|
|
312
|
+
except ApiException as e:
|
|
313
|
+
termination_reason = e.reason if e.reason == "Not Found" else None
|
|
314
|
+
|
|
315
|
+
# we are updating termination_reason only if the pod is indeed was not found / terminated.
|
|
316
|
+
if termination_reason in KT_TERMINATION_REASONS:
|
|
317
|
+
# Convert start_time float to ISO8601 format for comparison
|
|
318
|
+
start_dt = datetime.fromtimestamp(start_time, tz=timezone.utc)
|
|
319
|
+
|
|
320
|
+
# Fetch pod events since request started
|
|
321
|
+
events = self._get_pod_events_since(pod_name, start_time=start_dt)
|
|
322
|
+
|
|
323
|
+
raise PodTerminatedError(
|
|
324
|
+
pod_name=pod_name,
|
|
325
|
+
reason=termination_reason,
|
|
326
|
+
status_code=status_code,
|
|
327
|
+
events=events,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def _get_pod_events_since(self, pod_name: str, start_time: datetime) -> list[dict]:
|
|
331
|
+
"""Fetch all events for a pod since the given start time."""
|
|
332
|
+
try:
|
|
333
|
+
events = self.core_api.list_namespaced_event(
|
|
334
|
+
namespace=self.compute.namespace,
|
|
335
|
+
field_selector=f"involvedObject.name={pod_name}",
|
|
336
|
+
)
|
|
337
|
+
filtered_events = []
|
|
338
|
+
for event in events.items:
|
|
339
|
+
if event.first_timestamp and event.first_timestamp > start_time:
|
|
340
|
+
filtered_events.append(
|
|
341
|
+
{
|
|
342
|
+
"timestamp": event.first_timestamp,
|
|
343
|
+
"reason": event.reason,
|
|
344
|
+
"message": event.message,
|
|
345
|
+
}
|
|
346
|
+
)
|
|
347
|
+
return filtered_events
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.warning(f"Failed to fetch pod events for {pod_name}: {e}")
|
|
350
|
+
return []
|
|
351
|
+
|
|
352
|
+
def _query_tempo_internal(
|
|
353
|
+
self, tempo_url: str, request_id: str, retries=5, delay=2.0
|
|
354
|
+
) -> Optional[tuple[str, float]]:
|
|
355
|
+
"""
|
|
356
|
+
Query Tempo for the trace with the given request_id and return:
|
|
357
|
+
- the pod name (`service.instance.id`)
|
|
358
|
+
- the trace start time in epoch seconds
|
|
359
|
+
|
|
360
|
+
Note: retries are used to handle Tempo not being ready yet.
|
|
361
|
+
"""
|
|
362
|
+
for attempt in range(retries):
|
|
363
|
+
try:
|
|
364
|
+
search_url = f"{tempo_url}/api/search"
|
|
365
|
+
params = {"tags": f"request_id={request_id}"}
|
|
366
|
+
response = httpx.get(search_url, params=params, timeout=2)
|
|
367
|
+
response.raise_for_status()
|
|
368
|
+
traces = response.json().get("traces", [])
|
|
369
|
+
if traces:
|
|
370
|
+
trace_info = traces[0]
|
|
371
|
+
trace_id = trace_info["traceID"]
|
|
372
|
+
start_time_ns = int(trace_info["startTimeUnixNano"])
|
|
373
|
+
start_time_sec = start_time_ns / 1_000_000_000
|
|
374
|
+
|
|
375
|
+
# Fetch pod name from full trace detail
|
|
376
|
+
detail_url = f"{tempo_url}/api/traces/{trace_id}"
|
|
377
|
+
detail_response = httpx.get(detail_url, timeout=2)
|
|
378
|
+
detail_response.raise_for_status()
|
|
379
|
+
trace_data = detail_response.json()
|
|
380
|
+
|
|
381
|
+
for batch in trace_data.get("batches", []):
|
|
382
|
+
for attr in batch.get("resource", {}).get("attributes", []):
|
|
383
|
+
if attr["key"] == "service.instance.id":
|
|
384
|
+
pod_name = attr["value"].get("stringValue")
|
|
385
|
+
return pod_name, start_time_sec
|
|
386
|
+
|
|
387
|
+
except Exception as e:
|
|
388
|
+
logger.debug(f"Attempt {attempt + 1} failed to query Tempo: {e}")
|
|
389
|
+
|
|
390
|
+
if attempt < retries - 1:
|
|
391
|
+
logger.debug(f"Retrying loading traces in {delay} seconds...")
|
|
392
|
+
time.sleep(delay)
|
|
393
|
+
|
|
394
|
+
logger.warning("Failed to load pod metadata from Tempo")
|
|
395
|
+
return None, None
|
|
396
|
+
|
|
397
|
+
def _load_pod_metadata_from_tempo(self, request_id: str):
|
|
398
|
+
"""Query Tempo for the trace with the given request_id and return the pod name if available.
|
|
399
|
+
Note there are a few reasons why we may fail to load the data from Tempo:
|
|
400
|
+
(1) Flush failure: pod was killed (OOM, preempted, etc.) before the OTEL exporter flushed to Tempo
|
|
401
|
+
(2) Tempo ingestion errors
|
|
402
|
+
"""
|
|
403
|
+
base_url = service_url()
|
|
404
|
+
tempo_url = f"{base_url}/tempo"
|
|
405
|
+
return self._query_tempo_internal(tempo_url, request_id)
|
|
406
|
+
|
|
407
|
+
def _prepare_request(
|
|
408
|
+
self,
|
|
409
|
+
endpoint: str,
|
|
410
|
+
stream_logs: Union[bool, None],
|
|
411
|
+
headers: dict,
|
|
412
|
+
pdb: Union[bool, int],
|
|
413
|
+
serialization: str,
|
|
414
|
+
):
|
|
415
|
+
if stream_logs is None:
|
|
416
|
+
stream_logs = config.stream_logs or False
|
|
417
|
+
|
|
418
|
+
if pdb:
|
|
419
|
+
debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
|
|
420
|
+
endpoint += f"?debug_port={debug_port}"
|
|
421
|
+
|
|
422
|
+
request_id = request_id_ctx_var.get("-")
|
|
423
|
+
if request_id == "-":
|
|
424
|
+
timestamp = str(time.time())
|
|
425
|
+
request_id = generate_unique_request_id(
|
|
426
|
+
endpoint=endpoint, timestamp=timestamp
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
headers = headers or {}
|
|
430
|
+
headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
|
|
431
|
+
|
|
432
|
+
stop_event = threading.Event()
|
|
433
|
+
log_thread = None
|
|
434
|
+
if stream_logs:
|
|
435
|
+
log_thread = threading.Thread(
|
|
436
|
+
target=self.stream_logs, args=(request_id, stop_event)
|
|
437
|
+
)
|
|
438
|
+
log_thread.daemon = True
|
|
439
|
+
log_thread.start()
|
|
440
|
+
|
|
441
|
+
return endpoint, headers, stop_event, log_thread, request_id
|
|
442
|
+
|
|
443
|
+
def _prepare_request_async(
|
|
444
|
+
self,
|
|
445
|
+
endpoint: str,
|
|
446
|
+
stream_logs: Union[bool, None],
|
|
447
|
+
headers: dict,
|
|
448
|
+
pdb: Union[bool, int],
|
|
449
|
+
serialization: str,
|
|
450
|
+
):
|
|
451
|
+
"""Async version of _prepare_request that uses asyncio.Event and tasks instead of threads"""
|
|
452
|
+
if stream_logs is None:
|
|
453
|
+
stream_logs = config.stream_logs or False
|
|
454
|
+
|
|
455
|
+
if pdb:
|
|
456
|
+
debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
|
|
457
|
+
endpoint += f"?debug_port={debug_port}"
|
|
458
|
+
|
|
459
|
+
request_id = request_id_ctx_var.get("-")
|
|
460
|
+
if request_id == "-":
|
|
461
|
+
timestamp = str(time.time())
|
|
462
|
+
request_id = generate_unique_request_id(
|
|
463
|
+
endpoint=endpoint, timestamp=timestamp
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
headers = headers or {}
|
|
467
|
+
headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
|
|
468
|
+
|
|
469
|
+
stop_event = asyncio.Event()
|
|
470
|
+
log_task = None
|
|
471
|
+
if stream_logs:
|
|
472
|
+
log_task = asyncio.create_task(
|
|
473
|
+
self.stream_logs_async(request_id, stop_event)
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
return endpoint, headers, stop_event, log_task, request_id
|
|
477
|
+
|
|
478
|
+
def _make_request(self, method, endpoint, **kwargs):
|
|
479
|
+
response: httpx.Response = getattr(self.session, method)(endpoint, **kwargs)
|
|
480
|
+
self._handle_response_errors(response)
|
|
481
|
+
response.raise_for_status()
|
|
482
|
+
return response
|
|
483
|
+
|
|
484
|
+
async def _make_request_async(self, method, endpoint, **kwargs):
|
|
485
|
+
"""Async version of _make_request."""
|
|
486
|
+
response = await getattr(self.async_session, method)(endpoint, **kwargs)
|
|
487
|
+
self._handle_response_errors(response)
|
|
488
|
+
response.raise_for_status()
|
|
489
|
+
return response
|
|
490
|
+
|
|
491
|
+
# ----------------- Stream Helpers ----------------- #
|
|
492
|
+
async def _stream_logs_websocket(
|
|
493
|
+
self,
|
|
494
|
+
request_id,
|
|
495
|
+
stop_event: Union[threading.Event, asyncio.Event],
|
|
496
|
+
port: int,
|
|
497
|
+
host: str = "localhost",
|
|
498
|
+
):
|
|
499
|
+
"""Stream logs using Loki's websocket tail endpoint"""
|
|
500
|
+
formatter = ServerLogsFormatter()
|
|
501
|
+
websocket = None
|
|
502
|
+
try:
|
|
503
|
+
query = (
|
|
504
|
+
f'{{k8s_container_name="kubetorch"}} | json | request_id="{request_id}"'
|
|
505
|
+
)
|
|
506
|
+
encoded_query = urllib.parse.quote_plus(query)
|
|
507
|
+
uri = f"ws://{host}:{port}/loki/api/v1/tail?query={encoded_query}"
|
|
508
|
+
# Track the last timestamp we've seen to avoid duplicates
|
|
509
|
+
last_timestamp = None
|
|
510
|
+
# Track when we should stop
|
|
511
|
+
stop_time = None
|
|
512
|
+
|
|
513
|
+
# Add timeout to prevent hanging connections
|
|
514
|
+
logger.debug(f"Streaming logs with tail query {uri}")
|
|
515
|
+
websocket = await websockets.connect(
|
|
516
|
+
uri,
|
|
517
|
+
close_timeout=10, # Max time to wait for close handshake
|
|
518
|
+
ping_interval=20, # Send ping every 20 seconds
|
|
519
|
+
ping_timeout=10, # Wait 10 seconds for pong
|
|
520
|
+
)
|
|
521
|
+
try:
|
|
522
|
+
while True:
|
|
523
|
+
# If stop event is set, start counting down
|
|
524
|
+
# Handle both threading.Event and asyncio.Event
|
|
525
|
+
is_stop_set = (
|
|
526
|
+
stop_event.is_set()
|
|
527
|
+
if hasattr(stop_event, "is_set")
|
|
528
|
+
else stop_event.is_set()
|
|
529
|
+
)
|
|
530
|
+
if is_stop_set and stop_time is None:
|
|
531
|
+
stop_time = time.time() + 2 # 2 seconds grace period
|
|
532
|
+
|
|
533
|
+
# If we're past the grace period, exit
|
|
534
|
+
if stop_time is not None and time.time() > stop_time:
|
|
535
|
+
break
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
# Use shorter timeout during grace period
|
|
539
|
+
timeout = 0.1 if stop_time is not None else 1.0
|
|
540
|
+
message = await asyncio.wait_for(
|
|
541
|
+
websocket.recv(), timeout=timeout
|
|
542
|
+
)
|
|
543
|
+
data = json.loads(message)
|
|
544
|
+
|
|
545
|
+
if data.get("streams"):
|
|
546
|
+
for stream in data["streams"]:
|
|
547
|
+
labels = stream["stream"]
|
|
548
|
+
service_name = labels.get("kubetorch_com_service")
|
|
549
|
+
|
|
550
|
+
# Determine if this is a Knative service by checking for Knative-specific labels
|
|
551
|
+
is_knative = (
|
|
552
|
+
labels.get("serving_knative_dev_configuration")
|
|
553
|
+
is not None
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
for value in stream["values"]:
|
|
557
|
+
# Skip if we've already seen this timestamp
|
|
558
|
+
log_line = json.loads(value[1])
|
|
559
|
+
log_name = log_line.get("name")
|
|
560
|
+
log_message = log_line.get("message")
|
|
561
|
+
current_timestamp = value[0]
|
|
562
|
+
if (
|
|
563
|
+
last_timestamp is not None
|
|
564
|
+
and current_timestamp <= last_timestamp
|
|
565
|
+
):
|
|
566
|
+
continue
|
|
567
|
+
last_timestamp = value[0]
|
|
568
|
+
|
|
569
|
+
# Choose the appropriate identifier for the log prefix
|
|
570
|
+
if is_knative:
|
|
571
|
+
log_prefix = service_name
|
|
572
|
+
else:
|
|
573
|
+
# For deployments, use the pod name from the structured log
|
|
574
|
+
log_prefix = log_line.get("pod", service_name)
|
|
575
|
+
|
|
576
|
+
if log_name == "print_redirect":
|
|
577
|
+
print(
|
|
578
|
+
f"{formatter.start_color}({log_prefix}) {log_message}{formatter.reset_color}"
|
|
579
|
+
)
|
|
580
|
+
elif log_name != "uvicorn.access":
|
|
581
|
+
formatted_log = f"({log_prefix}) {log_line.get('asctime')} | {log_line.get('levelname')} | {log_message}"
|
|
582
|
+
print(
|
|
583
|
+
f"{formatter.start_color}{formatted_log}{formatter.reset_color}"
|
|
584
|
+
)
|
|
585
|
+
except asyncio.TimeoutError:
|
|
586
|
+
# Timeout is expected, just continue the loop
|
|
587
|
+
continue
|
|
588
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
589
|
+
logger.debug(f"WebSocket connection closed: {str(e)}")
|
|
590
|
+
break
|
|
591
|
+
finally:
|
|
592
|
+
if websocket:
|
|
593
|
+
try:
|
|
594
|
+
# Use wait_for to prevent hanging on close
|
|
595
|
+
await asyncio.wait_for(websocket.close(), timeout=1.0)
|
|
596
|
+
except (asyncio.TimeoutError, Exception):
|
|
597
|
+
pass
|
|
598
|
+
except Exception as e:
|
|
599
|
+
logger.error(f"Error in websocket stream: {e}")
|
|
600
|
+
finally:
|
|
601
|
+
# Ensure websocket is closed even if we didn't enter the context
|
|
602
|
+
if websocket:
|
|
603
|
+
try:
|
|
604
|
+
# Use wait_for to prevent hanging on close
|
|
605
|
+
await asyncio.wait_for(websocket.close(), timeout=1.0)
|
|
606
|
+
except (asyncio.TimeoutError, Exception):
|
|
607
|
+
pass
|
|
608
|
+
|
|
609
|
+
def _run_log_stream(self, request_id, stop_event, host, port):
|
|
610
|
+
"""Helper to run log streaming in an event loop"""
|
|
611
|
+
loop = asyncio.new_event_loop()
|
|
612
|
+
asyncio.set_event_loop(loop)
|
|
613
|
+
try:
|
|
614
|
+
loop.run_until_complete(
|
|
615
|
+
self._stream_logs_websocket(
|
|
616
|
+
request_id, stop_event, host=host, port=port
|
|
617
|
+
)
|
|
618
|
+
)
|
|
619
|
+
finally:
|
|
620
|
+
loop.close()
|
|
621
|
+
|
|
622
|
+
# ----------------- Core APIs ----------------- #
|
|
623
|
+
def stream_logs(self, request_id, stop_event):
|
|
624
|
+
"""Start websocket log streaming in a separate thread"""
|
|
625
|
+
logger.debug(
|
|
626
|
+
f"Streaming logs for service {self.service_name} (request_id: {request_id})"
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
base_url = service_url()
|
|
630
|
+
base_host, base_port = extract_host_port(base_url)
|
|
631
|
+
self._run_log_stream(request_id, stop_event, base_host, base_port)
|
|
632
|
+
|
|
633
|
+
async def stream_logs_async(self, request_id, stop_event):
|
|
634
|
+
"""Async version of stream_logs. Start websocket log streaming as an async task"""
|
|
635
|
+
logger.debug(
|
|
636
|
+
f"Streaming logs for service {self.service_name} (request_id: {request_id})"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
base_url = service_url()
|
|
640
|
+
base_host, base_port = extract_host_port(base_url)
|
|
641
|
+
await self._stream_logs_websocket(
|
|
642
|
+
request_id, stop_event, host=base_host, port=base_port
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
def call_method(
|
|
646
|
+
self,
|
|
647
|
+
endpoint: str,
|
|
648
|
+
stream_logs: Union[bool, None] = None,
|
|
649
|
+
body: dict = None,
|
|
650
|
+
headers: dict = None,
|
|
651
|
+
pdb: Union[bool, int] = None,
|
|
652
|
+
serialization: str = "json",
|
|
653
|
+
):
|
|
654
|
+
endpoint, headers, stop_event, log_thread, _ = self._prepare_request(
|
|
655
|
+
endpoint, stream_logs, headers, pdb, serialization
|
|
656
|
+
)
|
|
657
|
+
try:
|
|
658
|
+
json_data = _serialize_body(body, serialization)
|
|
659
|
+
response = self.post(endpoint=endpoint, json=json_data, headers=headers)
|
|
660
|
+
response.raise_for_status()
|
|
661
|
+
return _deserialize_response(response, serialization)
|
|
662
|
+
finally:
|
|
663
|
+
stop_event.set()
|
|
664
|
+
|
|
665
|
+
async def call_method_async(
|
|
666
|
+
self,
|
|
667
|
+
endpoint: str,
|
|
668
|
+
stream_logs: Union[bool, None] = None,
|
|
669
|
+
body: dict = None,
|
|
670
|
+
headers: dict = None,
|
|
671
|
+
pdb: Union[bool, int] = None,
|
|
672
|
+
serialization: str = "json",
|
|
673
|
+
):
|
|
674
|
+
"""Async version of call_method."""
|
|
675
|
+
endpoint, headers, stop_event, log_task, _ = self._prepare_request_async(
|
|
676
|
+
endpoint, stream_logs, headers, pdb, serialization
|
|
677
|
+
)
|
|
678
|
+
try:
|
|
679
|
+
json_data = _serialize_body(body, serialization)
|
|
680
|
+
response = await self.post_async(
|
|
681
|
+
endpoint=endpoint, json=json_data, headers=headers
|
|
682
|
+
)
|
|
683
|
+
response.raise_for_status()
|
|
684
|
+
result = _deserialize_response(response, serialization)
|
|
685
|
+
|
|
686
|
+
if stream_logs and log_task:
|
|
687
|
+
await asyncio.sleep(0.5)
|
|
688
|
+
|
|
689
|
+
return result
|
|
690
|
+
finally:
|
|
691
|
+
stop_event.set()
|
|
692
|
+
if log_task:
|
|
693
|
+
try:
|
|
694
|
+
await asyncio.wait_for(log_task, timeout=0.5)
|
|
695
|
+
except asyncio.TimeoutError:
|
|
696
|
+
log_task.cancel()
|
|
697
|
+
try:
|
|
698
|
+
await log_task
|
|
699
|
+
except asyncio.CancelledError:
|
|
700
|
+
pass
|
|
701
|
+
|
|
702
|
+
def post(self, endpoint, json=None, headers=None):
|
|
703
|
+
return self._make_request("post", endpoint, json=json, headers=headers)
|
|
704
|
+
|
|
705
|
+
def put(self, endpoint, json=None, headers=None):
|
|
706
|
+
return self._make_request("put", endpoint, json=json, headers=headers)
|
|
707
|
+
|
|
708
|
+
def delete(self, endpoint, json=None, headers=None):
|
|
709
|
+
return self._make_request("delete", endpoint, json=json, headers=headers)
|
|
710
|
+
|
|
711
|
+
def get(self, endpoint, headers=None):
|
|
712
|
+
return self._make_request("get", endpoint, headers=headers)
|
|
713
|
+
|
|
714
|
+
async def post_async(self, endpoint, json=None, headers=None):
|
|
715
|
+
return await self._make_request_async(
|
|
716
|
+
"post", endpoint, json=json, headers=headers
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
async def put_async(self, endpoint, json=None, headers=None):
|
|
720
|
+
return await self._make_request_async(
|
|
721
|
+
"put", endpoint, json=json, headers=headers
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
async def delete_async(self, endpoint, json=None, headers=None):
|
|
725
|
+
return await self._make_request_async(
|
|
726
|
+
"delete", endpoint, json=json, headers=headers
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
async def get_async(self, endpoint, headers=None):
|
|
730
|
+
return await self._make_request_async("get", endpoint, headers=headers)
|