kubetorch 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. kubetorch/__init__.py +59 -0
  2. kubetorch/cli.py +1939 -0
  3. kubetorch/cli_utils.py +967 -0
  4. kubetorch/config.py +453 -0
  5. kubetorch/constants.py +18 -0
  6. kubetorch/docs/Makefile +18 -0
  7. kubetorch/docs/__init__.py +0 -0
  8. kubetorch/docs/_ext/json_globaltoc.py +42 -0
  9. kubetorch/docs/api/cli.rst +10 -0
  10. kubetorch/docs/api/python/app.rst +21 -0
  11. kubetorch/docs/api/python/cls.rst +19 -0
  12. kubetorch/docs/api/python/compute.rst +25 -0
  13. kubetorch/docs/api/python/config.rst +11 -0
  14. kubetorch/docs/api/python/fn.rst +19 -0
  15. kubetorch/docs/api/python/image.rst +14 -0
  16. kubetorch/docs/api/python/secret.rst +18 -0
  17. kubetorch/docs/api/python/volumes.rst +13 -0
  18. kubetorch/docs/api/python.rst +101 -0
  19. kubetorch/docs/conf.py +69 -0
  20. kubetorch/docs/index.rst +20 -0
  21. kubetorch/docs/requirements.txt +5 -0
  22. kubetorch/globals.py +269 -0
  23. kubetorch/logger.py +59 -0
  24. kubetorch/resources/__init__.py +0 -0
  25. kubetorch/resources/callables/__init__.py +0 -0
  26. kubetorch/resources/callables/cls/__init__.py +0 -0
  27. kubetorch/resources/callables/cls/cls.py +159 -0
  28. kubetorch/resources/callables/fn/__init__.py +0 -0
  29. kubetorch/resources/callables/fn/fn.py +140 -0
  30. kubetorch/resources/callables/module.py +1315 -0
  31. kubetorch/resources/callables/utils.py +203 -0
  32. kubetorch/resources/compute/__init__.py +0 -0
  33. kubetorch/resources/compute/app.py +253 -0
  34. kubetorch/resources/compute/compute.py +2414 -0
  35. kubetorch/resources/compute/decorators.py +137 -0
  36. kubetorch/resources/compute/utils.py +1026 -0
  37. kubetorch/resources/compute/websocket.py +135 -0
  38. kubetorch/resources/images/__init__.py +1 -0
  39. kubetorch/resources/images/image.py +412 -0
  40. kubetorch/resources/images/images.py +64 -0
  41. kubetorch/resources/secrets/__init__.py +2 -0
  42. kubetorch/resources/secrets/kubernetes_secrets_client.py +377 -0
  43. kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
  44. kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
  45. kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
  46. kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
  47. kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
  48. kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
  49. kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
  50. kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
  51. kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
  52. kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
  53. kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
  54. kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
  55. kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
  56. kubetorch/resources/secrets/provider_secrets/providers.py +92 -0
  57. kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
  58. kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
  59. kubetorch/resources/secrets/secret.py +224 -0
  60. kubetorch/resources/secrets/secret_factory.py +64 -0
  61. kubetorch/resources/secrets/utils.py +222 -0
  62. kubetorch/resources/volumes/__init__.py +0 -0
  63. kubetorch/resources/volumes/volume.py +340 -0
  64. kubetorch/servers/__init__.py +0 -0
  65. kubetorch/servers/http/__init__.py +0 -0
  66. kubetorch/servers/http/distributed_utils.py +2968 -0
  67. kubetorch/servers/http/http_client.py +802 -0
  68. kubetorch/servers/http/http_server.py +1622 -0
  69. kubetorch/servers/http/server_metrics.py +255 -0
  70. kubetorch/servers/http/utils.py +722 -0
  71. kubetorch/serving/__init__.py +0 -0
  72. kubetorch/serving/autoscaling.py +153 -0
  73. kubetorch/serving/base_service_manager.py +344 -0
  74. kubetorch/serving/constants.py +77 -0
  75. kubetorch/serving/deployment_service_manager.py +431 -0
  76. kubetorch/serving/knative_service_manager.py +487 -0
  77. kubetorch/serving/raycluster_service_manager.py +526 -0
  78. kubetorch/serving/service_manager.py +18 -0
  79. kubetorch/serving/templates/deployment_template.yaml +17 -0
  80. kubetorch/serving/templates/knative_service_template.yaml +19 -0
  81. kubetorch/serving/templates/kt_setup_template.sh.j2 +91 -0
  82. kubetorch/serving/templates/pod_template.yaml +198 -0
  83. kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
  84. kubetorch/serving/templates/raycluster_template.yaml +35 -0
  85. kubetorch/serving/templates/service_template.yaml +21 -0
  86. kubetorch/serving/templates/workerset_template.yaml +36 -0
  87. kubetorch/serving/utils.py +344 -0
  88. kubetorch/utils.py +263 -0
  89. kubetorch-0.2.5.dist-info/METADATA +75 -0
  90. kubetorch-0.2.5.dist-info/RECORD +92 -0
  91. kubetorch-0.2.5.dist-info/WHEEL +4 -0
  92. kubetorch-0.2.5.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,1622 @@
1
+ import base64
2
+ import importlib
3
+ import importlib.util
4
+ import inspect
5
+ import json
6
+ import logging.config
7
+ import os
8
+ import pickle
9
+ import random
10
+ import subprocess
11
+ import sys
12
+ import threading
13
+ import time
14
+ import traceback
15
+ from contextlib import asynccontextmanager
16
+ from datetime import datetime, timezone
17
+ from pathlib import Path
18
+ from typing import Awaitable, Callable, Dict, Optional, Union
19
+
20
+ try:
21
+ import httpx
22
+ except:
23
+ pass
24
+
25
+ from fastapi import Body, FastAPI, Header, HTTPException, Request
26
+
27
+ from fastapi.exceptions import RequestValidationError
28
+ from fastapi.responses import JSONResponse
29
+
30
+ from pydantic import BaseModel
31
+ from starlette.middleware.base import BaseHTTPMiddleware
32
+
33
+ try:
34
+ from server_metrics import get_inactivity_ttl_annotation, HeartbeatManager, setup_otel_metrics
35
+ from utils import (
36
+ clear_debugging_sessions,
37
+ collect_reload_modules,
38
+ deep_breakpoint,
39
+ DEFAULT_ALLOWED_SERIALIZATION,
40
+ ensure_structured_logging,
41
+ is_running_in_kubernetes,
42
+ LOG_CONFIG,
43
+ request_id_ctx_var,
44
+ RSYNC_PORT,
45
+ wait_for_app_start,
46
+ )
47
+ except ImportError:
48
+ from .server_metrics import get_inactivity_ttl_annotation, HeartbeatManager, setup_otel_metrics
49
+ from .utils import (
50
+ clear_debugging_sessions,
51
+ collect_reload_modules,
52
+ deep_breakpoint,
53
+ DEFAULT_ALLOWED_SERIALIZATION,
54
+ ensure_structured_logging,
55
+ is_running_in_kubernetes,
56
+ LOG_CONFIG,
57
+ request_id_ctx_var,
58
+ RSYNC_PORT,
59
+ wait_for_app_start,
60
+ )
61
+
62
+ from starlette.background import BackgroundTask
63
+ from starlette.exceptions import HTTPException as StarletteHTTPException
64
+ from starlette.responses import StreamingResponse
65
+
66
+ logging.config.dictConfig(LOG_CONFIG)
67
+
68
+ # Set up our structured JSON logging
69
+ ensure_structured_logging()
70
+
71
+ # Create the print logger AFTER ensure_structured_logging so it inherits handlers
72
+ print_logger = logging.getLogger("print_redirect")
73
+
74
+ logger = logging.getLogger(__name__)
75
+ # Set log level based on environment variable
76
+ # Don't default the log_level
77
+ kt_log_level = os.getenv("KT_LOG_LEVEL")
78
+ if kt_log_level:
79
+ kt_log_level = kt_log_level.upper()
80
+ logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
81
+
82
+ _CACHED_CALLABLES = {}
83
+ _LAST_DEPLOYED = 0
84
+ _CACHED_IMAGE = []
85
+ DISTRIBUTED_SUPERVISOR = None
86
+ APP_PROCESS = None
87
+ _CALLABLE_LOAD_LOCK = threading.Lock() # Lock for thread-safe callable loading
88
+ LOKI_HOST = os.environ.get("LOKI_HOST", "loki-gateway.kubetorch.svc.cluster.local")
89
+ LOKI_PORT = int(os.environ.get("LOKI_PORT", 80)) # Default Loki port
90
+ KT_OTEL_ENABLED = os.environ.get("KT_OTEL_ENABLED", "False").lower() == "true"
91
+
92
+ # Global termination event that can be checked by running requests
93
+ TERMINATION_EVENT = threading.Event()
94
+ # Create a client for FastAPI service
95
+
96
+ # Set the python breakpoint to kt.deep_breakpoint
97
+ os.environ["PYTHONBREAKPOINT"] = "kubetorch.deep_breakpoint"
98
+
99
+ request_id_ctx_var.set(os.getenv("KT_LAUNCH_ID", "-"))
100
+
101
+ #####################################
102
+ ########### Proxy Helpers ###########
103
+ #####################################
104
+ if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
105
+ port = os.getenv("KT_APP_PORT")
106
+ logger.info(f"Creating /http reverse proxy to: http://localhost:{port}/")
107
+ proxy_client = httpx.AsyncClient(base_url=f"http://localhost:{port}/", timeout=None)
108
+ else:
109
+ proxy_client = None
110
+
111
+
112
+ async def _http_reverse_proxy(request: Request):
113
+ """Reverse proxy for /http/* routes to FastAPI service on its port"""
114
+ # Extract the endpoint name from the path
115
+ # request.path_params["path"] will contain everything after /http/
116
+ endpoint_path = request.path_params["path"]
117
+
118
+ # Build the URL for the FastAPI service
119
+ url = httpx.URL(path=f"/{endpoint_path}", query=request.url.query.encode("utf-8"))
120
+
121
+ # Build the request to forward to FastAPI
122
+ rp_req = proxy_client.build_request(request.method, url, headers=request.headers.raw, content=await request.body())
123
+
124
+ # Send the request and get streaming response
125
+ rp_resp = await proxy_client.send(rp_req, stream=True)
126
+
127
+ # Return streaming response
128
+ return StreamingResponse(
129
+ rp_resp.aiter_raw(),
130
+ status_code=rp_resp.status_code,
131
+ headers=rp_resp.headers,
132
+ background=BackgroundTask(rp_resp.aclose),
133
+ )
134
+
135
+
136
+ #####################################
137
+ ########### Cache Helpers ###########
138
+ #####################################
139
+ def clear_cache():
140
+ global _CACHED_CALLABLES
141
+
142
+ logger.debug("Clearing callables cache.")
143
+ _CACHED_CALLABLES.clear()
144
+
145
+
146
+ def cached_image_setup():
147
+ logger.debug("Starting cached image setup.")
148
+ global _CACHED_IMAGE
149
+ global APP_PROCESS
150
+
151
+ dockerfile_path = kt_directory() / "image.dockerfile"
152
+ with open(dockerfile_path, "r") as file:
153
+ lines = file.readlines()
154
+ lines = [line.strip() for line in lines]
155
+
156
+ # find first line where image differs from cache and update cache
157
+ cache_mismatch_index = -1
158
+ cmd_mismatch = False
159
+ for i, (new_line, cached_line) in enumerate(zip(lines, _CACHED_IMAGE)):
160
+ if new_line.startswith("CMD"):
161
+ cmd_mismatch = True
162
+
163
+ if new_line != cached_line or "# override" in new_line or cmd_mismatch:
164
+ cache_mismatch_index = i
165
+ break
166
+ if cache_mismatch_index == -1:
167
+ if len(lines) != len(_CACHED_IMAGE):
168
+ cache_mismatch_index = min(len(lines), len(_CACHED_IMAGE))
169
+ else:
170
+ cache_mismatch_index = len(lines)
171
+ _CACHED_IMAGE = lines
172
+
173
+ if cache_mismatch_index == len(lines):
174
+ return
175
+
176
+ if not (cache_mismatch_index == len(lines) - 1 and cmd_mismatch):
177
+ logger.info("Running image setup.")
178
+ else:
179
+ logger.debug("Skipping image setup steps, no changes detected.")
180
+
181
+ # Grab the current list of installed dependencies with pip freeze to check if anything changes (we need to send a
182
+ # SIGHUP to restart the server if so)
183
+ start_deps = None
184
+ import subprocess
185
+
186
+ try:
187
+ res = subprocess.run(
188
+ ["pip", "freeze"],
189
+ capture_output=True,
190
+ text=True,
191
+ check=True,
192
+ )
193
+ start_deps = res.stdout.splitlines()
194
+ except subprocess.CalledProcessError as e:
195
+ logger.error(f"Failed to run pip freeze: {e}")
196
+
197
+ # only run image setup steps starting from cache mismatch point
198
+ kt_pip_cmd = None
199
+ for line in lines[cache_mismatch_index:]:
200
+ command = ""
201
+ if line.strip().startswith("#"):
202
+ continue # ignore comments
203
+ if line.startswith("RUN") or line.startswith("CMD"):
204
+ command = line[len("RUN ") :]
205
+
206
+ if command.startswith("$KT_PIP_INSTALL_CMD"):
207
+ kt_pip_cmd = kt_pip_cmd or _get_kt_pip_install_cmd()
208
+ command = command.replace("$KT_PIP_INSTALL_CMD", kt_pip_cmd)
209
+ elif line.startswith("COPY"):
210
+ _, source, dest = line.split()
211
+ # COPY instructions are essentially no-ops since rsync_file_updates()
212
+ # already placed files in their correct locations.
213
+ # But we verify the files exist and log the absolute paths for clarity.
214
+
215
+ # Determine the actual absolute destination path
216
+ if dest and dest.startswith("/"):
217
+ # Already absolute
218
+ dest_path = Path(dest)
219
+ elif dest and dest.startswith("~/"):
220
+ # Tilde prefix - strip it and treat as relative to cwd
221
+ dest_path = Path.cwd() / dest[2:]
222
+ else:
223
+ # Relative to working directory (including explicit basenames)
224
+ dest_path = Path.cwd() / dest
225
+
226
+ # Verify the destination exists (it should have been rsync'd)
227
+ if dest_path.exists():
228
+ logger.info(f"Copied {source} to {dest_path.absolute()}")
229
+ else:
230
+ raise FileNotFoundError(
231
+ f"COPY {source} {dest} failed: destination {dest_path.absolute()} does not exist. "
232
+ f"This likely means the rsync operation failed to sync the files correctly."
233
+ )
234
+ elif line.startswith("ENV"):
235
+ # Need to handle the case where the env var is being set to "" (empty string)
236
+ line_vals = line.split(" ", 2)
237
+ if len(line_vals) < 2: # ENV line must have at least key
238
+ raise ValueError("ENV line cannot be empty")
239
+ if len(line_vals) == 2: # ENV line with just key
240
+ key = line_vals[1]
241
+ val = ""
242
+ elif len(line_vals) == 3: # ENV line with key and value
243
+ key, val = line_vals[1], line_vals[2]
244
+
245
+ # Expand environment variables in the value
246
+ # This supports patterns like $VAR, ${VAR}, and $VAR:default_value
247
+ expanded_val = os.path.expandvars(val)
248
+
249
+ if key not in [
250
+ "KT_FILE_PATH",
251
+ "KT_MODULE_NAME",
252
+ "KT_CLS_OR_FN_NAME",
253
+ "KT_INIT_ARGS",
254
+ "KT_CALLABLE_TYPE",
255
+ "KT_DISTRIBUTED_CONFIG",
256
+ ]:
257
+ logger.info(f"Setting env var {key}")
258
+ os.environ[key] = expanded_val
259
+ # If the env var is specifically KT_LOG_LEVEL, we need to update the logger level
260
+ if key == "KT_LOG_LEVEL":
261
+ global kt_log_level
262
+ kt_log_level = expanded_val.upper()
263
+ logger.setLevel(kt_log_level)
264
+ logger.info(f"Updated log level to {kt_log_level}")
265
+ elif line.startswith("FROM"):
266
+ continue
267
+ elif line:
268
+ raise ValueError(f"Unrecognized image setup instruction {line}")
269
+
270
+ if command:
271
+ is_app_cmd = line.startswith("CMD")
272
+ if is_app_cmd:
273
+ logger.info(f"Running app command: {command}")
274
+ else:
275
+ logger.info(f"Running image setup step: {command}")
276
+
277
+ try:
278
+ # Use subprocess.Popen to capture output and redirect through StreamToLogger
279
+ env = os.environ.copy()
280
+ env["PYTHONUNBUFFERED"] = "1"
281
+
282
+ if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
283
+ if APP_PROCESS and APP_PROCESS.poll() is None:
284
+ APP_PROCESS.kill()
285
+
286
+ process = subprocess.Popen(
287
+ command,
288
+ shell=True,
289
+ stdout=subprocess.PIPE,
290
+ stderr=subprocess.PIPE,
291
+ universal_newlines=True,
292
+ bufsize=1,
293
+ env=env,
294
+ )
295
+
296
+ if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
297
+ APP_PROCESS = process
298
+
299
+ # Collect stderr for potential error logging
300
+ import threading
301
+
302
+ stderr_lines = []
303
+ stderr_lock = threading.Lock()
304
+
305
+ # Stream stdout and stderr in real-time
306
+ # We need to do all this so the stdout and stderr are prints with the correct formatting
307
+ # for our queries. Without it they just flow straight to system stdout and stderr without any
308
+
309
+ def stream_output(pipe, log_func, request_id, collect_stderr=False):
310
+ request_id_ctx_var.set(request_id)
311
+ for line in iter(pipe.readline, ""):
312
+ if line:
313
+ stripped_line = line.rstrip()
314
+ log_func(stripped_line)
315
+
316
+ # Collect stderr lines for potential error logging
317
+ if collect_stderr:
318
+ with stderr_lock:
319
+ stderr_lines.append(stripped_line.lstrip("ERROR: "))
320
+ pipe.close()
321
+
322
+ # Start streaming threads
323
+ current_request_id = request_id_ctx_var.get("-")
324
+
325
+ stderr_log_func = logger.error if is_app_cmd else logger.debug
326
+ stdout_thread = threading.Thread(
327
+ target=stream_output,
328
+ args=(process.stdout, logger.info, current_request_id),
329
+ )
330
+ stderr_thread = threading.Thread(
331
+ target=stream_output,
332
+ args=(
333
+ process.stderr,
334
+ stderr_log_func,
335
+ current_request_id,
336
+ not is_app_cmd,
337
+ ),
338
+ )
339
+
340
+ stdout_thread.daemon = True
341
+ stderr_thread.daemon = True
342
+ stdout_thread.start()
343
+ stderr_thread.start()
344
+
345
+ if is_app_cmd and os.getenv("KT_APP_PORT"):
346
+ # wait for internal app to be healthy/ready if run port is provided
347
+ try:
348
+ port = os.getenv("KT_APP_PORT")
349
+ logger.debug(f"Waiting for internal app on port {port} to start:")
350
+ wait_for_app_start(
351
+ port=port,
352
+ health_check=os.getenv("KT_APP_HEALTHCHECK"),
353
+ process=process,
354
+ )
355
+ logger.info(f"App on port {port} is ready.")
356
+ except Exception as e:
357
+ logger.error(f"Caught exception waiting for app to start: {e}")
358
+ else:
359
+ # Check if this is a background command (ends with &)
360
+ is_background = command.rstrip().endswith("&")
361
+
362
+ if is_background:
363
+ # For background processes, give it a moment to start and check for immediate failures
364
+ import time
365
+
366
+ time.sleep(0.5) # Brief pause to catch immediate errors
367
+
368
+ # Check if process failed immediately
369
+ poll_result = process.poll()
370
+ if poll_result is not None and poll_result != 0:
371
+ # Process exited with error
372
+ stdout_thread.join(timeout=1)
373
+ stderr_thread.join(timeout=1)
374
+ return_code = poll_result
375
+ else:
376
+ # Process is running in background successfully
377
+ logger.info(f"Background process started successfully (PID: {process.pid})")
378
+ return_code = 0 # Indicate success for background start
379
+ else:
380
+ # Wait for process to complete
381
+ return_code = process.wait()
382
+
383
+ # Wait for streaming threads to finish
384
+ stdout_thread.join()
385
+ stderr_thread.join()
386
+
387
+ if return_code != 0 and not is_app_cmd:
388
+ with stderr_lock:
389
+ if stderr_lines:
390
+ logger.error(f"Failed to run command '{command}' with stderr:")
391
+ for stderr_line in stderr_lines:
392
+ logger.error(stderr_line)
393
+ except subprocess.CalledProcessError as e:
394
+ logger.error(f"Failed to run command '{command}' with error: {e}")
395
+ with stderr_lock:
396
+ if stderr_lines:
397
+ logger.error("Stderr:")
398
+ for stderr_line in stderr_lines:
399
+ logger.error(stderr_line)
400
+ # Check if any dependencies changed and if so reload them inside the server process
401
+ if start_deps:
402
+ try:
403
+ # Run pip freeze and capture the output
404
+ res = subprocess.run(
405
+ ["pip", "freeze"],
406
+ capture_output=True,
407
+ text=True,
408
+ check=True,
409
+ )
410
+ end_deps = res.stdout.splitlines()
411
+ # We only need to look at the deps which were already installed (i.e. lines in start_deps),
412
+ # new ones can't be "stale" inside the current server process
413
+ # We also only use lines with exact pypi versions (has "=="), no editable
414
+ changed_deps = [line.split("==")[0] for line in start_deps if "==" in line and line not in end_deps]
415
+ imported_changed_deps = [
416
+ dep for dep in changed_deps if dep in sys.modules
417
+ ] # Only reload deps which are already imported
418
+ if imported_changed_deps:
419
+ logger.debug(f"New dependencies found: {imported_changed_deps}, forcing reload")
420
+
421
+ # Don't clear the callable cache here - let load_callable_from_env handle it to preserve __kt_cached_state__
422
+ if DISTRIBUTED_SUPERVISOR:
423
+ DISTRIBUTED_SUPERVISOR.cleanup()
424
+
425
+ # Remove changed modules from sys.modules to override fresh imports
426
+ modules_to_remove = []
427
+ for module_name in sys.modules:
428
+ for dep in imported_changed_deps:
429
+ if module_name == dep or module_name.startswith(dep + "."):
430
+ modules_to_remove.append(module_name)
431
+ break
432
+
433
+ for module_name in modules_to_remove:
434
+ try:
435
+ del sys.modules[module_name]
436
+ logger.debug(f"Removed module {module_name} from sys.modules")
437
+ except KeyError:
438
+ pass
439
+ except subprocess.CalledProcessError as e:
440
+ logger.error(f"Failed to run pip freeze: {e}")
441
+
442
+
443
+ def run_image_setup(deployed_time: Optional[float] = None):
444
+ if os.environ["KT_FREEZE"] == "True" or not is_running_in_kubernetes():
445
+ return
446
+
447
+ rsync_file_updates()
448
+
449
+ dockerfile_path = kt_directory() / "image.dockerfile"
450
+ if not dockerfile_path.exists():
451
+ raise FileNotFoundError(f"No image and metadata configuration found in path: {str(dockerfile_path)}")
452
+ while (
453
+ # May need to give the dockerfile time to rsync over, so wait until the dockerfile timestamp is later than
454
+ # when we started the deployment (recorded in .to and passed here as deployed_time). We also should only
455
+ # wait if _LAST_DEPLOYED is not zero, as the first time the server is deployed the image is written before
456
+ # the server starts so we don't need to wait.
457
+ _LAST_DEPLOYED
458
+ and dockerfile_path.stat().st_mtime < deployed_time
459
+ and datetime.now(timezone.utc).timestamp() - deployed_time < 5
460
+ ):
461
+ time.sleep(0.1)
462
+
463
+ cached_image_setup()
464
+
465
+ if not os.getenv("KT_CALLABLE_TYPE") == "app":
466
+ logger.debug("Completed cached image setup.")
467
+
468
+
469
+ #####################################
470
+ ######## Generic Helpers ############
471
+ #####################################
472
+ class SerializationError(Exception):
473
+ pass
474
+
475
+
476
+ def kt_directory():
477
+ if "KT_DIRECTORY" in os.environ:
478
+ return Path(os.environ["KT_DIRECTORY"]).expanduser()
479
+ else:
480
+ return Path.cwd() / ".kt"
481
+
482
+
483
+ def _get_kt_pip_install_cmd() -> Optional[str]:
484
+ """Get the actual KT_PIP_INSTALL_CMD value for command expansion."""
485
+ kt_pip_cmd = os.getenv("KT_PIP_INSTALL_CMD")
486
+ if not kt_pip_cmd: # Fallback to reading from file
487
+ try:
488
+ with open(kt_directory() / "kt_pip_install_cmd", "r") as f:
489
+ return f.read().strip()
490
+ except FileNotFoundError:
491
+ return None
492
+ return kt_pip_cmd
493
+
494
+
495
+ def is_running_in_container():
496
+ # Check for .dockerenv file which exists in Docker containers
497
+ return Path("/.dockerenv").exists()
498
+
499
+
500
+ async def run_in_executor_with_context(executor, func, *args):
501
+ """
502
+ Helper to run a function in an executor while preserving the request_id context.
503
+
504
+ This wrapper captures the current request_id from the context before running
505
+ the function in a thread pool executor, then sets it in the new thread.
506
+ """
507
+ import asyncio
508
+
509
+ # Capture the current request_id before switching threads
510
+ current_request_id = request_id_ctx_var.get("-")
511
+
512
+ def wrapper(*args):
513
+ # Set the request_id in the executor thread
514
+ token = None
515
+ if current_request_id != "-":
516
+ token = request_id_ctx_var.set(current_request_id)
517
+ try:
518
+ return func(*args)
519
+ finally:
520
+ # Clean up the context to avoid leaking between requests
521
+ if token is not None:
522
+ request_id_ctx_var.reset(token)
523
+
524
+ return await asyncio.get_event_loop().run_in_executor(executor, wrapper, *args)
525
+
526
+
527
+ def should_reload(deployed_as_of: Optional[str] = None) -> bool:
528
+ """
529
+ Determine if the server should reload based on the deployment timestamp.
530
+ If deployed_as_of is provided, it checks against the last deployed time.
531
+ If not provided, it defaults to False.
532
+ """
533
+ if deployed_as_of in [None, "null", "None"]:
534
+ return False
535
+
536
+ try:
537
+ deployed_time = datetime.fromisoformat(deployed_as_of).timestamp()
538
+ return deployed_time > _LAST_DEPLOYED
539
+ except ValueError as e:
540
+ logger.error(f"Invalid deployed_as_of format: {deployed_as_of}. Error: {e}")
541
+ return True
542
+
543
+
544
+ def load_callable(
545
+ deployed_as_of: Optional[str] = None,
546
+ distributed_subprocess: bool = False,
547
+ reload_cleanup_fn: [Callable, None] = None,
548
+ ):
549
+ global _LAST_DEPLOYED
550
+
551
+ callable_name = os.environ["KT_CLS_OR_FN_NAME"]
552
+
553
+ callable_obj = _CACHED_CALLABLES.get(callable_name, None)
554
+ if callable_obj and not should_reload(deployed_as_of):
555
+ # If the callable is cached and doesn't need reload, return it immediately
556
+ logger.debug("Returning cached callable.")
557
+ return callable_obj
558
+
559
+ # Slow path: need to load or reload - use lock for thread safety
560
+ with _CALLABLE_LOAD_LOCK:
561
+ # Double-check within lock (another thread might have loaded it)
562
+ callable_obj = _CACHED_CALLABLES.get(callable_name, None)
563
+ if callable_obj and not should_reload(deployed_as_of):
564
+ logger.debug("Returning cached callable (found after acquiring lock).")
565
+ return callable_obj
566
+ # Proceed with loading/reloading
567
+ return _load_callable_internal(deployed_as_of, distributed_subprocess, reload_cleanup_fn, callable_obj)
568
+
569
+
570
+ def _load_callable_internal(
571
+ deployed_as_of: Optional[str] = None,
572
+ distributed_subprocess: bool = False,
573
+ reload_cleanup_fn: [Callable, None] = None,
574
+ callable_obj=None,
575
+ ):
576
+ """Internal callable loading logic - should be called within lock for thread safety."""
577
+ global _LAST_DEPLOYED
578
+
579
+ callable_name = os.environ["KT_CLS_OR_FN_NAME"]
580
+
581
+ if not callable_obj:
582
+ logger.debug("Callable not found in cache, loading from environment.")
583
+ else:
584
+ logger.debug(
585
+ f"Callable found in cache, but reloading because deployed_as_of {deployed_as_of} is newer than last deployed time {_LAST_DEPLOYED}"
586
+ )
587
+
588
+ # If not in cache or we have a more recent deployment timestamp, update metadata and reload
589
+ if reload_cleanup_fn and _LAST_DEPLOYED:
590
+ # If a reload cleanup function is provided and we've already deployed at least once, call it before
591
+ # reloading the callable
592
+ reload_cleanup_fn()
593
+
594
+ deployed_time = (
595
+ datetime.fromisoformat(deployed_as_of).timestamp() if deployed_as_of else datetime.now(timezone.utc).timestamp()
596
+ )
597
+ if not distributed_subprocess:
598
+ # We don't reload the image in distributed subprocess/es, as we already did it in the
599
+ # main process and we don't want to do it multiple times (in each subprocess).
600
+ if _LAST_DEPLOYED:
601
+ logger.info("Patching image and code updates and reloading callable.")
602
+ else:
603
+ logger.info("Setting up image and loading callable.")
604
+ run_image_setup(deployed_time)
605
+
606
+ distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
607
+ if distributed_config not in ["null", "None"] and not distributed_subprocess:
608
+ logger.debug(f"Loading distributed supervisor: {distributed_config}")
609
+ callable_obj = load_distributed_supervisor(deployed_as_of=deployed_as_of)
610
+ logger.debug("Distributed supervisor loaded successfully.")
611
+ else:
612
+ logger.debug(f"Loading callable from environment: {callable_name}")
613
+ callable_obj = load_callable_from_env()
614
+ logger.debug("Callable loaded successfully.")
615
+
616
+ _LAST_DEPLOYED = deployed_time
617
+ _CACHED_CALLABLES[callable_name] = callable_obj
618
+
619
+ return callable_obj
620
+
621
+
622
+ def load_distributed_supervisor(deployed_as_of: Optional[str] = None):
623
+ global DISTRIBUTED_SUPERVISOR
624
+
625
+ if os.environ["KT_FILE_PATH"] not in sys.path:
626
+ sys.path.insert(0, os.environ["KT_FILE_PATH"])
627
+
628
+ distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
629
+
630
+ # If this is the main process of a distributed call, we don't load the callable directly,
631
+ # we create a new supervisor if it doesn't exist or if the config has changed.
632
+ # We don't create a supervisor if this is a distributed subprocess.
633
+ config_hash = hash(str(distributed_config))
634
+ if DISTRIBUTED_SUPERVISOR is None or config_hash != DISTRIBUTED_SUPERVISOR.config_hash:
635
+ from .distributed_utils import distributed_supervisor_factory
636
+
637
+ logger.info(f"Loading distributed supervisor with config: {distributed_config}")
638
+ distributed_config = json.loads(distributed_config)
639
+ # If we already have some distributed processes, we need to clean them up before creating a new supervisor.
640
+ if DISTRIBUTED_SUPERVISOR:
641
+ DISTRIBUTED_SUPERVISOR.cleanup()
642
+ DISTRIBUTED_SUPERVISOR = distributed_supervisor_factory(**distributed_config)
643
+ DISTRIBUTED_SUPERVISOR.config_hash = config_hash
644
+ try:
645
+ # If there are any errors during setup, we catch and log them, and then undo the setup
646
+ # so that the distributed supervisor is not left in a broken state (and otherwise can still fail
647
+ # when we call DISTRIBUTED_SUPERVISOR.cleanup() in lifespan).
648
+ DISTRIBUTED_SUPERVISOR.setup(deployed_as_of=deployed_as_of)
649
+ except Exception as e:
650
+ logger.error(f"Failed to set up distributed supervisor with config {distributed_config}: {e}")
651
+ DISTRIBUTED_SUPERVISOR = None
652
+ raise e
653
+ return DISTRIBUTED_SUPERVISOR
654
+
655
+
656
+ def patch_sys_path():
657
+ abs_path = str(Path(os.environ["KT_FILE_PATH"]).expanduser().resolve())
658
+ if os.environ["KT_FILE_PATH"] not in sys.path:
659
+ sys.path.insert(0, abs_path)
660
+ logger.debug(f"Added {abs_path} to sys.path")
661
+
662
+ # Maybe needed for subprocesses (e.g. distributed) to find the callable's module
663
+ # Needed for distributed subprocesses to find the file path
664
+ existing_path = os.environ.get("PYTHONPATH", "")
665
+ if os.environ["KT_FILE_PATH"] not in existing_path:
666
+ os.environ["PYTHONPATH"] = f"{abs_path}{os.pathsep}{existing_path}" if existing_path else abs_path
667
+ logger.debug(f"Set PYTHONPATH to {os.environ['PYTHONPATH']}")
668
+
669
+
670
+ def load_callable_from_env():
671
+ """Load and cache callable objects from env, preserving state if __kt_cached_state__ is available."""
672
+ cls_or_fn_name = os.environ["KT_CLS_OR_FN_NAME"]
673
+ module_name = os.environ["KT_MODULE_NAME"]
674
+
675
+ # Check if we have an existing cached callable and extract state if available
676
+ cached_state = None
677
+ existing_callable = _CACHED_CALLABLES.get(cls_or_fn_name, None)
678
+
679
+ if existing_callable and hasattr(existing_callable, "__kt_cached_state__"):
680
+ try:
681
+ logger.info(f"Extracting cached state from {cls_or_fn_name} via __kt_cached_state__")
682
+ cached_state = existing_callable.__kt_cached_state__()
683
+ if cached_state is not None and not isinstance(cached_state, dict):
684
+ logger.warning(
685
+ f"__kt_cached_state__ returned non-dict type: {type(cached_state)}. Ignoring cached state."
686
+ )
687
+ cached_state = None
688
+ except Exception as e:
689
+ # This could happen if modules were removed from sys.modules during image setup
690
+ # and the callable's __kt_cached_state__ method depends on them
691
+ logger.warning(
692
+ f"Failed to extract cached state from {cls_or_fn_name} (possibly due to module reloading): {e}. "
693
+ f"Proceeding without cached state."
694
+ )
695
+ cached_state = None
696
+
697
+ # Now that we have the state, clean up the old callable to free memory
698
+ if existing_callable:
699
+ logger.debug(f"Deleting existing callable: {cls_or_fn_name}")
700
+ _CACHED_CALLABLES.pop(cls_or_fn_name, None)
701
+ del existing_callable
702
+ # Garbage collect to ensure everything cleaned up (especially GPU memory)
703
+ import gc
704
+
705
+ gc.collect()
706
+
707
+ patch_sys_path()
708
+
709
+ # If we're inside a distributed subprocess or the main process of a non-distributed call,
710
+ # we load and instantiate the callable.
711
+ try:
712
+ # Try regular package import first
713
+ if module_name in sys.modules:
714
+ # Clear any existing debugging sessions when reloading modules
715
+ clear_debugging_sessions()
716
+
717
+ # Find all user modules under kt_home_dir to reload
718
+ kt_home_dir = (
719
+ Path(os.environ.get("KT_DIRECTORY")).expanduser().resolve()
720
+ if "KT_DIRECTORY" in os.environ
721
+ else Path.cwd()
722
+ )
723
+ kt_home_dir_str = str(kt_home_dir) + os.sep
724
+
725
+ # Collect all user modules and reload in order from children to parents
726
+ modules_to_reload_sorted = collect_reload_modules(kt_home_dir_str)
727
+ logger.info(f"Reloading imported usermodules under {kt_home_dir_str}")
728
+ logger.debug(f"Modules to reload: {modules_to_reload_sorted}")
729
+ for name in modules_to_reload_sorted:
730
+ if name in sys.modules:
731
+ try:
732
+ importlib.reload(sys.modules[name])
733
+ except (ImportError, ValueError, TypeError) as e:
734
+ logger.debug(f"Could not reload module {name}: {e}. Deleting and reimporting instead.")
735
+ try:
736
+ del sys.modules[name]
737
+ importlib.import_module(name)
738
+ except Exception as reimport_error:
739
+ logger.error(f"Failed to reimport module {name} after reload failure: {reimport_error}")
740
+
741
+ module = sys.modules.get(module_name)
742
+ if module is None:
743
+ logger.debug(f"Module {module_name} not found after reload, reimporting")
744
+ module = importlib.import_module(module_name)
745
+ else:
746
+ logger.debug(f"Importing module {module_name}")
747
+ module = importlib.import_module(module_name)
748
+ logger.debug(f"Module {module_name} loaded")
749
+
750
+ # Ensure our structured logging is in place after user module import
751
+ # (in case the user's module configured its own logging)
752
+ ensure_structured_logging()
753
+
754
+ callable_obj = getattr(module, cls_or_fn_name)
755
+ logger.debug(f"Callable {cls_or_fn_name} loaded")
756
+ except (ImportError, ValueError) as original_error:
757
+ # Fall back to file-based import if package import fails
758
+ try:
759
+ module = import_from_file(os.environ["KT_FILE_PATH"], module_name)
760
+ # Ensure structured logging after file-based import
761
+ ensure_structured_logging()
762
+ callable_obj = getattr(module, cls_or_fn_name)
763
+ except (ImportError, ValueError):
764
+ # Raise the original error if file import also fails, because the errors which are raised here are
765
+ # more opaque and less useful than the original ImportError or ValueError.
766
+ raise original_error
767
+ except AttributeError as e:
768
+ # If the callable is not found in the module, raise an error
769
+ raise HTTPException(
770
+ status_code=404,
771
+ detail=f"Callable '{cls_or_fn_name}' not found in module '{module_name}'",
772
+ ) from e
773
+
774
+ # Unwrap to remove any kt deploy decorators (e.g. @kt.compute)
775
+ if hasattr(callable_obj, "__wrapped__"):
776
+ callable_obj = callable_obj.__wrapped__
777
+
778
+ if isinstance(callable_obj, type):
779
+ # Prepare init arguments
780
+ init_kwargs = {}
781
+
782
+ # Add user-provided init_args
783
+ if os.environ["KT_INIT_ARGS"] not in ["null", "None"]:
784
+ init_kwargs = json.loads(os.environ["KT_INIT_ARGS"])
785
+ logger.info(f"Setting init_args {init_kwargs}")
786
+
787
+ # Add cached state if available
788
+ # Allow user to manually set "kt_cached_state" to override/disable cache
789
+ if cached_state is not None and "kt_cached_state" not in init_kwargs:
790
+ # Check if the class's __init__ accepts kt_cached_state parameter
791
+ sig = inspect.signature(callable_obj.__init__)
792
+ if "kt_cached_state" in sig.parameters:
793
+ logger.info(f"Passing cached state to {cls_or_fn_name}.__init__")
794
+ init_kwargs["kt_cached_state"] = cached_state
795
+ else:
796
+ raise ValueError(
797
+ f"Class {cls_or_fn_name} has __kt_cached_state__ method but __init__ does not accept "
798
+ f"'kt_cached_state' parameter. Please add 'kt_cached_state=None' to __init__ signature."
799
+ )
800
+
801
+ # Instantiate with combined arguments
802
+ if init_kwargs:
803
+ callable_obj = callable_obj(**init_kwargs)
804
+ else:
805
+ callable_obj = callable_obj()
806
+
807
+ return callable_obj
808
+
809
+
810
+ def import_from_file(file_path: str, module_name: str):
811
+ """Import a module from file path."""
812
+ module_parts = module_name.split(".")
813
+ depth = max(0, len(module_parts) - 1)
814
+
815
+ # Convert file_path to absolute path if it's not already (note, .resolve will append the current working directory
816
+ # if file_path is relative)
817
+ abs_path = Path(file_path).expanduser().resolve()
818
+ # Ensure depth doesn't exceed available parent directories
819
+ max_available_depth = len(abs_path.parents) - 1
820
+
821
+ if max_available_depth < 0:
822
+ # File has no parent directories, use the file's directory itself
823
+ parent_path = str(abs_path.parent)
824
+ else:
825
+ # Clamp depth to available range to avoid IndexError
826
+ depth = min(depth, max_available_depth)
827
+ parent_path = str(abs_path.parents[depth])
828
+
829
+ if parent_path not in sys.path:
830
+ sys.path.insert(0, parent_path)
831
+
832
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
833
+ if spec is None or spec.loader is None:
834
+ raise ImportError(f"Could not load spec for module {module_name} from {file_path}")
835
+
836
+ module = importlib.util.module_from_spec(spec)
837
+ spec.loader.exec_module(module)
838
+ return module
839
+
840
+
841
+ #####################################
842
+ ########## Rsync Helpers ############
843
+ #####################################
844
+ def generate_rsync_command(subdir: str = ".", exclude_absolute: bool = True):
845
+ """Generate rsync command for syncing from jump pod.
846
+
847
+ Args:
848
+ subdir: Directory to sync to (default current directory)
849
+ exclude_absolute: Whether to exclude __absolute__ directory (default True)
850
+ """
851
+ service_name = os.getenv("KT_SERVICE_NAME")
852
+ namespace = os.getenv("POD_NAMESPACE")
853
+
854
+ exclude_opt = "--exclude='__absolute__*' " if exclude_absolute else ""
855
+ logger.debug("Syncing code from rsync pod to local directory")
856
+ return f"rsync -av {exclude_opt}rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/ {subdir}"
857
+
858
+
859
+ def rsync_file_updates():
860
+ """Rsync files from the jump pod to the worker pod.
861
+
862
+ Performs two rsync operations in parallel:
863
+ 1. Regular files (excluding __absolute__*) to the working directory
864
+ 2. Absolute path files (under __absolute__/) to their absolute destinations
865
+ """
866
+ import concurrent.futures
867
+ from concurrent.futures import ThreadPoolExecutor
868
+
869
+ service_name = os.getenv("KT_SERVICE_NAME")
870
+ namespace = os.getenv("POD_NAMESPACE")
871
+
872
+ # Build base rsync URL
873
+ rsync_base = f"rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/"
874
+
875
+ max_retries = 5
876
+ base_delay = 1 # seconds
877
+ max_delay = 30 # seconds
878
+
879
+ def run_rsync_with_retries(rsync_cmd, description):
880
+ """Helper to run rsync with exponential backoff retries."""
881
+ for attempt in range(max_retries):
882
+ resp = subprocess.run(
883
+ rsync_cmd,
884
+ shell=True,
885
+ capture_output=True,
886
+ text=True,
887
+ )
888
+
889
+ if resp.returncode == 0:
890
+ logger.debug(f"Successfully rsync'd {description}")
891
+ return # Success!
892
+
893
+ # Check if it's a retryable error
894
+ retryable_errors = [
895
+ "max connections",
896
+ "Temporary failure in name resolution",
897
+ "Name or service not known",
898
+ "Connection refused",
899
+ "No route to host",
900
+ ]
901
+
902
+ is_retryable = any(error in resp.stderr for error in retryable_errors)
903
+
904
+ if is_retryable and attempt < max_retries - 1:
905
+ # Calculate exponential backoff with jitter
906
+ delay = min(base_delay * (2**attempt) + random.uniform(0, 1), max_delay)
907
+ logger.warning(
908
+ f"Rsync {description} failed with retryable error: {resp.stderr.strip()}. "
909
+ f"Retrying in {delay:.1f} seconds (attempt {attempt + 1}/{max_retries})"
910
+ )
911
+ time.sleep(delay)
912
+ else:
913
+ # For non-retryable errors or final attempt, raise immediately
914
+ if attempt == max_retries - 1:
915
+ logger.error(f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}")
916
+ raise RuntimeError(f"Rsync {description} failed with error: {resp.stderr}")
917
+
918
+ # If we exhausted all retries
919
+ raise RuntimeError(f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}")
920
+
921
+ def rsync_regular_files():
922
+ """Rsync regular files (excluding __absolute__*) to working directory."""
923
+ rsync_cmd_regular = f"rsync -avL --exclude='__absolute__*' {rsync_base} ."
924
+ logger.debug(f"Rsyncing regular files with command: {rsync_cmd_regular}")
925
+ run_rsync_with_retries(rsync_cmd_regular, "regular files")
926
+
927
+ def rsync_absolute_files():
928
+ """Rsync absolute path files to their absolute destinations."""
929
+ # First, do a dry-run to see if __absolute__ directory exists
930
+ check_cmd = f"rsync --list-only {rsync_base}__absolute__/"
931
+ check_resp = subprocess.run(check_cmd, shell=True, capture_output=True, text=True)
932
+
933
+ if check_resp.returncode == 0 and check_resp.stdout.strip():
934
+ # __absolute__ directory exists, sync its contents to root
935
+ # The trick is to sync from __absolute__/ to / which places files in their absolute paths
936
+ rsync_cmd_absolute = f"rsync -avL {rsync_base}__absolute__/ /"
937
+ logger.debug(f"Rsyncing absolute path files with command: {rsync_cmd_absolute}")
938
+ run_rsync_with_retries(rsync_cmd_absolute, "absolute path files")
939
+ else:
940
+ logger.debug("No absolute path files to sync")
941
+
942
+ # Run both rsync operations in parallel
943
+ with ThreadPoolExecutor(max_workers=2) as executor:
944
+ # Submit both tasks
945
+ regular_future = executor.submit(rsync_regular_files)
946
+ absolute_future = executor.submit(rsync_absolute_files)
947
+
948
+ # Wait for both to complete and handle any exceptions
949
+ futures = [regular_future, absolute_future]
950
+ for future in concurrent.futures.as_completed(futures):
951
+ try:
952
+ future.result() # This will raise any exception that occurred
953
+ except Exception as e:
954
+ # Cancel remaining futures if one fails
955
+ for f in futures:
956
+ f.cancel()
957
+ raise e
958
+
959
+ logger.debug("Completed rsync of all files")
960
+
961
+
962
+ #####################################
963
+ ########### App setup ###############
964
+ #####################################
965
+ class HealthCheckFilter(logging.Filter):
966
+ def filter(self, record):
967
+ return not (
968
+ isinstance(record.args, tuple)
969
+ and len(record.args) >= 3
970
+ and ("/health" in record.args[2] or record.args[2] == "/")
971
+ )
972
+
973
+
974
+ class RequestContextFilter(logging.Filter):
975
+ def filter(self, record):
976
+ record.request_id = request_id_ctx_var.get("-")
977
+ record.pod = os.getenv("POD_NAME", "unknown-pod")
978
+ return True
979
+
980
+
981
+ class TerminationCheckMiddleware(BaseHTTPMiddleware):
982
+ """Monitor for termination while request is running and return error if detected."""
983
+
984
+ async def dispatch(self, request: Request, call_next):
985
+ # Skip health checks and metrics endpoints
986
+ if request.url.path in ["/health", "/", "/metrics"]:
987
+ return await call_next(request)
988
+
989
+ # Run the actual request in the background
990
+ import asyncio
991
+
992
+ request_task = asyncio.create_task(call_next(request))
993
+
994
+ # Monitor for termination while request is running
995
+ while not request_task.done():
996
+ # Check if we're terminating
997
+ if TERMINATION_EVENT.is_set() or (
998
+ hasattr(request.app.state, "terminating") and request.app.state.terminating
999
+ ):
1000
+ # Cancel the request task
1001
+ request_task.cancel()
1002
+
1003
+ # Return PodTerminatedError
1004
+ from kubetorch import PodTerminatedError
1005
+ from kubetorch.servers.http.http_server import package_exception
1006
+
1007
+ pod_name = os.environ.get("POD_NAME", "unknown")
1008
+ exc = PodTerminatedError(
1009
+ pod_name=pod_name,
1010
+ reason="SIGTERM",
1011
+ status_code=503,
1012
+ events=[
1013
+ {
1014
+ "timestamp": datetime.now(timezone.utc).isoformat(),
1015
+ "reason": "Terminating",
1016
+ "message": "Pod received SIGTERM signal and is shutting down gracefully",
1017
+ }
1018
+ ],
1019
+ )
1020
+
1021
+ return package_exception(exc)
1022
+
1023
+ # Wait a bit before checking again or for request to complete
1024
+ try:
1025
+ result = await asyncio.wait_for(asyncio.shield(request_task), timeout=0.5)
1026
+ return result
1027
+ except asyncio.TimeoutError:
1028
+ # Request still running after 0.5s, continue loop to check termination again
1029
+ continue
1030
+
1031
+ # Request completed normally
1032
+ return await request_task
1033
+
1034
+
1035
+ class RequestIDMiddleware(BaseHTTPMiddleware):
1036
+ async def dispatch(self, request: Request, call_next):
1037
+ request_id = request.headers.get("X-Request-ID", "-")
1038
+ token = request_id_ctx_var.set(request_id)
1039
+
1040
+ try:
1041
+ response = await call_next(request)
1042
+ return response
1043
+ finally:
1044
+ # Reset the context variable to its default value
1045
+ request_id_ctx_var.reset(token)
1046
+
1047
+
1048
+ class StreamToLogger:
1049
+ def __init__(self, logger, log_level=logging.INFO, original_stream=None):
1050
+ self.logger = logger
1051
+ self.log_level = log_level
1052
+ self.original_stream = original_stream
1053
+ self.linebuf = ""
1054
+
1055
+ def _is_from_logging(self):
1056
+ """Check if the current write call is coming from the logging system"""
1057
+ frame = sys._getframe()
1058
+ while frame:
1059
+ if frame.f_globals.get("__name__", "").startswith("logging"):
1060
+ return True
1061
+ frame = frame.f_back
1062
+ return False
1063
+
1064
+ def write(self, buf):
1065
+ # Check if this is from logging system
1066
+ is_from_logging = self._is_from_logging()
1067
+
1068
+ # Always write to original stream first
1069
+ if self.original_stream:
1070
+ self.original_stream.write(buf)
1071
+ self.original_stream.flush()
1072
+
1073
+ # Skip logging if this is from the logging system to prevent infinite loops
1074
+ if self.logger.name == "print_redirect" and is_from_logging:
1075
+ return
1076
+
1077
+ # Buffer and log complete lines
1078
+ temp_linebuf = self.linebuf + buf
1079
+ self.linebuf = ""
1080
+
1081
+ # Split on newlines but keep carriage returns
1082
+ lines = []
1083
+ current_line = ""
1084
+ for char in temp_linebuf:
1085
+ if char == "\n":
1086
+ lines.append(current_line)
1087
+ current_line = ""
1088
+ else:
1089
+ current_line += char
1090
+
1091
+ # Add any remaining content to linebuf
1092
+ if current_line:
1093
+ self.linebuf = current_line
1094
+
1095
+ # Log complete lines
1096
+ for line in lines:
1097
+ if line:
1098
+ self.logger.log(self.log_level, line)
1099
+
1100
+ def flush(self):
1101
+ if self.original_stream:
1102
+ self.original_stream.flush()
1103
+ if self.linebuf != "":
1104
+ self.logger.log(self.log_level, self.linebuf)
1105
+ self.linebuf = ""
1106
+
1107
+ def isatty(self):
1108
+ # Delegate to the original stream if it exists, else return False
1109
+ if self.original_stream and hasattr(self.original_stream, "isatty"):
1110
+ return self.original_stream.isatty()
1111
+ return False
1112
+
1113
+ def fileno(self):
1114
+ if self.original_stream and hasattr(self.original_stream, "fileno"):
1115
+ return self.original_stream.fileno()
1116
+ raise OSError("Stream does not support fileno()")
1117
+
1118
+ @property
1119
+ def encoding(self):
1120
+ # Return the encoding of the original stream if available, else UTF-8
1121
+ if self.original_stream and hasattr(self.original_stream, "encoding"):
1122
+ return self.original_stream.encoding
1123
+ return "utf-8"
1124
+
1125
+
1126
+ # Save original streams before redirection
1127
+ _original_stdout = sys.stdout
1128
+ _original_stderr = sys.stderr
1129
+
1130
+ # Redirect stdout and stderr to our logger while preserving original streams
1131
+ sys.stdout = StreamToLogger(print_logger, logging.INFO, _original_stdout)
1132
+ sys.stderr = StreamToLogger(print_logger, logging.ERROR, _original_stderr)
1133
+
1134
+
1135
+ @asynccontextmanager
1136
+ async def lifespan(app: FastAPI):
1137
+ """Manage application lifecycle"""
1138
+ import signal
1139
+ import threading
1140
+
1141
+ # Only register signal handlers if we're in the main thread
1142
+ # This allows tests to run without signal handling
1143
+ if threading.current_thread() is threading.main_thread():
1144
+ # Save any existing SIGTERM handler
1145
+ original_sigterm_handler = signal.getsignal(signal.SIGTERM)
1146
+
1147
+ def handle_sigterm(signum, frame):
1148
+ """Handle SIGTERM for graceful shutdown."""
1149
+ logger.info("Received SIGTERM, initiating graceful shutdown...")
1150
+
1151
+ # Mark that we're terminating and interrupt existing requests IMMEDIATELY
1152
+ app.state.terminating = True
1153
+ TERMINATION_EVENT.set()
1154
+
1155
+ # Clean up distributed supervisor to ensure child processes are terminated
1156
+ # This is important because SIGTERM is not propagated to child processes automatically
1157
+ # This runs synchronously and may take 1-2 seconds, but existing requests are already interrupted
1158
+ global DISTRIBUTED_SUPERVISOR
1159
+ if DISTRIBUTED_SUPERVISOR:
1160
+ logger.info("Cleaning up distributed supervisor and child processes...")
1161
+ try:
1162
+ DISTRIBUTED_SUPERVISOR.cleanup()
1163
+ except Exception as e:
1164
+ logger.error(f"Error cleaning up distributed supervisor: {e}")
1165
+
1166
+ # Call the original handler if it exists and isn't the default
1167
+ if original_sigterm_handler and original_sigterm_handler not in (
1168
+ signal.SIG_DFL,
1169
+ signal.SIG_IGN,
1170
+ ):
1171
+ original_sigterm_handler(signum, frame)
1172
+
1173
+ # Register SIGTERM handler
1174
+ signal.signal(signal.SIGTERM, handle_sigterm)
1175
+ app.state.terminating = False
1176
+
1177
+ # Startup
1178
+ ttl = get_inactivity_ttl_annotation()
1179
+ if ttl and KT_OTEL_ENABLED is True:
1180
+ app.state.heartbeat_manager = HeartbeatManager(ttl_seconds=ttl)
1181
+ if app.state.heartbeat_manager:
1182
+ await app.state.heartbeat_manager.start()
1183
+ logger.debug(f"Heartbeat manager started with TTL={ttl}s")
1184
+ elif ttl:
1185
+ logger.warning("TTL annotation found, but OTEL is not enabled, heartbeat disabled")
1186
+ else:
1187
+ logger.debug("No TTL annotation found, heartbeat disabled")
1188
+
1189
+ try:
1190
+ if os.getenv("KT_CALLABLE_TYPE") == "app":
1191
+ run_image_setup()
1192
+ else:
1193
+ load_callable()
1194
+
1195
+ logger.info("Kubetorch Server started.")
1196
+ request_id_ctx_var.set("-") # Reset request_id after launch sequence
1197
+ yield
1198
+
1199
+ except Exception:
1200
+ # We don't want to raise errors like ImportError during startup, as it will cause the server to crash and the
1201
+ # user won't be able to see the error in the logs to debug (e.g. quickly add dependencies or reorganize
1202
+ # imports). Instead, we log it (and a stack trace) and continue, so it will be surfaced to the user when they
1203
+ # call the service.
1204
+
1205
+ # However if this service is frozen, it should just fail because the user isn't debugging the service and there is no
1206
+ # way for the dependencies to be added at runtime.
1207
+ logger.error(traceback.format_exc())
1208
+ request_id_ctx_var.set("-")
1209
+ yield
1210
+
1211
+ finally:
1212
+ # Shutdown
1213
+ manager = getattr(app.state, "heartbeat_manager", None)
1214
+ if manager:
1215
+ await manager.stop()
1216
+ logger.info("Heartbeat manager stopped")
1217
+
1218
+ # Clean up during normal shutdown so we don't leave any hanging processes, which can cause pods to hang
1219
+ # indefinitely. Skip if already cleaned up by SIGTERM handler.
1220
+ if DISTRIBUTED_SUPERVISOR and not getattr(app.state, "terminating", False):
1221
+ DISTRIBUTED_SUPERVISOR.cleanup()
1222
+
1223
+ # Clear any remaining debugging sessions
1224
+ clear_debugging_sessions()
1225
+
1226
+
1227
+ # Add the filter to uvicorn's access logger
1228
+ logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
1229
+ root_logger = logging.getLogger()
1230
+ root_logger.addFilter(RequestContextFilter())
1231
+ for handler in root_logger.handlers:
1232
+ handler.addFilter(RequestContextFilter())
1233
+ print_logger.addFilter(RequestContextFilter())
1234
+
1235
+ app = FastAPI(lifespan=lifespan)
1236
+ app.add_middleware(TerminationCheckMiddleware) # Check termination first
1237
+ app.add_middleware(RequestIDMiddleware)
1238
+
1239
+ # Configure the FastAPI app for metrics first
1240
+ # Method will return None for meter_provider if otel is not enabled
1241
+ app, meter_provider = setup_otel_metrics(app) if KT_OTEL_ENABLED is True else (app, None)
1242
+
1243
+ # instrument metrics
1244
+ if meter_provider is not None:
1245
+ try:
1246
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1247
+
1248
+ logger.info("Instrumenting FastAPI app for metrics only")
1249
+ FastAPIInstrumentor.instrument_app(
1250
+ app,
1251
+ meter_provider=meter_provider,
1252
+ excluded_urls="/,/metrics,/health",
1253
+ )
1254
+ except ImportError:
1255
+ logger.info("OpenTelemetry instrumentation not enabled, skipping metrics instrumentation")
1256
+
1257
+ # add route for fastapi app
1258
+ if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
1259
+ logger.debug("Adding route for path /http")
1260
+ app.add_route(
1261
+ "/http/{path:path}",
1262
+ _http_reverse_proxy,
1263
+ ["GET", "POST", "PUT", "DELETE", "PATCH"],
1264
+ )
1265
+
1266
+
1267
+ #####################################
1268
+ ########## Error Handling ###########
1269
+ #####################################
1270
+ class ErrorResponse(BaseModel):
1271
+ error_type: str
1272
+ message: str
1273
+ traceback: str
1274
+ pod_name: str
1275
+ state: Optional[dict] = None # Optional serialized exception state
1276
+
1277
+
1278
+ # Factor out the exception packaging so we can use it in the handler below and also inside distributed subprocesses
1279
+ def package_exception(exc: Exception):
1280
+ import asyncio
1281
+ import concurrent
1282
+
1283
+ error_type = exc.__class__.__name__
1284
+ trace = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
1285
+
1286
+ # Check if the exception has a status_code attribute (e.g. PodTerminatedError)
1287
+ if hasattr(exc, "status_code"):
1288
+ status_code = exc.status_code
1289
+ elif isinstance(exc, (RequestValidationError, TypeError, AssertionError)):
1290
+ status_code = 422
1291
+ elif isinstance(exc, (ValueError, UnicodeError, json.JSONDecodeError)):
1292
+ status_code = 400
1293
+ elif isinstance(exc, (KeyError, FileNotFoundError)):
1294
+ status_code = 404
1295
+ elif isinstance(exc, PermissionError):
1296
+ status_code = 403
1297
+ elif isinstance(exc, (StarletteHTTPException, HTTPException)):
1298
+ status_code = exc.status_code
1299
+ elif isinstance(exc, (MemoryError, OSError)):
1300
+ status_code = 500
1301
+ elif isinstance(exc, NotImplementedError):
1302
+ status_code = 501
1303
+ elif isinstance(exc, asyncio.TimeoutError):
1304
+ status_code = 504
1305
+ elif isinstance(exc, concurrent.futures.TimeoutError):
1306
+ status_code = 504
1307
+ else:
1308
+ status_code = 500
1309
+
1310
+ # Try to serialize exception state if it has __getstate__
1311
+ state = None
1312
+ if hasattr(exc, "__getstate__"):
1313
+ try:
1314
+ state = exc.__getstate__()
1315
+ except Exception as e:
1316
+ logger.debug(f"Could not serialize exception state for {error_type}: {e}")
1317
+
1318
+ error_response = ErrorResponse(
1319
+ error_type=error_type,
1320
+ message=str(exc),
1321
+ traceback=trace,
1322
+ pod_name=os.getenv("POD_NAME"),
1323
+ state=state,
1324
+ )
1325
+
1326
+ return JSONResponse(status_code=status_code, content=error_response.model_dump())
1327
+
1328
+
1329
+ @app.exception_handler(Exception)
1330
+ async def generic_exception_handler(request: Request, exc: Exception):
1331
+ return package_exception(exc)
1332
+
1333
+
1334
+ @app.post("/_reload_image", response_class=JSONResponse)
1335
+ def _reload_image(
1336
+ request: Request,
1337
+ deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
1338
+ ):
1339
+ """
1340
+ Endpoint to reload the image and metadata configuration.
1341
+ This is used to reload the image in cases where we're not calling the callable directly,
1342
+ e.g. kt.app and Ray workers.
1343
+ """
1344
+ global _LAST_DEPLOYED
1345
+ deployed_time = (
1346
+ datetime.fromisoformat(deployed_as_of).timestamp() if deployed_as_of else datetime.now(timezone.utc).timestamp()
1347
+ )
1348
+ run_image_setup(deployed_time)
1349
+ _LAST_DEPLOYED = deployed_time
1350
+ return JSONResponse(
1351
+ status_code=200,
1352
+ content={"message": "Image and metadata reloaded successfully."},
1353
+ )
1354
+
1355
+
1356
+ @app.post("/{cls_or_fn_name}", response_class=JSONResponse)
1357
+ @app.post("/{cls_or_fn_name}/{method_name}", response_class=JSONResponse)
1358
+ async def run_callable(
1359
+ request: Request,
1360
+ cls_or_fn_name: str,
1361
+ method_name: Optional[str] = None,
1362
+ distributed_subcall=False,
1363
+ debug_port: Optional[int] = None,
1364
+ params: Optional[Union[Dict, str]] = Body(default=None),
1365
+ deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
1366
+ serialization: str = Header("json", alias="X-Serialization"),
1367
+ ):
1368
+ if cls_or_fn_name != os.environ["KT_CLS_OR_FN_NAME"]:
1369
+ raise HTTPException(
1370
+ status_code=404,
1371
+ detail=f"Callable '{cls_or_fn_name}' not found in metadata configuration. Found '{os.environ['KT_CLS_OR_FN_NAME']}' instead",
1372
+ )
1373
+
1374
+ # NOTE: The distributed replica processes (e.g. PyTorchProcess:run) rely on this running here even though
1375
+ # they will reconstruct the callable themselves, because they skip image reloading as a performance optimization.
1376
+ # Run load_callable in executor since it may do file I/O and other blocking operations
1377
+ callable_obj = await run_in_executor_with_context(None, load_callable, deployed_as_of)
1378
+
1379
+ # If this is a distributed call (and not a subcall from a different distributed replica),
1380
+ # and the type of distribution which requires a special call method (e.g. SIMD), use the
1381
+ # distributed supervisor to handle the call
1382
+ if DISTRIBUTED_SUPERVISOR and DISTRIBUTED_SUPERVISOR.intercept_call():
1383
+ # Run the blocking distributed call in executor to avoid blocking the event loop
1384
+ result = await run_in_executor_with_context(
1385
+ None,
1386
+ DISTRIBUTED_SUPERVISOR.call_distributed,
1387
+ request,
1388
+ cls_or_fn_name,
1389
+ method_name,
1390
+ params,
1391
+ distributed_subcall,
1392
+ debug_port,
1393
+ deployed_as_of,
1394
+ )
1395
+ clear_debugging_sessions()
1396
+ return result
1397
+
1398
+ # If this is not a distributed call, or the distribution type does not require special handling,
1399
+ # run the callable directly
1400
+ result = await run_callable_internal(
1401
+ callable_obj=callable_obj,
1402
+ cls_or_fn_name=cls_or_fn_name,
1403
+ method_name=method_name,
1404
+ params=params,
1405
+ serialization=serialization,
1406
+ debug_port=debug_port,
1407
+ )
1408
+ return result
1409
+
1410
+
1411
+ async def run_callable_internal(
1412
+ callable_obj: Callable,
1413
+ cls_or_fn_name: str,
1414
+ method_name: Optional[str] = None,
1415
+ params: Optional[Union[Dict, str]] = Body(default=None),
1416
+ serialization: str = "json",
1417
+ debug_port: Optional[int] = None,
1418
+ ):
1419
+ # Check if serialization is allowed
1420
+ allowed_serialization = os.getenv("KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION).split(",")
1421
+ if serialization not in allowed_serialization:
1422
+ raise HTTPException(
1423
+ status_code=400,
1424
+ detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
1425
+ )
1426
+
1427
+ # Process the call
1428
+ args = []
1429
+ kwargs = {}
1430
+ if params:
1431
+ if serialization == "pickle":
1432
+ # Handle pickle serialization - extract data from dictionary wrapper
1433
+ if isinstance(params, dict) and "data" in params:
1434
+ encoded_data = params.pop("data")
1435
+ pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
1436
+ param_args = pickle.loads(pickled_data)
1437
+ # data is unpickled in the format {"args": args, "kwargs": kwargs}
1438
+ params.update(param_args)
1439
+ elif isinstance(params, str):
1440
+ # Fallback for direct string
1441
+ pickled_data = base64.b64decode(params.encode("utf-8"))
1442
+ params = pickle.loads(pickled_data)
1443
+
1444
+ # Default JSON handling
1445
+ args = params.get("args", [])
1446
+ kwargs = params.get("kwargs", {})
1447
+
1448
+ if method_name:
1449
+ if not hasattr(callable_obj, method_name):
1450
+ raise HTTPException(
1451
+ status_code=404,
1452
+ detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
1453
+ )
1454
+ user_method = getattr(callable_obj, method_name)
1455
+ else:
1456
+ user_method = callable_obj
1457
+
1458
+ import inspect
1459
+
1460
+ # Check if the user method is async
1461
+ is_async_method = inspect.iscoroutinefunction(user_method)
1462
+
1463
+ callable_name = f"{cls_or_fn_name}.{method_name}" if method_name else cls_or_fn_name
1464
+ if debug_port:
1465
+ logger.info(f"Debugging remote callable {callable_name} on port {debug_port}")
1466
+ deep_breakpoint(debug_port)
1467
+ # If using the debugger, step in here ("s") to enter your function/class method.
1468
+ if is_async_method:
1469
+ result = await user_method(*args, **kwargs)
1470
+ else:
1471
+ # Run sync method in thread pool to avoid blocking
1472
+ # Use lambda to properly pass both args and kwargs
1473
+ result = await run_in_executor_with_context(None, lambda: user_method(*args, **kwargs))
1474
+ else:
1475
+ logger.debug(f"Calling remote callable {callable_name}")
1476
+ if is_async_method:
1477
+ result = await user_method(*args, **kwargs)
1478
+ else:
1479
+ # Run sync method in thread pool to avoid blocking
1480
+ # Use lambda to properly pass both args and kwargs
1481
+ result = await run_in_executor_with_context(None, lambda: user_method(*args, **kwargs))
1482
+
1483
+ # Handle case where sync method returns an awaitable (e.g., from an async framework)
1484
+ # This is less common but can happen with some async libraries
1485
+ if isinstance(result, Awaitable):
1486
+ result = await result
1487
+
1488
+ # Serialize response based on format
1489
+ if serialization == "pickle":
1490
+ try:
1491
+ pickled_result = pickle.dumps(result)
1492
+ encoded_result = base64.b64encode(pickled_result).decode("utf-8")
1493
+ result = {"data": encoded_result}
1494
+ except Exception as e:
1495
+ logger.error(f"Failed to pickle result: {str(e)}")
1496
+ raise SerializationError(f"Result could not be serialized with pickle: {str(e)}")
1497
+ else:
1498
+ # Default JSON serialization
1499
+ try:
1500
+ json.dumps(result)
1501
+ except (TypeError, ValueError) as e:
1502
+ logger.error(f"Result is not JSON serializable: {str(e)}")
1503
+ raise SerializationError(f"Result could not be serialized to JSON: {str(e)}")
1504
+
1505
+ clear_debugging_sessions()
1506
+
1507
+ return result
1508
+
1509
+
1510
+ def run_callable_internal_sync(
1511
+ callable_obj: Callable,
1512
+ cls_or_fn_name: str,
1513
+ method_name: Optional[str] = None,
1514
+ params: Optional[Union[Dict, str]] = None,
1515
+ serialization: str = "json",
1516
+ debug_port: Optional[int] = None,
1517
+ ):
1518
+ """Synchronous wrapper for run_callable_internal, used by distributed subprocesses."""
1519
+ import asyncio
1520
+ import inspect
1521
+
1522
+ # Check if serialization is allowed
1523
+ allowed_serialization = os.getenv("KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION).split(",")
1524
+ if serialization not in allowed_serialization:
1525
+ raise HTTPException(
1526
+ status_code=400,
1527
+ detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
1528
+ )
1529
+
1530
+ # Process the call
1531
+ args = []
1532
+ kwargs = {}
1533
+ if params:
1534
+ if serialization == "pickle":
1535
+ # Handle pickle serialization - extract data from dictionary wrapper
1536
+ if isinstance(params, dict) and "data" in params:
1537
+ encoded_data = params.pop("data")
1538
+ pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
1539
+ param_args = pickle.loads(pickled_data)
1540
+ # data is unpickled in the format {"args": args, "kwargs": kwargs}
1541
+ params.update(param_args)
1542
+ elif isinstance(params, str):
1543
+ # Fallback for direct string
1544
+ pickled_data = base64.b64decode(params.encode("utf-8"))
1545
+ params = pickle.loads(pickled_data)
1546
+
1547
+ # Default JSON handling
1548
+ args = params.get("args", [])
1549
+ kwargs = params.get("kwargs", {})
1550
+
1551
+ if method_name:
1552
+ if not hasattr(callable_obj, method_name):
1553
+ raise HTTPException(
1554
+ status_code=404,
1555
+ detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
1556
+ )
1557
+ user_method = getattr(callable_obj, method_name)
1558
+ else:
1559
+ user_method = callable_obj
1560
+
1561
+ # Check if the user method is async
1562
+ is_async_method = inspect.iscoroutinefunction(user_method)
1563
+
1564
+ callable_name = f"{cls_or_fn_name}.{method_name}" if method_name else cls_or_fn_name
1565
+ if debug_port:
1566
+ logger.info(f"Debugging remote callable {callable_name} on port {debug_port}")
1567
+ deep_breakpoint(debug_port)
1568
+ # If using the debugger, step in here ("s") to enter your function/class method.
1569
+ if is_async_method:
1570
+ # For async methods in sync context, we need to run them in a new event loop
1571
+ result = asyncio.run(user_method(*args, **kwargs))
1572
+ else:
1573
+ result = user_method(*args, **kwargs)
1574
+ else:
1575
+ logger.debug(f"Calling remote callable {callable_name}")
1576
+ if is_async_method:
1577
+ # For async methods in sync context, we need to run them in a new event loop
1578
+ result = asyncio.run(user_method(*args, **kwargs))
1579
+ else:
1580
+ result = user_method(*args, **kwargs)
1581
+
1582
+ # Handle case where sync method returns an awaitable
1583
+ if isinstance(result, Awaitable):
1584
+ result = asyncio.run(result)
1585
+
1586
+ # Serialize response based on format
1587
+ if serialization == "pickle":
1588
+ try:
1589
+ pickled_result = pickle.dumps(result)
1590
+ encoded_result = base64.b64encode(pickled_result).decode("utf-8")
1591
+ result = {"data": encoded_result}
1592
+ except Exception as e:
1593
+ logger.error(f"Failed to pickle result: {str(e)}")
1594
+ raise SerializationError(f"Result could not be serialized with pickle: {str(e)}")
1595
+ else:
1596
+ # Default JSON serialization
1597
+ try:
1598
+ json.dumps(result)
1599
+ except (TypeError, ValueError) as e:
1600
+ logger.error(f"Result is not JSON serializable: {str(e)}")
1601
+ raise SerializationError(f"Result could not be serialized to JSON: {str(e)}")
1602
+
1603
+ clear_debugging_sessions()
1604
+
1605
+ return result
1606
+
1607
+
1608
+ @app.get("/health", include_in_schema=False)
1609
+ @app.get("/", include_in_schema=False)
1610
+ def health():
1611
+ return {"status": "healthy"}
1612
+
1613
+
1614
+ if __name__ == "__main__" and not is_running_in_container():
1615
+ # NOTE: this will only run in local development, otherwise we start the uvicorn server in the pod template setup
1616
+ import uvicorn
1617
+ from dotenv import load_dotenv
1618
+
1619
+ load_dotenv()
1620
+
1621
+ logger.info("Starting HTTP server")
1622
+ uvicorn.run(app, host="0.0.0.0", port=os.environ.get("KT_SERVER_PORT", 32300))