zenml-nightly 0.83.0.dev20250619__py3-none-any.whl → 0.83.0.dev20250622__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/__init__.py +12 -2
- zenml/analytics/context.py +4 -2
- zenml/config/server_config.py +6 -1
- zenml/constants.py +3 -0
- zenml/entrypoints/step_entrypoint_configuration.py +14 -0
- zenml/models/__init__.py +15 -0
- zenml/models/v2/core/api_transaction.py +193 -0
- zenml/models/v2/core/pipeline_build.py +4 -0
- zenml/models/v2/core/pipeline_deployment.py +8 -1
- zenml/models/v2/core/pipeline_run.py +7 -0
- zenml/models/v2/core/step_run.py +6 -0
- zenml/orchestrators/input_utils.py +34 -11
- zenml/utils/json_utils.py +1 -1
- zenml/zen_server/auth.py +53 -31
- zenml/zen_server/cloud_utils.py +19 -7
- zenml/zen_server/middleware.py +424 -0
- zenml/zen_server/rbac/endpoint_utils.py +5 -2
- zenml/zen_server/rbac/utils.py +12 -7
- zenml/zen_server/request_management.py +556 -0
- zenml/zen_server/routers/auth_endpoints.py +1 -0
- zenml/zen_server/routers/model_versions_endpoints.py +3 -3
- zenml/zen_server/routers/models_endpoints.py +3 -3
- zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
- zenml/zen_server/routers/pipelines_endpoints.py +4 -4
- zenml/zen_server/routers/run_templates_endpoints.py +3 -3
- zenml/zen_server/routers/runs_endpoints.py +4 -4
- zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
- zenml/zen_server/routers/steps_endpoints.py +3 -3
- zenml/zen_server/utils.py +230 -63
- zenml/zen_server/zen_server_api.py +34 -399
- zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
- zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
- zenml/zen_stores/rest_zen_store.py +52 -42
- zenml/zen_stores/schemas/__init__.py +4 -0
- zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
- zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
- zenml/zen_stores/schemas/step_run_schemas.py +4 -4
- zenml/zen_stores/sql_zen_store.py +277 -42
- zenml/zen_stores/zen_store_interface.py +7 -1
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/RECORD +47 -41
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,556 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Request management utilities."""
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
import base64
|
18
|
+
import json
|
19
|
+
from contextvars import ContextVar
|
20
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
21
|
+
from uuid import UUID, uuid4
|
22
|
+
|
23
|
+
from fastapi import Request, Response
|
24
|
+
from fastapi.responses import JSONResponse
|
25
|
+
|
26
|
+
from zenml.constants import MEDIUMTEXT_MAX_LENGTH
|
27
|
+
from zenml.exceptions import EntityExistsError
|
28
|
+
from zenml.logger import get_logger
|
29
|
+
from zenml.models import ApiTransactionRequest, ApiTransactionUpdate
|
30
|
+
from zenml.utils.json_utils import pydantic_encoder
|
31
|
+
from zenml.utils.time_utils import utc_now
|
32
|
+
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from zenml.zen_server.auth import AuthContext
|
35
|
+
|
36
|
+
|
37
|
+
logger = get_logger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
class RequestContext:
|
41
|
+
"""Context for a request."""
|
42
|
+
|
43
|
+
def __init__(self, request: Request) -> None:
|
44
|
+
"""Initialize the request context.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
request: The request.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
ValueError: If the idempotency key is not a valid UUID.
|
51
|
+
"""
|
52
|
+
self.request = request
|
53
|
+
self.request_id = request.headers.get("X-Request-ID", str(uuid4())[:8])
|
54
|
+
self.transaction_id: Optional[UUID] = None
|
55
|
+
transaction_id = request.headers.get("Idempotency-Key")
|
56
|
+
if transaction_id:
|
57
|
+
try:
|
58
|
+
self.transaction_id = UUID(transaction_id)
|
59
|
+
except ValueError:
|
60
|
+
raise ValueError(
|
61
|
+
f"Invalid UUID idempotency key: {transaction_id}. "
|
62
|
+
"Please use a valid UUID."
|
63
|
+
)
|
64
|
+
|
65
|
+
# Use a random trace ID to identify the request internally in the logs.
|
66
|
+
self.trace_id = str(uuid4())[:4]
|
67
|
+
|
68
|
+
self.source = request.headers.get("User-Agent") or ""
|
69
|
+
self.received_at = utc_now()
|
70
|
+
|
71
|
+
self.auth_context: Optional["AuthContext"] = None
|
72
|
+
|
73
|
+
@property
|
74
|
+
def process_time(self) -> float:
|
75
|
+
"""Get the process time in seconds of the request.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The request ID.
|
79
|
+
"""
|
80
|
+
return (utc_now() - self.received_at).total_seconds()
|
81
|
+
|
82
|
+
@property
|
83
|
+
def log_request_id(self) -> str:
|
84
|
+
"""Get the full request ID for logging.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The request ID.
|
88
|
+
"""
|
89
|
+
source_type = self.source.split("/")[0]
|
90
|
+
return f"{self.request_id}/{source_type}/{self.trace_id}"
|
91
|
+
|
92
|
+
@property
|
93
|
+
def log_request(self) -> str:
|
94
|
+
"""Get the request details for logging.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The request details for logging.
|
98
|
+
"""
|
99
|
+
client_ip = (
|
100
|
+
self.request.client.host if self.request.client else "unknown"
|
101
|
+
)
|
102
|
+
url_path = self.request.url.path
|
103
|
+
method = self.request.method
|
104
|
+
return f"{method} {url_path} from {client_ip}"
|
105
|
+
|
106
|
+
@property
|
107
|
+
def log_duration(self) -> str:
|
108
|
+
"""Get the duration of the request.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
The duration of the request.
|
112
|
+
"""
|
113
|
+
current_time = utc_now()
|
114
|
+
duration = (current_time - self.received_at).total_seconds() * 1000
|
115
|
+
return f"{duration:.2f}ms"
|
116
|
+
|
117
|
+
@property
|
118
|
+
def is_cacheable(self) -> bool:
|
119
|
+
"""Check if the request is cacheable.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Whether the request is cacheable.
|
123
|
+
"""
|
124
|
+
# Only cache requests that are authenticated and are part of a
|
125
|
+
# transaction.
|
126
|
+
return (
|
127
|
+
self.auth_context is not None and self.transaction_id is not None
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
class RequestRecord:
|
132
|
+
"""A record of an in-flight or cached request."""
|
133
|
+
|
134
|
+
future: asyncio.Future[Any]
|
135
|
+
request_context: RequestContext
|
136
|
+
|
137
|
+
def __init__(
|
138
|
+
self, future: asyncio.Future[Any], request_context: RequestContext
|
139
|
+
) -> None:
|
140
|
+
"""Initialize the request record.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
future: The future of the request.
|
144
|
+
request_context: The request context.
|
145
|
+
"""
|
146
|
+
self.future = future
|
147
|
+
self.request_context = request_context
|
148
|
+
self.completed = False
|
149
|
+
|
150
|
+
def set_result(self, result: Any) -> None:
|
151
|
+
"""Set the result of the request.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
result: The result of the request.
|
155
|
+
"""
|
156
|
+
self.future.set_result(result)
|
157
|
+
self.completed = True
|
158
|
+
|
159
|
+
def set_exception(self, exception: Exception) -> None:
|
160
|
+
"""Set the exception of the request.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
exception: The exception of the request.
|
164
|
+
"""
|
165
|
+
self.future.set_exception(exception)
|
166
|
+
self.completed = True
|
167
|
+
|
168
|
+
|
169
|
+
class RequestManager:
|
170
|
+
"""A manager for requests.
|
171
|
+
|
172
|
+
This class is used to manage requests by caching the results of requests
|
173
|
+
that have already been executed.
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __init__(
|
177
|
+
self,
|
178
|
+
transaction_ttl: int,
|
179
|
+
request_timeout: float,
|
180
|
+
deduplicate: bool = True,
|
181
|
+
) -> None:
|
182
|
+
"""Initialize the request manager.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
transaction_ttl: The time to live for cached transactions. Comes
|
186
|
+
into effect after a request is completed.
|
187
|
+
request_timeout: The timeout for requests. If a request takes longer
|
188
|
+
than this, a 429 error will be returned to the client to slow
|
189
|
+
down the request rate.
|
190
|
+
deduplicate: Whether to deduplicate requests.
|
191
|
+
"""
|
192
|
+
self.deduplicate = deduplicate
|
193
|
+
self.transactions: Dict[UUID, RequestRecord] = dict()
|
194
|
+
self.lock = asyncio.Lock()
|
195
|
+
self.transaction_ttl = transaction_ttl
|
196
|
+
self.request_timeout = request_timeout
|
197
|
+
|
198
|
+
self.request_contexts: ContextVar[Optional[RequestContext]] = (
|
199
|
+
ContextVar("request_contexts", default=None)
|
200
|
+
)
|
201
|
+
|
202
|
+
@property
|
203
|
+
def current_request(self) -> RequestContext:
|
204
|
+
"""Get the current request context.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
The current request context.
|
208
|
+
|
209
|
+
Raises:
|
210
|
+
RuntimeError: If no request context is set.
|
211
|
+
"""
|
212
|
+
request_context = self.request_contexts.get()
|
213
|
+
if request_context is None:
|
214
|
+
raise RuntimeError("No request context set")
|
215
|
+
return request_context
|
216
|
+
|
217
|
+
@current_request.setter
|
218
|
+
def current_request(self, request_context: RequestContext) -> None:
|
219
|
+
"""Set the current request context.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
request_context: The request context.
|
223
|
+
"""
|
224
|
+
self.request_contexts.set(request_context)
|
225
|
+
|
226
|
+
async def startup(self) -> None:
|
227
|
+
"""Start the request manager."""
|
228
|
+
pass
|
229
|
+
|
230
|
+
async def shutdown(self) -> None:
|
231
|
+
"""Shutdown the request manager."""
|
232
|
+
pass
|
233
|
+
|
234
|
+
async def async_run_and_cache_result(
|
235
|
+
self,
|
236
|
+
func: Callable[..., Any],
|
237
|
+
deduplicate: bool,
|
238
|
+
request_record: RequestRecord,
|
239
|
+
*args: Any,
|
240
|
+
**kwargs: Any,
|
241
|
+
) -> None:
|
242
|
+
"""Run a request and cache the result.
|
243
|
+
|
244
|
+
This method is called in the background to run a request and cache the
|
245
|
+
result.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
func: The function to execute.
|
249
|
+
deduplicate: Whether to enable or disable request deduplication for
|
250
|
+
this request.
|
251
|
+
request_record: The request record to cache the result in.
|
252
|
+
*args: The arguments to pass to the function.
|
253
|
+
**kwargs: The keyword arguments to pass to the function.
|
254
|
+
"""
|
255
|
+
from starlette.concurrency import run_in_threadpool
|
256
|
+
|
257
|
+
from zenml.zen_server.utils import get_system_metrics_log_str
|
258
|
+
|
259
|
+
request_context = request_record.request_context
|
260
|
+
transaction_id = request_context.transaction_id
|
261
|
+
|
262
|
+
logger.debug(
|
263
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
264
|
+
f"{request_context.log_request} "
|
265
|
+
f"async {func.__name__} STARTED "
|
266
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
267
|
+
)
|
268
|
+
|
269
|
+
def sync_run_and_cache_result(*args: Any, **kwargs: Any) -> Any:
|
270
|
+
from zenml.zen_server.utils import zen_store
|
271
|
+
|
272
|
+
# Copy the deduplicate flag to a local variable to avoid modifying
|
273
|
+
# the argument in the outer scope.
|
274
|
+
deduplicate_request = deduplicate
|
275
|
+
|
276
|
+
logger.debug(
|
277
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
278
|
+
f"{request_context.log_request} "
|
279
|
+
f"sync {func.__name__} STARTED "
|
280
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
281
|
+
)
|
282
|
+
|
283
|
+
try:
|
284
|
+
# Create or get the API transaction from the database
|
285
|
+
if deduplicate_request:
|
286
|
+
assert transaction_id is not None
|
287
|
+
try:
|
288
|
+
api_transaction, transaction_created = (
|
289
|
+
zen_store().get_or_create_api_transaction(
|
290
|
+
api_transaction=ApiTransactionRequest(
|
291
|
+
transaction_id=transaction_id,
|
292
|
+
method=request_context.request.method,
|
293
|
+
url=str(request_context.request.url),
|
294
|
+
)
|
295
|
+
)
|
296
|
+
)
|
297
|
+
except EntityExistsError:
|
298
|
+
logger.error(
|
299
|
+
f"[{request_context.log_request_id}] "
|
300
|
+
f"Transaction {transaction_id} already exists "
|
301
|
+
f"with method {request_context.request.method} and "
|
302
|
+
f"URL {str(request_context.request.url)}. Skipping "
|
303
|
+
"caching."
|
304
|
+
)
|
305
|
+
deduplicate_request = False
|
306
|
+
|
307
|
+
if deduplicate_request:
|
308
|
+
if api_transaction.completed:
|
309
|
+
logger.debug(
|
310
|
+
f"[{request_context.log_request_id}] "
|
311
|
+
"ENDPOINT STATS - "
|
312
|
+
f"{request_context.log_request} "
|
313
|
+
f"sync {func.__name__} CACHE HIT "
|
314
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
315
|
+
)
|
316
|
+
|
317
|
+
# The transaction already completed, we can return the
|
318
|
+
# result right away.
|
319
|
+
result = api_transaction.get_result()
|
320
|
+
if result is not None:
|
321
|
+
return Response(
|
322
|
+
base64.b64decode(result),
|
323
|
+
media_type="application/json",
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
return
|
327
|
+
|
328
|
+
elif not transaction_created:
|
329
|
+
logger.debug(
|
330
|
+
f"[{request_context.log_request_id}] "
|
331
|
+
"ENDPOINT STATS - "
|
332
|
+
f"{request_context.log_request} "
|
333
|
+
f"sync {func.__name__} DELAYED "
|
334
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
335
|
+
)
|
336
|
+
|
337
|
+
# The transaction is being processed by another server
|
338
|
+
# instance. We need to wait for it to complete. Instead
|
339
|
+
# of blocking this worker thread, we return a 429 error
|
340
|
+
# to the client to force it to retry later.
|
341
|
+
return JSONResponse(
|
342
|
+
{
|
343
|
+
"error": "Server too busy. Please try again later."
|
344
|
+
},
|
345
|
+
status_code=429,
|
346
|
+
)
|
347
|
+
|
348
|
+
try:
|
349
|
+
result = func(*args, **kwargs)
|
350
|
+
except Exception:
|
351
|
+
if deduplicate_request:
|
352
|
+
assert transaction_id is not None
|
353
|
+
# We don't cache exceptions. If the client retries, the
|
354
|
+
# request will be executed again and the exception, if
|
355
|
+
# persistent, will be raised again.
|
356
|
+
zen_store().delete_api_transaction(
|
357
|
+
api_transaction_id=transaction_id,
|
358
|
+
)
|
359
|
+
raise
|
360
|
+
|
361
|
+
if deduplicate_request:
|
362
|
+
assert transaction_id is not None
|
363
|
+
cache_result = True
|
364
|
+
result_to_cache: Optional[bytes] = None
|
365
|
+
if result is not None:
|
366
|
+
try:
|
367
|
+
result_to_cache = base64.b64encode(
|
368
|
+
json.dumps(
|
369
|
+
result, default=pydantic_encoder
|
370
|
+
).encode("utf-8")
|
371
|
+
)
|
372
|
+
except Exception:
|
373
|
+
# If the result is not serializable, we don't cache it.
|
374
|
+
cache_result = False
|
375
|
+
logger.exception(
|
376
|
+
f"Failed to serialize result of {func.__name__} "
|
377
|
+
f"for transaction {transaction_id}. Skipping "
|
378
|
+
"caching."
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
if len(result_to_cache) > MEDIUMTEXT_MAX_LENGTH:
|
382
|
+
# If the result is too large, we also don't cache it.
|
383
|
+
cache_result = False
|
384
|
+
result_to_cache = None
|
385
|
+
logger.error(
|
386
|
+
f"Result of {func.__name__} "
|
387
|
+
f"for transaction {transaction_id} is too "
|
388
|
+
"large. Skipping caching."
|
389
|
+
)
|
390
|
+
|
391
|
+
if cache_result:
|
392
|
+
api_transaction_update = ApiTransactionUpdate(
|
393
|
+
cache_time=self.transaction_ttl,
|
394
|
+
)
|
395
|
+
if result_to_cache is not None:
|
396
|
+
api_transaction_update.set_result(
|
397
|
+
result_to_cache.decode("utf-8")
|
398
|
+
)
|
399
|
+
zen_store().finalize_api_transaction(
|
400
|
+
api_transaction_id=transaction_id,
|
401
|
+
api_transaction_update=api_transaction_update,
|
402
|
+
)
|
403
|
+
else:
|
404
|
+
# If the result is not cacheable, there is no point in
|
405
|
+
# keeping the transaction around.
|
406
|
+
zen_store().delete_api_transaction(
|
407
|
+
api_transaction_id=transaction_id,
|
408
|
+
)
|
409
|
+
|
410
|
+
return result
|
411
|
+
finally:
|
412
|
+
logger.debug(
|
413
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
414
|
+
f"{request_context.log_request} "
|
415
|
+
f"sync {func.__name__} COMPLETED "
|
416
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
417
|
+
)
|
418
|
+
|
419
|
+
try:
|
420
|
+
result = await run_in_threadpool(
|
421
|
+
sync_run_and_cache_result, *args, **kwargs
|
422
|
+
)
|
423
|
+
except Exception as e:
|
424
|
+
result = e
|
425
|
+
|
426
|
+
if deduplicate:
|
427
|
+
async with self.lock:
|
428
|
+
if transaction_id in self.transactions:
|
429
|
+
del self.transactions[transaction_id]
|
430
|
+
|
431
|
+
logger.debug(
|
432
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
433
|
+
f"{request_context.log_request} "
|
434
|
+
f"async {func.__name__} COMPLETED "
|
435
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
436
|
+
)
|
437
|
+
|
438
|
+
if isinstance(result, Exception):
|
439
|
+
request_record.set_exception(result)
|
440
|
+
else:
|
441
|
+
request_record.set_result(result)
|
442
|
+
|
443
|
+
async def execute(
|
444
|
+
self,
|
445
|
+
func: Callable[..., Any],
|
446
|
+
deduplicate: Optional[bool],
|
447
|
+
*args: Any,
|
448
|
+
**kwargs: Any,
|
449
|
+
) -> Any:
|
450
|
+
"""Execute a request with in-memory de-duplication.
|
451
|
+
|
452
|
+
Call this method to execute a request with in-memory de-duplication.
|
453
|
+
If the request (identified by the request_id) has previously been
|
454
|
+
executed and the result is still cached, the result will be returned
|
455
|
+
immediately. If the request is in-flight, the method will wait for it
|
456
|
+
to complete. If the request is not in-flight, the method will start
|
457
|
+
execution in the background and cache the result.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
func: The function to execute.
|
461
|
+
deduplicate: Whether to enable or disable request deduplication for
|
462
|
+
this request. If not specified, by default, the deduplication
|
463
|
+
is enabled for POST requests and disabled for other requests.
|
464
|
+
*args: The arguments to pass to the function.
|
465
|
+
**kwargs: The keyword arguments to pass to the function.
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
The result of the request.
|
469
|
+
"""
|
470
|
+
from zenml.zen_server.utils import get_system_metrics_log_str
|
471
|
+
|
472
|
+
request_context = self.current_request
|
473
|
+
assert request_context is not None
|
474
|
+
|
475
|
+
logger.debug(
|
476
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
477
|
+
f"{request_context.log_request} "
|
478
|
+
f"{func.__name__} STARTED "
|
479
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
480
|
+
)
|
481
|
+
|
482
|
+
transaction_id = request_context.transaction_id
|
483
|
+
|
484
|
+
if deduplicate is None:
|
485
|
+
# If not specified, by default, the deduplication is enabled for
|
486
|
+
# POST requests and disabled for other requests.
|
487
|
+
deduplicate = request_context.request.method == "POST"
|
488
|
+
|
489
|
+
deduplicate = (
|
490
|
+
deduplicate and self.deduplicate and request_context.is_cacheable
|
491
|
+
)
|
492
|
+
|
493
|
+
async with self.lock:
|
494
|
+
if deduplicate and transaction_id in self.transactions:
|
495
|
+
# The transaction is still being processed on the same
|
496
|
+
# server instance. We just wait for it to complete.
|
497
|
+
fut = self.transactions[transaction_id].future
|
498
|
+
logger.debug(
|
499
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
500
|
+
f"{request_context.log_request} "
|
501
|
+
f"{func.__name__} RESUMED "
|
502
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
503
|
+
)
|
504
|
+
else:
|
505
|
+
# Start execution in background, use the future to wait for it
|
506
|
+
# to complete.
|
507
|
+
fut = asyncio.get_event_loop().create_future()
|
508
|
+
request_record = RequestRecord(
|
509
|
+
future=fut, request_context=request_context
|
510
|
+
)
|
511
|
+
if deduplicate:
|
512
|
+
assert transaction_id is not None
|
513
|
+
# Also record the transaction for deduplication.
|
514
|
+
self.transactions[transaction_id] = request_record
|
515
|
+
asyncio.create_task(
|
516
|
+
self.async_run_and_cache_result(
|
517
|
+
func,
|
518
|
+
deduplicate,
|
519
|
+
request_record,
|
520
|
+
*args,
|
521
|
+
**kwargs,
|
522
|
+
)
|
523
|
+
)
|
524
|
+
|
525
|
+
# Wait for the request to complete; timeout if deduplication is enabled
|
526
|
+
try:
|
527
|
+
# We take into account the time that has already elapsed since the
|
528
|
+
# request was received to avoid keeping the request for too long.
|
529
|
+
timeout = max(
|
530
|
+
0, self.request_timeout - request_context.process_time
|
531
|
+
)
|
532
|
+
result = await asyncio.wait_for(
|
533
|
+
# We use asyncio.shield to prevent the request from being
|
534
|
+
# cancelled when the timeout is reached.
|
535
|
+
asyncio.shield(fut),
|
536
|
+
timeout=timeout if deduplicate else None,
|
537
|
+
)
|
538
|
+
|
539
|
+
logger.debug(
|
540
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
541
|
+
f"{request_context.log_request} "
|
542
|
+
f"{func.__name__} COMPLETED "
|
543
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
544
|
+
)
|
545
|
+
return result
|
546
|
+
except asyncio.TimeoutError:
|
547
|
+
logger.debug(
|
548
|
+
f"[{request_context.log_request_id}] ENDPOINT STATS - "
|
549
|
+
f"{request_context.log_request} "
|
550
|
+
f"{func.__name__} TIMEOUT "
|
551
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
552
|
+
)
|
553
|
+
return JSONResponse(
|
554
|
+
{"error": "Server too busy. Please try again later."},
|
555
|
+
status_code=429,
|
556
|
+
)
|
@@ -594,6 +594,7 @@ def api_token(
|
|
594
594
|
if pipeline_run_id:
|
595
595
|
# The pipeline run must exist and the run must not be concluded
|
596
596
|
try:
|
597
|
+
# TODO: this is expensive, we should only fetch the minimum data here
|
597
598
|
pipeline_run = zen_store().get_run(pipeline_run_id, hydrate=True)
|
598
599
|
except KeyError:
|
599
600
|
raise ValueError(
|
@@ -127,7 +127,7 @@ def create_model_version(
|
|
127
127
|
"",
|
128
128
|
responses={401: error_response, 404: error_response, 422: error_response},
|
129
129
|
)
|
130
|
-
@async_fastapi_endpoint_wrapper
|
130
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
131
131
|
def list_model_versions(
|
132
132
|
model_version_filter_model: ModelVersionFilter = Depends(
|
133
133
|
make_dependable(ModelVersionFilter)
|
@@ -176,7 +176,7 @@ def list_model_versions(
|
|
176
176
|
"/{model_version_id}",
|
177
177
|
responses={401: error_response, 404: error_response, 422: error_response},
|
178
178
|
)
|
179
|
-
@async_fastapi_endpoint_wrapper
|
179
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
180
180
|
def get_model_version(
|
181
181
|
model_version_id: UUID,
|
182
182
|
hydrate: bool = True,
|
@@ -204,7 +204,7 @@ def get_model_version(
|
|
204
204
|
"/{model_version_id}",
|
205
205
|
responses={401: error_response, 404: error_response, 422: error_response},
|
206
206
|
)
|
207
|
-
@async_fastapi_endpoint_wrapper
|
207
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
208
208
|
def update_model_version(
|
209
209
|
model_version_id: UUID,
|
210
210
|
model_version_update_model: ModelVersionUpdate,
|
@@ -101,7 +101,7 @@ def create_model(
|
|
101
101
|
"",
|
102
102
|
responses={401: error_response, 404: error_response, 422: error_response},
|
103
103
|
)
|
104
|
-
@async_fastapi_endpoint_wrapper
|
104
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
105
105
|
def list_models(
|
106
106
|
model_filter_model: ModelFilter = Depends(make_dependable(ModelFilter)),
|
107
107
|
hydrate: bool = False,
|
@@ -130,7 +130,7 @@ def list_models(
|
|
130
130
|
"/{model_id}",
|
131
131
|
responses={401: error_response, 404: error_response, 422: error_response},
|
132
132
|
)
|
133
|
-
@async_fastapi_endpoint_wrapper
|
133
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
134
134
|
def get_model(
|
135
135
|
model_id: UUID,
|
136
136
|
hydrate: bool = True,
|
@@ -155,7 +155,7 @@ def get_model(
|
|
155
155
|
"/{model_id}",
|
156
156
|
responses={401: error_response, 404: error_response, 422: error_response},
|
157
157
|
)
|
158
|
-
@async_fastapi_endpoint_wrapper
|
158
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
159
159
|
def update_model(
|
160
160
|
model_id: UUID,
|
161
161
|
model_update: ModelUpdate,
|
@@ -97,7 +97,7 @@ def create_build(
|
|
97
97
|
deprecated=True,
|
98
98
|
tags=["builds"],
|
99
99
|
)
|
100
|
-
@async_fastapi_endpoint_wrapper
|
100
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
101
101
|
def list_builds(
|
102
102
|
build_filter_model: PipelineBuildFilter = Depends(
|
103
103
|
make_dependable(PipelineBuildFilter)
|
@@ -133,7 +133,7 @@ def list_builds(
|
|
133
133
|
"/{build_id}",
|
134
134
|
responses={401: error_response, 404: error_response, 422: error_response},
|
135
135
|
)
|
136
|
-
@async_fastapi_endpoint_wrapper
|
136
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
137
137
|
def get_build(
|
138
138
|
build_id: UUID,
|
139
139
|
hydrate: bool = True,
|
@@ -13,10 +13,10 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Endpoint definitions for deployments."""
|
15
15
|
|
16
|
-
from typing import Any, Optional, Union
|
16
|
+
from typing import Any, List, Optional, Union
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
|
-
from fastapi import APIRouter, Depends, Request, Security
|
19
|
+
from fastapi import APIRouter, Depends, Query, Request, Security
|
20
20
|
|
21
21
|
from zenml.constants import API, PIPELINE_DEPLOYMENTS, VERSION_1
|
22
22
|
from zenml.logging.step_logging import fetch_logs
|
@@ -142,7 +142,7 @@ def create_deployment(
|
|
142
142
|
deprecated=True,
|
143
143
|
tags=["deployments"],
|
144
144
|
)
|
145
|
-
@async_fastapi_endpoint_wrapper
|
145
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
146
146
|
def list_deployments(
|
147
147
|
request: Request,
|
148
148
|
deployment_filter_model: PipelineDeploymentFilter = Depends(
|
@@ -196,11 +196,12 @@ def list_deployments(
|
|
196
196
|
"/{deployment_id}",
|
197
197
|
responses={401: error_response, 404: error_response, 422: error_response},
|
198
198
|
)
|
199
|
-
@async_fastapi_endpoint_wrapper
|
199
|
+
@async_fastapi_endpoint_wrapper(deduplicate=True)
|
200
200
|
def get_deployment(
|
201
201
|
request: Request,
|
202
202
|
deployment_id: UUID,
|
203
203
|
hydrate: bool = True,
|
204
|
+
step_configuration_filter: Optional[List[str]] = Query(None),
|
204
205
|
_: AuthContext = Security(authorize),
|
205
206
|
) -> Any:
|
206
207
|
"""Gets a specific deployment using its unique id.
|
@@ -210,6 +211,9 @@ def get_deployment(
|
|
210
211
|
deployment_id: ID of the deployment to get.
|
211
212
|
hydrate: Flag deciding whether to hydrate the output model(s)
|
212
213
|
by including metadata fields in the response.
|
214
|
+
step_configuration_filter: List of step configurations to include in
|
215
|
+
the response. If not given, all step configurations will be
|
216
|
+
included.
|
213
217
|
|
214
218
|
Returns:
|
215
219
|
A specific deployment object.
|
@@ -218,6 +222,7 @@ def get_deployment(
|
|
218
222
|
id=deployment_id,
|
219
223
|
get_method=zen_store().get_deployment,
|
220
224
|
hydrate=hydrate,
|
225
|
+
step_configuration_filter=step_configuration_filter,
|
221
226
|
)
|
222
227
|
|
223
228
|
exclude = None
|