kubetorch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubetorch might be problematic. Click here for more details.

Files changed (93) hide show
  1. kubetorch/__init__.py +60 -0
  2. kubetorch/cli.py +1985 -0
  3. kubetorch/cli_utils.py +1025 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +285 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +157 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +133 -0
  30. kubetorch/resources/callables/module.py +1416 -0
  31. kubetorch/resources/callables/utils.py +174 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +261 -0
  34. kubetorch/resources/compute/compute.py +2596 -0
  35. kubetorch/resources/compute/decorators.py +139 -0
  36. kubetorch/resources/compute/rbac.py +74 -0
  37. kubetorch/resources/compute/utils.py +1114 -0
  38. kubetorch/resources/compute/websocket.py +137 -0
  39. kubetorch/resources/images/__init__.py +1 -0
  40. kubetorch/resources/images/image.py +414 -0
  41. kubetorch/resources/images/images.py +74 -0
  42. kubetorch/resources/secrets/__init__.py +2 -0
  43. kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
  44. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  45. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  46. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  47. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  48. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  49. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  50. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  51. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  52. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  53. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  54. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  55. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  56. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  57. kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
  58. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  59. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  60. kubetorch/resources/secrets/secret.py +238 -0
  61. kubetorch/resources/secrets/secret_factory.py +70 -0
  62. kubetorch/resources/secrets/utils.py +209 -0
  63. kubetorch/resources/volumes/__init__.py +0 -0
  64. kubetorch/resources/volumes/volume.py +365 -0
  65. kubetorch/servers/__init__.py +0 -0
  66. kubetorch/servers/http/__init__.py +0 -0
  67. kubetorch/servers/http/distributed_utils.py +3223 -0
  68. kubetorch/servers/http/http_client.py +730 -0
  69. kubetorch/servers/http/http_server.py +1788 -0
  70. kubetorch/servers/http/server_metrics.py +278 -0
  71. kubetorch/servers/http/utils.py +728 -0
  72. kubetorch/serving/__init__.py +0 -0
  73. kubetorch/serving/autoscaling.py +173 -0
  74. kubetorch/serving/base_service_manager.py +363 -0
  75. kubetorch/serving/constants.py +83 -0
  76. kubetorch/serving/deployment_service_manager.py +478 -0
  77. kubetorch/serving/knative_service_manager.py +519 -0
  78. kubetorch/serving/raycluster_service_manager.py +582 -0
  79. kubetorch/serving/service_manager.py +18 -0
  80. kubetorch/serving/templates/deployment_template.yaml +17 -0
  81. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  82. kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
  83. kubetorch/serving/templates/pod_template.yaml +194 -0
  84. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  85. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  86. kubetorch/serving/templates/service_template.yaml +21 -0
  87. kubetorch/serving/templates/workerset_template.yaml +36 -0
  88. kubetorch/serving/utils.py +377 -0
  89. kubetorch/utils.py +284 -0
  90. kubetorch-0.2.0.dist-info/METADATA +121 -0
  91. kubetorch-0.2.0.dist-info/RECORD +93 -0
  92. kubetorch-0.2.0.dist-info/WHEEL +4 -0
  93. kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,1788 @@
1
+ import base64
2
+ import importlib
3
+ import importlib.util
4
+ import inspect
5
+ import json
6
+ import logging.config
7
+ import os
8
+ import pickle
9
+ import random
10
+ import subprocess
11
+ import sys
12
+ import threading
13
+ import time
14
+ import traceback
15
+ from contextlib import asynccontextmanager
16
+ from datetime import datetime, timezone
17
+ from pathlib import Path
18
+ from typing import Awaitable, Callable, Dict, Optional, Union
19
+
20
+ try:
21
+ import httpx
22
+ except:
23
+ pass
24
+
25
+ from fastapi import Body, FastAPI, Header, HTTPException, Request
26
+
27
+ from fastapi.exceptions import RequestValidationError
28
+ from fastapi.responses import JSONResponse
29
+
30
+ from pydantic import BaseModel
31
+ from starlette.middleware.base import BaseHTTPMiddleware
32
+
33
+ try:
34
+ from server_metrics import (
35
+ get_inactivity_ttl_annotation,
36
+ HeartbeatManager,
37
+ setup_otel_metrics,
38
+ )
39
+ from utils import (
40
+ clear_debugging_sessions,
41
+ deep_breakpoint,
42
+ DEFAULT_ALLOWED_SERIALIZATION,
43
+ ensure_structured_logging,
44
+ is_running_in_kubernetes,
45
+ LOG_CONFIG,
46
+ request_id_ctx_var,
47
+ RSYNC_PORT,
48
+ wait_for_app_start,
49
+ )
50
+ except ImportError:
51
+ from .server_metrics import (
52
+ get_inactivity_ttl_annotation,
53
+ HeartbeatManager,
54
+ setup_otel_metrics,
55
+ )
56
+ from .utils import (
57
+ clear_debugging_sessions,
58
+ deep_breakpoint,
59
+ DEFAULT_ALLOWED_SERIALIZATION,
60
+ ensure_structured_logging,
61
+ is_running_in_kubernetes,
62
+ LOG_CONFIG,
63
+ request_id_ctx_var,
64
+ RSYNC_PORT,
65
+ wait_for_app_start,
66
+ )
67
+
68
+ from starlette.background import BackgroundTask
69
+ from starlette.exceptions import HTTPException as StarletteHTTPException
70
+ from starlette.responses import StreamingResponse
71
+
72
+ logging.config.dictConfig(LOG_CONFIG)
73
+
74
+ # Set up our structured JSON logging
75
+ ensure_structured_logging()
76
+
77
+ # Create the print logger AFTER ensure_structured_logging so it inherits handlers
78
+ print_logger = logging.getLogger("print_redirect")
79
+
80
+ logger = logging.getLogger(__name__)
81
+ # Set log level based on environment variable
82
+ # Don't default the log_level
83
+ kt_log_level = os.getenv("KT_LOG_LEVEL")
84
+ if kt_log_level:
85
+ kt_log_level = kt_log_level.upper()
86
+ logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
87
+
88
+ _CACHED_CALLABLES = {}
89
+ _LAST_DEPLOYED = 0
90
+ _CACHED_IMAGE = []
91
+ DISTRIBUTED_SUPERVISOR = None
92
+ APP_PROCESS = None
93
+ _CALLABLE_LOAD_LOCK = threading.Lock() # Lock for thread-safe callable loading
94
+ LOKI_HOST = os.environ.get("LOKI_HOST", "loki-gateway.kubetorch.svc.cluster.local")
95
+ LOKI_PORT = int(os.environ.get("LOKI_PORT", 80)) # Default Loki port
96
+ KT_OTEL_ENABLED = os.environ.get("KT_OTEL_ENABLED", "False").lower() == "true"
97
+ KT_TRACING_ENABLED = (
98
+ os.environ.get("KT_TRACING_ENABLED", "").lower() != "false"
99
+ ) # Defaults to True
100
+
101
+ # Global termination event that can be checked by running requests
102
+ TERMINATION_EVENT = threading.Event()
103
+ # Create a client for FastAPI service
104
+
105
+ # Set the python breakpoint to kt.deep_breakpoint
106
+ os.environ["PYTHONBREAKPOINT"] = "kubetorch.deep_breakpoint"
107
+
108
+ request_id_ctx_var.set(os.getenv("KT_LAUNCH_ID", "-"))
109
+
110
+ #####################################
111
+ ######### Instrument Traces #########
112
+ #####################################
113
+ instrument_traces = KT_TRACING_ENABLED
114
+ if instrument_traces:
115
+ try:
116
+ from opentelemetry import trace
117
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
118
+ OTLPSpanExporter,
119
+ )
120
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
121
+ from opentelemetry.instrumentation.logging import LoggingInstrumentor
122
+ from opentelemetry.instrumentation.requests import RequestsInstrumentor
123
+ from opentelemetry.sdk.resources import Resource
124
+ from opentelemetry.sdk.trace import TracerProvider
125
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
126
+ except ImportError:
127
+ instrument_traces = False
128
+
129
+ if instrument_traces:
130
+ logger.info("Configuring OTLP exporter to instrument traces")
131
+ trace.set_tracer_provider(
132
+ TracerProvider(
133
+ resource=Resource.create(
134
+ {
135
+ "service.name": os.environ.get("OTEL_SERVICE_NAME"),
136
+ "service.instance.id": os.environ.get("POD_NAME"),
137
+ }
138
+ )
139
+ )
140
+ )
141
+ span_processor = BatchSpanProcessor(
142
+ OTLPSpanExporter(
143
+ endpoint=os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"),
144
+ insecure=True,
145
+ )
146
+ )
147
+ trace.get_tracer_provider().add_span_processor(span_processor)
148
+ RequestsInstrumentor().instrument()
149
+ LoggingInstrumentor().instrument()
150
+
151
+ #####################################
152
+ ########### Proxy Helpers ###########
153
+ #####################################
154
+ if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
155
+ port = os.getenv("KT_APP_PORT")
156
+ logger.info(f"Creating /http reverse proxy to: http://localhost:{port}/")
157
+ proxy_client = httpx.AsyncClient(base_url=f"http://localhost:{port}/", timeout=None)
158
+ else:
159
+ proxy_client = None
160
+
161
+
162
+ async def _http_reverse_proxy(request: Request):
163
+ """Reverse proxy for /http/* routes to FastAPI service on its port"""
164
+ # Extract the endpoint name from the path
165
+ # request.path_params["path"] will contain everything after /http/
166
+ endpoint_path = request.path_params["path"]
167
+
168
+ # Build the URL for the FastAPI service
169
+ url = httpx.URL(path=f"/{endpoint_path}", query=request.url.query.encode("utf-8"))
170
+
171
+ # Build the request to forward to FastAPI
172
+ rp_req = proxy_client.build_request(
173
+ request.method, url, headers=request.headers.raw, content=await request.body()
174
+ )
175
+
176
+ # Send the request and get streaming response
177
+ rp_resp = await proxy_client.send(rp_req, stream=True)
178
+
179
+ # Return streaming response
180
+ return StreamingResponse(
181
+ rp_resp.aiter_raw(),
182
+ status_code=rp_resp.status_code,
183
+ headers=rp_resp.headers,
184
+ background=BackgroundTask(rp_resp.aclose),
185
+ )
186
+
187
+
188
+ #####################################
189
+ ########### Cache Helpers ###########
190
+ #####################################
191
+ def clear_cache():
192
+ global _CACHED_CALLABLES
193
+
194
+ logger.debug("Clearing callables cache.")
195
+ _CACHED_CALLABLES.clear()
196
+
197
+
198
+ def cached_image_setup():
199
+ logger.debug("Starting cached image setup.")
200
+ global _CACHED_IMAGE
201
+ global APP_PROCESS
202
+
203
+ dockerfile_path = kt_directory() / "image.dockerfile"
204
+ with open(dockerfile_path, "r") as file:
205
+ lines = file.readlines()
206
+ lines = [line.strip() for line in lines]
207
+
208
+ # find first line where image differs from cache and update cache
209
+ cache_mismatch_index = -1
210
+ cmd_mismatch = False
211
+ for i, (new_line, cached_line) in enumerate(zip(lines, _CACHED_IMAGE)):
212
+ if new_line.startswith("CMD"):
213
+ cmd_mismatch = True
214
+
215
+ if new_line != cached_line or "# override" in new_line or cmd_mismatch:
216
+ cache_mismatch_index = i
217
+ break
218
+ if cache_mismatch_index == -1:
219
+ if len(lines) != len(_CACHED_IMAGE):
220
+ cache_mismatch_index = min(len(lines), len(_CACHED_IMAGE))
221
+ else:
222
+ cache_mismatch_index = len(lines)
223
+ _CACHED_IMAGE = lines
224
+
225
+ if cache_mismatch_index == len(lines):
226
+ return
227
+
228
+ if not (cache_mismatch_index == len(lines) - 1 and cmd_mismatch):
229
+ logger.info("Running image setup.")
230
+ else:
231
+ logger.debug("Skipping image setup steps, no changes detected.")
232
+
233
+ # Grab the current list of installed dependencies with pip freeze to check if anything changes (we need to send a
234
+ # SIGHUP to restart the server if so)
235
+ start_deps = None
236
+ import subprocess
237
+
238
+ try:
239
+ res = subprocess.run(
240
+ ["pip", "freeze"],
241
+ capture_output=True,
242
+ text=True,
243
+ check=True,
244
+ )
245
+ start_deps = res.stdout.splitlines()
246
+ except subprocess.CalledProcessError as e:
247
+ logger.error(f"Failed to run pip freeze: {e}")
248
+
249
+ # only run image setup steps starting from cache mismatch point
250
+ kt_pip_cmd = None
251
+ for line in lines[cache_mismatch_index:]:
252
+ command = ""
253
+ if line.strip().startswith("#"):
254
+ continue # ignore comments
255
+ if line.startswith("RUN") or line.startswith("CMD"):
256
+ command = line[len("RUN ") :]
257
+
258
+ if command.startswith("$KT_PIP_INSTALL_CMD"):
259
+ kt_pip_cmd = kt_pip_cmd or _get_kt_pip_install_cmd()
260
+ command = command.replace("$KT_PIP_INSTALL_CMD", kt_pip_cmd)
261
+ elif line.startswith("COPY"):
262
+ _, source, dest = line.split()
263
+ # COPY instructions are essentially no-ops since rsync_file_updates()
264
+ # already placed files in their correct locations.
265
+ # But we verify the files exist and log the absolute paths for clarity.
266
+
267
+ # Determine the actual absolute destination path
268
+ if dest and dest.startswith("/"):
269
+ # Already absolute
270
+ dest_path = Path(dest)
271
+ elif dest and dest.startswith("~/"):
272
+ # Tilde prefix - strip it and treat as relative to cwd
273
+ dest_path = Path.cwd() / dest[2:]
274
+ else:
275
+ # Relative to working directory (including explicit basenames)
276
+ dest_path = Path.cwd() / dest
277
+
278
+ # Verify the destination exists (it should have been rsync'd)
279
+ if dest_path.exists():
280
+ logger.info(f"Copied {source} to {dest_path.absolute()}")
281
+ else:
282
+ raise FileNotFoundError(
283
+ f"COPY {source} {dest} failed: destination {dest_path.absolute()} does not exist. "
284
+ f"This likely means the rsync operation failed to sync the files correctly."
285
+ )
286
+ elif line.startswith("ENV"):
287
+ # Need to handle the case where the env var is being set to "" (empty string)
288
+ line_vals = line.split(" ", 2)
289
+ if len(line_vals) < 2: # ENV line must have at least key
290
+ raise ValueError("ENV line cannot be empty")
291
+ if len(line_vals) == 2: # ENV line with just key
292
+ key = line_vals[1]
293
+ val = ""
294
+ elif len(line_vals) == 3: # ENV line with key and value
295
+ key, val = line_vals[1], line_vals[2]
296
+
297
+ # Expand environment variables in the value
298
+ # This supports patterns like $VAR, ${VAR}, and $VAR:default_value
299
+ expanded_val = os.path.expandvars(val)
300
+
301
+ if key not in [
302
+ "KT_FILE_PATH",
303
+ "KT_MODULE_NAME",
304
+ "KT_CLS_OR_FN_NAME",
305
+ "KT_INIT_ARGS",
306
+ "KT_CALLABLE_TYPE",
307
+ "KT_DISTRIBUTED_CONFIG",
308
+ ]:
309
+ logger.info(f"Setting env var {key}")
310
+ os.environ[key] = expanded_val
311
+ # If the env var is specifically KT_LOG_LEVEL, we need to update the logger level
312
+ if key == "KT_LOG_LEVEL":
313
+ global kt_log_level
314
+ kt_log_level = expanded_val.upper()
315
+ logger.setLevel(kt_log_level)
316
+ logger.info(f"Updated log level to {kt_log_level}")
317
+ elif line.startswith("FROM"):
318
+ continue
319
+ elif line:
320
+ raise ValueError(f"Unrecognized image setup instruction {line}")
321
+
322
+ if command:
323
+ is_app_cmd = line.startswith("CMD")
324
+ if is_app_cmd:
325
+ logger.info(f"Running app command: {command}")
326
+ else:
327
+ logger.info(f"Running image setup step: {command}")
328
+
329
+ try:
330
+ # Use subprocess.Popen to capture output and redirect through StreamToLogger
331
+ env = os.environ.copy()
332
+ env["PYTHONUNBUFFERED"] = "1"
333
+
334
+ if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
335
+ if APP_PROCESS and APP_PROCESS.poll() is None:
336
+ APP_PROCESS.kill()
337
+
338
+ process = subprocess.Popen(
339
+ command,
340
+ shell=True,
341
+ stdout=subprocess.PIPE,
342
+ stderr=subprocess.PIPE,
343
+ universal_newlines=True,
344
+ bufsize=1,
345
+ env=env,
346
+ )
347
+
348
+ if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
349
+ APP_PROCESS = process
350
+
351
+ # Collect stderr for potential error logging
352
+ import threading
353
+
354
+ stderr_lines = []
355
+ stderr_lock = threading.Lock()
356
+
357
+ # Stream stdout and stderr in real-time
358
+ # We need to do all this so the stdout and stderr are prints with the correct formatting
359
+ # for our queries. Without it they just flow straight to system stdout and stderr without any
360
+
361
+ def stream_output(pipe, log_func, request_id, collect_stderr=False):
362
+ request_id_ctx_var.set(request_id)
363
+ for line in iter(pipe.readline, ""):
364
+ if line:
365
+ stripped_line = line.rstrip()
366
+ log_func(stripped_line)
367
+
368
+ # Collect stderr lines for potential error logging
369
+ if collect_stderr:
370
+ with stderr_lock:
371
+ stderr_lines.append(stripped_line.lstrip("ERROR: "))
372
+ pipe.close()
373
+
374
+ # Start streaming threads
375
+ current_request_id = request_id_ctx_var.get("-")
376
+
377
+ stderr_log_func = logger.error if is_app_cmd else logger.debug
378
+ stdout_thread = threading.Thread(
379
+ target=stream_output,
380
+ args=(process.stdout, logger.info, current_request_id),
381
+ )
382
+ stderr_thread = threading.Thread(
383
+ target=stream_output,
384
+ args=(
385
+ process.stderr,
386
+ stderr_log_func,
387
+ current_request_id,
388
+ not is_app_cmd,
389
+ ),
390
+ )
391
+
392
+ stdout_thread.daemon = True
393
+ stderr_thread.daemon = True
394
+ stdout_thread.start()
395
+ stderr_thread.start()
396
+
397
+ if is_app_cmd and os.getenv("KT_APP_PORT"):
398
+ # wait for internal app to be healthy/ready if run port is provided
399
+ try:
400
+ port = os.getenv("KT_APP_PORT")
401
+ logger.debug(
402
+ f"Waiting for internal app on port {port} to start:"
403
+ )
404
+ wait_for_app_start(
405
+ port=port,
406
+ health_check=os.getenv("KT_APP_HEALTHCHECK"),
407
+ process=process,
408
+ )
409
+ logger.info(f"App on port {port} is ready.")
410
+ except Exception as e:
411
+ logger.error(f"Caught exception waiting for app to start: {e}")
412
+ else:
413
+ # Check if this is a background command (ends with &)
414
+ is_background = command.rstrip().endswith("&")
415
+
416
+ if is_background:
417
+ # For background processes, give it a moment to start and check for immediate failures
418
+ import time
419
+
420
+ time.sleep(0.5) # Brief pause to catch immediate errors
421
+
422
+ # Check if process failed immediately
423
+ poll_result = process.poll()
424
+ if poll_result is not None and poll_result != 0:
425
+ # Process exited with error
426
+ stdout_thread.join(timeout=1)
427
+ stderr_thread.join(timeout=1)
428
+ return_code = poll_result
429
+ else:
430
+ # Process is running in background successfully
431
+ logger.info(
432
+ f"Background process started successfully (PID: {process.pid})"
433
+ )
434
+ return_code = 0 # Indicate success for background start
435
+ else:
436
+ # Wait for process to complete
437
+ return_code = process.wait()
438
+
439
+ # Wait for streaming threads to finish
440
+ stdout_thread.join()
441
+ stderr_thread.join()
442
+
443
+ if return_code != 0 and not is_app_cmd:
444
+ with stderr_lock:
445
+ if stderr_lines:
446
+ logger.error(
447
+ f"Failed to run command '{command}' with stderr:"
448
+ )
449
+ for stderr_line in stderr_lines:
450
+ logger.error(stderr_line)
451
+ except subprocess.CalledProcessError as e:
452
+ logger.error(f"Failed to run command '{command}' with error: {e}")
453
+ with stderr_lock:
454
+ if stderr_lines:
455
+ logger.error("Stderr:")
456
+ for stderr_line in stderr_lines:
457
+ logger.error(stderr_line)
458
+ # Check if any dependencies changed and if so reload them inside the server process
459
+ if start_deps:
460
+ try:
461
+ # Run pip freeze and capture the output
462
+ res = subprocess.run(
463
+ ["pip", "freeze"],
464
+ capture_output=True,
465
+ text=True,
466
+ check=True,
467
+ )
468
+ end_deps = res.stdout.splitlines()
469
+ # We only need to look at the deps which were already installed (i.e. lines in start_deps),
470
+ # new ones can't be "stale" inside the current server process
471
+ # We also only use lines with exact pypi versions (has "=="), no editable
472
+ changed_deps = [
473
+ line.split("==")[0]
474
+ for line in start_deps
475
+ if "==" in line and line not in end_deps
476
+ ]
477
+ imported_changed_deps = [
478
+ dep for dep in changed_deps if dep in sys.modules
479
+ ] # Only reload deps which are already imported
480
+ if imported_changed_deps:
481
+ logger.debug(
482
+ f"New dependencies found: {imported_changed_deps}, forcing reload"
483
+ )
484
+
485
+ # Don't clear the callable cache here - let load_callable_from_env handle it to preserve __kt_cached_state__
486
+ if DISTRIBUTED_SUPERVISOR:
487
+ DISTRIBUTED_SUPERVISOR.cleanup()
488
+
489
+ # Remove changed modules from sys.modules to override fresh imports
490
+ modules_to_remove = []
491
+ for module_name in sys.modules:
492
+ for dep in imported_changed_deps:
493
+ if module_name == dep or module_name.startswith(dep + "."):
494
+ modules_to_remove.append(module_name)
495
+ break
496
+
497
+ for module_name in modules_to_remove:
498
+ try:
499
+ del sys.modules[module_name]
500
+ logger.debug(f"Removed module {module_name} from sys.modules")
501
+ except KeyError:
502
+ pass
503
+ except subprocess.CalledProcessError as e:
504
+ logger.error(f"Failed to run pip freeze: {e}")
505
+
506
+
507
+ def run_image_setup(deployed_time: Optional[float] = None):
508
+ if os.environ["KT_FREEZE"] == "True" or not is_running_in_kubernetes():
509
+ return
510
+
511
+ rsync_file_updates()
512
+
513
+ dockerfile_path = kt_directory() / "image.dockerfile"
514
+ if not dockerfile_path.exists():
515
+ raise FileNotFoundError(
516
+ f"No image and metadata configuration found in path: {str(dockerfile_path)}"
517
+ )
518
+ while (
519
+ # May need to give the dockerfile time to rsync over, so wait until the dockerfile timestamp is later than
520
+ # when we started the deployment (recorded in .to and passed here as deployed_time). We also should only
521
+ # wait if _LAST_DEPLOYED is not zero, as the first time the server is deployed the image is written before
522
+ # the server starts so we don't need to wait.
523
+ _LAST_DEPLOYED
524
+ and dockerfile_path.stat().st_mtime < deployed_time
525
+ and datetime.now(timezone.utc).timestamp() - deployed_time < 5
526
+ ):
527
+ time.sleep(0.1)
528
+
529
+ cached_image_setup()
530
+
531
+ if not os.getenv("KT_CALLABLE_TYPE") == "app":
532
+ logger.debug("Completed cached image setup.")
533
+
534
+
535
+ #####################################
536
+ ######## Generic Helpers ############
537
+ #####################################
538
+ class SerializationError(Exception):
539
+ pass
540
+
541
+
542
+ def kt_directory():
543
+ if "KT_DIRECTORY" in os.environ:
544
+ return Path(os.environ["KT_DIRECTORY"]).expanduser()
545
+ else:
546
+ return Path.cwd() / ".kt"
547
+
548
+
549
+ def _get_kt_pip_install_cmd() -> Optional[str]:
550
+ """Get the actual KT_PIP_INSTALL_CMD value for command expansion."""
551
+ kt_pip_cmd = os.getenv("KT_PIP_INSTALL_CMD")
552
+ if not kt_pip_cmd: # Fallback to reading from file
553
+ try:
554
+ with open(kt_directory() / "kt_pip_install_cmd", "r") as f:
555
+ return f.read().strip()
556
+ except FileNotFoundError:
557
+ return None
558
+ return kt_pip_cmd
559
+
560
+
561
+ def is_running_in_container():
562
+ # Check for .dockerenv file which exists in Docker containers
563
+ return Path("/.dockerenv").exists()
564
+
565
+
566
+ async def run_in_executor_with_context(executor, func, *args):
567
+ """
568
+ Helper to run a function in an executor while preserving the request_id context.
569
+
570
+ This wrapper captures the current request_id from the context before running
571
+ the function in a thread pool executor, then sets it in the new thread.
572
+ """
573
+ import asyncio
574
+
575
+ # Capture the current request_id before switching threads
576
+ current_request_id = request_id_ctx_var.get("-")
577
+
578
+ def wrapper(*args):
579
+ # Set the request_id in the executor thread
580
+ token = None
581
+ if current_request_id != "-":
582
+ token = request_id_ctx_var.set(current_request_id)
583
+ try:
584
+ return func(*args)
585
+ finally:
586
+ # Clean up the context to avoid leaking between requests
587
+ if token is not None:
588
+ request_id_ctx_var.reset(token)
589
+
590
+ return await asyncio.get_event_loop().run_in_executor(executor, wrapper, *args)
591
+
592
+
593
+ def should_reload(deployed_as_of: Optional[str] = None) -> bool:
594
+ """
595
+ Determine if the server should reload based on the deployment timestamp.
596
+ If deployed_as_of is provided, it checks against the last deployed time.
597
+ If not provided, it defaults to False.
598
+ """
599
+ if deployed_as_of in [None, "null", "None"]:
600
+ return False
601
+
602
+ try:
603
+ deployed_time = datetime.fromisoformat(deployed_as_of).timestamp()
604
+ return deployed_time > _LAST_DEPLOYED
605
+ except ValueError as e:
606
+ logger.error(f"Invalid deployed_as_of format: {deployed_as_of}. Error: {e}")
607
+ return True
608
+
609
+
610
+ def load_callable(
611
+ deployed_as_of: Optional[str] = None,
612
+ distributed_subprocess: bool = False,
613
+ reload_cleanup_fn: [Callable, None] = None,
614
+ ):
615
+ global _LAST_DEPLOYED
616
+
617
+ callable_name = os.environ["KT_CLS_OR_FN_NAME"]
618
+
619
+ callable_obj = _CACHED_CALLABLES.get(callable_name, None)
620
+ if callable_obj and not should_reload(deployed_as_of):
621
+ # If the callable is cached and doesn't need reload, return it immediately
622
+ logger.debug("Returning cached callable.")
623
+ return callable_obj
624
+
625
+ # Slow path: need to load or reload - use lock for thread safety
626
+ with _CALLABLE_LOAD_LOCK:
627
+ # Double-check within lock (another thread might have loaded it)
628
+ callable_obj = _CACHED_CALLABLES.get(callable_name, None)
629
+ if callable_obj and not should_reload(deployed_as_of):
630
+ logger.debug("Returning cached callable (found after acquiring lock).")
631
+ return callable_obj
632
+ # Proceed with loading/reloading
633
+ return _load_callable_internal(
634
+ deployed_as_of, distributed_subprocess, reload_cleanup_fn, callable_obj
635
+ )
636
+
637
+
638
+ def _load_callable_internal(
639
+ deployed_as_of: Optional[str] = None,
640
+ distributed_subprocess: bool = False,
641
+ reload_cleanup_fn: [Callable, None] = None,
642
+ callable_obj=None,
643
+ ):
644
+ """Internal callable loading logic - should be called within lock for thread safety."""
645
+ global _LAST_DEPLOYED
646
+
647
+ callable_name = os.environ["KT_CLS_OR_FN_NAME"]
648
+
649
+ if not callable_obj:
650
+ logger.debug("Callable not found in cache, loading from environment.")
651
+ else:
652
+ logger.debug(
653
+ f"Callable found in cache, but reloading because deployed_as_of {deployed_as_of} is newer than last deployed time {_LAST_DEPLOYED}"
654
+ )
655
+
656
+ # If not in cache or we have a more recent deployment timestamp, update metadata and reload
657
+ if reload_cleanup_fn and _LAST_DEPLOYED:
658
+ # If a reload cleanup function is provided and we've already deployed at least once, call it before
659
+ # reloading the callable
660
+ reload_cleanup_fn()
661
+
662
+ deployed_time = (
663
+ datetime.fromisoformat(deployed_as_of).timestamp()
664
+ if deployed_as_of
665
+ else datetime.now(timezone.utc).timestamp()
666
+ )
667
+ if not distributed_subprocess:
668
+ # We don't reload the image in distributed subprocess/es, as we already did it in the
669
+ # main process and we don't want to do it multiple times (in each subprocess).
670
+ if _LAST_DEPLOYED:
671
+ logger.info("Patching image and code updates and reloading callable.")
672
+ else:
673
+ logger.info("Setting up image and loading callable.")
674
+ run_image_setup(deployed_time)
675
+
676
+ distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
677
+ if distributed_config not in ["null", "None"] and not distributed_subprocess:
678
+ logger.debug(f"Loading distributed supervisor: {distributed_config}")
679
+ callable_obj = load_distributed_supervisor(deployed_as_of=deployed_as_of)
680
+ logger.debug("Distributed supervisor loaded successfully.")
681
+ else:
682
+ logger.debug(f"Loading callable from environment: {callable_name}")
683
+ callable_obj = load_callable_from_env()
684
+ logger.debug("Callable loaded successfully.")
685
+
686
+ _LAST_DEPLOYED = deployed_time
687
+ _CACHED_CALLABLES[callable_name] = callable_obj
688
+
689
+ return callable_obj
690
+
691
+
692
+ def load_distributed_supervisor(deployed_as_of: Optional[str] = None):
693
+ global DISTRIBUTED_SUPERVISOR
694
+
695
+ if os.environ["KT_FILE_PATH"] not in sys.path:
696
+ sys.path.insert(0, os.environ["KT_FILE_PATH"])
697
+
698
+ distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
699
+
700
+ # If this is the main process of a distributed call, we don't load the callable directly,
701
+ # we create a new supervisor if it doesn't exist or if the config has changed.
702
+ # We don't create a supervisor if this is a distributed subprocess.
703
+ config_hash = hash(str(distributed_config))
704
+ if (
705
+ DISTRIBUTED_SUPERVISOR is None
706
+ or config_hash != DISTRIBUTED_SUPERVISOR.config_hash
707
+ ):
708
+ from .distributed_utils import distributed_supervisor_factory
709
+
710
+ logger.info(f"Loading distributed supervisor with config: {distributed_config}")
711
+ distributed_config = json.loads(distributed_config)
712
+ # If we already have some distributed processes, we need to clean them up before creating a new supervisor.
713
+ if DISTRIBUTED_SUPERVISOR:
714
+ DISTRIBUTED_SUPERVISOR.cleanup()
715
+ DISTRIBUTED_SUPERVISOR = distributed_supervisor_factory(**distributed_config)
716
+ DISTRIBUTED_SUPERVISOR.config_hash = config_hash
717
+ try:
718
+ # If there are any errors during setup, we catch and log them, and then undo the setup
719
+ # so that the distributed supervisor is not left in a broken state (and otherwise can still fail
720
+ # when we call DISTRIBUTED_SUPERVISOR.cleanup() in lifespan).
721
+ DISTRIBUTED_SUPERVISOR.setup(deployed_as_of=deployed_as_of)
722
+ except Exception as e:
723
+ logger.error(
724
+ f"Failed to set up distributed supervisor with config {distributed_config}: {e}"
725
+ )
726
+ DISTRIBUTED_SUPERVISOR = None
727
+ raise e
728
+ return DISTRIBUTED_SUPERVISOR
729
+
730
+
731
+ def patch_sys_path():
732
+ abs_path = str(Path(os.environ["KT_FILE_PATH"]).expanduser().resolve())
733
+ if os.environ["KT_FILE_PATH"] not in sys.path:
734
+ sys.path.insert(0, abs_path)
735
+ logger.debug(f"Added {abs_path} to sys.path")
736
+
737
+ # Maybe needed for subprocesses (e.g. distributed) to find the callable's module
738
+ # Needed for distributed subprocesses to find the file path
739
+ existing_path = os.environ.get("PYTHONPATH", "")
740
+ if os.environ["KT_FILE_PATH"] not in existing_path:
741
+ os.environ["PYTHONPATH"] = (
742
+ f"{abs_path}{os.pathsep}{existing_path}" if existing_path else abs_path
743
+ )
744
+ logger.debug(f"Set PYTHONPATH to {os.environ['PYTHONPATH']}")
745
+
746
+
747
+ def load_callable_from_env():
748
+ """Load and cache callable objects from env, preserving state if __kt_cached_state__ is available."""
749
+ cls_or_fn_name = os.environ["KT_CLS_OR_FN_NAME"]
750
+ module_name = os.environ["KT_MODULE_NAME"]
751
+
752
+ # Check if we have an existing cached callable and extract state if available
753
+ cached_state = None
754
+ existing_callable = _CACHED_CALLABLES.get(cls_or_fn_name, None)
755
+
756
+ if existing_callable and hasattr(existing_callable, "__kt_cached_state__"):
757
+ try:
758
+ logger.info(
759
+ f"Extracting cached state from {cls_or_fn_name} via __kt_cached_state__"
760
+ )
761
+ cached_state = existing_callable.__kt_cached_state__()
762
+ if cached_state is not None and not isinstance(cached_state, dict):
763
+ logger.warning(
764
+ f"__kt_cached_state__ returned non-dict type: {type(cached_state)}. Ignoring cached state."
765
+ )
766
+ cached_state = None
767
+ except Exception as e:
768
+ # This could happen if modules were removed from sys.modules during image setup
769
+ # and the callable's __kt_cached_state__ method depends on them
770
+ logger.warning(
771
+ f"Failed to extract cached state from {cls_or_fn_name} (possibly due to module reloading): {e}. "
772
+ f"Proceeding without cached state."
773
+ )
774
+ cached_state = None
775
+
776
+ # Now that we have the state, clean up the old callable to free memory
777
+ if existing_callable:
778
+ logger.debug(f"Deleting existing callable: {cls_or_fn_name}")
779
+ _CACHED_CALLABLES.pop(cls_or_fn_name, None)
780
+ del existing_callable
781
+ # Garbage collect to ensure everything cleaned up (especially GPU memory)
782
+ import gc
783
+
784
+ gc.collect()
785
+
786
+ patch_sys_path()
787
+
788
+ # If we're inside a distributed subprocess or the main process of a non-distributed call,
789
+ # we load and instantiate the callable.
790
+ try:
791
+ # Try regular package import first
792
+ if module_name in sys.modules:
793
+ # We make this logs to info because some imports are slow and we want the user to know that it's not our fault
794
+ # and not hanging
795
+ logger.info(f"Reimporting module {module_name}")
796
+ # Clear any existing debugging sessions when reloading modules
797
+ clear_debugging_sessions()
798
+ module = importlib.reload(sys.modules[module_name])
799
+ else:
800
+ logger.debug(f"Importing module {module_name}")
801
+ module = importlib.import_module(module_name)
802
+ logger.debug(f"Module {module_name} loaded")
803
+
804
+ # Ensure our structured logging is in place after user module import
805
+ # (in case the user's module configured its own logging)
806
+ ensure_structured_logging()
807
+
808
+ callable_obj = getattr(module, cls_or_fn_name)
809
+ logger.debug(f"Callable {cls_or_fn_name} loaded")
810
+ except (ImportError, ValueError) as original_error:
811
+ # Fall back to file-based import if package import fails
812
+ try:
813
+ module = import_from_file(os.environ["KT_FILE_PATH"], module_name)
814
+ # Ensure structured logging after file-based import
815
+ ensure_structured_logging()
816
+ callable_obj = getattr(module, cls_or_fn_name)
817
+ except (ImportError, ValueError):
818
+ # Raise the original error if file import also fails, because the errors which are raised here are
819
+ # more opaque and less useful than the original ImportError or ValueError.
820
+ raise original_error
821
+ except AttributeError as e:
822
+ # If the callable is not found in the module, raise an error
823
+ raise HTTPException(
824
+ status_code=404,
825
+ detail=f"Callable '{cls_or_fn_name}' not found in module '{module_name}'",
826
+ ) from e
827
+
828
+ # Unwrap to remove any kt deploy decorators (e.g. @kt.compute)
829
+ if hasattr(callable_obj, "__wrapped__"):
830
+ callable_obj = callable_obj.__wrapped__
831
+
832
+ if isinstance(callable_obj, type):
833
+ # Prepare init arguments
834
+ init_kwargs = {}
835
+
836
+ # Add user-provided init_args
837
+ if os.environ["KT_INIT_ARGS"] not in ["null", "None"]:
838
+ init_kwargs = json.loads(os.environ["KT_INIT_ARGS"])
839
+ logger.info(f"Setting init_args {init_kwargs}")
840
+
841
+ # Add cached state if available
842
+ # Allow user to manually set "kt_cached_state" to override/disable cache
843
+ if cached_state is not None and "kt_cached_state" not in init_kwargs:
844
+ # Check if the class's __init__ accepts kt_cached_state parameter
845
+ sig = inspect.signature(callable_obj.__init__)
846
+ if "kt_cached_state" in sig.parameters:
847
+ logger.info(f"Passing cached state to {cls_or_fn_name}.__init__")
848
+ init_kwargs["kt_cached_state"] = cached_state
849
+ else:
850
+ raise ValueError(
851
+ f"Class {cls_or_fn_name} has __kt_cached_state__ method but __init__ does not accept "
852
+ f"'kt_cached_state' parameter. Please add 'kt_cached_state=None' to __init__ signature."
853
+ )
854
+
855
+ # Instantiate with combined arguments
856
+ if init_kwargs:
857
+ callable_obj = callable_obj(**init_kwargs)
858
+ else:
859
+ callable_obj = callable_obj()
860
+
861
+ return callable_obj
862
+
863
+
864
+ def import_from_file(file_path: str, module_name: str):
865
+ """Import a module from file path."""
866
+ module_parts = module_name.split(".")
867
+ depth = max(0, len(module_parts) - 1)
868
+
869
+ # Convert file_path to absolute path if it's not already (note, .resolve will append the current working directory
870
+ # if file_path is relative)
871
+ abs_path = Path(file_path).expanduser().resolve()
872
+ # Ensure depth doesn't exceed available parent directories
873
+ max_available_depth = len(abs_path.parents) - 1
874
+
875
+ if max_available_depth < 0:
876
+ # File has no parent directories, use the file's directory itself
877
+ parent_path = str(abs_path.parent)
878
+ else:
879
+ # Clamp depth to available range to avoid IndexError
880
+ depth = min(depth, max_available_depth)
881
+ parent_path = str(abs_path.parents[depth])
882
+
883
+ if parent_path not in sys.path:
884
+ sys.path.insert(0, parent_path)
885
+
886
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
887
+ if spec is None or spec.loader is None:
888
+ raise ImportError(
889
+ f"Could not load spec for module {module_name} from {file_path}"
890
+ )
891
+
892
+ module = importlib.util.module_from_spec(spec)
893
+ spec.loader.exec_module(module)
894
+ return module
895
+
896
+
897
+ #####################################
898
+ ########## Rsync Helpers ############
899
+ #####################################
900
+ def generate_rsync_command(subdir: str = ".", exclude_absolute: bool = True):
901
+ """Generate rsync command for syncing from jump pod.
902
+
903
+ Args:
904
+ subdir: Directory to sync to (default current directory)
905
+ exclude_absolute: Whether to exclude __absolute__ directory (default True)
906
+ """
907
+ service_name = os.getenv("KT_SERVICE_NAME")
908
+ namespace = os.getenv("POD_NAMESPACE")
909
+
910
+ exclude_opt = "--exclude='__absolute__*' " if exclude_absolute else ""
911
+ logger.debug("Syncing code from rsync pod to local directory")
912
+ return f"rsync -av {exclude_opt}rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/ {subdir}"
913
+
914
+
915
+ def rsync_file_updates():
916
+ """Rsync files from the jump pod to the worker pod.
917
+
918
+ Performs two rsync operations in parallel:
919
+ 1. Regular files (excluding __absolute__*) to the working directory
920
+ 2. Absolute path files (under __absolute__/) to their absolute destinations
921
+ """
922
+ import concurrent.futures
923
+ from concurrent.futures import ThreadPoolExecutor
924
+
925
+ service_name = os.getenv("KT_SERVICE_NAME")
926
+ namespace = os.getenv("POD_NAMESPACE")
927
+
928
+ # Build base rsync URL
929
+ rsync_base = f"rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/"
930
+
931
+ max_retries = 5
932
+ base_delay = 1 # seconds
933
+ max_delay = 30 # seconds
934
+
935
+ def run_rsync_with_retries(rsync_cmd, description):
936
+ """Helper to run rsync with exponential backoff retries."""
937
+ for attempt in range(max_retries):
938
+ resp = subprocess.run(
939
+ rsync_cmd,
940
+ shell=True,
941
+ capture_output=True,
942
+ text=True,
943
+ )
944
+
945
+ if resp.returncode == 0:
946
+ logger.debug(f"Successfully rsync'd {description}")
947
+ return # Success!
948
+
949
+ # Check if it's a retryable error
950
+ retryable_errors = [
951
+ "max connections",
952
+ "Temporary failure in name resolution",
953
+ "Name or service not known",
954
+ "Connection refused",
955
+ "No route to host",
956
+ ]
957
+
958
+ is_retryable = any(error in resp.stderr for error in retryable_errors)
959
+
960
+ if is_retryable and attempt < max_retries - 1:
961
+ # Calculate exponential backoff with jitter
962
+ delay = min(
963
+ base_delay * (2**attempt) + random.uniform(0, 1), max_delay
964
+ )
965
+ logger.warning(
966
+ f"Rsync {description} failed with retryable error: {resp.stderr.strip()}. "
967
+ f"Retrying in {delay:.1f} seconds (attempt {attempt + 1}/{max_retries})"
968
+ )
969
+ time.sleep(delay)
970
+ else:
971
+ # For non-retryable errors or final attempt, raise immediately
972
+ if attempt == max_retries - 1:
973
+ logger.error(
974
+ f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}"
975
+ )
976
+ raise RuntimeError(
977
+ f"Rsync {description} failed with error: {resp.stderr}"
978
+ )
979
+
980
+ # If we exhausted all retries
981
+ raise RuntimeError(
982
+ f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}"
983
+ )
984
+
985
+ def rsync_regular_files():
986
+ """Rsync regular files (excluding __absolute__*) to working directory."""
987
+ rsync_cmd_regular = f"rsync -avL --exclude='__absolute__*' {rsync_base} ."
988
+ logger.debug(f"Rsyncing regular files with command: {rsync_cmd_regular}")
989
+ run_rsync_with_retries(rsync_cmd_regular, "regular files")
990
+
991
+ def rsync_absolute_files():
992
+ """Rsync absolute path files to their absolute destinations."""
993
+ # First, do a dry-run to see if __absolute__ directory exists
994
+ check_cmd = f"rsync --list-only {rsync_base}__absolute__/"
995
+ check_resp = subprocess.run(
996
+ check_cmd, shell=True, capture_output=True, text=True
997
+ )
998
+
999
+ if check_resp.returncode == 0 and check_resp.stdout.strip():
1000
+ # __absolute__ directory exists, sync its contents to root
1001
+ # The trick is to sync from __absolute__/ to / which places files in their absolute paths
1002
+ rsync_cmd_absolute = f"rsync -avL {rsync_base}__absolute__/ /"
1003
+ logger.debug(
1004
+ f"Rsyncing absolute path files with command: {rsync_cmd_absolute}"
1005
+ )
1006
+ run_rsync_with_retries(rsync_cmd_absolute, "absolute path files")
1007
+ else:
1008
+ logger.debug("No absolute path files to sync")
1009
+
1010
+ # Run both rsync operations in parallel
1011
+ with ThreadPoolExecutor(max_workers=2) as executor:
1012
+ # Submit both tasks
1013
+ regular_future = executor.submit(rsync_regular_files)
1014
+ absolute_future = executor.submit(rsync_absolute_files)
1015
+
1016
+ # Wait for both to complete and handle any exceptions
1017
+ futures = [regular_future, absolute_future]
1018
+ for future in concurrent.futures.as_completed(futures):
1019
+ try:
1020
+ future.result() # This will raise any exception that occurred
1021
+ except Exception as e:
1022
+ # Cancel remaining futures if one fails
1023
+ for f in futures:
1024
+ f.cancel()
1025
+ raise e
1026
+
1027
+ logger.debug("Completed rsync of all files")
1028
+
1029
+
1030
+ #####################################
1031
+ ########### App setup ###############
1032
+ #####################################
1033
+ class HealthCheckFilter(logging.Filter):
1034
+ def filter(self, record):
1035
+ return not (
1036
+ isinstance(record.args, tuple)
1037
+ and len(record.args) >= 3
1038
+ and ("/health" in record.args[2] or record.args[2] == "/")
1039
+ )
1040
+
1041
+
1042
+ class RequestContextFilter(logging.Filter):
1043
+ def filter(self, record):
1044
+ record.request_id = request_id_ctx_var.get("-")
1045
+ record.pod = os.getenv("POD_NAME", "unknown-pod")
1046
+
1047
+ if instrument_traces:
1048
+ from opentelemetry.trace import format_trace_id, get_current_span
1049
+
1050
+ # Add trace_id and span_id for log correlation
1051
+ current_span = get_current_span()
1052
+ if current_span and current_span.get_span_context().is_valid:
1053
+ record.trace_id = format_trace_id(
1054
+ current_span.get_span_context().trace_id
1055
+ )
1056
+ record.span_id = format_trace_id(
1057
+ current_span.get_span_context().span_id
1058
+ )
1059
+ else:
1060
+ record.trace_id = "-"
1061
+ record.span_id = "-"
1062
+
1063
+ return True
1064
+
1065
+
1066
+ class TerminationCheckMiddleware(BaseHTTPMiddleware):
1067
+ """Monitor for termination while request is running and return error if detected."""
1068
+
1069
+ async def dispatch(self, request: Request, call_next):
1070
+ # Skip health checks and metrics endpoints
1071
+ if request.url.path in ["/health", "/", "/metrics"]:
1072
+ return await call_next(request)
1073
+
1074
+ # Run the actual request in the background
1075
+ import asyncio
1076
+
1077
+ request_task = asyncio.create_task(call_next(request))
1078
+
1079
+ # Monitor for termination while request is running
1080
+ while not request_task.done():
1081
+ # Check if we're terminating
1082
+ if TERMINATION_EVENT.is_set() or (
1083
+ hasattr(request.app.state, "terminating")
1084
+ and request.app.state.terminating
1085
+ ):
1086
+ # Cancel the request task
1087
+ request_task.cancel()
1088
+
1089
+ # Return PodTerminatedError
1090
+ from kubetorch import PodTerminatedError
1091
+ from kubetorch.servers.http.http_server import package_exception
1092
+
1093
+ pod_name = os.environ.get("POD_NAME", "unknown")
1094
+ exc = PodTerminatedError(
1095
+ pod_name=pod_name,
1096
+ reason="SIGTERM",
1097
+ status_code=503,
1098
+ events=[
1099
+ {
1100
+ "timestamp": datetime.now(timezone.utc).isoformat(),
1101
+ "reason": "Terminating",
1102
+ "message": "Pod received SIGTERM signal and is shutting down gracefully",
1103
+ }
1104
+ ],
1105
+ )
1106
+
1107
+ return package_exception(exc)
1108
+
1109
+ # Wait a bit before checking again or for request to complete
1110
+ try:
1111
+ result = await asyncio.wait_for(
1112
+ asyncio.shield(request_task), timeout=0.5
1113
+ )
1114
+ return result
1115
+ except asyncio.TimeoutError:
1116
+ # Request still running after 0.5s, continue loop to check termination again
1117
+ continue
1118
+
1119
+ # Request completed normally
1120
+ return await request_task
1121
+
1122
+
1123
+ class RequestIDMiddleware(BaseHTTPMiddleware):
1124
+ async def dispatch(self, request: Request, call_next):
1125
+ request_id = request.headers.get("X-Request-ID", "-")
1126
+ token = request_id_ctx_var.set(request_id)
1127
+
1128
+ if instrument_traces and request_id != "-":
1129
+ span_attributes = {
1130
+ "request_id": request_id,
1131
+ "http.method": request.method,
1132
+ "http.url": str(request.url),
1133
+ "service.name": os.environ.get("OTEL_SERVICE_NAME"),
1134
+ "service.instance.id": os.environ.get("POD_NAME"),
1135
+ }
1136
+ # of the pod crashes (e.g., due to OOM) during execution of run_callable, we'll still have at least
1137
+ # this heartbeat span recorded
1138
+ tracer = trace.get_tracer("heartbeat")
1139
+ try:
1140
+ with tracer.start_as_current_span(
1141
+ "heartbeat.request", attributes=span_attributes
1142
+ ):
1143
+ tracer_provider = trace.get_tracer_provider()
1144
+ if isinstance(tracer_provider, TracerProvider):
1145
+ tracer_provider.force_flush()
1146
+ except Exception as e:
1147
+ logger.warning(f"Heartbeat span flush failed: {e}")
1148
+
1149
+ try:
1150
+ response = await call_next(request)
1151
+ return response
1152
+ finally:
1153
+ # Reset the context variable to its default value
1154
+ request_id_ctx_var.reset(token)
1155
+
1156
+
1157
+ class TraceFlushMiddleware(BaseHTTPMiddleware):
1158
+ """Flush traces after each HTTP Request so we don't lose trace data if the pod is killed"""
1159
+
1160
+ async def dispatch(self, request: Request, call_next):
1161
+ response = await call_next(request)
1162
+ tracer_provider = trace.get_tracer_provider()
1163
+ if isinstance(tracer_provider, TracerProvider):
1164
+ tracer_provider.force_flush()
1165
+ return response
1166
+
1167
+
1168
+ class StreamToLogger:
1169
+ def __init__(self, logger, log_level=logging.INFO, original_stream=None):
1170
+ self.logger = logger
1171
+ self.log_level = log_level
1172
+ self.original_stream = original_stream
1173
+ self.linebuf = ""
1174
+
1175
+ def _is_from_logging(self):
1176
+ """Check if the current write call is coming from the logging system"""
1177
+ frame = sys._getframe()
1178
+ while frame:
1179
+ if frame.f_globals.get("__name__", "").startswith("logging"):
1180
+ return True
1181
+ frame = frame.f_back
1182
+ return False
1183
+
1184
+ def write(self, buf):
1185
+ # Check if this is from logging system
1186
+ is_from_logging = self._is_from_logging()
1187
+
1188
+ # Always write to original stream first
1189
+ if self.original_stream:
1190
+ self.original_stream.write(buf)
1191
+ self.original_stream.flush()
1192
+
1193
+ # Skip logging if this is from the logging system to prevent infinite loops
1194
+ if self.logger.name == "print_redirect" and is_from_logging:
1195
+ return
1196
+
1197
+ # Buffer and log complete lines
1198
+ temp_linebuf = self.linebuf + buf
1199
+ self.linebuf = ""
1200
+
1201
+ # Split on newlines but keep carriage returns
1202
+ lines = []
1203
+ current_line = ""
1204
+ for char in temp_linebuf:
1205
+ if char == "\n":
1206
+ lines.append(current_line)
1207
+ current_line = ""
1208
+ else:
1209
+ current_line += char
1210
+
1211
+ # Add any remaining content to linebuf
1212
+ if current_line:
1213
+ self.linebuf = current_line
1214
+
1215
+ # Log complete lines
1216
+ for line in lines:
1217
+ if line:
1218
+ self.logger.log(self.log_level, line)
1219
+
1220
+ def flush(self):
1221
+ if self.original_stream:
1222
+ self.original_stream.flush()
1223
+ if self.linebuf != "":
1224
+ self.logger.log(self.log_level, self.linebuf)
1225
+ self.linebuf = ""
1226
+
1227
+ def isatty(self):
1228
+ # Delegate to the original stream if it exists, else return False
1229
+ if self.original_stream and hasattr(self.original_stream, "isatty"):
1230
+ return self.original_stream.isatty()
1231
+ return False
1232
+
1233
+ def fileno(self):
1234
+ if self.original_stream and hasattr(self.original_stream, "fileno"):
1235
+ return self.original_stream.fileno()
1236
+ raise OSError("Stream does not support fileno()")
1237
+
1238
+ @property
1239
+ def encoding(self):
1240
+ # Return the encoding of the original stream if available, else UTF-8
1241
+ if self.original_stream and hasattr(self.original_stream, "encoding"):
1242
+ return self.original_stream.encoding
1243
+ return "utf-8"
1244
+
1245
+
1246
+ # Save original streams before redirection
1247
+ _original_stdout = sys.stdout
1248
+ _original_stderr = sys.stderr
1249
+
1250
+ # Redirect stdout and stderr to our logger while preserving original streams
1251
+ sys.stdout = StreamToLogger(print_logger, logging.INFO, _original_stdout)
1252
+ sys.stderr = StreamToLogger(print_logger, logging.ERROR, _original_stderr)
1253
+
1254
+
1255
+ @asynccontextmanager
1256
+ async def lifespan(app: FastAPI):
1257
+ """Manage application lifecycle"""
1258
+ import signal
1259
+ import threading
1260
+
1261
+ # Only register signal handlers if we're in the main thread
1262
+ # This allows tests to run without signal handling
1263
+ if threading.current_thread() is threading.main_thread():
1264
+ # Save any existing SIGTERM handler
1265
+ original_sigterm_handler = signal.getsignal(signal.SIGTERM)
1266
+
1267
+ def handle_sigterm(signum, frame):
1268
+ """Handle SIGTERM for graceful shutdown."""
1269
+ logger.info("Received SIGTERM, initiating graceful shutdown...")
1270
+
1271
+ # Mark that we're terminating and interrupt existing requests IMMEDIATELY
1272
+ app.state.terminating = True
1273
+ TERMINATION_EVENT.set()
1274
+
1275
+ # Clean up distributed supervisor to ensure child processes are terminated
1276
+ # This is important because SIGTERM is not propagated to child processes automatically
1277
+ # This runs synchronously and may take 1-2 seconds, but existing requests are already interrupted
1278
+ global DISTRIBUTED_SUPERVISOR
1279
+ if DISTRIBUTED_SUPERVISOR:
1280
+ logger.info("Cleaning up distributed supervisor and child processes...")
1281
+ try:
1282
+ DISTRIBUTED_SUPERVISOR.cleanup()
1283
+ except Exception as e:
1284
+ logger.error(f"Error cleaning up distributed supervisor: {e}")
1285
+
1286
+ # Call the original handler if it exists and isn't the default
1287
+ if original_sigterm_handler and original_sigterm_handler not in (
1288
+ signal.SIG_DFL,
1289
+ signal.SIG_IGN,
1290
+ ):
1291
+ original_sigterm_handler(signum, frame)
1292
+
1293
+ # Register SIGTERM handler
1294
+ signal.signal(signal.SIGTERM, handle_sigterm)
1295
+ app.state.terminating = False
1296
+
1297
+ # Startup
1298
+ ttl = get_inactivity_ttl_annotation()
1299
+ if ttl and KT_OTEL_ENABLED is True:
1300
+ app.state.heartbeat_manager = HeartbeatManager(ttl_seconds=ttl)
1301
+ if app.state.heartbeat_manager:
1302
+ await app.state.heartbeat_manager.start()
1303
+ logger.debug(f"Heartbeat manager started with TTL={ttl}s")
1304
+ elif ttl:
1305
+ logger.warning(
1306
+ "TTL annotation found, but OTEL is not enabled, heartbeat disabled"
1307
+ )
1308
+ else:
1309
+ logger.debug("No TTL annotation found, heartbeat disabled")
1310
+
1311
+ try:
1312
+ cached_image_setup()
1313
+ if not os.getenv("KT_CALLABLE_TYPE") == "app":
1314
+ load_callable()
1315
+
1316
+ logger.info("Kubetorch Server started.")
1317
+ request_id_ctx_var.set("-") # Reset request_id after launch sequence
1318
+ yield
1319
+
1320
+ except Exception:
1321
+ # We don't want to raise errors like ImportError during startup, as it will cause the server to crash and the
1322
+ # user won't be able to see the error in the logs to debug (e.g. quickly add dependencies or reorganize
1323
+ # imports). Instead, we log it (and a stack trace) and continue, so it will be surfaced to the user when they
1324
+ # call the service.
1325
+
1326
+ # However if this service is frozen, it should just fail because the user isn't debugging the service and there is no
1327
+ # way for the dependencies to be added at runtime.
1328
+ logger.error(traceback.format_exc())
1329
+ request_id_ctx_var.set("-")
1330
+ yield
1331
+
1332
+ finally:
1333
+ # Flush OpenTelemetry traces before shutdown
1334
+ if instrument_traces:
1335
+ from opentelemetry.sdk.trace import TracerProvider
1336
+
1337
+ tracer_provider = trace.get_tracer_provider()
1338
+ if isinstance(tracer_provider, TracerProvider):
1339
+ logger.info("Forcing OpenTelemetry span flush before shutdown")
1340
+ tracer_provider.force_flush()
1341
+
1342
+ # Shutdown
1343
+ manager = getattr(app.state, "heartbeat_manager", None)
1344
+ if manager:
1345
+ await manager.stop()
1346
+ logger.info("Heartbeat manager stopped")
1347
+
1348
+ # Clean up during normal shutdown so we don't leave any hanging processes, which can cause pods to hang
1349
+ # indefinitely. Skip if already cleaned up by SIGTERM handler.
1350
+ if DISTRIBUTED_SUPERVISOR and not getattr(app.state, "terminating", False):
1351
+ DISTRIBUTED_SUPERVISOR.cleanup()
1352
+
1353
+ # Clear any remaining debugging sessions
1354
+ clear_debugging_sessions()
1355
+
1356
+
1357
+ # Add the filter to uvicorn's access logger
1358
+ logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
1359
+ root_logger = logging.getLogger()
1360
+ root_logger.addFilter(RequestContextFilter())
1361
+ for handler in root_logger.handlers:
1362
+ handler.addFilter(RequestContextFilter())
1363
+ print_logger.addFilter(RequestContextFilter())
1364
+
1365
+ app = FastAPI(lifespan=lifespan)
1366
+ app.add_middleware(TerminationCheckMiddleware) # Check termination first
1367
+ app.add_middleware(RequestIDMiddleware)
1368
+
1369
+ # Configure the FastAPI app for metrics first
1370
+ # Method will return None for meter_provider if otel is not enabled
1371
+ app, meter_provider = (
1372
+ setup_otel_metrics(app) if KT_OTEL_ENABLED is True else (app, None)
1373
+ )
1374
+
1375
+ # Now instrument for traces and metrics together
1376
+ if instrument_traces:
1377
+ logger.info("Instrumenting FastAPI app for traces and metrics")
1378
+ FastAPIInstrumentor.instrument_app(
1379
+ app,
1380
+ meter_provider=meter_provider,
1381
+ excluded_urls="/metrics,/health",
1382
+ )
1383
+ logger.info("Adding TraceFlushMiddleware to flush traces")
1384
+ app.add_middleware(TraceFlushMiddleware)
1385
+ elif meter_provider is not None:
1386
+ try:
1387
+ # Skipped if instrument_traces is False, need to reimplement if we want to use metrics only
1388
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1389
+
1390
+ logger.info("Instrumenting FastAPI app for metrics only")
1391
+ FastAPIInstrumentor.instrument_app(
1392
+ app,
1393
+ meter_provider=meter_provider,
1394
+ excluded_urls="/,/metrics,/health",
1395
+ )
1396
+ except ImportError:
1397
+ logger.info(
1398
+ "OpenTelemetry instrumentation not enabled, skipping metrics instrumentation"
1399
+ )
1400
+
1401
+ # add route for fastapi app
1402
+ if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
1403
+ logger.debug("Adding route for path /http")
1404
+ app.add_route(
1405
+ "/http/{path:path}",
1406
+ _http_reverse_proxy,
1407
+ ["GET", "POST", "PUT", "DELETE", "PATCH"],
1408
+ )
1409
+
1410
+
1411
+ #####################################
1412
+ ########## Error Handling ###########
1413
+ #####################################
1414
+ class ErrorResponse(BaseModel):
1415
+ error_type: str
1416
+ message: str
1417
+ traceback: str
1418
+ pod_name: str
1419
+ state: Optional[dict] = None # Optional serialized exception state
1420
+
1421
+
1422
+ # Factor out the exception packaging so we can use it in the handler below and also inside distributed subprocesses
1423
+ def package_exception(exc: Exception):
1424
+ import asyncio
1425
+ import concurrent
1426
+
1427
+ error_type = exc.__class__.__name__
1428
+ trace = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
1429
+
1430
+ # Check if the exception has a status_code attribute (e.g. PodTerminatedError)
1431
+ if hasattr(exc, "status_code"):
1432
+ status_code = exc.status_code
1433
+ elif isinstance(exc, (RequestValidationError, TypeError, AssertionError)):
1434
+ status_code = 422
1435
+ elif isinstance(exc, (ValueError, UnicodeError, json.JSONDecodeError)):
1436
+ status_code = 400
1437
+ elif isinstance(exc, (KeyError, FileNotFoundError)):
1438
+ status_code = 404
1439
+ elif isinstance(exc, PermissionError):
1440
+ status_code = 403
1441
+ elif isinstance(exc, (StarletteHTTPException, HTTPException)):
1442
+ status_code = exc.status_code
1443
+ elif isinstance(exc, (MemoryError, OSError)):
1444
+ status_code = 500
1445
+ elif isinstance(exc, NotImplementedError):
1446
+ status_code = 501
1447
+ elif isinstance(exc, asyncio.TimeoutError):
1448
+ status_code = 504
1449
+ elif isinstance(exc, concurrent.futures.TimeoutError):
1450
+ status_code = 504
1451
+ else:
1452
+ status_code = 500
1453
+
1454
+ # Try to serialize exception state if it has __getstate__
1455
+ state = None
1456
+ if hasattr(exc, "__getstate__"):
1457
+ try:
1458
+ state = exc.__getstate__()
1459
+ except Exception as e:
1460
+ logger.debug(f"Could not serialize exception state for {error_type}: {e}")
1461
+
1462
+ error_response = ErrorResponse(
1463
+ error_type=error_type,
1464
+ message=str(exc),
1465
+ traceback=trace,
1466
+ pod_name=os.getenv("POD_NAME"),
1467
+ state=state,
1468
+ )
1469
+
1470
+ return JSONResponse(status_code=status_code, content=error_response.model_dump())
1471
+
1472
+
1473
+ @app.exception_handler(Exception)
1474
+ async def generic_exception_handler(request: Request, exc: Exception):
1475
+ return package_exception(exc)
1476
+
1477
+
1478
+ @app.post("/_reload_image", response_class=JSONResponse)
1479
+ def _reload_image(
1480
+ request: Request,
1481
+ deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
1482
+ ):
1483
+ """
1484
+ Endpoint to reload the image and metadata configuration.
1485
+ This is used to reload the image in cases where we're not calling the callable directly,
1486
+ e.g. kt.app and Ray workers.
1487
+ """
1488
+ global _LAST_DEPLOYED
1489
+ deployed_time = (
1490
+ datetime.fromisoformat(deployed_as_of).timestamp()
1491
+ if deployed_as_of
1492
+ else datetime.now(timezone.utc).timestamp()
1493
+ )
1494
+ run_image_setup(deployed_time)
1495
+ _LAST_DEPLOYED = deployed_time
1496
+ return JSONResponse(
1497
+ status_code=200,
1498
+ content={"message": "Image and metadata reloaded successfully."},
1499
+ )
1500
+
1501
+
1502
+ @app.post("/{cls_or_fn_name}", response_class=JSONResponse)
1503
+ @app.post("/{cls_or_fn_name}/{method_name}", response_class=JSONResponse)
1504
+ async def run_callable(
1505
+ request: Request,
1506
+ cls_or_fn_name: str,
1507
+ method_name: Optional[str] = None,
1508
+ distributed_subcall=False,
1509
+ debug_port: Optional[int] = None,
1510
+ params: Optional[Union[Dict, str]] = Body(default=None),
1511
+ deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
1512
+ serialization: str = Header("json", alias="X-Serialization"),
1513
+ ):
1514
+ if cls_or_fn_name != os.environ["KT_CLS_OR_FN_NAME"]:
1515
+ raise HTTPException(
1516
+ status_code=404,
1517
+ detail=f"Callable '{cls_or_fn_name}' not found in metadata configuration. Found '{os.environ['KT_CLS_OR_FN_NAME']}' instead",
1518
+ )
1519
+
1520
+ # NOTE: The distributed replica processes (e.g. PyTorchProcess:run) rely on this running here even though
1521
+ # they will reconstruct the callable themselves, because they skip image reloading as a performance optimization.
1522
+ # Run load_callable in executor since it may do file I/O and other blocking operations
1523
+ callable_obj = await run_in_executor_with_context(
1524
+ None, load_callable, deployed_as_of
1525
+ )
1526
+
1527
+ # If this is a distributed call (and not a subcall from a different distributed replica),
1528
+ # and the type of distribution which requires a special call method (e.g. SIMD), use the
1529
+ # distributed supervisor to handle the call
1530
+ if DISTRIBUTED_SUPERVISOR and DISTRIBUTED_SUPERVISOR.intercept_call():
1531
+ # Run the blocking distributed call in executor to avoid blocking the event loop
1532
+ result = await run_in_executor_with_context(
1533
+ None,
1534
+ DISTRIBUTED_SUPERVISOR.call_distributed,
1535
+ request,
1536
+ cls_or_fn_name,
1537
+ method_name,
1538
+ params,
1539
+ distributed_subcall,
1540
+ debug_port,
1541
+ deployed_as_of,
1542
+ )
1543
+ clear_debugging_sessions()
1544
+ return result
1545
+
1546
+ # If this is not a distributed call, or the distribution type does not require special handling,
1547
+ # run the callable directly
1548
+ result = await run_callable_internal(
1549
+ callable_obj=callable_obj,
1550
+ cls_or_fn_name=cls_or_fn_name,
1551
+ method_name=method_name,
1552
+ params=params,
1553
+ serialization=serialization,
1554
+ debug_port=debug_port,
1555
+ )
1556
+ return result
1557
+
1558
+
1559
+ async def run_callable_internal(
1560
+ callable_obj: Callable,
1561
+ cls_or_fn_name: str,
1562
+ method_name: Optional[str] = None,
1563
+ params: Optional[Union[Dict, str]] = Body(default=None),
1564
+ serialization: str = "json",
1565
+ debug_port: Optional[int] = None,
1566
+ ):
1567
+ # Check if serialization is allowed
1568
+ allowed_serialization = os.getenv(
1569
+ "KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION
1570
+ ).split(",")
1571
+ if serialization not in allowed_serialization:
1572
+ raise HTTPException(
1573
+ status_code=400,
1574
+ detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
1575
+ )
1576
+
1577
+ # Process the call
1578
+ args = []
1579
+ kwargs = {}
1580
+ if params:
1581
+ if serialization == "pickle":
1582
+ # Handle pickle serialization - extract data from dictionary wrapper
1583
+ if isinstance(params, dict) and "data" in params:
1584
+ encoded_data = params.pop("data")
1585
+ pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
1586
+ param_args = pickle.loads(pickled_data)
1587
+ # data is unpickled in the format {"args": args, "kwargs": kwargs}
1588
+ params.update(param_args)
1589
+ elif isinstance(params, str):
1590
+ # Fallback for direct string
1591
+ pickled_data = base64.b64decode(params.encode("utf-8"))
1592
+ params = pickle.loads(pickled_data)
1593
+
1594
+ # Default JSON handling
1595
+ args = params.get("args", [])
1596
+ kwargs = params.get("kwargs", {})
1597
+
1598
+ if method_name:
1599
+ if not hasattr(callable_obj, method_name):
1600
+ raise HTTPException(
1601
+ status_code=404,
1602
+ detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
1603
+ )
1604
+ user_method = getattr(callable_obj, method_name)
1605
+ else:
1606
+ user_method = callable_obj
1607
+
1608
+ import inspect
1609
+
1610
+ # Check if the user method is async
1611
+ is_async_method = inspect.iscoroutinefunction(user_method)
1612
+
1613
+ if debug_port:
1614
+ logger.info(
1615
+ f"Debugging remote callable {cls_or_fn_name}.{method_name} on port {debug_port}"
1616
+ )
1617
+ deep_breakpoint(debug_port)
1618
+ # If using the debugger, step in here ("s") to enter your function/class method.
1619
+ if is_async_method:
1620
+ result = await user_method(*args, **kwargs)
1621
+ else:
1622
+ # Run sync method in thread pool to avoid blocking
1623
+ # Use lambda to properly pass both args and kwargs
1624
+ result = await run_in_executor_with_context(
1625
+ None, lambda: user_method(*args, **kwargs)
1626
+ )
1627
+ else:
1628
+ logger.debug(f"Calling remote callable {cls_or_fn_name}.{method_name}")
1629
+ if is_async_method:
1630
+ result = await user_method(*args, **kwargs)
1631
+ else:
1632
+ # Run sync method in thread pool to avoid blocking
1633
+ # Use lambda to properly pass both args and kwargs
1634
+ result = await run_in_executor_with_context(
1635
+ None, lambda: user_method(*args, **kwargs)
1636
+ )
1637
+
1638
+ # Handle case where sync method returns an awaitable (e.g., from an async framework)
1639
+ # This is less common but can happen with some async libraries
1640
+ if isinstance(result, Awaitable):
1641
+ result = await result
1642
+
1643
+ # Serialize response based on format
1644
+ if serialization == "pickle":
1645
+ try:
1646
+ pickled_result = pickle.dumps(result)
1647
+ encoded_result = base64.b64encode(pickled_result).decode("utf-8")
1648
+ result = {"data": encoded_result}
1649
+ except Exception as e:
1650
+ logger.error(f"Failed to pickle result: {str(e)}")
1651
+ raise SerializationError(
1652
+ f"Result could not be serialized with pickle: {str(e)}"
1653
+ )
1654
+ else:
1655
+ # Default JSON serialization
1656
+ try:
1657
+ json.dumps(result)
1658
+ except (TypeError, ValueError) as e:
1659
+ logger.error(f"Result is not JSON serializable: {str(e)}")
1660
+ raise SerializationError(
1661
+ f"Result could not be serialized to JSON: {str(e)}"
1662
+ )
1663
+
1664
+ clear_debugging_sessions()
1665
+
1666
+ return result
1667
+
1668
+
1669
+ def run_callable_internal_sync(
1670
+ callable_obj: Callable,
1671
+ cls_or_fn_name: str,
1672
+ method_name: Optional[str] = None,
1673
+ params: Optional[Union[Dict, str]] = None,
1674
+ serialization: str = "json",
1675
+ debug_port: Optional[int] = None,
1676
+ ):
1677
+ """Synchronous wrapper for run_callable_internal, used by distributed subprocesses."""
1678
+ import asyncio
1679
+ import inspect
1680
+
1681
+ # Check if serialization is allowed
1682
+ allowed_serialization = os.getenv(
1683
+ "KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION
1684
+ ).split(",")
1685
+ if serialization not in allowed_serialization:
1686
+ raise HTTPException(
1687
+ status_code=400,
1688
+ detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
1689
+ )
1690
+
1691
+ # Process the call
1692
+ args = []
1693
+ kwargs = {}
1694
+ if params:
1695
+ if serialization == "pickle":
1696
+ # Handle pickle serialization - extract data from dictionary wrapper
1697
+ if isinstance(params, dict) and "data" in params:
1698
+ encoded_data = params.pop("data")
1699
+ pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
1700
+ param_args = pickle.loads(pickled_data)
1701
+ # data is unpickled in the format {"args": args, "kwargs": kwargs}
1702
+ params.update(param_args)
1703
+ elif isinstance(params, str):
1704
+ # Fallback for direct string
1705
+ pickled_data = base64.b64decode(params.encode("utf-8"))
1706
+ params = pickle.loads(pickled_data)
1707
+
1708
+ # Default JSON handling
1709
+ args = params.get("args", [])
1710
+ kwargs = params.get("kwargs", {})
1711
+
1712
+ if method_name:
1713
+ if not hasattr(callable_obj, method_name):
1714
+ raise HTTPException(
1715
+ status_code=404,
1716
+ detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
1717
+ )
1718
+ user_method = getattr(callable_obj, method_name)
1719
+ else:
1720
+ user_method = callable_obj
1721
+
1722
+ # Check if the user method is async
1723
+ is_async_method = inspect.iscoroutinefunction(user_method)
1724
+
1725
+ if debug_port:
1726
+ logger.info(
1727
+ f"Debugging remote callable {cls_or_fn_name}.{method_name} on port {debug_port}"
1728
+ )
1729
+ deep_breakpoint(debug_port)
1730
+ # If using the debugger, step in here ("s") to enter your function/class method.
1731
+ if is_async_method:
1732
+ # For async methods in sync context, we need to run them in a new event loop
1733
+ result = asyncio.run(user_method(*args, **kwargs))
1734
+ else:
1735
+ result = user_method(*args, **kwargs)
1736
+ else:
1737
+ logger.debug(f"Calling remote callable {cls_or_fn_name}.{method_name}")
1738
+ if is_async_method:
1739
+ # For async methods in sync context, we need to run them in a new event loop
1740
+ result = asyncio.run(user_method(*args, **kwargs))
1741
+ else:
1742
+ result = user_method(*args, **kwargs)
1743
+
1744
+ # Handle case where sync method returns an awaitable
1745
+ if isinstance(result, Awaitable):
1746
+ result = asyncio.run(result)
1747
+
1748
+ # Serialize response based on format
1749
+ if serialization == "pickle":
1750
+ try:
1751
+ pickled_result = pickle.dumps(result)
1752
+ encoded_result = base64.b64encode(pickled_result).decode("utf-8")
1753
+ result = {"data": encoded_result}
1754
+ except Exception as e:
1755
+ logger.error(f"Failed to pickle result: {str(e)}")
1756
+ raise SerializationError(
1757
+ f"Result could not be serialized with pickle: {str(e)}"
1758
+ )
1759
+ else:
1760
+ # Default JSON serialization
1761
+ try:
1762
+ json.dumps(result)
1763
+ except (TypeError, ValueError) as e:
1764
+ logger.error(f"Result is not JSON serializable: {str(e)}")
1765
+ raise SerializationError(
1766
+ f"Result could not be serialized to JSON: {str(e)}"
1767
+ )
1768
+
1769
+ clear_debugging_sessions()
1770
+
1771
+ return result
1772
+
1773
+
1774
+ @app.get("/health", include_in_schema=False)
1775
+ @app.get("/", include_in_schema=False)
1776
+ def health():
1777
+ return {"status": "healthy"}
1778
+
1779
+
1780
+ if __name__ == "__main__" and not is_running_in_container():
1781
+ # NOTE: this will only run in local development, otherwise we start the uvicorn server in the pod template setup
1782
+ import uvicorn
1783
+ from dotenv import load_dotenv
1784
+
1785
+ load_dotenv()
1786
+
1787
+ logger.info("Starting HTTP server")
1788
+ uvicorn.run(app, host="0.0.0.0", port=os.environ.get("KT_SERVER_PORT", 32300))