snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -105,7 +105,7 @@ class ModelMethod:
|
|
|
105
105
|
except ValueError as e:
|
|
106
106
|
raise ValueError(
|
|
107
107
|
f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
|
|
108
|
-
"Try
|
|
108
|
+
"Try specifying `case_sensitive` as True."
|
|
109
109
|
) from e
|
|
110
110
|
|
|
111
111
|
if self.target_method not in self.model_meta.signatures.keys():
|
|
@@ -127,12 +127,41 @@ class ModelMethod:
|
|
|
127
127
|
except ValueError as e:
|
|
128
128
|
raise ValueError(
|
|
129
129
|
f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
|
|
130
|
-
"Try
|
|
130
|
+
"Try specifying `case_sensitive` as True."
|
|
131
131
|
) from e
|
|
132
132
|
return model_manifest_schema.ModelMethodSignatureFieldWithName(
|
|
133
133
|
name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
|
|
134
134
|
)
|
|
135
135
|
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
|
|
138
|
+
"""Flatten ParamGroupSpec into leaf ParamSpec items."""
|
|
139
|
+
result: list[model_signature.ParamSpec] = []
|
|
140
|
+
for param in params:
|
|
141
|
+
if isinstance(param, model_signature.ParamSpec):
|
|
142
|
+
result.append(param)
|
|
143
|
+
elif isinstance(param, model_signature.ParamGroupSpec):
|
|
144
|
+
result.extend(ModelMethod._flatten_params(param.specs))
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _get_method_arg_from_param(
|
|
149
|
+
param_spec: model_signature.ParamSpec,
|
|
150
|
+
case_sensitive: bool = False,
|
|
151
|
+
) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
|
|
152
|
+
try:
|
|
153
|
+
param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
|
|
154
|
+
except ValueError as e:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
|
+
"Try specifying `case_sensitive` as True."
|
|
158
|
+
) from e
|
|
159
|
+
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
160
|
+
name=param_name.resolved(),
|
|
161
|
+
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
162
|
+
default=param_spec.default_value,
|
|
163
|
+
)
|
|
164
|
+
|
|
136
165
|
def save(
|
|
137
166
|
self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
|
|
138
167
|
) -> model_manifest_schema.ModelMethodDict:
|
|
@@ -182,6 +211,36 @@ class ModelMethod:
|
|
|
182
211
|
inputs=input_list,
|
|
183
212
|
outputs=outputs,
|
|
184
213
|
)
|
|
214
|
+
|
|
215
|
+
# Add parameters if signature has parameters
|
|
216
|
+
if self.model_meta.signatures[self.target_method].params:
|
|
217
|
+
flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
|
|
218
|
+
param_list = [
|
|
219
|
+
ModelMethod._get_method_arg_from_param(
|
|
220
|
+
param_spec, case_sensitive=self.options.get("case_sensitive", False)
|
|
221
|
+
)
|
|
222
|
+
for param_spec in flat_params
|
|
223
|
+
]
|
|
224
|
+
param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
|
|
225
|
+
dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
|
|
226
|
+
if dup_param_names:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
|
|
229
|
+
f" {self.target_method}. This might be because you have parameters with same letters but "
|
|
230
|
+
"different cases. In this case, set case_sensitive as True for those methods to distinguish them."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Check for name collisions between parameters and inputs using existing counters
|
|
234
|
+
collision_names = [name for name in param_name_counter if name in input_name_counter]
|
|
235
|
+
if collision_names:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
|
|
238
|
+
f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
|
|
239
|
+
"Try using case_sensitive=True if the names differ only by case."
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
method_dict["params"] = param_list
|
|
243
|
+
|
|
185
244
|
should_set_volatility = (
|
|
186
245
|
platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
|
|
187
246
|
)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import importlib
|
|
3
|
+
import logging
|
|
3
4
|
import pkgutil
|
|
4
5
|
from types import ModuleType
|
|
5
6
|
from typing import Any, Callable, Optional, TypeVar, cast
|
|
@@ -11,6 +12,8 @@ _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
|
|
|
11
12
|
_MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
|
|
12
13
|
_IS_HANDLER_LOADED = False
|
|
13
14
|
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
14
17
|
|
|
15
18
|
def _register_handlers() -> None:
|
|
16
19
|
"""
|
|
@@ -56,8 +59,11 @@ def find_handler(
|
|
|
56
59
|
model: model_types.SupportedModelType,
|
|
57
60
|
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
|
58
61
|
for handler in _MODEL_HANDLER_REGISTRY.values():
|
|
59
|
-
|
|
60
|
-
|
|
62
|
+
try:
|
|
63
|
+
if handler.can_handle(model):
|
|
64
|
+
return handler
|
|
65
|
+
except Exception:
|
|
66
|
+
logger.error(f"Error in {handler.__name__} `can_handle` method for model {type(model)}", exc_info=True)
|
|
61
67
|
return None
|
|
62
68
|
|
|
63
69
|
|
|
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
86
86
|
get_prediction_fn=get_prediction,
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
+
# Add parameters extracted from custom model inference methods to signatures
|
|
90
|
+
cls._add_method_parameters_to_signatures(model, model_meta)
|
|
91
|
+
|
|
89
92
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
90
93
|
os.makedirs(model_blob_path, exist_ok=True)
|
|
91
94
|
if model.context.artifacts:
|
|
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
188
191
|
assert isinstance(model, custom_model.CustomModel)
|
|
189
192
|
return model
|
|
190
193
|
|
|
194
|
+
@classmethod
|
|
195
|
+
def _add_method_parameters_to_signatures(
|
|
196
|
+
cls,
|
|
197
|
+
model: "custom_model.CustomModel",
|
|
198
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Extract parameters from custom model inference methods and add them to signatures.
|
|
201
|
+
|
|
202
|
+
For each inference method, if the signature doesn't already have parameters and the method
|
|
203
|
+
has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
model: The custom model instance.
|
|
207
|
+
model_meta: The model metadata containing signatures to augment.
|
|
208
|
+
"""
|
|
209
|
+
for method in model._get_infer_methods():
|
|
210
|
+
method_name = method.__name__
|
|
211
|
+
if method_name not in model_meta.signatures:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
sig = model_meta.signatures[method_name]
|
|
215
|
+
|
|
216
|
+
# Skip if the signature already has parameters (user-provided or previously set)
|
|
217
|
+
if sig.params:
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
# Extract parameters from the method
|
|
221
|
+
method_params = custom_model.get_method_parameters(method)
|
|
222
|
+
if not method_params:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Create ParamSpecs from the method parameters
|
|
226
|
+
param_specs = []
|
|
227
|
+
for param_name, param_type, param_default in method_params:
|
|
228
|
+
dtype = model_signature.DataType.from_python_type(param_type)
|
|
229
|
+
param_spec = model_signature.ParamSpec(
|
|
230
|
+
name=param_name,
|
|
231
|
+
dtype=dtype,
|
|
232
|
+
default_value=param_default,
|
|
233
|
+
)
|
|
234
|
+
param_specs.append(param_spec)
|
|
235
|
+
|
|
236
|
+
# Create a new signature with parameters
|
|
237
|
+
model_meta.signatures[method_name] = model_signature.ModelSignature(
|
|
238
|
+
inputs=sig.inputs,
|
|
239
|
+
outputs=sig.outputs,
|
|
240
|
+
params=param_specs,
|
|
241
|
+
)
|
|
242
|
+
|
|
191
243
|
@classmethod
|
|
192
244
|
def convert_as_custom_model(
|
|
193
245
|
cls,
|
|
@@ -29,12 +29,16 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
29
29
|
model_meta_schema,
|
|
30
30
|
)
|
|
31
31
|
from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
32
|
-
from snowflake.ml.model.models import
|
|
32
|
+
from snowflake.ml.model.models import (
|
|
33
|
+
huggingface as huggingface_base,
|
|
34
|
+
huggingface_pipeline,
|
|
35
|
+
)
|
|
33
36
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
34
37
|
|
|
35
38
|
logger = logging.getLogger(__name__)
|
|
36
39
|
|
|
37
40
|
if TYPE_CHECKING:
|
|
41
|
+
import torch
|
|
38
42
|
import transformers
|
|
39
43
|
|
|
40
44
|
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
|
@@ -77,8 +81,14 @@ def sanitize_output(data: Any) -> Any:
|
|
|
77
81
|
|
|
78
82
|
|
|
79
83
|
@final
|
|
80
|
-
class
|
|
81
|
-
_base.BaseModelHandler[
|
|
84
|
+
class TransformersPipelineHandler(
|
|
85
|
+
_base.BaseModelHandler[
|
|
86
|
+
Union[
|
|
87
|
+
huggingface_base.TransformersPipeline,
|
|
88
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
89
|
+
"transformers.Pipeline",
|
|
90
|
+
]
|
|
91
|
+
]
|
|
82
92
|
):
|
|
83
93
|
"""Handler for custom model."""
|
|
84
94
|
|
|
@@ -97,35 +107,48 @@ class HuggingFacePipelineHandler(
|
|
|
97
107
|
def can_handle(
|
|
98
108
|
cls,
|
|
99
109
|
model: model_types.SupportedModelType,
|
|
100
|
-
) -> TypeGuard[
|
|
110
|
+
) -> TypeGuard[
|
|
111
|
+
Union[
|
|
112
|
+
huggingface_base.TransformersPipeline,
|
|
113
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
114
|
+
"transformers.Pipeline",
|
|
115
|
+
]
|
|
116
|
+
]:
|
|
101
117
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
|
102
118
|
return True
|
|
103
119
|
if isinstance(model, huggingface_pipeline.HuggingFacePipelineModel):
|
|
104
120
|
return True
|
|
121
|
+
if isinstance(model, huggingface_base.TransformersPipeline):
|
|
122
|
+
return True
|
|
105
123
|
return False
|
|
106
124
|
|
|
107
125
|
@classmethod
|
|
108
126
|
def cast_model(
|
|
109
127
|
cls,
|
|
110
128
|
model: model_types.SupportedModelType,
|
|
111
|
-
) -> Union[
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
except ImportError:
|
|
118
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
129
|
+
) -> Union[
|
|
130
|
+
huggingface_base.TransformersPipeline,
|
|
131
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
132
|
+
"transformers.Pipeline",
|
|
133
|
+
]:
|
|
134
|
+
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
|
119
135
|
return model
|
|
120
|
-
|
|
121
|
-
|
|
136
|
+
elif isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
137
|
+
model, huggingface_base.TransformersPipeline
|
|
138
|
+
):
|
|
122
139
|
return model
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"Model {model} is not a valid Hugging Face model.")
|
|
123
142
|
|
|
124
143
|
@classmethod
|
|
125
144
|
def save_model(
|
|
126
145
|
cls,
|
|
127
146
|
name: str,
|
|
128
|
-
model: Union[
|
|
147
|
+
model: Union[
|
|
148
|
+
huggingface_base.TransformersPipeline,
|
|
149
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
150
|
+
"transformers.Pipeline",
|
|
151
|
+
],
|
|
129
152
|
model_meta: model_meta_api.ModelMetadata,
|
|
130
153
|
model_blobs_dir_path: str,
|
|
131
154
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
|
@@ -140,7 +163,9 @@ class HuggingFacePipelineHandler(
|
|
|
140
163
|
framework = model.framework # type:ignore[attr-defined]
|
|
141
164
|
batch_size = model._batch_size # type:ignore[attr-defined]
|
|
142
165
|
else:
|
|
143
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
166
|
+
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
167
|
+
model, huggingface_base.TransformersPipeline
|
|
168
|
+
)
|
|
144
169
|
task = model.task
|
|
145
170
|
framework = getattr(model, "framework", None)
|
|
146
171
|
batch_size = getattr(model, "batch_size", None)
|
|
@@ -156,7 +181,9 @@ class HuggingFacePipelineHandler(
|
|
|
156
181
|
**model._postprocess_params, # type:ignore[attr-defined]
|
|
157
182
|
}
|
|
158
183
|
else:
|
|
159
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
184
|
+
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
185
|
+
model, huggingface_base.TransformersPipeline
|
|
186
|
+
)
|
|
160
187
|
params = {**model.__dict__, **model.model_kwargs}
|
|
161
188
|
|
|
162
189
|
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
@@ -177,7 +204,7 @@ class HuggingFacePipelineHandler(
|
|
|
177
204
|
else:
|
|
178
205
|
warnings.warn(
|
|
179
206
|
"It is impossible to validate your model signatures when using a"
|
|
180
|
-
"
|
|
207
|
+
f" {type(model).__name__} object. "
|
|
181
208
|
"Please make sure you are providing correct model signatures.",
|
|
182
209
|
UserWarning,
|
|
183
210
|
stacklevel=2,
|
|
@@ -302,14 +329,16 @@ class HuggingFacePipelineHandler(
|
|
|
302
329
|
def _load_pickle_model(
|
|
303
330
|
pickle_file: str,
|
|
304
331
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
305
|
-
) -> huggingface_pipeline.HuggingFacePipelineModel:
|
|
332
|
+
) -> Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline]:
|
|
306
333
|
with open(pickle_file, "rb") as f:
|
|
307
334
|
m = cloudpickle.load(f)
|
|
308
|
-
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
|
|
335
|
+
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
336
|
+
m, huggingface_base.TransformersPipeline
|
|
337
|
+
)
|
|
309
338
|
torch_dtype: Optional[str] = None
|
|
310
339
|
device_config = None
|
|
311
340
|
if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
|
|
312
|
-
device_config =
|
|
341
|
+
device_config = TransformersPipelineHandler._get_device_config(**kwargs)
|
|
313
342
|
m.__dict__.update(device_config)
|
|
314
343
|
|
|
315
344
|
if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
|
|
@@ -326,7 +355,9 @@ class HuggingFacePipelineHandler(
|
|
|
326
355
|
model_meta: model_meta_api.ModelMetadata,
|
|
327
356
|
model_blobs_dir_path: str,
|
|
328
357
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
329
|
-
) -> Union[
|
|
358
|
+
) -> Union[
|
|
359
|
+
huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline, "transformers.Pipeline"
|
|
360
|
+
]:
|
|
330
361
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
331
362
|
# We need to redirect the some folders to a writable location in the sandbox.
|
|
332
363
|
os.environ["HF_HOME"] = "/tmp"
|
|
@@ -369,7 +400,7 @@ class HuggingFacePipelineHandler(
|
|
|
369
400
|
) as f:
|
|
370
401
|
pipeline_params = cloudpickle.load(f)
|
|
371
402
|
|
|
372
|
-
device_config =
|
|
403
|
+
device_config = TransformersPipelineHandler._get_device_config(**kwargs)
|
|
373
404
|
|
|
374
405
|
m = transformers.pipeline(
|
|
375
406
|
model_blob_options["task"],
|
|
@@ -402,7 +433,7 @@ class HuggingFacePipelineHandler(
|
|
|
402
433
|
|
|
403
434
|
def _create_pipeline_from_model(
|
|
404
435
|
model_blob_file_or_dir_path: str,
|
|
405
|
-
m: huggingface_pipeline.HuggingFacePipelineModel,
|
|
436
|
+
m: Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline],
|
|
406
437
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
407
438
|
) -> "transformers.Pipeline":
|
|
408
439
|
import transformers
|
|
@@ -414,7 +445,7 @@ class HuggingFacePipelineHandler(
|
|
|
414
445
|
torch_dtype=getattr(m, "torch_dtype", None),
|
|
415
446
|
revision=m.revision,
|
|
416
447
|
# pass device or device_map when creating the pipeline
|
|
417
|
-
**
|
|
448
|
+
**TransformersPipelineHandler._get_device_config(**kwargs),
|
|
418
449
|
# pass other model_kwargs to transformers.pipeline.from_pretrained method
|
|
419
450
|
**m.model_kwargs,
|
|
420
451
|
)
|
|
@@ -455,7 +486,11 @@ class HuggingFacePipelineHandler(
|
|
|
455
486
|
@classmethod
|
|
456
487
|
def convert_as_custom_model(
|
|
457
488
|
cls,
|
|
458
|
-
raw_model: Union[
|
|
489
|
+
raw_model: Union[
|
|
490
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
491
|
+
huggingface_base.TransformersPipeline,
|
|
492
|
+
"transformers.Pipeline",
|
|
493
|
+
],
|
|
459
494
|
model_meta: model_meta_api.ModelMetadata,
|
|
460
495
|
background_data: Optional[pd.DataFrame] = None,
|
|
461
496
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
@@ -609,7 +644,9 @@ class HuggingFacePipelineHandler(
|
|
|
609
644
|
|
|
610
645
|
return _HFPipelineModel
|
|
611
646
|
|
|
612
|
-
if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
647
|
+
if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
648
|
+
raw_model, huggingface_base.TransformersPipeline
|
|
649
|
+
):
|
|
613
650
|
if version.parse(transformers.__version__) < version.parse("4.32.0"):
|
|
614
651
|
# Backward compatibility since HF interface change.
|
|
615
652
|
raw_model.__dict__["use_auth_token"] = raw_model.__dict__["token"]
|
|
@@ -668,43 +705,64 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
668
705
|
|
|
669
706
|
self.model_name = self.pipeline.model.name_or_path
|
|
670
707
|
|
|
708
|
+
if self.tokenizer.pad_token is None:
|
|
709
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
710
|
+
|
|
711
|
+
# Ensure the tokenizer has a chat template.
|
|
712
|
+
# If not, we inject the default ChatML template which supports prompt generation.
|
|
713
|
+
if not getattr(self.tokenizer, "chat_template", None):
|
|
714
|
+
logger.warning(f"No chat template found for {self.model_name}. Using default ChatML template.")
|
|
715
|
+
self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
|
716
|
+
|
|
671
717
|
def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
|
|
672
718
|
"""
|
|
673
719
|
Applies a chat template to a list of messages.
|
|
674
|
-
If the tokenizer has a chat template, it uses that.
|
|
675
|
-
Otherwise, it falls back to a simple concatenation.
|
|
676
720
|
|
|
677
721
|
Args:
|
|
678
|
-
messages (list[dict]): A list of message dictionaries
|
|
679
|
-
[{"role": "user", "content": "Hello!"}, ...]
|
|
722
|
+
messages (list[dict]): A list of message dictionaries.
|
|
680
723
|
|
|
681
724
|
Returns:
|
|
682
725
|
The formatted prompt string ready for model input.
|
|
683
726
|
"""
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
# `tokenize=False` means it returns a string, not token IDs
|
|
727
|
+
# Use the tokenizer's apply_chat_template method.
|
|
728
|
+
# We ensured a template exists in __init__.
|
|
729
|
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
688
730
|
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
689
731
|
messages,
|
|
690
732
|
tokenize=False,
|
|
691
733
|
add_generation_prompt=True,
|
|
692
734
|
)
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
735
|
+
|
|
736
|
+
# Fallback for very old transformers without apply_chat_template
|
|
737
|
+
# Manually apply ChatML-like formatting
|
|
738
|
+
prompt = ""
|
|
739
|
+
for message in messages:
|
|
740
|
+
role = message.get("role", "user")
|
|
741
|
+
content = message.get("content", "")
|
|
742
|
+
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
|
743
|
+
prompt += "<|im_start|>assistant\n"
|
|
744
|
+
return prompt
|
|
745
|
+
|
|
746
|
+
def _get_stopping_criteria(self, stop_strings: list[str]) -> "transformers.StoppingCriteriaList":
|
|
747
|
+
|
|
748
|
+
import transformers
|
|
749
|
+
|
|
750
|
+
class StopStringsStoppingCriteria(transformers.StoppingCriteria):
|
|
751
|
+
def __init__(self, stop_strings: list[str], tokenizer: Any) -> None:
|
|
752
|
+
self.stop_strings = stop_strings
|
|
753
|
+
self.tokenizer = tokenizer
|
|
754
|
+
|
|
755
|
+
def __call__(self, input_ids: "torch.Tensor", scores: "torch.Tensor", **kwargs: Any) -> bool:
|
|
756
|
+
# Decode the generated text for each sequence
|
|
757
|
+
for i in range(input_ids.shape[0]):
|
|
758
|
+
generated_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True)
|
|
759
|
+
# Check if any stop string appears in the generated text
|
|
760
|
+
for stop_str in self.stop_strings:
|
|
761
|
+
if stop_str in generated_text:
|
|
762
|
+
return True
|
|
763
|
+
return False
|
|
764
|
+
|
|
765
|
+
return transformers.StoppingCriteriaList([StopStringsStoppingCriteria(stop_strings, self.tokenizer)])
|
|
708
766
|
|
|
709
767
|
def generate_chat_completion(
|
|
710
768
|
self,
|
|
@@ -727,18 +785,17 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
727
785
|
{"role": "user", "content": "What is deep learning?"}]
|
|
728
786
|
max_completion_tokens (int): The maximum number of completion tokens to generate.
|
|
729
787
|
stop_strings (list[str]): A list of strings to stop generation.
|
|
730
|
-
temperature (float): The temperature for sampling.
|
|
731
|
-
top_p (float): The top-p value for sampling.
|
|
732
|
-
stream (bool): Whether to stream the generation.
|
|
733
|
-
frequency_penalty (float): The frequency penalty for sampling.
|
|
734
|
-
presence_penalty (float): The presence penalty for sampling.
|
|
788
|
+
temperature (float): The temperature for sampling. 0 means greedy decoding.
|
|
789
|
+
top_p (float): The top-p value for nucleus sampling.
|
|
790
|
+
stream (bool): Whether to stream the generation (not yet supported).
|
|
791
|
+
frequency_penalty (float): The frequency penalty for sampling (maps to repetition_penalty).
|
|
792
|
+
presence_penalty (float): The presence penalty for sampling (not directly supported).
|
|
735
793
|
n (int): The number of samples to generate.
|
|
736
794
|
|
|
737
795
|
Returns:
|
|
738
796
|
dict: An OpenAI-compatible dictionary representing the chat completion.
|
|
739
797
|
"""
|
|
740
798
|
# Apply chat template to convert messages into a single prompt string
|
|
741
|
-
|
|
742
799
|
prompt_text = self._apply_chat_template(messages)
|
|
743
800
|
|
|
744
801
|
# Tokenize the prompt
|
|
@@ -749,42 +806,112 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
749
806
|
).to(self.model.device)
|
|
750
807
|
prompt_tokens = inputs.input_ids.shape[1]
|
|
751
808
|
|
|
752
|
-
|
|
809
|
+
if stream:
|
|
810
|
+
logger.warning(
|
|
811
|
+
"Streaming is not supported using transformers.Pipeline implementation. Ignoring stream=True."
|
|
812
|
+
)
|
|
813
|
+
stream = False
|
|
814
|
+
|
|
815
|
+
if presence_penalty is not None:
|
|
816
|
+
logger.warning(
|
|
817
|
+
"Presence penalty is not supported using transformers.Pipeline implementation."
|
|
818
|
+
" Ignoring presence_penalty."
|
|
819
|
+
)
|
|
820
|
+
presence_penalty = None
|
|
821
|
+
|
|
822
|
+
import transformers
|
|
823
|
+
|
|
824
|
+
transformers_version = version.parse(transformers.__version__)
|
|
825
|
+
|
|
826
|
+
# Stop strings are supported in transformers >= 4.43.0
|
|
827
|
+
can_handle_stop_strings = transformers_version >= version.parse("4.43.0")
|
|
753
828
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
829
|
+
# Determine sampling based on temperature (following serve.py logic)
|
|
830
|
+
# Default temperature to 1.0 if not specified
|
|
831
|
+
actual_temperature = temperature if temperature is not None else 1.0
|
|
832
|
+
do_sample = actual_temperature > 0.0
|
|
833
|
+
|
|
834
|
+
# Set up generation config following best practices from serve.py
|
|
835
|
+
generation_config = transformers.GenerationConfig(
|
|
836
|
+
max_new_tokens=max_completion_tokens if max_completion_tokens is not None else 1024,
|
|
758
837
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
759
838
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
760
|
-
|
|
761
|
-
stream=stream,
|
|
762
|
-
num_return_sequences=n,
|
|
763
|
-
num_beams=max(1, n), # must be >1
|
|
764
|
-
repetition_penalty=frequency_penalty,
|
|
765
|
-
# TODO: Handle diversity_penalty and num_beam_groups
|
|
766
|
-
# not all models support them making it hard to support any huggingface model
|
|
767
|
-
# diversity_penalty=presence_penalty if n > 1 else None,
|
|
768
|
-
# num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
769
|
-
do_sample=False,
|
|
839
|
+
do_sample=do_sample,
|
|
770
840
|
)
|
|
771
841
|
|
|
842
|
+
# Only set temperature and top_p if sampling is enabled
|
|
843
|
+
if do_sample:
|
|
844
|
+
generation_config.temperature = actual_temperature
|
|
845
|
+
if top_p is not None:
|
|
846
|
+
generation_config.top_p = top_p
|
|
847
|
+
|
|
848
|
+
# Handle repetition penalty (mapped from frequency_penalty)
|
|
849
|
+
if frequency_penalty is not None:
|
|
850
|
+
# OpenAI's frequency_penalty is typically in range [-2.0, 2.0]
|
|
851
|
+
# HuggingFace's repetition_penalty is typically > 0, with 1.0 = no penalty
|
|
852
|
+
# We need to convert: frequency_penalty=0 -> repetition_penalty=1.0
|
|
853
|
+
# Higher frequency_penalty should increase repetition_penalty
|
|
854
|
+
generation_config.repetition_penalty = 1.0 + (frequency_penalty if frequency_penalty > 0 else 0)
|
|
855
|
+
|
|
856
|
+
# For multiple completions (n > 1), use sampling not beam search
|
|
857
|
+
if n > 1:
|
|
858
|
+
generation_config.num_return_sequences = n
|
|
859
|
+
# Force sampling on for multiple sequences
|
|
860
|
+
if not do_sample:
|
|
861
|
+
logger.warning("Forcing do_sample=True for n>1. Consider setting temperature > 0 for better diversity.")
|
|
862
|
+
generation_config.do_sample = True
|
|
863
|
+
generation_config.temperature = 1.0
|
|
864
|
+
else:
|
|
865
|
+
generation_config.num_return_sequences = 1
|
|
866
|
+
|
|
867
|
+
# Handle stop strings if provided
|
|
868
|
+
stopping_criteria = None
|
|
869
|
+
if stop_strings and not can_handle_stop_strings:
|
|
870
|
+
logger.warning("Stop strings are not supported in transformers < 4.41.0. Ignoring stop strings.")
|
|
871
|
+
|
|
872
|
+
if stop_strings and can_handle_stop_strings:
|
|
873
|
+
stopping_criteria = self._get_stopping_criteria(stop_strings)
|
|
874
|
+
output_ids = self.model.generate(
|
|
875
|
+
inputs.input_ids,
|
|
876
|
+
attention_mask=inputs.attention_mask,
|
|
877
|
+
generation_config=generation_config,
|
|
878
|
+
# Pass tokenizer for proper handling of stop strings
|
|
879
|
+
tokenizer=self.tokenizer,
|
|
880
|
+
stopping_criteria=stopping_criteria,
|
|
881
|
+
)
|
|
882
|
+
else:
|
|
883
|
+
output_ids = self.model.generate(
|
|
884
|
+
inputs.input_ids,
|
|
885
|
+
attention_mask=inputs.attention_mask,
|
|
886
|
+
generation_config=generation_config,
|
|
887
|
+
)
|
|
888
|
+
|
|
772
889
|
# Generate text
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
generation_config=generation_config,
|
|
777
|
-
)
|
|
890
|
+
# Handle the case where output might be 1D if n=1
|
|
891
|
+
if output_ids.dim() == 1:
|
|
892
|
+
output_ids = output_ids.unsqueeze(0)
|
|
778
893
|
|
|
779
894
|
generated_texts = []
|
|
780
895
|
completion_tokens = 0
|
|
781
896
|
total_tokens = prompt_tokens
|
|
897
|
+
|
|
782
898
|
for output_id in output_ids:
|
|
783
899
|
# The output_ids include the input prompt
|
|
784
900
|
# Decode the generated text, excluding the input prompt
|
|
785
901
|
# so we slice to get only new tokens
|
|
786
902
|
generated_tokens = output_id[prompt_tokens:]
|
|
787
903
|
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
904
|
+
|
|
905
|
+
# Trim stop strings from generated text if they appear
|
|
906
|
+
# The stop criteria would stop generating further tokens, so we need to trim the generated text
|
|
907
|
+
if stop_strings and can_handle_stop_strings:
|
|
908
|
+
for stop_str in stop_strings:
|
|
909
|
+
if stop_str in generated_text:
|
|
910
|
+
# Find the first occurrence and trim everything from there
|
|
911
|
+
stop_idx = generated_text.find(stop_str)
|
|
912
|
+
generated_text = generated_text[:stop_idx]
|
|
913
|
+
break # Stop after finding the first stop string
|
|
914
|
+
|
|
788
915
|
generated_texts.append(generated_text)
|
|
789
916
|
|
|
790
917
|
# Calculate completion tokens
|
|
@@ -148,12 +148,17 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
|
148
148
|
|
|
149
149
|
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
|
150
150
|
|
|
151
|
+
# MLflow 3.x may return file:// URIs for artifact_path; extract just the last path component
|
|
152
|
+
artifact_path = model_info.artifact_path
|
|
153
|
+
if artifact_path.startswith("file://"):
|
|
154
|
+
artifact_path = artifact_path.rstrip("/").split("/")[-1]
|
|
155
|
+
|
|
151
156
|
base_meta = model_blob_meta.ModelBlobMeta(
|
|
152
157
|
name=name,
|
|
153
158
|
model_type=cls.HANDLER_TYPE,
|
|
154
159
|
handler_version=cls.HANDLER_VERSION,
|
|
155
160
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
|
156
|
-
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path":
|
|
161
|
+
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": artifact_path}),
|
|
157
162
|
)
|
|
158
163
|
model_meta.models[name] = base_meta
|
|
159
164
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|