zenml-nightly 0.83.0.dev20250618__py3-none-any.whl → 0.83.0.dev20250621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/__init__.py +12 -2
  3. zenml/analytics/context.py +4 -2
  4. zenml/config/server_config.py +6 -1
  5. zenml/constants.py +3 -0
  6. zenml/entrypoints/step_entrypoint_configuration.py +14 -0
  7. zenml/models/__init__.py +15 -0
  8. zenml/models/v2/core/api_transaction.py +193 -0
  9. zenml/models/v2/core/pipeline_build.py +4 -0
  10. zenml/models/v2/core/pipeline_deployment.py +8 -1
  11. zenml/models/v2/core/pipeline_run.py +7 -0
  12. zenml/models/v2/core/step_run.py +6 -0
  13. zenml/orchestrators/input_utils.py +34 -11
  14. zenml/utils/json_utils.py +1 -1
  15. zenml/zen_server/auth.py +53 -31
  16. zenml/zen_server/cloud_utils.py +19 -7
  17. zenml/zen_server/middleware.py +424 -0
  18. zenml/zen_server/rbac/endpoint_utils.py +5 -2
  19. zenml/zen_server/rbac/utils.py +12 -7
  20. zenml/zen_server/request_management.py +556 -0
  21. zenml/zen_server/routers/auth_endpoints.py +1 -0
  22. zenml/zen_server/routers/model_versions_endpoints.py +3 -3
  23. zenml/zen_server/routers/models_endpoints.py +3 -3
  24. zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
  25. zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
  26. zenml/zen_server/routers/pipelines_endpoints.py +4 -4
  27. zenml/zen_server/routers/run_templates_endpoints.py +3 -3
  28. zenml/zen_server/routers/runs_endpoints.py +4 -4
  29. zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
  30. zenml/zen_server/routers/steps_endpoints.py +3 -3
  31. zenml/zen_server/utils.py +230 -63
  32. zenml/zen_server/zen_server_api.py +34 -399
  33. zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
  34. zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
  35. zenml/zen_stores/rest_zen_store.py +52 -42
  36. zenml/zen_stores/schemas/__init__.py +4 -0
  37. zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
  38. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
  40. zenml/zen_stores/schemas/step_run_schemas.py +4 -4
  41. zenml/zen_stores/sql_zen_store.py +277 -42
  42. zenml/zen_stores/zen_store_interface.py +7 -1
  43. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/RECORD +47 -41
  45. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,424 @@
1
+ # Copyright (c) ZenML GmbH 2022. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Server middlewares."""
15
+
16
+ import logging
17
+ from asyncio import Lock
18
+ from asyncio.log import logger
19
+ from datetime import datetime, timedelta
20
+ from typing import Any, Set
21
+
22
+ from anyio import CapacityLimiter, to_thread
23
+ from fastapi import FastAPI, Request
24
+ from fastapi.responses import PlainTextResponse
25
+ from starlette.middleware.base import (
26
+ BaseHTTPMiddleware,
27
+ RequestResponseEndpoint,
28
+ )
29
+ from starlette.middleware.cors import CORSMiddleware
30
+ from starlette.responses import (
31
+ JSONResponse,
32
+ Response,
33
+ )
34
+ from starlette.types import ASGIApp
35
+
36
+ from zenml.analytics import source_context
37
+ from zenml.constants import (
38
+ DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS,
39
+ HEALTH,
40
+ READY,
41
+ )
42
+ from zenml.enums import SourceContextTypes
43
+ from zenml.utils.time_utils import utc_now
44
+ from zenml.zen_server.request_management import RequestContext
45
+ from zenml.zen_server.secure_headers import (
46
+ secure_headers,
47
+ )
48
+ from zenml.zen_server.utils import (
49
+ get_system_metrics_log_str,
50
+ is_user_request,
51
+ request_manager,
52
+ server_config,
53
+ zen_store,
54
+ )
55
+
56
+ # Track active requests with an atomic counter
57
+ active_requests_count = 0
58
+ active_requests_lock = Lock()
59
+
60
+
61
+ # Initialize last_user_activity
62
+ last_user_activity: datetime = utc_now()
63
+ last_user_activity_reported: datetime = last_user_activity + timedelta(
64
+ seconds=-DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
65
+ )
66
+ last_user_activity_lock = Lock()
67
+ # Create a custom thread pool limiter with a limit of 1 thread for all
68
+ # user activity updates
69
+ last_user_activity_thread_limiter = CapacityLimiter(1)
70
+
71
+
72
+ class RequestBodyLimit(BaseHTTPMiddleware):
73
+ """Limits the size of the request body."""
74
+
75
+ def __init__(self, app: ASGIApp, max_bytes: int) -> None:
76
+ """Limits the size of the request body.
77
+
78
+ Args:
79
+ app: The FastAPI app.
80
+ max_bytes: The maximum size of the request body.
81
+ """
82
+ super().__init__(app)
83
+ self.max_bytes = max_bytes
84
+
85
+ async def dispatch(
86
+ self, request: Request, call_next: RequestResponseEndpoint
87
+ ) -> Response:
88
+ """Limits the size of the request body.
89
+
90
+ Args:
91
+ request: The incoming request.
92
+ call_next: The next function to be called.
93
+
94
+ Returns:
95
+ The response to the request.
96
+ """
97
+ if content_length := request.headers.get("content-length"):
98
+ if int(content_length) > self.max_bytes:
99
+ return Response(status_code=413) # Request Entity Too Large
100
+
101
+ try:
102
+ return await call_next(request)
103
+ except Exception:
104
+ logger.exception("An error occurred while processing the request")
105
+ return JSONResponse(
106
+ status_code=500,
107
+ content={"detail": "An unexpected error occurred."},
108
+ )
109
+
110
+
111
+ class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
112
+ """Restrict file uploads to certain paths."""
113
+
114
+ def __init__(self, app: FastAPI, allowed_paths: Set[str]):
115
+ """Restrict file uploads to certain paths.
116
+
117
+ Args:
118
+ app: The FastAPI app.
119
+ allowed_paths: The allowed paths.
120
+ """
121
+ super().__init__(app)
122
+ self.allowed_paths = allowed_paths
123
+
124
+ async def dispatch(
125
+ self, request: Request, call_next: RequestResponseEndpoint
126
+ ) -> Response:
127
+ """Restrict file uploads to certain paths.
128
+
129
+ Args:
130
+ request: The incoming request.
131
+ call_next: The next function to be called.
132
+
133
+ Returns:
134
+ The response to the request.
135
+ """
136
+ if request.method == "POST":
137
+ content_type = request.headers.get("content-type", "")
138
+ if (
139
+ "multipart/form-data" in content_type
140
+ and request.url.path not in self.allowed_paths
141
+ ):
142
+ return JSONResponse(
143
+ status_code=403,
144
+ content={
145
+ "detail": "File uploads are not allowed on this endpoint."
146
+ },
147
+ )
148
+
149
+ try:
150
+ return await call_next(request)
151
+ except Exception:
152
+ logger.exception("An error occurred while processing the request")
153
+ return JSONResponse(
154
+ status_code=500,
155
+ content={"detail": "An unexpected error occurred."},
156
+ )
157
+
158
+
159
+ ALLOWED_FOR_FILE_UPLOAD: Set[str] = set()
160
+
161
+
162
+ async def track_last_user_activity(request: Request, call_next: Any) -> Any:
163
+ """A middleware to track last user activity.
164
+
165
+ This middleware checks if the incoming request is a user request and
166
+ updates the last activity timestamp if it is.
167
+
168
+ Args:
169
+ request: The incoming request object.
170
+ call_next: A function that will receive the request as a parameter and
171
+ pass it to the corresponding path operation.
172
+
173
+ Returns:
174
+ The response to the request.
175
+ """
176
+ global last_user_activity
177
+ global last_user_activity_reported
178
+ global last_user_activity_lock
179
+
180
+ now = utc_now()
181
+
182
+ try:
183
+ if is_user_request(request):
184
+ report_user_activity = False
185
+ async with last_user_activity_lock:
186
+ last_user_activity = now
187
+ if (
188
+ (now - last_user_activity_reported).total_seconds()
189
+ > DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
190
+ ):
191
+ last_user_activity_reported = now
192
+ report_user_activity = True
193
+
194
+ if report_user_activity:
195
+ # We don't want to make a DB call here because we're in the
196
+ # context of the asyncio event loop and it would block the
197
+ # entire application for who knows how long.
198
+ # We use the threadpool for it.
199
+
200
+ request_context = request_manager().current_request
201
+
202
+ def update_last_user_activity_timestamp() -> None:
203
+ logger.debug(
204
+ f"[{request_context.log_request_id}] API STATS - "
205
+ f"{request_context.log_request} "
206
+ f"UPDATING LAST USER ACTIVITY "
207
+ f"{get_system_metrics_log_str(request_context.request)}"
208
+ )
209
+
210
+ try:
211
+ zen_store()._update_last_user_activity_timestamp(
212
+ last_user_activity=last_user_activity,
213
+ )
214
+ finally:
215
+ logger.debug(
216
+ f"[{request_context.log_request_id}] API STATS - "
217
+ f"{request_context.log_request} "
218
+ f"UPDATED LAST USER ACTIVITY "
219
+ f"{get_system_metrics_log_str(request_context.request)}"
220
+ )
221
+
222
+ await to_thread.run_sync(
223
+ update_last_user_activity_timestamp,
224
+ limiter=last_user_activity_thread_limiter,
225
+ )
226
+
227
+ except Exception as e:
228
+ logger.debug(
229
+ f"An unexpected error occurred while checking user activity: {e}"
230
+ )
231
+
232
+ try:
233
+ return await call_next(request)
234
+ except Exception:
235
+ logger.exception("An error occurred while processing the request")
236
+ return JSONResponse(
237
+ status_code=500,
238
+ content={"detail": "An unexpected error occurred."},
239
+ )
240
+
241
+
242
+ async def infer_source_context(request: Request, call_next: Any) -> Any:
243
+ """A middleware to track the source of an event.
244
+
245
+ It extracts the source context from the header of incoming requests
246
+ and applies it to the ZenML source context on the API side. This way, the
247
+ outgoing analytics request can append it as an additional field.
248
+
249
+ Args:
250
+ request: the incoming request object.
251
+ call_next: a function that will receive the request as a parameter and
252
+ pass it to the corresponding path operation.
253
+
254
+ Returns:
255
+ the response to the request.
256
+ """
257
+ try:
258
+ s = request.headers.get(
259
+ source_context.name,
260
+ default=SourceContextTypes.API.value,
261
+ )
262
+ source_context.set(SourceContextTypes(s))
263
+ except Exception as e:
264
+ logger.warning(
265
+ f"An unexpected error occurred while getting the source "
266
+ f"context: {e}"
267
+ )
268
+ source_context.set(SourceContextTypes.API)
269
+
270
+ try:
271
+ return await call_next(request)
272
+ except Exception:
273
+ logger.exception("An error occurred while processing the request")
274
+ return JSONResponse(
275
+ status_code=500,
276
+ content={"detail": "An unexpected error occurred."},
277
+ )
278
+
279
+
280
+ async def set_secure_headers(request: Request, call_next: Any) -> Any:
281
+ """Middleware to set secure headers.
282
+
283
+ Args:
284
+ request: The incoming request.
285
+ call_next: The next function to be called.
286
+
287
+ Returns:
288
+ The response with secure headers set.
289
+ """
290
+ try:
291
+ response = await call_next(request)
292
+ except Exception:
293
+ logger.exception("An error occurred while processing the request")
294
+ response = JSONResponse(
295
+ status_code=500,
296
+ content={"detail": "An unexpected error occurred."},
297
+ )
298
+
299
+ # If the request is for the openAPI docs, don't set secure headers
300
+ if request.url.path.startswith("/docs") or request.url.path.startswith(
301
+ "/redoc"
302
+ ):
303
+ return response
304
+
305
+ secure_headers().framework.fastapi(response)
306
+ return response
307
+
308
+
309
+ async def log_requests(request: Request, call_next: Any) -> Any:
310
+ """Log requests to the ZenML server.
311
+
312
+ Args:
313
+ request: The incoming request object.
314
+ call_next: A function that will receive the request as a parameter and
315
+ pass it to the corresponding path operation.
316
+
317
+ Returns:
318
+ The response to the request.
319
+ """
320
+ global active_requests_count
321
+
322
+ if not logger.isEnabledFor(logging.DEBUG):
323
+ return await call_next(request)
324
+
325
+ async with active_requests_lock:
326
+ active_requests_count += 1
327
+
328
+ request_context = request_manager().current_request
329
+
330
+ logger.debug(
331
+ f"[{request_context.log_request_id}] API STATS - "
332
+ f"{request_context.log_request} "
333
+ f"RECEIVED {get_system_metrics_log_str(request)}"
334
+ )
335
+
336
+ try:
337
+ response = await call_next(request)
338
+
339
+ logger.debug(
340
+ f"[{request_context.log_request_id}] API STATS - "
341
+ f"{response.status_code} {request_context.log_request} "
342
+ f"took {request_context.log_duration} "
343
+ f"{get_system_metrics_log_str(request)}"
344
+ )
345
+
346
+ return response
347
+ finally:
348
+ async with active_requests_lock:
349
+ active_requests_count -= 1
350
+
351
+
352
+ async def record_requests(request: Request, call_next: Any) -> Any:
353
+ """Record requests to the ZenML server.
354
+
355
+ Args:
356
+ request: The incoming request object.
357
+ call_next: A function that will receive the request as a parameter and
358
+ pass it to the corresponding path operation.
359
+
360
+ Returns:
361
+ The response to the request.
362
+ """
363
+ # Keep track of the request context in a context variable
364
+ request_context = RequestContext(request=request)
365
+ request_manager().current_request = request_context
366
+
367
+ try:
368
+ response = await call_next(request)
369
+ except Exception:
370
+ logger.exception("An error occurred while processing the request")
371
+ response = JSONResponse(
372
+ status_code=500,
373
+ content={"detail": "An unexpected error occurred."},
374
+ )
375
+
376
+ return response
377
+
378
+
379
+ async def skip_health_middleware(request: Request, call_next: Any) -> Any:
380
+ """Skip health and ready endpoints.
381
+
382
+ Args:
383
+ request: The incoming request.
384
+ call_next: The next function to be called.
385
+
386
+ Returns:
387
+ The response to the request.
388
+ """
389
+ if request.url.path in [HEALTH, READY]:
390
+ # Skip expensive processing
391
+ return PlainTextResponse("ok")
392
+
393
+ return await call_next(request)
394
+
395
+
396
+ def add_middlewares(app: FastAPI) -> None:
397
+ """Add middlewares to the FastAPI app.
398
+
399
+ Args:
400
+ app: The FastAPI app.
401
+ """
402
+ app.add_middleware(BaseHTTPMiddleware, dispatch=track_last_user_activity)
403
+ app.add_middleware(BaseHTTPMiddleware, dispatch=infer_source_context)
404
+
405
+ app.add_middleware(
406
+ CORSMiddleware,
407
+ allow_origins=server_config().cors_allow_origins,
408
+ allow_credentials=True,
409
+ allow_methods=["*"],
410
+ allow_headers=["*"],
411
+ )
412
+
413
+ app.add_middleware(
414
+ RequestBodyLimit,
415
+ max_bytes=server_config().max_request_body_size_in_bytes,
416
+ )
417
+ app.add_middleware(
418
+ RestrictFileUploadsMiddleware, allowed_paths=ALLOWED_FOR_FILE_UPLOAD
419
+ )
420
+
421
+ app.add_middleware(BaseHTTPMiddleware, dispatch=set_secure_headers)
422
+ app.add_middleware(BaseHTTPMiddleware, dispatch=log_requests)
423
+ app.add_middleware(BaseHTTPMiddleware, dispatch=record_requests)
424
+ app.add_middleware(BaseHTTPMiddleware, dispatch=skip_health_middleware)
@@ -25,7 +25,6 @@ from zenml.models import (
25
25
  ProjectScopedFilter,
26
26
  UserScopedRequest,
27
27
  )
28
- from zenml.zen_server.auth import get_auth_context
29
28
  from zenml.zen_server.feature_gate.endpoint_utils import (
30
29
  check_entitlement,
31
30
  report_usage,
@@ -42,7 +41,11 @@ from zenml.zen_server.rbac.utils import (
42
41
  verify_permission,
43
42
  verify_permission_for_model,
44
43
  )
45
- from zenml.zen_server.utils import server_config, set_filter_project_scope
44
+ from zenml.zen_server.utils import (
45
+ get_auth_context,
46
+ server_config,
47
+ set_filter_project_scope,
48
+ )
46
49
 
47
50
  AnyRequest = TypeVar("AnyRequest", bound=BaseRequest)
48
51
  AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg]
@@ -37,9 +37,8 @@ from zenml.models import (
37
37
  UserResponse,
38
38
  UserScopedResponse,
39
39
  )
40
- from zenml.zen_server.auth import get_auth_context
41
40
  from zenml.zen_server.rbac.models import Action, Resource, ResourceType
42
- from zenml.zen_server.utils import rbac, server_config
41
+ from zenml.zen_server.utils import get_auth_context, rbac, server_config
43
42
 
44
43
  if TYPE_CHECKING:
45
44
  from zenml.zen_stores.schemas import BaseSchema
@@ -120,12 +119,16 @@ def dehydrate_response_model(
120
119
  )
121
120
 
122
121
  dehydrated_values = {}
122
+ skip_dehydration = getattr(model, "__zenml_skip_dehydration__", [])
123
123
  # See `get_subresources_for_model(...)` for a detailed explanation why we
124
124
  # need to use `model.__iter__()` here
125
125
  for key, value in model.__iter__():
126
- dehydrated_values[key] = _dehydrate_value(
127
- value, permissions=permissions
128
- )
126
+ if key in skip_dehydration:
127
+ dehydrated_values[key] = value
128
+ else:
129
+ dehydrated_values[key] = _dehydrate_value(
130
+ value, permissions=permissions
131
+ )
129
132
 
130
133
  return type(model).model_validate(dehydrated_values)
131
134
 
@@ -579,8 +582,10 @@ def get_subresources_for_model(
579
582
  for item in model:
580
583
  resources.update(_get_subresources_for_value(item))
581
584
  else:
582
- for _, value in model.__iter__():
583
- resources.update(_get_subresources_for_value(value))
585
+ skip_dehydration = getattr(model, "__zenml_skip_dehydration__", [])
586
+ for key, value in model.__iter__():
587
+ if key not in skip_dehydration:
588
+ resources.update(_get_subresources_for_value(value))
584
589
 
585
590
  return resources
586
591