zenml-nightly 0.73.0.dev20250129__py3-none-any.whl → 0.73.0.dev20250131__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/code_repository.py +26 -0
  3. zenml/cli/utils.py +14 -9
  4. zenml/client.py +2 -7
  5. zenml/code_repositories/base_code_repository.py +30 -2
  6. zenml/code_repositories/git/local_git_repository_context.py +26 -10
  7. zenml/code_repositories/local_repository_context.py +11 -8
  8. zenml/constants.py +3 -0
  9. zenml/integrations/gcp/constants.py +1 -1
  10. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +3 -1
  11. zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -0
  12. zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
  13. zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
  14. zenml/integrations/huggingface/services/huggingface_deployment.py +72 -29
  15. zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
  16. zenml/integrations/vllm/services/vllm_deployment.py +6 -1
  17. zenml/pipelines/build_utils.py +42 -35
  18. zenml/pipelines/pipeline_definition.py +5 -2
  19. zenml/utils/code_repository_utils.py +11 -2
  20. zenml/utils/downloaded_repository_context.py +3 -5
  21. zenml/utils/source_utils.py +3 -3
  22. zenml/zen_stores/migrations/utils.py +48 -1
  23. zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
  24. zenml/zen_stores/rest_zen_store.py +3 -13
  25. zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
  26. zenml/zen_stores/schemas/schema_utils.py +34 -2
  27. zenml/zen_stores/schemas/tag_schemas.py +14 -1
  28. zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -2
  29. zenml/zen_stores/sql_zen_store.py +24 -17
  30. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/METADATA +1 -1
  31. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/RECORD +34 -33
  32. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/LICENSE +0 -0
  33. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/WHEEL +0 -0
  34. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.73.0.dev20250129
1
+ 0.73.0.dev20250131
@@ -162,6 +162,32 @@ def register_code_repository(
162
162
  cli_utils.declare(f"Successfully registered code repository `{name}`.")
163
163
 
164
164
 
165
+ @code_repository.command("describe", help="Describe a code repository.")
166
+ @click.argument(
167
+ "name_id_or_prefix",
168
+ type=str,
169
+ required=True,
170
+ )
171
+ def describe_code_repository(name_id_or_prefix: str) -> None:
172
+ """Describe a code repository.
173
+
174
+ Args:
175
+ name_id_or_prefix: Name, ID or prefix of the code repository.
176
+ """
177
+ client = Client()
178
+ try:
179
+ code_repository = client.get_code_repository(
180
+ name_id_or_prefix=name_id_or_prefix,
181
+ )
182
+ except KeyError as err:
183
+ cli_utils.error(str(err))
184
+ else:
185
+ cli_utils.print_pydantic_model(
186
+ title=f"Code repository '{code_repository.name}'",
187
+ model=code_repository,
188
+ )
189
+
190
+
165
191
  @code_repository.command("list", help="List all connected code repositories.")
166
192
  @list_options(CodeRepositoryFilter)
167
193
  def list_code_repositories(**kwargs: Any) -> None:
zenml/cli/utils.py CHANGED
@@ -76,6 +76,7 @@ from zenml.models import (
76
76
  from zenml.models.v2.base.filter import FilterGenerator
77
77
  from zenml.services import BaseService, ServiceState
78
78
  from zenml.stack import StackComponent
79
+ from zenml.stack.flavor import Flavor
79
80
  from zenml.stack.stack_component import StackComponentConfig
80
81
  from zenml.utils import secret_utils
81
82
  from zenml.utils.time_utils import expires_in
@@ -2151,10 +2152,11 @@ def _scrub_secret(config: StackComponentConfig) -> Dict[str, Any]:
2151
2152
  config_dict = {}
2152
2153
  config_fields = config.__class__.model_fields
2153
2154
  for key, value in config_fields.items():
2154
- if secret_utils.is_secret_field(value):
2155
- config_dict[key] = "********"
2156
- else:
2157
- config_dict[key] = getattr(config, key)
2155
+ if getattr(config, key):
2156
+ if secret_utils.is_secret_field(value):
2157
+ config_dict[key] = "********"
2158
+ else:
2159
+ config_dict[key] = getattr(config, key)
2158
2160
  return config_dict
2159
2161
 
2160
2162
 
@@ -2164,8 +2166,6 @@ def print_debug_stack() -> None:
2164
2166
 
2165
2167
  client = Client()
2166
2168
  stack = client.get_stack()
2167
- active_stack = client.active_stack
2168
- components = _get_stack_components(active_stack)
2169
2169
 
2170
2170
  declare("\nCURRENT STACK\n", bold=True)
2171
2171
  console.print(f"Name: {stack.name}")
@@ -2176,7 +2176,8 @@ def print_debug_stack() -> None:
2176
2176
  f"Workspace: {stack.workspace.name} / {str(stack.workspace.id)}"
2177
2177
  )
2178
2178
 
2179
- for component in components:
2179
+ for component_type, components in stack.components.items():
2180
+ component = components[0]
2180
2181
  component_response = client.get_stack_component(
2181
2182
  name_id_or_prefix=component.id, component_type=component.type
2182
2183
  )
@@ -2186,8 +2187,12 @@ def print_debug_stack() -> None:
2186
2187
  console.print(f"Name: {component.name}")
2187
2188
  console.print(f"ID: {str(component.id)}")
2188
2189
  console.print(f"Type: {component.type.value}")
2189
- console.print(f"Flavor: {component.flavor}")
2190
- console.print(f"Configuration: {_scrub_secret(component.config)}")
2190
+ console.print(f"Flavor: {component.flavor_name}")
2191
+
2192
+ flavor = Flavor.from_model(component.flavor)
2193
+ config = flavor.config_class(**component.configuration)
2194
+
2195
+ console.print(f"Configuration: {_scrub_secret(config)}")
2191
2196
  if (
2192
2197
  component_response.user
2193
2198
  and component_response.user.name
zenml/client.py CHANGED
@@ -34,7 +34,7 @@ from typing import (
34
34
  Union,
35
35
  cast,
36
36
  )
37
- from uuid import UUID, uuid4
37
+ from uuid import UUID
38
38
 
39
39
  from pydantic import ConfigDict, SecretStr
40
40
 
@@ -4980,12 +4980,7 @@ class Client(metaclass=ClientMetaClass):
4980
4980
  )
4981
4981
  )
4982
4982
  try:
4983
- # This does a login to verify the credentials
4984
- code_repo_class(id=uuid4(), config=config)
4985
-
4986
- # Explicitly access the config for pydantic validation, in case
4987
- # the login for some reason did not do that.
4988
- _ = code_repo_class.config
4983
+ code_repo_class.validate_config(config)
4989
4984
  except Exception as e:
4990
4985
  raise RuntimeError(
4991
4986
  "Failed to validate code repository config."
@@ -15,7 +15,7 @@
15
15
 
16
16
  from abc import ABC, abstractmethod
17
17
  from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type
18
- from uuid import UUID
18
+ from uuid import UUID, uuid4
19
19
 
20
20
  from zenml.config.secret_reference_mixin import SecretReferenceMixin
21
21
  from zenml.logger import get_logger
@@ -44,15 +44,18 @@ class BaseCodeRepository(ABC):
44
44
  def __init__(
45
45
  self,
46
46
  id: UUID,
47
+ name: str,
47
48
  config: Dict[str, Any],
48
49
  ) -> None:
49
50
  """Initializes a code repository.
50
51
 
51
52
  Args:
52
53
  id: The ID of the code repository.
54
+ name: The name of the code repository.
53
55
  config: The config of the code repository.
54
56
  """
55
57
  self._id = id
58
+ self._name = name
56
59
  self._config = config
57
60
  self.login()
58
61
 
@@ -80,7 +83,23 @@ class BaseCodeRepository(ABC):
80
83
  source=model.source, expected_class=BaseCodeRepository
81
84
  )
82
85
  )
83
- return class_(id=model.id, config=model.config)
86
+ return class_(id=model.id, name=model.name, config=model.config)
87
+
88
+ @classmethod
89
+ def validate_config(cls, config: Dict[str, Any]) -> None:
90
+ """Validate the code repository config.
91
+
92
+ This method should check that the config/credentials are valid and
93
+ the configured repository exists.
94
+
95
+ Args:
96
+ config: The configuration.
97
+ """
98
+ # The initialization calls the login to verify the credentials
99
+ code_repo = cls(id=uuid4(), name="", config=config)
100
+
101
+ # Explicitly access the config for pydantic validation
102
+ _ = code_repo.config
84
103
 
85
104
  @property
86
105
  def id(self) -> UUID:
@@ -91,6 +110,15 @@ class BaseCodeRepository(ABC):
91
110
  """
92
111
  return self._id
93
112
 
113
+ @property
114
+ def name(self) -> str:
115
+ """Name of the code repository.
116
+
117
+ Returns:
118
+ The name of the code repository.
119
+ """
120
+ return self._name
121
+
94
122
  @property
95
123
  def requirements(self) -> Set[str]:
96
124
  """Set of PyPI requirements for the repository.
@@ -14,11 +14,14 @@
14
14
  """Implementation of the Local git repository context."""
15
15
 
16
16
  from typing import TYPE_CHECKING, Callable, Optional, cast
17
- from uuid import UUID
18
17
 
19
18
  from zenml.code_repositories import (
20
19
  LocalRepositoryContext,
21
20
  )
21
+ from zenml.constants import (
22
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES,
23
+ handle_bool_env_var,
24
+ )
22
25
  from zenml.logger import get_logger
23
26
 
24
27
  if TYPE_CHECKING:
@@ -26,6 +29,8 @@ if TYPE_CHECKING:
26
29
  from git.remote import Remote
27
30
  from git.repo.base import Repo
28
31
 
32
+ from zenml.code_repositories import BaseCodeRepository
33
+
29
34
  logger = get_logger(__name__)
30
35
 
31
36
 
@@ -33,16 +38,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
33
38
  """Local git repository context."""
34
39
 
35
40
  def __init__(
36
- self, code_repository_id: UUID, git_repo: "Repo", remote_name: str
41
+ self,
42
+ code_repository: "BaseCodeRepository",
43
+ git_repo: "Repo",
44
+ remote_name: str,
37
45
  ):
38
46
  """Initializes a local git repository context.
39
47
 
40
48
  Args:
41
- code_repository_id: The ID of the code repository.
49
+ code_repository: The code repository.
42
50
  git_repo: The git repo.
43
51
  remote_name: Name of the remote.
44
52
  """
45
- super().__init__(code_repository_id=code_repository_id)
53
+ super().__init__(code_repository=code_repository)
46
54
  self._git_repo = git_repo
47
55
  self._remote = git_repo.remote(name=remote_name)
48
56
 
@@ -50,14 +58,14 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
50
58
  def at(
51
59
  cls,
52
60
  path: str,
53
- code_repository_id: UUID,
61
+ code_repository: "BaseCodeRepository",
54
62
  remote_url_validation_callback: Callable[[str], bool],
55
63
  ) -> Optional["LocalGitRepositoryContext"]:
56
64
  """Returns a local git repository at the given path.
57
65
 
58
66
  Args:
59
67
  path: The path to the local git repository.
60
- code_repository_id: The ID of the code repository.
68
+ code_repository: The code repository.
61
69
  remote_url_validation_callback: A callback that validates the
62
70
  remote URL of the git repository.
63
71
 
@@ -70,11 +78,13 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
70
78
  from git.exc import InvalidGitRepositoryError
71
79
  from git.repo.base import Repo
72
80
  except ImportError:
81
+ logger.debug("Failed to import git library.")
73
82
  return None
74
83
 
75
84
  try:
76
85
  git_repo = Repo(path=path, search_parent_directories=True)
77
86
  except InvalidGitRepositoryError:
87
+ logger.debug("No git repository exists at path %s.", path)
78
88
  return None
79
89
 
80
90
  remote_name = None
@@ -87,7 +97,7 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
87
97
  return None
88
98
 
89
99
  return LocalGitRepositoryContext(
90
- code_repository_id=code_repository_id,
100
+ code_repository=code_repository,
91
101
  git_repo=git_repo,
92
102
  remote_name=remote_name,
93
103
  )
@@ -124,13 +134,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
124
134
  def is_dirty(self) -> bool:
125
135
  """Whether the git repo is dirty.
126
136
 
127
- A repository counts as dirty if it has any untracked or uncommitted
128
- changes.
137
+ By default, a repository counts as dirty if it has any untracked or
138
+ uncommitted changes. Users can use an environment variable to ignore
139
+ untracked files.
129
140
 
130
141
  Returns:
131
142
  True if the git repo is dirty, False otherwise.
132
143
  """
133
- return self.git_repo.is_dirty(untracked_files=True)
144
+ ignore_untracked_files = handle_bool_env_var(
145
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES, default=False
146
+ )
147
+ return self.git_repo.is_dirty(
148
+ untracked_files=not ignore_untracked_files
149
+ )
134
150
 
135
151
  @property
136
152
  def has_local_changes(self) -> bool:
@@ -14,10 +14,13 @@
14
14
  """Base class for local code repository contexts."""
15
15
 
16
16
  from abc import ABC, abstractmethod
17
- from uuid import UUID
17
+ from typing import TYPE_CHECKING
18
18
 
19
19
  from zenml.logger import get_logger
20
20
 
21
+ if TYPE_CHECKING:
22
+ from zenml.code_repositories import BaseCodeRepository
23
+
21
24
  logger = get_logger(__name__)
22
25
 
23
26
 
@@ -30,22 +33,22 @@ class LocalRepositoryContext(ABC):
30
33
  commit, and whether the repository is dirty.
31
34
  """
32
35
 
33
- def __init__(self, code_repository_id: UUID) -> None:
36
+ def __init__(self, code_repository: "BaseCodeRepository") -> None:
34
37
  """Initializes a local repository context.
35
38
 
36
39
  Args:
37
- code_repository_id: The ID of the code repository.
40
+ code_repository: The code repository.
38
41
  """
39
- self._code_repository_id = code_repository_id
42
+ self._code_repository = code_repository
40
43
 
41
44
  @property
42
- def code_repository_id(self) -> UUID:
43
- """Returns the ID of the code repository.
45
+ def code_repository(self) -> "BaseCodeRepository":
46
+ """Returns the code repository.
44
47
 
45
48
  Returns:
46
- The ID of the code repository.
49
+ The code repository.
47
50
  """
48
- return self._code_repository_id
51
+ return self._code_repository
49
52
 
50
53
  @property
51
54
  @abstractmethod
zenml/constants.py CHANGED
@@ -175,6 +175,9 @@ ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
175
175
  ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
176
176
  "ZENML_PIPELINE_API_TOKEN_EXPIRATION"
177
177
  )
178
+ ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES = (
179
+ "ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES"
180
+ )
178
181
 
179
182
  # Materializer environment variables
180
183
  ENV_ZENML_MATERIALIZER_ALLOW_NON_ASCII_JSON_DUMPS = (
@@ -16,7 +16,7 @@
16
16
  from google.cloud.aiplatform_v1.types.job_state import JobState
17
17
 
18
18
  VERTEX_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com"
19
- POLLING_INTERVAL_IN_SECONDS = 30
19
+ POLLING_INTERVAL_IN_SECONDS = 10
20
20
  CONNECTION_ERROR_RETRY_LIMIT = 5
21
21
  _VERTEX_JOB_STATE_SUCCEEDED = JobState.JOB_STATE_SUCCEEDED
22
22
  _VERTEX_JOB_STATE_FAILED = JobState.JOB_STATE_FAILED
@@ -51,7 +51,8 @@ class VertexStepOperatorSettings(BaseSettings):
51
51
  https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
52
52
  boot_disk_type: Type of the boot disk. (Default: pd-ssd)
53
53
  https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
54
-
54
+ persistent_resource_id: The ID of the persistent resource to use for the job.
55
+ https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
55
56
  """
56
57
 
57
58
  accelerator_type: Optional[str] = None
@@ -59,6 +60,7 @@ class VertexStepOperatorSettings(BaseSettings):
59
60
  machine_type: str = "n1-standard-4"
60
61
  boot_disk_size_gb: int = 100
61
62
  boot_disk_type: str = "pd-ssd"
63
+ persistent_resource_id: Optional[str] = None
62
64
 
63
65
 
64
66
  class VertexStepOperatorConfig(
@@ -258,6 +258,7 @@ class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
258
258
  if self.config.reserved_ip_ranges
259
259
  else []
260
260
  ),
261
+ "persistent_resource_id": settings.persistent_resource_id,
261
262
  },
262
263
  "labels": job_labels,
263
264
  "encryption_spec": {
@@ -15,7 +15,8 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import List, Optional
18
+ from typing import Any, Dict, List, Optional
19
+ from uuid import uuid4
19
20
 
20
21
  import requests
21
22
  from github import Consts, Github, GithubException
@@ -62,6 +63,20 @@ class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig):
62
63
  class GitHubCodeRepository(BaseCodeRepository):
63
64
  """GitHub code repository."""
64
65
 
66
+ @classmethod
67
+ def validate_config(cls, config: Dict[str, Any]) -> None:
68
+ """Validate the code repository config.
69
+
70
+ This method should check that the config/credentials are valid and
71
+ the configured repository exists.
72
+
73
+ Args:
74
+ config: The configuration.
75
+ """
76
+ code_repo = cls(id=uuid4(), name="", config=config)
77
+ # Try to access the project to make sure it exists
78
+ _ = code_repo.github_repo
79
+
65
80
  @property
66
81
  def config(self) -> GitHubCodeRepositoryConfig:
67
82
  """Returns the `GitHubCodeRepositoryConfig` config.
@@ -190,7 +205,7 @@ class GitHubCodeRepository(BaseCodeRepository):
190
205
  """
191
206
  return LocalGitRepositoryContext.at(
192
207
  path=path,
193
- code_repository_id=self.id,
208
+ code_repository=self,
194
209
  remote_url_validation_callback=self.check_remote_url,
195
210
  )
196
211
 
@@ -15,7 +15,8 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import Optional
18
+ from typing import Any, Dict, Optional
19
+ from uuid import uuid4
19
20
 
20
21
  from gitlab import Gitlab
21
22
  from gitlab.v4.objects import Project
@@ -63,6 +64,20 @@ class GitLabCodeRepositoryConfig(BaseCodeRepositoryConfig):
63
64
  class GitLabCodeRepository(BaseCodeRepository):
64
65
  """GitLab code repository."""
65
66
 
67
+ @classmethod
68
+ def validate_config(cls, config: Dict[str, Any]) -> None:
69
+ """Validate the code repository config.
70
+
71
+ This method should check that the config/credentials are valid and
72
+ the configured repository exists.
73
+
74
+ Args:
75
+ config: The configuration.
76
+ """
77
+ code_repo = cls(id=uuid4(), name="", config=config)
78
+ # Try to access the project to make sure it exists
79
+ _ = code_repo.gitlab_project
80
+
66
81
  @property
67
82
  def config(self) -> GitLabCodeRepositoryConfig:
68
83
  """Returns the `GitLabCodeRepositoryConfig` config.
@@ -147,7 +162,7 @@ class GitLabCodeRepository(BaseCodeRepository):
147
162
  """
148
163
  return LocalGitRepositoryContext.at(
149
164
  path=path,
150
- code_repository_id=self.id,
165
+ code_repository=self,
151
166
  remote_url_validation_callback=self.check_remote_url,
152
167
  )
153
168
 
@@ -13,17 +13,18 @@
13
13
  # permissions and limitations under the License.
14
14
  """Implementation of the Hugging Face Deployment service."""
15
15
 
16
- from typing import Any, Generator, Optional, Tuple
16
+ from typing import Any, Dict, Generator, Optional, Tuple
17
17
 
18
18
  from huggingface_hub import (
19
19
  InferenceClient,
20
20
  InferenceEndpoint,
21
21
  InferenceEndpointError,
22
22
  InferenceEndpointStatus,
23
+ InferenceEndpointType,
23
24
  create_inference_endpoint,
24
25
  get_inference_endpoint,
25
26
  )
26
- from huggingface_hub.utils import HfHubHTTPError
27
+ from huggingface_hub.errors import HfHubHTTPError
27
28
  from pydantic import Field
28
29
 
29
30
  from zenml.client import Client
@@ -138,30 +139,67 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
138
139
  """
139
140
  return self.hf_endpoint.client
140
141
 
142
+ def _validate_endpoint_configuration(self) -> Dict[str, str]:
143
+ """Validates the configuration to provision a Huggingface service.
144
+
145
+ Raises:
146
+ ValueError: if there is a missing value in the configuration
147
+
148
+ Returns:
149
+ The validated configuration values.
150
+ """
151
+ configuration = {}
152
+ missing_keys = []
153
+
154
+ for k, v in {
155
+ "repository": self.config.repository,
156
+ "framework": self.config.framework,
157
+ "accelerator": self.config.accelerator,
158
+ "instance_size": self.config.instance_size,
159
+ "instance_type": self.config.instance_type,
160
+ "region": self.config.region,
161
+ "vendor": self.config.vendor,
162
+ "endpoint_type": self.config.endpoint_type,
163
+ }.items():
164
+ if v is None:
165
+ missing_keys.append(k)
166
+ else:
167
+ configuration[k] = v
168
+
169
+ if missing_keys:
170
+ raise ValueError(
171
+ f"Missing values in the Huggingface Service "
172
+ f"configuration: {', '.join(missing_keys)}"
173
+ )
174
+
175
+ return configuration
176
+
141
177
  def provision(self) -> None:
142
178
  """Provision or update remote Hugging Face deployment instance.
143
179
 
144
180
  Raises:
145
- Exception: If any unexpected error while creating inference endpoint.
181
+ Exception: If any unexpected error while creating inference
182
+ endpoint.
146
183
  """
147
184
  try:
148
- # Attempt to create and wait for the inference endpoint
185
+ validated_config = self._validate_endpoint_configuration()
186
+
149
187
  hf_endpoint = create_inference_endpoint(
150
188
  name=self._generate_an_endpoint_name(),
151
- repository=self.config.repository,
152
- framework=self.config.framework,
153
- accelerator=self.config.accelerator,
154
- instance_size=self.config.instance_size,
155
- instance_type=self.config.instance_type,
156
- region=self.config.region,
157
- vendor=self.config.vendor,
189
+ repository=validated_config["repository"],
190
+ framework=validated_config["framework"],
191
+ accelerator=validated_config["accelerator"],
192
+ instance_size=validated_config["instance_size"],
193
+ instance_type=validated_config["instance_type"],
194
+ region=validated_config["region"],
195
+ vendor=validated_config["vendor"],
158
196
  account_id=self.config.account_id,
159
197
  min_replica=self.config.min_replica,
160
198
  max_replica=self.config.max_replica,
161
199
  revision=self.config.revision,
162
200
  task=self.config.task,
163
201
  custom_image=self.config.custom_image,
164
- type=self.config.endpoint_type,
202
+ type=InferenceEndpointType(validated_config["endpoint_type"]),
165
203
  token=self.get_token(),
166
204
  namespace=self.config.namespace,
167
205
  ).wait(timeout=POLLING_TIMEOUT)
@@ -172,21 +210,25 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
172
210
  )
173
211
  # Catch-all for any other unexpected errors
174
212
  raise Exception(
175
- f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: {e}"
213
+ "An unexpected error occurred while provisioning the "
214
+ f"Hugging Face inference endpoint: {e}"
176
215
  )
177
216
 
178
217
  # Check if the endpoint URL is available after provisioning
179
218
  if hf_endpoint.url:
180
219
  logger.info(
181
- f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: {hf_endpoint.url}"
220
+ "Hugging Face inference endpoint successfully deployed "
221
+ f"and available. Endpoint URL: {hf_endpoint.url}"
182
222
  )
183
223
  else:
184
224
  logger.error(
185
- "Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details."
225
+ "Failed to start Hugging Face inference endpoint "
226
+ "service: No URL available, please check the Hugging "
227
+ "Face console for more details."
186
228
  )
187
229
 
188
230
  def check_status(self) -> Tuple[ServiceState, str]:
189
- """Check the the current operational state of the Hugging Face deployment.
231
+ """Check the current operational state of the Hugging Face deployment.
190
232
 
191
233
  Returns:
192
234
  The operational state of the Hugging Face deployment and a message
@@ -196,26 +238,29 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
196
238
  try:
197
239
  status = self.hf_endpoint.status
198
240
  if status == InferenceEndpointStatus.RUNNING:
199
- return (ServiceState.ACTIVE, "")
241
+ return ServiceState.ACTIVE, ""
200
242
 
201
243
  elif status == InferenceEndpointStatus.SCALED_TO_ZERO:
202
244
  return (
203
245
  ServiceState.SCALED_TO_ZERO,
204
- "Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand.",
246
+ "Hugging Face Inference Endpoint is scaled to zero, but "
247
+ "still running. It will be started on demand.",
205
248
  )
206
249
 
207
250
  elif status == InferenceEndpointStatus.FAILED:
208
251
  return (
209
252
  ServiceState.ERROR,
210
- "Hugging Face Inference Endpoint deployment is inactive or not found",
253
+ "Hugging Face Inference Endpoint deployment is inactive "
254
+ "or not found",
211
255
  )
212
256
  elif status == InferenceEndpointStatus.PENDING:
213
- return (ServiceState.PENDING_STARTUP, "")
214
- return (ServiceState.PENDING_STARTUP, "")
257
+ return ServiceState.PENDING_STARTUP, ""
258
+ return ServiceState.PENDING_STARTUP, ""
215
259
  except (InferenceEndpointError, HfHubHTTPError):
216
260
  return (
217
261
  ServiceState.INACTIVE,
218
- "Hugging Face Inference Endpoint deployment is inactive or not found",
262
+ "Hugging Face Inference Endpoint deployment is inactive or "
263
+ "not found",
219
264
  )
220
265
 
221
266
  def deprovision(self, force: bool = False) -> None:
@@ -253,15 +298,13 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
253
298
  )
254
299
  if self.prediction_url is not None:
255
300
  if self.hf_endpoint.task == "text-generation":
256
- result = self.inference_client.task_generation(
301
+ return self.inference_client.text_generation(
257
302
  data, max_new_tokens=max_new_tokens
258
303
  )
259
- else:
260
- # TODO: Add support for all different supported tasks
261
- raise NotImplementedError(
262
- "Tasks other than text-generation is not implemented."
263
- )
264
- return result
304
+ # TODO: Add support for all different supported tasks
305
+ raise NotImplementedError(
306
+ "Tasks other than text-generation is not implemented."
307
+ )
265
308
 
266
309
  def get_logs(
267
310
  self, follow: bool = False, tail: Optional[int] = None
@@ -44,7 +44,7 @@ class BasePyTorchMaterializer(BaseMaterializer):
44
44
  # NOTE (security): The `torch.load` function uses `pickle` as
45
45
  # the default unpickler, which is NOT secure. This materializer
46
46
  # is intended for use with trusted data sources.
47
- return torch.load(f) # nosec
47
+ return torch.load(f, weights_only=False) # nosec
48
48
 
49
49
  def save(self, obj: Any) -> None:
50
50
  """Uses `torch.save` to save a PyTorch object.