kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,802 @@
1
+ import asyncio
2
+ import json
3
+ import threading
4
+ import time
5
+ import urllib.parse
6
+ from collections import defaultdict
7
+ from datetime import datetime
8
+ from typing import Literal, Union
9
+
10
+ import httpx
11
+ import requests
12
+ import websockets
13
+
14
+ from kubernetes import client
15
+
16
+ from kubetorch.globals import config, MetricsConfig, service_url
17
+ from kubetorch.logger import get_logger
18
+
19
+ from kubetorch.servers.http.utils import (
20
+ _deserialize_response,
21
+ _serialize_body,
22
+ generate_unique_request_id,
23
+ request_id_ctx_var,
24
+ )
25
+
26
+ from kubetorch.serving.constants import DEFAULT_DEBUG_PORT, DEFAULT_NGINX_PORT
27
+ from kubetorch.utils import extract_host_port, ServerLogsFormatter
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ class CustomResponse(httpx.Response):
33
+ def raise_for_status(self):
34
+ """Raises parsed server errors or HTTPError for other status codes"""
35
+ if not 400 <= self.status_code < 600:
36
+ return
37
+
38
+ if "application/json" in self.headers.get("Content-Type", ""):
39
+ try:
40
+ error_data = self.json()
41
+ if all(k in error_data for k in ["error_type", "message", "traceback", "pod_name"]):
42
+ error_type = error_data["error_type"]
43
+ message = error_data.get("message", "")
44
+ traceback = error_data["traceback"]
45
+ pod_name = error_data["pod_name"]
46
+ error_state = error_data.get("state", {}) # Optional serialized state
47
+
48
+ # Try to use the actual exception class if it exists
49
+ exc = None
50
+ error_class = None
51
+
52
+ # Import the exception registry
53
+ try:
54
+ from kubetorch import EXCEPTION_REGISTRY
55
+ except ImportError:
56
+ EXCEPTION_REGISTRY = {}
57
+
58
+ # First check if it's a Python builtin exception
59
+ import builtins
60
+
61
+ if hasattr(builtins, error_type):
62
+ error_class = getattr(builtins, error_type)
63
+ # Otherwise try to use from the kubetorch registry
64
+ elif error_type in EXCEPTION_REGISTRY:
65
+ error_class = EXCEPTION_REGISTRY[error_type]
66
+
67
+ if error_class:
68
+ try:
69
+ # First try to reconstruct from state if available
70
+ if error_state and hasattr(error_class, "from_dict"):
71
+ exc = error_class.from_dict(error_state)
72
+ # Otherwise try simple construction with message
73
+ else:
74
+ exc = error_class(message)
75
+ except Exception as e:
76
+ logger.debug(f"Could not reconstruct {error_type}: {e}, will use dynamic type")
77
+ # Fall back to dynamic creation
78
+ pass
79
+
80
+ # If we couldn't create the actual exception, fall back to dynamic type creation
81
+ if not exc:
82
+
83
+ def create_str_method(remote_traceback):
84
+ def __str__(self):
85
+ cleaned_traceback = remote_traceback.encode().decode("unicode_escape")
86
+ return f"{self.args[0]}\n\n{cleaned_traceback}"
87
+
88
+ return __str__
89
+
90
+ # Create the exception class with the custom __str__
91
+ error_class = type(
92
+ error_type,
93
+ (Exception,),
94
+ {"__str__": create_str_method(traceback)},
95
+ )
96
+
97
+ exc = error_class(message)
98
+
99
+ # Always add remote_traceback and pod_name
100
+ exc.remote_traceback = traceback
101
+ exc.pod_name = pod_name
102
+
103
+ # Wrap the exception to display remote traceback
104
+ # Create a new class that inherits from the original exception
105
+ # and overrides __str__ to include the remote traceback
106
+ class RemoteException(exc.__class__):
107
+ def __str__(self):
108
+ # Get the original message
109
+ original_msg = super().__str__()
110
+ # Clean up the traceback
111
+ cleaned_traceback = self.remote_traceback.encode().decode("unicode_escape")
112
+ return f"{original_msg}\n\n{cleaned_traceback}"
113
+
114
+ # Create wrapped instance without calling __init__
115
+ wrapped_exc = RemoteException.__new__(RemoteException)
116
+ # Copy all attributes from the original exception
117
+ wrapped_exc.__dict__.update(exc.__dict__)
118
+ # Set the exception args for proper display
119
+ wrapped_exc.args = (str(exc),)
120
+ raise wrapped_exc
121
+
122
+ except Exception as e:
123
+ # Catchall for errors during exception handling above
124
+ if isinstance(e, RemoteException):
125
+ # If we caught a RemoteException, it was packaged properly
126
+ raise
127
+ import httpx
128
+
129
+ raise httpx.HTTPStatusError(
130
+ f"{self.status_code} {self.text}",
131
+ request=self.request,
132
+ response=self,
133
+ )
134
+ else:
135
+ logger.debug(f"Non-JSON error body: {self.text[:100]}")
136
+ super().raise_for_status()
137
+
138
+
139
+ class CustomSession(httpx.Client):
140
+ def __init__(self):
141
+ limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
142
+ super().__init__(timeout=None, limits=limits)
143
+
144
+ def __del__(self):
145
+ self.close()
146
+
147
+ def request(self, *args, **kwargs):
148
+ response = super().request(*args, **kwargs)
149
+ response.__class__ = CustomResponse
150
+ return response
151
+
152
+
153
+ class CustomAsyncClient(httpx.AsyncClient):
154
+ def __init__(self):
155
+ limits = httpx.Limits(max_connections=None, max_keepalive_connections=None)
156
+ super().__init__(timeout=None, limits=limits)
157
+
158
+ async def request(self, *args, **kwargs):
159
+ response = await super().request(*args, **kwargs)
160
+ response.__class__ = CustomResponse
161
+ return response
162
+
163
+
164
+ class HTTPClient:
165
+ """Client for making HTTP requests to a remote service. Port forwards are shared between client
166
+ instances. Each port forward instance is cleaned up when the last reference is closed."""
167
+
168
+ def __init__(self, base_url, compute, service_name):
169
+ self._core_api = None
170
+ self._objects_api = None
171
+
172
+ self.compute = compute
173
+ self.service_name = service_name
174
+ self.base_url = base_url.rstrip("/")
175
+ self.session = CustomSession()
176
+ self._async_client = None
177
+
178
+ def __del__(self):
179
+ self.close()
180
+
181
+ def close(self):
182
+ """Close the async HTTP client to prevent resource leaks."""
183
+ if self._async_client:
184
+ try:
185
+ # Close the async client if it's still open
186
+ if not self._async_client.is_closed:
187
+ # Use asyncio.run if we're in a sync context, otherwise schedule the close
188
+ try:
189
+ import asyncio
190
+
191
+ loop = asyncio.get_event_loop()
192
+ if loop.is_running():
193
+ # If we're in an async context, schedule the close
194
+ loop.create_task(self._async_client.aclose())
195
+ else:
196
+ # If we're in a sync context, run the close
197
+ asyncio.run(self._async_client.aclose())
198
+ except RuntimeError:
199
+ # No event loop available, try to create one
200
+ try:
201
+ asyncio.run(self._async_client.aclose())
202
+ except Exception:
203
+ pass
204
+ except Exception as e:
205
+ logger.debug(f"Error closing async client: {e}")
206
+ finally:
207
+ self._async_client = None
208
+
209
+ # Close the session as well
210
+ if self.session:
211
+ try:
212
+ self.session.close()
213
+ except Exception as e:
214
+ logger.debug(f"Error closing session: {e}")
215
+ finally:
216
+ self.session = None
217
+
218
+ @property
219
+ def core_api(self):
220
+ if self._core_api is None:
221
+ self._core_api = client.CoreV1Api()
222
+ return self._core_api
223
+
224
+ @property
225
+ def objects_api(self):
226
+ if self._objects_api is None:
227
+ self._objects_api = client.CustomObjectsApi()
228
+ return self._objects_api
229
+
230
+ @property
231
+ def local_port(self):
232
+ """Local port to open the port forward connection with the proxy service. This should match the client port used
233
+ to set the URL of the service in the Compute class."""
234
+ if self.compute:
235
+ return self.compute.client_port()
236
+ return DEFAULT_NGINX_PORT
237
+
238
+ @property
239
+ def async_session(self):
240
+ """Get or create async HTTP client."""
241
+ if self._async_client is None:
242
+ self._async_client = CustomAsyncClient()
243
+ return self._async_client
244
+
245
+ def _prepare_request(
246
+ self,
247
+ endpoint: str,
248
+ stream_logs: Union[bool, None],
249
+ stream_metrics: Union[bool, MetricsConfig, None],
250
+ headers: dict,
251
+ pdb: Union[bool, int],
252
+ serialization: str,
253
+ ):
254
+ if stream_logs is None:
255
+ stream_logs = config.stream_logs or False
256
+
257
+ metrics_config = MetricsConfig()
258
+ if isinstance(stream_metrics, MetricsConfig):
259
+ metrics_config = stream_metrics
260
+ stream_metrics = True
261
+ elif stream_metrics is None:
262
+ stream_metrics = config.stream_metrics or False
263
+
264
+ if pdb:
265
+ debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
266
+ endpoint += f"?debug_port={debug_port}"
267
+
268
+ request_id = request_id_ctx_var.get("-")
269
+ if request_id == "-":
270
+ timestamp = str(time.time())
271
+ request_id = generate_unique_request_id(endpoint=endpoint, timestamp=timestamp)
272
+
273
+ headers = headers or {}
274
+ headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
275
+
276
+ stop_event = threading.Event()
277
+ log_thread = None
278
+ if stream_logs:
279
+ log_thread = threading.Thread(target=self.stream_logs, args=(request_id, stop_event))
280
+ log_thread.daemon = True
281
+ log_thread.start()
282
+
283
+ if stream_metrics:
284
+ metrics_thread = threading.Thread(target=self.stream_metrics, args=(stop_event, metrics_config))
285
+ metrics_thread.daemon = True
286
+ metrics_thread.start()
287
+ else:
288
+ metrics_thread = None
289
+
290
+ return endpoint, headers, stop_event, log_thread, metrics_thread, request_id
291
+
292
+ def _prepare_request_async(
293
+ self,
294
+ endpoint: str,
295
+ stream_logs: Union[bool, None],
296
+ stream_metrics: Union[bool, MetricsConfig, None],
297
+ headers: dict,
298
+ pdb: Union[bool, int],
299
+ serialization: str,
300
+ ):
301
+ """Async version of _prepare_request that uses asyncio.Event and tasks instead of threads"""
302
+ if stream_logs is None:
303
+ stream_logs = config.stream_logs or False
304
+
305
+ metrics_config = None
306
+ if isinstance(stream_metrics, MetricsConfig):
307
+ metrics_config = stream_metrics
308
+ stream_metrics = True
309
+ elif stream_metrics is None:
310
+ stream_metrics = config.stream_metrics or False
311
+ metrics_config = None
312
+
313
+ if pdb:
314
+ debug_port = DEFAULT_DEBUG_PORT if isinstance(pdb, bool) else pdb
315
+ endpoint += f"?debug_port={debug_port}"
316
+
317
+ request_id = request_id_ctx_var.get("-")
318
+ if request_id == "-":
319
+ timestamp = str(time.time())
320
+ request_id = generate_unique_request_id(endpoint=endpoint, timestamp=timestamp)
321
+
322
+ headers = headers or {}
323
+ headers.update({"X-Request-ID": request_id, "X-Serialization": serialization})
324
+
325
+ stop_event = asyncio.Event()
326
+ log_task = None
327
+ if stream_logs:
328
+ log_task = asyncio.create_task(self.stream_logs_async(request_id, stop_event))
329
+
330
+ metrics_task = None
331
+ if stream_metrics:
332
+ metrics_task = asyncio.create_task(self.stream_metrics_async(request_id, stop_event, metrics_config))
333
+
334
+ return endpoint, headers, stop_event, log_task, metrics_task, request_id
335
+
336
+ def _make_request(self, method, endpoint, **kwargs):
337
+ response: httpx.Response = getattr(self.session, method)(endpoint, **kwargs)
338
+ response.raise_for_status()
339
+ return response
340
+
341
+ async def _make_request_async(self, method, endpoint, **kwargs):
342
+ """Async version of _make_request."""
343
+ response = await getattr(self.async_session, method)(endpoint, **kwargs)
344
+ response.raise_for_status()
345
+ return response
346
+
347
+ # ----------------- Stream Helpers ----------------- #
348
+ async def _stream_logs_websocket(
349
+ self,
350
+ request_id,
351
+ stop_event: Union[threading.Event, asyncio.Event],
352
+ port: int,
353
+ host: str = "localhost",
354
+ ):
355
+ """Stream logs using Loki's websocket tail endpoint"""
356
+ formatter = ServerLogsFormatter()
357
+ websocket = None
358
+ try:
359
+ query = f'{{k8s_container_name="kubetorch"}} | json | request_id="{request_id}"'
360
+ encoded_query = urllib.parse.quote_plus(query)
361
+ uri = f"ws://{host}:{port}/loki/api/v1/tail?query={encoded_query}"
362
+ # Track the last timestamp we've seen to avoid duplicates
363
+ last_timestamp = None
364
+ # Track when we should stop
365
+ stop_time = None
366
+
367
+ # Add timeout to prevent hanging connections
368
+ logger.debug(f"Streaming logs with tail query {uri}")
369
+ websocket = await websockets.connect(
370
+ uri,
371
+ close_timeout=10, # Max time to wait for close handshake
372
+ ping_interval=20, # Send ping every 20 seconds
373
+ ping_timeout=10, # Wait 10 seconds for pong
374
+ )
375
+ try:
376
+ while True:
377
+ # If stop event is set, start counting down
378
+ # Handle both threading.Event and asyncio.Event
379
+ is_stop_set = stop_event.is_set() if hasattr(stop_event, "is_set") else stop_event.is_set()
380
+ if is_stop_set and stop_time is None:
381
+ stop_time = time.time() + 2 # 2 seconds grace period
382
+
383
+ # If we're past the grace period, exit
384
+ if stop_time is not None and time.time() > stop_time:
385
+ break
386
+
387
+ try:
388
+ # Use shorter timeout during grace period
389
+ timeout = 0.1 if stop_time is not None else 1.0
390
+ message = await asyncio.wait_for(websocket.recv(), timeout=timeout)
391
+ data = json.loads(message)
392
+
393
+ if data.get("streams"):
394
+ for stream in data["streams"]:
395
+ labels = stream["stream"]
396
+ service_name = labels.get("kubetorch_com_service")
397
+
398
+ # Determine if this is a Knative service by checking for Knative-specific labels
399
+ is_knative = labels.get("serving_knative_dev_configuration") is not None
400
+
401
+ for value in stream["values"]:
402
+ # Skip if we've already seen this timestamp
403
+ log_line = json.loads(value[1])
404
+ log_name = log_line.get("name")
405
+ log_message = log_line.get("message")
406
+ current_timestamp = value[0]
407
+ if last_timestamp is not None and current_timestamp <= last_timestamp:
408
+ continue
409
+ last_timestamp = value[0]
410
+
411
+ # Choose the appropriate identifier for the log prefix
412
+ if is_knative:
413
+ log_prefix = service_name
414
+ else:
415
+ # For deployments, use the pod name from the structured log
416
+ log_prefix = log_line.get("pod", service_name)
417
+
418
+ if log_name == "print_redirect":
419
+ print(
420
+ f"{formatter.start_color}({log_prefix}) {log_message}{formatter.reset_color}"
421
+ )
422
+ elif log_name != "uvicorn.access":
423
+ formatted_log = f"({log_prefix}) {log_line.get('asctime')} | {log_line.get('levelname')} | {log_message}"
424
+ print(f"{formatter.start_color}{formatted_log}{formatter.reset_color}")
425
+ except asyncio.TimeoutError:
426
+ # Timeout is expected, just continue the loop
427
+ continue
428
+ except websockets.exceptions.ConnectionClosed as e:
429
+ logger.debug(f"WebSocket connection closed: {str(e)}")
430
+ break
431
+ finally:
432
+ if websocket:
433
+ try:
434
+ # Use wait_for to prevent hanging on close
435
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
436
+ except (asyncio.TimeoutError, Exception):
437
+ pass
438
+ except Exception as e:
439
+ logger.error(f"Error in websocket stream: {e}")
440
+ finally:
441
+ # Ensure websocket is closed even if we didn't enter the context
442
+ if websocket:
443
+ try:
444
+ # Use wait_for to prevent hanging on close
445
+ await asyncio.wait_for(websocket.close(), timeout=1.0)
446
+ except (asyncio.TimeoutError, Exception):
447
+ pass
448
+
449
+ def _run_log_stream(self, request_id, stop_event, host, port):
450
+ """Helper to run log streaming in an event loop"""
451
+ loop = asyncio.new_event_loop()
452
+ asyncio.set_event_loop(loop)
453
+ try:
454
+ loop.run_until_complete(self._stream_logs_websocket(request_id, stop_event, host=host, port=port))
455
+ finally:
456
+ loop.close()
457
+
458
+ # ----------------- Metrics Helpers ----------------- #
459
+
460
+ def _get_stream_metrics_queries(self, scope: Literal["pod", "resource"], interval: int):
461
+ # lookback window for each Prometheus query
462
+ # For short intervals (1–60s polling): look back ≤ 2 min
463
+ # For slow polling (≥ 1 min): allow up to 5 min lookback
464
+ effective_window = min(max(30, interval * 3), 120 if interval < 60 else 300)
465
+ metric_queries = {}
466
+ if scope == "pod":
467
+ active_pods = self.compute.pod_names()
468
+ if not active_pods:
469
+ logger.warning("No active pods found for service, skipping metrics collection")
470
+ return
471
+
472
+ pod_regex = "|".join(active_pods)
473
+ metric_queries = {
474
+ # CPU: seconds of CPU used per second (i.e. cores used)
475
+ # Note: using irate ensures we always capture at least 2 samples in the window
476
+ # https://prometheus.io/docs/prometheus/latest/querying/functions/#irate
477
+ "CPU": f'sum by (pod) (irate(container_cpu_usage_seconds_total{{container!="",pod=~"{pod_regex}"}}[{effective_window}s]))',
478
+ # Memory: Working set in MiB
479
+ "Mem": f'last_over_time(container_memory_working_set_bytes{{container!="",pod=~"{pod_regex}"}}[{effective_window}s]) / 1024 / 1024',
480
+ # GPU metrics from DCGM
481
+ "GPU_SM": f'avg by (pod) (last_over_time(DCGM_FI_DEV_GPU_UTIL{{pod=~"{pod_regex}"}}[{effective_window}s]))',
482
+ "GPUMiB": f'avg by (pod) (last_over_time(DCGM_FI_DEV_FB_USED{{pod=~"{pod_regex}"}}[{effective_window}s]))',
483
+ }
484
+
485
+ elif scope == "resource":
486
+ service_name_regex = f"{self.compute.service_name}.+"
487
+ metric_queries = {
488
+ # CPU: Use rate of CPU seconds - cores utilized
489
+ "CPU": f'avg((irate(container_cpu_usage_seconds_total{{container!="",pod=~"{service_name_regex}"}}[{effective_window}s])))',
490
+ # Memory: Working set in MiB
491
+ "Mem": f'avg(last_over_time(container_memory_working_set_bytes{{container!="",pod=~"{service_name_regex}"}}[{effective_window}s]) / 1024 / 1024)',
492
+ # GPU metrics from DCGM
493
+ "GPU_SM": f'avg(last_over_time(DCGM_FI_DEV_GPU_UTIL{{pod=~"{service_name_regex}"}}[{effective_window}s]))',
494
+ "GPUMiB": f'avg(last_over_time(DCGM_FI_DEV_FB_USED{{pod=~"{service_name_regex}"}}[{effective_window}s]))',
495
+ }
496
+
497
+ return metric_queries
498
+
499
+ def _collect_metrics_common(
500
+ self,
501
+ stop_event,
502
+ http_getter,
503
+ sleeper,
504
+ metrics_config: MetricsConfig,
505
+ is_async: bool = False,
506
+ ):
507
+ """
508
+ Internal shared implementation for collecting and printing live resource metrics
509
+ (CPU, memory, and GPU) for all active pods in the service.
510
+
511
+ This function drives both the synchronous (`_collect_metrics`) and asynchronous
512
+ (`_collect_metrics_async`) metric collectors. It repeatedly queries Prometheus for
513
+ metrics related to the service’s pods until the given `stop_event` is set.
514
+
515
+ Args:
516
+ stop_event (threading.event or asyncio.Event): A threading.Event or asyncio.Event used to stop collection.
517
+ http_getter (Callable): Callable that fetches Prometheus data — either sync (`requests.get`)
518
+ or async (`httpx.AsyncClient.get`).
519
+ sleeper (Callable): Callable that sleeps between metric polls — either time.sleep or asyncio.sleep.
520
+ metrics_config (MetricsConfig): User provided configuration controlling metrics collection behavior.
521
+ is_async (bool): If ``True``, runs in async mode (awaits HTTP + sleep calls).
522
+ If ``False``, runs in blocking sync mode.
523
+
524
+ Behavior:
525
+ - Polls Prometheus every 1–5 seconds for CPU, memory, and GPU metrics.
526
+ - Prints a formatted line per pod to stdout.
527
+ - Automatically adapts between synchronous and asynchronous execution modes.
528
+
529
+ Note:
530
+ - This function should not be called directly; use `_collect_metrics` or
531
+ `_collect_metrics_async` instead.
532
+ - Stops automatically when `stop_event.set()` is triggered.
533
+ """
534
+
535
+ async def maybe_await(obj):
536
+ if is_async and asyncio.iscoroutine(obj):
537
+ return await obj
538
+ return obj
539
+
540
+ async def run():
541
+
542
+ interval = int(metrics_config.interval)
543
+ metric_queries = self._get_stream_metrics_queries(scope=metrics_config.scope, interval=interval)
544
+ show_gpu = True
545
+ prom_url = f"{service_url()}/prometheus/api/v1/query"
546
+
547
+ start_time = time.time()
548
+
549
+ while not stop_event.is_set():
550
+ await maybe_await(sleeper(interval))
551
+ pod_data = defaultdict(dict)
552
+ gpu_values = []
553
+
554
+ for name, query in metric_queries.items():
555
+ try:
556
+ data = await maybe_await(
557
+ http_getter(
558
+ prom_url,
559
+ params={
560
+ "query": query,
561
+ "lookback_delta": interval,
562
+ },
563
+ )
564
+ )
565
+ if data.get("status") != "success":
566
+ continue
567
+ for result in data["data"]["result"]:
568
+ m = result["metric"]
569
+ ts, val = result["value"]
570
+ pod = m.get("pod", "unknown")
571
+ val_f = float(val)
572
+ pod_data[pod][name] = val_f
573
+ if name in ("GPU%", "GPUMiB"):
574
+ gpu_values.append(val_f)
575
+ except Exception as e:
576
+ logger.error(f"Error loading metrics: {e}")
577
+ continue
578
+
579
+ if not gpu_values:
580
+ show_gpu = False
581
+
582
+ if pod_data:
583
+ for pod, vals in sorted(pod_data.items()):
584
+ mem = vals.get("Mem", 0.0)
585
+ cpu_cores = vals.get("CPU", 0.0)
586
+ gpu = vals.get("GPU_SM", 0.0)
587
+ gpumem = vals.get("GPUMiB", 0.0)
588
+ now_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
589
+
590
+ pod_info = f"| pod: {pod} " if metrics_config.scope == "pod" else ""
591
+ line = f"[METRICS] {now_ts} {pod_info}| " f"CPU: {cpu_cores:.2f} | Memory: {mem:.3f}MiB"
592
+ if show_gpu:
593
+ line += f" | GPU SM: {gpu:.2f}% | GPU Memory: {gpumem:.3f}MiB"
594
+
595
+ print(f"{line}", flush=True)
596
+
597
+ elapsed = time.time() - start_time
598
+ sleep_interval = max(interval, int(min(60, 1 + elapsed / 30)))
599
+ await maybe_await(sleeper(sleep_interval))
600
+
601
+ # run sync or async depending on mode
602
+ if is_async:
603
+ return run()
604
+ else:
605
+ asyncio.run(run())
606
+
607
+ def _collect_metrics(self, stop_event, http_getter, sleeper, metrics_config):
608
+ """
609
+ Synchronous metrics collector.
610
+
611
+ Invokes `_collect_metrics_common` in blocking mode to stream metrics. Designed for use in background threads
612
+ where the event loop is *not* running (e.g. standard Python threads).
613
+
614
+ Args:
615
+ stop_event: threading.Event to signal termination of metric collection.
616
+ http_getter: Synchronous callable that fetches Prometheus query results.
617
+ sleeper: Blocking sleep callable.
618
+ metrics_config: User provided configuration controlling metrics collection behavior.
619
+
620
+ Notes:
621
+ - Runs until `stop_event` is set.
622
+ - Safe to use in multi-threaded environments.
623
+ - Should not be invoked from within an asyncio event loop.
624
+ """
625
+ self._collect_metrics_common(stop_event, http_getter, sleeper, metrics_config=metrics_config, is_async=False)
626
+
627
+ async def _collect_metrics_async(self, stop_event, http_getter, sleeper, metrics_config):
628
+ """
629
+ Asynchronous metrics collector.
630
+
631
+ Invokes `_collect_metrics_common` in fully async mode. Designed for use when the caller is already
632
+ inside an active asyncio event loop.
633
+
634
+ Args:
635
+ stop_event: asyncio.Event to signal termination of metric collection.
636
+ http_getter: Asynchronous callable that fetches Prometheus query results.
637
+ sleeper: Async sleep callable.
638
+ metrics_config: User provided configuration controlling metrics collection behavior.
639
+
640
+ Note:
641
+ - Should only be called from within an asyncio context.
642
+ - Automatically terminates once `stop_event` is set.
643
+ - Prints formatted metrics continuously until stopped.
644
+ """
645
+ await self._collect_metrics_common(
646
+ stop_event, http_getter, sleeper, metrics_config=metrics_config, is_async=True
647
+ )
648
+
649
+ # ----------------- Core APIs ----------------- #
650
+ def stream_logs(self, request_id, stop_event):
651
+ """Start websocket log streaming in a separate thread"""
652
+ logger.debug(f"Streaming logs for service {self.service_name} (request_id: {request_id})")
653
+
654
+ base_url = service_url()
655
+ base_host, base_port = extract_host_port(base_url)
656
+ self._run_log_stream(request_id, stop_event, base_host, base_port)
657
+
658
+ async def stream_logs_async(self, request_id, stop_event):
659
+ """Async version of stream_logs. Start websocket log streaming as an async task"""
660
+ logger.debug(f"Streaming logs for service {self.service_name} (request_id: {request_id})")
661
+
662
+ base_url = service_url()
663
+ base_host, base_port = extract_host_port(base_url)
664
+ await self._stream_logs_websocket(request_id, stop_event, host=base_host, port=base_port)
665
+
666
+ async def stream_metrics_async(self, request_id, stop_event, metrics_config):
667
+ """Async GPU/CPU metrics streaming (uses httpx.AsyncClient)."""
668
+ logger.debug(f"Starting async metrics for {self.service_name} (request_id={request_id})")
669
+
670
+ async def async_http_get(url, params):
671
+ try:
672
+ async with httpx.AsyncClient(timeout=5.0) as client:
673
+ resp = await client.get(url, params=params)
674
+ resp.raise_for_status()
675
+ try:
676
+ return resp.json()
677
+ except json.JSONDecodeError:
678
+ logger.debug(f"Non-JSON response from {url}: {resp.text[:100]}")
679
+ return {}
680
+ except Exception as e:
681
+ logger.debug(f"Async metrics request failed for {url} ({params}): {e}")
682
+ return {}
683
+
684
+ async def async_sleep(seconds):
685
+ await asyncio.sleep(seconds)
686
+
687
+ await self._collect_metrics_async(stop_event, async_http_get, async_sleep, metrics_config)
688
+ logger.debug(f"Stopped async metrics for {request_id}")
689
+
690
+ def stream_metrics(self, stop_event, metrics_config: MetricsConfig = None):
691
+ """Synchronous GPU/CPU metrics streaming (uses requests)."""
692
+ logger.debug(f"Streaming metrics for {self.service_name}")
693
+ logger.debug(f"Using metrics config: {metrics_config}")
694
+
695
+ def sync_http_get(url, params):
696
+ try:
697
+ resp = requests.get(url, params=params, timeout=5.0)
698
+ resp.raise_for_status()
699
+ try:
700
+ return resp.json()
701
+ except json.JSONDecodeError:
702
+ logger.debug(f"Non-JSON response from {url}: {resp.text[:100]}")
703
+ return {}
704
+ except Exception as e:
705
+ logger.debug(f"Sync metrics request failed for {url} ({params}): {e}")
706
+ return {}
707
+
708
+ def sync_sleep(seconds):
709
+ time.sleep(seconds)
710
+
711
+ self._collect_metrics(stop_event, sync_http_get, sync_sleep, metrics_config)
712
+
713
+ def call_method(
714
+ self,
715
+ endpoint: str,
716
+ stream_logs: Union[bool, None] = None,
717
+ stream_metrics: Union[bool, MetricsConfig, None] = None,
718
+ body: dict = None,
719
+ headers: dict = None,
720
+ pdb: Union[bool, int] = None,
721
+ serialization: str = "json",
722
+ ):
723
+ (
724
+ endpoint,
725
+ headers,
726
+ stop_event,
727
+ log_thread,
728
+ metrics_thread,
729
+ _,
730
+ ) = self._prepare_request(endpoint, stream_logs, stream_metrics, headers, pdb, serialization)
731
+ try:
732
+ json_data = _serialize_body(body, serialization)
733
+ response = self.post(endpoint=endpoint, json=json_data, headers=headers)
734
+ response.raise_for_status()
735
+ return _deserialize_response(response, serialization)
736
+ finally:
737
+ stop_event.set()
738
+
739
+ async def call_method_async(
740
+ self,
741
+ endpoint: str,
742
+ stream_logs: Union[bool, None] = None,
743
+ stream_metrics: Union[bool, MetricsConfig, None] = None,
744
+ body: dict = None,
745
+ headers: dict = None,
746
+ pdb: Union[bool, int] = None,
747
+ serialization: str = "json",
748
+ ):
749
+ """Async version of call_method."""
750
+ (
751
+ endpoint,
752
+ headers,
753
+ stop_event,
754
+ log_task,
755
+ monitoring_task,
756
+ _,
757
+ ) = self._prepare_request_async(endpoint, stream_logs, stream_metrics, headers, pdb, serialization)
758
+ try:
759
+ json_data = _serialize_body(body, serialization)
760
+ response = await self.post_async(endpoint=endpoint, json=json_data, headers=headers)
761
+ response.raise_for_status()
762
+ result = _deserialize_response(response, serialization)
763
+
764
+ if stream_logs and log_task:
765
+ await asyncio.sleep(0.5)
766
+
767
+ return result
768
+ finally:
769
+ stop_event.set()
770
+ if log_task:
771
+ try:
772
+ await asyncio.wait_for(log_task, timeout=0.5)
773
+ except asyncio.TimeoutError:
774
+ log_task.cancel()
775
+ try:
776
+ await log_task
777
+ except asyncio.CancelledError:
778
+ pass
779
+
780
+ def post(self, endpoint, json=None, headers=None):
781
+ return self._make_request("post", endpoint, json=json, headers=headers)
782
+
783
+ def put(self, endpoint, json=None, headers=None):
784
+ return self._make_request("put", endpoint, json=json, headers=headers)
785
+
786
+ def delete(self, endpoint, json=None, headers=None):
787
+ return self._make_request("delete", endpoint, json=json, headers=headers)
788
+
789
+ def get(self, endpoint, headers=None):
790
+ return self._make_request("get", endpoint, headers=headers)
791
+
792
+ async def post_async(self, endpoint, json=None, headers=None):
793
+ return await self._make_request_async("post", endpoint, json=json, headers=headers)
794
+
795
+ async def put_async(self, endpoint, json=None, headers=None):
796
+ return await self._make_request_async("put", endpoint, json=json, headers=headers)
797
+
798
+ async def delete_async(self, endpoint, json=None, headers=None):
799
+ return await self._make_request_async("delete", endpoint, json=json, headers=headers)
800
+
801
+ async def get_async(self, endpoint, headers=None):
802
+ return await self._make_request_async("get", endpoint, headers=headers)