zenml-nightly 0.83.0.dev20250618__py3-none-any.whl → 0.83.0.dev20250621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/__init__.py +12 -2
  3. zenml/analytics/context.py +4 -2
  4. zenml/config/server_config.py +6 -1
  5. zenml/constants.py +3 -0
  6. zenml/entrypoints/step_entrypoint_configuration.py +14 -0
  7. zenml/models/__init__.py +15 -0
  8. zenml/models/v2/core/api_transaction.py +193 -0
  9. zenml/models/v2/core/pipeline_build.py +4 -0
  10. zenml/models/v2/core/pipeline_deployment.py +8 -1
  11. zenml/models/v2/core/pipeline_run.py +7 -0
  12. zenml/models/v2/core/step_run.py +6 -0
  13. zenml/orchestrators/input_utils.py +34 -11
  14. zenml/utils/json_utils.py +1 -1
  15. zenml/zen_server/auth.py +53 -31
  16. zenml/zen_server/cloud_utils.py +19 -7
  17. zenml/zen_server/middleware.py +424 -0
  18. zenml/zen_server/rbac/endpoint_utils.py +5 -2
  19. zenml/zen_server/rbac/utils.py +12 -7
  20. zenml/zen_server/request_management.py +556 -0
  21. zenml/zen_server/routers/auth_endpoints.py +1 -0
  22. zenml/zen_server/routers/model_versions_endpoints.py +3 -3
  23. zenml/zen_server/routers/models_endpoints.py +3 -3
  24. zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
  25. zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
  26. zenml/zen_server/routers/pipelines_endpoints.py +4 -4
  27. zenml/zen_server/routers/run_templates_endpoints.py +3 -3
  28. zenml/zen_server/routers/runs_endpoints.py +4 -4
  29. zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
  30. zenml/zen_server/routers/steps_endpoints.py +3 -3
  31. zenml/zen_server/utils.py +230 -63
  32. zenml/zen_server/zen_server_api.py +34 -399
  33. zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
  34. zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
  35. zenml/zen_stores/rest_zen_store.py +52 -42
  36. zenml/zen_stores/schemas/__init__.py +4 -0
  37. zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
  38. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
  40. zenml/zen_stores/schemas/step_run_schemas.py +4 -4
  41. zenml/zen_stores/sql_zen_store.py +277 -42
  42. zenml/zen_stores/zen_store_interface.py +7 -1
  43. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/RECORD +47 -41
  45. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/entry_points.txt +0 -0
@@ -22,10 +22,9 @@ import os
22
22
  import random
23
23
  import re
24
24
  import sys
25
- import threading
26
25
  import time
27
26
  from collections import defaultdict
28
- from datetime import datetime
27
+ from datetime import datetime, timedelta
29
28
  from functools import lru_cache
30
29
  from pathlib import Path
31
30
  from typing import (
@@ -60,7 +59,7 @@ from pydantic import (
60
59
  field_validator,
61
60
  model_validator,
62
61
  )
63
- from sqlalchemy import QueuePool, func
62
+ from sqlalchemy import QueuePool, func, update
64
63
  from sqlalchemy.engine import URL, Engine, make_url
65
64
  from sqlalchemy.exc import (
66
65
  ArgumentError,
@@ -160,6 +159,9 @@ from zenml.models import (
160
159
  APIKeyResponse,
161
160
  APIKeyRotateRequest,
162
161
  APIKeyUpdate,
162
+ ApiTransactionRequest,
163
+ ApiTransactionResponse,
164
+ ApiTransactionUpdate,
163
165
  ArtifactFilter,
164
166
  ArtifactRequest,
165
167
  ArtifactResponse,
@@ -331,6 +333,7 @@ from zenml.zen_stores.migrations.utils import MigrationUtils
331
333
  from zenml.zen_stores.schemas import (
332
334
  ActionSchema,
333
335
  APIKeySchema,
336
+ ApiTransactionSchema,
334
337
  ArtifactSchema,
335
338
  ArtifactVersionSchema,
336
339
  BaseSchema,
@@ -358,6 +361,7 @@ from zenml.zen_stores.schemas import (
358
361
  ServiceConnectorSchema,
359
362
  StackComponentSchema,
360
363
  StackSchema,
364
+ StepConfigurationSchema,
361
365
  StepRunInputArtifactSchema,
362
366
  StepRunOutputArtifactSchema,
363
367
  StepRunParentsSchema,
@@ -426,6 +430,46 @@ def exponential_backoff_with_jitter(
426
430
  class Session(SqlModelSession):
427
431
  """Session subclass that automatically tracks duration and calling context."""
428
432
 
433
+ def _get_metrics(self) -> Dict[str, Any]:
434
+ """Get the metrics for the session.
435
+
436
+ Returns:
437
+ The metrics for the session.
438
+ """
439
+ # Get SQLAlchemy connection pool info
440
+ assert isinstance(self.bind, Engine)
441
+ assert isinstance(self.bind.pool, QueuePool)
442
+ checked_out_connections = self.bind.pool.checkedout()
443
+ available_connections = self.bind.pool.checkedin()
444
+ overflow = self.bind.pool.overflow()
445
+
446
+ return {
447
+ "active_connections": checked_out_connections,
448
+ "idle_connections": available_connections,
449
+ "overflow_connections": overflow,
450
+ }
451
+
452
+ def _get_metrics_log_str(self) -> str:
453
+ """Get the metrics for the session as a string for logging.
454
+
455
+ Returns:
456
+ The metrics for the session as a string for logging.
457
+ """
458
+ if not logger.isEnabledFor(logging.DEBUG):
459
+ return ""
460
+ metrics = self._get_metrics()
461
+ # Add the server metrics if running in a server
462
+ if handle_bool_env_var(ENV_ZENML_SERVER):
463
+ from zenml.zen_server.utils import get_system_metrics
464
+
465
+ metrics.update(get_system_metrics())
466
+
467
+ return (
468
+ " [ "
469
+ + " ".join([f"{key}: {value}" for key, value in metrics.items()])
470
+ + " ]"
471
+ )
472
+
429
473
  def __enter__(self) -> "Session":
430
474
  """Enter the context manager.
431
475
 
@@ -433,15 +477,19 @@ class Session(SqlModelSession):
433
477
  The SqlModel session.
434
478
  """
435
479
  if logger.isEnabledFor(logging.DEBUG):
436
- # Get the request ID from the current thread object
437
- self.request_id = threading.current_thread().name
480
+ self.log_request_id = "N/A"
481
+ self.log_request = ""
482
+ if handle_bool_env_var(ENV_ZENML_SERVER):
483
+ # Running inside server
484
+ from zenml.zen_server.utils import get_current_request_context
438
485
 
439
- # Get SQLAlchemy connection pool info
440
- assert isinstance(self.bind, Engine)
441
- assert isinstance(self.bind.pool, QueuePool)
442
- checked_out_connections = self.bind.pool.checkedout()
443
- available_connections = self.bind.pool.checkedin()
444
- overflow = self.bind.pool.overflow()
486
+ # If the code is running on the server, use the auth context.
487
+ try:
488
+ request_context = get_current_request_context()
489
+ self.log_request_id = request_context.log_request_id
490
+ self.log_request = request_context.log_request
491
+ except RuntimeError:
492
+ pass
445
493
 
446
494
  # Look up the stack to find the SQLZenStore method
447
495
  for frame in inspect.stack():
@@ -456,10 +504,9 @@ class Session(SqlModelSession):
456
504
  self.caller_method = "unknown"
457
505
 
458
506
  logger.debug(
459
- f"[{self.request_id}] SQL STATS - "
460
- f"'{self.caller_method}' started [ conn(active): "
461
- f"{checked_out_connections} conn(idle): "
462
- f"{available_connections} conn(overflow): {overflow} ]"
507
+ f"[{self.log_request_id}] SQL STATS - "
508
+ f"{self.log_request} "
509
+ f"'{self.caller_method}' STARTED {self._get_metrics_log_str()}"
463
510
  )
464
511
 
465
512
  self.start_time = time.time()
@@ -482,19 +529,18 @@ class Session(SqlModelSession):
482
529
  if logger.isEnabledFor(logging.DEBUG):
483
530
  duration = (time.time() - self.start_time) * 1000
484
531
 
485
- # Get SQLAlchemy connection pool info
486
- assert isinstance(self.bind, Engine)
487
- assert isinstance(self.bind.pool, QueuePool)
488
- checked_out_connections = self.bind.pool.checkedout()
489
- available_connections = self.bind.pool.checkedin()
490
- overflow = self.bind.pool.overflow()
532
+ # Add error information to the log
533
+ error_info = ""
534
+ if exc_type is not None:
535
+ error_info = " with ERROR"
536
+
491
537
  logger.debug(
492
- f"[{self.request_id}] SQL STATS - "
493
- f"'{self.caller_method}' completed in "
494
- f"{duration:.2f}ms [ conn(active): "
495
- f"{checked_out_connections} conn(idle): "
496
- f"{available_connections} conn(overflow): {overflow} ]"
538
+ f"[{self.log_request_id}] SQL STATS - "
539
+ f"{self.log_request} "
540
+ f"'{self.caller_method}' COMPLETED in "
541
+ f"{duration:.2f}ms {error_info} {self._get_metrics_log_str()}"
497
542
  )
543
+
498
544
  super().__exit__(exc_type, exc_val, exc_tb)
499
545
 
500
546
 
@@ -2456,6 +2502,178 @@ class SqlZenStore(BaseZenStore):
2456
2502
  session.delete(api_key)
2457
2503
  session.commit()
2458
2504
 
2505
+ # -------------------- API Transactions --------------------
2506
+
2507
+ def _get_api_transaction(
2508
+ self,
2509
+ api_transaction_id: UUID,
2510
+ session: Session,
2511
+ method: Optional[str] = None,
2512
+ url: Optional[str] = None,
2513
+ ) -> ApiTransactionSchema:
2514
+ """Retrieve or create a new API transaction.
2515
+
2516
+ Args:
2517
+ api_transaction_id: The ID of the API transaction to retrieve.
2518
+ session: The session to use for the query.
2519
+ method: The HTTP method of the API transaction.
2520
+ url: The URL of the API transaction.
2521
+
2522
+ Returns:
2523
+ The API transaction.
2524
+
2525
+ Raises:
2526
+ KeyError: If the API transaction does not exist.
2527
+ EntityExistsError: If the API transaction exists but is not owned by
2528
+ the current user.
2529
+ """
2530
+ api_transaction_schema = session.exec(
2531
+ select(ApiTransactionSchema).where(
2532
+ ApiTransactionSchema.id == api_transaction_id
2533
+ )
2534
+ ).first()
2535
+
2536
+ if not api_transaction_schema:
2537
+ raise KeyError(
2538
+ f"API transaction with ID {api_transaction_id} not found."
2539
+ )
2540
+
2541
+ # As a security measure, we don't allow users to access other users'
2542
+ # API transactions.
2543
+ if (
2544
+ api_transaction_schema.user_id
2545
+ != self._get_active_user(session=session).id
2546
+ ):
2547
+ raise EntityExistsError(
2548
+ f"Unable to create API transaction with ID "
2549
+ f"{api_transaction_id}: A transaction with "
2550
+ "the same ID already exists for a different user."
2551
+ )
2552
+
2553
+ # As another security measure, we don't allow the same transaction
2554
+ # ID to be used with different method or URL.
2555
+ if (
2556
+ method is not None
2557
+ and api_transaction_schema.method != method
2558
+ or url is not None
2559
+ and api_transaction_schema.url != url
2560
+ ):
2561
+ raise EntityExistsError(
2562
+ f"Unable to get API transaction with ID "
2563
+ f"{api_transaction_id}: A transaction with "
2564
+ "the same ID already exists with a different method or URL."
2565
+ )
2566
+
2567
+ return api_transaction_schema
2568
+
2569
+ def _cleanup_expired_api_transactions(self, session: Session) -> None:
2570
+ """Delete completed API transactions that have expired.
2571
+
2572
+ Args:
2573
+ session: The session to use for the query.
2574
+ """
2575
+ session.execute(
2576
+ delete(ApiTransactionSchema).where(
2577
+ col(ApiTransactionSchema.completed),
2578
+ col(ApiTransactionSchema.expired) < utc_now(),
2579
+ )
2580
+ )
2581
+
2582
+ def get_or_create_api_transaction(
2583
+ self, api_transaction: ApiTransactionRequest
2584
+ ) -> Tuple[ApiTransactionResponse, bool]:
2585
+ """Retrieve or create a new API transaction.
2586
+
2587
+ Args:
2588
+ api_transaction: The API transaction to retrieve or create.
2589
+
2590
+ Returns:
2591
+ The API transaction and a boolean indicating whether the transaction
2592
+ was created.
2593
+ """
2594
+ with Session(self.engine) as session:
2595
+ self._set_request_user_id(
2596
+ request_model=api_transaction, session=session
2597
+ )
2598
+
2599
+ api_transaction_schema = ApiTransactionSchema.from_request(
2600
+ api_transaction
2601
+ )
2602
+ session.add(api_transaction_schema)
2603
+ created = False
2604
+ try:
2605
+ session.commit()
2606
+ session.refresh(api_transaction_schema)
2607
+ created = True
2608
+ except IntegrityError:
2609
+ # We have to rollback the failed session first in order to
2610
+ # continue using it
2611
+ session.rollback()
2612
+ api_transaction_schema = self._get_api_transaction(
2613
+ api_transaction_id=api_transaction_schema.id,
2614
+ method=api_transaction.method,
2615
+ url=api_transaction.url,
2616
+ session=session,
2617
+ )
2618
+
2619
+ return (
2620
+ api_transaction_schema.to_model(
2621
+ include_metadata=True, include_resources=True
2622
+ ),
2623
+ created,
2624
+ )
2625
+
2626
+ def finalize_api_transaction(
2627
+ self,
2628
+ api_transaction_id: UUID,
2629
+ api_transaction_update: ApiTransactionUpdate,
2630
+ ) -> None:
2631
+ """Finalize an API transaction.
2632
+
2633
+ Args:
2634
+ api_transaction_id: The ID of the API transaction to update.
2635
+ api_transaction_update: The update to be applied to the API transaction.
2636
+
2637
+ Raises:
2638
+ KeyError: If the API transaction is not found.
2639
+ """
2640
+ with Session(self.engine) as session:
2641
+ updated = utc_now()
2642
+ expired = updated + timedelta(
2643
+ seconds=api_transaction_update.cache_time
2644
+ )
2645
+ result = session.execute(
2646
+ update(ApiTransactionSchema)
2647
+ .where(col(ApiTransactionSchema.id) == api_transaction_id)
2648
+ .values(
2649
+ completed=True,
2650
+ updated=updated,
2651
+ expired=expired,
2652
+ result=api_transaction_update.get_result(),
2653
+ )
2654
+ )
2655
+ self._cleanup_expired_api_transactions(session=session)
2656
+ session.commit()
2657
+
2658
+ if result.rowcount == 0: # type: ignore[attr-defined]
2659
+ raise KeyError(
2660
+ f"API transaction with ID {api_transaction_id} not found."
2661
+ )
2662
+
2663
+ def delete_api_transaction(self, api_transaction_id: UUID) -> None:
2664
+ """Delete an API transaction.
2665
+
2666
+ Args:
2667
+ api_transaction_id: The ID of the API transaction to delete.
2668
+ """
2669
+ with Session(self.engine) as session:
2670
+ session.execute(
2671
+ delete(ApiTransactionSchema).where(
2672
+ col(ApiTransactionSchema.id) == api_transaction_id
2673
+ )
2674
+ )
2675
+ session.commit()
2676
+
2459
2677
  # -------------------- Services --------------------
2460
2678
 
2461
2679
  @staticmethod
@@ -4649,6 +4867,23 @@ class SqlZenStore(BaseZenStore):
4649
4867
  )
4650
4868
  session.add(new_deployment)
4651
4869
  session.commit()
4870
+
4871
+ for index, (step_name, step_configuration) in enumerate(
4872
+ deployment.step_configurations.items()
4873
+ ):
4874
+ step_configuration_schema = StepConfigurationSchema(
4875
+ index=index,
4876
+ name=step_name,
4877
+ # Don't include the merged config in the step
4878
+ # configurations, we reconstruct it in the `to_model` method
4879
+ # using the pipeline configuration.
4880
+ config=step_configuration.model_dump_json(
4881
+ exclude={"config"}
4882
+ ),
4883
+ deployment_id=new_deployment.id,
4884
+ )
4885
+ session.add(step_configuration_schema)
4886
+ session.commit()
4652
4887
  session.refresh(new_deployment)
4653
4888
 
4654
4889
  return new_deployment.to_model(
@@ -4656,7 +4891,10 @@ class SqlZenStore(BaseZenStore):
4656
4891
  )
4657
4892
 
4658
4893
  def get_deployment(
4659
- self, deployment_id: UUID, hydrate: bool = True
4894
+ self,
4895
+ deployment_id: UUID,
4896
+ hydrate: bool = True,
4897
+ step_configuration_filter: Optional[List[str]] = None,
4660
4898
  ) -> PipelineDeploymentResponse:
4661
4899
  """Get a deployment with a given ID.
4662
4900
 
@@ -4664,6 +4902,9 @@ class SqlZenStore(BaseZenStore):
4664
4902
  deployment_id: ID of the deployment.
4665
4903
  hydrate: Flag deciding whether to hydrate the output model(s)
4666
4904
  by including metadata fields in the response.
4905
+ step_configuration_filter: List of step configurations to include in
4906
+ the response. If not given, all step configurations will be
4907
+ included.
4667
4908
 
4668
4909
  Returns:
4669
4910
  The deployment.
@@ -4677,7 +4918,9 @@ class SqlZenStore(BaseZenStore):
4677
4918
  )
4678
4919
 
4679
4920
  return deployment.to_model(
4680
- include_metadata=hydrate, include_resources=True
4921
+ include_metadata=hydrate,
4922
+ include_resources=True,
4923
+ step_configuration_filter=step_configuration_filter,
4681
4924
  )
4682
4925
 
4683
4926
  def list_deployments(
@@ -5110,12 +5353,11 @@ class SqlZenStore(BaseZenStore):
5110
5353
  )
5111
5354
 
5112
5355
  steps = {
5113
- step_name: Step.from_dict(
5114
- config_dict, pipeline_configuration=pipeline_configuration
5356
+ config_table.name: Step.from_dict(
5357
+ json.loads(config_table.config),
5358
+ pipeline_configuration=pipeline_configuration,
5115
5359
  )
5116
- for step_name, config_dict in json.loads(
5117
- deployment.step_configurations
5118
- ).items()
5360
+ for config_table in deployment.get_step_configurations()
5119
5361
  }
5120
5362
  regular_output_artifact_nodes: Dict[
5121
5363
  str, Dict[str, PipelineRunDAG.Node]
@@ -8943,11 +9185,6 @@ class SqlZenStore(BaseZenStore):
8943
9185
  pipeline_run = session.exec(
8944
9186
  select(PipelineRunSchema)
8945
9187
  .with_for_update()
8946
- .options(
8947
- joinedload(
8948
- jl_arg(PipelineRunSchema.deployment), innerjoin=True
8949
- )
8950
- )
8951
9188
  .where(PipelineRunSchema.id == pipeline_run_id)
8952
9189
  ).one()
8953
9190
  step_run_statuses = session.exec(
@@ -8958,9 +9195,7 @@ class SqlZenStore(BaseZenStore):
8958
9195
 
8959
9196
  # Deployment always exists for pipeline runs of newer versions
8960
9197
  assert pipeline_run.deployment
8961
- num_steps = len(
8962
- json.loads(pipeline_run.deployment.step_configurations)
8963
- )
9198
+ num_steps = pipeline_run.deployment.step_count
8964
9199
  new_status = get_pipeline_run_status(
8965
9200
  step_statuses=[
8966
9201
  ExecutionStatus(status) for status in step_run_statuses
@@ -9499,7 +9734,7 @@ class SqlZenStore(BaseZenStore):
9499
9734
  """
9500
9735
  if handle_bool_env_var(ENV_ZENML_SERVER):
9501
9736
  # Running inside server
9502
- from zenml.zen_server.auth import get_auth_context
9737
+ from zenml.zen_server.utils import get_auth_context
9503
9738
 
9504
9739
  # If the code is running on the server, use the auth context.
9505
9740
  auth_context = get_auth_context()
@@ -1296,7 +1296,10 @@ class ZenStoreInterface(ABC):
1296
1296
 
1297
1297
  @abstractmethod
1298
1298
  def get_deployment(
1299
- self, deployment_id: UUID, hydrate: bool = True
1299
+ self,
1300
+ deployment_id: UUID,
1301
+ hydrate: bool = True,
1302
+ step_configuration_filter: Optional[List[str]] = None,
1300
1303
  ) -> PipelineDeploymentResponse:
1301
1304
  """Get a deployment with a given ID.
1302
1305
 
@@ -1304,6 +1307,9 @@ class ZenStoreInterface(ABC):
1304
1307
  deployment_id: ID of the deployment.
1305
1308
  hydrate: Flag deciding whether to hydrate the output model(s)
1306
1309
  by including metadata fields in the response.
1310
+ step_configuration_filter: List of step configurations to include in
1311
+ the response. If not given, all step configurations will be
1312
+ included.
1307
1313
 
1308
1314
  Returns:
1309
1315
  The deployment.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: zenml-nightly
3
- Version: 0.83.0.dev20250618
3
+ Version: 0.83.0.dev20250621
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  License: Apache-2.0
6
6
  Keywords: machine learning,production,pipeline,mlops,devops