snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -105,7 +105,7 @@ class ModelMethod:
105
105
  except ValueError as e:
106
106
  raise ValueError(
107
107
  f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
108
- "Try specify `case_sensitive` as True."
108
+ "Try specifying `case_sensitive` as True."
109
109
  ) from e
110
110
 
111
111
  if self.target_method not in self.model_meta.signatures.keys():
@@ -127,12 +127,41 @@ class ModelMethod:
127
127
  except ValueError as e:
128
128
  raise ValueError(
129
129
  f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
130
- "Try specify `case_sensitive` as True."
130
+ "Try specifying `case_sensitive` as True."
131
131
  ) from e
132
132
  return model_manifest_schema.ModelMethodSignatureFieldWithName(
133
133
  name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
134
134
  )
135
135
 
136
+ @staticmethod
137
+ def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
138
+ """Flatten ParamGroupSpec into leaf ParamSpec items."""
139
+ result: list[model_signature.ParamSpec] = []
140
+ for param in params:
141
+ if isinstance(param, model_signature.ParamSpec):
142
+ result.append(param)
143
+ elif isinstance(param, model_signature.ParamGroupSpec):
144
+ result.extend(ModelMethod._flatten_params(param.specs))
145
+ return result
146
+
147
+ @staticmethod
148
+ def _get_method_arg_from_param(
149
+ param_spec: model_signature.ParamSpec,
150
+ case_sensitive: bool = False,
151
+ ) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
152
+ try:
153
+ param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
154
+ except ValueError as e:
155
+ raise ValueError(
156
+ f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
+ "Try specifying `case_sensitive` as True."
158
+ ) from e
159
+ return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
160
+ name=param_name.resolved(),
161
+ type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
162
+ default=param_spec.default_value,
163
+ )
164
+
136
165
  def save(
137
166
  self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
138
167
  ) -> model_manifest_schema.ModelMethodDict:
@@ -182,6 +211,36 @@ class ModelMethod:
182
211
  inputs=input_list,
183
212
  outputs=outputs,
184
213
  )
214
+
215
+ # Add parameters if signature has parameters
216
+ if self.model_meta.signatures[self.target_method].params:
217
+ flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
218
+ param_list = [
219
+ ModelMethod._get_method_arg_from_param(
220
+ param_spec, case_sensitive=self.options.get("case_sensitive", False)
221
+ )
222
+ for param_spec in flat_params
223
+ ]
224
+ param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
225
+ dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
226
+ if dup_param_names:
227
+ raise ValueError(
228
+ f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
229
+ f" {self.target_method}. This might be because you have parameters with same letters but "
230
+ "different cases. In this case, set case_sensitive as True for those methods to distinguish them."
231
+ )
232
+
233
+ # Check for name collisions between parameters and inputs using existing counters
234
+ collision_names = [name for name in param_name_counter if name in input_name_counter]
235
+ if collision_names:
236
+ raise ValueError(
237
+ f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
238
+ f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
239
+ "Try using case_sensitive=True if the names differ only by case."
240
+ )
241
+
242
+ method_dict["params"] = param_list
243
+
185
244
  should_set_volatility = (
186
245
  platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
187
246
  )
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import importlib
3
+ import logging
3
4
  import pkgutil
4
5
  from types import ModuleType
5
6
  from typing import Any, Callable, Optional, TypeVar, cast
@@ -11,6 +12,8 @@ _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
11
12
  _MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
12
13
  _IS_HANDLER_LOADED = False
13
14
 
15
+ logger = logging.getLogger(__name__)
16
+
14
17
 
15
18
  def _register_handlers() -> None:
16
19
  """
@@ -56,8 +59,11 @@ def find_handler(
56
59
  model: model_types.SupportedModelType,
57
60
  ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
58
61
  for handler in _MODEL_HANDLER_REGISTRY.values():
59
- if handler.can_handle(model):
60
- return handler
62
+ try:
63
+ if handler.can_handle(model):
64
+ return handler
65
+ except Exception:
66
+ logger.error(f"Error in {handler.__name__} `can_handle` method for model {type(model)}", exc_info=True)
61
67
  return None
62
68
 
63
69
 
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
86
86
  get_prediction_fn=get_prediction,
87
87
  )
88
88
 
89
+ # Add parameters extracted from custom model inference methods to signatures
90
+ cls._add_method_parameters_to_signatures(model, model_meta)
91
+
89
92
  model_blob_path = os.path.join(model_blobs_dir_path, name)
90
93
  os.makedirs(model_blob_path, exist_ok=True)
91
94
  if model.context.artifacts:
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
188
191
  assert isinstance(model, custom_model.CustomModel)
189
192
  return model
190
193
 
194
+ @classmethod
195
+ def _add_method_parameters_to_signatures(
196
+ cls,
197
+ model: "custom_model.CustomModel",
198
+ model_meta: model_meta_api.ModelMetadata,
199
+ ) -> None:
200
+ """Extract parameters from custom model inference methods and add them to signatures.
201
+
202
+ For each inference method, if the signature doesn't already have parameters and the method
203
+ has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
204
+
205
+ Args:
206
+ model: The custom model instance.
207
+ model_meta: The model metadata containing signatures to augment.
208
+ """
209
+ for method in model._get_infer_methods():
210
+ method_name = method.__name__
211
+ if method_name not in model_meta.signatures:
212
+ continue
213
+
214
+ sig = model_meta.signatures[method_name]
215
+
216
+ # Skip if the signature already has parameters (user-provided or previously set)
217
+ if sig.params:
218
+ continue
219
+
220
+ # Extract parameters from the method
221
+ method_params = custom_model.get_method_parameters(method)
222
+ if not method_params:
223
+ continue
224
+
225
+ # Create ParamSpecs from the method parameters
226
+ param_specs = []
227
+ for param_name, param_type, param_default in method_params:
228
+ dtype = model_signature.DataType.from_python_type(param_type)
229
+ param_spec = model_signature.ParamSpec(
230
+ name=param_name,
231
+ dtype=dtype,
232
+ default_value=param_default,
233
+ )
234
+ param_specs.append(param_spec)
235
+
236
+ # Create a new signature with parameters
237
+ model_meta.signatures[method_name] = model_signature.ModelSignature(
238
+ inputs=sig.inputs,
239
+ outputs=sig.outputs,
240
+ params=param_specs,
241
+ )
242
+
191
243
  @classmethod
192
244
  def convert_as_custom_model(
193
245
  cls,
@@ -29,12 +29,16 @@ from snowflake.ml.model._packager.model_meta import (
29
29
  model_meta_schema,
30
30
  )
31
31
  from snowflake.ml.model._signatures import utils as model_signature_utils
32
- from snowflake.ml.model.models import huggingface_pipeline
32
+ from snowflake.ml.model.models import (
33
+ huggingface as huggingface_base,
34
+ huggingface_pipeline,
35
+ )
33
36
  from snowflake.snowpark._internal import utils as snowpark_utils
34
37
 
35
38
  logger = logging.getLogger(__name__)
36
39
 
37
40
  if TYPE_CHECKING:
41
+ import torch
38
42
  import transformers
39
43
 
40
44
  DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
@@ -77,8 +81,14 @@ def sanitize_output(data: Any) -> Any:
77
81
 
78
82
 
79
83
  @final
80
- class HuggingFacePipelineHandler(
81
- _base.BaseModelHandler[Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]]
84
+ class TransformersPipelineHandler(
85
+ _base.BaseModelHandler[
86
+ Union[
87
+ huggingface_base.TransformersPipeline,
88
+ huggingface_pipeline.HuggingFacePipelineModel,
89
+ "transformers.Pipeline",
90
+ ]
91
+ ]
82
92
  ):
83
93
  """Handler for custom model."""
84
94
 
@@ -97,35 +107,48 @@ class HuggingFacePipelineHandler(
97
107
  def can_handle(
98
108
  cls,
99
109
  model: model_types.SupportedModelType,
100
- ) -> TypeGuard[Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]]:
110
+ ) -> TypeGuard[
111
+ Union[
112
+ huggingface_base.TransformersPipeline,
113
+ huggingface_pipeline.HuggingFacePipelineModel,
114
+ "transformers.Pipeline",
115
+ ]
116
+ ]:
101
117
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
102
118
  return True
103
119
  if isinstance(model, huggingface_pipeline.HuggingFacePipelineModel):
104
120
  return True
121
+ if isinstance(model, huggingface_base.TransformersPipeline):
122
+ return True
105
123
  return False
106
124
 
107
125
  @classmethod
108
126
  def cast_model(
109
127
  cls,
110
128
  model: model_types.SupportedModelType,
111
- ) -> Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]:
112
- try:
113
- if isinstance(model, huggingface_pipeline.HuggingFacePipelineModel):
114
- raise ImportError
115
- else:
116
- import transformers
117
- except ImportError:
118
- assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
129
+ ) -> Union[
130
+ huggingface_base.TransformersPipeline,
131
+ huggingface_pipeline.HuggingFacePipelineModel,
132
+ "transformers.Pipeline",
133
+ ]:
134
+ if type_utils.LazyType("transformers.Pipeline").isinstance(model):
119
135
  return model
120
- else:
121
- assert isinstance(model, transformers.Pipeline)
136
+ elif isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
137
+ model, huggingface_base.TransformersPipeline
138
+ ):
122
139
  return model
140
+ else:
141
+ raise ValueError(f"Model {model} is not a valid Hugging Face model.")
123
142
 
124
143
  @classmethod
125
144
  def save_model(
126
145
  cls,
127
146
  name: str,
128
- model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
147
+ model: Union[
148
+ huggingface_base.TransformersPipeline,
149
+ huggingface_pipeline.HuggingFacePipelineModel,
150
+ "transformers.Pipeline",
151
+ ],
129
152
  model_meta: model_meta_api.ModelMetadata,
130
153
  model_blobs_dir_path: str,
131
154
  sample_input_data: Optional[model_types.SupportedDataType] = None,
@@ -140,7 +163,9 @@ class HuggingFacePipelineHandler(
140
163
  framework = model.framework # type:ignore[attr-defined]
141
164
  batch_size = model._batch_size # type:ignore[attr-defined]
142
165
  else:
143
- assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
166
+ assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
167
+ model, huggingface_base.TransformersPipeline
168
+ )
144
169
  task = model.task
145
170
  framework = getattr(model, "framework", None)
146
171
  batch_size = getattr(model, "batch_size", None)
@@ -156,7 +181,9 @@ class HuggingFacePipelineHandler(
156
181
  **model._postprocess_params, # type:ignore[attr-defined]
157
182
  }
158
183
  else:
159
- assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
184
+ assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
185
+ model, huggingface_base.TransformersPipeline
186
+ )
160
187
  params = {**model.__dict__, **model.model_kwargs}
161
188
 
162
189
  inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
@@ -177,7 +204,7 @@ class HuggingFacePipelineHandler(
177
204
  else:
178
205
  warnings.warn(
179
206
  "It is impossible to validate your model signatures when using a"
180
- " `snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel` object. "
207
+ f" {type(model).__name__} object. "
181
208
  "Please make sure you are providing correct model signatures.",
182
209
  UserWarning,
183
210
  stacklevel=2,
@@ -302,14 +329,16 @@ class HuggingFacePipelineHandler(
302
329
  def _load_pickle_model(
303
330
  pickle_file: str,
304
331
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
305
- ) -> huggingface_pipeline.HuggingFacePipelineModel:
332
+ ) -> Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline]:
306
333
  with open(pickle_file, "rb") as f:
307
334
  m = cloudpickle.load(f)
308
- assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
335
+ assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
336
+ m, huggingface_base.TransformersPipeline
337
+ )
309
338
  torch_dtype: Optional[str] = None
310
339
  device_config = None
311
340
  if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
312
- device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
341
+ device_config = TransformersPipelineHandler._get_device_config(**kwargs)
313
342
  m.__dict__.update(device_config)
314
343
 
315
344
  if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
@@ -326,7 +355,9 @@ class HuggingFacePipelineHandler(
326
355
  model_meta: model_meta_api.ModelMetadata,
327
356
  model_blobs_dir_path: str,
328
357
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
329
- ) -> Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]:
358
+ ) -> Union[
359
+ huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline, "transformers.Pipeline"
360
+ ]:
330
361
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
331
362
  # We need to redirect the some folders to a writable location in the sandbox.
332
363
  os.environ["HF_HOME"] = "/tmp"
@@ -369,7 +400,7 @@ class HuggingFacePipelineHandler(
369
400
  ) as f:
370
401
  pipeline_params = cloudpickle.load(f)
371
402
 
372
- device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
403
+ device_config = TransformersPipelineHandler._get_device_config(**kwargs)
373
404
 
374
405
  m = transformers.pipeline(
375
406
  model_blob_options["task"],
@@ -402,7 +433,7 @@ class HuggingFacePipelineHandler(
402
433
 
403
434
  def _create_pipeline_from_model(
404
435
  model_blob_file_or_dir_path: str,
405
- m: huggingface_pipeline.HuggingFacePipelineModel,
436
+ m: Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline],
406
437
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
407
438
  ) -> "transformers.Pipeline":
408
439
  import transformers
@@ -414,7 +445,7 @@ class HuggingFacePipelineHandler(
414
445
  torch_dtype=getattr(m, "torch_dtype", None),
415
446
  revision=m.revision,
416
447
  # pass device or device_map when creating the pipeline
417
- **HuggingFacePipelineHandler._get_device_config(**kwargs),
448
+ **TransformersPipelineHandler._get_device_config(**kwargs),
418
449
  # pass other model_kwargs to transformers.pipeline.from_pretrained method
419
450
  **m.model_kwargs,
420
451
  )
@@ -455,7 +486,11 @@ class HuggingFacePipelineHandler(
455
486
  @classmethod
456
487
  def convert_as_custom_model(
457
488
  cls,
458
- raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
489
+ raw_model: Union[
490
+ huggingface_pipeline.HuggingFacePipelineModel,
491
+ huggingface_base.TransformersPipeline,
492
+ "transformers.Pipeline",
493
+ ],
459
494
  model_meta: model_meta_api.ModelMetadata,
460
495
  background_data: Optional[pd.DataFrame] = None,
461
496
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
@@ -609,7 +644,9 @@ class HuggingFacePipelineHandler(
609
644
 
610
645
  return _HFPipelineModel
611
646
 
612
- if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel):
647
+ if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
648
+ raw_model, huggingface_base.TransformersPipeline
649
+ ):
613
650
  if version.parse(transformers.__version__) < version.parse("4.32.0"):
614
651
  # Backward compatibility since HF interface change.
615
652
  raw_model.__dict__["use_auth_token"] = raw_model.__dict__["token"]
@@ -668,43 +705,64 @@ class HuggingFaceOpenAICompatibleModel:
668
705
 
669
706
  self.model_name = self.pipeline.model.name_or_path
670
707
 
708
+ if self.tokenizer.pad_token is None:
709
+ self.tokenizer.pad_token = self.tokenizer.eos_token
710
+
711
+ # Ensure the tokenizer has a chat template.
712
+ # If not, we inject the default ChatML template which supports prompt generation.
713
+ if not getattr(self.tokenizer, "chat_template", None):
714
+ logger.warning(f"No chat template found for {self.model_name}. Using default ChatML template.")
715
+ self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
716
+
671
717
  def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
672
718
  """
673
719
  Applies a chat template to a list of messages.
674
- If the tokenizer has a chat template, it uses that.
675
- Otherwise, it falls back to a simple concatenation.
676
720
 
677
721
  Args:
678
- messages (list[dict]): A list of message dictionaries, e.g.,
679
- [{"role": "user", "content": "Hello!"}, ...]
722
+ messages (list[dict]): A list of message dictionaries.
680
723
 
681
724
  Returns:
682
725
  The formatted prompt string ready for model input.
683
726
  """
684
-
685
- if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
686
- # Use the tokenizer's built-in chat template if available
687
- # `tokenize=False` means it returns a string, not token IDs
727
+ # Use the tokenizer's apply_chat_template method.
728
+ # We ensured a template exists in __init__.
729
+ if hasattr(self.tokenizer, "apply_chat_template"):
688
730
  return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
689
731
  messages,
690
732
  tokenize=False,
691
733
  add_generation_prompt=True,
692
734
  )
693
- else:
694
- # Fallback to a simple concatenation for models without a specific chat template
695
- # This is a basic example; real chat models often need specific formatting.
696
- prompt = ""
697
- for message in messages:
698
- role = message.get("role", "user")
699
- content = message.get("content", "")
700
- if role == "system":
701
- prompt += f"System: {content}\n"
702
- elif role == "user":
703
- prompt += f"User: {content}\n"
704
- elif role == "assistant":
705
- prompt += f"Assistant: {content}\n"
706
- prompt += "Assistant:" # Indicate that the assistant should respond
707
- return prompt
735
+
736
+ # Fallback for very old transformers without apply_chat_template
737
+ # Manually apply ChatML-like formatting
738
+ prompt = ""
739
+ for message in messages:
740
+ role = message.get("role", "user")
741
+ content = message.get("content", "")
742
+ prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
743
+ prompt += "<|im_start|>assistant\n"
744
+ return prompt
745
+
746
+ def _get_stopping_criteria(self, stop_strings: list[str]) -> "transformers.StoppingCriteriaList":
747
+
748
+ import transformers
749
+
750
+ class StopStringsStoppingCriteria(transformers.StoppingCriteria):
751
+ def __init__(self, stop_strings: list[str], tokenizer: Any) -> None:
752
+ self.stop_strings = stop_strings
753
+ self.tokenizer = tokenizer
754
+
755
+ def __call__(self, input_ids: "torch.Tensor", scores: "torch.Tensor", **kwargs: Any) -> bool:
756
+ # Decode the generated text for each sequence
757
+ for i in range(input_ids.shape[0]):
758
+ generated_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True)
759
+ # Check if any stop string appears in the generated text
760
+ for stop_str in self.stop_strings:
761
+ if stop_str in generated_text:
762
+ return True
763
+ return False
764
+
765
+ return transformers.StoppingCriteriaList([StopStringsStoppingCriteria(stop_strings, self.tokenizer)])
708
766
 
709
767
  def generate_chat_completion(
710
768
  self,
@@ -727,18 +785,17 @@ class HuggingFaceOpenAICompatibleModel:
727
785
  {"role": "user", "content": "What is deep learning?"}]
728
786
  max_completion_tokens (int): The maximum number of completion tokens to generate.
729
787
  stop_strings (list[str]): A list of strings to stop generation.
730
- temperature (float): The temperature for sampling.
731
- top_p (float): The top-p value for sampling.
732
- stream (bool): Whether to stream the generation.
733
- frequency_penalty (float): The frequency penalty for sampling.
734
- presence_penalty (float): The presence penalty for sampling.
788
+ temperature (float): The temperature for sampling. 0 means greedy decoding.
789
+ top_p (float): The top-p value for nucleus sampling.
790
+ stream (bool): Whether to stream the generation (not yet supported).
791
+ frequency_penalty (float): The frequency penalty for sampling (maps to repetition_penalty).
792
+ presence_penalty (float): The presence penalty for sampling (not directly supported).
735
793
  n (int): The number of samples to generate.
736
794
 
737
795
  Returns:
738
796
  dict: An OpenAI-compatible dictionary representing the chat completion.
739
797
  """
740
798
  # Apply chat template to convert messages into a single prompt string
741
-
742
799
  prompt_text = self._apply_chat_template(messages)
743
800
 
744
801
  # Tokenize the prompt
@@ -749,42 +806,112 @@ class HuggingFaceOpenAICompatibleModel:
749
806
  ).to(self.model.device)
750
807
  prompt_tokens = inputs.input_ids.shape[1]
751
808
 
752
- from transformers import GenerationConfig
809
+ if stream:
810
+ logger.warning(
811
+ "Streaming is not supported using transformers.Pipeline implementation. Ignoring stream=True."
812
+ )
813
+ stream = False
814
+
815
+ if presence_penalty is not None:
816
+ logger.warning(
817
+ "Presence penalty is not supported using transformers.Pipeline implementation."
818
+ " Ignoring presence_penalty."
819
+ )
820
+ presence_penalty = None
821
+
822
+ import transformers
823
+
824
+ transformers_version = version.parse(transformers.__version__)
825
+
826
+ # Stop strings are supported in transformers >= 4.43.0
827
+ can_handle_stop_strings = transformers_version >= version.parse("4.43.0")
753
828
 
754
- generation_config = GenerationConfig(
755
- max_new_tokens=max_completion_tokens,
756
- temperature=temperature,
757
- top_p=top_p,
829
+ # Determine sampling based on temperature (following serve.py logic)
830
+ # Default temperature to 1.0 if not specified
831
+ actual_temperature = temperature if temperature is not None else 1.0
832
+ do_sample = actual_temperature > 0.0
833
+
834
+ # Set up generation config following best practices from serve.py
835
+ generation_config = transformers.GenerationConfig(
836
+ max_new_tokens=max_completion_tokens if max_completion_tokens is not None else 1024,
758
837
  pad_token_id=self.tokenizer.pad_token_id,
759
838
  eos_token_id=self.tokenizer.eos_token_id,
760
- stop_strings=stop_strings,
761
- stream=stream,
762
- num_return_sequences=n,
763
- num_beams=max(1, n), # must be >1
764
- repetition_penalty=frequency_penalty,
765
- # TODO: Handle diversity_penalty and num_beam_groups
766
- # not all models support them making it hard to support any huggingface model
767
- # diversity_penalty=presence_penalty if n > 1 else None,
768
- # num_beam_groups=max(2, n) if presence_penalty else 1,
769
- do_sample=False,
839
+ do_sample=do_sample,
770
840
  )
771
841
 
842
+ # Only set temperature and top_p if sampling is enabled
843
+ if do_sample:
844
+ generation_config.temperature = actual_temperature
845
+ if top_p is not None:
846
+ generation_config.top_p = top_p
847
+
848
+ # Handle repetition penalty (mapped from frequency_penalty)
849
+ if frequency_penalty is not None:
850
+ # OpenAI's frequency_penalty is typically in range [-2.0, 2.0]
851
+ # HuggingFace's repetition_penalty is typically > 0, with 1.0 = no penalty
852
+ # We need to convert: frequency_penalty=0 -> repetition_penalty=1.0
853
+ # Higher frequency_penalty should increase repetition_penalty
854
+ generation_config.repetition_penalty = 1.0 + (frequency_penalty if frequency_penalty > 0 else 0)
855
+
856
+ # For multiple completions (n > 1), use sampling not beam search
857
+ if n > 1:
858
+ generation_config.num_return_sequences = n
859
+ # Force sampling on for multiple sequences
860
+ if not do_sample:
861
+ logger.warning("Forcing do_sample=True for n>1. Consider setting temperature > 0 for better diversity.")
862
+ generation_config.do_sample = True
863
+ generation_config.temperature = 1.0
864
+ else:
865
+ generation_config.num_return_sequences = 1
866
+
867
+ # Handle stop strings if provided
868
+ stopping_criteria = None
869
+ if stop_strings and not can_handle_stop_strings:
870
+ logger.warning("Stop strings are not supported in transformers < 4.41.0. Ignoring stop strings.")
871
+
872
+ if stop_strings and can_handle_stop_strings:
873
+ stopping_criteria = self._get_stopping_criteria(stop_strings)
874
+ output_ids = self.model.generate(
875
+ inputs.input_ids,
876
+ attention_mask=inputs.attention_mask,
877
+ generation_config=generation_config,
878
+ # Pass tokenizer for proper handling of stop strings
879
+ tokenizer=self.tokenizer,
880
+ stopping_criteria=stopping_criteria,
881
+ )
882
+ else:
883
+ output_ids = self.model.generate(
884
+ inputs.input_ids,
885
+ attention_mask=inputs.attention_mask,
886
+ generation_config=generation_config,
887
+ )
888
+
772
889
  # Generate text
773
- output_ids = self.model.generate(
774
- inputs.input_ids,
775
- attention_mask=inputs.attention_mask,
776
- generation_config=generation_config,
777
- )
890
+ # Handle the case where output might be 1D if n=1
891
+ if output_ids.dim() == 1:
892
+ output_ids = output_ids.unsqueeze(0)
778
893
 
779
894
  generated_texts = []
780
895
  completion_tokens = 0
781
896
  total_tokens = prompt_tokens
897
+
782
898
  for output_id in output_ids:
783
899
  # The output_ids include the input prompt
784
900
  # Decode the generated text, excluding the input prompt
785
901
  # so we slice to get only new tokens
786
902
  generated_tokens = output_id[prompt_tokens:]
787
903
  generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
904
+
905
+ # Trim stop strings from generated text if they appear
906
+ # The stop criteria would stop generating further tokens, so we need to trim the generated text
907
+ if stop_strings and can_handle_stop_strings:
908
+ for stop_str in stop_strings:
909
+ if stop_str in generated_text:
910
+ # Find the first occurrence and trim everything from there
911
+ stop_idx = generated_text.find(stop_str)
912
+ generated_text = generated_text[:stop_idx]
913
+ break # Stop after finding the first stop string
914
+
788
915
  generated_texts.append(generated_text)
789
916
 
790
917
  # Calculate completion tokens
@@ -148,12 +148,17 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
148
148
 
149
149
  file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
150
150
 
151
+ # MLflow 3.x may return file:// URIs for artifact_path; extract just the last path component
152
+ artifact_path = model_info.artifact_path
153
+ if artifact_path.startswith("file://"):
154
+ artifact_path = artifact_path.rstrip("/").split("/")[-1]
155
+
151
156
  base_meta = model_blob_meta.ModelBlobMeta(
152
157
  name=name,
153
158
  model_type=cls.HANDLER_TYPE,
154
159
  handler_version=cls.HANDLER_VERSION,
155
160
  path=cls.MODEL_BLOB_FILE_OR_DIR,
156
- options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
161
+ options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": artifact_path}),
157
162
  )
158
163
  model_meta.models[name] = base_meta
159
164
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION