zenml-nightly 0.73.0.dev20250130__py3-none-any.whl → 0.73.0.dev20250201__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/code_repository.py +26 -0
- zenml/client.py +2 -7
- zenml/code_repositories/base_code_repository.py +30 -2
- zenml/code_repositories/git/local_git_repository_context.py +26 -10
- zenml/code_repositories/local_repository_context.py +11 -8
- zenml/constants.py +3 -0
- zenml/integrations/azure/artifact_stores/azure_artifact_store.py +18 -9
- zenml/integrations/azure/service_connectors/azure_service_connector.py +146 -29
- zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
- zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
- zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
- zenml/pipelines/build_utils.py +42 -35
- zenml/pipelines/pipeline_definition.py +5 -2
- zenml/utils/code_repository_utils.py +11 -2
- zenml/utils/downloaded_repository_context.py +3 -5
- zenml/utils/source_utils.py +3 -3
- zenml/zen_stores/migrations/utils.py +48 -1
- zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
- zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
- zenml/zen_stores/schemas/schema_utils.py +34 -2
- zenml/zen_stores/schemas/tag_schemas.py +14 -1
- {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/METADATA +2 -1
- {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/RECORD +27 -26
- {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.73.0.
|
1
|
+
0.73.0.dev20250201
|
zenml/cli/code_repository.py
CHANGED
@@ -162,6 +162,32 @@ def register_code_repository(
|
|
162
162
|
cli_utils.declare(f"Successfully registered code repository `{name}`.")
|
163
163
|
|
164
164
|
|
165
|
+
@code_repository.command("describe", help="Describe a code repository.")
|
166
|
+
@click.argument(
|
167
|
+
"name_id_or_prefix",
|
168
|
+
type=str,
|
169
|
+
required=True,
|
170
|
+
)
|
171
|
+
def describe_code_repository(name_id_or_prefix: str) -> None:
|
172
|
+
"""Describe a code repository.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
name_id_or_prefix: Name, ID or prefix of the code repository.
|
176
|
+
"""
|
177
|
+
client = Client()
|
178
|
+
try:
|
179
|
+
code_repository = client.get_code_repository(
|
180
|
+
name_id_or_prefix=name_id_or_prefix,
|
181
|
+
)
|
182
|
+
except KeyError as err:
|
183
|
+
cli_utils.error(str(err))
|
184
|
+
else:
|
185
|
+
cli_utils.print_pydantic_model(
|
186
|
+
title=f"Code repository '{code_repository.name}'",
|
187
|
+
model=code_repository,
|
188
|
+
)
|
189
|
+
|
190
|
+
|
165
191
|
@code_repository.command("list", help="List all connected code repositories.")
|
166
192
|
@list_options(CodeRepositoryFilter)
|
167
193
|
def list_code_repositories(**kwargs: Any) -> None:
|
zenml/client.py
CHANGED
@@ -34,7 +34,7 @@ from typing import (
|
|
34
34
|
Union,
|
35
35
|
cast,
|
36
36
|
)
|
37
|
-
from uuid import UUID
|
37
|
+
from uuid import UUID
|
38
38
|
|
39
39
|
from pydantic import ConfigDict, SecretStr
|
40
40
|
|
@@ -4980,12 +4980,7 @@ class Client(metaclass=ClientMetaClass):
|
|
4980
4980
|
)
|
4981
4981
|
)
|
4982
4982
|
try:
|
4983
|
-
|
4984
|
-
code_repo_class(id=uuid4(), config=config)
|
4985
|
-
|
4986
|
-
# Explicitly access the config for pydantic validation, in case
|
4987
|
-
# the login for some reason did not do that.
|
4988
|
-
_ = code_repo_class.config
|
4983
|
+
code_repo_class.validate_config(config)
|
4989
4984
|
except Exception as e:
|
4990
4985
|
raise RuntimeError(
|
4991
4986
|
"Failed to validate code repository config."
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
17
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type
|
18
|
-
from uuid import UUID
|
18
|
+
from uuid import UUID, uuid4
|
19
19
|
|
20
20
|
from zenml.config.secret_reference_mixin import SecretReferenceMixin
|
21
21
|
from zenml.logger import get_logger
|
@@ -44,15 +44,18 @@ class BaseCodeRepository(ABC):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
id: UUID,
|
47
|
+
name: str,
|
47
48
|
config: Dict[str, Any],
|
48
49
|
) -> None:
|
49
50
|
"""Initializes a code repository.
|
50
51
|
|
51
52
|
Args:
|
52
53
|
id: The ID of the code repository.
|
54
|
+
name: The name of the code repository.
|
53
55
|
config: The config of the code repository.
|
54
56
|
"""
|
55
57
|
self._id = id
|
58
|
+
self._name = name
|
56
59
|
self._config = config
|
57
60
|
self.login()
|
58
61
|
|
@@ -80,7 +83,23 @@ class BaseCodeRepository(ABC):
|
|
80
83
|
source=model.source, expected_class=BaseCodeRepository
|
81
84
|
)
|
82
85
|
)
|
83
|
-
return class_(id=model.id, config=model.config)
|
86
|
+
return class_(id=model.id, name=model.name, config=model.config)
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
90
|
+
"""Validate the code repository config.
|
91
|
+
|
92
|
+
This method should check that the config/credentials are valid and
|
93
|
+
the configured repository exists.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
config: The configuration.
|
97
|
+
"""
|
98
|
+
# The initialization calls the login to verify the credentials
|
99
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
100
|
+
|
101
|
+
# Explicitly access the config for pydantic validation
|
102
|
+
_ = code_repo.config
|
84
103
|
|
85
104
|
@property
|
86
105
|
def id(self) -> UUID:
|
@@ -91,6 +110,15 @@ class BaseCodeRepository(ABC):
|
|
91
110
|
"""
|
92
111
|
return self._id
|
93
112
|
|
113
|
+
@property
|
114
|
+
def name(self) -> str:
|
115
|
+
"""Name of the code repository.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The name of the code repository.
|
119
|
+
"""
|
120
|
+
return self._name
|
121
|
+
|
94
122
|
@property
|
95
123
|
def requirements(self) -> Set[str]:
|
96
124
|
"""Set of PyPI requirements for the repository.
|
@@ -14,11 +14,14 @@
|
|
14
14
|
"""Implementation of the Local git repository context."""
|
15
15
|
|
16
16
|
from typing import TYPE_CHECKING, Callable, Optional, cast
|
17
|
-
from uuid import UUID
|
18
17
|
|
19
18
|
from zenml.code_repositories import (
|
20
19
|
LocalRepositoryContext,
|
21
20
|
)
|
21
|
+
from zenml.constants import (
|
22
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES,
|
23
|
+
handle_bool_env_var,
|
24
|
+
)
|
22
25
|
from zenml.logger import get_logger
|
23
26
|
|
24
27
|
if TYPE_CHECKING:
|
@@ -26,6 +29,8 @@ if TYPE_CHECKING:
|
|
26
29
|
from git.remote import Remote
|
27
30
|
from git.repo.base import Repo
|
28
31
|
|
32
|
+
from zenml.code_repositories import BaseCodeRepository
|
33
|
+
|
29
34
|
logger = get_logger(__name__)
|
30
35
|
|
31
36
|
|
@@ -33,16 +38,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
33
38
|
"""Local git repository context."""
|
34
39
|
|
35
40
|
def __init__(
|
36
|
-
self,
|
41
|
+
self,
|
42
|
+
code_repository: "BaseCodeRepository",
|
43
|
+
git_repo: "Repo",
|
44
|
+
remote_name: str,
|
37
45
|
):
|
38
46
|
"""Initializes a local git repository context.
|
39
47
|
|
40
48
|
Args:
|
41
|
-
|
49
|
+
code_repository: The code repository.
|
42
50
|
git_repo: The git repo.
|
43
51
|
remote_name: Name of the remote.
|
44
52
|
"""
|
45
|
-
super().__init__(
|
53
|
+
super().__init__(code_repository=code_repository)
|
46
54
|
self._git_repo = git_repo
|
47
55
|
self._remote = git_repo.remote(name=remote_name)
|
48
56
|
|
@@ -50,14 +58,14 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
50
58
|
def at(
|
51
59
|
cls,
|
52
60
|
path: str,
|
53
|
-
|
61
|
+
code_repository: "BaseCodeRepository",
|
54
62
|
remote_url_validation_callback: Callable[[str], bool],
|
55
63
|
) -> Optional["LocalGitRepositoryContext"]:
|
56
64
|
"""Returns a local git repository at the given path.
|
57
65
|
|
58
66
|
Args:
|
59
67
|
path: The path to the local git repository.
|
60
|
-
|
68
|
+
code_repository: The code repository.
|
61
69
|
remote_url_validation_callback: A callback that validates the
|
62
70
|
remote URL of the git repository.
|
63
71
|
|
@@ -70,11 +78,13 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
70
78
|
from git.exc import InvalidGitRepositoryError
|
71
79
|
from git.repo.base import Repo
|
72
80
|
except ImportError:
|
81
|
+
logger.debug("Failed to import git library.")
|
73
82
|
return None
|
74
83
|
|
75
84
|
try:
|
76
85
|
git_repo = Repo(path=path, search_parent_directories=True)
|
77
86
|
except InvalidGitRepositoryError:
|
87
|
+
logger.debug("No git repository exists at path %s.", path)
|
78
88
|
return None
|
79
89
|
|
80
90
|
remote_name = None
|
@@ -87,7 +97,7 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
87
97
|
return None
|
88
98
|
|
89
99
|
return LocalGitRepositoryContext(
|
90
|
-
|
100
|
+
code_repository=code_repository,
|
91
101
|
git_repo=git_repo,
|
92
102
|
remote_name=remote_name,
|
93
103
|
)
|
@@ -124,13 +134,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
124
134
|
def is_dirty(self) -> bool:
|
125
135
|
"""Whether the git repo is dirty.
|
126
136
|
|
127
|
-
|
128
|
-
changes.
|
137
|
+
By default, a repository counts as dirty if it has any untracked or
|
138
|
+
uncommitted changes. Users can use an environment variable to ignore
|
139
|
+
untracked files.
|
129
140
|
|
130
141
|
Returns:
|
131
142
|
True if the git repo is dirty, False otherwise.
|
132
143
|
"""
|
133
|
-
|
144
|
+
ignore_untracked_files = handle_bool_env_var(
|
145
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES, default=False
|
146
|
+
)
|
147
|
+
return self.git_repo.is_dirty(
|
148
|
+
untracked_files=not ignore_untracked_files
|
149
|
+
)
|
134
150
|
|
135
151
|
@property
|
136
152
|
def has_local_changes(self) -> bool:
|
@@ -14,10 +14,13 @@
|
|
14
14
|
"""Base class for local code repository contexts."""
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
|
-
from
|
17
|
+
from typing import TYPE_CHECKING
|
18
18
|
|
19
19
|
from zenml.logger import get_logger
|
20
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from zenml.code_repositories import BaseCodeRepository
|
23
|
+
|
21
24
|
logger = get_logger(__name__)
|
22
25
|
|
23
26
|
|
@@ -30,22 +33,22 @@ class LocalRepositoryContext(ABC):
|
|
30
33
|
commit, and whether the repository is dirty.
|
31
34
|
"""
|
32
35
|
|
33
|
-
def __init__(self,
|
36
|
+
def __init__(self, code_repository: "BaseCodeRepository") -> None:
|
34
37
|
"""Initializes a local repository context.
|
35
38
|
|
36
39
|
Args:
|
37
|
-
|
40
|
+
code_repository: The code repository.
|
38
41
|
"""
|
39
|
-
self.
|
42
|
+
self._code_repository = code_repository
|
40
43
|
|
41
44
|
@property
|
42
|
-
def
|
43
|
-
"""Returns the
|
45
|
+
def code_repository(self) -> "BaseCodeRepository":
|
46
|
+
"""Returns the code repository.
|
44
47
|
|
45
48
|
Returns:
|
46
|
-
The
|
49
|
+
The code repository.
|
47
50
|
"""
|
48
|
-
return self.
|
51
|
+
return self._code_repository
|
49
52
|
|
50
53
|
@property
|
51
54
|
@abstractmethod
|
zenml/constants.py
CHANGED
@@ -175,6 +175,9 @@ ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
|
|
175
175
|
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
|
176
176
|
"ZENML_PIPELINE_API_TOKEN_EXPIRATION"
|
177
177
|
)
|
178
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES = (
|
179
|
+
"ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES"
|
180
|
+
)
|
178
181
|
|
179
182
|
# Materializer environment variables
|
180
183
|
ENV_ZENML_MATERIALIZER_ALLOW_NON_ASCII_JSON_DUMPS = (
|
@@ -64,7 +64,10 @@ class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
|
|
64
64
|
"""
|
65
65
|
connector = self.get_connector()
|
66
66
|
if connector:
|
67
|
-
from azure.identity import
|
67
|
+
from azure.identity import (
|
68
|
+
ClientSecretCredential,
|
69
|
+
DefaultAzureCredential,
|
70
|
+
)
|
68
71
|
from azure.storage.blob import BlobServiceClient
|
69
72
|
|
70
73
|
client = connector.connect()
|
@@ -77,18 +80,24 @@ class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
|
|
77
80
|
)
|
78
81
|
# Get the credentials from the client
|
79
82
|
credentials = client.credential
|
80
|
-
|
83
|
+
|
84
|
+
if isinstance(credentials, ClientSecretCredential):
|
85
|
+
return AzureSecretSchema(
|
86
|
+
client_id=credentials._client_id,
|
87
|
+
client_secret=credentials._client_credential,
|
88
|
+
tenant_id=credentials._tenant_id,
|
89
|
+
account_name=client.account_name,
|
90
|
+
)
|
91
|
+
|
92
|
+
elif isinstance(credentials, DefaultAzureCredential):
|
93
|
+
return AzureSecretSchema(account_name=client.account_name)
|
94
|
+
|
95
|
+
else:
|
81
96
|
raise RuntimeError(
|
82
97
|
"The Azure Artifact Store connector can only be used "
|
83
98
|
"with a service connector that is configured with "
|
84
|
-
"Azure service principal credentials
|
99
|
+
"Azure service principal credentials or implicit authentication"
|
85
100
|
)
|
86
|
-
return AzureSecretSchema(
|
87
|
-
client_id=credentials._client_id,
|
88
|
-
client_secret=credentials._client_credential,
|
89
|
-
tenant_id=credentials._tenant_id,
|
90
|
-
account_name=client.account_name,
|
91
|
-
)
|
92
101
|
|
93
102
|
secret = self.get_typed_authentication_secret(
|
94
103
|
expected_schema_type=AzureSecretSchema
|
@@ -14,12 +14,14 @@
|
|
14
14
|
"""Azure Service Connector."""
|
15
15
|
|
16
16
|
import datetime
|
17
|
+
import json
|
17
18
|
import logging
|
18
19
|
import re
|
19
20
|
import subprocess
|
20
21
|
from typing import Any, Dict, List, Optional, Tuple
|
21
22
|
from uuid import UUID
|
22
23
|
|
24
|
+
import requests
|
23
25
|
import yaml
|
24
26
|
from azure.core.credentials import AccessToken, TokenCredential
|
25
27
|
from azure.core.exceptions import AzureError
|
@@ -69,6 +71,7 @@ logger = get_logger(__name__)
|
|
69
71
|
AZURE_MANAGEMENT_TOKEN_SCOPE = "https://management.azure.com/.default"
|
70
72
|
AZURE_SESSION_TOKEN_DEFAULT_EXPIRATION_TIME = 60 * 60 # 1 hour
|
71
73
|
AZURE_SESSION_EXPIRATION_BUFFER = 15 # 15 minutes
|
74
|
+
AZURE_ACR_OAUTH_SCOPE = "repository:*:*"
|
72
75
|
|
73
76
|
|
74
77
|
class AzureBaseConfig(AuthenticationConfig):
|
@@ -362,8 +365,8 @@ neither a storage account nor a resource group is configured in the connector,
|
|
362
365
|
all blob storage containers in all accessible storage accounts will be
|
363
366
|
accessible.
|
364
367
|
|
365
|
-
The only Azure authentication
|
366
|
-
|
368
|
+
The only Azure authentication methods that work with Azure blob storage resources are the implicit
|
369
|
+
authentication and the service principal authentication method.
|
367
370
|
""",
|
368
371
|
auth_methods=AzureAuthenticationMethods.values(),
|
369
372
|
# Request a blob container to be configured in the
|
@@ -435,12 +438,10 @@ following formats:
|
|
435
438
|
If a resource group is configured in the connector, only ACR registries in that
|
436
439
|
resource group will be accessible.
|
437
440
|
|
438
|
-
If an authentication method other than the Azure service principal is used
|
439
|
-
|
440
|
-
|
441
|
-
[documentation on the admin account](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account
|
442
|
-
)
|
443
|
-
for more information.
|
441
|
+
If an authentication method other than the Azure service principal is used, Entra ID authentication is used.
|
442
|
+
This requires the configured identity to have the `AcrPush` role to be configured.
|
443
|
+
If this fails, admin account authentication is tried. For this the admin account must be enabled for the registry.
|
444
|
+
See the official Azure[documentation on the admin account](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account) for more information.
|
444
445
|
""",
|
445
446
|
auth_methods=AzureAuthenticationMethods.values(),
|
446
447
|
supports_instances=True,
|
@@ -1718,10 +1719,14 @@ class AzureServiceConnector(ServiceConnector):
|
|
1718
1719
|
resource_type=resource_type,
|
1719
1720
|
resource_id=resource_id,
|
1720
1721
|
)
|
1721
|
-
|
1722
|
+
|
1723
|
+
resource_group: Optional[str]
|
1724
|
+
registry_name: str
|
1725
|
+
cluster_name: str
|
1722
1726
|
|
1723
1727
|
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
|
1724
1728
|
registry_name = self._parse_acr_resource_id(resource_id)
|
1729
|
+
registry_domain = f"{registry_name}.azurecr.io"
|
1725
1730
|
|
1726
1731
|
# If a service principal is used for authentication, the client ID
|
1727
1732
|
# and client secret can be used to authenticate to the registry, if
|
@@ -1733,13 +1738,12 @@ class AzureServiceConnector(ServiceConnector):
|
|
1733
1738
|
):
|
1734
1739
|
assert isinstance(self.config, AzureServicePrincipalConfig)
|
1735
1740
|
username = str(self.config.client_id)
|
1736
|
-
password = self.config.client_secret
|
1741
|
+
password = str(self.config.client_secret)
|
1737
1742
|
|
1738
|
-
# Without a service principal,
|
1739
|
-
#
|
1740
|
-
# disabled by default.
|
1743
|
+
# Without a service principal, we try to use the AzureDefaultCredentials to authenticate against the ACR.
|
1744
|
+
# If this fails, we try to use the admin account.
|
1745
|
+
# This has to be enabled for the registry, but this is not recommended and disabled by default.
|
1741
1746
|
# https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account
|
1742
|
-
|
1743
1747
|
else:
|
1744
1748
|
registries = self._list_acr_registries(
|
1745
1749
|
credential,
|
@@ -1748,22 +1752,35 @@ class AzureServiceConnector(ServiceConnector):
|
|
1748
1752
|
registry_name, resource_group = registries.popitem()
|
1749
1753
|
|
1750
1754
|
try:
|
1751
|
-
|
1752
|
-
|
1755
|
+
username = "00000000-0000-0000-0000-000000000000"
|
1756
|
+
password = _ACRTokenExchangeClient(
|
1757
|
+
credential
|
1758
|
+
).get_acr_access_token(
|
1759
|
+
registry_domain, AZURE_ACR_OAUTH_SCOPE
|
1753
1760
|
)
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1761
|
+
except AuthorizationException:
|
1762
|
+
logger.warning(
|
1763
|
+
"Falling back to admin credentials for ACR authentication. Be sure to assign AcrPush role to the configured identity."
|
1757
1764
|
)
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1765
|
+
try:
|
1766
|
+
client = ContainerRegistryManagementClient(
|
1767
|
+
credential, subscription_id
|
1768
|
+
)
|
1769
|
+
|
1770
|
+
registry_credentials = (
|
1771
|
+
client.registries.list_credentials(
|
1772
|
+
resource_group, registry_name
|
1773
|
+
)
|
1774
|
+
)
|
1775
|
+
username = registry_credentials.username
|
1776
|
+
password = registry_credentials.passwords[0].value
|
1777
|
+
except AzureError as e:
|
1778
|
+
raise AuthorizationException(
|
1779
|
+
f"failed to list admin credentials for Azure Container "
|
1780
|
+
f"Registry '{registry_name}' in resource group "
|
1781
|
+
f"'{resource_group}'. Make sure the registry is "
|
1782
|
+
f"configured with an admin account: {e}"
|
1783
|
+
) from e
|
1767
1784
|
|
1768
1785
|
# Create a client-side Docker connector instance with the temporary
|
1769
1786
|
# Docker credentials
|
@@ -1775,7 +1792,7 @@ class AzureServiceConnector(ServiceConnector):
|
|
1775
1792
|
config=DockerConfiguration(
|
1776
1793
|
username=username,
|
1777
1794
|
password=password,
|
1778
|
-
registry=
|
1795
|
+
registry=registry_domain,
|
1779
1796
|
),
|
1780
1797
|
expires_at=expires_at,
|
1781
1798
|
)
|
@@ -1847,3 +1864,103 @@ class AzureServiceConnector(ServiceConnector):
|
|
1847
1864
|
)
|
1848
1865
|
|
1849
1866
|
raise ValueError(f"Unsupported resource type: {resource_type}")
|
1867
|
+
|
1868
|
+
|
1869
|
+
class _ACRTokenExchangeClient:
|
1870
|
+
def __init__(self, credential: TokenCredential):
|
1871
|
+
self._credential = credential
|
1872
|
+
|
1873
|
+
def _get_aad_access_token(self) -> str:
|
1874
|
+
aad_access_token: str = self._credential.get_token(
|
1875
|
+
AZURE_MANAGEMENT_TOKEN_SCOPE
|
1876
|
+
).token
|
1877
|
+
return aad_access_token
|
1878
|
+
|
1879
|
+
# https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#authenticating-docker-with-an-acr-refresh-token
|
1880
|
+
def get_acr_refresh_token(self, acr_url: str) -> str:
|
1881
|
+
try:
|
1882
|
+
aad_access_token = self._get_aad_access_token()
|
1883
|
+
|
1884
|
+
headers = {
|
1885
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
1886
|
+
}
|
1887
|
+
|
1888
|
+
data = {
|
1889
|
+
"grant_type": "access_token",
|
1890
|
+
"service": acr_url,
|
1891
|
+
"access_token": aad_access_token,
|
1892
|
+
}
|
1893
|
+
|
1894
|
+
response = requests.post(
|
1895
|
+
f"https://{acr_url}/oauth2/exchange",
|
1896
|
+
headers=headers,
|
1897
|
+
data=data,
|
1898
|
+
timeout=5,
|
1899
|
+
)
|
1900
|
+
|
1901
|
+
if response.status_code != 200:
|
1902
|
+
raise AuthorizationException(
|
1903
|
+
f"failed to get refresh token for Azure Container "
|
1904
|
+
f"Registry '{acr_url}' in resource group. "
|
1905
|
+
f"Be sure to assign AcrPush to the configured principal. "
|
1906
|
+
f"The token exchange returned status {response.status_code} "
|
1907
|
+
f"with body '{response.content.decode()}'"
|
1908
|
+
)
|
1909
|
+
|
1910
|
+
acr_refresh_token_response = json.loads(response.content)
|
1911
|
+
acr_refresh_token: str = acr_refresh_token_response[
|
1912
|
+
"refresh_token"
|
1913
|
+
]
|
1914
|
+
return acr_refresh_token
|
1915
|
+
|
1916
|
+
except (AzureError, requests.exceptions.RequestException) as e:
|
1917
|
+
raise AuthorizationException(
|
1918
|
+
f"failed to get refresh token for Azure Container "
|
1919
|
+
f"Registry '{acr_url}' in resource group. "
|
1920
|
+
f"Make sure the implicit authentication identity "
|
1921
|
+
f"has access to the configured registry: {e}"
|
1922
|
+
) from e
|
1923
|
+
|
1924
|
+
# https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2token-to-get-an-acr-access-token
|
1925
|
+
def get_acr_access_token(self, acr_url: str, scope: str) -> str:
|
1926
|
+
acr_refresh_token = self.get_acr_refresh_token(acr_url)
|
1927
|
+
|
1928
|
+
headers = {
|
1929
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
1930
|
+
}
|
1931
|
+
|
1932
|
+
data = {
|
1933
|
+
"grant_type": "refresh_token",
|
1934
|
+
"service": acr_url,
|
1935
|
+
"scope": scope,
|
1936
|
+
"refresh_token": acr_refresh_token,
|
1937
|
+
}
|
1938
|
+
|
1939
|
+
try:
|
1940
|
+
response = requests.post(
|
1941
|
+
f"https://{acr_url}/oauth2/token",
|
1942
|
+
headers=headers,
|
1943
|
+
data=data,
|
1944
|
+
timeout=5,
|
1945
|
+
)
|
1946
|
+
|
1947
|
+
if response.status_code != 200:
|
1948
|
+
raise AuthorizationException(
|
1949
|
+
f"failed to get access token for Azure Container "
|
1950
|
+
f"Registry '{acr_url}' in resource group. "
|
1951
|
+
f"Be sure to assign AcrPush to the configured principal. "
|
1952
|
+
f"The token exchange returned status {response.status_code} "
|
1953
|
+
f"with body '{response.content.decode()}'"
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
acr_access_token_response = json.loads(response.content)
|
1957
|
+
acr_access_token: str = acr_access_token_response["access_token"]
|
1958
|
+
return acr_access_token
|
1959
|
+
|
1960
|
+
except (AzureError, requests.exceptions.RequestException) as e:
|
1961
|
+
raise AuthorizationException(
|
1962
|
+
f"failed to get access token for Azure Container "
|
1963
|
+
f"Registry '{acr_url}' in resource group. "
|
1964
|
+
f"Make sure the implicit authentication identity "
|
1965
|
+
f"has access to the configured registry: {e}"
|
1966
|
+
) from e
|
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
-
from typing import List, Optional
|
18
|
+
from typing import Any, Dict, List, Optional
|
19
|
+
from uuid import uuid4
|
19
20
|
|
20
21
|
import requests
|
21
22
|
from github import Consts, Github, GithubException
|
@@ -62,6 +63,20 @@ class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig):
|
|
62
63
|
class GitHubCodeRepository(BaseCodeRepository):
|
63
64
|
"""GitHub code repository."""
|
64
65
|
|
66
|
+
@classmethod
|
67
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
68
|
+
"""Validate the code repository config.
|
69
|
+
|
70
|
+
This method should check that the config/credentials are valid and
|
71
|
+
the configured repository exists.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
config: The configuration.
|
75
|
+
"""
|
76
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
77
|
+
# Try to access the project to make sure it exists
|
78
|
+
_ = code_repo.github_repo
|
79
|
+
|
65
80
|
@property
|
66
81
|
def config(self) -> GitHubCodeRepositoryConfig:
|
67
82
|
"""Returns the `GitHubCodeRepositoryConfig` config.
|
@@ -190,7 +205,7 @@ class GitHubCodeRepository(BaseCodeRepository):
|
|
190
205
|
"""
|
191
206
|
return LocalGitRepositoryContext.at(
|
192
207
|
path=path,
|
193
|
-
|
208
|
+
code_repository=self,
|
194
209
|
remote_url_validation_callback=self.check_remote_url,
|
195
210
|
)
|
196
211
|
|
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Any, Dict, Optional
|
19
|
+
from uuid import uuid4
|
19
20
|
|
20
21
|
from gitlab import Gitlab
|
21
22
|
from gitlab.v4.objects import Project
|
@@ -63,6 +64,20 @@ class GitLabCodeRepositoryConfig(BaseCodeRepositoryConfig):
|
|
63
64
|
class GitLabCodeRepository(BaseCodeRepository):
|
64
65
|
"""GitLab code repository."""
|
65
66
|
|
67
|
+
@classmethod
|
68
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
69
|
+
"""Validate the code repository config.
|
70
|
+
|
71
|
+
This method should check that the config/credentials are valid and
|
72
|
+
the configured repository exists.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
config: The configuration.
|
76
|
+
"""
|
77
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
78
|
+
# Try to access the project to make sure it exists
|
79
|
+
_ = code_repo.gitlab_project
|
80
|
+
|
66
81
|
@property
|
67
82
|
def config(self) -> GitLabCodeRepositoryConfig:
|
68
83
|
"""Returns the `GitLabCodeRepositoryConfig` config.
|
@@ -147,7 +162,7 @@ class GitLabCodeRepository(BaseCodeRepository):
|
|
147
162
|
"""
|
148
163
|
return LocalGitRepositoryContext.at(
|
149
164
|
path=path,
|
150
|
-
|
165
|
+
code_repository=self,
|
151
166
|
remote_url_validation_callback=self.check_remote_url,
|
152
167
|
)
|
153
168
|
|
@@ -44,7 +44,7 @@ class BasePyTorchMaterializer(BaseMaterializer):
|
|
44
44
|
# NOTE (security): The `torch.load` function uses `pickle` as
|
45
45
|
# the default unpickler, which is NOT secure. This materializer
|
46
46
|
# is intended for use with trusted data sources.
|
47
|
-
return torch.load(f) # nosec
|
47
|
+
return torch.load(f, weights_only=False) # nosec
|
48
48
|
|
49
49
|
def save(self, obj: Any) -> None:
|
50
50
|
"""Uses `torch.save` to save a PyTorch object.
|