zenml-nightly 0.71.0.dev20250105__py3-none-any.whl → 0.71.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.71.0.dev20250105
1
+ 0.71.0.dev20250109
@@ -490,15 +490,14 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
490
490
  Returns:
491
491
  The query with filter applied.
492
492
  """
493
- from zenml.zen_stores.schemas import TagResourceSchema
493
+ from zenml.zen_stores.schemas import TagResourceSchema, TagSchema
494
494
 
495
495
  query = super().apply_filter(query=query, table=table)
496
496
  if self.tag:
497
- query = (
498
- query.join(getattr(table, "tags"))
499
- .join(TagResourceSchema.tag)
500
- .distinct()
501
- )
497
+ query = query.join(
498
+ TagResourceSchema,
499
+ TagResourceSchema.resource_id == getattr(table, "id"),
500
+ ).join(TagSchema, TagSchema.id == TagResourceSchema.tag_id)
502
501
 
503
502
  return query
504
503
 
zenml/zen_server/auth.py CHANGED
@@ -37,7 +37,12 @@ from zenml.constants import (
37
37
  LOGIN,
38
38
  VERSION_1,
39
39
  )
40
- from zenml.enums import AuthScheme, ExecutionStatus, OAuthDeviceStatus
40
+ from zenml.enums import (
41
+ AuthScheme,
42
+ ExecutionStatus,
43
+ OAuthDeviceStatus,
44
+ OnboardingStep,
45
+ )
41
46
  from zenml.exceptions import (
42
47
  AuthorizationException,
43
48
  CredentialsNotValid,
@@ -630,12 +635,15 @@ def authenticate_device(client_id: UUID, device_code: str) -> AuthContext:
630
635
  return AuthContext(user=device_model.user, device=device_model)
631
636
 
632
637
 
633
- def authenticate_external_user(external_access_token: str) -> AuthContext:
638
+ def authenticate_external_user(
639
+ external_access_token: str, request: Request
640
+ ) -> AuthContext:
634
641
  """Implement external authentication.
635
642
 
636
643
  Args:
637
644
  external_access_token: The access token used to authenticate the user
638
645
  to the external authenticator.
646
+ request: The request object.
639
647
 
640
648
  Returns:
641
649
  The authentication context reflecting the authenticated user.
@@ -761,6 +769,17 @@ def authenticate_external_user(external_access_token: str) -> AuthContext:
761
769
  )
762
770
  context.alias(user_id=external_user.id, previous_id=user.id)
763
771
 
772
+ # This is the best spot to update the onboarding state to mark the
773
+ # "zenml login" step as completed for ZenML Pro servers, because the
774
+ # user has just successfully logged in. However, we need to differentiate
775
+ # between web clients (i.e. the dashboard) and CLI clients (i.e. the
776
+ # zenml CLI).
777
+ user_agent = request.headers.get("User-Agent", "").lower()
778
+ if "zenml/" in user_agent:
779
+ store.update_onboarding_state(
780
+ completed_steps={OnboardingStep.DEVICE_VERIFIED}
781
+ )
782
+
764
783
  return AuthContext(user=user)
765
784
 
766
785
 
@@ -287,7 +287,8 @@ def token(
287
287
  return OAuthRedirectResponse(authorization_url=authorization_url)
288
288
 
289
289
  auth_context = authenticate_external_user(
290
- external_access_token=external_access_token
290
+ external_access_token=external_access_token,
291
+ request=request,
291
292
  )
292
293
 
293
294
  else:
@@ -4053,7 +4053,12 @@ class RestZenStore(BaseZenStore):
4053
4053
  )
4054
4054
 
4055
4055
  data: Optional[Dict[str, str]] = None
4056
- headers: Dict[str, str] = {}
4056
+
4057
+ # Use a custom user agent to identify the ZenML client in the server
4058
+ # logs.
4059
+ headers: Dict[str, str] = {
4060
+ "User-Agent": "zenml/" + zenml.__version__,
4061
+ }
4057
4062
 
4058
4063
  # Check if an API key is configured
4059
4064
  api_key = credentials_store.get_api_key(self.url)
@@ -4218,6 +4223,11 @@ class RestZenStore(BaseZenStore):
4218
4223
  self._session.mount("https://", HTTPAdapter(max_retries=retries))
4219
4224
  self._session.mount("http://", HTTPAdapter(max_retries=retries))
4220
4225
  self._session.verify = self.config.verify_ssl
4226
+ # Use a custom user agent to identify the ZenML client in the server
4227
+ # logs.
4228
+ self._session.headers.update(
4229
+ {"User-Agent": "zenml/" + zenml.__version__}
4230
+ )
4221
4231
 
4222
4232
  # Note that we return an unauthenticated session here. An API token
4223
4233
  # is only fetched and set in the authorization header when and if it is
@@ -59,10 +59,8 @@ if TYPE_CHECKING:
59
59
  from zenml.zen_stores.schemas.model_schemas import (
60
60
  ModelVersionArtifactSchema,
61
61
  )
62
- from zenml.zen_stores.schemas.run_metadata_schemas import (
63
- RunMetadataResourceSchema,
64
- )
65
- from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
62
+ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
63
+ from zenml.zen_stores.schemas.tag_schemas import TagSchema
66
64
 
67
65
 
68
66
  class ArtifactSchema(NamedSchema, table=True):
@@ -82,11 +80,12 @@ class ArtifactSchema(NamedSchema, table=True):
82
80
  back_populates="artifact",
83
81
  sa_relationship_kwargs={"cascade": "delete"},
84
82
  )
85
- tags: List["TagResourceSchema"] = Relationship(
86
- back_populates="artifact",
83
+ tags: List["TagSchema"] = Relationship(
87
84
  sa_relationship_kwargs=dict(
88
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
89
- cascade="delete",
85
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
86
+ secondary="tag_resource",
87
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
88
+ order_by="TagSchema.name",
90
89
  overlaps="tags",
91
90
  ),
92
91
  )
@@ -136,7 +135,7 @@ class ArtifactSchema(NamedSchema, table=True):
136
135
  body = ArtifactResponseBody(
137
136
  created=self.created,
138
137
  updated=self.updated,
139
- tags=[t.tag.to_model() for t in self.tags],
138
+ tags=[tag.to_model() for tag in self.tags],
140
139
  latest_version_name=latest_name,
141
140
  latest_version_id=latest_id,
142
141
  )
@@ -192,11 +191,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
192
191
  uri: str = Field(sa_column=Column(TEXT, nullable=False))
193
192
  materializer: str = Field(sa_column=Column(TEXT, nullable=False))
194
193
  data_type: str = Field(sa_column=Column(TEXT, nullable=False))
195
- tags: List["TagResourceSchema"] = Relationship(
196
- back_populates="artifact_version",
194
+ tags: List["TagSchema"] = Relationship(
197
195
  sa_relationship_kwargs=dict(
198
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
199
- cascade="delete",
196
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
197
+ secondary="tag_resource",
198
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
199
+ order_by="TagSchema.name",
200
200
  overlaps="tags",
201
201
  ),
202
202
  )
@@ -244,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
244
244
  workspace: "WorkspaceSchema" = Relationship(
245
245
  back_populates="artifact_versions"
246
246
  )
247
- run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
248
- back_populates="artifact_versions",
247
+ run_metadata: List["RunMetadataSchema"] = Relationship(
249
248
  sa_relationship_kwargs=dict(
250
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
251
- cascade="delete",
252
- overlaps="run_metadata_resources",
249
+ secondary="run_metadata_resource",
250
+ primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
251
+ secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
252
+ overlaps="run_metadata",
253
253
  ),
254
254
  )
255
255
  output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship(
@@ -365,7 +365,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
365
365
  data_type=data_type,
366
366
  created=self.created,
367
367
  updated=self.updated,
368
- tags=[t.tag.to_model() for t in self.tags],
368
+ tags=[tag.to_model() for tag in self.tags],
369
369
  producer_pipeline_run_id=producer_pipeline_run_id,
370
370
  save_type=ArtifactSaveType(self.save_type),
371
371
  artifact_store_id=self.artifact_store_id,
@@ -56,11 +56,9 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
56
56
  from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
57
57
  from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
58
58
  from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
59
- from zenml.zen_stores.schemas.run_metadata_schemas import (
60
- RunMetadataResourceSchema,
61
- )
59
+ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
62
60
  from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
63
- from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
61
+ from zenml.zen_stores.schemas.tag_schemas import TagSchema
64
62
  from zenml.zen_stores.schemas.user_schemas import UserSchema
65
63
  from zenml.zen_stores.schemas.utils import (
66
64
  RunMetadataInterface,
@@ -114,11 +112,12 @@ class ModelSchema(NamedSchema, table=True):
114
112
  save_models_to_registry: bool = Field(
115
113
  sa_column=Column(BOOLEAN, nullable=False)
116
114
  )
117
- tags: List["TagResourceSchema"] = Relationship(
118
- back_populates="model",
115
+ tags: List["TagSchema"] = Relationship(
119
116
  sa_relationship_kwargs=dict(
120
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
121
- cascade="delete",
117
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
118
+ secondary="tag_resource",
119
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
120
+ order_by="TagSchema.name",
122
121
  overlaps="tags",
123
122
  ),
124
123
  )
@@ -168,7 +167,7 @@ class ModelSchema(NamedSchema, table=True):
168
167
  Returns:
169
168
  The created `ModelResponse`.
170
169
  """
171
- tags = [t.tag.to_model() for t in self.tags]
170
+ tags = [tag.to_model() for tag in self.tags]
172
171
 
173
172
  if self.model_versions:
174
173
  version_numbers = [mv.number for mv in self.model_versions]
@@ -299,11 +298,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
299
298
  back_populates="model_version",
300
299
  sa_relationship_kwargs={"cascade": "delete"},
301
300
  )
302
- tags: List["TagResourceSchema"] = Relationship(
303
- back_populates="model_version",
301
+ tags: List["TagSchema"] = Relationship(
304
302
  sa_relationship_kwargs=dict(
305
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
306
- cascade="delete",
303
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
304
+ secondary="tag_resource",
305
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
306
+ order_by="TagSchema.name",
307
307
  overlaps="tags",
308
308
  ),
309
309
  )
@@ -316,12 +316,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
316
316
  description: str = Field(sa_column=Column(TEXT, nullable=True))
317
317
  stage: str = Field(sa_column=Column(TEXT, nullable=True))
318
318
 
319
- run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
320
- back_populates="model_versions",
319
+ run_metadata: List["RunMetadataSchema"] = Relationship(
321
320
  sa_relationship_kwargs=dict(
322
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
323
- cascade="delete",
324
- overlaps="run_metadata_resources",
321
+ secondary="run_metadata_resource",
322
+ primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
323
+ secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
324
+ overlaps="run_metadata",
325
325
  ),
326
326
  )
327
327
  pipeline_runs: List["PipelineRunSchema"] = Relationship(
@@ -471,7 +471,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
471
471
  data_artifact_ids=data_artifact_ids,
472
472
  deployment_artifact_ids=deployment_artifact_ids,
473
473
  pipeline_run_ids=pipeline_run_ids,
474
- tags=[t.tag.to_model() for t in self.tags],
474
+ tags=[tag.to_model() for tag in self.tags],
475
475
  )
476
476
 
477
477
  return ModelVersionResponse(
@@ -58,12 +58,10 @@ if TYPE_CHECKING:
58
58
  ModelVersionPipelineRunSchema,
59
59
  ModelVersionSchema,
60
60
  )
61
- from zenml.zen_stores.schemas.run_metadata_schemas import (
62
- RunMetadataResourceSchema,
63
- )
61
+ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
64
62
  from zenml.zen_stores.schemas.service_schemas import ServiceSchema
65
63
  from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
66
- from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
64
+ from zenml.zen_stores.schemas.tag_schemas import TagSchema
67
65
 
68
66
 
69
67
  class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
@@ -140,12 +138,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
140
138
  )
141
139
  workspace: "WorkspaceSchema" = Relationship(back_populates="runs")
142
140
  user: Optional["UserSchema"] = Relationship(back_populates="runs")
143
- run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
144
- back_populates="pipeline_runs",
141
+ run_metadata: List["RunMetadataSchema"] = Relationship(
145
142
  sa_relationship_kwargs=dict(
146
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
147
- cascade="delete",
148
- overlaps="run_metadata_resources",
143
+ secondary="run_metadata_resource",
144
+ primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
145
+ secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
146
+ overlaps="run_metadata",
149
147
  ),
150
148
  )
151
149
  logs: Optional["LogsSchema"] = Relationship(
@@ -215,10 +213,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
215
213
  services: List["ServiceSchema"] = Relationship(
216
214
  back_populates="pipeline_run",
217
215
  )
218
- tags: List["TagResourceSchema"] = Relationship(
216
+ tags: List["TagSchema"] = Relationship(
219
217
  sa_relationship_kwargs=dict(
220
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
221
- cascade="delete",
218
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
219
+ secondary="tag_resource",
220
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
221
+ order_by="TagSchema.name",
222
222
  overlaps="tags",
223
223
  ),
224
224
  )
@@ -291,12 +291,6 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
291
291
  Raises:
292
292
  RuntimeError: if the model creation fails.
293
293
  """
294
- orchestrator_environment = (
295
- json.loads(self.orchestrator_environment)
296
- if self.orchestrator_environment
297
- else {}
298
- )
299
-
300
294
  if self.deployment is not None:
301
295
  deployment = self.deployment.to_model(include_metadata=True)
302
296
 
@@ -377,6 +371,11 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
377
371
  # in the response -> We need to reset the metadata here
378
372
  step.metadata = None
379
373
 
374
+ orchestrator_environment = (
375
+ json.loads(self.orchestrator_environment)
376
+ if self.orchestrator_environment
377
+ else {}
378
+ )
380
379
  metadata = PipelineRunResponseMetadata(
381
380
  workspace=self.workspace.to_model(),
382
381
  run_metadata=self.fetch_metadata(),
@@ -405,7 +404,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
405
404
 
406
405
  resources = PipelineRunResponseResources(
407
406
  model_version=model_version,
408
- tags=[t.tag.to_model() for t in self.tags],
407
+ tags=[tag.to_model() for tag in self.tags],
409
408
  )
410
409
 
411
410
  return PipelineRunResponse(
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
43
43
  )
44
44
  from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
45
45
  from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema
46
- from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
46
+ from zenml.zen_stores.schemas.tag_schemas import TagSchema
47
47
 
48
48
 
49
49
  class PipelineSchema(NamedSchema, table=True):
@@ -95,10 +95,12 @@ class PipelineSchema(NamedSchema, table=True):
95
95
  deployments: List["PipelineDeploymentSchema"] = Relationship(
96
96
  back_populates="pipeline",
97
97
  )
98
- tags: List["TagResourceSchema"] = Relationship(
98
+ tags: List["TagSchema"] = Relationship(
99
99
  sa_relationship_kwargs=dict(
100
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
101
- cascade="delete",
100
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
101
+ secondary="tag_resource",
102
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
103
+ order_by="TagSchema.name",
102
104
  overlaps="tags",
103
105
  ),
104
106
  )
@@ -162,7 +164,7 @@ class PipelineSchema(NamedSchema, table=True):
162
164
  latest_run_user=latest_run_user.to_model()
163
165
  if latest_run_user
164
166
  else None,
165
- tags=[t.tag.to_model() for t in self.tags],
167
+ tags=[tag.to_model() for tag in self.tags],
166
168
  )
167
169
 
168
170
  return PipelineResponse(
@@ -13,13 +13,12 @@
13
13
  # permissions and limitations under the License.
14
14
  """SQLModel implementation of pipeline run metadata tables."""
15
15
 
16
- from typing import TYPE_CHECKING, List, Optional
16
+ from typing import Optional
17
17
  from uuid import UUID, uuid4
18
18
 
19
19
  from sqlalchemy import TEXT, VARCHAR, Column
20
20
  from sqlmodel import Field, Relationship, SQLModel
21
21
 
22
- from zenml.enums import MetadataResourceTypes
23
22
  from zenml.zen_stores.schemas.base_schemas import BaseSchema
24
23
  from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
25
24
  from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
@@ -27,22 +26,12 @@ from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
27
26
  from zenml.zen_stores.schemas.user_schemas import UserSchema
28
27
  from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema
29
28
 
30
- if TYPE_CHECKING:
31
- from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
32
- from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema
33
- from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
34
-
35
29
 
36
30
  class RunMetadataSchema(BaseSchema, table=True):
37
31
  """SQL Model for run metadata."""
38
32
 
39
33
  __tablename__ = "run_metadata"
40
34
 
41
- # Relationship to link to resources
42
- resources: List["RunMetadataResourceSchema"] = Relationship(
43
- back_populates="run_metadata",
44
- sa_relationship_kwargs={"cascade": "delete"},
45
- )
46
35
  stack_component_id: Optional[UUID] = build_foreign_key_field(
47
36
  source=__tablename__,
48
37
  target=StackComponentSchema.__tablename__,
@@ -105,36 +94,3 @@ class RunMetadataResourceSchema(SQLModel, table=True):
105
94
  ondelete="CASCADE",
106
95
  nullable=False,
107
96
  )
108
-
109
- # Relationship back to the base metadata table
110
- run_metadata: RunMetadataSchema = Relationship(back_populates="resources")
111
-
112
- # Relationship to link specific resource types
113
- pipeline_runs: List["PipelineRunSchema"] = Relationship(
114
- back_populates="run_metadata_resources",
115
- sa_relationship_kwargs=dict(
116
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
117
- overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions",
118
- ),
119
- )
120
- step_runs: List["StepRunSchema"] = Relationship(
121
- back_populates="run_metadata_resources",
122
- sa_relationship_kwargs=dict(
123
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)",
124
- overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions",
125
- ),
126
- )
127
- artifact_versions: List["ArtifactVersionSchema"] = Relationship(
128
- back_populates="run_metadata_resources",
129
- sa_relationship_kwargs=dict(
130
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
131
- overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions",
132
- ),
133
- )
134
- model_versions: List["ModelVersionSchema"] = Relationship(
135
- back_populates="run_metadata_resources",
136
- sa_relationship_kwargs=dict(
137
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
138
- overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions",
139
- ),
140
- )
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
41
41
  PipelineDeploymentSchema,
42
42
  )
43
43
  from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
44
- from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
44
+ from zenml.zen_stores.schemas.tag_schemas import TagSchema
45
45
 
46
46
 
47
47
  class RunTemplateSchema(BaseSchema, table=True):
@@ -110,10 +110,12 @@ class RunTemplateSchema(BaseSchema, table=True):
110
110
  }
111
111
  )
112
112
 
113
- tags: List["TagResourceSchema"] = Relationship(
113
+ tags: List["TagSchema"] = Relationship(
114
114
  sa_relationship_kwargs=dict(
115
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)",
116
- cascade="delete",
115
+ primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)",
116
+ secondary="tag_resource",
117
+ secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
118
+ order_by="TagSchema.name",
117
119
  overlaps="tags",
118
120
  ),
119
121
  )
@@ -253,7 +255,7 @@ class RunTemplateSchema(BaseSchema, table=True):
253
255
  pipeline=pipeline,
254
256
  build=build,
255
257
  code_reference=code_reference,
256
- tags=[t.tag.to_model() for t in self.tags],
258
+ tags=[tag.to_model() for tag in self.tags],
257
259
  )
258
260
 
259
261
  return RunTemplateResponse(
@@ -58,9 +58,7 @@ if TYPE_CHECKING:
58
58
  from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
59
59
  from zenml.zen_stores.schemas.logs_schemas import LogsSchema
60
60
  from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema
61
- from zenml.zen_stores.schemas.run_metadata_schemas import (
62
- RunMetadataResourceSchema,
63
- )
61
+ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
64
62
 
65
63
 
66
64
  class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
@@ -150,12 +148,12 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
150
148
  deployment: Optional["PipelineDeploymentSchema"] = Relationship(
151
149
  back_populates="step_runs"
152
150
  )
153
- run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
154
- back_populates="step_runs",
151
+ run_metadata: List["RunMetadataSchema"] = Relationship(
155
152
  sa_relationship_kwargs=dict(
156
- primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)",
157
- cascade="delete",
158
- overlaps="run_metadata_resources",
153
+ secondary="run_metadata_resource",
154
+ primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)",
155
+ secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
156
+ overlaps="run_metadata",
159
157
  ),
160
158
  )
161
159
  input_artifacts: List["StepRunInputArtifactSchema"] = Relationship(
@@ -14,7 +14,7 @@
14
14
  """SQLModel implementation of tag tables."""
15
15
 
16
16
  from datetime import datetime
17
- from typing import TYPE_CHECKING, Any, List
17
+ from typing import Any, List
18
18
  from uuid import UUID
19
19
 
20
20
  from sqlalchemy import VARCHAR, Column
@@ -33,16 +33,6 @@ from zenml.models import (
33
33
  from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
34
34
  from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
35
35
 
36
- if TYPE_CHECKING:
37
- from zenml.zen_stores.schemas.artifact_schemas import (
38
- ArtifactSchema,
39
- ArtifactVersionSchema,
40
- )
41
- from zenml.zen_stores.schemas.model_schemas import (
42
- ModelSchema,
43
- ModelVersionSchema,
44
- )
45
-
46
36
 
47
37
  class TagSchema(NamedSchema, table=True):
48
38
  """SQL Model for tag."""
@@ -52,7 +42,7 @@ class TagSchema(NamedSchema, table=True):
52
42
  color: str = Field(sa_column=Column(VARCHAR(255), nullable=False))
53
43
  links: List["TagResourceSchema"] = Relationship(
54
44
  back_populates="tag",
55
- sa_relationship_kwargs={"cascade": "delete"},
45
+ sa_relationship_kwargs={"overlaps": "tags", "cascade": "delete"},
56
46
  )
57
47
 
58
48
  @classmethod
@@ -130,37 +120,11 @@ class TagResourceSchema(BaseSchema, table=True):
130
120
  ondelete="CASCADE",
131
121
  nullable=False,
132
122
  )
133
- tag: "TagSchema" = Relationship(back_populates="links")
123
+ tag: "TagSchema" = Relationship(
124
+ back_populates="links", sa_relationship_kwargs={"overlaps": "tags"}
125
+ )
134
126
  resource_id: UUID
135
127
  resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False))
136
- artifact: List["ArtifactSchema"] = Relationship(
137
- back_populates="tags",
138
- sa_relationship_kwargs=dict(
139
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
140
- overlaps="tags,model,artifact_version,model_version",
141
- ),
142
- )
143
- artifact_version: List["ArtifactVersionSchema"] = Relationship(
144
- back_populates="tags",
145
- sa_relationship_kwargs=dict(
146
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
147
- overlaps="tags,model,artifact,model_version",
148
- ),
149
- )
150
- model: List["ModelSchema"] = Relationship(
151
- back_populates="tags",
152
- sa_relationship_kwargs=dict(
153
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
154
- overlaps="tags,artifact,artifact_version,model_version",
155
- ),
156
- )
157
- model_version: List["ModelVersionSchema"] = Relationship(
158
- back_populates="tags",
159
- sa_relationship_kwargs=dict(
160
- primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
161
- overlaps="tags,model,artifact,artifact_version",
162
- ),
163
- )
164
128
 
165
129
  @classmethod
166
130
  def from_request(cls, request: TagResourceRequest) -> "TagResourceSchema":