zenml-nightly 0.73.0.dev20250130__py3-none-any.whl → 0.73.0.dev20250201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/code_repository.py +26 -0
  3. zenml/client.py +2 -7
  4. zenml/code_repositories/base_code_repository.py +30 -2
  5. zenml/code_repositories/git/local_git_repository_context.py +26 -10
  6. zenml/code_repositories/local_repository_context.py +11 -8
  7. zenml/constants.py +3 -0
  8. zenml/integrations/azure/artifact_stores/azure_artifact_store.py +18 -9
  9. zenml/integrations/azure/service_connectors/azure_service_connector.py +146 -29
  10. zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
  11. zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
  12. zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
  13. zenml/pipelines/build_utils.py +42 -35
  14. zenml/pipelines/pipeline_definition.py +5 -2
  15. zenml/utils/code_repository_utils.py +11 -2
  16. zenml/utils/downloaded_repository_context.py +3 -5
  17. zenml/utils/source_utils.py +3 -3
  18. zenml/zen_stores/migrations/utils.py +48 -1
  19. zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
  20. zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
  21. zenml/zen_stores/schemas/schema_utils.py +34 -2
  22. zenml/zen_stores/schemas/tag_schemas.py +14 -1
  23. {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/METADATA +2 -1
  24. {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/RECORD +27 -26
  25. {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/LICENSE +0 -0
  26. {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/WHEEL +0 -0
  27. {zenml_nightly-0.73.0.dev20250130.dist-info → zenml_nightly-0.73.0.dev20250201.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.73.0.dev20250130
1
+ 0.73.0.dev20250201
@@ -162,6 +162,32 @@ def register_code_repository(
162
162
  cli_utils.declare(f"Successfully registered code repository `{name}`.")
163
163
 
164
164
 
165
+ @code_repository.command("describe", help="Describe a code repository.")
166
+ @click.argument(
167
+ "name_id_or_prefix",
168
+ type=str,
169
+ required=True,
170
+ )
171
+ def describe_code_repository(name_id_or_prefix: str) -> None:
172
+ """Describe a code repository.
173
+
174
+ Args:
175
+ name_id_or_prefix: Name, ID or prefix of the code repository.
176
+ """
177
+ client = Client()
178
+ try:
179
+ code_repository = client.get_code_repository(
180
+ name_id_or_prefix=name_id_or_prefix,
181
+ )
182
+ except KeyError as err:
183
+ cli_utils.error(str(err))
184
+ else:
185
+ cli_utils.print_pydantic_model(
186
+ title=f"Code repository '{code_repository.name}'",
187
+ model=code_repository,
188
+ )
189
+
190
+
165
191
  @code_repository.command("list", help="List all connected code repositories.")
166
192
  @list_options(CodeRepositoryFilter)
167
193
  def list_code_repositories(**kwargs: Any) -> None:
zenml/client.py CHANGED
@@ -34,7 +34,7 @@ from typing import (
34
34
  Union,
35
35
  cast,
36
36
  )
37
- from uuid import UUID, uuid4
37
+ from uuid import UUID
38
38
 
39
39
  from pydantic import ConfigDict, SecretStr
40
40
 
@@ -4980,12 +4980,7 @@ class Client(metaclass=ClientMetaClass):
4980
4980
  )
4981
4981
  )
4982
4982
  try:
4983
- # This does a login to verify the credentials
4984
- code_repo_class(id=uuid4(), config=config)
4985
-
4986
- # Explicitly access the config for pydantic validation, in case
4987
- # the login for some reason did not do that.
4988
- _ = code_repo_class.config
4983
+ code_repo_class.validate_config(config)
4989
4984
  except Exception as e:
4990
4985
  raise RuntimeError(
4991
4986
  "Failed to validate code repository config."
@@ -15,7 +15,7 @@
15
15
 
16
16
  from abc import ABC, abstractmethod
17
17
  from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type
18
- from uuid import UUID
18
+ from uuid import UUID, uuid4
19
19
 
20
20
  from zenml.config.secret_reference_mixin import SecretReferenceMixin
21
21
  from zenml.logger import get_logger
@@ -44,15 +44,18 @@ class BaseCodeRepository(ABC):
44
44
  def __init__(
45
45
  self,
46
46
  id: UUID,
47
+ name: str,
47
48
  config: Dict[str, Any],
48
49
  ) -> None:
49
50
  """Initializes a code repository.
50
51
 
51
52
  Args:
52
53
  id: The ID of the code repository.
54
+ name: The name of the code repository.
53
55
  config: The config of the code repository.
54
56
  """
55
57
  self._id = id
58
+ self._name = name
56
59
  self._config = config
57
60
  self.login()
58
61
 
@@ -80,7 +83,23 @@ class BaseCodeRepository(ABC):
80
83
  source=model.source, expected_class=BaseCodeRepository
81
84
  )
82
85
  )
83
- return class_(id=model.id, config=model.config)
86
+ return class_(id=model.id, name=model.name, config=model.config)
87
+
88
+ @classmethod
89
+ def validate_config(cls, config: Dict[str, Any]) -> None:
90
+ """Validate the code repository config.
91
+
92
+ This method should check that the config/credentials are valid and
93
+ the configured repository exists.
94
+
95
+ Args:
96
+ config: The configuration.
97
+ """
98
+ # The initialization calls the login to verify the credentials
99
+ code_repo = cls(id=uuid4(), name="", config=config)
100
+
101
+ # Explicitly access the config for pydantic validation
102
+ _ = code_repo.config
84
103
 
85
104
  @property
86
105
  def id(self) -> UUID:
@@ -91,6 +110,15 @@ class BaseCodeRepository(ABC):
91
110
  """
92
111
  return self._id
93
112
 
113
+ @property
114
+ def name(self) -> str:
115
+ """Name of the code repository.
116
+
117
+ Returns:
118
+ The name of the code repository.
119
+ """
120
+ return self._name
121
+
94
122
  @property
95
123
  def requirements(self) -> Set[str]:
96
124
  """Set of PyPI requirements for the repository.
@@ -14,11 +14,14 @@
14
14
  """Implementation of the Local git repository context."""
15
15
 
16
16
  from typing import TYPE_CHECKING, Callable, Optional, cast
17
- from uuid import UUID
18
17
 
19
18
  from zenml.code_repositories import (
20
19
  LocalRepositoryContext,
21
20
  )
21
+ from zenml.constants import (
22
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES,
23
+ handle_bool_env_var,
24
+ )
22
25
  from zenml.logger import get_logger
23
26
 
24
27
  if TYPE_CHECKING:
@@ -26,6 +29,8 @@ if TYPE_CHECKING:
26
29
  from git.remote import Remote
27
30
  from git.repo.base import Repo
28
31
 
32
+ from zenml.code_repositories import BaseCodeRepository
33
+
29
34
  logger = get_logger(__name__)
30
35
 
31
36
 
@@ -33,16 +38,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
33
38
  """Local git repository context."""
34
39
 
35
40
  def __init__(
36
- self, code_repository_id: UUID, git_repo: "Repo", remote_name: str
41
+ self,
42
+ code_repository: "BaseCodeRepository",
43
+ git_repo: "Repo",
44
+ remote_name: str,
37
45
  ):
38
46
  """Initializes a local git repository context.
39
47
 
40
48
  Args:
41
- code_repository_id: The ID of the code repository.
49
+ code_repository: The code repository.
42
50
  git_repo: The git repo.
43
51
  remote_name: Name of the remote.
44
52
  """
45
- super().__init__(code_repository_id=code_repository_id)
53
+ super().__init__(code_repository=code_repository)
46
54
  self._git_repo = git_repo
47
55
  self._remote = git_repo.remote(name=remote_name)
48
56
 
@@ -50,14 +58,14 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
50
58
  def at(
51
59
  cls,
52
60
  path: str,
53
- code_repository_id: UUID,
61
+ code_repository: "BaseCodeRepository",
54
62
  remote_url_validation_callback: Callable[[str], bool],
55
63
  ) -> Optional["LocalGitRepositoryContext"]:
56
64
  """Returns a local git repository at the given path.
57
65
 
58
66
  Args:
59
67
  path: The path to the local git repository.
60
- code_repository_id: The ID of the code repository.
68
+ code_repository: The code repository.
61
69
  remote_url_validation_callback: A callback that validates the
62
70
  remote URL of the git repository.
63
71
 
@@ -70,11 +78,13 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
70
78
  from git.exc import InvalidGitRepositoryError
71
79
  from git.repo.base import Repo
72
80
  except ImportError:
81
+ logger.debug("Failed to import git library.")
73
82
  return None
74
83
 
75
84
  try:
76
85
  git_repo = Repo(path=path, search_parent_directories=True)
77
86
  except InvalidGitRepositoryError:
87
+ logger.debug("No git repository exists at path %s.", path)
78
88
  return None
79
89
 
80
90
  remote_name = None
@@ -87,7 +97,7 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
87
97
  return None
88
98
 
89
99
  return LocalGitRepositoryContext(
90
- code_repository_id=code_repository_id,
100
+ code_repository=code_repository,
91
101
  git_repo=git_repo,
92
102
  remote_name=remote_name,
93
103
  )
@@ -124,13 +134,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
124
134
  def is_dirty(self) -> bool:
125
135
  """Whether the git repo is dirty.
126
136
 
127
- A repository counts as dirty if it has any untracked or uncommitted
128
- changes.
137
+ By default, a repository counts as dirty if it has any untracked or
138
+ uncommitted changes. Users can use an environment variable to ignore
139
+ untracked files.
129
140
 
130
141
  Returns:
131
142
  True if the git repo is dirty, False otherwise.
132
143
  """
133
- return self.git_repo.is_dirty(untracked_files=True)
144
+ ignore_untracked_files = handle_bool_env_var(
145
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES, default=False
146
+ )
147
+ return self.git_repo.is_dirty(
148
+ untracked_files=not ignore_untracked_files
149
+ )
134
150
 
135
151
  @property
136
152
  def has_local_changes(self) -> bool:
@@ -14,10 +14,13 @@
14
14
  """Base class for local code repository contexts."""
15
15
 
16
16
  from abc import ABC, abstractmethod
17
- from uuid import UUID
17
+ from typing import TYPE_CHECKING
18
18
 
19
19
  from zenml.logger import get_logger
20
20
 
21
+ if TYPE_CHECKING:
22
+ from zenml.code_repositories import BaseCodeRepository
23
+
21
24
  logger = get_logger(__name__)
22
25
 
23
26
 
@@ -30,22 +33,22 @@ class LocalRepositoryContext(ABC):
30
33
  commit, and whether the repository is dirty.
31
34
  """
32
35
 
33
- def __init__(self, code_repository_id: UUID) -> None:
36
+ def __init__(self, code_repository: "BaseCodeRepository") -> None:
34
37
  """Initializes a local repository context.
35
38
 
36
39
  Args:
37
- code_repository_id: The ID of the code repository.
40
+ code_repository: The code repository.
38
41
  """
39
- self._code_repository_id = code_repository_id
42
+ self._code_repository = code_repository
40
43
 
41
44
  @property
42
- def code_repository_id(self) -> UUID:
43
- """Returns the ID of the code repository.
45
+ def code_repository(self) -> "BaseCodeRepository":
46
+ """Returns the code repository.
44
47
 
45
48
  Returns:
46
- The ID of the code repository.
49
+ The code repository.
47
50
  """
48
- return self._code_repository_id
51
+ return self._code_repository
49
52
 
50
53
  @property
51
54
  @abstractmethod
zenml/constants.py CHANGED
@@ -175,6 +175,9 @@ ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
175
175
  ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
176
176
  "ZENML_PIPELINE_API_TOKEN_EXPIRATION"
177
177
  )
178
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES = (
179
+ "ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES"
180
+ )
178
181
 
179
182
  # Materializer environment variables
180
183
  ENV_ZENML_MATERIALIZER_ALLOW_NON_ASCII_JSON_DUMPS = (
@@ -64,7 +64,10 @@ class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
64
64
  """
65
65
  connector = self.get_connector()
66
66
  if connector:
67
- from azure.identity import ClientSecretCredential
67
+ from azure.identity import (
68
+ ClientSecretCredential,
69
+ DefaultAzureCredential,
70
+ )
68
71
  from azure.storage.blob import BlobServiceClient
69
72
 
70
73
  client = connector.connect()
@@ -77,18 +80,24 @@ class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
77
80
  )
78
81
  # Get the credentials from the client
79
82
  credentials = client.credential
80
- if not isinstance(credentials, ClientSecretCredential):
83
+
84
+ if isinstance(credentials, ClientSecretCredential):
85
+ return AzureSecretSchema(
86
+ client_id=credentials._client_id,
87
+ client_secret=credentials._client_credential,
88
+ tenant_id=credentials._tenant_id,
89
+ account_name=client.account_name,
90
+ )
91
+
92
+ elif isinstance(credentials, DefaultAzureCredential):
93
+ return AzureSecretSchema(account_name=client.account_name)
94
+
95
+ else:
81
96
  raise RuntimeError(
82
97
  "The Azure Artifact Store connector can only be used "
83
98
  "with a service connector that is configured with "
84
- "Azure service principal credentials."
99
+ "Azure service principal credentials or implicit authentication"
85
100
  )
86
- return AzureSecretSchema(
87
- client_id=credentials._client_id,
88
- client_secret=credentials._client_credential,
89
- tenant_id=credentials._tenant_id,
90
- account_name=client.account_name,
91
- )
92
101
 
93
102
  secret = self.get_typed_authentication_secret(
94
103
  expected_schema_type=AzureSecretSchema
@@ -14,12 +14,14 @@
14
14
  """Azure Service Connector."""
15
15
 
16
16
  import datetime
17
+ import json
17
18
  import logging
18
19
  import re
19
20
  import subprocess
20
21
  from typing import Any, Dict, List, Optional, Tuple
21
22
  from uuid import UUID
22
23
 
24
+ import requests
23
25
  import yaml
24
26
  from azure.core.credentials import AccessToken, TokenCredential
25
27
  from azure.core.exceptions import AzureError
@@ -69,6 +71,7 @@ logger = get_logger(__name__)
69
71
  AZURE_MANAGEMENT_TOKEN_SCOPE = "https://management.azure.com/.default"
70
72
  AZURE_SESSION_TOKEN_DEFAULT_EXPIRATION_TIME = 60 * 60 # 1 hour
71
73
  AZURE_SESSION_EXPIRATION_BUFFER = 15 # 15 minutes
74
+ AZURE_ACR_OAUTH_SCOPE = "repository:*:*"
72
75
 
73
76
 
74
77
  class AzureBaseConfig(AuthenticationConfig):
@@ -362,8 +365,8 @@ neither a storage account nor a resource group is configured in the connector,
362
365
  all blob storage containers in all accessible storage accounts will be
363
366
  accessible.
364
367
 
365
- The only Azure authentication method that works with Azure blob storage
366
- resources is the service principal authentication method.
368
+ The only Azure authentication methods that work with Azure blob storage resources are the implicit
369
+ authentication and the service principal authentication method.
367
370
  """,
368
371
  auth_methods=AzureAuthenticationMethods.values(),
369
372
  # Request a blob container to be configured in the
@@ -435,12 +438,10 @@ following formats:
435
438
  If a resource group is configured in the connector, only ACR registries in that
436
439
  resource group will be accessible.
437
440
 
438
- If an authentication method other than the Azure service principal is used for
439
- authentication, the admin account must be enabled for the registry, otherwise
440
- clients will not be able to authenticate to the registry. See the official Azure
441
- [documentation on the admin account](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account
442
- )
443
- for more information.
441
+ If an authentication method other than the Azure service principal is used, Entra ID authentication is used.
442
+ This requires the configured identity to have the `AcrPush` role to be configured.
443
+ If this fails, admin account authentication is tried. For this the admin account must be enabled for the registry.
444
+ See the official Azure[documentation on the admin account](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account) for more information.
444
445
  """,
445
446
  auth_methods=AzureAuthenticationMethods.values(),
446
447
  supports_instances=True,
@@ -1718,10 +1719,14 @@ class AzureServiceConnector(ServiceConnector):
1718
1719
  resource_type=resource_type,
1719
1720
  resource_id=resource_id,
1720
1721
  )
1721
- resource_group: Optional[str] = None
1722
+
1723
+ resource_group: Optional[str]
1724
+ registry_name: str
1725
+ cluster_name: str
1722
1726
 
1723
1727
  if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
1724
1728
  registry_name = self._parse_acr_resource_id(resource_id)
1729
+ registry_domain = f"{registry_name}.azurecr.io"
1725
1730
 
1726
1731
  # If a service principal is used for authentication, the client ID
1727
1732
  # and client secret can be used to authenticate to the registry, if
@@ -1733,13 +1738,12 @@ class AzureServiceConnector(ServiceConnector):
1733
1738
  ):
1734
1739
  assert isinstance(self.config, AzureServicePrincipalConfig)
1735
1740
  username = str(self.config.client_id)
1736
- password = self.config.client_secret
1741
+ password = str(self.config.client_secret)
1737
1742
 
1738
- # Without a service principal, this only works if the admin account
1739
- # is enabled for the registry, but this is not recommended and
1740
- # disabled by default.
1743
+ # Without a service principal, we try to use the AzureDefaultCredentials to authenticate against the ACR.
1744
+ # If this fails, we try to use the admin account.
1745
+ # This has to be enabled for the registry, but this is not recommended and disabled by default.
1741
1746
  # https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#admin-account
1742
-
1743
1747
  else:
1744
1748
  registries = self._list_acr_registries(
1745
1749
  credential,
@@ -1748,22 +1752,35 @@ class AzureServiceConnector(ServiceConnector):
1748
1752
  registry_name, resource_group = registries.popitem()
1749
1753
 
1750
1754
  try:
1751
- client = ContainerRegistryManagementClient(
1752
- credential, subscription_id
1755
+ username = "00000000-0000-0000-0000-000000000000"
1756
+ password = _ACRTokenExchangeClient(
1757
+ credential
1758
+ ).get_acr_access_token(
1759
+ registry_domain, AZURE_ACR_OAUTH_SCOPE
1753
1760
  )
1754
-
1755
- registry_credentials = client.registries.list_credentials(
1756
- resource_group, registry_name
1761
+ except AuthorizationException:
1762
+ logger.warning(
1763
+ "Falling back to admin credentials for ACR authentication. Be sure to assign AcrPush role to the configured identity."
1757
1764
  )
1758
- username = registry_credentials.username
1759
- password = registry_credentials.passwords[0].value
1760
- except AzureError as e:
1761
- raise AuthorizationException(
1762
- f"failed to list admin credentials for Azure Container "
1763
- f"Registry '{registry_name}' in resource group "
1764
- f"'{resource_group}'. Make sure the registry is "
1765
- f"configured with an admin account: {e}"
1766
- ) from e
1765
+ try:
1766
+ client = ContainerRegistryManagementClient(
1767
+ credential, subscription_id
1768
+ )
1769
+
1770
+ registry_credentials = (
1771
+ client.registries.list_credentials(
1772
+ resource_group, registry_name
1773
+ )
1774
+ )
1775
+ username = registry_credentials.username
1776
+ password = registry_credentials.passwords[0].value
1777
+ except AzureError as e:
1778
+ raise AuthorizationException(
1779
+ f"failed to list admin credentials for Azure Container "
1780
+ f"Registry '{registry_name}' in resource group "
1781
+ f"'{resource_group}'. Make sure the registry is "
1782
+ f"configured with an admin account: {e}"
1783
+ ) from e
1767
1784
 
1768
1785
  # Create a client-side Docker connector instance with the temporary
1769
1786
  # Docker credentials
@@ -1775,7 +1792,7 @@ class AzureServiceConnector(ServiceConnector):
1775
1792
  config=DockerConfiguration(
1776
1793
  username=username,
1777
1794
  password=password,
1778
- registry=f"{registry_name}.azurecr.io",
1795
+ registry=registry_domain,
1779
1796
  ),
1780
1797
  expires_at=expires_at,
1781
1798
  )
@@ -1847,3 +1864,103 @@ class AzureServiceConnector(ServiceConnector):
1847
1864
  )
1848
1865
 
1849
1866
  raise ValueError(f"Unsupported resource type: {resource_type}")
1867
+
1868
+
1869
+ class _ACRTokenExchangeClient:
1870
+ def __init__(self, credential: TokenCredential):
1871
+ self._credential = credential
1872
+
1873
+ def _get_aad_access_token(self) -> str:
1874
+ aad_access_token: str = self._credential.get_token(
1875
+ AZURE_MANAGEMENT_TOKEN_SCOPE
1876
+ ).token
1877
+ return aad_access_token
1878
+
1879
+ # https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#authenticating-docker-with-an-acr-refresh-token
1880
+ def get_acr_refresh_token(self, acr_url: str) -> str:
1881
+ try:
1882
+ aad_access_token = self._get_aad_access_token()
1883
+
1884
+ headers = {
1885
+ "Content-Type": "application/x-www-form-urlencoded",
1886
+ }
1887
+
1888
+ data = {
1889
+ "grant_type": "access_token",
1890
+ "service": acr_url,
1891
+ "access_token": aad_access_token,
1892
+ }
1893
+
1894
+ response = requests.post(
1895
+ f"https://{acr_url}/oauth2/exchange",
1896
+ headers=headers,
1897
+ data=data,
1898
+ timeout=5,
1899
+ )
1900
+
1901
+ if response.status_code != 200:
1902
+ raise AuthorizationException(
1903
+ f"failed to get refresh token for Azure Container "
1904
+ f"Registry '{acr_url}' in resource group. "
1905
+ f"Be sure to assign AcrPush to the configured principal. "
1906
+ f"The token exchange returned status {response.status_code} "
1907
+ f"with body '{response.content.decode()}'"
1908
+ )
1909
+
1910
+ acr_refresh_token_response = json.loads(response.content)
1911
+ acr_refresh_token: str = acr_refresh_token_response[
1912
+ "refresh_token"
1913
+ ]
1914
+ return acr_refresh_token
1915
+
1916
+ except (AzureError, requests.exceptions.RequestException) as e:
1917
+ raise AuthorizationException(
1918
+ f"failed to get refresh token for Azure Container "
1919
+ f"Registry '{acr_url}' in resource group. "
1920
+ f"Make sure the implicit authentication identity "
1921
+ f"has access to the configured registry: {e}"
1922
+ ) from e
1923
+
1924
+ # https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2token-to-get-an-acr-access-token
1925
+ def get_acr_access_token(self, acr_url: str, scope: str) -> str:
1926
+ acr_refresh_token = self.get_acr_refresh_token(acr_url)
1927
+
1928
+ headers = {
1929
+ "Content-Type": "application/x-www-form-urlencoded",
1930
+ }
1931
+
1932
+ data = {
1933
+ "grant_type": "refresh_token",
1934
+ "service": acr_url,
1935
+ "scope": scope,
1936
+ "refresh_token": acr_refresh_token,
1937
+ }
1938
+
1939
+ try:
1940
+ response = requests.post(
1941
+ f"https://{acr_url}/oauth2/token",
1942
+ headers=headers,
1943
+ data=data,
1944
+ timeout=5,
1945
+ )
1946
+
1947
+ if response.status_code != 200:
1948
+ raise AuthorizationException(
1949
+ f"failed to get access token for Azure Container "
1950
+ f"Registry '{acr_url}' in resource group. "
1951
+ f"Be sure to assign AcrPush to the configured principal. "
1952
+ f"The token exchange returned status {response.status_code} "
1953
+ f"with body '{response.content.decode()}'"
1954
+ )
1955
+
1956
+ acr_access_token_response = json.loads(response.content)
1957
+ acr_access_token: str = acr_access_token_response["access_token"]
1958
+ return acr_access_token
1959
+
1960
+ except (AzureError, requests.exceptions.RequestException) as e:
1961
+ raise AuthorizationException(
1962
+ f"failed to get access token for Azure Container "
1963
+ f"Registry '{acr_url}' in resource group. "
1964
+ f"Make sure the implicit authentication identity "
1965
+ f"has access to the configured registry: {e}"
1966
+ ) from e
@@ -15,7 +15,8 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import List, Optional
18
+ from typing import Any, Dict, List, Optional
19
+ from uuid import uuid4
19
20
 
20
21
  import requests
21
22
  from github import Consts, Github, GithubException
@@ -62,6 +63,20 @@ class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig):
62
63
  class GitHubCodeRepository(BaseCodeRepository):
63
64
  """GitHub code repository."""
64
65
 
66
+ @classmethod
67
+ def validate_config(cls, config: Dict[str, Any]) -> None:
68
+ """Validate the code repository config.
69
+
70
+ This method should check that the config/credentials are valid and
71
+ the configured repository exists.
72
+
73
+ Args:
74
+ config: The configuration.
75
+ """
76
+ code_repo = cls(id=uuid4(), name="", config=config)
77
+ # Try to access the project to make sure it exists
78
+ _ = code_repo.github_repo
79
+
65
80
  @property
66
81
  def config(self) -> GitHubCodeRepositoryConfig:
67
82
  """Returns the `GitHubCodeRepositoryConfig` config.
@@ -190,7 +205,7 @@ class GitHubCodeRepository(BaseCodeRepository):
190
205
  """
191
206
  return LocalGitRepositoryContext.at(
192
207
  path=path,
193
- code_repository_id=self.id,
208
+ code_repository=self,
194
209
  remote_url_validation_callback=self.check_remote_url,
195
210
  )
196
211
 
@@ -15,7 +15,8 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import Optional
18
+ from typing import Any, Dict, Optional
19
+ from uuid import uuid4
19
20
 
20
21
  from gitlab import Gitlab
21
22
  from gitlab.v4.objects import Project
@@ -63,6 +64,20 @@ class GitLabCodeRepositoryConfig(BaseCodeRepositoryConfig):
63
64
  class GitLabCodeRepository(BaseCodeRepository):
64
65
  """GitLab code repository."""
65
66
 
67
+ @classmethod
68
+ def validate_config(cls, config: Dict[str, Any]) -> None:
69
+ """Validate the code repository config.
70
+
71
+ This method should check that the config/credentials are valid and
72
+ the configured repository exists.
73
+
74
+ Args:
75
+ config: The configuration.
76
+ """
77
+ code_repo = cls(id=uuid4(), name="", config=config)
78
+ # Try to access the project to make sure it exists
79
+ _ = code_repo.gitlab_project
80
+
66
81
  @property
67
82
  def config(self) -> GitLabCodeRepositoryConfig:
68
83
  """Returns the `GitLabCodeRepositoryConfig` config.
@@ -147,7 +162,7 @@ class GitLabCodeRepository(BaseCodeRepository):
147
162
  """
148
163
  return LocalGitRepositoryContext.at(
149
164
  path=path,
150
- code_repository_id=self.id,
165
+ code_repository=self,
151
166
  remote_url_validation_callback=self.check_remote_url,
152
167
  )
153
168
 
@@ -44,7 +44,7 @@ class BasePyTorchMaterializer(BaseMaterializer):
44
44
  # NOTE (security): The `torch.load` function uses `pickle` as
45
45
  # the default unpickler, which is NOT secure. This materializer
46
46
  # is intended for use with trusted data sources.
47
- return torch.load(f) # nosec
47
+ return torch.load(f, weights_only=False) # nosec
48
48
 
49
49
  def save(self, obj: Any) -> None:
50
50
  """Uses `torch.save` to save a PyTorch object.