kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
kubetorch/cli_utils.py ADDED
@@ -0,0 +1,1025 @@
1
+ import asyncio
2
+ import base64
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import signal
7
+
8
+ import subprocess
9
+ import threading
10
+ import time
11
+ import urllib.parse
12
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
13
+ from contextlib import contextmanager
14
+ from datetime import datetime, timedelta
15
+ from enum import Enum
16
+ from pathlib import Path
17
+ from typing import List, Optional
18
+
19
+ import httpx
20
+ import typer
21
+ import yaml
22
+ from kubernetes import client
23
+ from kubernetes.client.rest import ApiException
24
+ from pydantic import BaseModel
25
+ from rich import box
26
+ from rich.console import Console
27
+ from rich.style import Style
28
+ from rich.table import Table
29
+ from websocket import create_connection
30
+
31
+ import kubetorch.serving.constants as serving_constants
32
+
33
+ from kubetorch import globals
34
+ from kubetorch.config import KubetorchConfig
35
+ from kubetorch.constants import MAX_PORT_TRIES
36
+
37
+ from kubetorch.resources.compute.utils import is_port_available
38
+ from kubetorch.servers.http.utils import stream_logs_websocket_helper, StreamType
39
+ from kubetorch.serving.utils import wait_for_port_forward
40
+ from kubetorch.utils import load_kubeconfig
41
+
42
+ from .constants import BULLET_UNICODE, CPU_RATE, DOUBLE_SPACE_UNICODE, GPU_RATE
43
+
44
+ from .logger import get_logger
45
+
46
+ console = Console()
47
+
48
+ logger = get_logger(__name__)
49
+
50
+ OTEL_ERROR_MSG = (
51
+ "[red]Grafana setup failed. Is `kubetorch-otel` installed? See "
52
+ "https://www.run.house/kubetorch/advanced-installation/#kubetorch-telemetry-helm-chart for more info.[/red]"
53
+ )
54
+
55
+
56
+ # ------------------ Billing helpers--------------------
57
+ class UsageData(BaseModel):
58
+ date_start: str
59
+ date_end: str
60
+ cpu_hours: float
61
+ gpu_hours: float
62
+
63
+
64
+ class BillingTotals(BaseModel):
65
+ cpu: float
66
+ gpu: float
67
+
68
+
69
+ class BillingCosts(BaseModel):
70
+ cpu: float
71
+ gpu: float
72
+
73
+
74
+ class BillingRequest(BaseModel):
75
+ license_key: str
76
+ signature: str
77
+ file_name: str
78
+ username: Optional[str] = None
79
+ usage_data: UsageData
80
+ totals: BillingTotals
81
+ costs: BillingCosts
82
+
83
+
84
+ # ------------------ Generic helpers--------------------
85
+ class VolumeAction(str, Enum):
86
+ list = "list"
87
+ create = "create"
88
+ delete = "delete"
89
+ ssh = "ssh"
90
+
91
+
92
+ class SecretAction(str, Enum):
93
+ list = "list"
94
+ create = "create"
95
+ delete = "delete"
96
+ describe = "describe"
97
+
98
+
99
+ def default_typer_values(*args):
100
+ """Convert typer model arguments to their default values or types, so the CLI commands can be also imported and
101
+ called in Python if desired."""
102
+ new_args = []
103
+ for arg in args:
104
+ if isinstance(arg, typer.models.OptionInfo):
105
+ # Replace the typer model with its value
106
+ arg = arg.default if arg.default is not None else arg.type
107
+ elif isinstance(arg, typer.models.ArgumentInfo):
108
+ # Replace the typer model with its value
109
+ arg = arg.default if arg.default is not None else arg.type
110
+ new_args.append(arg)
111
+ return new_args
112
+
113
+
114
+ def validate_config_key(key: str = None):
115
+ if key is None:
116
+ return
117
+
118
+ valid_keys = {
119
+ name
120
+ for name, attr in vars(KubetorchConfig).items()
121
+ if isinstance(attr, property)
122
+ }
123
+ if key not in valid_keys:
124
+ raise typer.BadParameter(f"Valid keys are: {', '.join(sorted(valid_keys))}")
125
+ return key
126
+
127
+
128
+ def get_pods_for_service_cli(name: str, namespace: str, core_api):
129
+ """Get pods for a service using unified label selector."""
130
+ # Use unified service label - works for all deployment modes
131
+ label_selector = f"kubetorch.com/service={name}"
132
+ return core_api.list_namespaced_pod(
133
+ namespace=namespace,
134
+ label_selector=label_selector,
135
+ )
136
+
137
+
138
+ def service_name_argument(*args, required: bool = True, **kwargs):
139
+ def _lowercase(value: str) -> str:
140
+ return value.lower() if value else value
141
+
142
+ default = ... if required else ""
143
+ return typer.Argument(default, callback=_lowercase, *args, **kwargs)
144
+
145
+
146
+ def get_deployment_mode(name: str, namespace: str, custom_api, apps_v1_api) -> str:
147
+ """Validate service exists and return deployment mode."""
148
+ try:
149
+ original_name = name
150
+ deployment_mode = detect_deployment_mode(
151
+ name, namespace, custom_api, apps_v1_api
152
+ )
153
+ # If service not found and not already prefixed with username, try with username prefix
154
+ if (
155
+ not deployment_mode
156
+ and globals.config.username
157
+ and not name.startswith(globals.config.username + "-")
158
+ ):
159
+ name = f"{globals.config.username}-{name}"
160
+ deployment_mode = detect_deployment_mode(
161
+ name, namespace, custom_api, apps_v1_api
162
+ )
163
+
164
+ if not deployment_mode:
165
+ console.print(
166
+ f"[red]Failed to load service [bold]{original_name}[/bold] in namespace {namespace}[/red]"
167
+ )
168
+ raise typer.Exit(1)
169
+ console.print(
170
+ f"Found [green]{deployment_mode}[/green] service [blue]{name}[/blue]"
171
+ )
172
+ return name, deployment_mode
173
+
174
+ except ApiException as e:
175
+ console.print(f"[red]Kubernetes API error: {e}[/red]")
176
+ raise typer.Exit(1)
177
+
178
+
179
+ def validate_pods_exist(name: str, namespace: str, core_api) -> list:
180
+ """Validate pods exist for service and return pod list."""
181
+ pods = get_pods_for_service_cli(name, namespace, core_api)
182
+ if not pods.items:
183
+ console.print(
184
+ f"\n[red]No pods found for service {name} in namespace {namespace}[/red]"
185
+ )
186
+ console.print(
187
+ f"You can view the service's status using:\n [yellow] kt status {name}[/yellow]"
188
+ )
189
+ raise typer.Exit(1)
190
+ return pods.items
191
+
192
+
193
+ @contextmanager
194
+ def port_forward_to_pod(
195
+ pod_name,
196
+ namespace: str = None,
197
+ local_port: int = 8080,
198
+ remote_port: int = serving_constants.DEFAULT_NGINX_PORT,
199
+ health_endpoint: str = None,
200
+ ):
201
+
202
+ load_kubeconfig()
203
+ for attempt in range(MAX_PORT_TRIES):
204
+ candidate_port = local_port + attempt
205
+ if not is_port_available(candidate_port):
206
+ logger.debug(
207
+ f"Local port {candidate_port} is already in use, trying again..."
208
+ )
209
+ continue
210
+
211
+ cmd = [
212
+ "kubectl",
213
+ "port-forward",
214
+ f"pod/{pod_name}",
215
+ f"{candidate_port}:{remote_port}",
216
+ "--namespace",
217
+ namespace,
218
+ ]
219
+ logger.debug(f"Running port-forward command: {' '.join(cmd)}")
220
+
221
+ process = subprocess.Popen(
222
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, start_new_session=True
223
+ )
224
+
225
+ try:
226
+ wait_for_port_forward(
227
+ process,
228
+ candidate_port,
229
+ health_endpoint=health_endpoint,
230
+ validate_kubetorch_versions=False,
231
+ )
232
+ time.sleep(2)
233
+ yield candidate_port
234
+ return
235
+
236
+ finally:
237
+ if process:
238
+ try:
239
+ os.killpg(os.getpgid(process.pid), signal.SIGTERM)
240
+ process.wait()
241
+ except (ProcessLookupError, OSError):
242
+ # Process may have already terminated
243
+ pass
244
+
245
+ raise RuntimeError(f"Could not bind available port after {MAX_PORT_TRIES} attempts")
246
+
247
+
248
+ def get_last_updated(pod):
249
+ conditions = pod["status"].get("conditions", [])
250
+ latest = max(
251
+ (
252
+ c.get("lastTransitionTime")
253
+ for c in conditions
254
+ if c.get("lastTransitionTime")
255
+ ),
256
+ default="",
257
+ )
258
+ return latest
259
+
260
+
261
+ # ------------------ Reporting helpers--------------------
262
+ def upload_report(
263
+ usage_data: dict,
264
+ signature: str,
265
+ costs: BillingCosts,
266
+ totals: BillingTotals,
267
+ file_name: str,
268
+ license_key: str,
269
+ username: str = None,
270
+ ):
271
+ billing_request = BillingRequest(
272
+ license_key=license_key,
273
+ signature=signature,
274
+ file_name=file_name,
275
+ username=username,
276
+ usage_data=UsageData(**usage_data),
277
+ costs=costs,
278
+ totals=totals,
279
+ )
280
+
281
+ url = "https://auth.run.house/v1/billing/report"
282
+ resp = httpx.post(url, json=billing_request.model_dump())
283
+ if resp.status_code != 200:
284
+ console.print("[red]Failed to send billing report[/red]")
285
+ raise typer.Exit(1)
286
+
287
+
288
+ def export_report_pdf(report_data, filename):
289
+ try:
290
+ from reportlab.lib import colors
291
+ from reportlab.lib.pagesizes import letter
292
+ from reportlab.pdfgen import canvas
293
+ except ImportError:
294
+ console.print(
295
+ "[red]ReportLab is required for downloading the report. Please install it "
296
+ "with `pip install reportlab`.[/red]"
297
+ )
298
+ raise typer.Exit(1)
299
+
300
+ usage_data: dict = report_data["usage_report"]
301
+ report_str = json.dumps(report_data, sort_keys=True)
302
+ signature = base64.b64encode(hashlib.sha256(report_str.encode()).digest()).decode()
303
+
304
+ c = canvas.Canvas(filename, pagesize=letter)
305
+ width, height = letter
306
+
307
+ # Sidebar
308
+ sidebar_color = colors.HexColor("#4B9CD3")
309
+ c.setFillColor(sidebar_color)
310
+ c.roundRect(0, 0, 18, height, 0, fill=1, stroke=0)
311
+ c.setFillColor(colors.black)
312
+
313
+ y = height - 60
314
+
315
+ # Header Title
316
+ c.setFont("Helvetica-Bold", 26)
317
+ c.setFillColor(sidebar_color)
318
+ c.drawCentredString(width / 2, y, "Kubetorch Usage Report")
319
+ c.setFillColor(colors.black)
320
+ y -= 30
321
+
322
+ # Header Bar
323
+ c.setStrokeColor(sidebar_color)
324
+ c.setLineWidth(2)
325
+ c.line(40, y, width - 40, y)
326
+ y -= 20
327
+
328
+ # Info Box
329
+ c.setFillColor(colors.whitesmoke)
330
+ c.roundRect(40, y - 60, width - 80, 60, 8, fill=1, stroke=0)
331
+ c.setFillColor(colors.black)
332
+ c.setFont("Helvetica-Bold", 12)
333
+ c.drawString(55, y - 20, "Username:")
334
+ c.drawString(55, y - 35, "Cluster:")
335
+ c.setFont("Helvetica", 12)
336
+ c.drawString(130, y - 20, report_data["username"])
337
+ c.drawString(130, y - 35, report_data.get("cluster_name", "N/A"))
338
+ y -= 100
339
+
340
+ # Usage Summary Section
341
+ c.setFont("Helvetica-Bold", 15)
342
+ c.setFillColor(sidebar_color)
343
+ c.drawString(40, y, "Usage Summary")
344
+ c.setFillColor(colors.black)
345
+ y -= 25
346
+
347
+ # Table Outline (dashed)
348
+ table_left = 40
349
+ table_width = width - 80
350
+ row_height = 18
351
+ num_rows = 2 # header + data
352
+ table_height = row_height * num_rows
353
+
354
+ # Table Header (centered text)
355
+ header_height = row_height
356
+ c.setFillColor(sidebar_color)
357
+ c.roundRect(
358
+ table_left, y - header_height, table_width, header_height, 4, fill=1, stroke=0
359
+ )
360
+ c.setFont("Helvetica-Bold", 11)
361
+ c.setFillColor(colors.white)
362
+ header_y = y - header_height + 5
363
+ c.drawString(table_left + 10, header_y, "Start Date")
364
+ c.drawString(table_left + 90, header_y, "End Date")
365
+ c.drawString(table_left + 200, header_y, "vCPU Hours")
366
+ c.drawString(table_left + 300, header_y, "GPU Hours")
367
+ c.setFillColor(colors.black)
368
+
369
+ # Dashed outline (starts at header, not above)
370
+ c.setStrokeColor(sidebar_color)
371
+ c.setDash(4, 4)
372
+ c.roundRect(
373
+ table_left, y - table_height, table_width, table_height, 6, fill=0, stroke=1
374
+ )
375
+ c.setDash()
376
+ y -= header_height
377
+
378
+ # Table Rows
379
+ c.setFont("Helvetica", 10)
380
+ y -= row_height
381
+ c.drawString(table_left + 10, y + 5, usage_data["date_start"])
382
+ c.drawString(table_left + 90, y + 5, usage_data["date_end"])
383
+ c.drawRightString(table_left + 270, y + 5, f"{usage_data['cpu_hours']:.2f}")
384
+ c.drawRightString(table_left + 370, y + 5, f"{usage_data['gpu_hours']:.2f}")
385
+
386
+ y -= 30
387
+
388
+ # Invoice Calculation
389
+ total_cpu = usage_data["cpu_hours"]
390
+ total_gpu = usage_data["gpu_hours"]
391
+ cpu_cost = total_cpu * CPU_RATE
392
+ gpu_cost = total_gpu * GPU_RATE
393
+ total_cost = cpu_cost + gpu_cost
394
+
395
+ y -= 20
396
+ c.setFont("Helvetica-Bold", 13)
397
+ c.setFillColor(sidebar_color)
398
+ c.drawString(40, y, "Invoice Summary")
399
+ c.setFillColor(colors.black)
400
+ y -= 18
401
+
402
+ c.setFont("Helvetica", 11)
403
+ c.drawString(50, y, f"Total vCPU Hours: {total_cpu:.2f} @ ${CPU_RATE:.2f}/hr")
404
+ c.drawRightString(width - 50, y, f"${cpu_cost:.2f}")
405
+ y -= 15
406
+ c.drawString(50, y, f"Total GPU Hours: {total_gpu:} @ ${GPU_RATE:.2f}/hr")
407
+ c.drawRightString(width - 50, y, f"${gpu_cost:.2f}")
408
+ y -= 15
409
+
410
+ line_left = 50
411
+ line_right = width - 50
412
+ c.setStrokeColor(sidebar_color)
413
+ c.setLineWidth(1.5)
414
+ c.line(line_left, y, line_right, y)
415
+ y -= 15
416
+
417
+ c.setFont("Helvetica-Bold", 12)
418
+ c.drawString(50, y, "Total Due:")
419
+ c.setFont("Helvetica-Bold", 12)
420
+ c.setFillColor(colors.HexColor("#008000"))
421
+ c.drawRightString(width - 50, y, f"${total_cost:.2f}")
422
+ c.setFillColor(colors.black)
423
+ y -= 30
424
+
425
+ # Signature and footer
426
+ sig_y = 80
427
+ sig_val_y = sig_y - 15
428
+ footer_y = sig_val_y - 40
429
+
430
+ # Signature at the bottom
431
+ c.setFont("Helvetica-Bold", 12)
432
+ c.setFillColor(colors.black)
433
+ c.drawString(40, sig_y, "Signature:")
434
+ c.setFont("Courier-Oblique", 8)
435
+ c.setFillColor(colors.HexColor("#888888"))
436
+ c.drawString(40, sig_val_y, signature)
437
+ c.setFillColor(colors.black)
438
+
439
+ # Footer
440
+ c.setFont("Helvetica-Oblique", 8)
441
+ c.setFillColor(colors.HexColor("#888888"))
442
+ c.drawString(
443
+ 40, footer_y, f"Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
444
+ )
445
+ c.setFillColor(colors.black)
446
+
447
+ c.save()
448
+ return signature
449
+
450
+
451
+ def print_usage_table(usage_data, cluster_name):
452
+ table = Table(title="Usage Summary")
453
+ table.add_column("Start Date")
454
+ table.add_column("End Date")
455
+ table.add_column("vCPU Hours")
456
+ table.add_column("GPU Hours")
457
+ table.add_row(
458
+ usage_data["date_start"],
459
+ usage_data["date_end"],
460
+ str(usage_data["cpu_hours"]),
461
+ str(usage_data["gpu_hours"]),
462
+ )
463
+ console.print(table)
464
+ console.print(f"[dim]Cluster: {str(cluster_name)}[/dim]")
465
+
466
+
467
+ def get_last_n_calendar_weeks(n_weeks):
468
+ """Return a list of (week_start, week_end) tuples for the last n full calendar weeks (Mon–Sun),
469
+ not including this week."""
470
+ today = datetime.utcnow().date()
471
+ # Find the most recent Monday before today (not including today if today is Monday)
472
+ if today.weekday() == 0:
473
+ last_monday = today - timedelta(days=7)
474
+ else:
475
+ last_monday = today - timedelta(days=today.weekday())
476
+
477
+ weeks = []
478
+ for i in range(n_weeks):
479
+ week_start = last_monday - timedelta(weeks=i)
480
+ week_end = week_start + timedelta(days=6) # Monday + 6 = Sunday
481
+ weeks.append((week_start, week_end))
482
+
483
+ weeks.reverse() # So oldest week is first
484
+ return weeks
485
+
486
+
487
+ def get_usage_data(prom, weeks):
488
+ from datetime import datetime, timedelta
489
+
490
+ days = weeks * 7
491
+ end_time = datetime.now()
492
+ start_time = end_time - timedelta(days=days)
493
+
494
+ # sum of CPU-seconds used by all cores for that container (ex: 2 cores for 1 second = 2 seconds)
495
+ cpu_query = (
496
+ f'increase(container_cpu_usage_seconds_total{{container="kubetorch"}}[{days}d])'
497
+ )
498
+ cpu_result = prom.custom_query(cpu_query)
499
+
500
+ total_cpu_seconds = 0
501
+ if cpu_result:
502
+ for series in cpu_result:
503
+ cpu_val = float(series["value"][1])
504
+ total_cpu_seconds += cpu_val
505
+
506
+ cpu_hours = round(total_cpu_seconds / 3600, 2)
507
+
508
+ # requested GPUs × time they were running
509
+ gpu_query = f'sum_over_time(kube_pod_container_resource_requests{{resource="nvidia_com_gpu", container="kubetorch"}}[{days}d])'
510
+ gpu_result = prom.custom_query(gpu_query)
511
+
512
+ # Convert to "GPU hours" over the period
513
+ total_gpu_seconds = sum(float(s["value"][1]) for s in gpu_result or [])
514
+ gpu_hours = total_gpu_seconds / 3600
515
+
516
+ usage = {
517
+ "date_start": start_time.strftime("%Y-%m-%d"),
518
+ "date_end": end_time.strftime("%Y-%m-%d"),
519
+ "cpu_hours": round(cpu_hours, 2),
520
+ "gpu_hours": round(gpu_hours, 2),
521
+ }
522
+
523
+ return usage
524
+
525
+
526
+ # ------------------ Monitoring helpers--------------------
527
+ def get_service_metrics(prom, pod_name: str, pod_node: str, running_on_gpu: bool):
528
+ """Get CPU, GPU (if relevant) and memory metrics for a pod"""
529
+
530
+ def extract_prometheus_metric_value(query) -> float:
531
+ result = prom.custom_query(query=query)
532
+ return float(result[0].get("value")[1]) if result else 0
533
+
534
+ # --- CPU metrics --- #
535
+ cpu_query_time_window = "30s"
536
+ used_cpu_query = (
537
+ f"sum(rate(container_cpu_usage_seconds_total{{pod='{pod_name}', "
538
+ f"container='kubetorch'}}[{cpu_query_time_window}]))"
539
+ )
540
+ requested_cpu_query = (
541
+ f"sum(kube_pod_container_resource_requests{{pod='{pod_name}', "
542
+ f"resource='cpu', container='kubetorch'}})"
543
+ )
544
+
545
+ used_cpu_result = extract_prometheus_metric_value(used_cpu_query)
546
+
547
+ requested_cpu_result = extract_prometheus_metric_value(requested_cpu_query)
548
+
549
+ cpu_util = (
550
+ round((100 * (used_cpu_result / requested_cpu_result)), 3)
551
+ if used_cpu_result and requested_cpu_result
552
+ else 0
553
+ )
554
+
555
+ memory_usage_query = f"container_memory_usage_bytes{{pod='{pod_name}', container='kubetorch'}} / 1073741824"
556
+ memory_usage = round(extract_prometheus_metric_value(memory_usage_query), 3)
557
+
558
+ machine_mem_query = (
559
+ f"machine_memory_bytes{{node='{pod_node}'}} / 1073741824" # convert to GB
560
+ )
561
+ machine_mem_result = extract_prometheus_metric_value(machine_mem_query)
562
+
563
+ cpu_mem_percent = (
564
+ round((memory_usage / machine_mem_result) * 100, 3) if machine_mem_result else 0
565
+ )
566
+ collected_metrics = {
567
+ "cpu_util": cpu_util,
568
+ "used_cpu": round(used_cpu_result, 4),
569
+ "requested_cpu": round(requested_cpu_result, 4),
570
+ "cpu_memory_usage": memory_usage,
571
+ "cpu_memory_total": round(machine_mem_result, 3),
572
+ "cpu_memory_usage_percent": cpu_mem_percent,
573
+ }
574
+
575
+ # --- GPU metrics --- #
576
+ if running_on_gpu:
577
+ gpu_util_query = f"DCGM_FI_DEV_GPU_UTIL{{exported_pod='{pod_name}', exported_container='kubetorch'}}"
578
+ gpu_mem_used_query = (
579
+ f"DCGM_FI_DEV_FB_USED{{exported_pod='{pod_name}', "
580
+ f"exported_container='kubetorch'}} * 1.048576 / 1000"
581
+ ) # convert MiB to MB to GB
582
+ gpu_mem_free_query = (
583
+ f"DCGM_FI_DEV_FB_FREE{{exported_pod='{pod_name}', "
584
+ f"exported_container='kubetorch'}} * 1.048576 / 1000"
585
+ ) # convert MiB to MB to GB
586
+
587
+ gpu_util = extract_prometheus_metric_value(gpu_util_query)
588
+ gpu_mem_used = round(extract_prometheus_metric_value(gpu_mem_used_query), 3)
589
+ gpu_mem_free = extract_prometheus_metric_value(gpu_mem_free_query)
590
+ gpu_mem_total = gpu_mem_free + gpu_mem_used
591
+ gpu_mem_percent = (
592
+ round(100 * (gpu_mem_used / gpu_mem_total), 2) if gpu_mem_used else 0
593
+ ) # raw approximation, because total_allocated_gpu_memory is not collected
594
+
595
+ gpu_metrics = {
596
+ "gpu_util": gpu_util,
597
+ "gpu_memory_usage": gpu_mem_used,
598
+ "gpu_memory_total": round(gpu_mem_total, 3),
599
+ "gpu_memory_usage_percent": gpu_mem_percent,
600
+ }
601
+
602
+ collected_metrics.update(gpu_metrics)
603
+
604
+ return collected_metrics
605
+
606
+
607
+ def get_current_cluster_name():
608
+ try:
609
+ from kubernetes import config as k8s_config
610
+
611
+ k8s_config.load_incluster_config()
612
+ # In-cluster: return a generic name or the service host
613
+ return os.environ.get("CLUSTER_NAME", "in-cluster")
614
+ except Exception:
615
+ pass
616
+
617
+ # Fallback to kubeconfig file
618
+ kubeconfig_path = os.getenv("KUBECONFIG") or str(Path.home() / ".kube" / "config")
619
+ if not os.path.exists(kubeconfig_path):
620
+ return None
621
+
622
+ with open(kubeconfig_path, "r") as f:
623
+ kubeconfig = yaml.safe_load(f)
624
+ current_context = kubeconfig.get("current-context")
625
+ for context in kubeconfig.get("contexts", []):
626
+ if context["name"] == current_context:
627
+ return context["context"]["cluster"]
628
+ return None
629
+
630
+
631
+ def print_pod_info(pod_name, pod_idx, is_gpu, metrics=None, queue_name=None):
632
+ """Print pod info with metrics if available"""
633
+ queue_msg = f" | [bold]Queue Name[/bold]: {queue_name}"
634
+ base_msg = (
635
+ f"{BULLET_UNICODE} [reset][bold cyan]{pod_name}[/bold cyan] (idx: {pod_idx})"
636
+ )
637
+ if queue_name:
638
+ base_msg += queue_msg
639
+ console.print(base_msg)
640
+ if metrics:
641
+ console.print(
642
+ f"{DOUBLE_SPACE_UNICODE}[bold]CPU[/bold]: [reset]{metrics['cpu_util']}% "
643
+ f"({metrics['used_cpu']} / {metrics['requested_cpu']}) | "
644
+ f"[bold]Memory[/bold]: {metrics['cpu_memory_usage']} / {metrics['cpu_memory_total']} "
645
+ f"[bold]GB[/bold] ({metrics['cpu_memory_usage_percent']}%)"
646
+ )
647
+ if is_gpu:
648
+ console.print(
649
+ f"{DOUBLE_SPACE_UNICODE}GPU: [reset]{metrics['gpu_util']}% | "
650
+ f"Memory: {metrics['gpu_memory_usage']} / {metrics['gpu_memory_total']} "
651
+ f"GB ({metrics['gpu_memory_usage_percent']}%)"
652
+ )
653
+ else:
654
+ console.print(f"{DOUBLE_SPACE_UNICODE}[yellow]Metrics unavailable[/yellow]")
655
+
656
+
657
+ def _get_logs_from_loki_worker(uri: str, print_pod_name: bool):
658
+ """Worker function for getting logs from Loki - runs in a separate thread."""
659
+ ws = None
660
+ try:
661
+ ws = create_connection(uri)
662
+ message = ws.recv()
663
+ if not message:
664
+ return None
665
+ data = json.loads(message)
666
+ logs = []
667
+ if data.get("streams"):
668
+ for stream in data["streams"]:
669
+ pod_name = (
670
+ f'({stream.get("stream").get("pod")}) ' if print_pod_name else ""
671
+ )
672
+ for value in stream.get("values"):
673
+ try:
674
+ log_line = json.loads(value[1])
675
+ log_name = log_line.get("name")
676
+ if log_name == "print_redirect":
677
+ logs.append(f'{pod_name}{log_line.get("message")}')
678
+ elif log_name != "uvicorn.access":
679
+ formatted_log = (
680
+ f"{pod_name}{log_line.get('asctime')} | {log_line.get('levelname')} | "
681
+ f"{log_line.get('message')}\n"
682
+ )
683
+ logs.append(formatted_log)
684
+ except Exception:
685
+ logs.append(value[1])
686
+ return logs
687
+ finally:
688
+ if ws:
689
+ try:
690
+ ws.close()
691
+ except Exception:
692
+ pass
693
+
694
+
695
+ def get_logs_from_loki(
696
+ query: str = None,
697
+ uri: str = None,
698
+ print_pod_name: bool = False,
699
+ timeout: float = 5.0,
700
+ ):
701
+ """Get logs from Loki with fail-fast approach to avoid hanging."""
702
+ try:
703
+ # If URI is provided, use it directly (skip cluster checks)
704
+ if uri:
705
+ return _get_logs_from_loki_worker(uri, print_pod_name)
706
+
707
+ import urllib.parse
708
+
709
+ # Now safe to proceed with service URL setup
710
+ from kubetorch import globals
711
+ from kubetorch.utils import http_to_ws
712
+
713
+ base_url = globals.service_url()
714
+ target_uri = f"{http_to_ws(base_url)}/loki/api/v1/tail?query={urllib.parse.quote_plus(query)}"
715
+
716
+ # Use thread timeout for websocket worker since websocket timeouts don't work reliably
717
+ executor = ThreadPoolExecutor(max_workers=1)
718
+ try:
719
+ future = executor.submit(
720
+ _get_logs_from_loki_worker, target_uri, print_pod_name
721
+ )
722
+ try:
723
+ result = future.result(timeout=timeout)
724
+ return result
725
+ except TimeoutError:
726
+ logger.debug(f"Loki websocket connection timed out after {timeout}s")
727
+ return None
728
+ except Exception as e:
729
+ logger.debug(f"Error in Loki websocket worker: {e}")
730
+ return None
731
+ finally:
732
+ # Don't wait for stuck threads to complete
733
+ executor.shutdown(wait=False)
734
+
735
+ except Exception as e:
736
+ logger.debug(f"Error getting logs from Loki: {e}")
737
+ return None
738
+
739
+
740
+ def stream_logs_websocket(uri, stop_event, service_name, print_pod_name: bool = False):
741
+ """Stream logs using Loki's websocket tail endpoint"""
742
+
743
+ console.print(f"\nFollowing logs of [reset]{service_name}\n")
744
+
745
+ # Create and run event loop in a separate thread
746
+ loop = asyncio.new_event_loop()
747
+ asyncio.set_event_loop(loop)
748
+ try:
749
+ loop.run_until_complete(
750
+ stream_logs_websocket_helper(
751
+ uri=uri,
752
+ stop_event=stop_event,
753
+ stream_type=StreamType.CLI,
754
+ print_pod_name=print_pod_name,
755
+ )
756
+ )
757
+ finally:
758
+ loop.close()
759
+ # Signal the log thread to stop
760
+ stop_event.set()
761
+ # Don't wait for the log thread - it will handle its own cleanup
762
+
763
+
764
+ def get_logs_query(name: str, namespace: str, selected_pod: str, deployment_mode):
765
+ if not selected_pod:
766
+ if deployment_mode in ["knative", "deployment"]:
767
+ # we need to get the pod names first since Loki doesn't have a service_name label
768
+ core_api = client.CoreV1Api()
769
+ pods = validate_pods_exist(name, namespace, core_api)
770
+ pod_names = [pod.metadata.name for pod in pods]
771
+ return f'{{k8s_pod_name=~"{"|".join(pod_names)}",k8s_container_name="kubetorch"}} | json'
772
+ else:
773
+ console.print(
774
+ f"[red]Logs does not support deployment mode: {deployment_mode}[/red]"
775
+ )
776
+ return None
777
+ else:
778
+ return (
779
+ f'{{k8s_pod_name=~"{selected_pod}",k8s_container_name="kubetorch"}} | json'
780
+ )
781
+
782
+
783
+ def follow_logs_in_cli(
784
+ name: str,
785
+ namespace: str,
786
+ selected_pod: str,
787
+ deployment_mode,
788
+ print_pod_name: bool = False,
789
+ ):
790
+ """Stream logs when triggerd by the CLI command."""
791
+ from kubetorch.utils import http_to_ws
792
+
793
+ stop_event = threading.Event()
794
+
795
+ # Set up signal handler to cleanly stop on Ctrl+C
796
+ def signal_handler(signum, frame):
797
+ stop_event.set()
798
+ raise KeyboardInterrupt()
799
+
800
+ original_handler = signal.signal(signal.SIGINT, signal_handler)
801
+
802
+ # setting up the query
803
+ query = get_logs_query(name, namespace, selected_pod, deployment_mode)
804
+ if not query:
805
+ return
806
+ encoded_query = urllib.parse.quote_plus(query)
807
+
808
+ base_url = globals.service_url()
809
+ uri = f"{http_to_ws(base_url)}/loki/api/v1/tail?query={encoded_query}"
810
+
811
+ try:
812
+ stream_logs_websocket(
813
+ uri=uri,
814
+ stop_event=stop_event,
815
+ service_name=name,
816
+ print_pod_name=print_pod_name,
817
+ )
818
+ finally:
819
+ # Restore original signal handler
820
+ signal.signal(signal.SIGINT, original_handler)
821
+
822
+
823
+ def is_ingress_vpc_only(annotations: dict):
824
+ # Check for internal LoadBalancer annotations
825
+ internal_checks = [
826
+ annotations.get("service.beta.kubernetes.io/aws-load-balancer-internal")
827
+ == "true",
828
+ annotations.get("networking.gke.io/load-balancer-type") == "Internal",
829
+ annotations.get("service.beta.kubernetes.io/oci-load-balancer-internal")
830
+ == "true",
831
+ ]
832
+
833
+ vpc_only = any(internal_checks)
834
+ return vpc_only
835
+
836
+
837
+ def load_ingress(namespace: str = globals.config.install_namespace):
838
+ networking_v1_api = client.NetworkingV1Api()
839
+ ingresses = networking_v1_api.list_namespaced_ingress(namespace=namespace)
840
+
841
+ for ingress in ingresses.items:
842
+ if ingress.metadata.name == "kubetorch-proxy-ingress":
843
+ return ingress
844
+
845
+
846
+ def get_ingress_host(ingress):
847
+ """Get the configured host from the kubetorch ingress."""
848
+ try:
849
+ return ingress.spec.rules[0].host
850
+ except Exception:
851
+ return None
852
+
853
+
854
+ def list_all_queues():
855
+ try:
856
+ custom_api = client.CustomObjectsApi()
857
+ queues = custom_api.list_cluster_custom_object(
858
+ group="scheduling.run.ai",
859
+ version="v2",
860
+ plural="queues",
861
+ )["items"]
862
+
863
+ if not queues:
864
+ console.print("[yellow]No queues found in the cluster[/yellow]")
865
+ return
866
+
867
+ # Insert "default" queue if missing
868
+ if not any(q["metadata"]["name"] == "default" for q in queues):
869
+ default_children = [
870
+ q["metadata"]["name"]
871
+ for q in queues
872
+ if q.get("spec", {}).get("parentQueue") == "default"
873
+ ]
874
+ queues.insert(
875
+ 0,
876
+ {
877
+ "metadata": {"name": "default"},
878
+ "spec": {
879
+ "parentQueue": "-",
880
+ "children": default_children,
881
+ "resources": {
882
+ "cpu": {"quota": "-", "overQuotaWeight": "-"},
883
+ "gpu": {"quota": "-", "overQuotaWeight": "-"},
884
+ "memory": {"quota": "-", "overQuotaWeight": "-"},
885
+ },
886
+ "priority": "-",
887
+ },
888
+ },
889
+ )
890
+
891
+ queue_table = Table(title="Available Queues", header_style=Style(bold=True))
892
+ queue_table.add_column("QUEUE NAME", style="cyan")
893
+ queue_table.add_column("PRIORITY", style="magenta")
894
+ queue_table.add_column("PARENT", style="green")
895
+ queue_table.add_column("CHILDREN", style="yellow")
896
+ queue_table.add_column("CPU QUOTA", style="white")
897
+ queue_table.add_column("GPU QUOTA", style="white")
898
+ queue_table.add_column("MEMORY QUOTA", style="white")
899
+ queue_table.add_column("OVERQUOTA WEIGHT", style="blue")
900
+
901
+ for q in queues:
902
+ spec = q.get("spec", {})
903
+ resources = spec.get("resources", {})
904
+ cpu = resources.get("cpu", {})
905
+ gpu = resources.get("gpu", {})
906
+ memory = resources.get("memory", {})
907
+
908
+ queue_table.add_row(
909
+ q["metadata"]["name"],
910
+ str(spec.get("priority", "-")),
911
+ spec.get("parentQueue", "-"),
912
+ ", ".join(spec.get("children", [])) or "-",
913
+ str(cpu.get("quota", "-")),
914
+ str(gpu.get("quota", "-")),
915
+ str(memory.get("quota", "-")),
916
+ str(
917
+ cpu.get("overQuotaWeight", "-")
918
+ ), # use CPU's overQuotaWeight as example
919
+ )
920
+
921
+ console.print(queue_table)
922
+ return
923
+
924
+ except client.exceptions.ApiException as e:
925
+ console.print(f"[red]Failed to list queues: {e}[/red]")
926
+ raise typer.Exit(1)
927
+
928
+
929
+ def detect_deployment_mode(name: str, namespace: str, custom_api, apps_v1_api):
930
+ """Detect if a service is deployed as Knative, Deployment, or RayCluster."""
931
+ # First try Deployment
932
+ try:
933
+ apps_v1_api.read_namespaced_deployment(name=name, namespace=namespace)
934
+ return "deployment"
935
+ except ApiException:
936
+ pass
937
+
938
+ # Then try Knative
939
+ try:
940
+ custom_api.get_namespaced_custom_object(
941
+ group="serving.knative.dev",
942
+ version="v1",
943
+ namespace=namespace,
944
+ plural="services",
945
+ name=name,
946
+ )
947
+ return "knative"
948
+ except ApiException:
949
+ pass
950
+
951
+ # Then try RayCluster
952
+ try:
953
+ custom_api.get_namespaced_custom_object(
954
+ group="ray.io",
955
+ version="v1",
956
+ namespace=namespace,
957
+ plural="rayclusters",
958
+ name=name,
959
+ )
960
+ return "raycluster"
961
+ except ApiException:
962
+ pass
963
+
964
+ return None
965
+
966
+
967
+ def validate_provided_pod(service_name, provided_pod, service_pods):
968
+ if provided_pod is None:
969
+ return provided_pod
970
+
971
+ if provided_pod.isnumeric():
972
+ pod = int(provided_pod)
973
+ if pod < 0 or pod >= len(service_pods):
974
+ console.print(f"[red]Pod index {pod} is out of range[/red]")
975
+ raise typer.Exit(1)
976
+ pod_name = service_pods[pod].metadata.name
977
+
978
+ # case when the user provides pod name
979
+ else:
980
+ pod_names = [pod.metadata.name for pod in service_pods]
981
+ if provided_pod not in pod_names:
982
+ console.print(
983
+ f"[red]{service_name} does not have an associated pod called {provided_pod}[/red]"
984
+ )
985
+ raise typer.Exit(1)
986
+ else:
987
+ pod_name = provided_pod
988
+
989
+ return pod_name
990
+
991
+
992
+ def load_kubetorch_volumes_for_service(namespace, service_name, core_v1) -> List[str]:
993
+ """Extract volume information from service definition"""
994
+ volumes = []
995
+
996
+ try:
997
+ pods = core_v1.list_namespaced_pod(
998
+ namespace=namespace,
999
+ label_selector=f"kubetorch.com/service={service_name}",
1000
+ )
1001
+ if pods.items:
1002
+ pod = pods.items[0]
1003
+ for v in pod.spec.volumes or []:
1004
+ if v.persistent_volume_claim:
1005
+ volumes.append(v.name)
1006
+ return volumes
1007
+
1008
+ except Exception as e:
1009
+ logger.warning(f"Failed to extract volumes from service: {e}")
1010
+
1011
+ return volumes
1012
+
1013
+
1014
+ def create_table_for_output(
1015
+ columns: List[set], no_wrap_columns_names: list = None, header_style: dict = None
1016
+ ):
1017
+ table = Table(box=box.SQUARE, header_style=Style(**header_style))
1018
+ for name, style in columns:
1019
+ if name in no_wrap_columns_names:
1020
+ # always make service name fully visible
1021
+ table.add_column(name, style=style, no_wrap=True)
1022
+ else:
1023
+ table.add_column(name, style=style)
1024
+
1025
+ return table