zenml-nightly 0.82.1.dev20250521__py3-none-any.whl → 0.82.1.dev20250524__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/client.py +6 -2
- zenml/config/build_configuration.py +7 -0
- zenml/config/docker_settings.py +25 -0
- zenml/config/server_config.py +7 -0
- zenml/constants.py +1 -0
- zenml/enums.py +1 -0
- zenml/environment.py +12 -0
- zenml/integrations/gcp/__init__.py +1 -1
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +12 -11
- zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py +1 -1
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +16 -12
- zenml/materializers/built_in_materializer.py +9 -3
- zenml/zen_server/cloud_utils.py +45 -21
- zenml/zen_server/routers/actions_endpoints.py +6 -6
- zenml/zen_server/routers/artifact_endpoint.py +6 -6
- zenml/zen_server/routers/artifact_version_endpoints.py +11 -11
- zenml/zen_server/routers/auth_endpoints.py +4 -3
- zenml/zen_server/routers/code_repositories_endpoints.py +6 -6
- zenml/zen_server/routers/devices_endpoints.py +6 -6
- zenml/zen_server/routers/event_source_endpoints.py +6 -6
- zenml/zen_server/routers/flavors_endpoints.py +7 -7
- zenml/zen_server/routers/logs_endpoints.py +2 -2
- zenml/zen_server/routers/model_versions_endpoints.py +13 -13
- zenml/zen_server/routers/models_endpoints.py +6 -6
- zenml/zen_server/routers/pipeline_builds_endpoints.py +5 -5
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +6 -6
- zenml/zen_server/routers/pipelines_endpoints.py +7 -7
- zenml/zen_server/routers/plugin_endpoints.py +3 -3
- zenml/zen_server/routers/projects_endpoints.py +7 -7
- zenml/zen_server/routers/run_metadata_endpoints.py +2 -2
- zenml/zen_server/routers/run_templates_endpoints.py +7 -7
- zenml/zen_server/routers/runs_endpoints.py +11 -11
- zenml/zen_server/routers/schedule_endpoints.py +6 -6
- zenml/zen_server/routers/secrets_endpoints.py +8 -8
- zenml/zen_server/routers/server_endpoints.py +13 -9
- zenml/zen_server/routers/service_accounts_endpoints.py +12 -12
- zenml/zen_server/routers/service_connectors_endpoints.py +13 -13
- zenml/zen_server/routers/service_endpoints.py +6 -6
- zenml/zen_server/routers/stack_components_endpoints.py +18 -9
- zenml/zen_server/routers/stack_deployment_endpoints.py +4 -4
- zenml/zen_server/routers/stacks_endpoints.py +6 -6
- zenml/zen_server/routers/steps_endpoints.py +8 -8
- zenml/zen_server/routers/tag_resource_endpoints.py +5 -5
- zenml/zen_server/routers/tags_endpoints.py +6 -6
- zenml/zen_server/routers/triggers_endpoints.py +9 -9
- zenml/zen_server/routers/users_endpoints.py +12 -12
- zenml/zen_server/routers/webhook_endpoints.py +2 -2
- zenml/zen_server/utils.py +72 -33
- zenml/zen_server/zen_server_api.py +211 -55
- zenml/zen_stores/rest_zen_store.py +40 -11
- zenml/zen_stores/sql_zen_store.py +79 -2
- {zenml_nightly-0.82.1.dev20250521.dist-info → zenml_nightly-0.82.1.dev20250524.dist-info}/METADATA +1 -1
- {zenml_nightly-0.82.1.dev20250521.dist-info → zenml_nightly-0.82.1.dev20250524.dist-info}/RECORD +57 -57
- {zenml_nightly-0.82.1.dev20250521.dist-info → zenml_nightly-0.82.1.dev20250524.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.82.1.dev20250521.dist-info → zenml_nightly-0.82.1.dev20250524.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.82.1.dev20250521.dist-info → zenml_nightly-0.82.1.dev20250524.dist-info}/entry_points.txt +0 -0
zenml/zen_server/utils.py
CHANGED
@@ -15,10 +15,12 @@
|
|
15
15
|
|
16
16
|
import inspect
|
17
17
|
import os
|
18
|
+
import threading
|
18
19
|
from functools import wraps
|
19
20
|
from typing import (
|
20
21
|
TYPE_CHECKING,
|
21
22
|
Any,
|
23
|
+
Awaitable,
|
22
24
|
Callable,
|
23
25
|
Dict,
|
24
26
|
List,
|
@@ -27,11 +29,11 @@ from typing import (
|
|
27
29
|
Type,
|
28
30
|
TypeVar,
|
29
31
|
Union,
|
30
|
-
cast,
|
31
32
|
)
|
32
33
|
from uuid import UUID
|
33
34
|
|
34
35
|
from pydantic import BaseModel, ValidationError
|
36
|
+
from typing_extensions import ParamSpec
|
35
37
|
|
36
38
|
from zenml import __version__ as zenml_version
|
37
39
|
from zenml.config.global_config import GlobalConfiguration
|
@@ -64,6 +66,11 @@ if TYPE_CHECKING:
|
|
64
66
|
BoundedThreadPoolExecutor,
|
65
67
|
)
|
66
68
|
|
69
|
+
|
70
|
+
P = ParamSpec("P")
|
71
|
+
R = TypeVar("R")
|
72
|
+
|
73
|
+
|
67
74
|
logger = get_logger(__name__)
|
68
75
|
|
69
76
|
_zen_store: Optional["SqlZenStore"] = None
|
@@ -298,11 +305,16 @@ def server_config() -> ServerConfiguration:
|
|
298
305
|
return _server_config
|
299
306
|
|
300
307
|
|
301
|
-
|
302
|
-
|
308
|
+
def async_fastapi_endpoint_wrapper(
|
309
|
+
func: Callable[P, R],
|
310
|
+
) -> Callable[P, Awaitable[Any]]:
|
311
|
+
"""Decorator for FastAPI endpoints.
|
303
312
|
|
304
|
-
|
305
|
-
|
313
|
+
This decorator for FastAPI endpoints does the following:
|
314
|
+
- Sets the auth_context context variable if the endpoint is authenticated.
|
315
|
+
- Converts exceptions to HTTPExceptions with the correct status code.
|
316
|
+
- Converts the sync endpoint function to an coroutine and runs the original
|
317
|
+
function in a worker threadpool. See below for more details.
|
306
318
|
|
307
319
|
Args:
|
308
320
|
func: Function to decorate.
|
@@ -311,41 +323,68 @@ def handle_exceptions(func: F) -> F:
|
|
311
323
|
Decorated function.
|
312
324
|
"""
|
313
325
|
|
326
|
+
# When having a sync FastAPI endpoint, it runs the endpoint function in
|
327
|
+
# a worker threadpool. If all threads are busy, it will queue the task.
|
328
|
+
# The problem is that after the endpoint code returns, FastAPI will queue
|
329
|
+
# another task in the same threadpool to serialize the response. If there
|
330
|
+
# are many tasks already in the queue, this means that the response
|
331
|
+
# serialization will wait for a long time instead of returning the response
|
332
|
+
# immediately. By making our endpoints async and then immediately
|
333
|
+
# dispatching them to the threadpool ourselves (which is essentially what
|
334
|
+
# FastAPI does when having a sync endpoint), we can avoid this problem.
|
335
|
+
# The serialization logic will now run on the event loop and not wait for
|
336
|
+
# a worker thread to become available.
|
337
|
+
# See: `fastapi.routing.serialize_response(...)` and
|
338
|
+
# https://github.com/fastapi/fastapi/pull/888 for more information.
|
314
339
|
@wraps(func)
|
315
|
-
def
|
316
|
-
|
317
|
-
# used by the CLI when installed without the `server` extra
|
318
|
-
from fastapi import HTTPException
|
319
|
-
from fastapi.responses import JSONResponse
|
340
|
+
async def async_decorated(*args: P.args, **kwargs: P.kwargs) -> Any:
|
341
|
+
from starlette.concurrency import run_in_threadpool
|
320
342
|
|
321
|
-
from zenml.zen_server.
|
343
|
+
from zenml.zen_server.zen_server_api import request_ids
|
322
344
|
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
345
|
+
request_id = request_ids.get()
|
346
|
+
|
347
|
+
@wraps(func)
|
348
|
+
def decorated(*args: P.args, **kwargs: P.kwargs) -> Any:
|
349
|
+
# These imports can't happen at module level as this module is also
|
350
|
+
# used by the CLI when installed without the `server` extra
|
351
|
+
from fastapi import HTTPException
|
352
|
+
from fastapi.responses import JSONResponse
|
353
|
+
|
354
|
+
from zenml.zen_server.auth import AuthContext, set_auth_context
|
355
|
+
|
356
|
+
if request_id:
|
357
|
+
# Change the name of the current thread to the request ID
|
358
|
+
threading.current_thread().name = request_id
|
359
|
+
|
360
|
+
for arg in args:
|
329
361
|
if isinstance(arg, AuthContext):
|
330
362
|
set_auth_context(arg)
|
331
363
|
break
|
364
|
+
else:
|
365
|
+
for _, arg in kwargs.items():
|
366
|
+
if isinstance(arg, AuthContext):
|
367
|
+
set_auth_context(arg)
|
368
|
+
break
|
369
|
+
|
370
|
+
try:
|
371
|
+
return func(*args, **kwargs)
|
372
|
+
except OAuthError as error:
|
373
|
+
# The OAuthError is special because it needs to have a JSON response
|
374
|
+
return JSONResponse(
|
375
|
+
status_code=error.status_code,
|
376
|
+
content=error.to_dict(),
|
377
|
+
)
|
378
|
+
except HTTPException:
|
379
|
+
raise
|
380
|
+
except Exception as error:
|
381
|
+
logger.exception("API error")
|
382
|
+
http_exception = http_exception_from_error(error)
|
383
|
+
raise http_exception
|
332
384
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
# The OAuthError is special because it needs to have a JSON response
|
337
|
-
return JSONResponse(
|
338
|
-
status_code=error.status_code,
|
339
|
-
content=error.to_dict(),
|
340
|
-
)
|
341
|
-
except HTTPException:
|
342
|
-
raise
|
343
|
-
except Exception as error:
|
344
|
-
logger.exception("API error")
|
345
|
-
http_exception = http_exception_from_error(error)
|
346
|
-
raise http_exception
|
347
|
-
|
348
|
-
return cast(F, decorated)
|
385
|
+
return await run_in_threadpool(decorated, *args, **kwargs)
|
386
|
+
|
387
|
+
return async_decorated
|
349
388
|
|
350
389
|
|
351
390
|
# Code from https://github.com/tiangolo/fastapi/issues/1474#issuecomment-1160633178
|
@@ -20,11 +20,17 @@ To run this file locally, execute:
|
|
20
20
|
```
|
21
21
|
"""
|
22
22
|
|
23
|
+
import logging
|
23
24
|
import os
|
25
|
+
import threading
|
26
|
+
import time
|
27
|
+
from asyncio import Lock, Semaphore, TimeoutError, wait_for
|
24
28
|
from asyncio.log import logger
|
29
|
+
from contextvars import ContextVar
|
25
30
|
from datetime import datetime, timedelta
|
26
31
|
from genericpath import isfile
|
27
|
-
from typing import Any, List, Set
|
32
|
+
from typing import Any, List, Optional, Set
|
33
|
+
from uuid import uuid4
|
28
34
|
|
29
35
|
from anyio import to_thread
|
30
36
|
from fastapi import FastAPI, HTTPException, Request
|
@@ -113,6 +119,10 @@ from zenml.zen_server.utils import (
|
|
113
119
|
|
114
120
|
DASHBOARD_DIRECTORY = "dashboard"
|
115
121
|
|
122
|
+
request_ids: ContextVar[Optional[str]] = ContextVar(
|
123
|
+
"request_ids", default=None
|
124
|
+
)
|
125
|
+
|
116
126
|
|
117
127
|
def relative_path(rel: str) -> str:
|
118
128
|
"""Get the absolute path of a path relative to the ZenML server module.
|
@@ -138,6 +148,7 @@ last_user_activity: datetime = utc_now()
|
|
138
148
|
last_user_activity_reported: datetime = last_user_activity + timedelta(
|
139
149
|
seconds=-DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
|
140
150
|
)
|
151
|
+
last_user_activity_lock = Lock()
|
141
152
|
|
142
153
|
|
143
154
|
# Customize the default request validation handler that comes with FastAPI
|
@@ -247,51 +258,6 @@ class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
|
|
247
258
|
|
248
259
|
ALLOWED_FOR_FILE_UPLOAD: Set[str] = set()
|
249
260
|
|
250
|
-
app.add_middleware(
|
251
|
-
CORSMiddleware,
|
252
|
-
allow_origins=server_config().cors_allow_origins,
|
253
|
-
allow_credentials=True,
|
254
|
-
allow_methods=["*"],
|
255
|
-
allow_headers=["*"],
|
256
|
-
)
|
257
|
-
|
258
|
-
app.add_middleware(
|
259
|
-
RequestBodyLimit, max_bytes=server_config().max_request_body_size_in_bytes
|
260
|
-
)
|
261
|
-
app.add_middleware(
|
262
|
-
RestrictFileUploadsMiddleware, allowed_paths=ALLOWED_FOR_FILE_UPLOAD
|
263
|
-
)
|
264
|
-
|
265
|
-
|
266
|
-
@app.middleware("http")
|
267
|
-
async def set_secure_headers(request: Request, call_next: Any) -> Any:
|
268
|
-
"""Middleware to set secure headers.
|
269
|
-
|
270
|
-
Args:
|
271
|
-
request: The incoming request.
|
272
|
-
call_next: The next function to be called.
|
273
|
-
|
274
|
-
Returns:
|
275
|
-
The response with secure headers set.
|
276
|
-
"""
|
277
|
-
try:
|
278
|
-
response = await call_next(request)
|
279
|
-
except Exception:
|
280
|
-
logger.exception("An error occurred while processing the request")
|
281
|
-
response = JSONResponse(
|
282
|
-
status_code=500,
|
283
|
-
content={"detail": "An unexpected error occurred."},
|
284
|
-
)
|
285
|
-
|
286
|
-
# If the request is for the openAPI docs, don't set secure headers
|
287
|
-
if request.url.path.startswith("/docs") or request.url.path.startswith(
|
288
|
-
"/redoc"
|
289
|
-
):
|
290
|
-
return response
|
291
|
-
|
292
|
-
secure_headers().framework.fastapi(response)
|
293
|
-
return response
|
294
|
-
|
295
261
|
|
296
262
|
@app.middleware("http")
|
297
263
|
async def track_last_user_activity(request: Request, call_next: Any) -> Any:
|
@@ -310,24 +276,30 @@ async def track_last_user_activity(request: Request, call_next: Any) -> Any:
|
|
310
276
|
"""
|
311
277
|
global last_user_activity
|
312
278
|
global last_user_activity_reported
|
279
|
+
global last_user_activity_lock
|
313
280
|
|
314
281
|
now = utc_now()
|
315
282
|
|
316
283
|
try:
|
317
284
|
if is_user_request(request):
|
318
|
-
|
285
|
+
report_user_activity = False
|
286
|
+
async with last_user_activity_lock:
|
287
|
+
last_user_activity = now
|
288
|
+
if (
|
289
|
+
(now - last_user_activity_reported).total_seconds()
|
290
|
+
> DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
|
291
|
+
):
|
292
|
+
last_user_activity_reported = now
|
293
|
+
report_user_activity = True
|
294
|
+
|
295
|
+
if report_user_activity:
|
296
|
+
zen_store()._update_last_user_activity_timestamp(
|
297
|
+
last_user_activity=last_user_activity
|
298
|
+
)
|
319
299
|
except Exception as e:
|
320
300
|
logger.debug(
|
321
301
|
f"An unexpected error occurred while checking user activity: {e}"
|
322
302
|
)
|
323
|
-
if (
|
324
|
-
(now - last_user_activity_reported).total_seconds()
|
325
|
-
> DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
|
326
|
-
):
|
327
|
-
last_user_activity_reported = now
|
328
|
-
zen_store()._update_last_user_activity_timestamp(
|
329
|
-
last_user_activity=last_user_activity
|
330
|
-
)
|
331
303
|
|
332
304
|
try:
|
333
305
|
return await call_next(request)
|
@@ -378,6 +350,190 @@ async def infer_source_context(request: Request, call_next: Any) -> Any:
|
|
378
350
|
)
|
379
351
|
|
380
352
|
|
353
|
+
request_semaphore = Semaphore(server_config().thread_pool_size)
|
354
|
+
|
355
|
+
|
356
|
+
@app.middleware("http")
|
357
|
+
async def prevent_read_timeout(request: Request, call_next: Any) -> Any:
|
358
|
+
"""Prevent read timeout client errors.
|
359
|
+
|
360
|
+
Args:
|
361
|
+
request: The incoming request.
|
362
|
+
call_next: The next function to be called.
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
The response to the request.
|
366
|
+
"""
|
367
|
+
# Only process the REST API requests because these are the ones that
|
368
|
+
# take the most time to complete.
|
369
|
+
if not request.url.path.startswith(API):
|
370
|
+
return await call_next(request)
|
371
|
+
|
372
|
+
server_request_timeout = server_config().server_request_timeout
|
373
|
+
|
374
|
+
active_threads = threading.active_count()
|
375
|
+
request_id = request_ids.get()
|
376
|
+
|
377
|
+
client_ip = request.client.host if request.client else "unknown"
|
378
|
+
method = request.method
|
379
|
+
url_path = request.url.path
|
380
|
+
|
381
|
+
logger.debug(
|
382
|
+
f"[{request_id}] API STATS - {method} {url_path} from {client_ip} "
|
383
|
+
f"QUEUED [ "
|
384
|
+
f"threads: {active_threads} "
|
385
|
+
f"]"
|
386
|
+
)
|
387
|
+
|
388
|
+
start_time = time.time()
|
389
|
+
|
390
|
+
try:
|
391
|
+
# Here we wait until a worker thread is available to process the
|
392
|
+
# request with a timeout value that is set to be lower than the
|
393
|
+
# what the client is willing to wait for (i.e. lower than the
|
394
|
+
# client's HTTP request timeout). The rationale is that we want to
|
395
|
+
# respond to the client before it times out and decides to retry the
|
396
|
+
# request (which would overwhelm the server).
|
397
|
+
await wait_for(
|
398
|
+
request_semaphore.acquire(),
|
399
|
+
timeout=server_request_timeout,
|
400
|
+
)
|
401
|
+
except TimeoutError:
|
402
|
+
end_time = time.time()
|
403
|
+
duration = (end_time - start_time) * 1000
|
404
|
+
active_threads = threading.active_count()
|
405
|
+
|
406
|
+
logger.debug(
|
407
|
+
f"[{request_id}] API STATS - {method} {url_path} from {client_ip} "
|
408
|
+
f"THROTTLED after {duration:.2f}ms [ "
|
409
|
+
f"threads: {active_threads} "
|
410
|
+
f"]"
|
411
|
+
)
|
412
|
+
|
413
|
+
# We return a 429 error, basically telling the client to slow down.
|
414
|
+
# For the client, the 429 error is more meaningful than a ReadTimeout
|
415
|
+
# error, because it also tells the client two additional things:
|
416
|
+
#
|
417
|
+
# 1. The server is alive.
|
418
|
+
# 2. The server hasn't processed the request, so even if the request
|
419
|
+
# is not idempotent, it's safe to retry it.
|
420
|
+
return JSONResponse(
|
421
|
+
{"error": "Server too busy. Please try again later."},
|
422
|
+
status_code=429,
|
423
|
+
)
|
424
|
+
|
425
|
+
duration = (time.time() - start_time) * 1000
|
426
|
+
active_threads = threading.active_count()
|
427
|
+
|
428
|
+
logger.debug(
|
429
|
+
f"[{request_id}] API STATS - {method} {url_path} from {client_ip} "
|
430
|
+
f"ACCEPTED after {duration:.2f}ms [ "
|
431
|
+
f"threads: {active_threads} "
|
432
|
+
f"]"
|
433
|
+
)
|
434
|
+
|
435
|
+
try:
|
436
|
+
return await call_next(request)
|
437
|
+
finally:
|
438
|
+
request_semaphore.release()
|
439
|
+
|
440
|
+
|
441
|
+
@app.middleware("http")
|
442
|
+
async def log_requests(request: Request, call_next: Any) -> Any:
|
443
|
+
"""Log requests to the ZenML server.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
request: The incoming request object.
|
447
|
+
call_next: A function that will receive the request as a parameter and
|
448
|
+
pass it to the corresponding path operation.
|
449
|
+
|
450
|
+
Returns:
|
451
|
+
The response to the request.
|
452
|
+
"""
|
453
|
+
if not logger.isEnabledFor(logging.DEBUG):
|
454
|
+
return await call_next(request)
|
455
|
+
|
456
|
+
# Get active threads count
|
457
|
+
active_threads = threading.active_count()
|
458
|
+
|
459
|
+
request_id = request.headers.get("X-Request-ID", str(uuid4())[:8])
|
460
|
+
# Detect if the request comes from Python, Web UI or something else
|
461
|
+
if source := request.headers.get("User-Agent"):
|
462
|
+
source = source.split("/")[0]
|
463
|
+
request_id = f"{request_id}/{source}"
|
464
|
+
|
465
|
+
request_ids.set(request_id)
|
466
|
+
client_ip = request.client.host if request.client else "unknown"
|
467
|
+
method = request.method
|
468
|
+
url_path = request.url.path
|
469
|
+
|
470
|
+
logger.debug(
|
471
|
+
f"[{request_id}] API STATS - {method} {url_path} from {client_ip} "
|
472
|
+
f"RECEIVED [ "
|
473
|
+
f"threads: {active_threads} "
|
474
|
+
f"]"
|
475
|
+
)
|
476
|
+
|
477
|
+
start_time = time.time()
|
478
|
+
response = await call_next(request)
|
479
|
+
duration = (time.time() - start_time) * 1000
|
480
|
+
status_code = response.status_code
|
481
|
+
|
482
|
+
logger.debug(
|
483
|
+
f"[{request_id}] API STATS - {status_code} {method} {url_path} from "
|
484
|
+
f"{client_ip} took {duration:.2f}ms [ "
|
485
|
+
f"threads: {active_threads} "
|
486
|
+
f"]"
|
487
|
+
)
|
488
|
+
return response
|
489
|
+
|
490
|
+
|
491
|
+
app.add_middleware(
|
492
|
+
CORSMiddleware,
|
493
|
+
allow_origins=server_config().cors_allow_origins,
|
494
|
+
allow_credentials=True,
|
495
|
+
allow_methods=["*"],
|
496
|
+
allow_headers=["*"],
|
497
|
+
)
|
498
|
+
|
499
|
+
app.add_middleware(
|
500
|
+
RequestBodyLimit, max_bytes=server_config().max_request_body_size_in_bytes
|
501
|
+
)
|
502
|
+
app.add_middleware(
|
503
|
+
RestrictFileUploadsMiddleware, allowed_paths=ALLOWED_FOR_FILE_UPLOAD
|
504
|
+
)
|
505
|
+
|
506
|
+
|
507
|
+
@app.middleware("http")
|
508
|
+
async def set_secure_headers(request: Request, call_next: Any) -> Any:
|
509
|
+
"""Middleware to set secure headers.
|
510
|
+
|
511
|
+
Args:
|
512
|
+
request: The incoming request.
|
513
|
+
call_next: The next function to be called.
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
The response with secure headers set.
|
517
|
+
"""
|
518
|
+
try:
|
519
|
+
response = await call_next(request)
|
520
|
+
except Exception:
|
521
|
+
logger.exception("An error occurred while processing the request")
|
522
|
+
response = JSONResponse(
|
523
|
+
status_code=500,
|
524
|
+
content={"detail": "An unexpected error occurred."},
|
525
|
+
)
|
526
|
+
|
527
|
+
# If the request is for the openAPI docs, don't set secure headers
|
528
|
+
if request.url.path.startswith("/docs") or request.url.path.startswith(
|
529
|
+
"/redoc"
|
530
|
+
):
|
531
|
+
return response
|
532
|
+
|
533
|
+
secure_headers().framework.fastapi(response)
|
534
|
+
return response
|
535
|
+
|
536
|
+
|
381
537
|
@app.on_event("startup")
|
382
538
|
def initialize() -> None:
|
383
539
|
"""Initialize the ZenML server."""
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
+
import time
|
18
19
|
from datetime import datetime
|
19
20
|
from pathlib import Path
|
20
21
|
from typing import (
|
@@ -30,7 +31,7 @@ from typing import (
|
|
30
31
|
Union,
|
31
32
|
)
|
32
33
|
from urllib.parse import urlparse
|
33
|
-
from uuid import UUID
|
34
|
+
from uuid import UUID, uuid4
|
34
35
|
|
35
36
|
import requests
|
36
37
|
import urllib3
|
@@ -4179,6 +4180,20 @@ class RestZenStore(BaseZenStore):
|
|
4179
4180
|
Returns:
|
4180
4181
|
A requests session.
|
4181
4182
|
"""
|
4183
|
+
|
4184
|
+
class AugmentedRetry(Retry):
|
4185
|
+
"""Augmented retry class that also retries on 429 status codes for POST requests."""
|
4186
|
+
|
4187
|
+
def is_retry(
|
4188
|
+
self,
|
4189
|
+
method: str,
|
4190
|
+
status_code: int,
|
4191
|
+
has_retry_after: bool = False,
|
4192
|
+
) -> bool:
|
4193
|
+
if status_code == 429:
|
4194
|
+
return True
|
4195
|
+
return super().is_retry(method, status_code, has_retry_after)
|
4196
|
+
|
4182
4197
|
if self._session is None:
|
4183
4198
|
# We only need to initialize the session once over the lifetime
|
4184
4199
|
# of the client. We can swap the token out when it expires.
|
@@ -4208,7 +4223,7 @@ class RestZenStore(BaseZenStore):
|
|
4208
4223
|
# the timeout period.
|
4209
4224
|
# Connection Refused: If the server refuses the connection.
|
4210
4225
|
#
|
4211
|
-
retries =
|
4226
|
+
retries = AugmentedRetry(
|
4212
4227
|
connect=5,
|
4213
4228
|
read=8,
|
4214
4229
|
redirect=3,
|
@@ -4222,7 +4237,7 @@ class RestZenStore(BaseZenStore):
|
|
4222
4237
|
504, # Gateway Timeout
|
4223
4238
|
],
|
4224
4239
|
other=3,
|
4225
|
-
backoff_factor=
|
4240
|
+
backoff_factor=1,
|
4226
4241
|
)
|
4227
4242
|
self._session.mount("https://", HTTPAdapter(max_retries=retries))
|
4228
4243
|
self._session.mount("http://", HTTPAdapter(max_retries=retries))
|
@@ -4360,6 +4375,14 @@ class RestZenStore(BaseZenStore):
|
|
4360
4375
|
self.session.headers.update(
|
4361
4376
|
{source_context.name: source_context.get().value}
|
4362
4377
|
)
|
4378
|
+
# Add a request ID to the request headers
|
4379
|
+
request_id = str(uuid4())[:8]
|
4380
|
+
self.session.headers.update({"X-Request-ID": request_id})
|
4381
|
+
path = url.removeprefix(self.url)
|
4382
|
+
start_time = time.time()
|
4383
|
+
logger.debug(
|
4384
|
+
f"Sending {method} request to {path} with request ID {request_id}..."
|
4385
|
+
)
|
4363
4386
|
|
4364
4387
|
# If the server replies with a credentials validation (401 Unauthorized)
|
4365
4388
|
# error, we (re-)authenticate and retry the request here in the
|
@@ -4401,7 +4424,8 @@ class RestZenStore(BaseZenStore):
|
|
4401
4424
|
# request again, this time with a valid API token in the
|
4402
4425
|
# header.
|
4403
4426
|
logger.debug(
|
4404
|
-
f"The last request was not
|
4427
|
+
f"The last request with ID {request_id} was not "
|
4428
|
+
f"authenticated: {e}\n"
|
4405
4429
|
"Re-authenticating and retrying..."
|
4406
4430
|
)
|
4407
4431
|
self.authenticate()
|
@@ -4428,8 +4452,9 @@ class RestZenStore(BaseZenStore):
|
|
4428
4452
|
# that was rejected by the server. We attempt a
|
4429
4453
|
# re-authentication here and then retry the request.
|
4430
4454
|
logger.debug(
|
4431
|
-
"The last request
|
4432
|
-
|
4455
|
+
f"The last request with ID {request_id} was authenticated "
|
4456
|
+
"with an API token that was rejected by the server: "
|
4457
|
+
f"{e}\n"
|
4433
4458
|
"Re-authenticating and retrying..."
|
4434
4459
|
)
|
4435
4460
|
re_authenticated = True
|
@@ -4441,13 +4466,21 @@ class RestZenStore(BaseZenStore):
|
|
4441
4466
|
# The last request was made after re-authenticating but
|
4442
4467
|
# still failed. Bailing out.
|
4443
4468
|
logger.debug(
|
4444
|
-
f"The last request failed after
|
4469
|
+
f"The last request with ID {request_id} failed after "
|
4470
|
+
"re-authenticating: {e}\n"
|
4445
4471
|
"Bailing out..."
|
4446
4472
|
)
|
4447
4473
|
raise CredentialsNotValid(
|
4448
4474
|
"The current credentials are no longer valid. Please "
|
4449
4475
|
"log in again using 'zenml login'."
|
4450
4476
|
) from e
|
4477
|
+
finally:
|
4478
|
+
end_time = time.time()
|
4479
|
+
duration = (end_time - start_time) * 1000
|
4480
|
+
logger.debug(
|
4481
|
+
f"Request to {path} with request ID {request_id} took "
|
4482
|
+
f"{duration:.2f}ms."
|
4483
|
+
)
|
4451
4484
|
|
4452
4485
|
def get(
|
4453
4486
|
self,
|
@@ -4467,7 +4500,6 @@ class RestZenStore(BaseZenStore):
|
|
4467
4500
|
Returns:
|
4468
4501
|
The response body.
|
4469
4502
|
"""
|
4470
|
-
logger.debug(f"Sending GET request to {path}...")
|
4471
4503
|
return self._request(
|
4472
4504
|
"GET",
|
4473
4505
|
self.url + API + VERSION_1 + path,
|
@@ -4496,7 +4528,6 @@ class RestZenStore(BaseZenStore):
|
|
4496
4528
|
Returns:
|
4497
4529
|
The response body.
|
4498
4530
|
"""
|
4499
|
-
logger.debug(f"Sending DELETE request to {path}...")
|
4500
4531
|
return self._request(
|
4501
4532
|
"DELETE",
|
4502
4533
|
self.url + API + VERSION_1 + path,
|
@@ -4526,7 +4557,6 @@ class RestZenStore(BaseZenStore):
|
|
4526
4557
|
Returns:
|
4527
4558
|
The response body.
|
4528
4559
|
"""
|
4529
|
-
logger.debug(f"Sending POST request to {path}...")
|
4530
4560
|
return self._request(
|
4531
4561
|
"POST",
|
4532
4562
|
self.url + API + VERSION_1 + path,
|
@@ -4556,7 +4586,6 @@ class RestZenStore(BaseZenStore):
|
|
4556
4586
|
Returns:
|
4557
4587
|
The response body.
|
4558
4588
|
"""
|
4559
|
-
logger.debug(f"Sending PUT request to {path}...")
|
4560
4589
|
json = (
|
4561
4590
|
body.model_dump(mode="json", exclude_unset=True) if body else None
|
4562
4591
|
)
|