snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +2 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +109 -32
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +45 -2
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +81 -61
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +30 -29
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
|
@@ -23,24 +23,55 @@ if TYPE_CHECKING:
|
|
|
23
23
|
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
|
+
# Allowlist of supported target methods for SentenceTransformer models.
|
|
27
|
+
_ALLOWED_TARGET_METHODS = ["encode", "encode_queries", "encode_documents"]
|
|
28
|
+
|
|
26
29
|
|
|
27
30
|
def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
"""Validate signatures for SentenceTransformer models.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
sigs: Dictionary mapping method names to their signatures.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If signatures are empty, contain unsupported methods, or violate per-method constraints.
|
|
38
|
+
"""
|
|
39
|
+
# Check that signatures are non-empty
|
|
40
|
+
if not sigs:
|
|
41
|
+
raise ValueError("At least one signature must be provided.")
|
|
42
|
+
|
|
43
|
+
# Check that all methods are in the allowlist
|
|
44
|
+
unsupported_methods = set(sigs.keys()) - set(_ALLOWED_TARGET_METHODS)
|
|
45
|
+
if unsupported_methods:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported target methods: {sorted(unsupported_methods)}. "
|
|
48
|
+
f"Supported methods are: {_ALLOWED_TARGET_METHODS}."
|
|
49
|
+
)
|
|
30
50
|
|
|
31
|
-
|
|
32
|
-
|
|
51
|
+
# Validate per-method constraints
|
|
52
|
+
for method_name, sig in sigs.items():
|
|
53
|
+
if len(sig.inputs) != 1:
|
|
54
|
+
raise ValueError(f"SentenceTransformer method '{method_name}' must have exactly 1 input column.")
|
|
33
55
|
|
|
34
|
-
|
|
35
|
-
|
|
56
|
+
if len(sig.outputs) != 1:
|
|
57
|
+
raise ValueError(f"SentenceTransformer method '{method_name}' must have exactly 1 output column.")
|
|
36
58
|
|
|
37
|
-
|
|
59
|
+
# FeatureSpec is expected here; FeatureGroupSpec would indicate a nested/grouped input
|
|
60
|
+
# which SentenceTransformer does not support.
|
|
61
|
+
if not isinstance(sig.inputs[0], model_signature.FeatureSpec):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"SentenceTransformer method '{method_name}' requires a FeatureSpec input, "
|
|
64
|
+
f"got {type(sig.inputs[0]).__name__}."
|
|
65
|
+
)
|
|
38
66
|
|
|
39
|
-
|
|
40
|
-
|
|
67
|
+
if sig.inputs[0]._shape is not None:
|
|
68
|
+
raise ValueError(f"SentenceTransformer method '{method_name}' does not support input shape.")
|
|
41
69
|
|
|
42
|
-
|
|
43
|
-
|
|
70
|
+
if sig.inputs[0]._dtype != model_signature.DataType.STRING:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"SentenceTransformer method '{method_name}' only accepts STRING input, "
|
|
73
|
+
f"got {sig.inputs[0]._dtype.name}."
|
|
74
|
+
)
|
|
44
75
|
|
|
45
76
|
|
|
46
77
|
@final
|
|
@@ -51,7 +82,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
51
82
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
|
52
83
|
|
|
53
84
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
|
54
|
-
DEFAULT_TARGET_METHODS = ["encode"]
|
|
85
|
+
DEFAULT_TARGET_METHODS = ["encode", "encode_queries", "encode_documents"]
|
|
55
86
|
|
|
56
87
|
@classmethod
|
|
57
88
|
def can_handle(
|
|
@@ -98,8 +129,13 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
98
129
|
target_methods=kwargs.pop("target_methods", None),
|
|
99
130
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
|
100
131
|
)
|
|
101
|
-
|
|
102
|
-
|
|
132
|
+
|
|
133
|
+
# Validate target_methods
|
|
134
|
+
if not target_methods:
|
|
135
|
+
raise ValueError("At least one target method must be specified.")
|
|
136
|
+
|
|
137
|
+
if not set(target_methods).issubset(_ALLOWED_TARGET_METHODS):
|
|
138
|
+
raise ValueError(f"target_methods {target_methods} must be a subset of {_ALLOWED_TARGET_METHODS}.")
|
|
103
139
|
|
|
104
140
|
def get_prediction(
|
|
105
141
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
|
@@ -246,10 +282,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
246
282
|
|
|
247
283
|
type_method_dict = {}
|
|
248
284
|
for target_method_name, sig in model_meta.signatures.items():
|
|
249
|
-
if target_method_name
|
|
285
|
+
if target_method_name in _ALLOWED_TARGET_METHODS:
|
|
250
286
|
type_method_dict[target_method_name] = get_prediction(raw_model, sig, target_method_name)
|
|
251
287
|
else:
|
|
252
|
-
ValueError(f"{target_method_name} is currently not supported.")
|
|
288
|
+
raise ValueError(f"{target_method_name} is currently not supported.")
|
|
253
289
|
|
|
254
290
|
_SentenceTransformer = type(
|
|
255
291
|
"_SentenceTransformer",
|
|
@@ -194,7 +194,18 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
|
194
194
|
|
|
195
195
|
if kwargs.get("use_gpu", False):
|
|
196
196
|
assert type(kwargs.get("use_gpu", False)) == bool
|
|
197
|
-
|
|
197
|
+
from packaging import version
|
|
198
|
+
|
|
199
|
+
xgb_version = version.parse(xgboost.__version__)
|
|
200
|
+
if xgb_version >= version.parse("3.1.0"):
|
|
201
|
+
# XGBoost 3.1.0+: Use device="cuda" for GPU acceleration
|
|
202
|
+
# gpu_hist and gpu_predictor were removed in XGBoost 3.1.0
|
|
203
|
+
# See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
|
|
204
|
+
gpu_params = {"tree_method": "hist", "device": "cuda"}
|
|
205
|
+
else:
|
|
206
|
+
# XGBoost < 3.1.0: Use legacy gpu_hist tree_method
|
|
207
|
+
gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
|
|
208
|
+
|
|
198
209
|
if isinstance(m, xgboost.Booster):
|
|
199
210
|
m.set_param(gpu_params)
|
|
200
211
|
elif isinstance(m, xgboost.XGBModel):
|
|
@@ -256,6 +267,20 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
|
256
267
|
@custom_model.inference_api
|
|
257
268
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
|
258
269
|
import shap
|
|
270
|
+
from packaging import version
|
|
271
|
+
|
|
272
|
+
xgb_version = version.parse(xgboost.__version__)
|
|
273
|
+
shap_version = version.parse(shap.__version__)
|
|
274
|
+
|
|
275
|
+
# SHAP < 0.50.0 is incompatible with XGBoost >= 3.1.0 due to base_score format change
|
|
276
|
+
# (base_score is now stored as a vector for multi-output models)
|
|
277
|
+
# See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
|
|
278
|
+
if xgb_version >= version.parse("3.1.0") and shap_version < version.parse("0.50.0"):
|
|
279
|
+
raise RuntimeError(
|
|
280
|
+
f"SHAP version {shap.__version__} is incompatible with XGBoost version "
|
|
281
|
+
f"{xgboost.__version__}. XGBoost 3.1+ changed the model format which requires "
|
|
282
|
+
f"SHAP >= 0.50.0. Please upgrade SHAP or use XGBoost < 3.1."
|
|
283
|
+
)
|
|
259
284
|
|
|
260
285
|
explainer = shap.TreeExplainer(raw_model)
|
|
261
286
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
|
@@ -20,6 +20,7 @@ from snowflake.ml.model._packager.model_env import model_env
|
|
|
20
20
|
from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta_schema
|
|
21
21
|
from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
|
|
22
22
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
|
23
|
+
from snowflake.ml.model.code_path import CodePath
|
|
23
24
|
|
|
24
25
|
MODEL_METADATA_FILE = "model.yaml"
|
|
25
26
|
MODEL_CODE_DIR = "code"
|
|
@@ -39,7 +40,7 @@ def create_model_metadata(
|
|
|
39
40
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
|
40
41
|
function_properties: Optional[dict[str, dict[str, Any]]] = None,
|
|
41
42
|
metadata: Optional[dict[str, str]] = None,
|
|
42
|
-
code_paths: Optional[list[
|
|
43
|
+
code_paths: Optional[list[model_types.CodePathLike]] = None,
|
|
43
44
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
44
45
|
conda_dependencies: Optional[list[str]] = None,
|
|
45
46
|
pip_requirements: Optional[list[str]] = None,
|
|
@@ -77,7 +78,8 @@ def create_model_metadata(
|
|
|
77
78
|
**kwargs: Dict of attributes and values of the metadata. Used when loading from file.
|
|
78
79
|
|
|
79
80
|
Raises:
|
|
80
|
-
ValueError: Raised when the code path contains reserved file or directory.
|
|
81
|
+
ValueError: Raised when the code path contains reserved file or directory, or destination conflicts.
|
|
82
|
+
FileNotFoundError: Raised when a code path does not exist.
|
|
81
83
|
|
|
82
84
|
Yields:
|
|
83
85
|
A model metadata object.
|
|
@@ -134,13 +136,44 @@ def create_model_metadata(
|
|
|
134
136
|
os.makedirs(code_dir_path, exist_ok=True)
|
|
135
137
|
|
|
136
138
|
if code_paths:
|
|
139
|
+
# Resolve all code paths and check for conflicts
|
|
140
|
+
resolved_paths: list[tuple[str, str]] = [] # (source, destination_relative)
|
|
137
141
|
for code_path in code_paths:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
+
if isinstance(code_path, CodePath):
|
|
143
|
+
source, dest_relative = code_path._resolve()
|
|
144
|
+
else:
|
|
145
|
+
# String path: keep existing behavior
|
|
146
|
+
source = os.path.normpath(os.path.abspath(code_path))
|
|
147
|
+
if not os.path.exists(source):
|
|
148
|
+
raise FileNotFoundError(f"Code path '{code_path}' does not exist (resolved to {source}).")
|
|
149
|
+
dest_relative = os.path.basename(source)
|
|
150
|
+
resolved_paths.append((source, dest_relative))
|
|
151
|
+
|
|
152
|
+
# Check for destination conflicts
|
|
153
|
+
seen: dict[str, str] = {}
|
|
154
|
+
for source, dest in resolved_paths:
|
|
155
|
+
if dest in seen:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Destination path conflict: '{dest}' is targeted by both '{seen[dest]}' and '{source}'."
|
|
158
|
+
)
|
|
159
|
+
seen[dest] = source
|
|
160
|
+
|
|
161
|
+
# Copy files
|
|
162
|
+
for source, dest_relative in resolved_paths:
|
|
163
|
+
# Prevent reserved name conflicts
|
|
164
|
+
dest_name = dest_relative.split(os.sep)[0] if os.sep in dest_relative else dest_relative
|
|
165
|
+
if (os.path.isfile(source) and os.path.splitext(dest_name)[0] == _SNOWFLAKE_PKG_NAME) or (
|
|
166
|
+
os.path.isdir(source) and dest_name == _SNOWFLAKE_PKG_NAME
|
|
167
|
+
):
|
|
142
168
|
raise ValueError("`snowflake` is a reserved name and you cannot contain that into code path.")
|
|
143
|
-
|
|
169
|
+
|
|
170
|
+
parent_dir = (
|
|
171
|
+
os.path.join(code_dir_path, os.path.dirname(dest_relative))
|
|
172
|
+
if os.path.dirname(dest_relative)
|
|
173
|
+
else code_dir_path
|
|
174
|
+
)
|
|
175
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
176
|
+
file_utils.copy_file_or_tree(source, parent_dir)
|
|
144
177
|
|
|
145
178
|
try:
|
|
146
179
|
imported_modules = []
|
|
@@ -49,7 +49,7 @@ class ModelPackager:
|
|
|
49
49
|
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
|
50
50
|
python_version: Optional[str] = None,
|
|
51
51
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
52
|
-
code_paths: Optional[list[
|
|
52
|
+
code_paths: Optional[list[model_types.CodePathLike]] = None,
|
|
53
53
|
options: model_types.ModelSaveOption,
|
|
54
54
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
55
55
|
) -> model_meta.ModelMetadata:
|
|
@@ -191,6 +191,35 @@ class DataType(Enum):
|
|
|
191
191
|
original_exception=NotImplementedError(f"Type {snowpark_type} is not supported as a DataType."),
|
|
192
192
|
)
|
|
193
193
|
|
|
194
|
+
@classmethod
|
|
195
|
+
def from_python_type(cls, python_type: type) -> "DataType":
|
|
196
|
+
"""Translate Python built-in type to DataType for signature definition.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
python_type: A Python built-in type (int, float, str, bool).
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
SnowflakeMLException: NotImplementedError: Raised when the given Python type is not supported.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
Corresponding DataType.
|
|
206
|
+
"""
|
|
207
|
+
python_to_snowml_type_mapping: dict[type, "DataType"] = {
|
|
208
|
+
int: DataType.INT64,
|
|
209
|
+
float: DataType.DOUBLE,
|
|
210
|
+
str: DataType.STRING,
|
|
211
|
+
bool: DataType.BOOL,
|
|
212
|
+
}
|
|
213
|
+
if python_type in python_to_snowml_type_mapping:
|
|
214
|
+
return python_to_snowml_type_mapping[python_type]
|
|
215
|
+
raise snowml_exceptions.SnowflakeMLException(
|
|
216
|
+
error_code=error_codes.NOT_IMPLEMENTED,
|
|
217
|
+
original_exception=NotImplementedError(
|
|
218
|
+
f"Python type {python_type} is not supported as a DataType. "
|
|
219
|
+
f"Supported types are: {list(python_to_snowml_type_mapping.keys())}."
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
|
|
194
223
|
|
|
195
224
|
class BaseFeatureSpec(ABC):
|
|
196
225
|
"""Abstract Class for specification of a feature."""
|
|
@@ -764,10 +793,17 @@ class ModelSignature:
|
|
|
764
793
|
the output of the model.
|
|
765
794
|
params: A sequence of parameter specifications and parameter group specifications that will compose
|
|
766
795
|
the parameters of the model. Defaults to None.
|
|
796
|
+
|
|
797
|
+
Raises:
|
|
798
|
+
SnowflakeMLException: ValueError: When the parameters have duplicate names or the same
|
|
799
|
+
names as input features.
|
|
800
|
+
|
|
801
|
+
# noqa: DAR402
|
|
767
802
|
"""
|
|
768
803
|
self._inputs = inputs
|
|
769
804
|
self._outputs = outputs
|
|
770
805
|
self._params = params or []
|
|
806
|
+
self._name_validation()
|
|
771
807
|
|
|
772
808
|
@property
|
|
773
809
|
def inputs(self) -> Sequence[BaseFeatureSpec]:
|
|
@@ -879,6 +915,55 @@ class ModelSignature:
|
|
|
879
915
|
|
|
880
916
|
return html_utils.create_base_container("Model Signature", content)
|
|
881
917
|
|
|
918
|
+
def _name_validation(self) -> None:
|
|
919
|
+
"""Validate the names of the inputs and parameters.
|
|
920
|
+
|
|
921
|
+
Names are compared case-insensitively (matches Snowflake identifier behavior).
|
|
922
|
+
|
|
923
|
+
Raises:
|
|
924
|
+
SnowflakeMLException: ValueError: When the parameters have duplicate names or the same
|
|
925
|
+
names as input features.
|
|
926
|
+
"""
|
|
927
|
+
input_names: set[str] = set()
|
|
928
|
+
for input_spec in self._inputs:
|
|
929
|
+
names = (
|
|
930
|
+
[input_spec.name.upper() for spec in input_spec._specs]
|
|
931
|
+
if isinstance(input_spec, FeatureGroupSpec)
|
|
932
|
+
else [input_spec.name.upper()]
|
|
933
|
+
)
|
|
934
|
+
input_names.update(names)
|
|
935
|
+
|
|
936
|
+
param_names: set[str] = set()
|
|
937
|
+
dup_params: set[str] = set()
|
|
938
|
+
collision_names: set[str] = set()
|
|
939
|
+
|
|
940
|
+
for param in self._params:
|
|
941
|
+
names = [spec.name for spec in param.specs] if isinstance(param, ParamGroupSpec) else [param.name]
|
|
942
|
+
for name in names:
|
|
943
|
+
if name.upper() in param_names:
|
|
944
|
+
dup_params.add(name)
|
|
945
|
+
if name.upper() in input_names:
|
|
946
|
+
collision_names.add(name)
|
|
947
|
+
param_names.add(name.upper())
|
|
948
|
+
|
|
949
|
+
if dup_params:
|
|
950
|
+
raise snowml_exceptions.SnowflakeMLException(
|
|
951
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
952
|
+
original_exception=ValueError(
|
|
953
|
+
f"Found duplicate parameter named resolved as {', '.join(sorted(dup_params))}."
|
|
954
|
+
" Parameters must have distinct names (case-insensitive)."
|
|
955
|
+
),
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
if collision_names:
|
|
959
|
+
raise snowml_exceptions.SnowflakeMLException(
|
|
960
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
961
|
+
original_exception=ValueError(
|
|
962
|
+
f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))}."
|
|
963
|
+
" Parameters and inputs must have distinct names (case-insensitive)."
|
|
964
|
+
),
|
|
965
|
+
)
|
|
966
|
+
|
|
882
967
|
@classmethod
|
|
883
968
|
def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
|
|
884
969
|
return ModelSignature(
|
|
@@ -9,6 +9,7 @@ from snowflake.ml._internal.exceptions import (
|
|
|
9
9
|
error_codes,
|
|
10
10
|
exceptions as snowml_exceptions,
|
|
11
11
|
)
|
|
12
|
+
from snowflake.ml.model import openai_signatures
|
|
12
13
|
from snowflake.ml.model._signatures import core
|
|
13
14
|
|
|
14
15
|
|
|
@@ -259,6 +260,48 @@ def huggingface_pipeline_signature_auto_infer(
|
|
|
259
260
|
],
|
|
260
261
|
)
|
|
261
262
|
|
|
263
|
+
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ImageClassificationPipeline
|
|
264
|
+
if task == "image-classification":
|
|
265
|
+
return core.ModelSignature(
|
|
266
|
+
inputs=[
|
|
267
|
+
core.FeatureSpec(name="images", dtype=core.DataType.BYTES),
|
|
268
|
+
],
|
|
269
|
+
outputs=[
|
|
270
|
+
core.FeatureGroupSpec(
|
|
271
|
+
name="labels",
|
|
272
|
+
specs=[
|
|
273
|
+
core.FeatureSpec(name="label", dtype=core.DataType.STRING),
|
|
274
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
|
275
|
+
],
|
|
276
|
+
shape=(-1,),
|
|
277
|
+
),
|
|
278
|
+
],
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
|
|
282
|
+
if task == "automatic-speech-recognition":
|
|
283
|
+
return core.ModelSignature(
|
|
284
|
+
inputs=[
|
|
285
|
+
core.FeatureSpec(name="audio", dtype=core.DataType.BYTES),
|
|
286
|
+
],
|
|
287
|
+
outputs=[
|
|
288
|
+
core.FeatureGroupSpec(
|
|
289
|
+
name="outputs",
|
|
290
|
+
specs=[
|
|
291
|
+
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
|
292
|
+
core.FeatureGroupSpec(
|
|
293
|
+
name="chunks",
|
|
294
|
+
specs=[
|
|
295
|
+
core.FeatureSpec(name="timestamp", dtype=core.DataType.DOUBLE, shape=(2,)),
|
|
296
|
+
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
|
297
|
+
],
|
|
298
|
+
shape=(-1,), # Variable length list of chunks
|
|
299
|
+
),
|
|
300
|
+
],
|
|
301
|
+
),
|
|
302
|
+
],
|
|
303
|
+
)
|
|
304
|
+
|
|
262
305
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextGenerationPipeline
|
|
263
306
|
if task == "text-generation":
|
|
264
307
|
if params.get("return_tensors", False):
|
|
@@ -288,6 +331,18 @@ def huggingface_pipeline_signature_auto_infer(
|
|
|
288
331
|
)
|
|
289
332
|
],
|
|
290
333
|
)
|
|
334
|
+
|
|
335
|
+
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ImageTextToTextPipeline
|
|
336
|
+
if task == "image-text-to-text":
|
|
337
|
+
if params.get("return_tensors", False):
|
|
338
|
+
raise NotImplementedError(
|
|
339
|
+
f"Auto deployment for HuggingFace pipeline {task} "
|
|
340
|
+
"when `return_tensors` set to `True` has not been supported yet."
|
|
341
|
+
)
|
|
342
|
+
# Always generate a dict per input
|
|
343
|
+
# defaulting to OPENAI_CHAT_SIGNATURE_SPEC for image-text-to-text pipeline
|
|
344
|
+
return openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
|
|
345
|
+
|
|
291
346
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
|
|
292
347
|
if task == "text2text-generation":
|
|
293
348
|
if params.get("return_tensors", False):
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""CodePath class for selective code packaging in model registry."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
_ERR_ROOT_NOT_FOUND = "CodePath: root '{root}' does not exist (resolved to {resolved})."
|
|
8
|
+
_ERR_WILDCARDS_NOT_SUPPORTED = "CodePath: Wildcards are not supported in filter. Got '{filter}'. Use exact paths only."
|
|
9
|
+
_ERR_FILTER_MUST_BE_RELATIVE = "CodePath: filter must be a relative path, got absolute path '{filter}'."
|
|
10
|
+
_ERR_FILTER_HOME_PATH = "CodePath: filter must be a relative path, got home directory path '{filter}'."
|
|
11
|
+
_ERR_FILTER_ON_FILE_ROOT = (
|
|
12
|
+
"CodePath: cannot apply filter to a file root. " "Root '{root}' is a file. Use filter only with directory roots."
|
|
13
|
+
)
|
|
14
|
+
_ERR_FILTER_ESCAPES_ROOT = (
|
|
15
|
+
"CodePath: filter '{filter}' escapes root directory '{root}'. " "Relative paths must stay within root."
|
|
16
|
+
)
|
|
17
|
+
_ERR_FILTER_NOT_FOUND = "CodePath: filter '{filter}' under root '{root}' does not exist (resolved to {resolved})."
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class CodePath:
|
|
22
|
+
"""Specifies a code path with optional filtering for selective inclusion.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root: The root directory or file path (absolute or relative to cwd).
|
|
26
|
+
filter: Optional relative path under root to select a subdirectory or file.
|
|
27
|
+
The filter also determines the destination path under code/.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
CodePath("project/src/") # Copy entire src/ to code/src/
|
|
31
|
+
CodePath("project/src/", filter="utils") # Copy utils/ to code/utils/
|
|
32
|
+
CodePath("project/src/", filter="lib/helpers") # Copy to code/lib/helpers/
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
root: str
|
|
36
|
+
filter: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
if self.filter == "":
|
|
40
|
+
object.__setattr__(self, "filter", None)
|
|
41
|
+
|
|
42
|
+
def __repr__(self) -> str:
|
|
43
|
+
if self.filter:
|
|
44
|
+
return f"CodePath({self.root!r}, filter={self.filter!r})"
|
|
45
|
+
return f"CodePath({self.root!r})"
|
|
46
|
+
|
|
47
|
+
def _validate_filter(self) -> Optional[str]:
|
|
48
|
+
"""Validate and normalize filter, returning normalized filter or None.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Normalized filter path, or None if no filter is set.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If filter contains wildcards or is an absolute path.
|
|
55
|
+
"""
|
|
56
|
+
if self.filter is None:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
if any(c in self.filter for c in ["*", "?", "[", "]"]):
|
|
60
|
+
raise ValueError(_ERR_WILDCARDS_NOT_SUPPORTED.format(filter=self.filter))
|
|
61
|
+
|
|
62
|
+
if self.filter.startswith("~"):
|
|
63
|
+
raise ValueError(_ERR_FILTER_HOME_PATH.format(filter=self.filter))
|
|
64
|
+
|
|
65
|
+
filter_normalized = os.path.normpath(self.filter)
|
|
66
|
+
|
|
67
|
+
if os.path.isabs(filter_normalized):
|
|
68
|
+
raise ValueError(_ERR_FILTER_MUST_BE_RELATIVE.format(filter=self.filter))
|
|
69
|
+
|
|
70
|
+
return filter_normalized
|
|
71
|
+
|
|
72
|
+
def _resolve(self) -> tuple[str, str]:
|
|
73
|
+
"""Resolve the source path and destination path.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (source_path, destination_relative_path)
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
FileNotFoundError: If root or filter path does not exist.
|
|
80
|
+
ValueError: If filter is invalid (wildcards, absolute, escapes root, or applied to file).
|
|
81
|
+
"""
|
|
82
|
+
filter_normalized = self._validate_filter()
|
|
83
|
+
root_normalized = os.path.normpath(os.path.abspath(self.root))
|
|
84
|
+
|
|
85
|
+
if filter_normalized is None:
|
|
86
|
+
if not os.path.exists(root_normalized):
|
|
87
|
+
raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
|
|
88
|
+
return root_normalized, os.path.basename(root_normalized)
|
|
89
|
+
|
|
90
|
+
if not os.path.exists(root_normalized):
|
|
91
|
+
raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
|
|
92
|
+
|
|
93
|
+
if os.path.isfile(root_normalized):
|
|
94
|
+
raise ValueError(_ERR_FILTER_ON_FILE_ROOT.format(root=self.root))
|
|
95
|
+
|
|
96
|
+
source = os.path.normpath(os.path.join(root_normalized, filter_normalized))
|
|
97
|
+
|
|
98
|
+
if not (source.startswith(root_normalized + os.sep) or source == root_normalized):
|
|
99
|
+
raise ValueError(_ERR_FILTER_ESCAPES_ROOT.format(filter=self.filter, root=self.root))
|
|
100
|
+
|
|
101
|
+
if not os.path.exists(source):
|
|
102
|
+
raise FileNotFoundError(_ERR_FILTER_NOT_FOUND.format(filter=self.filter, root=self.root, resolved=source))
|
|
103
|
+
|
|
104
|
+
return source, filter_normalized
|
|
@@ -4,10 +4,13 @@ from typing import Any, Callable, Coroutine, Generator, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
import anyio
|
|
6
6
|
import pandas as pd
|
|
7
|
-
from typing_extensions import deprecated
|
|
7
|
+
from typing_extensions import Concatenate, ParamSpec, deprecated
|
|
8
8
|
|
|
9
9
|
from snowflake.ml.model import type_hints as model_types
|
|
10
10
|
|
|
11
|
+
# Captures additional keyword-only parameters for inference methods
|
|
12
|
+
InferenceParams = ParamSpec("InferenceParams")
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
class MethodRef:
|
|
13
16
|
"""Represents a method invocation of an instance of `ModelRef`.
|
|
@@ -217,7 +220,7 @@ class CustomModel:
|
|
|
217
220
|
|
|
218
221
|
def _get_infer_methods(
|
|
219
222
|
self,
|
|
220
|
-
) -> Generator[Callable[
|
|
223
|
+
) -> Generator[Callable[..., pd.DataFrame], None, None]:
|
|
221
224
|
"""Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
|
|
222
225
|
for cls_method_str in dir(self):
|
|
223
226
|
cls_method = getattr(self, cls_method_str)
|
|
@@ -240,7 +243,7 @@ class CustomModel:
|
|
|
240
243
|
return rv
|
|
241
244
|
|
|
242
245
|
|
|
243
|
-
def _validate_predict_function(func: Callable[
|
|
246
|
+
def _validate_predict_function(func: Callable[..., pd.DataFrame]) -> None:
|
|
244
247
|
"""Validate the user provided predict method.
|
|
245
248
|
|
|
246
249
|
Args:
|
|
@@ -248,19 +251,22 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
|
|
|
248
251
|
|
|
249
252
|
Raises:
|
|
250
253
|
TypeError: Raised when the method is not a callable object.
|
|
251
|
-
TypeError: Raised when the method does not have 2 arguments (self and X).
|
|
254
|
+
TypeError: Raised when the method does not have at least 2 arguments (self and X).
|
|
252
255
|
TypeError: Raised when the method does not have typing annotation.
|
|
253
256
|
TypeError: Raised when the method's input (X) does not have type pd.DataFrame.
|
|
254
257
|
TypeError: Raised when the method's output does not have type pd.DataFrame.
|
|
258
|
+
TypeError: Raised when additional parameters are not keyword-only with defaults.
|
|
255
259
|
"""
|
|
256
260
|
if not callable(func):
|
|
257
261
|
raise TypeError("Predict method is not callable.")
|
|
258
262
|
|
|
259
263
|
func_signature = inspect.signature(func)
|
|
260
|
-
|
|
261
|
-
|
|
264
|
+
func_signature_params = list(func_signature.parameters.values())
|
|
265
|
+
|
|
266
|
+
if len(func_signature_params) < 2:
|
|
267
|
+
raise TypeError("Predict method should have at least 2 arguments.")
|
|
262
268
|
|
|
263
|
-
input_annotation =
|
|
269
|
+
input_annotation = func_signature_params[1].annotation
|
|
264
270
|
output_annotation = func_signature.return_annotation
|
|
265
271
|
|
|
266
272
|
if input_annotation == inspect.Parameter.empty or output_annotation == inspect.Signature.empty:
|
|
@@ -275,17 +281,53 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
|
|
|
275
281
|
):
|
|
276
282
|
raise TypeError("Output for predict method should have type pandas.DataFrame.")
|
|
277
283
|
|
|
284
|
+
# Validate additional parameters (beyond self and input) are keyword-only with defaults
|
|
285
|
+
for func_signature_param in func_signature_params[2:]:
|
|
286
|
+
_validate_parameter(func_signature_param)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _validate_parameter(param: inspect.Parameter) -> None:
|
|
290
|
+
"""Validate a parameter."""
|
|
291
|
+
if param.kind != inspect.Parameter.KEYWORD_ONLY:
|
|
292
|
+
raise TypeError(f"Parameter '{param.name}' must be keyword-only (defined after '*' in signature).")
|
|
293
|
+
if param.default == inspect.Parameter.empty:
|
|
294
|
+
raise TypeError(f"Parameter '{param.name}' must have a default value.")
|
|
295
|
+
if param.annotation == inspect.Parameter.empty:
|
|
296
|
+
raise TypeError(f"Parameter '{param.name}' must have a type annotation.")
|
|
297
|
+
|
|
298
|
+
# Validate annotation is a supported type
|
|
299
|
+
supported_types = {int, float, str, bool}
|
|
300
|
+
if param.annotation not in supported_types:
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Parameter '{param.name}' has unsupported type annotation '{param.annotation}'. "
|
|
303
|
+
f"Supported types are: int, float, str, bool"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def get_method_parameters(func: Callable[..., Any]) -> list[tuple[str, Any, Any]]:
|
|
308
|
+
"""Extract keyword-only parameters with defaults from an inference method.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
func: The inference method.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
A list of tuples (name, type annotation, default value) for each keyword-only parameter.
|
|
315
|
+
"""
|
|
316
|
+
func_signature = inspect.signature(func)
|
|
317
|
+
params = list(func_signature.parameters.values())
|
|
318
|
+
return [(param.name, param.annotation, param.default) for param in params[2:]]
|
|
319
|
+
|
|
278
320
|
|
|
279
321
|
def inference_api(
|
|
280
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
281
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
322
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
323
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
282
324
|
func.__dict__["_is_inference_api"] = True
|
|
283
325
|
return func
|
|
284
326
|
|
|
285
327
|
|
|
286
328
|
def partitioned_api(
|
|
287
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
288
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
329
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
330
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
289
331
|
func.__dict__["_is_inference_api"] = True
|
|
290
332
|
func.__dict__["_is_partitioned_api"] = True
|
|
291
333
|
return func
|
|
@@ -296,8 +338,8 @@ def partitioned_api(
|
|
|
296
338
|
" Use snowflake.ml.custom_model.partitioned_api instead."
|
|
297
339
|
)
|
|
298
340
|
def partitioned_inference_api(
|
|
299
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
|
300
|
-
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
|
341
|
+
func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
|
|
342
|
+
) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
|
|
301
343
|
func.__dict__["_is_inference_api"] = True
|
|
302
344
|
func.__dict__["_is_partitioned_api"] = True
|
|
303
345
|
return func
|