zenml-nightly 0.83.0.dev20250619__py3-none-any.whl → 0.83.0.dev20250622__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/__init__.py +12 -2
  3. zenml/analytics/context.py +4 -2
  4. zenml/config/server_config.py +6 -1
  5. zenml/constants.py +3 -0
  6. zenml/entrypoints/step_entrypoint_configuration.py +14 -0
  7. zenml/models/__init__.py +15 -0
  8. zenml/models/v2/core/api_transaction.py +193 -0
  9. zenml/models/v2/core/pipeline_build.py +4 -0
  10. zenml/models/v2/core/pipeline_deployment.py +8 -1
  11. zenml/models/v2/core/pipeline_run.py +7 -0
  12. zenml/models/v2/core/step_run.py +6 -0
  13. zenml/orchestrators/input_utils.py +34 -11
  14. zenml/utils/json_utils.py +1 -1
  15. zenml/zen_server/auth.py +53 -31
  16. zenml/zen_server/cloud_utils.py +19 -7
  17. zenml/zen_server/middleware.py +424 -0
  18. zenml/zen_server/rbac/endpoint_utils.py +5 -2
  19. zenml/zen_server/rbac/utils.py +12 -7
  20. zenml/zen_server/request_management.py +556 -0
  21. zenml/zen_server/routers/auth_endpoints.py +1 -0
  22. zenml/zen_server/routers/model_versions_endpoints.py +3 -3
  23. zenml/zen_server/routers/models_endpoints.py +3 -3
  24. zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
  25. zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
  26. zenml/zen_server/routers/pipelines_endpoints.py +4 -4
  27. zenml/zen_server/routers/run_templates_endpoints.py +3 -3
  28. zenml/zen_server/routers/runs_endpoints.py +4 -4
  29. zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
  30. zenml/zen_server/routers/steps_endpoints.py +3 -3
  31. zenml/zen_server/utils.py +230 -63
  32. zenml/zen_server/zen_server_api.py +34 -399
  33. zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
  34. zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
  35. zenml/zen_stores/rest_zen_store.py +52 -42
  36. zenml/zen_stores/schemas/__init__.py +4 -0
  37. zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
  38. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
  40. zenml/zen_stores/schemas/step_run_schemas.py +4 -4
  41. zenml/zen_stores/sql_zen_store.py +277 -42
  42. zenml/zen_stores/zen_store_interface.py +7 -1
  43. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/RECORD +47 -41
  45. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,556 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Request management utilities."""
15
+
16
+ import asyncio
17
+ import base64
18
+ import json
19
+ from contextvars import ContextVar
20
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
21
+ from uuid import UUID, uuid4
22
+
23
+ from fastapi import Request, Response
24
+ from fastapi.responses import JSONResponse
25
+
26
+ from zenml.constants import MEDIUMTEXT_MAX_LENGTH
27
+ from zenml.exceptions import EntityExistsError
28
+ from zenml.logger import get_logger
29
+ from zenml.models import ApiTransactionRequest, ApiTransactionUpdate
30
+ from zenml.utils.json_utils import pydantic_encoder
31
+ from zenml.utils.time_utils import utc_now
32
+
33
+ if TYPE_CHECKING:
34
+ from zenml.zen_server.auth import AuthContext
35
+
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ class RequestContext:
41
+ """Context for a request."""
42
+
43
+ def __init__(self, request: Request) -> None:
44
+ """Initialize the request context.
45
+
46
+ Args:
47
+ request: The request.
48
+
49
+ Raises:
50
+ ValueError: If the idempotency key is not a valid UUID.
51
+ """
52
+ self.request = request
53
+ self.request_id = request.headers.get("X-Request-ID", str(uuid4())[:8])
54
+ self.transaction_id: Optional[UUID] = None
55
+ transaction_id = request.headers.get("Idempotency-Key")
56
+ if transaction_id:
57
+ try:
58
+ self.transaction_id = UUID(transaction_id)
59
+ except ValueError:
60
+ raise ValueError(
61
+ f"Invalid UUID idempotency key: {transaction_id}. "
62
+ "Please use a valid UUID."
63
+ )
64
+
65
+ # Use a random trace ID to identify the request internally in the logs.
66
+ self.trace_id = str(uuid4())[:4]
67
+
68
+ self.source = request.headers.get("User-Agent") or ""
69
+ self.received_at = utc_now()
70
+
71
+ self.auth_context: Optional["AuthContext"] = None
72
+
73
+ @property
74
+ def process_time(self) -> float:
75
+ """Get the process time in seconds of the request.
76
+
77
+ Returns:
78
+ The request ID.
79
+ """
80
+ return (utc_now() - self.received_at).total_seconds()
81
+
82
+ @property
83
+ def log_request_id(self) -> str:
84
+ """Get the full request ID for logging.
85
+
86
+ Returns:
87
+ The request ID.
88
+ """
89
+ source_type = self.source.split("/")[0]
90
+ return f"{self.request_id}/{source_type}/{self.trace_id}"
91
+
92
+ @property
93
+ def log_request(self) -> str:
94
+ """Get the request details for logging.
95
+
96
+ Returns:
97
+ The request details for logging.
98
+ """
99
+ client_ip = (
100
+ self.request.client.host if self.request.client else "unknown"
101
+ )
102
+ url_path = self.request.url.path
103
+ method = self.request.method
104
+ return f"{method} {url_path} from {client_ip}"
105
+
106
+ @property
107
+ def log_duration(self) -> str:
108
+ """Get the duration of the request.
109
+
110
+ Returns:
111
+ The duration of the request.
112
+ """
113
+ current_time = utc_now()
114
+ duration = (current_time - self.received_at).total_seconds() * 1000
115
+ return f"{duration:.2f}ms"
116
+
117
+ @property
118
+ def is_cacheable(self) -> bool:
119
+ """Check if the request is cacheable.
120
+
121
+ Returns:
122
+ Whether the request is cacheable.
123
+ """
124
+ # Only cache requests that are authenticated and are part of a
125
+ # transaction.
126
+ return (
127
+ self.auth_context is not None and self.transaction_id is not None
128
+ )
129
+
130
+
131
+ class RequestRecord:
132
+ """A record of an in-flight or cached request."""
133
+
134
+ future: asyncio.Future[Any]
135
+ request_context: RequestContext
136
+
137
+ def __init__(
138
+ self, future: asyncio.Future[Any], request_context: RequestContext
139
+ ) -> None:
140
+ """Initialize the request record.
141
+
142
+ Args:
143
+ future: The future of the request.
144
+ request_context: The request context.
145
+ """
146
+ self.future = future
147
+ self.request_context = request_context
148
+ self.completed = False
149
+
150
+ def set_result(self, result: Any) -> None:
151
+ """Set the result of the request.
152
+
153
+ Args:
154
+ result: The result of the request.
155
+ """
156
+ self.future.set_result(result)
157
+ self.completed = True
158
+
159
+ def set_exception(self, exception: Exception) -> None:
160
+ """Set the exception of the request.
161
+
162
+ Args:
163
+ exception: The exception of the request.
164
+ """
165
+ self.future.set_exception(exception)
166
+ self.completed = True
167
+
168
+
169
+ class RequestManager:
170
+ """A manager for requests.
171
+
172
+ This class is used to manage requests by caching the results of requests
173
+ that have already been executed.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ transaction_ttl: int,
179
+ request_timeout: float,
180
+ deduplicate: bool = True,
181
+ ) -> None:
182
+ """Initialize the request manager.
183
+
184
+ Args:
185
+ transaction_ttl: The time to live for cached transactions. Comes
186
+ into effect after a request is completed.
187
+ request_timeout: The timeout for requests. If a request takes longer
188
+ than this, a 429 error will be returned to the client to slow
189
+ down the request rate.
190
+ deduplicate: Whether to deduplicate requests.
191
+ """
192
+ self.deduplicate = deduplicate
193
+ self.transactions: Dict[UUID, RequestRecord] = dict()
194
+ self.lock = asyncio.Lock()
195
+ self.transaction_ttl = transaction_ttl
196
+ self.request_timeout = request_timeout
197
+
198
+ self.request_contexts: ContextVar[Optional[RequestContext]] = (
199
+ ContextVar("request_contexts", default=None)
200
+ )
201
+
202
+ @property
203
+ def current_request(self) -> RequestContext:
204
+ """Get the current request context.
205
+
206
+ Returns:
207
+ The current request context.
208
+
209
+ Raises:
210
+ RuntimeError: If no request context is set.
211
+ """
212
+ request_context = self.request_contexts.get()
213
+ if request_context is None:
214
+ raise RuntimeError("No request context set")
215
+ return request_context
216
+
217
+ @current_request.setter
218
+ def current_request(self, request_context: RequestContext) -> None:
219
+ """Set the current request context.
220
+
221
+ Args:
222
+ request_context: The request context.
223
+ """
224
+ self.request_contexts.set(request_context)
225
+
226
+ async def startup(self) -> None:
227
+ """Start the request manager."""
228
+ pass
229
+
230
+ async def shutdown(self) -> None:
231
+ """Shutdown the request manager."""
232
+ pass
233
+
234
+ async def async_run_and_cache_result(
235
+ self,
236
+ func: Callable[..., Any],
237
+ deduplicate: bool,
238
+ request_record: RequestRecord,
239
+ *args: Any,
240
+ **kwargs: Any,
241
+ ) -> None:
242
+ """Run a request and cache the result.
243
+
244
+ This method is called in the background to run a request and cache the
245
+ result.
246
+
247
+ Args:
248
+ func: The function to execute.
249
+ deduplicate: Whether to enable or disable request deduplication for
250
+ this request.
251
+ request_record: The request record to cache the result in.
252
+ *args: The arguments to pass to the function.
253
+ **kwargs: The keyword arguments to pass to the function.
254
+ """
255
+ from starlette.concurrency import run_in_threadpool
256
+
257
+ from zenml.zen_server.utils import get_system_metrics_log_str
258
+
259
+ request_context = request_record.request_context
260
+ transaction_id = request_context.transaction_id
261
+
262
+ logger.debug(
263
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
264
+ f"{request_context.log_request} "
265
+ f"async {func.__name__} STARTED "
266
+ f"{get_system_metrics_log_str(request_context.request)}"
267
+ )
268
+
269
+ def sync_run_and_cache_result(*args: Any, **kwargs: Any) -> Any:
270
+ from zenml.zen_server.utils import zen_store
271
+
272
+ # Copy the deduplicate flag to a local variable to avoid modifying
273
+ # the argument in the outer scope.
274
+ deduplicate_request = deduplicate
275
+
276
+ logger.debug(
277
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
278
+ f"{request_context.log_request} "
279
+ f"sync {func.__name__} STARTED "
280
+ f"{get_system_metrics_log_str(request_context.request)}"
281
+ )
282
+
283
+ try:
284
+ # Create or get the API transaction from the database
285
+ if deduplicate_request:
286
+ assert transaction_id is not None
287
+ try:
288
+ api_transaction, transaction_created = (
289
+ zen_store().get_or_create_api_transaction(
290
+ api_transaction=ApiTransactionRequest(
291
+ transaction_id=transaction_id,
292
+ method=request_context.request.method,
293
+ url=str(request_context.request.url),
294
+ )
295
+ )
296
+ )
297
+ except EntityExistsError:
298
+ logger.error(
299
+ f"[{request_context.log_request_id}] "
300
+ f"Transaction {transaction_id} already exists "
301
+ f"with method {request_context.request.method} and "
302
+ f"URL {str(request_context.request.url)}. Skipping "
303
+ "caching."
304
+ )
305
+ deduplicate_request = False
306
+
307
+ if deduplicate_request:
308
+ if api_transaction.completed:
309
+ logger.debug(
310
+ f"[{request_context.log_request_id}] "
311
+ "ENDPOINT STATS - "
312
+ f"{request_context.log_request} "
313
+ f"sync {func.__name__} CACHE HIT "
314
+ f"{get_system_metrics_log_str(request_context.request)}"
315
+ )
316
+
317
+ # The transaction already completed, we can return the
318
+ # result right away.
319
+ result = api_transaction.get_result()
320
+ if result is not None:
321
+ return Response(
322
+ base64.b64decode(result),
323
+ media_type="application/json",
324
+ )
325
+ else:
326
+ return
327
+
328
+ elif not transaction_created:
329
+ logger.debug(
330
+ f"[{request_context.log_request_id}] "
331
+ "ENDPOINT STATS - "
332
+ f"{request_context.log_request} "
333
+ f"sync {func.__name__} DELAYED "
334
+ f"{get_system_metrics_log_str(request_context.request)}"
335
+ )
336
+
337
+ # The transaction is being processed by another server
338
+ # instance. We need to wait for it to complete. Instead
339
+ # of blocking this worker thread, we return a 429 error
340
+ # to the client to force it to retry later.
341
+ return JSONResponse(
342
+ {
343
+ "error": "Server too busy. Please try again later."
344
+ },
345
+ status_code=429,
346
+ )
347
+
348
+ try:
349
+ result = func(*args, **kwargs)
350
+ except Exception:
351
+ if deduplicate_request:
352
+ assert transaction_id is not None
353
+ # We don't cache exceptions. If the client retries, the
354
+ # request will be executed again and the exception, if
355
+ # persistent, will be raised again.
356
+ zen_store().delete_api_transaction(
357
+ api_transaction_id=transaction_id,
358
+ )
359
+ raise
360
+
361
+ if deduplicate_request:
362
+ assert transaction_id is not None
363
+ cache_result = True
364
+ result_to_cache: Optional[bytes] = None
365
+ if result is not None:
366
+ try:
367
+ result_to_cache = base64.b64encode(
368
+ json.dumps(
369
+ result, default=pydantic_encoder
370
+ ).encode("utf-8")
371
+ )
372
+ except Exception:
373
+ # If the result is not serializable, we don't cache it.
374
+ cache_result = False
375
+ logger.exception(
376
+ f"Failed to serialize result of {func.__name__} "
377
+ f"for transaction {transaction_id}. Skipping "
378
+ "caching."
379
+ )
380
+ else:
381
+ if len(result_to_cache) > MEDIUMTEXT_MAX_LENGTH:
382
+ # If the result is too large, we also don't cache it.
383
+ cache_result = False
384
+ result_to_cache = None
385
+ logger.error(
386
+ f"Result of {func.__name__} "
387
+ f"for transaction {transaction_id} is too "
388
+ "large. Skipping caching."
389
+ )
390
+
391
+ if cache_result:
392
+ api_transaction_update = ApiTransactionUpdate(
393
+ cache_time=self.transaction_ttl,
394
+ )
395
+ if result_to_cache is not None:
396
+ api_transaction_update.set_result(
397
+ result_to_cache.decode("utf-8")
398
+ )
399
+ zen_store().finalize_api_transaction(
400
+ api_transaction_id=transaction_id,
401
+ api_transaction_update=api_transaction_update,
402
+ )
403
+ else:
404
+ # If the result is not cacheable, there is no point in
405
+ # keeping the transaction around.
406
+ zen_store().delete_api_transaction(
407
+ api_transaction_id=transaction_id,
408
+ )
409
+
410
+ return result
411
+ finally:
412
+ logger.debug(
413
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
414
+ f"{request_context.log_request} "
415
+ f"sync {func.__name__} COMPLETED "
416
+ f"{get_system_metrics_log_str(request_context.request)}"
417
+ )
418
+
419
+ try:
420
+ result = await run_in_threadpool(
421
+ sync_run_and_cache_result, *args, **kwargs
422
+ )
423
+ except Exception as e:
424
+ result = e
425
+
426
+ if deduplicate:
427
+ async with self.lock:
428
+ if transaction_id in self.transactions:
429
+ del self.transactions[transaction_id]
430
+
431
+ logger.debug(
432
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
433
+ f"{request_context.log_request} "
434
+ f"async {func.__name__} COMPLETED "
435
+ f"{get_system_metrics_log_str(request_context.request)}"
436
+ )
437
+
438
+ if isinstance(result, Exception):
439
+ request_record.set_exception(result)
440
+ else:
441
+ request_record.set_result(result)
442
+
443
+ async def execute(
444
+ self,
445
+ func: Callable[..., Any],
446
+ deduplicate: Optional[bool],
447
+ *args: Any,
448
+ **kwargs: Any,
449
+ ) -> Any:
450
+ """Execute a request with in-memory de-duplication.
451
+
452
+ Call this method to execute a request with in-memory de-duplication.
453
+ If the request (identified by the request_id) has previously been
454
+ executed and the result is still cached, the result will be returned
455
+ immediately. If the request is in-flight, the method will wait for it
456
+ to complete. If the request is not in-flight, the method will start
457
+ execution in the background and cache the result.
458
+
459
+ Args:
460
+ func: The function to execute.
461
+ deduplicate: Whether to enable or disable request deduplication for
462
+ this request. If not specified, by default, the deduplication
463
+ is enabled for POST requests and disabled for other requests.
464
+ *args: The arguments to pass to the function.
465
+ **kwargs: The keyword arguments to pass to the function.
466
+
467
+ Returns:
468
+ The result of the request.
469
+ """
470
+ from zenml.zen_server.utils import get_system_metrics_log_str
471
+
472
+ request_context = self.current_request
473
+ assert request_context is not None
474
+
475
+ logger.debug(
476
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
477
+ f"{request_context.log_request} "
478
+ f"{func.__name__} STARTED "
479
+ f"{get_system_metrics_log_str(request_context.request)}"
480
+ )
481
+
482
+ transaction_id = request_context.transaction_id
483
+
484
+ if deduplicate is None:
485
+ # If not specified, by default, the deduplication is enabled for
486
+ # POST requests and disabled for other requests.
487
+ deduplicate = request_context.request.method == "POST"
488
+
489
+ deduplicate = (
490
+ deduplicate and self.deduplicate and request_context.is_cacheable
491
+ )
492
+
493
+ async with self.lock:
494
+ if deduplicate and transaction_id in self.transactions:
495
+ # The transaction is still being processed on the same
496
+ # server instance. We just wait for it to complete.
497
+ fut = self.transactions[transaction_id].future
498
+ logger.debug(
499
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
500
+ f"{request_context.log_request} "
501
+ f"{func.__name__} RESUMED "
502
+ f"{get_system_metrics_log_str(request_context.request)}"
503
+ )
504
+ else:
505
+ # Start execution in background, use the future to wait for it
506
+ # to complete.
507
+ fut = asyncio.get_event_loop().create_future()
508
+ request_record = RequestRecord(
509
+ future=fut, request_context=request_context
510
+ )
511
+ if deduplicate:
512
+ assert transaction_id is not None
513
+ # Also record the transaction for deduplication.
514
+ self.transactions[transaction_id] = request_record
515
+ asyncio.create_task(
516
+ self.async_run_and_cache_result(
517
+ func,
518
+ deduplicate,
519
+ request_record,
520
+ *args,
521
+ **kwargs,
522
+ )
523
+ )
524
+
525
+ # Wait for the request to complete; timeout if deduplication is enabled
526
+ try:
527
+ # We take into account the time that has already elapsed since the
528
+ # request was received to avoid keeping the request for too long.
529
+ timeout = max(
530
+ 0, self.request_timeout - request_context.process_time
531
+ )
532
+ result = await asyncio.wait_for(
533
+ # We use asyncio.shield to prevent the request from being
534
+ # cancelled when the timeout is reached.
535
+ asyncio.shield(fut),
536
+ timeout=timeout if deduplicate else None,
537
+ )
538
+
539
+ logger.debug(
540
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
541
+ f"{request_context.log_request} "
542
+ f"{func.__name__} COMPLETED "
543
+ f"{get_system_metrics_log_str(request_context.request)}"
544
+ )
545
+ return result
546
+ except asyncio.TimeoutError:
547
+ logger.debug(
548
+ f"[{request_context.log_request_id}] ENDPOINT STATS - "
549
+ f"{request_context.log_request} "
550
+ f"{func.__name__} TIMEOUT "
551
+ f"{get_system_metrics_log_str(request_context.request)}"
552
+ )
553
+ return JSONResponse(
554
+ {"error": "Server too busy. Please try again later."},
555
+ status_code=429,
556
+ )
@@ -594,6 +594,7 @@ def api_token(
594
594
  if pipeline_run_id:
595
595
  # The pipeline run must exist and the run must not be concluded
596
596
  try:
597
+ # TODO: this is expensive, we should only fetch the minimum data here
597
598
  pipeline_run = zen_store().get_run(pipeline_run_id, hydrate=True)
598
599
  except KeyError:
599
600
  raise ValueError(
@@ -127,7 +127,7 @@ def create_model_version(
127
127
  "",
128
128
  responses={401: error_response, 404: error_response, 422: error_response},
129
129
  )
130
- @async_fastapi_endpoint_wrapper
130
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
131
131
  def list_model_versions(
132
132
  model_version_filter_model: ModelVersionFilter = Depends(
133
133
  make_dependable(ModelVersionFilter)
@@ -176,7 +176,7 @@ def list_model_versions(
176
176
  "/{model_version_id}",
177
177
  responses={401: error_response, 404: error_response, 422: error_response},
178
178
  )
179
- @async_fastapi_endpoint_wrapper
179
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
180
180
  def get_model_version(
181
181
  model_version_id: UUID,
182
182
  hydrate: bool = True,
@@ -204,7 +204,7 @@ def get_model_version(
204
204
  "/{model_version_id}",
205
205
  responses={401: error_response, 404: error_response, 422: error_response},
206
206
  )
207
- @async_fastapi_endpoint_wrapper
207
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
208
208
  def update_model_version(
209
209
  model_version_id: UUID,
210
210
  model_version_update_model: ModelVersionUpdate,
@@ -101,7 +101,7 @@ def create_model(
101
101
  "",
102
102
  responses={401: error_response, 404: error_response, 422: error_response},
103
103
  )
104
- @async_fastapi_endpoint_wrapper
104
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
105
105
  def list_models(
106
106
  model_filter_model: ModelFilter = Depends(make_dependable(ModelFilter)),
107
107
  hydrate: bool = False,
@@ -130,7 +130,7 @@ def list_models(
130
130
  "/{model_id}",
131
131
  responses={401: error_response, 404: error_response, 422: error_response},
132
132
  )
133
- @async_fastapi_endpoint_wrapper
133
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
134
134
  def get_model(
135
135
  model_id: UUID,
136
136
  hydrate: bool = True,
@@ -155,7 +155,7 @@ def get_model(
155
155
  "/{model_id}",
156
156
  responses={401: error_response, 404: error_response, 422: error_response},
157
157
  )
158
- @async_fastapi_endpoint_wrapper
158
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
159
159
  def update_model(
160
160
  model_id: UUID,
161
161
  model_update: ModelUpdate,
@@ -97,7 +97,7 @@ def create_build(
97
97
  deprecated=True,
98
98
  tags=["builds"],
99
99
  )
100
- @async_fastapi_endpoint_wrapper
100
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
101
101
  def list_builds(
102
102
  build_filter_model: PipelineBuildFilter = Depends(
103
103
  make_dependable(PipelineBuildFilter)
@@ -133,7 +133,7 @@ def list_builds(
133
133
  "/{build_id}",
134
134
  responses={401: error_response, 404: error_response, 422: error_response},
135
135
  )
136
- @async_fastapi_endpoint_wrapper
136
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
137
137
  def get_build(
138
138
  build_id: UUID,
139
139
  hydrate: bool = True,
@@ -13,10 +13,10 @@
13
13
  # permissions and limitations under the License.
14
14
  """Endpoint definitions for deployments."""
15
15
 
16
- from typing import Any, Optional, Union
16
+ from typing import Any, List, Optional, Union
17
17
  from uuid import UUID
18
18
 
19
- from fastapi import APIRouter, Depends, Request, Security
19
+ from fastapi import APIRouter, Depends, Query, Request, Security
20
20
 
21
21
  from zenml.constants import API, PIPELINE_DEPLOYMENTS, VERSION_1
22
22
  from zenml.logging.step_logging import fetch_logs
@@ -142,7 +142,7 @@ def create_deployment(
142
142
  deprecated=True,
143
143
  tags=["deployments"],
144
144
  )
145
- @async_fastapi_endpoint_wrapper
145
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
146
146
  def list_deployments(
147
147
  request: Request,
148
148
  deployment_filter_model: PipelineDeploymentFilter = Depends(
@@ -196,11 +196,12 @@ def list_deployments(
196
196
  "/{deployment_id}",
197
197
  responses={401: error_response, 404: error_response, 422: error_response},
198
198
  )
199
- @async_fastapi_endpoint_wrapper
199
+ @async_fastapi_endpoint_wrapper(deduplicate=True)
200
200
  def get_deployment(
201
201
  request: Request,
202
202
  deployment_id: UUID,
203
203
  hydrate: bool = True,
204
+ step_configuration_filter: Optional[List[str]] = Query(None),
204
205
  _: AuthContext = Security(authorize),
205
206
  ) -> Any:
206
207
  """Gets a specific deployment using its unique id.
@@ -210,6 +211,9 @@ def get_deployment(
210
211
  deployment_id: ID of the deployment to get.
211
212
  hydrate: Flag deciding whether to hydrate the output model(s)
212
213
  by including metadata fields in the response.
214
+ step_configuration_filter: List of step configurations to include in
215
+ the response. If not given, all step configurations will be
216
+ included.
213
217
 
214
218
  Returns:
215
219
  A specific deployment object.
@@ -218,6 +222,7 @@ def get_deployment(
218
222
  id=deployment_id,
219
223
  get_method=zen_store().get_deployment,
220
224
  hydrate=hydrate,
225
+ step_configuration_filter=step_configuration_filter,
221
226
  )
222
227
 
223
228
  exclude = None