kubetorch 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kubetorch might be problematic. Click here for more details.
- kubetorch/__init__.py +60 -0
- kubetorch/cli.py +1985 -0
- kubetorch/cli_utils.py +1025 -0
- kubetorch/config.py +453 -0
- kubetorch/constants.py +18 -0
- kubetorch/docs/Makefile +18 -0
- kubetorch/docs/__init__.py +0 -0
- kubetorch/docs/_ext/json_globaltoc.py +42 -0
- kubetorch/docs/api/cli.rst +10 -0
- kubetorch/docs/api/python/app.rst +21 -0
- kubetorch/docs/api/python/cls.rst +19 -0
- kubetorch/docs/api/python/compute.rst +25 -0
- kubetorch/docs/api/python/config.rst +11 -0
- kubetorch/docs/api/python/fn.rst +19 -0
- kubetorch/docs/api/python/image.rst +14 -0
- kubetorch/docs/api/python/secret.rst +18 -0
- kubetorch/docs/api/python/volumes.rst +13 -0
- kubetorch/docs/api/python.rst +101 -0
- kubetorch/docs/conf.py +69 -0
- kubetorch/docs/index.rst +20 -0
- kubetorch/docs/requirements.txt +5 -0
- kubetorch/globals.py +285 -0
- kubetorch/logger.py +59 -0
- kubetorch/resources/__init__.py +0 -0
- kubetorch/resources/callables/__init__.py +0 -0
- kubetorch/resources/callables/cls/__init__.py +0 -0
- kubetorch/resources/callables/cls/cls.py +157 -0
- kubetorch/resources/callables/fn/__init__.py +0 -0
- kubetorch/resources/callables/fn/fn.py +133 -0
- kubetorch/resources/callables/module.py +1416 -0
- kubetorch/resources/callables/utils.py +174 -0
- kubetorch/resources/compute/__init__.py +0 -0
- kubetorch/resources/compute/app.py +261 -0
- kubetorch/resources/compute/compute.py +2596 -0
- kubetorch/resources/compute/decorators.py +139 -0
- kubetorch/resources/compute/rbac.py +74 -0
- kubetorch/resources/compute/utils.py +1114 -0
- kubetorch/resources/compute/websocket.py +137 -0
- kubetorch/resources/images/__init__.py +1 -0
- kubetorch/resources/images/image.py +414 -0
- kubetorch/resources/images/images.py +74 -0
- kubetorch/resources/secrets/__init__.py +2 -0
- kubetorch/resources/secrets/kubernetes_secrets_client.py +412 -0
- kubetorch/resources/secrets/provider_secrets/__init__.py +0 -0
- kubetorch/resources/secrets/provider_secrets/anthropic_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/aws_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/azure_secret.py +14 -0
- kubetorch/resources/secrets/provider_secrets/cohere_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/gcp_secret.py +16 -0
- kubetorch/resources/secrets/provider_secrets/github_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/huggingface_secret.py +20 -0
- kubetorch/resources/secrets/provider_secrets/kubeconfig_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/lambda_secret.py +13 -0
- kubetorch/resources/secrets/provider_secrets/langchain_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/openai_secret.py +11 -0
- kubetorch/resources/secrets/provider_secrets/pinecone_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/providers.py +93 -0
- kubetorch/resources/secrets/provider_secrets/ssh_secret.py +12 -0
- kubetorch/resources/secrets/provider_secrets/wandb_secret.py +11 -0
- kubetorch/resources/secrets/secret.py +238 -0
- kubetorch/resources/secrets/secret_factory.py +70 -0
- kubetorch/resources/secrets/utils.py +209 -0
- kubetorch/resources/volumes/__init__.py +0 -0
- kubetorch/resources/volumes/volume.py +365 -0
- kubetorch/servers/__init__.py +0 -0
- kubetorch/servers/http/__init__.py +0 -0
- kubetorch/servers/http/distributed_utils.py +3223 -0
- kubetorch/servers/http/http_client.py +730 -0
- kubetorch/servers/http/http_server.py +1788 -0
- kubetorch/servers/http/server_metrics.py +278 -0
- kubetorch/servers/http/utils.py +728 -0
- kubetorch/serving/__init__.py +0 -0
- kubetorch/serving/autoscaling.py +173 -0
- kubetorch/serving/base_service_manager.py +363 -0
- kubetorch/serving/constants.py +83 -0
- kubetorch/serving/deployment_service_manager.py +478 -0
- kubetorch/serving/knative_service_manager.py +519 -0
- kubetorch/serving/raycluster_service_manager.py +582 -0
- kubetorch/serving/service_manager.py +18 -0
- kubetorch/serving/templates/deployment_template.yaml +17 -0
- kubetorch/serving/templates/knative_service_template.yaml +19 -0
- kubetorch/serving/templates/kt_setup_template.sh.j2 +81 -0
- kubetorch/serving/templates/pod_template.yaml +194 -0
- kubetorch/serving/templates/raycluster_service_template.yaml +42 -0
- kubetorch/serving/templates/raycluster_template.yaml +35 -0
- kubetorch/serving/templates/service_template.yaml +21 -0
- kubetorch/serving/templates/workerset_template.yaml +36 -0
- kubetorch/serving/utils.py +377 -0
- kubetorch/utils.py +284 -0
- kubetorch-0.2.0.dist-info/METADATA +121 -0
- kubetorch-0.2.0.dist-info/RECORD +93 -0
- kubetorch-0.2.0.dist-info/WHEEL +4 -0
- kubetorch-0.2.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,1788 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import importlib
|
|
3
|
+
import importlib.util
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import logging.config
|
|
7
|
+
import os
|
|
8
|
+
import pickle
|
|
9
|
+
import random
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
import traceback
|
|
15
|
+
from contextlib import asynccontextmanager
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Awaitable, Callable, Dict, Optional, Union
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import httpx
|
|
22
|
+
except:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
from fastapi import Body, FastAPI, Header, HTTPException, Request
|
|
26
|
+
|
|
27
|
+
from fastapi.exceptions import RequestValidationError
|
|
28
|
+
from fastapi.responses import JSONResponse
|
|
29
|
+
|
|
30
|
+
from pydantic import BaseModel
|
|
31
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from server_metrics import (
|
|
35
|
+
get_inactivity_ttl_annotation,
|
|
36
|
+
HeartbeatManager,
|
|
37
|
+
setup_otel_metrics,
|
|
38
|
+
)
|
|
39
|
+
from utils import (
|
|
40
|
+
clear_debugging_sessions,
|
|
41
|
+
deep_breakpoint,
|
|
42
|
+
DEFAULT_ALLOWED_SERIALIZATION,
|
|
43
|
+
ensure_structured_logging,
|
|
44
|
+
is_running_in_kubernetes,
|
|
45
|
+
LOG_CONFIG,
|
|
46
|
+
request_id_ctx_var,
|
|
47
|
+
RSYNC_PORT,
|
|
48
|
+
wait_for_app_start,
|
|
49
|
+
)
|
|
50
|
+
except ImportError:
|
|
51
|
+
from .server_metrics import (
|
|
52
|
+
get_inactivity_ttl_annotation,
|
|
53
|
+
HeartbeatManager,
|
|
54
|
+
setup_otel_metrics,
|
|
55
|
+
)
|
|
56
|
+
from .utils import (
|
|
57
|
+
clear_debugging_sessions,
|
|
58
|
+
deep_breakpoint,
|
|
59
|
+
DEFAULT_ALLOWED_SERIALIZATION,
|
|
60
|
+
ensure_structured_logging,
|
|
61
|
+
is_running_in_kubernetes,
|
|
62
|
+
LOG_CONFIG,
|
|
63
|
+
request_id_ctx_var,
|
|
64
|
+
RSYNC_PORT,
|
|
65
|
+
wait_for_app_start,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
from starlette.background import BackgroundTask
|
|
69
|
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
70
|
+
from starlette.responses import StreamingResponse
|
|
71
|
+
|
|
72
|
+
logging.config.dictConfig(LOG_CONFIG)
|
|
73
|
+
|
|
74
|
+
# Set up our structured JSON logging
|
|
75
|
+
ensure_structured_logging()
|
|
76
|
+
|
|
77
|
+
# Create the print logger AFTER ensure_structured_logging so it inherits handlers
|
|
78
|
+
print_logger = logging.getLogger("print_redirect")
|
|
79
|
+
|
|
80
|
+
logger = logging.getLogger(__name__)
|
|
81
|
+
# Set log level based on environment variable
|
|
82
|
+
# Don't default the log_level
|
|
83
|
+
kt_log_level = os.getenv("KT_LOG_LEVEL")
|
|
84
|
+
if kt_log_level:
|
|
85
|
+
kt_log_level = kt_log_level.upper()
|
|
86
|
+
logger.setLevel(getattr(logging, kt_log_level, logging.INFO))
|
|
87
|
+
|
|
88
|
+
_CACHED_CALLABLES = {}
|
|
89
|
+
_LAST_DEPLOYED = 0
|
|
90
|
+
_CACHED_IMAGE = []
|
|
91
|
+
DISTRIBUTED_SUPERVISOR = None
|
|
92
|
+
APP_PROCESS = None
|
|
93
|
+
_CALLABLE_LOAD_LOCK = threading.Lock() # Lock for thread-safe callable loading
|
|
94
|
+
LOKI_HOST = os.environ.get("LOKI_HOST", "loki-gateway.kubetorch.svc.cluster.local")
|
|
95
|
+
LOKI_PORT = int(os.environ.get("LOKI_PORT", 80)) # Default Loki port
|
|
96
|
+
KT_OTEL_ENABLED = os.environ.get("KT_OTEL_ENABLED", "False").lower() == "true"
|
|
97
|
+
KT_TRACING_ENABLED = (
|
|
98
|
+
os.environ.get("KT_TRACING_ENABLED", "").lower() != "false"
|
|
99
|
+
) # Defaults to True
|
|
100
|
+
|
|
101
|
+
# Global termination event that can be checked by running requests
|
|
102
|
+
TERMINATION_EVENT = threading.Event()
|
|
103
|
+
# Create a client for FastAPI service
|
|
104
|
+
|
|
105
|
+
# Set the python breakpoint to kt.deep_breakpoint
|
|
106
|
+
os.environ["PYTHONBREAKPOINT"] = "kubetorch.deep_breakpoint"
|
|
107
|
+
|
|
108
|
+
request_id_ctx_var.set(os.getenv("KT_LAUNCH_ID", "-"))
|
|
109
|
+
|
|
110
|
+
#####################################
|
|
111
|
+
######### Instrument Traces #########
|
|
112
|
+
#####################################
|
|
113
|
+
instrument_traces = KT_TRACING_ENABLED
|
|
114
|
+
if instrument_traces:
|
|
115
|
+
try:
|
|
116
|
+
from opentelemetry import trace
|
|
117
|
+
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
|
118
|
+
OTLPSpanExporter,
|
|
119
|
+
)
|
|
120
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
121
|
+
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
|
122
|
+
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
|
123
|
+
from opentelemetry.sdk.resources import Resource
|
|
124
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
125
|
+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
126
|
+
except ImportError:
|
|
127
|
+
instrument_traces = False
|
|
128
|
+
|
|
129
|
+
if instrument_traces:
|
|
130
|
+
logger.info("Configuring OTLP exporter to instrument traces")
|
|
131
|
+
trace.set_tracer_provider(
|
|
132
|
+
TracerProvider(
|
|
133
|
+
resource=Resource.create(
|
|
134
|
+
{
|
|
135
|
+
"service.name": os.environ.get("OTEL_SERVICE_NAME"),
|
|
136
|
+
"service.instance.id": os.environ.get("POD_NAME"),
|
|
137
|
+
}
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
span_processor = BatchSpanProcessor(
|
|
142
|
+
OTLPSpanExporter(
|
|
143
|
+
endpoint=os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"),
|
|
144
|
+
insecure=True,
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
trace.get_tracer_provider().add_span_processor(span_processor)
|
|
148
|
+
RequestsInstrumentor().instrument()
|
|
149
|
+
LoggingInstrumentor().instrument()
|
|
150
|
+
|
|
151
|
+
#####################################
|
|
152
|
+
########### Proxy Helpers ###########
|
|
153
|
+
#####################################
|
|
154
|
+
if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
|
|
155
|
+
port = os.getenv("KT_APP_PORT")
|
|
156
|
+
logger.info(f"Creating /http reverse proxy to: http://localhost:{port}/")
|
|
157
|
+
proxy_client = httpx.AsyncClient(base_url=f"http://localhost:{port}/", timeout=None)
|
|
158
|
+
else:
|
|
159
|
+
proxy_client = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
async def _http_reverse_proxy(request: Request):
|
|
163
|
+
"""Reverse proxy for /http/* routes to FastAPI service on its port"""
|
|
164
|
+
# Extract the endpoint name from the path
|
|
165
|
+
# request.path_params["path"] will contain everything after /http/
|
|
166
|
+
endpoint_path = request.path_params["path"]
|
|
167
|
+
|
|
168
|
+
# Build the URL for the FastAPI service
|
|
169
|
+
url = httpx.URL(path=f"/{endpoint_path}", query=request.url.query.encode("utf-8"))
|
|
170
|
+
|
|
171
|
+
# Build the request to forward to FastAPI
|
|
172
|
+
rp_req = proxy_client.build_request(
|
|
173
|
+
request.method, url, headers=request.headers.raw, content=await request.body()
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Send the request and get streaming response
|
|
177
|
+
rp_resp = await proxy_client.send(rp_req, stream=True)
|
|
178
|
+
|
|
179
|
+
# Return streaming response
|
|
180
|
+
return StreamingResponse(
|
|
181
|
+
rp_resp.aiter_raw(),
|
|
182
|
+
status_code=rp_resp.status_code,
|
|
183
|
+
headers=rp_resp.headers,
|
|
184
|
+
background=BackgroundTask(rp_resp.aclose),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
#####################################
|
|
189
|
+
########### Cache Helpers ###########
|
|
190
|
+
#####################################
|
|
191
|
+
def clear_cache():
|
|
192
|
+
global _CACHED_CALLABLES
|
|
193
|
+
|
|
194
|
+
logger.debug("Clearing callables cache.")
|
|
195
|
+
_CACHED_CALLABLES.clear()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def cached_image_setup():
|
|
199
|
+
logger.debug("Starting cached image setup.")
|
|
200
|
+
global _CACHED_IMAGE
|
|
201
|
+
global APP_PROCESS
|
|
202
|
+
|
|
203
|
+
dockerfile_path = kt_directory() / "image.dockerfile"
|
|
204
|
+
with open(dockerfile_path, "r") as file:
|
|
205
|
+
lines = file.readlines()
|
|
206
|
+
lines = [line.strip() for line in lines]
|
|
207
|
+
|
|
208
|
+
# find first line where image differs from cache and update cache
|
|
209
|
+
cache_mismatch_index = -1
|
|
210
|
+
cmd_mismatch = False
|
|
211
|
+
for i, (new_line, cached_line) in enumerate(zip(lines, _CACHED_IMAGE)):
|
|
212
|
+
if new_line.startswith("CMD"):
|
|
213
|
+
cmd_mismatch = True
|
|
214
|
+
|
|
215
|
+
if new_line != cached_line or "# override" in new_line or cmd_mismatch:
|
|
216
|
+
cache_mismatch_index = i
|
|
217
|
+
break
|
|
218
|
+
if cache_mismatch_index == -1:
|
|
219
|
+
if len(lines) != len(_CACHED_IMAGE):
|
|
220
|
+
cache_mismatch_index = min(len(lines), len(_CACHED_IMAGE))
|
|
221
|
+
else:
|
|
222
|
+
cache_mismatch_index = len(lines)
|
|
223
|
+
_CACHED_IMAGE = lines
|
|
224
|
+
|
|
225
|
+
if cache_mismatch_index == len(lines):
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
if not (cache_mismatch_index == len(lines) - 1 and cmd_mismatch):
|
|
229
|
+
logger.info("Running image setup.")
|
|
230
|
+
else:
|
|
231
|
+
logger.debug("Skipping image setup steps, no changes detected.")
|
|
232
|
+
|
|
233
|
+
# Grab the current list of installed dependencies with pip freeze to check if anything changes (we need to send a
|
|
234
|
+
# SIGHUP to restart the server if so)
|
|
235
|
+
start_deps = None
|
|
236
|
+
import subprocess
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
res = subprocess.run(
|
|
240
|
+
["pip", "freeze"],
|
|
241
|
+
capture_output=True,
|
|
242
|
+
text=True,
|
|
243
|
+
check=True,
|
|
244
|
+
)
|
|
245
|
+
start_deps = res.stdout.splitlines()
|
|
246
|
+
except subprocess.CalledProcessError as e:
|
|
247
|
+
logger.error(f"Failed to run pip freeze: {e}")
|
|
248
|
+
|
|
249
|
+
# only run image setup steps starting from cache mismatch point
|
|
250
|
+
kt_pip_cmd = None
|
|
251
|
+
for line in lines[cache_mismatch_index:]:
|
|
252
|
+
command = ""
|
|
253
|
+
if line.strip().startswith("#"):
|
|
254
|
+
continue # ignore comments
|
|
255
|
+
if line.startswith("RUN") or line.startswith("CMD"):
|
|
256
|
+
command = line[len("RUN ") :]
|
|
257
|
+
|
|
258
|
+
if command.startswith("$KT_PIP_INSTALL_CMD"):
|
|
259
|
+
kt_pip_cmd = kt_pip_cmd or _get_kt_pip_install_cmd()
|
|
260
|
+
command = command.replace("$KT_PIP_INSTALL_CMD", kt_pip_cmd)
|
|
261
|
+
elif line.startswith("COPY"):
|
|
262
|
+
_, source, dest = line.split()
|
|
263
|
+
# COPY instructions are essentially no-ops since rsync_file_updates()
|
|
264
|
+
# already placed files in their correct locations.
|
|
265
|
+
# But we verify the files exist and log the absolute paths for clarity.
|
|
266
|
+
|
|
267
|
+
# Determine the actual absolute destination path
|
|
268
|
+
if dest and dest.startswith("/"):
|
|
269
|
+
# Already absolute
|
|
270
|
+
dest_path = Path(dest)
|
|
271
|
+
elif dest and dest.startswith("~/"):
|
|
272
|
+
# Tilde prefix - strip it and treat as relative to cwd
|
|
273
|
+
dest_path = Path.cwd() / dest[2:]
|
|
274
|
+
else:
|
|
275
|
+
# Relative to working directory (including explicit basenames)
|
|
276
|
+
dest_path = Path.cwd() / dest
|
|
277
|
+
|
|
278
|
+
# Verify the destination exists (it should have been rsync'd)
|
|
279
|
+
if dest_path.exists():
|
|
280
|
+
logger.info(f"Copied {source} to {dest_path.absolute()}")
|
|
281
|
+
else:
|
|
282
|
+
raise FileNotFoundError(
|
|
283
|
+
f"COPY {source} {dest} failed: destination {dest_path.absolute()} does not exist. "
|
|
284
|
+
f"This likely means the rsync operation failed to sync the files correctly."
|
|
285
|
+
)
|
|
286
|
+
elif line.startswith("ENV"):
|
|
287
|
+
# Need to handle the case where the env var is being set to "" (empty string)
|
|
288
|
+
line_vals = line.split(" ", 2)
|
|
289
|
+
if len(line_vals) < 2: # ENV line must have at least key
|
|
290
|
+
raise ValueError("ENV line cannot be empty")
|
|
291
|
+
if len(line_vals) == 2: # ENV line with just key
|
|
292
|
+
key = line_vals[1]
|
|
293
|
+
val = ""
|
|
294
|
+
elif len(line_vals) == 3: # ENV line with key and value
|
|
295
|
+
key, val = line_vals[1], line_vals[2]
|
|
296
|
+
|
|
297
|
+
# Expand environment variables in the value
|
|
298
|
+
# This supports patterns like $VAR, ${VAR}, and $VAR:default_value
|
|
299
|
+
expanded_val = os.path.expandvars(val)
|
|
300
|
+
|
|
301
|
+
if key not in [
|
|
302
|
+
"KT_FILE_PATH",
|
|
303
|
+
"KT_MODULE_NAME",
|
|
304
|
+
"KT_CLS_OR_FN_NAME",
|
|
305
|
+
"KT_INIT_ARGS",
|
|
306
|
+
"KT_CALLABLE_TYPE",
|
|
307
|
+
"KT_DISTRIBUTED_CONFIG",
|
|
308
|
+
]:
|
|
309
|
+
logger.info(f"Setting env var {key}")
|
|
310
|
+
os.environ[key] = expanded_val
|
|
311
|
+
# If the env var is specifically KT_LOG_LEVEL, we need to update the logger level
|
|
312
|
+
if key == "KT_LOG_LEVEL":
|
|
313
|
+
global kt_log_level
|
|
314
|
+
kt_log_level = expanded_val.upper()
|
|
315
|
+
logger.setLevel(kt_log_level)
|
|
316
|
+
logger.info(f"Updated log level to {kt_log_level}")
|
|
317
|
+
elif line.startswith("FROM"):
|
|
318
|
+
continue
|
|
319
|
+
elif line:
|
|
320
|
+
raise ValueError(f"Unrecognized image setup instruction {line}")
|
|
321
|
+
|
|
322
|
+
if command:
|
|
323
|
+
is_app_cmd = line.startswith("CMD")
|
|
324
|
+
if is_app_cmd:
|
|
325
|
+
logger.info(f"Running app command: {command}")
|
|
326
|
+
else:
|
|
327
|
+
logger.info(f"Running image setup step: {command}")
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
# Use subprocess.Popen to capture output and redirect through StreamToLogger
|
|
331
|
+
env = os.environ.copy()
|
|
332
|
+
env["PYTHONUNBUFFERED"] = "1"
|
|
333
|
+
|
|
334
|
+
if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
|
|
335
|
+
if APP_PROCESS and APP_PROCESS.poll() is None:
|
|
336
|
+
APP_PROCESS.kill()
|
|
337
|
+
|
|
338
|
+
process = subprocess.Popen(
|
|
339
|
+
command,
|
|
340
|
+
shell=True,
|
|
341
|
+
stdout=subprocess.PIPE,
|
|
342
|
+
stderr=subprocess.PIPE,
|
|
343
|
+
universal_newlines=True,
|
|
344
|
+
bufsize=1,
|
|
345
|
+
env=env,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if is_app_cmd and os.getenv("KT_CALLABLE_TYPE") == "app":
|
|
349
|
+
APP_PROCESS = process
|
|
350
|
+
|
|
351
|
+
# Collect stderr for potential error logging
|
|
352
|
+
import threading
|
|
353
|
+
|
|
354
|
+
stderr_lines = []
|
|
355
|
+
stderr_lock = threading.Lock()
|
|
356
|
+
|
|
357
|
+
# Stream stdout and stderr in real-time
|
|
358
|
+
# We need to do all this so the stdout and stderr are prints with the correct formatting
|
|
359
|
+
# for our queries. Without it they just flow straight to system stdout and stderr without any
|
|
360
|
+
|
|
361
|
+
def stream_output(pipe, log_func, request_id, collect_stderr=False):
|
|
362
|
+
request_id_ctx_var.set(request_id)
|
|
363
|
+
for line in iter(pipe.readline, ""):
|
|
364
|
+
if line:
|
|
365
|
+
stripped_line = line.rstrip()
|
|
366
|
+
log_func(stripped_line)
|
|
367
|
+
|
|
368
|
+
# Collect stderr lines for potential error logging
|
|
369
|
+
if collect_stderr:
|
|
370
|
+
with stderr_lock:
|
|
371
|
+
stderr_lines.append(stripped_line.lstrip("ERROR: "))
|
|
372
|
+
pipe.close()
|
|
373
|
+
|
|
374
|
+
# Start streaming threads
|
|
375
|
+
current_request_id = request_id_ctx_var.get("-")
|
|
376
|
+
|
|
377
|
+
stderr_log_func = logger.error if is_app_cmd else logger.debug
|
|
378
|
+
stdout_thread = threading.Thread(
|
|
379
|
+
target=stream_output,
|
|
380
|
+
args=(process.stdout, logger.info, current_request_id),
|
|
381
|
+
)
|
|
382
|
+
stderr_thread = threading.Thread(
|
|
383
|
+
target=stream_output,
|
|
384
|
+
args=(
|
|
385
|
+
process.stderr,
|
|
386
|
+
stderr_log_func,
|
|
387
|
+
current_request_id,
|
|
388
|
+
not is_app_cmd,
|
|
389
|
+
),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
stdout_thread.daemon = True
|
|
393
|
+
stderr_thread.daemon = True
|
|
394
|
+
stdout_thread.start()
|
|
395
|
+
stderr_thread.start()
|
|
396
|
+
|
|
397
|
+
if is_app_cmd and os.getenv("KT_APP_PORT"):
|
|
398
|
+
# wait for internal app to be healthy/ready if run port is provided
|
|
399
|
+
try:
|
|
400
|
+
port = os.getenv("KT_APP_PORT")
|
|
401
|
+
logger.debug(
|
|
402
|
+
f"Waiting for internal app on port {port} to start:"
|
|
403
|
+
)
|
|
404
|
+
wait_for_app_start(
|
|
405
|
+
port=port,
|
|
406
|
+
health_check=os.getenv("KT_APP_HEALTHCHECK"),
|
|
407
|
+
process=process,
|
|
408
|
+
)
|
|
409
|
+
logger.info(f"App on port {port} is ready.")
|
|
410
|
+
except Exception as e:
|
|
411
|
+
logger.error(f"Caught exception waiting for app to start: {e}")
|
|
412
|
+
else:
|
|
413
|
+
# Check if this is a background command (ends with &)
|
|
414
|
+
is_background = command.rstrip().endswith("&")
|
|
415
|
+
|
|
416
|
+
if is_background:
|
|
417
|
+
# For background processes, give it a moment to start and check for immediate failures
|
|
418
|
+
import time
|
|
419
|
+
|
|
420
|
+
time.sleep(0.5) # Brief pause to catch immediate errors
|
|
421
|
+
|
|
422
|
+
# Check if process failed immediately
|
|
423
|
+
poll_result = process.poll()
|
|
424
|
+
if poll_result is not None and poll_result != 0:
|
|
425
|
+
# Process exited with error
|
|
426
|
+
stdout_thread.join(timeout=1)
|
|
427
|
+
stderr_thread.join(timeout=1)
|
|
428
|
+
return_code = poll_result
|
|
429
|
+
else:
|
|
430
|
+
# Process is running in background successfully
|
|
431
|
+
logger.info(
|
|
432
|
+
f"Background process started successfully (PID: {process.pid})"
|
|
433
|
+
)
|
|
434
|
+
return_code = 0 # Indicate success for background start
|
|
435
|
+
else:
|
|
436
|
+
# Wait for process to complete
|
|
437
|
+
return_code = process.wait()
|
|
438
|
+
|
|
439
|
+
# Wait for streaming threads to finish
|
|
440
|
+
stdout_thread.join()
|
|
441
|
+
stderr_thread.join()
|
|
442
|
+
|
|
443
|
+
if return_code != 0 and not is_app_cmd:
|
|
444
|
+
with stderr_lock:
|
|
445
|
+
if stderr_lines:
|
|
446
|
+
logger.error(
|
|
447
|
+
f"Failed to run command '{command}' with stderr:"
|
|
448
|
+
)
|
|
449
|
+
for stderr_line in stderr_lines:
|
|
450
|
+
logger.error(stderr_line)
|
|
451
|
+
except subprocess.CalledProcessError as e:
|
|
452
|
+
logger.error(f"Failed to run command '{command}' with error: {e}")
|
|
453
|
+
with stderr_lock:
|
|
454
|
+
if stderr_lines:
|
|
455
|
+
logger.error("Stderr:")
|
|
456
|
+
for stderr_line in stderr_lines:
|
|
457
|
+
logger.error(stderr_line)
|
|
458
|
+
# Check if any dependencies changed and if so reload them inside the server process
|
|
459
|
+
if start_deps:
|
|
460
|
+
try:
|
|
461
|
+
# Run pip freeze and capture the output
|
|
462
|
+
res = subprocess.run(
|
|
463
|
+
["pip", "freeze"],
|
|
464
|
+
capture_output=True,
|
|
465
|
+
text=True,
|
|
466
|
+
check=True,
|
|
467
|
+
)
|
|
468
|
+
end_deps = res.stdout.splitlines()
|
|
469
|
+
# We only need to look at the deps which were already installed (i.e. lines in start_deps),
|
|
470
|
+
# new ones can't be "stale" inside the current server process
|
|
471
|
+
# We also only use lines with exact pypi versions (has "=="), no editable
|
|
472
|
+
changed_deps = [
|
|
473
|
+
line.split("==")[0]
|
|
474
|
+
for line in start_deps
|
|
475
|
+
if "==" in line and line not in end_deps
|
|
476
|
+
]
|
|
477
|
+
imported_changed_deps = [
|
|
478
|
+
dep for dep in changed_deps if dep in sys.modules
|
|
479
|
+
] # Only reload deps which are already imported
|
|
480
|
+
if imported_changed_deps:
|
|
481
|
+
logger.debug(
|
|
482
|
+
f"New dependencies found: {imported_changed_deps}, forcing reload"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Don't clear the callable cache here - let load_callable_from_env handle it to preserve __kt_cached_state__
|
|
486
|
+
if DISTRIBUTED_SUPERVISOR:
|
|
487
|
+
DISTRIBUTED_SUPERVISOR.cleanup()
|
|
488
|
+
|
|
489
|
+
# Remove changed modules from sys.modules to override fresh imports
|
|
490
|
+
modules_to_remove = []
|
|
491
|
+
for module_name in sys.modules:
|
|
492
|
+
for dep in imported_changed_deps:
|
|
493
|
+
if module_name == dep or module_name.startswith(dep + "."):
|
|
494
|
+
modules_to_remove.append(module_name)
|
|
495
|
+
break
|
|
496
|
+
|
|
497
|
+
for module_name in modules_to_remove:
|
|
498
|
+
try:
|
|
499
|
+
del sys.modules[module_name]
|
|
500
|
+
logger.debug(f"Removed module {module_name} from sys.modules")
|
|
501
|
+
except KeyError:
|
|
502
|
+
pass
|
|
503
|
+
except subprocess.CalledProcessError as e:
|
|
504
|
+
logger.error(f"Failed to run pip freeze: {e}")
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def run_image_setup(deployed_time: Optional[float] = None):
|
|
508
|
+
if os.environ["KT_FREEZE"] == "True" or not is_running_in_kubernetes():
|
|
509
|
+
return
|
|
510
|
+
|
|
511
|
+
rsync_file_updates()
|
|
512
|
+
|
|
513
|
+
dockerfile_path = kt_directory() / "image.dockerfile"
|
|
514
|
+
if not dockerfile_path.exists():
|
|
515
|
+
raise FileNotFoundError(
|
|
516
|
+
f"No image and metadata configuration found in path: {str(dockerfile_path)}"
|
|
517
|
+
)
|
|
518
|
+
while (
|
|
519
|
+
# May need to give the dockerfile time to rsync over, so wait until the dockerfile timestamp is later than
|
|
520
|
+
# when we started the deployment (recorded in .to and passed here as deployed_time). We also should only
|
|
521
|
+
# wait if _LAST_DEPLOYED is not zero, as the first time the server is deployed the image is written before
|
|
522
|
+
# the server starts so we don't need to wait.
|
|
523
|
+
_LAST_DEPLOYED
|
|
524
|
+
and dockerfile_path.stat().st_mtime < deployed_time
|
|
525
|
+
and datetime.now(timezone.utc).timestamp() - deployed_time < 5
|
|
526
|
+
):
|
|
527
|
+
time.sleep(0.1)
|
|
528
|
+
|
|
529
|
+
cached_image_setup()
|
|
530
|
+
|
|
531
|
+
if not os.getenv("KT_CALLABLE_TYPE") == "app":
|
|
532
|
+
logger.debug("Completed cached image setup.")
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
#####################################
|
|
536
|
+
######## Generic Helpers ############
|
|
537
|
+
#####################################
|
|
538
|
+
class SerializationError(Exception):
|
|
539
|
+
pass
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def kt_directory():
|
|
543
|
+
if "KT_DIRECTORY" in os.environ:
|
|
544
|
+
return Path(os.environ["KT_DIRECTORY"]).expanduser()
|
|
545
|
+
else:
|
|
546
|
+
return Path.cwd() / ".kt"
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _get_kt_pip_install_cmd() -> Optional[str]:
|
|
550
|
+
"""Get the actual KT_PIP_INSTALL_CMD value for command expansion."""
|
|
551
|
+
kt_pip_cmd = os.getenv("KT_PIP_INSTALL_CMD")
|
|
552
|
+
if not kt_pip_cmd: # Fallback to reading from file
|
|
553
|
+
try:
|
|
554
|
+
with open(kt_directory() / "kt_pip_install_cmd", "r") as f:
|
|
555
|
+
return f.read().strip()
|
|
556
|
+
except FileNotFoundError:
|
|
557
|
+
return None
|
|
558
|
+
return kt_pip_cmd
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def is_running_in_container():
|
|
562
|
+
# Check for .dockerenv file which exists in Docker containers
|
|
563
|
+
return Path("/.dockerenv").exists()
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
async def run_in_executor_with_context(executor, func, *args):
|
|
567
|
+
"""
|
|
568
|
+
Helper to run a function in an executor while preserving the request_id context.
|
|
569
|
+
|
|
570
|
+
This wrapper captures the current request_id from the context before running
|
|
571
|
+
the function in a thread pool executor, then sets it in the new thread.
|
|
572
|
+
"""
|
|
573
|
+
import asyncio
|
|
574
|
+
|
|
575
|
+
# Capture the current request_id before switching threads
|
|
576
|
+
current_request_id = request_id_ctx_var.get("-")
|
|
577
|
+
|
|
578
|
+
def wrapper(*args):
|
|
579
|
+
# Set the request_id in the executor thread
|
|
580
|
+
token = None
|
|
581
|
+
if current_request_id != "-":
|
|
582
|
+
token = request_id_ctx_var.set(current_request_id)
|
|
583
|
+
try:
|
|
584
|
+
return func(*args)
|
|
585
|
+
finally:
|
|
586
|
+
# Clean up the context to avoid leaking between requests
|
|
587
|
+
if token is not None:
|
|
588
|
+
request_id_ctx_var.reset(token)
|
|
589
|
+
|
|
590
|
+
return await asyncio.get_event_loop().run_in_executor(executor, wrapper, *args)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def should_reload(deployed_as_of: Optional[str] = None) -> bool:
|
|
594
|
+
"""
|
|
595
|
+
Determine if the server should reload based on the deployment timestamp.
|
|
596
|
+
If deployed_as_of is provided, it checks against the last deployed time.
|
|
597
|
+
If not provided, it defaults to False.
|
|
598
|
+
"""
|
|
599
|
+
if deployed_as_of in [None, "null", "None"]:
|
|
600
|
+
return False
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
deployed_time = datetime.fromisoformat(deployed_as_of).timestamp()
|
|
604
|
+
return deployed_time > _LAST_DEPLOYED
|
|
605
|
+
except ValueError as e:
|
|
606
|
+
logger.error(f"Invalid deployed_as_of format: {deployed_as_of}. Error: {e}")
|
|
607
|
+
return True
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def load_callable(
|
|
611
|
+
deployed_as_of: Optional[str] = None,
|
|
612
|
+
distributed_subprocess: bool = False,
|
|
613
|
+
reload_cleanup_fn: [Callable, None] = None,
|
|
614
|
+
):
|
|
615
|
+
global _LAST_DEPLOYED
|
|
616
|
+
|
|
617
|
+
callable_name = os.environ["KT_CLS_OR_FN_NAME"]
|
|
618
|
+
|
|
619
|
+
callable_obj = _CACHED_CALLABLES.get(callable_name, None)
|
|
620
|
+
if callable_obj and not should_reload(deployed_as_of):
|
|
621
|
+
# If the callable is cached and doesn't need reload, return it immediately
|
|
622
|
+
logger.debug("Returning cached callable.")
|
|
623
|
+
return callable_obj
|
|
624
|
+
|
|
625
|
+
# Slow path: need to load or reload - use lock for thread safety
|
|
626
|
+
with _CALLABLE_LOAD_LOCK:
|
|
627
|
+
# Double-check within lock (another thread might have loaded it)
|
|
628
|
+
callable_obj = _CACHED_CALLABLES.get(callable_name, None)
|
|
629
|
+
if callable_obj and not should_reload(deployed_as_of):
|
|
630
|
+
logger.debug("Returning cached callable (found after acquiring lock).")
|
|
631
|
+
return callable_obj
|
|
632
|
+
# Proceed with loading/reloading
|
|
633
|
+
return _load_callable_internal(
|
|
634
|
+
deployed_as_of, distributed_subprocess, reload_cleanup_fn, callable_obj
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _load_callable_internal(
|
|
639
|
+
deployed_as_of: Optional[str] = None,
|
|
640
|
+
distributed_subprocess: bool = False,
|
|
641
|
+
reload_cleanup_fn: [Callable, None] = None,
|
|
642
|
+
callable_obj=None,
|
|
643
|
+
):
|
|
644
|
+
"""Internal callable loading logic - should be called within lock for thread safety."""
|
|
645
|
+
global _LAST_DEPLOYED
|
|
646
|
+
|
|
647
|
+
callable_name = os.environ["KT_CLS_OR_FN_NAME"]
|
|
648
|
+
|
|
649
|
+
if not callable_obj:
|
|
650
|
+
logger.debug("Callable not found in cache, loading from environment.")
|
|
651
|
+
else:
|
|
652
|
+
logger.debug(
|
|
653
|
+
f"Callable found in cache, but reloading because deployed_as_of {deployed_as_of} is newer than last deployed time {_LAST_DEPLOYED}"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# If not in cache or we have a more recent deployment timestamp, update metadata and reload
|
|
657
|
+
if reload_cleanup_fn and _LAST_DEPLOYED:
|
|
658
|
+
# If a reload cleanup function is provided and we've already deployed at least once, call it before
|
|
659
|
+
# reloading the callable
|
|
660
|
+
reload_cleanup_fn()
|
|
661
|
+
|
|
662
|
+
deployed_time = (
|
|
663
|
+
datetime.fromisoformat(deployed_as_of).timestamp()
|
|
664
|
+
if deployed_as_of
|
|
665
|
+
else datetime.now(timezone.utc).timestamp()
|
|
666
|
+
)
|
|
667
|
+
if not distributed_subprocess:
|
|
668
|
+
# We don't reload the image in distributed subprocess/es, as we already did it in the
|
|
669
|
+
# main process and we don't want to do it multiple times (in each subprocess).
|
|
670
|
+
if _LAST_DEPLOYED:
|
|
671
|
+
logger.info("Patching image and code updates and reloading callable.")
|
|
672
|
+
else:
|
|
673
|
+
logger.info("Setting up image and loading callable.")
|
|
674
|
+
run_image_setup(deployed_time)
|
|
675
|
+
|
|
676
|
+
distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
|
|
677
|
+
if distributed_config not in ["null", "None"] and not distributed_subprocess:
|
|
678
|
+
logger.debug(f"Loading distributed supervisor: {distributed_config}")
|
|
679
|
+
callable_obj = load_distributed_supervisor(deployed_as_of=deployed_as_of)
|
|
680
|
+
logger.debug("Distributed supervisor loaded successfully.")
|
|
681
|
+
else:
|
|
682
|
+
logger.debug(f"Loading callable from environment: {callable_name}")
|
|
683
|
+
callable_obj = load_callable_from_env()
|
|
684
|
+
logger.debug("Callable loaded successfully.")
|
|
685
|
+
|
|
686
|
+
_LAST_DEPLOYED = deployed_time
|
|
687
|
+
_CACHED_CALLABLES[callable_name] = callable_obj
|
|
688
|
+
|
|
689
|
+
return callable_obj
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def load_distributed_supervisor(deployed_as_of: Optional[str] = None):
|
|
693
|
+
global DISTRIBUTED_SUPERVISOR
|
|
694
|
+
|
|
695
|
+
if os.environ["KT_FILE_PATH"] not in sys.path:
|
|
696
|
+
sys.path.insert(0, os.environ["KT_FILE_PATH"])
|
|
697
|
+
|
|
698
|
+
distributed_config = os.environ["KT_DISTRIBUTED_CONFIG"]
|
|
699
|
+
|
|
700
|
+
# If this is the main process of a distributed call, we don't load the callable directly,
|
|
701
|
+
# we create a new supervisor if it doesn't exist or if the config has changed.
|
|
702
|
+
# We don't create a supervisor if this is a distributed subprocess.
|
|
703
|
+
config_hash = hash(str(distributed_config))
|
|
704
|
+
if (
|
|
705
|
+
DISTRIBUTED_SUPERVISOR is None
|
|
706
|
+
or config_hash != DISTRIBUTED_SUPERVISOR.config_hash
|
|
707
|
+
):
|
|
708
|
+
from .distributed_utils import distributed_supervisor_factory
|
|
709
|
+
|
|
710
|
+
logger.info(f"Loading distributed supervisor with config: {distributed_config}")
|
|
711
|
+
distributed_config = json.loads(distributed_config)
|
|
712
|
+
# If we already have some distributed processes, we need to clean them up before creating a new supervisor.
|
|
713
|
+
if DISTRIBUTED_SUPERVISOR:
|
|
714
|
+
DISTRIBUTED_SUPERVISOR.cleanup()
|
|
715
|
+
DISTRIBUTED_SUPERVISOR = distributed_supervisor_factory(**distributed_config)
|
|
716
|
+
DISTRIBUTED_SUPERVISOR.config_hash = config_hash
|
|
717
|
+
try:
|
|
718
|
+
# If there are any errors during setup, we catch and log them, and then undo the setup
|
|
719
|
+
# so that the distributed supervisor is not left in a broken state (and otherwise can still fail
|
|
720
|
+
# when we call DISTRIBUTED_SUPERVISOR.cleanup() in lifespan).
|
|
721
|
+
DISTRIBUTED_SUPERVISOR.setup(deployed_as_of=deployed_as_of)
|
|
722
|
+
except Exception as e:
|
|
723
|
+
logger.error(
|
|
724
|
+
f"Failed to set up distributed supervisor with config {distributed_config}: {e}"
|
|
725
|
+
)
|
|
726
|
+
DISTRIBUTED_SUPERVISOR = None
|
|
727
|
+
raise e
|
|
728
|
+
return DISTRIBUTED_SUPERVISOR
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def patch_sys_path():
|
|
732
|
+
abs_path = str(Path(os.environ["KT_FILE_PATH"]).expanduser().resolve())
|
|
733
|
+
if os.environ["KT_FILE_PATH"] not in sys.path:
|
|
734
|
+
sys.path.insert(0, abs_path)
|
|
735
|
+
logger.debug(f"Added {abs_path} to sys.path")
|
|
736
|
+
|
|
737
|
+
# Maybe needed for subprocesses (e.g. distributed) to find the callable's module
|
|
738
|
+
# Needed for distributed subprocesses to find the file path
|
|
739
|
+
existing_path = os.environ.get("PYTHONPATH", "")
|
|
740
|
+
if os.environ["KT_FILE_PATH"] not in existing_path:
|
|
741
|
+
os.environ["PYTHONPATH"] = (
|
|
742
|
+
f"{abs_path}{os.pathsep}{existing_path}" if existing_path else abs_path
|
|
743
|
+
)
|
|
744
|
+
logger.debug(f"Set PYTHONPATH to {os.environ['PYTHONPATH']}")
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def load_callable_from_env():
|
|
748
|
+
"""Load and cache callable objects from env, preserving state if __kt_cached_state__ is available."""
|
|
749
|
+
cls_or_fn_name = os.environ["KT_CLS_OR_FN_NAME"]
|
|
750
|
+
module_name = os.environ["KT_MODULE_NAME"]
|
|
751
|
+
|
|
752
|
+
# Check if we have an existing cached callable and extract state if available
|
|
753
|
+
cached_state = None
|
|
754
|
+
existing_callable = _CACHED_CALLABLES.get(cls_or_fn_name, None)
|
|
755
|
+
|
|
756
|
+
if existing_callable and hasattr(existing_callable, "__kt_cached_state__"):
|
|
757
|
+
try:
|
|
758
|
+
logger.info(
|
|
759
|
+
f"Extracting cached state from {cls_or_fn_name} via __kt_cached_state__"
|
|
760
|
+
)
|
|
761
|
+
cached_state = existing_callable.__kt_cached_state__()
|
|
762
|
+
if cached_state is not None and not isinstance(cached_state, dict):
|
|
763
|
+
logger.warning(
|
|
764
|
+
f"__kt_cached_state__ returned non-dict type: {type(cached_state)}. Ignoring cached state."
|
|
765
|
+
)
|
|
766
|
+
cached_state = None
|
|
767
|
+
except Exception as e:
|
|
768
|
+
# This could happen if modules were removed from sys.modules during image setup
|
|
769
|
+
# and the callable's __kt_cached_state__ method depends on them
|
|
770
|
+
logger.warning(
|
|
771
|
+
f"Failed to extract cached state from {cls_or_fn_name} (possibly due to module reloading): {e}. "
|
|
772
|
+
f"Proceeding without cached state."
|
|
773
|
+
)
|
|
774
|
+
cached_state = None
|
|
775
|
+
|
|
776
|
+
# Now that we have the state, clean up the old callable to free memory
|
|
777
|
+
if existing_callable:
|
|
778
|
+
logger.debug(f"Deleting existing callable: {cls_or_fn_name}")
|
|
779
|
+
_CACHED_CALLABLES.pop(cls_or_fn_name, None)
|
|
780
|
+
del existing_callable
|
|
781
|
+
# Garbage collect to ensure everything cleaned up (especially GPU memory)
|
|
782
|
+
import gc
|
|
783
|
+
|
|
784
|
+
gc.collect()
|
|
785
|
+
|
|
786
|
+
patch_sys_path()
|
|
787
|
+
|
|
788
|
+
# If we're inside a distributed subprocess or the main process of a non-distributed call,
|
|
789
|
+
# we load and instantiate the callable.
|
|
790
|
+
try:
|
|
791
|
+
# Try regular package import first
|
|
792
|
+
if module_name in sys.modules:
|
|
793
|
+
# We make this logs to info because some imports are slow and we want the user to know that it's not our fault
|
|
794
|
+
# and not hanging
|
|
795
|
+
logger.info(f"Reimporting module {module_name}")
|
|
796
|
+
# Clear any existing debugging sessions when reloading modules
|
|
797
|
+
clear_debugging_sessions()
|
|
798
|
+
module = importlib.reload(sys.modules[module_name])
|
|
799
|
+
else:
|
|
800
|
+
logger.debug(f"Importing module {module_name}")
|
|
801
|
+
module = importlib.import_module(module_name)
|
|
802
|
+
logger.debug(f"Module {module_name} loaded")
|
|
803
|
+
|
|
804
|
+
# Ensure our structured logging is in place after user module import
|
|
805
|
+
# (in case the user's module configured its own logging)
|
|
806
|
+
ensure_structured_logging()
|
|
807
|
+
|
|
808
|
+
callable_obj = getattr(module, cls_or_fn_name)
|
|
809
|
+
logger.debug(f"Callable {cls_or_fn_name} loaded")
|
|
810
|
+
except (ImportError, ValueError) as original_error:
|
|
811
|
+
# Fall back to file-based import if package import fails
|
|
812
|
+
try:
|
|
813
|
+
module = import_from_file(os.environ["KT_FILE_PATH"], module_name)
|
|
814
|
+
# Ensure structured logging after file-based import
|
|
815
|
+
ensure_structured_logging()
|
|
816
|
+
callable_obj = getattr(module, cls_or_fn_name)
|
|
817
|
+
except (ImportError, ValueError):
|
|
818
|
+
# Raise the original error if file import also fails, because the errors which are raised here are
|
|
819
|
+
# more opaque and less useful than the original ImportError or ValueError.
|
|
820
|
+
raise original_error
|
|
821
|
+
except AttributeError as e:
|
|
822
|
+
# If the callable is not found in the module, raise an error
|
|
823
|
+
raise HTTPException(
|
|
824
|
+
status_code=404,
|
|
825
|
+
detail=f"Callable '{cls_or_fn_name}' not found in module '{module_name}'",
|
|
826
|
+
) from e
|
|
827
|
+
|
|
828
|
+
# Unwrap to remove any kt deploy decorators (e.g. @kt.compute)
|
|
829
|
+
if hasattr(callable_obj, "__wrapped__"):
|
|
830
|
+
callable_obj = callable_obj.__wrapped__
|
|
831
|
+
|
|
832
|
+
if isinstance(callable_obj, type):
|
|
833
|
+
# Prepare init arguments
|
|
834
|
+
init_kwargs = {}
|
|
835
|
+
|
|
836
|
+
# Add user-provided init_args
|
|
837
|
+
if os.environ["KT_INIT_ARGS"] not in ["null", "None"]:
|
|
838
|
+
init_kwargs = json.loads(os.environ["KT_INIT_ARGS"])
|
|
839
|
+
logger.info(f"Setting init_args {init_kwargs}")
|
|
840
|
+
|
|
841
|
+
# Add cached state if available
|
|
842
|
+
# Allow user to manually set "kt_cached_state" to override/disable cache
|
|
843
|
+
if cached_state is not None and "kt_cached_state" not in init_kwargs:
|
|
844
|
+
# Check if the class's __init__ accepts kt_cached_state parameter
|
|
845
|
+
sig = inspect.signature(callable_obj.__init__)
|
|
846
|
+
if "kt_cached_state" in sig.parameters:
|
|
847
|
+
logger.info(f"Passing cached state to {cls_or_fn_name}.__init__")
|
|
848
|
+
init_kwargs["kt_cached_state"] = cached_state
|
|
849
|
+
else:
|
|
850
|
+
raise ValueError(
|
|
851
|
+
f"Class {cls_or_fn_name} has __kt_cached_state__ method but __init__ does not accept "
|
|
852
|
+
f"'kt_cached_state' parameter. Please add 'kt_cached_state=None' to __init__ signature."
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
# Instantiate with combined arguments
|
|
856
|
+
if init_kwargs:
|
|
857
|
+
callable_obj = callable_obj(**init_kwargs)
|
|
858
|
+
else:
|
|
859
|
+
callable_obj = callable_obj()
|
|
860
|
+
|
|
861
|
+
return callable_obj
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def import_from_file(file_path: str, module_name: str):
|
|
865
|
+
"""Import a module from file path."""
|
|
866
|
+
module_parts = module_name.split(".")
|
|
867
|
+
depth = max(0, len(module_parts) - 1)
|
|
868
|
+
|
|
869
|
+
# Convert file_path to absolute path if it's not already (note, .resolve will append the current working directory
|
|
870
|
+
# if file_path is relative)
|
|
871
|
+
abs_path = Path(file_path).expanduser().resolve()
|
|
872
|
+
# Ensure depth doesn't exceed available parent directories
|
|
873
|
+
max_available_depth = len(abs_path.parents) - 1
|
|
874
|
+
|
|
875
|
+
if max_available_depth < 0:
|
|
876
|
+
# File has no parent directories, use the file's directory itself
|
|
877
|
+
parent_path = str(abs_path.parent)
|
|
878
|
+
else:
|
|
879
|
+
# Clamp depth to available range to avoid IndexError
|
|
880
|
+
depth = min(depth, max_available_depth)
|
|
881
|
+
parent_path = str(abs_path.parents[depth])
|
|
882
|
+
|
|
883
|
+
if parent_path not in sys.path:
|
|
884
|
+
sys.path.insert(0, parent_path)
|
|
885
|
+
|
|
886
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
887
|
+
if spec is None or spec.loader is None:
|
|
888
|
+
raise ImportError(
|
|
889
|
+
f"Could not load spec for module {module_name} from {file_path}"
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
module = importlib.util.module_from_spec(spec)
|
|
893
|
+
spec.loader.exec_module(module)
|
|
894
|
+
return module
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
#####################################
|
|
898
|
+
########## Rsync Helpers ############
|
|
899
|
+
#####################################
|
|
900
|
+
def generate_rsync_command(subdir: str = ".", exclude_absolute: bool = True):
|
|
901
|
+
"""Generate rsync command for syncing from jump pod.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
subdir: Directory to sync to (default current directory)
|
|
905
|
+
exclude_absolute: Whether to exclude __absolute__ directory (default True)
|
|
906
|
+
"""
|
|
907
|
+
service_name = os.getenv("KT_SERVICE_NAME")
|
|
908
|
+
namespace = os.getenv("POD_NAMESPACE")
|
|
909
|
+
|
|
910
|
+
exclude_opt = "--exclude='__absolute__*' " if exclude_absolute else ""
|
|
911
|
+
logger.debug("Syncing code from rsync pod to local directory")
|
|
912
|
+
return f"rsync -av {exclude_opt}rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/ {subdir}"
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
def rsync_file_updates():
|
|
916
|
+
"""Rsync files from the jump pod to the worker pod.
|
|
917
|
+
|
|
918
|
+
Performs two rsync operations in parallel:
|
|
919
|
+
1. Regular files (excluding __absolute__*) to the working directory
|
|
920
|
+
2. Absolute path files (under __absolute__/) to their absolute destinations
|
|
921
|
+
"""
|
|
922
|
+
import concurrent.futures
|
|
923
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
924
|
+
|
|
925
|
+
service_name = os.getenv("KT_SERVICE_NAME")
|
|
926
|
+
namespace = os.getenv("POD_NAMESPACE")
|
|
927
|
+
|
|
928
|
+
# Build base rsync URL
|
|
929
|
+
rsync_base = f"rsync://kubetorch-rsync.{namespace}.svc.cluster.local:{RSYNC_PORT}/data/{namespace}/{service_name}/"
|
|
930
|
+
|
|
931
|
+
max_retries = 5
|
|
932
|
+
base_delay = 1 # seconds
|
|
933
|
+
max_delay = 30 # seconds
|
|
934
|
+
|
|
935
|
+
def run_rsync_with_retries(rsync_cmd, description):
|
|
936
|
+
"""Helper to run rsync with exponential backoff retries."""
|
|
937
|
+
for attempt in range(max_retries):
|
|
938
|
+
resp = subprocess.run(
|
|
939
|
+
rsync_cmd,
|
|
940
|
+
shell=True,
|
|
941
|
+
capture_output=True,
|
|
942
|
+
text=True,
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
if resp.returncode == 0:
|
|
946
|
+
logger.debug(f"Successfully rsync'd {description}")
|
|
947
|
+
return # Success!
|
|
948
|
+
|
|
949
|
+
# Check if it's a retryable error
|
|
950
|
+
retryable_errors = [
|
|
951
|
+
"max connections",
|
|
952
|
+
"Temporary failure in name resolution",
|
|
953
|
+
"Name or service not known",
|
|
954
|
+
"Connection refused",
|
|
955
|
+
"No route to host",
|
|
956
|
+
]
|
|
957
|
+
|
|
958
|
+
is_retryable = any(error in resp.stderr for error in retryable_errors)
|
|
959
|
+
|
|
960
|
+
if is_retryable and attempt < max_retries - 1:
|
|
961
|
+
# Calculate exponential backoff with jitter
|
|
962
|
+
delay = min(
|
|
963
|
+
base_delay * (2**attempt) + random.uniform(0, 1), max_delay
|
|
964
|
+
)
|
|
965
|
+
logger.warning(
|
|
966
|
+
f"Rsync {description} failed with retryable error: {resp.stderr.strip()}. "
|
|
967
|
+
f"Retrying in {delay:.1f} seconds (attempt {attempt + 1}/{max_retries})"
|
|
968
|
+
)
|
|
969
|
+
time.sleep(delay)
|
|
970
|
+
else:
|
|
971
|
+
# For non-retryable errors or final attempt, raise immediately
|
|
972
|
+
if attempt == max_retries - 1:
|
|
973
|
+
logger.error(
|
|
974
|
+
f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}"
|
|
975
|
+
)
|
|
976
|
+
raise RuntimeError(
|
|
977
|
+
f"Rsync {description} failed with error: {resp.stderr}"
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
# If we exhausted all retries
|
|
981
|
+
raise RuntimeError(
|
|
982
|
+
f"Rsync {description} failed after {max_retries} attempts. Last error: {resp.stderr}"
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
def rsync_regular_files():
|
|
986
|
+
"""Rsync regular files (excluding __absolute__*) to working directory."""
|
|
987
|
+
rsync_cmd_regular = f"rsync -avL --exclude='__absolute__*' {rsync_base} ."
|
|
988
|
+
logger.debug(f"Rsyncing regular files with command: {rsync_cmd_regular}")
|
|
989
|
+
run_rsync_with_retries(rsync_cmd_regular, "regular files")
|
|
990
|
+
|
|
991
|
+
def rsync_absolute_files():
|
|
992
|
+
"""Rsync absolute path files to their absolute destinations."""
|
|
993
|
+
# First, do a dry-run to see if __absolute__ directory exists
|
|
994
|
+
check_cmd = f"rsync --list-only {rsync_base}__absolute__/"
|
|
995
|
+
check_resp = subprocess.run(
|
|
996
|
+
check_cmd, shell=True, capture_output=True, text=True
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
if check_resp.returncode == 0 and check_resp.stdout.strip():
|
|
1000
|
+
# __absolute__ directory exists, sync its contents to root
|
|
1001
|
+
# The trick is to sync from __absolute__/ to / which places files in their absolute paths
|
|
1002
|
+
rsync_cmd_absolute = f"rsync -avL {rsync_base}__absolute__/ /"
|
|
1003
|
+
logger.debug(
|
|
1004
|
+
f"Rsyncing absolute path files with command: {rsync_cmd_absolute}"
|
|
1005
|
+
)
|
|
1006
|
+
run_rsync_with_retries(rsync_cmd_absolute, "absolute path files")
|
|
1007
|
+
else:
|
|
1008
|
+
logger.debug("No absolute path files to sync")
|
|
1009
|
+
|
|
1010
|
+
# Run both rsync operations in parallel
|
|
1011
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
1012
|
+
# Submit both tasks
|
|
1013
|
+
regular_future = executor.submit(rsync_regular_files)
|
|
1014
|
+
absolute_future = executor.submit(rsync_absolute_files)
|
|
1015
|
+
|
|
1016
|
+
# Wait for both to complete and handle any exceptions
|
|
1017
|
+
futures = [regular_future, absolute_future]
|
|
1018
|
+
for future in concurrent.futures.as_completed(futures):
|
|
1019
|
+
try:
|
|
1020
|
+
future.result() # This will raise any exception that occurred
|
|
1021
|
+
except Exception as e:
|
|
1022
|
+
# Cancel remaining futures if one fails
|
|
1023
|
+
for f in futures:
|
|
1024
|
+
f.cancel()
|
|
1025
|
+
raise e
|
|
1026
|
+
|
|
1027
|
+
logger.debug("Completed rsync of all files")
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
#####################################
|
|
1031
|
+
########### App setup ###############
|
|
1032
|
+
#####################################
|
|
1033
|
+
class HealthCheckFilter(logging.Filter):
|
|
1034
|
+
def filter(self, record):
|
|
1035
|
+
return not (
|
|
1036
|
+
isinstance(record.args, tuple)
|
|
1037
|
+
and len(record.args) >= 3
|
|
1038
|
+
and ("/health" in record.args[2] or record.args[2] == "/")
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
class RequestContextFilter(logging.Filter):
|
|
1043
|
+
def filter(self, record):
|
|
1044
|
+
record.request_id = request_id_ctx_var.get("-")
|
|
1045
|
+
record.pod = os.getenv("POD_NAME", "unknown-pod")
|
|
1046
|
+
|
|
1047
|
+
if instrument_traces:
|
|
1048
|
+
from opentelemetry.trace import format_trace_id, get_current_span
|
|
1049
|
+
|
|
1050
|
+
# Add trace_id and span_id for log correlation
|
|
1051
|
+
current_span = get_current_span()
|
|
1052
|
+
if current_span and current_span.get_span_context().is_valid:
|
|
1053
|
+
record.trace_id = format_trace_id(
|
|
1054
|
+
current_span.get_span_context().trace_id
|
|
1055
|
+
)
|
|
1056
|
+
record.span_id = format_trace_id(
|
|
1057
|
+
current_span.get_span_context().span_id
|
|
1058
|
+
)
|
|
1059
|
+
else:
|
|
1060
|
+
record.trace_id = "-"
|
|
1061
|
+
record.span_id = "-"
|
|
1062
|
+
|
|
1063
|
+
return True
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
class TerminationCheckMiddleware(BaseHTTPMiddleware):
|
|
1067
|
+
"""Monitor for termination while request is running and return error if detected."""
|
|
1068
|
+
|
|
1069
|
+
async def dispatch(self, request: Request, call_next):
|
|
1070
|
+
# Skip health checks and metrics endpoints
|
|
1071
|
+
if request.url.path in ["/health", "/", "/metrics"]:
|
|
1072
|
+
return await call_next(request)
|
|
1073
|
+
|
|
1074
|
+
# Run the actual request in the background
|
|
1075
|
+
import asyncio
|
|
1076
|
+
|
|
1077
|
+
request_task = asyncio.create_task(call_next(request))
|
|
1078
|
+
|
|
1079
|
+
# Monitor for termination while request is running
|
|
1080
|
+
while not request_task.done():
|
|
1081
|
+
# Check if we're terminating
|
|
1082
|
+
if TERMINATION_EVENT.is_set() or (
|
|
1083
|
+
hasattr(request.app.state, "terminating")
|
|
1084
|
+
and request.app.state.terminating
|
|
1085
|
+
):
|
|
1086
|
+
# Cancel the request task
|
|
1087
|
+
request_task.cancel()
|
|
1088
|
+
|
|
1089
|
+
# Return PodTerminatedError
|
|
1090
|
+
from kubetorch import PodTerminatedError
|
|
1091
|
+
from kubetorch.servers.http.http_server import package_exception
|
|
1092
|
+
|
|
1093
|
+
pod_name = os.environ.get("POD_NAME", "unknown")
|
|
1094
|
+
exc = PodTerminatedError(
|
|
1095
|
+
pod_name=pod_name,
|
|
1096
|
+
reason="SIGTERM",
|
|
1097
|
+
status_code=503,
|
|
1098
|
+
events=[
|
|
1099
|
+
{
|
|
1100
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
1101
|
+
"reason": "Terminating",
|
|
1102
|
+
"message": "Pod received SIGTERM signal and is shutting down gracefully",
|
|
1103
|
+
}
|
|
1104
|
+
],
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
return package_exception(exc)
|
|
1108
|
+
|
|
1109
|
+
# Wait a bit before checking again or for request to complete
|
|
1110
|
+
try:
|
|
1111
|
+
result = await asyncio.wait_for(
|
|
1112
|
+
asyncio.shield(request_task), timeout=0.5
|
|
1113
|
+
)
|
|
1114
|
+
return result
|
|
1115
|
+
except asyncio.TimeoutError:
|
|
1116
|
+
# Request still running after 0.5s, continue loop to check termination again
|
|
1117
|
+
continue
|
|
1118
|
+
|
|
1119
|
+
# Request completed normally
|
|
1120
|
+
return await request_task
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
class RequestIDMiddleware(BaseHTTPMiddleware):
|
|
1124
|
+
async def dispatch(self, request: Request, call_next):
|
|
1125
|
+
request_id = request.headers.get("X-Request-ID", "-")
|
|
1126
|
+
token = request_id_ctx_var.set(request_id)
|
|
1127
|
+
|
|
1128
|
+
if instrument_traces and request_id != "-":
|
|
1129
|
+
span_attributes = {
|
|
1130
|
+
"request_id": request_id,
|
|
1131
|
+
"http.method": request.method,
|
|
1132
|
+
"http.url": str(request.url),
|
|
1133
|
+
"service.name": os.environ.get("OTEL_SERVICE_NAME"),
|
|
1134
|
+
"service.instance.id": os.environ.get("POD_NAME"),
|
|
1135
|
+
}
|
|
1136
|
+
# of the pod crashes (e.g., due to OOM) during execution of run_callable, we'll still have at least
|
|
1137
|
+
# this heartbeat span recorded
|
|
1138
|
+
tracer = trace.get_tracer("heartbeat")
|
|
1139
|
+
try:
|
|
1140
|
+
with tracer.start_as_current_span(
|
|
1141
|
+
"heartbeat.request", attributes=span_attributes
|
|
1142
|
+
):
|
|
1143
|
+
tracer_provider = trace.get_tracer_provider()
|
|
1144
|
+
if isinstance(tracer_provider, TracerProvider):
|
|
1145
|
+
tracer_provider.force_flush()
|
|
1146
|
+
except Exception as e:
|
|
1147
|
+
logger.warning(f"Heartbeat span flush failed: {e}")
|
|
1148
|
+
|
|
1149
|
+
try:
|
|
1150
|
+
response = await call_next(request)
|
|
1151
|
+
return response
|
|
1152
|
+
finally:
|
|
1153
|
+
# Reset the context variable to its default value
|
|
1154
|
+
request_id_ctx_var.reset(token)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
class TraceFlushMiddleware(BaseHTTPMiddleware):
|
|
1158
|
+
"""Flush traces after each HTTP Request so we don't lose trace data if the pod is killed"""
|
|
1159
|
+
|
|
1160
|
+
async def dispatch(self, request: Request, call_next):
|
|
1161
|
+
response = await call_next(request)
|
|
1162
|
+
tracer_provider = trace.get_tracer_provider()
|
|
1163
|
+
if isinstance(tracer_provider, TracerProvider):
|
|
1164
|
+
tracer_provider.force_flush()
|
|
1165
|
+
return response
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
class StreamToLogger:
|
|
1169
|
+
def __init__(self, logger, log_level=logging.INFO, original_stream=None):
|
|
1170
|
+
self.logger = logger
|
|
1171
|
+
self.log_level = log_level
|
|
1172
|
+
self.original_stream = original_stream
|
|
1173
|
+
self.linebuf = ""
|
|
1174
|
+
|
|
1175
|
+
def _is_from_logging(self):
|
|
1176
|
+
"""Check if the current write call is coming from the logging system"""
|
|
1177
|
+
frame = sys._getframe()
|
|
1178
|
+
while frame:
|
|
1179
|
+
if frame.f_globals.get("__name__", "").startswith("logging"):
|
|
1180
|
+
return True
|
|
1181
|
+
frame = frame.f_back
|
|
1182
|
+
return False
|
|
1183
|
+
|
|
1184
|
+
def write(self, buf):
|
|
1185
|
+
# Check if this is from logging system
|
|
1186
|
+
is_from_logging = self._is_from_logging()
|
|
1187
|
+
|
|
1188
|
+
# Always write to original stream first
|
|
1189
|
+
if self.original_stream:
|
|
1190
|
+
self.original_stream.write(buf)
|
|
1191
|
+
self.original_stream.flush()
|
|
1192
|
+
|
|
1193
|
+
# Skip logging if this is from the logging system to prevent infinite loops
|
|
1194
|
+
if self.logger.name == "print_redirect" and is_from_logging:
|
|
1195
|
+
return
|
|
1196
|
+
|
|
1197
|
+
# Buffer and log complete lines
|
|
1198
|
+
temp_linebuf = self.linebuf + buf
|
|
1199
|
+
self.linebuf = ""
|
|
1200
|
+
|
|
1201
|
+
# Split on newlines but keep carriage returns
|
|
1202
|
+
lines = []
|
|
1203
|
+
current_line = ""
|
|
1204
|
+
for char in temp_linebuf:
|
|
1205
|
+
if char == "\n":
|
|
1206
|
+
lines.append(current_line)
|
|
1207
|
+
current_line = ""
|
|
1208
|
+
else:
|
|
1209
|
+
current_line += char
|
|
1210
|
+
|
|
1211
|
+
# Add any remaining content to linebuf
|
|
1212
|
+
if current_line:
|
|
1213
|
+
self.linebuf = current_line
|
|
1214
|
+
|
|
1215
|
+
# Log complete lines
|
|
1216
|
+
for line in lines:
|
|
1217
|
+
if line:
|
|
1218
|
+
self.logger.log(self.log_level, line)
|
|
1219
|
+
|
|
1220
|
+
def flush(self):
|
|
1221
|
+
if self.original_stream:
|
|
1222
|
+
self.original_stream.flush()
|
|
1223
|
+
if self.linebuf != "":
|
|
1224
|
+
self.logger.log(self.log_level, self.linebuf)
|
|
1225
|
+
self.linebuf = ""
|
|
1226
|
+
|
|
1227
|
+
def isatty(self):
|
|
1228
|
+
# Delegate to the original stream if it exists, else return False
|
|
1229
|
+
if self.original_stream and hasattr(self.original_stream, "isatty"):
|
|
1230
|
+
return self.original_stream.isatty()
|
|
1231
|
+
return False
|
|
1232
|
+
|
|
1233
|
+
def fileno(self):
|
|
1234
|
+
if self.original_stream and hasattr(self.original_stream, "fileno"):
|
|
1235
|
+
return self.original_stream.fileno()
|
|
1236
|
+
raise OSError("Stream does not support fileno()")
|
|
1237
|
+
|
|
1238
|
+
@property
|
|
1239
|
+
def encoding(self):
|
|
1240
|
+
# Return the encoding of the original stream if available, else UTF-8
|
|
1241
|
+
if self.original_stream and hasattr(self.original_stream, "encoding"):
|
|
1242
|
+
return self.original_stream.encoding
|
|
1243
|
+
return "utf-8"
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
# Save original streams before redirection
|
|
1247
|
+
_original_stdout = sys.stdout
|
|
1248
|
+
_original_stderr = sys.stderr
|
|
1249
|
+
|
|
1250
|
+
# Redirect stdout and stderr to our logger while preserving original streams
|
|
1251
|
+
sys.stdout = StreamToLogger(print_logger, logging.INFO, _original_stdout)
|
|
1252
|
+
sys.stderr = StreamToLogger(print_logger, logging.ERROR, _original_stderr)
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
@asynccontextmanager
|
|
1256
|
+
async def lifespan(app: FastAPI):
|
|
1257
|
+
"""Manage application lifecycle"""
|
|
1258
|
+
import signal
|
|
1259
|
+
import threading
|
|
1260
|
+
|
|
1261
|
+
# Only register signal handlers if we're in the main thread
|
|
1262
|
+
# This allows tests to run without signal handling
|
|
1263
|
+
if threading.current_thread() is threading.main_thread():
|
|
1264
|
+
# Save any existing SIGTERM handler
|
|
1265
|
+
original_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|
1266
|
+
|
|
1267
|
+
def handle_sigterm(signum, frame):
|
|
1268
|
+
"""Handle SIGTERM for graceful shutdown."""
|
|
1269
|
+
logger.info("Received SIGTERM, initiating graceful shutdown...")
|
|
1270
|
+
|
|
1271
|
+
# Mark that we're terminating and interrupt existing requests IMMEDIATELY
|
|
1272
|
+
app.state.terminating = True
|
|
1273
|
+
TERMINATION_EVENT.set()
|
|
1274
|
+
|
|
1275
|
+
# Clean up distributed supervisor to ensure child processes are terminated
|
|
1276
|
+
# This is important because SIGTERM is not propagated to child processes automatically
|
|
1277
|
+
# This runs synchronously and may take 1-2 seconds, but existing requests are already interrupted
|
|
1278
|
+
global DISTRIBUTED_SUPERVISOR
|
|
1279
|
+
if DISTRIBUTED_SUPERVISOR:
|
|
1280
|
+
logger.info("Cleaning up distributed supervisor and child processes...")
|
|
1281
|
+
try:
|
|
1282
|
+
DISTRIBUTED_SUPERVISOR.cleanup()
|
|
1283
|
+
except Exception as e:
|
|
1284
|
+
logger.error(f"Error cleaning up distributed supervisor: {e}")
|
|
1285
|
+
|
|
1286
|
+
# Call the original handler if it exists and isn't the default
|
|
1287
|
+
if original_sigterm_handler and original_sigterm_handler not in (
|
|
1288
|
+
signal.SIG_DFL,
|
|
1289
|
+
signal.SIG_IGN,
|
|
1290
|
+
):
|
|
1291
|
+
original_sigterm_handler(signum, frame)
|
|
1292
|
+
|
|
1293
|
+
# Register SIGTERM handler
|
|
1294
|
+
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
1295
|
+
app.state.terminating = False
|
|
1296
|
+
|
|
1297
|
+
# Startup
|
|
1298
|
+
ttl = get_inactivity_ttl_annotation()
|
|
1299
|
+
if ttl and KT_OTEL_ENABLED is True:
|
|
1300
|
+
app.state.heartbeat_manager = HeartbeatManager(ttl_seconds=ttl)
|
|
1301
|
+
if app.state.heartbeat_manager:
|
|
1302
|
+
await app.state.heartbeat_manager.start()
|
|
1303
|
+
logger.debug(f"Heartbeat manager started with TTL={ttl}s")
|
|
1304
|
+
elif ttl:
|
|
1305
|
+
logger.warning(
|
|
1306
|
+
"TTL annotation found, but OTEL is not enabled, heartbeat disabled"
|
|
1307
|
+
)
|
|
1308
|
+
else:
|
|
1309
|
+
logger.debug("No TTL annotation found, heartbeat disabled")
|
|
1310
|
+
|
|
1311
|
+
try:
|
|
1312
|
+
cached_image_setup()
|
|
1313
|
+
if not os.getenv("KT_CALLABLE_TYPE") == "app":
|
|
1314
|
+
load_callable()
|
|
1315
|
+
|
|
1316
|
+
logger.info("Kubetorch Server started.")
|
|
1317
|
+
request_id_ctx_var.set("-") # Reset request_id after launch sequence
|
|
1318
|
+
yield
|
|
1319
|
+
|
|
1320
|
+
except Exception:
|
|
1321
|
+
# We don't want to raise errors like ImportError during startup, as it will cause the server to crash and the
|
|
1322
|
+
# user won't be able to see the error in the logs to debug (e.g. quickly add dependencies or reorganize
|
|
1323
|
+
# imports). Instead, we log it (and a stack trace) and continue, so it will be surfaced to the user when they
|
|
1324
|
+
# call the service.
|
|
1325
|
+
|
|
1326
|
+
# However if this service is frozen, it should just fail because the user isn't debugging the service and there is no
|
|
1327
|
+
# way for the dependencies to be added at runtime.
|
|
1328
|
+
logger.error(traceback.format_exc())
|
|
1329
|
+
request_id_ctx_var.set("-")
|
|
1330
|
+
yield
|
|
1331
|
+
|
|
1332
|
+
finally:
|
|
1333
|
+
# Flush OpenTelemetry traces before shutdown
|
|
1334
|
+
if instrument_traces:
|
|
1335
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
1336
|
+
|
|
1337
|
+
tracer_provider = trace.get_tracer_provider()
|
|
1338
|
+
if isinstance(tracer_provider, TracerProvider):
|
|
1339
|
+
logger.info("Forcing OpenTelemetry span flush before shutdown")
|
|
1340
|
+
tracer_provider.force_flush()
|
|
1341
|
+
|
|
1342
|
+
# Shutdown
|
|
1343
|
+
manager = getattr(app.state, "heartbeat_manager", None)
|
|
1344
|
+
if manager:
|
|
1345
|
+
await manager.stop()
|
|
1346
|
+
logger.info("Heartbeat manager stopped")
|
|
1347
|
+
|
|
1348
|
+
# Clean up during normal shutdown so we don't leave any hanging processes, which can cause pods to hang
|
|
1349
|
+
# indefinitely. Skip if already cleaned up by SIGTERM handler.
|
|
1350
|
+
if DISTRIBUTED_SUPERVISOR and not getattr(app.state, "terminating", False):
|
|
1351
|
+
DISTRIBUTED_SUPERVISOR.cleanup()
|
|
1352
|
+
|
|
1353
|
+
# Clear any remaining debugging sessions
|
|
1354
|
+
clear_debugging_sessions()
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
# Add the filter to uvicorn's access logger
|
|
1358
|
+
logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
|
|
1359
|
+
root_logger = logging.getLogger()
|
|
1360
|
+
root_logger.addFilter(RequestContextFilter())
|
|
1361
|
+
for handler in root_logger.handlers:
|
|
1362
|
+
handler.addFilter(RequestContextFilter())
|
|
1363
|
+
print_logger.addFilter(RequestContextFilter())
|
|
1364
|
+
|
|
1365
|
+
app = FastAPI(lifespan=lifespan)
|
|
1366
|
+
app.add_middleware(TerminationCheckMiddleware) # Check termination first
|
|
1367
|
+
app.add_middleware(RequestIDMiddleware)
|
|
1368
|
+
|
|
1369
|
+
# Configure the FastAPI app for metrics first
|
|
1370
|
+
# Method will return None for meter_provider if otel is not enabled
|
|
1371
|
+
app, meter_provider = (
|
|
1372
|
+
setup_otel_metrics(app) if KT_OTEL_ENABLED is True else (app, None)
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
# Now instrument for traces and metrics together
|
|
1376
|
+
if instrument_traces:
|
|
1377
|
+
logger.info("Instrumenting FastAPI app for traces and metrics")
|
|
1378
|
+
FastAPIInstrumentor.instrument_app(
|
|
1379
|
+
app,
|
|
1380
|
+
meter_provider=meter_provider,
|
|
1381
|
+
excluded_urls="/metrics,/health",
|
|
1382
|
+
)
|
|
1383
|
+
logger.info("Adding TraceFlushMiddleware to flush traces")
|
|
1384
|
+
app.add_middleware(TraceFlushMiddleware)
|
|
1385
|
+
elif meter_provider is not None:
|
|
1386
|
+
try:
|
|
1387
|
+
# Skipped if instrument_traces is False, need to reimplement if we want to use metrics only
|
|
1388
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
1389
|
+
|
|
1390
|
+
logger.info("Instrumenting FastAPI app for metrics only")
|
|
1391
|
+
FastAPIInstrumentor.instrument_app(
|
|
1392
|
+
app,
|
|
1393
|
+
meter_provider=meter_provider,
|
|
1394
|
+
excluded_urls="/,/metrics,/health",
|
|
1395
|
+
)
|
|
1396
|
+
except ImportError:
|
|
1397
|
+
logger.info(
|
|
1398
|
+
"OpenTelemetry instrumentation not enabled, skipping metrics instrumentation"
|
|
1399
|
+
)
|
|
1400
|
+
|
|
1401
|
+
# add route for fastapi app
|
|
1402
|
+
if os.getenv("KT_CALLABLE_TYPE") == "app" and os.getenv("KT_APP_PORT"):
|
|
1403
|
+
logger.debug("Adding route for path /http")
|
|
1404
|
+
app.add_route(
|
|
1405
|
+
"/http/{path:path}",
|
|
1406
|
+
_http_reverse_proxy,
|
|
1407
|
+
["GET", "POST", "PUT", "DELETE", "PATCH"],
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
|
|
1411
|
+
#####################################
|
|
1412
|
+
########## Error Handling ###########
|
|
1413
|
+
#####################################
|
|
1414
|
+
class ErrorResponse(BaseModel):
|
|
1415
|
+
error_type: str
|
|
1416
|
+
message: str
|
|
1417
|
+
traceback: str
|
|
1418
|
+
pod_name: str
|
|
1419
|
+
state: Optional[dict] = None # Optional serialized exception state
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
# Factor out the exception packaging so we can use it in the handler below and also inside distributed subprocesses
|
|
1423
|
+
def package_exception(exc: Exception):
|
|
1424
|
+
import asyncio
|
|
1425
|
+
import concurrent
|
|
1426
|
+
|
|
1427
|
+
error_type = exc.__class__.__name__
|
|
1428
|
+
trace = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
|
1429
|
+
|
|
1430
|
+
# Check if the exception has a status_code attribute (e.g. PodTerminatedError)
|
|
1431
|
+
if hasattr(exc, "status_code"):
|
|
1432
|
+
status_code = exc.status_code
|
|
1433
|
+
elif isinstance(exc, (RequestValidationError, TypeError, AssertionError)):
|
|
1434
|
+
status_code = 422
|
|
1435
|
+
elif isinstance(exc, (ValueError, UnicodeError, json.JSONDecodeError)):
|
|
1436
|
+
status_code = 400
|
|
1437
|
+
elif isinstance(exc, (KeyError, FileNotFoundError)):
|
|
1438
|
+
status_code = 404
|
|
1439
|
+
elif isinstance(exc, PermissionError):
|
|
1440
|
+
status_code = 403
|
|
1441
|
+
elif isinstance(exc, (StarletteHTTPException, HTTPException)):
|
|
1442
|
+
status_code = exc.status_code
|
|
1443
|
+
elif isinstance(exc, (MemoryError, OSError)):
|
|
1444
|
+
status_code = 500
|
|
1445
|
+
elif isinstance(exc, NotImplementedError):
|
|
1446
|
+
status_code = 501
|
|
1447
|
+
elif isinstance(exc, asyncio.TimeoutError):
|
|
1448
|
+
status_code = 504
|
|
1449
|
+
elif isinstance(exc, concurrent.futures.TimeoutError):
|
|
1450
|
+
status_code = 504
|
|
1451
|
+
else:
|
|
1452
|
+
status_code = 500
|
|
1453
|
+
|
|
1454
|
+
# Try to serialize exception state if it has __getstate__
|
|
1455
|
+
state = None
|
|
1456
|
+
if hasattr(exc, "__getstate__"):
|
|
1457
|
+
try:
|
|
1458
|
+
state = exc.__getstate__()
|
|
1459
|
+
except Exception as e:
|
|
1460
|
+
logger.debug(f"Could not serialize exception state for {error_type}: {e}")
|
|
1461
|
+
|
|
1462
|
+
error_response = ErrorResponse(
|
|
1463
|
+
error_type=error_type,
|
|
1464
|
+
message=str(exc),
|
|
1465
|
+
traceback=trace,
|
|
1466
|
+
pod_name=os.getenv("POD_NAME"),
|
|
1467
|
+
state=state,
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
return JSONResponse(status_code=status_code, content=error_response.model_dump())
|
|
1471
|
+
|
|
1472
|
+
|
|
1473
|
+
@app.exception_handler(Exception)
|
|
1474
|
+
async def generic_exception_handler(request: Request, exc: Exception):
|
|
1475
|
+
return package_exception(exc)
|
|
1476
|
+
|
|
1477
|
+
|
|
1478
|
+
@app.post("/_reload_image", response_class=JSONResponse)
|
|
1479
|
+
def _reload_image(
|
|
1480
|
+
request: Request,
|
|
1481
|
+
deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
|
|
1482
|
+
):
|
|
1483
|
+
"""
|
|
1484
|
+
Endpoint to reload the image and metadata configuration.
|
|
1485
|
+
This is used to reload the image in cases where we're not calling the callable directly,
|
|
1486
|
+
e.g. kt.app and Ray workers.
|
|
1487
|
+
"""
|
|
1488
|
+
global _LAST_DEPLOYED
|
|
1489
|
+
deployed_time = (
|
|
1490
|
+
datetime.fromisoformat(deployed_as_of).timestamp()
|
|
1491
|
+
if deployed_as_of
|
|
1492
|
+
else datetime.now(timezone.utc).timestamp()
|
|
1493
|
+
)
|
|
1494
|
+
run_image_setup(deployed_time)
|
|
1495
|
+
_LAST_DEPLOYED = deployed_time
|
|
1496
|
+
return JSONResponse(
|
|
1497
|
+
status_code=200,
|
|
1498
|
+
content={"message": "Image and metadata reloaded successfully."},
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
|
|
1502
|
+
@app.post("/{cls_or_fn_name}", response_class=JSONResponse)
|
|
1503
|
+
@app.post("/{cls_or_fn_name}/{method_name}", response_class=JSONResponse)
|
|
1504
|
+
async def run_callable(
|
|
1505
|
+
request: Request,
|
|
1506
|
+
cls_or_fn_name: str,
|
|
1507
|
+
method_name: Optional[str] = None,
|
|
1508
|
+
distributed_subcall=False,
|
|
1509
|
+
debug_port: Optional[int] = None,
|
|
1510
|
+
params: Optional[Union[Dict, str]] = Body(default=None),
|
|
1511
|
+
deployed_as_of: Optional[str] = Header(None, alias="X-Deployed-As-Of"),
|
|
1512
|
+
serialization: str = Header("json", alias="X-Serialization"),
|
|
1513
|
+
):
|
|
1514
|
+
if cls_or_fn_name != os.environ["KT_CLS_OR_FN_NAME"]:
|
|
1515
|
+
raise HTTPException(
|
|
1516
|
+
status_code=404,
|
|
1517
|
+
detail=f"Callable '{cls_or_fn_name}' not found in metadata configuration. Found '{os.environ['KT_CLS_OR_FN_NAME']}' instead",
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
# NOTE: The distributed replica processes (e.g. PyTorchProcess:run) rely on this running here even though
|
|
1521
|
+
# they will reconstruct the callable themselves, because they skip image reloading as a performance optimization.
|
|
1522
|
+
# Run load_callable in executor since it may do file I/O and other blocking operations
|
|
1523
|
+
callable_obj = await run_in_executor_with_context(
|
|
1524
|
+
None, load_callable, deployed_as_of
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
# If this is a distributed call (and not a subcall from a different distributed replica),
|
|
1528
|
+
# and the type of distribution which requires a special call method (e.g. SIMD), use the
|
|
1529
|
+
# distributed supervisor to handle the call
|
|
1530
|
+
if DISTRIBUTED_SUPERVISOR and DISTRIBUTED_SUPERVISOR.intercept_call():
|
|
1531
|
+
# Run the blocking distributed call in executor to avoid blocking the event loop
|
|
1532
|
+
result = await run_in_executor_with_context(
|
|
1533
|
+
None,
|
|
1534
|
+
DISTRIBUTED_SUPERVISOR.call_distributed,
|
|
1535
|
+
request,
|
|
1536
|
+
cls_or_fn_name,
|
|
1537
|
+
method_name,
|
|
1538
|
+
params,
|
|
1539
|
+
distributed_subcall,
|
|
1540
|
+
debug_port,
|
|
1541
|
+
deployed_as_of,
|
|
1542
|
+
)
|
|
1543
|
+
clear_debugging_sessions()
|
|
1544
|
+
return result
|
|
1545
|
+
|
|
1546
|
+
# If this is not a distributed call, or the distribution type does not require special handling,
|
|
1547
|
+
# run the callable directly
|
|
1548
|
+
result = await run_callable_internal(
|
|
1549
|
+
callable_obj=callable_obj,
|
|
1550
|
+
cls_or_fn_name=cls_or_fn_name,
|
|
1551
|
+
method_name=method_name,
|
|
1552
|
+
params=params,
|
|
1553
|
+
serialization=serialization,
|
|
1554
|
+
debug_port=debug_port,
|
|
1555
|
+
)
|
|
1556
|
+
return result
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
async def run_callable_internal(
|
|
1560
|
+
callable_obj: Callable,
|
|
1561
|
+
cls_or_fn_name: str,
|
|
1562
|
+
method_name: Optional[str] = None,
|
|
1563
|
+
params: Optional[Union[Dict, str]] = Body(default=None),
|
|
1564
|
+
serialization: str = "json",
|
|
1565
|
+
debug_port: Optional[int] = None,
|
|
1566
|
+
):
|
|
1567
|
+
# Check if serialization is allowed
|
|
1568
|
+
allowed_serialization = os.getenv(
|
|
1569
|
+
"KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION
|
|
1570
|
+
).split(",")
|
|
1571
|
+
if serialization not in allowed_serialization:
|
|
1572
|
+
raise HTTPException(
|
|
1573
|
+
status_code=400,
|
|
1574
|
+
detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
|
|
1575
|
+
)
|
|
1576
|
+
|
|
1577
|
+
# Process the call
|
|
1578
|
+
args = []
|
|
1579
|
+
kwargs = {}
|
|
1580
|
+
if params:
|
|
1581
|
+
if serialization == "pickle":
|
|
1582
|
+
# Handle pickle serialization - extract data from dictionary wrapper
|
|
1583
|
+
if isinstance(params, dict) and "data" in params:
|
|
1584
|
+
encoded_data = params.pop("data")
|
|
1585
|
+
pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
|
|
1586
|
+
param_args = pickle.loads(pickled_data)
|
|
1587
|
+
# data is unpickled in the format {"args": args, "kwargs": kwargs}
|
|
1588
|
+
params.update(param_args)
|
|
1589
|
+
elif isinstance(params, str):
|
|
1590
|
+
# Fallback for direct string
|
|
1591
|
+
pickled_data = base64.b64decode(params.encode("utf-8"))
|
|
1592
|
+
params = pickle.loads(pickled_data)
|
|
1593
|
+
|
|
1594
|
+
# Default JSON handling
|
|
1595
|
+
args = params.get("args", [])
|
|
1596
|
+
kwargs = params.get("kwargs", {})
|
|
1597
|
+
|
|
1598
|
+
if method_name:
|
|
1599
|
+
if not hasattr(callable_obj, method_name):
|
|
1600
|
+
raise HTTPException(
|
|
1601
|
+
status_code=404,
|
|
1602
|
+
detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
|
|
1603
|
+
)
|
|
1604
|
+
user_method = getattr(callable_obj, method_name)
|
|
1605
|
+
else:
|
|
1606
|
+
user_method = callable_obj
|
|
1607
|
+
|
|
1608
|
+
import inspect
|
|
1609
|
+
|
|
1610
|
+
# Check if the user method is async
|
|
1611
|
+
is_async_method = inspect.iscoroutinefunction(user_method)
|
|
1612
|
+
|
|
1613
|
+
if debug_port:
|
|
1614
|
+
logger.info(
|
|
1615
|
+
f"Debugging remote callable {cls_or_fn_name}.{method_name} on port {debug_port}"
|
|
1616
|
+
)
|
|
1617
|
+
deep_breakpoint(debug_port)
|
|
1618
|
+
# If using the debugger, step in here ("s") to enter your function/class method.
|
|
1619
|
+
if is_async_method:
|
|
1620
|
+
result = await user_method(*args, **kwargs)
|
|
1621
|
+
else:
|
|
1622
|
+
# Run sync method in thread pool to avoid blocking
|
|
1623
|
+
# Use lambda to properly pass both args and kwargs
|
|
1624
|
+
result = await run_in_executor_with_context(
|
|
1625
|
+
None, lambda: user_method(*args, **kwargs)
|
|
1626
|
+
)
|
|
1627
|
+
else:
|
|
1628
|
+
logger.debug(f"Calling remote callable {cls_or_fn_name}.{method_name}")
|
|
1629
|
+
if is_async_method:
|
|
1630
|
+
result = await user_method(*args, **kwargs)
|
|
1631
|
+
else:
|
|
1632
|
+
# Run sync method in thread pool to avoid blocking
|
|
1633
|
+
# Use lambda to properly pass both args and kwargs
|
|
1634
|
+
result = await run_in_executor_with_context(
|
|
1635
|
+
None, lambda: user_method(*args, **kwargs)
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
# Handle case where sync method returns an awaitable (e.g., from an async framework)
|
|
1639
|
+
# This is less common but can happen with some async libraries
|
|
1640
|
+
if isinstance(result, Awaitable):
|
|
1641
|
+
result = await result
|
|
1642
|
+
|
|
1643
|
+
# Serialize response based on format
|
|
1644
|
+
if serialization == "pickle":
|
|
1645
|
+
try:
|
|
1646
|
+
pickled_result = pickle.dumps(result)
|
|
1647
|
+
encoded_result = base64.b64encode(pickled_result).decode("utf-8")
|
|
1648
|
+
result = {"data": encoded_result}
|
|
1649
|
+
except Exception as e:
|
|
1650
|
+
logger.error(f"Failed to pickle result: {str(e)}")
|
|
1651
|
+
raise SerializationError(
|
|
1652
|
+
f"Result could not be serialized with pickle: {str(e)}"
|
|
1653
|
+
)
|
|
1654
|
+
else:
|
|
1655
|
+
# Default JSON serialization
|
|
1656
|
+
try:
|
|
1657
|
+
json.dumps(result)
|
|
1658
|
+
except (TypeError, ValueError) as e:
|
|
1659
|
+
logger.error(f"Result is not JSON serializable: {str(e)}")
|
|
1660
|
+
raise SerializationError(
|
|
1661
|
+
f"Result could not be serialized to JSON: {str(e)}"
|
|
1662
|
+
)
|
|
1663
|
+
|
|
1664
|
+
clear_debugging_sessions()
|
|
1665
|
+
|
|
1666
|
+
return result
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
def run_callable_internal_sync(
|
|
1670
|
+
callable_obj: Callable,
|
|
1671
|
+
cls_or_fn_name: str,
|
|
1672
|
+
method_name: Optional[str] = None,
|
|
1673
|
+
params: Optional[Union[Dict, str]] = None,
|
|
1674
|
+
serialization: str = "json",
|
|
1675
|
+
debug_port: Optional[int] = None,
|
|
1676
|
+
):
|
|
1677
|
+
"""Synchronous wrapper for run_callable_internal, used by distributed subprocesses."""
|
|
1678
|
+
import asyncio
|
|
1679
|
+
import inspect
|
|
1680
|
+
|
|
1681
|
+
# Check if serialization is allowed
|
|
1682
|
+
allowed_serialization = os.getenv(
|
|
1683
|
+
"KT_ALLOWED_SERIALIZATION", DEFAULT_ALLOWED_SERIALIZATION
|
|
1684
|
+
).split(",")
|
|
1685
|
+
if serialization not in allowed_serialization:
|
|
1686
|
+
raise HTTPException(
|
|
1687
|
+
status_code=400,
|
|
1688
|
+
detail=f"Serialization format '{serialization}' not allowed. Allowed formats: {allowed_serialization}",
|
|
1689
|
+
)
|
|
1690
|
+
|
|
1691
|
+
# Process the call
|
|
1692
|
+
args = []
|
|
1693
|
+
kwargs = {}
|
|
1694
|
+
if params:
|
|
1695
|
+
if serialization == "pickle":
|
|
1696
|
+
# Handle pickle serialization - extract data from dictionary wrapper
|
|
1697
|
+
if isinstance(params, dict) and "data" in params:
|
|
1698
|
+
encoded_data = params.pop("data")
|
|
1699
|
+
pickled_data = base64.b64decode(encoded_data.encode("utf-8"))
|
|
1700
|
+
param_args = pickle.loads(pickled_data)
|
|
1701
|
+
# data is unpickled in the format {"args": args, "kwargs": kwargs}
|
|
1702
|
+
params.update(param_args)
|
|
1703
|
+
elif isinstance(params, str):
|
|
1704
|
+
# Fallback for direct string
|
|
1705
|
+
pickled_data = base64.b64decode(params.encode("utf-8"))
|
|
1706
|
+
params = pickle.loads(pickled_data)
|
|
1707
|
+
|
|
1708
|
+
# Default JSON handling
|
|
1709
|
+
args = params.get("args", [])
|
|
1710
|
+
kwargs = params.get("kwargs", {})
|
|
1711
|
+
|
|
1712
|
+
if method_name:
|
|
1713
|
+
if not hasattr(callable_obj, method_name):
|
|
1714
|
+
raise HTTPException(
|
|
1715
|
+
status_code=404,
|
|
1716
|
+
detail=f"Method '{method_name}' not found in class '{cls_or_fn_name}'",
|
|
1717
|
+
)
|
|
1718
|
+
user_method = getattr(callable_obj, method_name)
|
|
1719
|
+
else:
|
|
1720
|
+
user_method = callable_obj
|
|
1721
|
+
|
|
1722
|
+
# Check if the user method is async
|
|
1723
|
+
is_async_method = inspect.iscoroutinefunction(user_method)
|
|
1724
|
+
|
|
1725
|
+
if debug_port:
|
|
1726
|
+
logger.info(
|
|
1727
|
+
f"Debugging remote callable {cls_or_fn_name}.{method_name} on port {debug_port}"
|
|
1728
|
+
)
|
|
1729
|
+
deep_breakpoint(debug_port)
|
|
1730
|
+
# If using the debugger, step in here ("s") to enter your function/class method.
|
|
1731
|
+
if is_async_method:
|
|
1732
|
+
# For async methods in sync context, we need to run them in a new event loop
|
|
1733
|
+
result = asyncio.run(user_method(*args, **kwargs))
|
|
1734
|
+
else:
|
|
1735
|
+
result = user_method(*args, **kwargs)
|
|
1736
|
+
else:
|
|
1737
|
+
logger.debug(f"Calling remote callable {cls_or_fn_name}.{method_name}")
|
|
1738
|
+
if is_async_method:
|
|
1739
|
+
# For async methods in sync context, we need to run them in a new event loop
|
|
1740
|
+
result = asyncio.run(user_method(*args, **kwargs))
|
|
1741
|
+
else:
|
|
1742
|
+
result = user_method(*args, **kwargs)
|
|
1743
|
+
|
|
1744
|
+
# Handle case where sync method returns an awaitable
|
|
1745
|
+
if isinstance(result, Awaitable):
|
|
1746
|
+
result = asyncio.run(result)
|
|
1747
|
+
|
|
1748
|
+
# Serialize response based on format
|
|
1749
|
+
if serialization == "pickle":
|
|
1750
|
+
try:
|
|
1751
|
+
pickled_result = pickle.dumps(result)
|
|
1752
|
+
encoded_result = base64.b64encode(pickled_result).decode("utf-8")
|
|
1753
|
+
result = {"data": encoded_result}
|
|
1754
|
+
except Exception as e:
|
|
1755
|
+
logger.error(f"Failed to pickle result: {str(e)}")
|
|
1756
|
+
raise SerializationError(
|
|
1757
|
+
f"Result could not be serialized with pickle: {str(e)}"
|
|
1758
|
+
)
|
|
1759
|
+
else:
|
|
1760
|
+
# Default JSON serialization
|
|
1761
|
+
try:
|
|
1762
|
+
json.dumps(result)
|
|
1763
|
+
except (TypeError, ValueError) as e:
|
|
1764
|
+
logger.error(f"Result is not JSON serializable: {str(e)}")
|
|
1765
|
+
raise SerializationError(
|
|
1766
|
+
f"Result could not be serialized to JSON: {str(e)}"
|
|
1767
|
+
)
|
|
1768
|
+
|
|
1769
|
+
clear_debugging_sessions()
|
|
1770
|
+
|
|
1771
|
+
return result
|
|
1772
|
+
|
|
1773
|
+
|
|
1774
|
+
@app.get("/health", include_in_schema=False)
|
|
1775
|
+
@app.get("/", include_in_schema=False)
|
|
1776
|
+
def health():
|
|
1777
|
+
return {"status": "healthy"}
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
if __name__ == "__main__" and not is_running_in_container():
|
|
1781
|
+
# NOTE: this will only run in local development, otherwise we start the uvicorn server in the pod template setup
|
|
1782
|
+
import uvicorn
|
|
1783
|
+
from dotenv import load_dotenv
|
|
1784
|
+
|
|
1785
|
+
load_dotenv()
|
|
1786
|
+
|
|
1787
|
+
logger.info("Starting HTTP server")
|
|
1788
|
+
uvicorn.run(app, host="0.0.0.0", port=os.environ.get("KT_SERVER_PORT", 32300))
|