snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""CodePath class for selective code packaging in model registry."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
_ERR_ROOT_NOT_FOUND = "CodePath: root '{root}' does not exist (resolved to {resolved})."
|
|
8
|
+
_ERR_WILDCARDS_NOT_SUPPORTED = "CodePath: Wildcards are not supported in filter. Got '{filter}'. Use exact paths only."
|
|
9
|
+
_ERR_FILTER_MUST_BE_RELATIVE = "CodePath: filter must be a relative path, got absolute path '{filter}'."
|
|
10
|
+
_ERR_FILTER_HOME_PATH = "CodePath: filter must be a relative path, got home directory path '{filter}'."
|
|
11
|
+
_ERR_FILTER_ON_FILE_ROOT = (
|
|
12
|
+
"CodePath: cannot apply filter to a file root. " "Root '{root}' is a file. Use filter only with directory roots."
|
|
13
|
+
)
|
|
14
|
+
_ERR_FILTER_ESCAPES_ROOT = (
|
|
15
|
+
"CodePath: filter '{filter}' escapes root directory '{root}'. " "Relative paths must stay within root."
|
|
16
|
+
)
|
|
17
|
+
_ERR_FILTER_NOT_FOUND = "CodePath: filter '{filter}' under root '{root}' does not exist (resolved to {resolved})."
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class CodePath:
|
|
22
|
+
"""Specifies a code path with optional filtering for selective inclusion.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root: The root directory or file path (absolute or relative to cwd).
|
|
26
|
+
filter: Optional relative path under root to select a subdirectory or file.
|
|
27
|
+
The filter also determines the destination path under code/.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
CodePath("project/src/") # Copy entire src/ to code/src/
|
|
31
|
+
CodePath("project/src/", filter="utils") # Copy utils/ to code/utils/
|
|
32
|
+
CodePath("project/src/", filter="lib/helpers") # Copy to code/lib/helpers/
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
root: str
|
|
36
|
+
filter: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
if self.filter == "":
|
|
40
|
+
object.__setattr__(self, "filter", None)
|
|
41
|
+
|
|
42
|
+
def __repr__(self) -> str:
|
|
43
|
+
if self.filter:
|
|
44
|
+
return f"CodePath({self.root!r}, filter={self.filter!r})"
|
|
45
|
+
return f"CodePath({self.root!r})"
|
|
46
|
+
|
|
47
|
+
def _validate_filter(self) -> Optional[str]:
|
|
48
|
+
"""Validate and normalize filter, returning normalized filter or None.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Normalized filter path, or None if no filter is set.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If filter contains wildcards or is an absolute path.
|
|
55
|
+
"""
|
|
56
|
+
if self.filter is None:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
if any(c in self.filter for c in ["*", "?", "[", "]"]):
|
|
60
|
+
raise ValueError(_ERR_WILDCARDS_NOT_SUPPORTED.format(filter=self.filter))
|
|
61
|
+
|
|
62
|
+
if self.filter.startswith("~"):
|
|
63
|
+
raise ValueError(_ERR_FILTER_HOME_PATH.format(filter=self.filter))
|
|
64
|
+
|
|
65
|
+
filter_normalized = os.path.normpath(self.filter)
|
|
66
|
+
|
|
67
|
+
if os.path.isabs(filter_normalized):
|
|
68
|
+
raise ValueError(_ERR_FILTER_MUST_BE_RELATIVE.format(filter=self.filter))
|
|
69
|
+
|
|
70
|
+
return filter_normalized
|
|
71
|
+
|
|
72
|
+
def _resolve(self) -> tuple[str, str]:
|
|
73
|
+
"""Resolve the source path and destination path.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (source_path, destination_relative_path)
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
FileNotFoundError: If root or filter path does not exist.
|
|
80
|
+
ValueError: If filter is invalid (wildcards, absolute, escapes root, or applied to file).
|
|
81
|
+
"""
|
|
82
|
+
filter_normalized = self._validate_filter()
|
|
83
|
+
root_normalized = os.path.normpath(os.path.abspath(self.root))
|
|
84
|
+
|
|
85
|
+
if filter_normalized is None:
|
|
86
|
+
if not os.path.exists(root_normalized):
|
|
87
|
+
raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
|
|
88
|
+
return root_normalized, os.path.basename(root_normalized)
|
|
89
|
+
|
|
90
|
+
if not os.path.exists(root_normalized):
|
|
91
|
+
raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
|
|
92
|
+
|
|
93
|
+
if os.path.isfile(root_normalized):
|
|
94
|
+
raise ValueError(_ERR_FILTER_ON_FILE_ROOT.format(root=self.root))
|
|
95
|
+
|
|
96
|
+
source = os.path.normpath(os.path.join(root_normalized, filter_normalized))
|
|
97
|
+
|
|
98
|
+
if not (source.startswith(root_normalized + os.sep) or source == root_normalized):
|
|
99
|
+
raise ValueError(_ERR_FILTER_ESCAPES_ROOT.format(filter=self.filter, root=self.root))
|
|
100
|
+
|
|
101
|
+
if not os.path.exists(source):
|
|
102
|
+
raise FileNotFoundError(_ERR_FILTER_NOT_FOUND.format(filter=self.filter, root=self.root, resolved=source))
|
|
103
|
+
|
|
104
|
+
return source, filter_normalized
|
|
@@ -4,10 +4,13 @@ from typing import Any, Callable, Coroutine, Generator, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
import anyio
|
|
6
6
|
import pandas as pd
|
|
7
|
-
from typing_extensions import deprecated
|
|
7
|
+
from typing_extensions import Concatenate, ParamSpec, deprecated
|
|
8
8
|
|
|
9
9
|
from snowflake.ml.model import type_hints as model_types
|
|
10
10
|
|
|
11
|
+
# Captures additional keyword-only parameters for inference methods
|
|
12
|
+
InferenceParams = ParamSpec("InferenceParams")
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
class MethodRef:
|
|
13
16
|
"""Represents a method invocation of an instance of `ModelRef`.
|
|
@@ -217,7 +220,7 @@ class CustomModel:
|
|
|
217
220
|
|
|
218
221
|
def _get_infer_methods(
|
|
219
222
|
self,
|
|
220
|
-
) -> Generator[Callable[
|
|
223
|
+
) -> Generator[Callable[..., pd.DataFrame], None, None]:
|
|
221
224
|
"""Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
|
|
222
225
|
for cls_method_str in dir(self):
|
|
223
226
|
cls_method = getattr(self, cls_method_str)
|
|
@@ -240,7 +243,7 @@ class CustomModel:
|
|
|
240
243
|
return rv
|
|
241
244
|
|
|
242
245
|
|
|
243
|
-
def _validate_predict_function(func: Callable[
|
|
246
|
+
def _validate_predict_function(func: Callable[..., pd.DataFrame]) -> None:
|
|
244
247
|
"""Validate the user provided predict method.
|
|
245
248
|
|
|
246
249
|
Args:
|
|
@@ -248,19 +251,22 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
|
|
|
248
251
|
|
|
249
252
|
Raises:
|
|
250
253
|
TypeError: Raised when the method is not a callable object.
|
|
251
|
-
TypeError: Raised when the method does not have 2 arguments (self and X).
|
|
254
|
+
TypeError: Raised when the method does not have at least 2 arguments (self and X).
|
|
252
255
|
TypeError: Raised when the method does not have typing annotation.
|
|
253
256
|
TypeError: Raised when the method's input (X) does not have type pd.DataFrame.
|
|
254
257
|
TypeError: Raised when the method's output does not have type pd.DataFrame.
|
|
258
|
+
TypeError: Raised when additional parameters are not keyword-only with defaults.
|
|
255
259
|
"""
|
|
256
260
|
if not callable(func):
|
|
257
261
|
raise TypeError("Predict method is not callable.")
|
|
258
262
|
|
|
259
263
|
func_signature = inspect.signature(func)
|
|
260
|
-
|
|
261
|
-
|
|
264
|
+
func_signature_params = list(func_signature.parameters.values())
|
|
265
|
+
|
|
266
|
+
if len(func_signature_params) < 2:
|
|
267
|
+
raise TypeError("Predict method should have at least 2 arguments.")
|
|
262
268
|
|
|
263
|
-
input_annotation =
|
|
269
|
+
input_annotation = func_signature_params[1].annotation
|
|
264
270
|
output_annotation = func_signature.return_annotation
|
|
265
271
|
|
|
266
272
|
if input_annotation == inspect.Parameter.empty or output_annotation == inspect.Signature.empty:
|
|
@@ -275,17 +281,53 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
|
|
|
275
281
|
):
|
|
276
282
|
raise TypeError("Output for predict method should have type pandas.DataFrame.")
|
|
277
283
|
|
|
284
|
+
# Validate additional parameters (beyond self and input) are keyword-only with defaults
|
|
285
|
+
for func_signature_param in func_signature_params[2:]:
|
|
286
|
+
_validate_parameter(func_signature_param)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _validate_parameter(param: inspect.Parameter) -> None:
|
|
290
|
+
"""Validate a parameter."""
|
|
291
|
+
if param.kind != inspect.Parameter.KEYWORD_ONLY:
|
|
292
|
+
raise TypeError(f"Parameter '{param.name}' must be keyword-only (defined after '*' in signature).")
|
|
293
|
+
if param.default == inspect.Parameter.empty:
|
|
294
|
+
raise TypeError(f"Parameter '{param.name}' must have a default value.")
|
|
295
|
+
if param.annotation == inspect.Parameter.empty:
|
|
296
|
+
raise TypeError(f"Parameter '{param.name}' must have a type annotation.")
|
|
297
|
+
|
|
298
|
+
# Validate annotation is a supported type
|
|
299
|
+
supported_types = {int, float, str, bool}
|
|
300
|
+
if param.annotation not in supported_types:
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Parameter '{param.name}' has unsupported type annotation '{param.annotation}'. "
|
|
303
|
+
f"Supported types are: int, float, str, bool"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def get_method_parameters(func: Callable[..., Any]) -> list[tuple[str, Any, Any]]:
|
|
308
|
+
"""Extract keyword-only parameters with defaults from an inference method.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
func: The inference method.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
A list of tuples (name, type annotation, default value) for each keyword-only parameter.
|
|
315
|
+
"""
|
|
316
|
+
func_signature = inspect.signature(func)
|
|
317
|
+
params = list(func_signature.parameters.values())
|
|
318
|
+
return [(param.name, param.annotation, param.default) for param in params[2:]]
|
|
319
|
+
|
|
278
320
|
|
|
279
321
|
def inference_api(
|
|
280
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
281
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
322
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
323
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
282
324
|
func.__dict__["_is_inference_api"] = True
|
|
283
325
|
return func
|
|
284
326
|
|
|
285
327
|
|
|
286
328
|
def partitioned_api(
|
|
287
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
288
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
329
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
330
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
289
331
|
func.__dict__["_is_inference_api"] = True
|
|
290
332
|
func.__dict__["_is_partitioned_api"] = True
|
|
291
333
|
return func
|
|
@@ -296,8 +338,8 @@ def partitioned_api(
|
|
|
296
338
|
" Use snowflake.ml.custom_model.partitioned_api instead."
|
|
297
339
|
)
|
|
298
340
|
def partitioned_inference_api(
|
|
299
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
300
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
341
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
342
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
301
343
|
func.__dict__["_is_inference_api"] = True
|
|
302
344
|
func.__dict__["_is_partitioned_api"] = True
|
|
303
345
|
return func
|
|
@@ -34,6 +34,9 @@ DataType = core.DataType
|
|
|
34
34
|
BaseFeatureSpec = core.BaseFeatureSpec
|
|
35
35
|
FeatureSpec = core.FeatureSpec
|
|
36
36
|
FeatureGroupSpec = core.FeatureGroupSpec
|
|
37
|
+
BaseParamSpec = core.BaseParamSpec
|
|
38
|
+
ParamSpec = core.ParamSpec
|
|
39
|
+
ParamGroupSpec = core.ParamGroupSpec
|
|
37
40
|
ModelSignature = core.ModelSignature
|
|
38
41
|
|
|
39
42
|
|
|
@@ -711,6 +714,7 @@ def infer_signature(
|
|
|
711
714
|
output_feature_names: Optional[list[str]] = None,
|
|
712
715
|
input_data_limit: Optional[int] = 100,
|
|
713
716
|
output_data_limit: Optional[int] = 100,
|
|
717
|
+
params: Optional[Sequence[core.BaseParamSpec]] = None,
|
|
714
718
|
) -> core.ModelSignature:
|
|
715
719
|
"""
|
|
716
720
|
Infer model signature from given input and output sample data.
|
|
@@ -740,12 +744,20 @@ def infer_signature(
|
|
|
740
744
|
output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
|
|
741
745
|
100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
|
|
742
746
|
are used.
|
|
747
|
+
params: Optional sequence of parameter specifications to include in the signature. Parameters define
|
|
748
|
+
optional configuration values that can be passed to model inference. Defaults to None.
|
|
749
|
+
|
|
750
|
+
Raises:
|
|
751
|
+
SnowflakeMLException: ValueError: Raised when input data contains columns matching parameter names.
|
|
743
752
|
|
|
744
753
|
Returns:
|
|
745
754
|
A model signature inferred from the given input and output sample data.
|
|
755
|
+
|
|
756
|
+
# noqa: DAR402
|
|
746
757
|
"""
|
|
747
758
|
inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
|
|
748
759
|
inputs = utils.rename_features(inputs, input_feature_names)
|
|
760
|
+
|
|
749
761
|
outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
|
|
750
762
|
outputs = utils.rename_features(outputs, output_feature_names)
|
|
751
|
-
return core.ModelSignature(inputs, outputs)
|
|
763
|
+
return core.ModelSignature(inputs, outputs, params=params)
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
from packaging import version
|
|
6
|
+
|
|
7
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
8
|
+
from snowflake.ml.model.compute_pool import DEFAULT_CPU_COMPUTE_POOL
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_TELEMETRY_PROJECT = "MLOps"
|
|
14
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TransformersPipeline:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
task: Optional[str] = None,
|
|
21
|
+
model: Optional[str] = None,
|
|
22
|
+
*,
|
|
23
|
+
revision: Optional[str] = None,
|
|
24
|
+
token_or_secret: Optional[str] = None,
|
|
25
|
+
trust_remote_code: Optional[bool] = None,
|
|
26
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
27
|
+
compute_pool_for_log: Optional[str] = DEFAULT_CPU_COMPUTE_POOL,
|
|
28
|
+
# repo snapshot download args
|
|
29
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
30
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
31
|
+
**kwargs: Any,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Utility factory method to build a wrapper over transformers [`Pipeline`].
|
|
35
|
+
When deploying, this wrapper will create a real pipeline object and loading tokenizers and models.
|
|
36
|
+
|
|
37
|
+
For pipelines docs, please refer:
|
|
38
|
+
https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
task: The task that pipeline will be used. If None it would be inferred from model.
|
|
42
|
+
For available tasks, please refer Transformers's documentation. Defaults to None.
|
|
43
|
+
model: The model that will be used by the pipeline to make predictions. This can only be a model identifier
|
|
44
|
+
currently. If not provided, the default for the `task` will be loaded. Defaults to None.
|
|
45
|
+
revision: When passing a task name or a string model identifier: The specific model version to use. It can
|
|
46
|
+
be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and
|
|
47
|
+
other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. Defaults to None.
|
|
48
|
+
token_or_secret: The token to use as HTTP bearer authorization for remote files. Defaults to None.
|
|
49
|
+
The token can be a token or a secret. If a secret is provided, it must a fully qualified secret name.
|
|
50
|
+
trust_remote_code: Whether or not to allow for custom code defined on the Hub in their own modeling,
|
|
51
|
+
configuration, tokenization or even pipeline files. This option should only be set to `True` for
|
|
52
|
+
repositories you trust and in which you have read the code, as it will execute code present on the Hub.
|
|
53
|
+
Defaults to None.
|
|
54
|
+
model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,`.
|
|
55
|
+
Defaults to None.
|
|
56
|
+
compute_pool_for_log: The compute pool to use for logging the model. Defaults to DEFAULT_CPU_COMPUTE_POOL.
|
|
57
|
+
If a string is provided, it will be used as the compute pool name. This override allows for logging
|
|
58
|
+
the model when there is no system compute pool available.
|
|
59
|
+
If None is passed,
|
|
60
|
+
if `huggingface_hub` is installed, the model artifacts will be downloaded
|
|
61
|
+
from the HuggingFace repository.
|
|
62
|
+
otherwise, the only the metadata will be logged to snowflake.
|
|
63
|
+
allow_patterns: If provided, only files matching at least one pattern are downloaded.
|
|
64
|
+
ignore_patterns: If provided, files matching any of the patterns are not downloaded.
|
|
65
|
+
kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for
|
|
66
|
+
the corresponding pipeline class for possible values).
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
RuntimeError: Raised when the input argument cannot determine the pipeline.
|
|
70
|
+
ValueError: Raised when the pipeline contains remote code but trust_remote_code is not set or False.
|
|
71
|
+
ValueError: Raised when having conflicting arguments.
|
|
72
|
+
|
|
73
|
+
.. # noqa: DAR003
|
|
74
|
+
"""
|
|
75
|
+
import transformers
|
|
76
|
+
|
|
77
|
+
config = kwargs.get("config", None)
|
|
78
|
+
tokenizer = kwargs.get("tokenizer", None)
|
|
79
|
+
framework = kwargs.get("framework", None)
|
|
80
|
+
feature_extractor = kwargs.get("feature_extractor", None)
|
|
81
|
+
|
|
82
|
+
self.secret_identifier: Optional[str] = None
|
|
83
|
+
uses_secret = False
|
|
84
|
+
if token_or_secret is not None and isinstance(token_or_secret, str):
|
|
85
|
+
db, schema, secret_name = sql_identifier.parse_fully_qualified_name(token_or_secret)
|
|
86
|
+
if db is not None and schema is not None and secret_name is not None:
|
|
87
|
+
self.secret_identifier = sql_identifier.get_fully_qualified_name(
|
|
88
|
+
db=db,
|
|
89
|
+
schema=schema,
|
|
90
|
+
object=secret_name,
|
|
91
|
+
)
|
|
92
|
+
uses_secret = True
|
|
93
|
+
else:
|
|
94
|
+
logger.info("The token_or_secret is not a fully qualified secret name. It will be used as is.")
|
|
95
|
+
|
|
96
|
+
can_download_snapshot = False
|
|
97
|
+
if compute_pool_for_log is None:
|
|
98
|
+
try:
|
|
99
|
+
import huggingface_hub as hf_hub
|
|
100
|
+
|
|
101
|
+
can_download_snapshot = True
|
|
102
|
+
except ImportError:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
if compute_pool_for_log is None and not can_download_snapshot:
|
|
106
|
+
logger.info(
|
|
107
|
+
"The model will be logged with metadata only. No model artifacts will be downloaded. "
|
|
108
|
+
"During deployment, the model artifacts will be downloaded from the HuggingFace repository."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# ==== Start pipeline logic from transformers ====
|
|
112
|
+
if model_kwargs is None:
|
|
113
|
+
model_kwargs = {}
|
|
114
|
+
|
|
115
|
+
use_auth_token = model_kwargs.pop("use_auth_token", None)
|
|
116
|
+
if use_auth_token is not None:
|
|
117
|
+
warnings.warn(
|
|
118
|
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
|
|
119
|
+
FutureWarning,
|
|
120
|
+
stacklevel=2,
|
|
121
|
+
)
|
|
122
|
+
if token_or_secret is not None:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"`token_or_secret` and `use_auth_token` are both specified. "
|
|
125
|
+
"Please set only the argument `token_or_secret`."
|
|
126
|
+
)
|
|
127
|
+
token_or_secret = use_auth_token
|
|
128
|
+
|
|
129
|
+
hub_kwargs = {
|
|
130
|
+
"revision": revision,
|
|
131
|
+
"token": token_or_secret,
|
|
132
|
+
"trust_remote_code": trust_remote_code,
|
|
133
|
+
"_commit_hash": None,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
# Backward compatibility since HF interface change.
|
|
137
|
+
if version.parse(transformers.__version__) < version.parse("4.32.0"):
|
|
138
|
+
# Backward compatibility since HF interface change.
|
|
139
|
+
hub_kwargs["use_auth_token"] = hub_kwargs["token"]
|
|
140
|
+
del hub_kwargs["token"]
|
|
141
|
+
|
|
142
|
+
if task is None and model is None:
|
|
143
|
+
raise RuntimeError(
|
|
144
|
+
"Impossible to instantiate a pipeline without either a task or a model being specified. "
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if model is None and tokenizer is not None:
|
|
148
|
+
raise RuntimeError(
|
|
149
|
+
"Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided"
|
|
150
|
+
" tokenizer may not be compatible with the default model. Please provide an identifier to a pretrained"
|
|
151
|
+
" model when providing tokenizer."
|
|
152
|
+
)
|
|
153
|
+
if model is None and feature_extractor is not None:
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
"Impossible to instantiate a pipeline with feature_extractor specified but not the model as the "
|
|
156
|
+
"provided feature_extractor may not be compatible with the default model. Please provide an identifier"
|
|
157
|
+
" to a pretrained model when providing feature_extractor."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# ==== End pipeline logic from transformers ====
|
|
161
|
+
|
|
162
|
+
# We only support string as model argument.
|
|
163
|
+
|
|
164
|
+
if model is not None and not isinstance(model, str):
|
|
165
|
+
raise RuntimeError(f"Impossible to use non-string model as input for class {self.__class__.__name__}.")
|
|
166
|
+
|
|
167
|
+
# ==== Start pipeline logic (Config) from transformers ====
|
|
168
|
+
|
|
169
|
+
# Config is the primordial information item.
|
|
170
|
+
# Instantiate config if needed
|
|
171
|
+
config_obj = None
|
|
172
|
+
|
|
173
|
+
if not can_download_snapshot:
|
|
174
|
+
if isinstance(config, str):
|
|
175
|
+
config_obj = transformers.AutoConfig.from_pretrained(
|
|
176
|
+
config, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
177
|
+
)
|
|
178
|
+
hub_kwargs["_commit_hash"] = config_obj._commit_hash
|
|
179
|
+
elif config is None and isinstance(model, str):
|
|
180
|
+
config_obj = transformers.AutoConfig.from_pretrained(
|
|
181
|
+
model, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
182
|
+
)
|
|
183
|
+
hub_kwargs["_commit_hash"] = config_obj._commit_hash
|
|
184
|
+
# We only support string as config argument.
|
|
185
|
+
elif config is not None and not isinstance(config, str):
|
|
186
|
+
raise RuntimeError(
|
|
187
|
+
f"Impossible to use non-string config as input for class {self.__class__.__name__}. "
|
|
188
|
+
"Use transformers.Pipeline object if required."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# ==== Start pipeline logic (Task) from transformers ====
|
|
192
|
+
|
|
193
|
+
custom_tasks = {}
|
|
194
|
+
if config_obj is not None and len(getattr(config_obj, "custom_pipelines", {})) > 0:
|
|
195
|
+
custom_tasks = config_obj.custom_pipelines
|
|
196
|
+
if task is None and trust_remote_code is not False:
|
|
197
|
+
if len(custom_tasks) == 1:
|
|
198
|
+
task = list(custom_tasks.keys())[0]
|
|
199
|
+
else:
|
|
200
|
+
raise RuntimeError(
|
|
201
|
+
"We can't infer the task automatically for this model as there are multiple tasks available. "
|
|
202
|
+
f"Pick one in {', '.join(custom_tasks.keys())}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if task is None and model is not None:
|
|
206
|
+
task = transformers.pipelines.get_task(model, token_or_secret)
|
|
207
|
+
|
|
208
|
+
# Retrieve the task
|
|
209
|
+
if task in custom_tasks:
|
|
210
|
+
normalized_task = task
|
|
211
|
+
targeted_task, task_options = transformers.pipelines.clean_custom_task(custom_tasks[task])
|
|
212
|
+
if not trust_remote_code:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Loading this pipeline requires you to execute the code in the pipeline file in that"
|
|
215
|
+
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
|
216
|
+
" set the option `trust_remote_code=True` to remove this error."
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
(
|
|
220
|
+
normalized_task,
|
|
221
|
+
targeted_task,
|
|
222
|
+
task_options,
|
|
223
|
+
) = transformers.pipelines.check_task(task)
|
|
224
|
+
|
|
225
|
+
# ==== Start pipeline logic (Model) from transformers ====
|
|
226
|
+
|
|
227
|
+
# Use default model/config/tokenizer for the task if no model is provided
|
|
228
|
+
if model is None:
|
|
229
|
+
# At that point framework might still be undetermined
|
|
230
|
+
(
|
|
231
|
+
model,
|
|
232
|
+
default_revision,
|
|
233
|
+
) = transformers.pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
|
|
234
|
+
revision = revision if revision is not None else default_revision
|
|
235
|
+
warnings.warn(
|
|
236
|
+
f"No model was supplied, defaulted to {model} and revision"
|
|
237
|
+
f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
|
238
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended.",
|
|
239
|
+
stacklevel=2,
|
|
240
|
+
)
|
|
241
|
+
if not can_download_snapshot and config is None and isinstance(model, str):
|
|
242
|
+
config_obj = transformers.AutoConfig.from_pretrained(
|
|
243
|
+
model, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
244
|
+
)
|
|
245
|
+
hub_kwargs["_commit_hash"] = config_obj._commit_hash
|
|
246
|
+
|
|
247
|
+
if kwargs.get("device_map", None) is not None:
|
|
248
|
+
if "device_map" in model_kwargs:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
|
|
251
|
+
" arguments might conflict, use only one.)"
|
|
252
|
+
)
|
|
253
|
+
if kwargs.get("device", None) is not None:
|
|
254
|
+
warnings.warn(
|
|
255
|
+
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
|
256
|
+
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
|
|
257
|
+
stacklevel=2,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
repo_snapshot_dir: Optional[str] = None
|
|
261
|
+
if can_download_snapshot and not uses_secret:
|
|
262
|
+
try:
|
|
263
|
+
repo_snapshot_dir = hf_hub.snapshot_download(
|
|
264
|
+
repo_id=model,
|
|
265
|
+
revision=revision,
|
|
266
|
+
token=token_or_secret,
|
|
267
|
+
allow_patterns=allow_patterns,
|
|
268
|
+
ignore_patterns=ignore_patterns,
|
|
269
|
+
)
|
|
270
|
+
except ImportError:
|
|
271
|
+
logger.info("huggingface_hub package is not installed, skipping snapshot download")
|
|
272
|
+
|
|
273
|
+
# ==== End pipeline logic from transformers ====
|
|
274
|
+
|
|
275
|
+
self.model = model
|
|
276
|
+
self.task = normalized_task
|
|
277
|
+
self.revision = revision
|
|
278
|
+
self.token_or_secret = token_or_secret
|
|
279
|
+
self.trust_remote_code = trust_remote_code
|
|
280
|
+
self.model_kwargs = model_kwargs
|
|
281
|
+
self.tokenizer = tokenizer
|
|
282
|
+
|
|
283
|
+
self.repo_snapshot_dir = repo_snapshot_dir
|
|
284
|
+
self.compute_pool_for_log = compute_pool_for_log
|
|
285
|
+
self.__dict__.update(kwargs)
|