zenml-nightly 0.73.0.dev20250129__py3-none-any.whl → 0.73.0.dev20250131__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/code_repository.py +26 -0
- zenml/cli/utils.py +14 -9
- zenml/client.py +2 -7
- zenml/code_repositories/base_code_repository.py +30 -2
- zenml/code_repositories/git/local_git_repository_context.py +26 -10
- zenml/code_repositories/local_repository_context.py +11 -8
- zenml/constants.py +3 -0
- zenml/integrations/gcp/constants.py +1 -1
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +3 -1
- zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -0
- zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
- zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
- zenml/integrations/huggingface/services/huggingface_deployment.py +72 -29
- zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
- zenml/integrations/vllm/services/vllm_deployment.py +6 -1
- zenml/pipelines/build_utils.py +42 -35
- zenml/pipelines/pipeline_definition.py +5 -2
- zenml/utils/code_repository_utils.py +11 -2
- zenml/utils/downloaded_repository_context.py +3 -5
- zenml/utils/source_utils.py +3 -3
- zenml/zen_stores/migrations/utils.py +48 -1
- zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
- zenml/zen_stores/rest_zen_store.py +3 -13
- zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
- zenml/zen_stores/schemas/schema_utils.py +34 -2
- zenml/zen_stores/schemas/tag_schemas.py +14 -1
- zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -2
- zenml/zen_stores/sql_zen_store.py +24 -17
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/METADATA +1 -1
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/RECORD +34 -33
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.73.0.
|
1
|
+
0.73.0.dev20250131
|
zenml/cli/code_repository.py
CHANGED
@@ -162,6 +162,32 @@ def register_code_repository(
|
|
162
162
|
cli_utils.declare(f"Successfully registered code repository `{name}`.")
|
163
163
|
|
164
164
|
|
165
|
+
@code_repository.command("describe", help="Describe a code repository.")
|
166
|
+
@click.argument(
|
167
|
+
"name_id_or_prefix",
|
168
|
+
type=str,
|
169
|
+
required=True,
|
170
|
+
)
|
171
|
+
def describe_code_repository(name_id_or_prefix: str) -> None:
|
172
|
+
"""Describe a code repository.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
name_id_or_prefix: Name, ID or prefix of the code repository.
|
176
|
+
"""
|
177
|
+
client = Client()
|
178
|
+
try:
|
179
|
+
code_repository = client.get_code_repository(
|
180
|
+
name_id_or_prefix=name_id_or_prefix,
|
181
|
+
)
|
182
|
+
except KeyError as err:
|
183
|
+
cli_utils.error(str(err))
|
184
|
+
else:
|
185
|
+
cli_utils.print_pydantic_model(
|
186
|
+
title=f"Code repository '{code_repository.name}'",
|
187
|
+
model=code_repository,
|
188
|
+
)
|
189
|
+
|
190
|
+
|
165
191
|
@code_repository.command("list", help="List all connected code repositories.")
|
166
192
|
@list_options(CodeRepositoryFilter)
|
167
193
|
def list_code_repositories(**kwargs: Any) -> None:
|
zenml/cli/utils.py
CHANGED
@@ -76,6 +76,7 @@ from zenml.models import (
|
|
76
76
|
from zenml.models.v2.base.filter import FilterGenerator
|
77
77
|
from zenml.services import BaseService, ServiceState
|
78
78
|
from zenml.stack import StackComponent
|
79
|
+
from zenml.stack.flavor import Flavor
|
79
80
|
from zenml.stack.stack_component import StackComponentConfig
|
80
81
|
from zenml.utils import secret_utils
|
81
82
|
from zenml.utils.time_utils import expires_in
|
@@ -2151,10 +2152,11 @@ def _scrub_secret(config: StackComponentConfig) -> Dict[str, Any]:
|
|
2151
2152
|
config_dict = {}
|
2152
2153
|
config_fields = config.__class__.model_fields
|
2153
2154
|
for key, value in config_fields.items():
|
2154
|
-
if
|
2155
|
-
|
2156
|
-
|
2157
|
-
|
2155
|
+
if getattr(config, key):
|
2156
|
+
if secret_utils.is_secret_field(value):
|
2157
|
+
config_dict[key] = "********"
|
2158
|
+
else:
|
2159
|
+
config_dict[key] = getattr(config, key)
|
2158
2160
|
return config_dict
|
2159
2161
|
|
2160
2162
|
|
@@ -2164,8 +2166,6 @@ def print_debug_stack() -> None:
|
|
2164
2166
|
|
2165
2167
|
client = Client()
|
2166
2168
|
stack = client.get_stack()
|
2167
|
-
active_stack = client.active_stack
|
2168
|
-
components = _get_stack_components(active_stack)
|
2169
2169
|
|
2170
2170
|
declare("\nCURRENT STACK\n", bold=True)
|
2171
2171
|
console.print(f"Name: {stack.name}")
|
@@ -2176,7 +2176,8 @@ def print_debug_stack() -> None:
|
|
2176
2176
|
f"Workspace: {stack.workspace.name} / {str(stack.workspace.id)}"
|
2177
2177
|
)
|
2178
2178
|
|
2179
|
-
for
|
2179
|
+
for component_type, components in stack.components.items():
|
2180
|
+
component = components[0]
|
2180
2181
|
component_response = client.get_stack_component(
|
2181
2182
|
name_id_or_prefix=component.id, component_type=component.type
|
2182
2183
|
)
|
@@ -2186,8 +2187,12 @@ def print_debug_stack() -> None:
|
|
2186
2187
|
console.print(f"Name: {component.name}")
|
2187
2188
|
console.print(f"ID: {str(component.id)}")
|
2188
2189
|
console.print(f"Type: {component.type.value}")
|
2189
|
-
console.print(f"Flavor: {component.
|
2190
|
-
|
2190
|
+
console.print(f"Flavor: {component.flavor_name}")
|
2191
|
+
|
2192
|
+
flavor = Flavor.from_model(component.flavor)
|
2193
|
+
config = flavor.config_class(**component.configuration)
|
2194
|
+
|
2195
|
+
console.print(f"Configuration: {_scrub_secret(config)}")
|
2191
2196
|
if (
|
2192
2197
|
component_response.user
|
2193
2198
|
and component_response.user.name
|
zenml/client.py
CHANGED
@@ -34,7 +34,7 @@ from typing import (
|
|
34
34
|
Union,
|
35
35
|
cast,
|
36
36
|
)
|
37
|
-
from uuid import UUID
|
37
|
+
from uuid import UUID
|
38
38
|
|
39
39
|
from pydantic import ConfigDict, SecretStr
|
40
40
|
|
@@ -4980,12 +4980,7 @@ class Client(metaclass=ClientMetaClass):
|
|
4980
4980
|
)
|
4981
4981
|
)
|
4982
4982
|
try:
|
4983
|
-
|
4984
|
-
code_repo_class(id=uuid4(), config=config)
|
4985
|
-
|
4986
|
-
# Explicitly access the config for pydantic validation, in case
|
4987
|
-
# the login for some reason did not do that.
|
4988
|
-
_ = code_repo_class.config
|
4983
|
+
code_repo_class.validate_config(config)
|
4989
4984
|
except Exception as e:
|
4990
4985
|
raise RuntimeError(
|
4991
4986
|
"Failed to validate code repository config."
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
17
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type
|
18
|
-
from uuid import UUID
|
18
|
+
from uuid import UUID, uuid4
|
19
19
|
|
20
20
|
from zenml.config.secret_reference_mixin import SecretReferenceMixin
|
21
21
|
from zenml.logger import get_logger
|
@@ -44,15 +44,18 @@ class BaseCodeRepository(ABC):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
id: UUID,
|
47
|
+
name: str,
|
47
48
|
config: Dict[str, Any],
|
48
49
|
) -> None:
|
49
50
|
"""Initializes a code repository.
|
50
51
|
|
51
52
|
Args:
|
52
53
|
id: The ID of the code repository.
|
54
|
+
name: The name of the code repository.
|
53
55
|
config: The config of the code repository.
|
54
56
|
"""
|
55
57
|
self._id = id
|
58
|
+
self._name = name
|
56
59
|
self._config = config
|
57
60
|
self.login()
|
58
61
|
|
@@ -80,7 +83,23 @@ class BaseCodeRepository(ABC):
|
|
80
83
|
source=model.source, expected_class=BaseCodeRepository
|
81
84
|
)
|
82
85
|
)
|
83
|
-
return class_(id=model.id, config=model.config)
|
86
|
+
return class_(id=model.id, name=model.name, config=model.config)
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
90
|
+
"""Validate the code repository config.
|
91
|
+
|
92
|
+
This method should check that the config/credentials are valid and
|
93
|
+
the configured repository exists.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
config: The configuration.
|
97
|
+
"""
|
98
|
+
# The initialization calls the login to verify the credentials
|
99
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
100
|
+
|
101
|
+
# Explicitly access the config for pydantic validation
|
102
|
+
_ = code_repo.config
|
84
103
|
|
85
104
|
@property
|
86
105
|
def id(self) -> UUID:
|
@@ -91,6 +110,15 @@ class BaseCodeRepository(ABC):
|
|
91
110
|
"""
|
92
111
|
return self._id
|
93
112
|
|
113
|
+
@property
|
114
|
+
def name(self) -> str:
|
115
|
+
"""Name of the code repository.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The name of the code repository.
|
119
|
+
"""
|
120
|
+
return self._name
|
121
|
+
|
94
122
|
@property
|
95
123
|
def requirements(self) -> Set[str]:
|
96
124
|
"""Set of PyPI requirements for the repository.
|
@@ -14,11 +14,14 @@
|
|
14
14
|
"""Implementation of the Local git repository context."""
|
15
15
|
|
16
16
|
from typing import TYPE_CHECKING, Callable, Optional, cast
|
17
|
-
from uuid import UUID
|
18
17
|
|
19
18
|
from zenml.code_repositories import (
|
20
19
|
LocalRepositoryContext,
|
21
20
|
)
|
21
|
+
from zenml.constants import (
|
22
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES,
|
23
|
+
handle_bool_env_var,
|
24
|
+
)
|
22
25
|
from zenml.logger import get_logger
|
23
26
|
|
24
27
|
if TYPE_CHECKING:
|
@@ -26,6 +29,8 @@ if TYPE_CHECKING:
|
|
26
29
|
from git.remote import Remote
|
27
30
|
from git.repo.base import Repo
|
28
31
|
|
32
|
+
from zenml.code_repositories import BaseCodeRepository
|
33
|
+
|
29
34
|
logger = get_logger(__name__)
|
30
35
|
|
31
36
|
|
@@ -33,16 +38,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
33
38
|
"""Local git repository context."""
|
34
39
|
|
35
40
|
def __init__(
|
36
|
-
self,
|
41
|
+
self,
|
42
|
+
code_repository: "BaseCodeRepository",
|
43
|
+
git_repo: "Repo",
|
44
|
+
remote_name: str,
|
37
45
|
):
|
38
46
|
"""Initializes a local git repository context.
|
39
47
|
|
40
48
|
Args:
|
41
|
-
|
49
|
+
code_repository: The code repository.
|
42
50
|
git_repo: The git repo.
|
43
51
|
remote_name: Name of the remote.
|
44
52
|
"""
|
45
|
-
super().__init__(
|
53
|
+
super().__init__(code_repository=code_repository)
|
46
54
|
self._git_repo = git_repo
|
47
55
|
self._remote = git_repo.remote(name=remote_name)
|
48
56
|
|
@@ -50,14 +58,14 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
50
58
|
def at(
|
51
59
|
cls,
|
52
60
|
path: str,
|
53
|
-
|
61
|
+
code_repository: "BaseCodeRepository",
|
54
62
|
remote_url_validation_callback: Callable[[str], bool],
|
55
63
|
) -> Optional["LocalGitRepositoryContext"]:
|
56
64
|
"""Returns a local git repository at the given path.
|
57
65
|
|
58
66
|
Args:
|
59
67
|
path: The path to the local git repository.
|
60
|
-
|
68
|
+
code_repository: The code repository.
|
61
69
|
remote_url_validation_callback: A callback that validates the
|
62
70
|
remote URL of the git repository.
|
63
71
|
|
@@ -70,11 +78,13 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
70
78
|
from git.exc import InvalidGitRepositoryError
|
71
79
|
from git.repo.base import Repo
|
72
80
|
except ImportError:
|
81
|
+
logger.debug("Failed to import git library.")
|
73
82
|
return None
|
74
83
|
|
75
84
|
try:
|
76
85
|
git_repo = Repo(path=path, search_parent_directories=True)
|
77
86
|
except InvalidGitRepositoryError:
|
87
|
+
logger.debug("No git repository exists at path %s.", path)
|
78
88
|
return None
|
79
89
|
|
80
90
|
remote_name = None
|
@@ -87,7 +97,7 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
87
97
|
return None
|
88
98
|
|
89
99
|
return LocalGitRepositoryContext(
|
90
|
-
|
100
|
+
code_repository=code_repository,
|
91
101
|
git_repo=git_repo,
|
92
102
|
remote_name=remote_name,
|
93
103
|
)
|
@@ -124,13 +134,19 @@ class LocalGitRepositoryContext(LocalRepositoryContext):
|
|
124
134
|
def is_dirty(self) -> bool:
|
125
135
|
"""Whether the git repo is dirty.
|
126
136
|
|
127
|
-
|
128
|
-
changes.
|
137
|
+
By default, a repository counts as dirty if it has any untracked or
|
138
|
+
uncommitted changes. Users can use an environment variable to ignore
|
139
|
+
untracked files.
|
129
140
|
|
130
141
|
Returns:
|
131
142
|
True if the git repo is dirty, False otherwise.
|
132
143
|
"""
|
133
|
-
|
144
|
+
ignore_untracked_files = handle_bool_env_var(
|
145
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES, default=False
|
146
|
+
)
|
147
|
+
return self.git_repo.is_dirty(
|
148
|
+
untracked_files=not ignore_untracked_files
|
149
|
+
)
|
134
150
|
|
135
151
|
@property
|
136
152
|
def has_local_changes(self) -> bool:
|
@@ -14,10 +14,13 @@
|
|
14
14
|
"""Base class for local code repository contexts."""
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
|
-
from
|
17
|
+
from typing import TYPE_CHECKING
|
18
18
|
|
19
19
|
from zenml.logger import get_logger
|
20
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from zenml.code_repositories import BaseCodeRepository
|
23
|
+
|
21
24
|
logger = get_logger(__name__)
|
22
25
|
|
23
26
|
|
@@ -30,22 +33,22 @@ class LocalRepositoryContext(ABC):
|
|
30
33
|
commit, and whether the repository is dirty.
|
31
34
|
"""
|
32
35
|
|
33
|
-
def __init__(self,
|
36
|
+
def __init__(self, code_repository: "BaseCodeRepository") -> None:
|
34
37
|
"""Initializes a local repository context.
|
35
38
|
|
36
39
|
Args:
|
37
|
-
|
40
|
+
code_repository: The code repository.
|
38
41
|
"""
|
39
|
-
self.
|
42
|
+
self._code_repository = code_repository
|
40
43
|
|
41
44
|
@property
|
42
|
-
def
|
43
|
-
"""Returns the
|
45
|
+
def code_repository(self) -> "BaseCodeRepository":
|
46
|
+
"""Returns the code repository.
|
44
47
|
|
45
48
|
Returns:
|
46
|
-
The
|
49
|
+
The code repository.
|
47
50
|
"""
|
48
|
-
return self.
|
51
|
+
return self._code_repository
|
49
52
|
|
50
53
|
@property
|
51
54
|
@abstractmethod
|
zenml/constants.py
CHANGED
@@ -175,6 +175,9 @@ ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
|
|
175
175
|
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
|
176
176
|
"ZENML_PIPELINE_API_TOKEN_EXPIRATION"
|
177
177
|
)
|
178
|
+
ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES = (
|
179
|
+
"ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES"
|
180
|
+
)
|
178
181
|
|
179
182
|
# Materializer environment variables
|
180
183
|
ENV_ZENML_MATERIALIZER_ALLOW_NON_ASCII_JSON_DUMPS = (
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from google.cloud.aiplatform_v1.types.job_state import JobState
|
17
17
|
|
18
18
|
VERTEX_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com"
|
19
|
-
POLLING_INTERVAL_IN_SECONDS =
|
19
|
+
POLLING_INTERVAL_IN_SECONDS = 10
|
20
20
|
CONNECTION_ERROR_RETRY_LIMIT = 5
|
21
21
|
_VERTEX_JOB_STATE_SUCCEEDED = JobState.JOB_STATE_SUCCEEDED
|
22
22
|
_VERTEX_JOB_STATE_FAILED = JobState.JOB_STATE_FAILED
|
@@ -51,7 +51,8 @@ class VertexStepOperatorSettings(BaseSettings):
|
|
51
51
|
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
52
52
|
boot_disk_type: Type of the boot disk. (Default: pd-ssd)
|
53
53
|
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
54
|
-
|
54
|
+
persistent_resource_id: The ID of the persistent resource to use for the job.
|
55
|
+
https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
|
55
56
|
"""
|
56
57
|
|
57
58
|
accelerator_type: Optional[str] = None
|
@@ -59,6 +60,7 @@ class VertexStepOperatorSettings(BaseSettings):
|
|
59
60
|
machine_type: str = "n1-standard-4"
|
60
61
|
boot_disk_size_gb: int = 100
|
61
62
|
boot_disk_type: str = "pd-ssd"
|
63
|
+
persistent_resource_id: Optional[str] = None
|
62
64
|
|
63
65
|
|
64
66
|
class VertexStepOperatorConfig(
|
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
-
from typing import List, Optional
|
18
|
+
from typing import Any, Dict, List, Optional
|
19
|
+
from uuid import uuid4
|
19
20
|
|
20
21
|
import requests
|
21
22
|
from github import Consts, Github, GithubException
|
@@ -62,6 +63,20 @@ class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig):
|
|
62
63
|
class GitHubCodeRepository(BaseCodeRepository):
|
63
64
|
"""GitHub code repository."""
|
64
65
|
|
66
|
+
@classmethod
|
67
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
68
|
+
"""Validate the code repository config.
|
69
|
+
|
70
|
+
This method should check that the config/credentials are valid and
|
71
|
+
the configured repository exists.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
config: The configuration.
|
75
|
+
"""
|
76
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
77
|
+
# Try to access the project to make sure it exists
|
78
|
+
_ = code_repo.github_repo
|
79
|
+
|
65
80
|
@property
|
66
81
|
def config(self) -> GitHubCodeRepositoryConfig:
|
67
82
|
"""Returns the `GitHubCodeRepositoryConfig` config.
|
@@ -190,7 +205,7 @@ class GitHubCodeRepository(BaseCodeRepository):
|
|
190
205
|
"""
|
191
206
|
return LocalGitRepositoryContext.at(
|
192
207
|
path=path,
|
193
|
-
|
208
|
+
code_repository=self,
|
194
209
|
remote_url_validation_callback=self.check_remote_url,
|
195
210
|
)
|
196
211
|
|
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Any, Dict, Optional
|
19
|
+
from uuid import uuid4
|
19
20
|
|
20
21
|
from gitlab import Gitlab
|
21
22
|
from gitlab.v4.objects import Project
|
@@ -63,6 +64,20 @@ class GitLabCodeRepositoryConfig(BaseCodeRepositoryConfig):
|
|
63
64
|
class GitLabCodeRepository(BaseCodeRepository):
|
64
65
|
"""GitLab code repository."""
|
65
66
|
|
67
|
+
@classmethod
|
68
|
+
def validate_config(cls, config: Dict[str, Any]) -> None:
|
69
|
+
"""Validate the code repository config.
|
70
|
+
|
71
|
+
This method should check that the config/credentials are valid and
|
72
|
+
the configured repository exists.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
config: The configuration.
|
76
|
+
"""
|
77
|
+
code_repo = cls(id=uuid4(), name="", config=config)
|
78
|
+
# Try to access the project to make sure it exists
|
79
|
+
_ = code_repo.gitlab_project
|
80
|
+
|
66
81
|
@property
|
67
82
|
def config(self) -> GitLabCodeRepositoryConfig:
|
68
83
|
"""Returns the `GitLabCodeRepositoryConfig` config.
|
@@ -147,7 +162,7 @@ class GitLabCodeRepository(BaseCodeRepository):
|
|
147
162
|
"""
|
148
163
|
return LocalGitRepositoryContext.at(
|
149
164
|
path=path,
|
150
|
-
|
165
|
+
code_repository=self,
|
151
166
|
remote_url_validation_callback=self.check_remote_url,
|
152
167
|
)
|
153
168
|
|
@@ -13,17 +13,18 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Implementation of the Hugging Face Deployment service."""
|
15
15
|
|
16
|
-
from typing import Any, Generator, Optional, Tuple
|
16
|
+
from typing import Any, Dict, Generator, Optional, Tuple
|
17
17
|
|
18
18
|
from huggingface_hub import (
|
19
19
|
InferenceClient,
|
20
20
|
InferenceEndpoint,
|
21
21
|
InferenceEndpointError,
|
22
22
|
InferenceEndpointStatus,
|
23
|
+
InferenceEndpointType,
|
23
24
|
create_inference_endpoint,
|
24
25
|
get_inference_endpoint,
|
25
26
|
)
|
26
|
-
from huggingface_hub.
|
27
|
+
from huggingface_hub.errors import HfHubHTTPError
|
27
28
|
from pydantic import Field
|
28
29
|
|
29
30
|
from zenml.client import Client
|
@@ -138,30 +139,67 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
|
|
138
139
|
"""
|
139
140
|
return self.hf_endpoint.client
|
140
141
|
|
142
|
+
def _validate_endpoint_configuration(self) -> Dict[str, str]:
|
143
|
+
"""Validates the configuration to provision a Huggingface service.
|
144
|
+
|
145
|
+
Raises:
|
146
|
+
ValueError: if there is a missing value in the configuration
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
The validated configuration values.
|
150
|
+
"""
|
151
|
+
configuration = {}
|
152
|
+
missing_keys = []
|
153
|
+
|
154
|
+
for k, v in {
|
155
|
+
"repository": self.config.repository,
|
156
|
+
"framework": self.config.framework,
|
157
|
+
"accelerator": self.config.accelerator,
|
158
|
+
"instance_size": self.config.instance_size,
|
159
|
+
"instance_type": self.config.instance_type,
|
160
|
+
"region": self.config.region,
|
161
|
+
"vendor": self.config.vendor,
|
162
|
+
"endpoint_type": self.config.endpoint_type,
|
163
|
+
}.items():
|
164
|
+
if v is None:
|
165
|
+
missing_keys.append(k)
|
166
|
+
else:
|
167
|
+
configuration[k] = v
|
168
|
+
|
169
|
+
if missing_keys:
|
170
|
+
raise ValueError(
|
171
|
+
f"Missing values in the Huggingface Service "
|
172
|
+
f"configuration: {', '.join(missing_keys)}"
|
173
|
+
)
|
174
|
+
|
175
|
+
return configuration
|
176
|
+
|
141
177
|
def provision(self) -> None:
|
142
178
|
"""Provision or update remote Hugging Face deployment instance.
|
143
179
|
|
144
180
|
Raises:
|
145
|
-
Exception: If any unexpected error while creating inference
|
181
|
+
Exception: If any unexpected error while creating inference
|
182
|
+
endpoint.
|
146
183
|
"""
|
147
184
|
try:
|
148
|
-
|
185
|
+
validated_config = self._validate_endpoint_configuration()
|
186
|
+
|
149
187
|
hf_endpoint = create_inference_endpoint(
|
150
188
|
name=self._generate_an_endpoint_name(),
|
151
|
-
repository=
|
152
|
-
framework=
|
153
|
-
accelerator=
|
154
|
-
instance_size=
|
155
|
-
instance_type=
|
156
|
-
region=
|
157
|
-
vendor=
|
189
|
+
repository=validated_config["repository"],
|
190
|
+
framework=validated_config["framework"],
|
191
|
+
accelerator=validated_config["accelerator"],
|
192
|
+
instance_size=validated_config["instance_size"],
|
193
|
+
instance_type=validated_config["instance_type"],
|
194
|
+
region=validated_config["region"],
|
195
|
+
vendor=validated_config["vendor"],
|
158
196
|
account_id=self.config.account_id,
|
159
197
|
min_replica=self.config.min_replica,
|
160
198
|
max_replica=self.config.max_replica,
|
161
199
|
revision=self.config.revision,
|
162
200
|
task=self.config.task,
|
163
201
|
custom_image=self.config.custom_image,
|
164
|
-
type=
|
202
|
+
type=InferenceEndpointType(validated_config["endpoint_type"]),
|
165
203
|
token=self.get_token(),
|
166
204
|
namespace=self.config.namespace,
|
167
205
|
).wait(timeout=POLLING_TIMEOUT)
|
@@ -172,21 +210,25 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
|
|
172
210
|
)
|
173
211
|
# Catch-all for any other unexpected errors
|
174
212
|
raise Exception(
|
175
|
-
|
213
|
+
"An unexpected error occurred while provisioning the "
|
214
|
+
f"Hugging Face inference endpoint: {e}"
|
176
215
|
)
|
177
216
|
|
178
217
|
# Check if the endpoint URL is available after provisioning
|
179
218
|
if hf_endpoint.url:
|
180
219
|
logger.info(
|
181
|
-
|
220
|
+
"Hugging Face inference endpoint successfully deployed "
|
221
|
+
f"and available. Endpoint URL: {hf_endpoint.url}"
|
182
222
|
)
|
183
223
|
else:
|
184
224
|
logger.error(
|
185
|
-
"Failed to start Hugging Face inference endpoint
|
225
|
+
"Failed to start Hugging Face inference endpoint "
|
226
|
+
"service: No URL available, please check the Hugging "
|
227
|
+
"Face console for more details."
|
186
228
|
)
|
187
229
|
|
188
230
|
def check_status(self) -> Tuple[ServiceState, str]:
|
189
|
-
"""Check the
|
231
|
+
"""Check the current operational state of the Hugging Face deployment.
|
190
232
|
|
191
233
|
Returns:
|
192
234
|
The operational state of the Hugging Face deployment and a message
|
@@ -196,26 +238,29 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
|
|
196
238
|
try:
|
197
239
|
status = self.hf_endpoint.status
|
198
240
|
if status == InferenceEndpointStatus.RUNNING:
|
199
|
-
return
|
241
|
+
return ServiceState.ACTIVE, ""
|
200
242
|
|
201
243
|
elif status == InferenceEndpointStatus.SCALED_TO_ZERO:
|
202
244
|
return (
|
203
245
|
ServiceState.SCALED_TO_ZERO,
|
204
|
-
"Hugging Face Inference Endpoint is scaled to zero, but
|
246
|
+
"Hugging Face Inference Endpoint is scaled to zero, but "
|
247
|
+
"still running. It will be started on demand.",
|
205
248
|
)
|
206
249
|
|
207
250
|
elif status == InferenceEndpointStatus.FAILED:
|
208
251
|
return (
|
209
252
|
ServiceState.ERROR,
|
210
|
-
"Hugging Face Inference Endpoint deployment is inactive
|
253
|
+
"Hugging Face Inference Endpoint deployment is inactive "
|
254
|
+
"or not found",
|
211
255
|
)
|
212
256
|
elif status == InferenceEndpointStatus.PENDING:
|
213
|
-
return
|
214
|
-
return
|
257
|
+
return ServiceState.PENDING_STARTUP, ""
|
258
|
+
return ServiceState.PENDING_STARTUP, ""
|
215
259
|
except (InferenceEndpointError, HfHubHTTPError):
|
216
260
|
return (
|
217
261
|
ServiceState.INACTIVE,
|
218
|
-
"Hugging Face Inference Endpoint deployment is inactive or
|
262
|
+
"Hugging Face Inference Endpoint deployment is inactive or "
|
263
|
+
"not found",
|
219
264
|
)
|
220
265
|
|
221
266
|
def deprovision(self, force: bool = False) -> None:
|
@@ -253,15 +298,13 @@ class HuggingFaceDeploymentService(BaseDeploymentService):
|
|
253
298
|
)
|
254
299
|
if self.prediction_url is not None:
|
255
300
|
if self.hf_endpoint.task == "text-generation":
|
256
|
-
|
301
|
+
return self.inference_client.text_generation(
|
257
302
|
data, max_new_tokens=max_new_tokens
|
258
303
|
)
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
)
|
264
|
-
return result
|
304
|
+
# TODO: Add support for all different supported tasks
|
305
|
+
raise NotImplementedError(
|
306
|
+
"Tasks other than text-generation is not implemented."
|
307
|
+
)
|
265
308
|
|
266
309
|
def get_logs(
|
267
310
|
self, follow: bool = False, tail: Optional[int] = None
|
@@ -44,7 +44,7 @@ class BasePyTorchMaterializer(BaseMaterializer):
|
|
44
44
|
# NOTE (security): The `torch.load` function uses `pickle` as
|
45
45
|
# the default unpickler, which is NOT secure. This materializer
|
46
46
|
# is intended for use with trusted data sources.
|
47
|
-
return torch.load(f) # nosec
|
47
|
+
return torch.load(f, weights_only=False) # nosec
|
48
48
|
|
49
49
|
def save(self, obj: Any) -> None:
|
50
50
|
"""Uses `torch.save` to save a PyTorch object.
|