snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/data/__init__.py +3 -0
  13. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  14. snowflake/ml/data/data_connector.py +53 -11
  15. snowflake/ml/data/data_ingestor.py +2 -1
  16. snowflake/ml/data/torch_utils.py +18 -5
  17. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  18. snowflake/ml/fileset/fileset.py +18 -18
  19. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  20. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  21. snowflake/ml/model/_client/sql/model_version.py +11 -0
  22. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  25. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  26. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  31. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  32. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  33. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  35. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  36. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  37. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  38. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  39. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  40. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  43. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  44. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  45. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  46. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  47. snowflake/ml/model/type_hints.py +1 -0
  48. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  49. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  50. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  51. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  52. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  53. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  54. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  55. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  56. snowflake/ml/registry/_manager/model_manager.py +70 -33
  57. snowflake/ml/registry/registry.py +41 -22
  58. snowflake/ml/version.py +1 -1
  59. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +38 -9
  60. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +63 -67
  61. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  62. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  63. snowflake/ml/fileset/parquet_parser.py +0 -170
  64. snowflake/ml/fileset/tf_dataset.py +0 -88
  65. snowflake/ml/fileset/torch_datapipe.py +0 -57
  66. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  67. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  68. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  69. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -195,8 +195,12 @@ class HuggingFacePipelineHandler(
195
195
  os.makedirs(model_blob_path, exist_ok=True)
196
196
 
197
197
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
198
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
198
199
  model.save_pretrained( # type:ignore[attr-defined]
199
- os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
200
+ save_path
201
+ )
202
+ handlers_utils.save_transformers_config_with_auto_map(
203
+ save_path,
200
204
  )
201
205
  pipeline_params = {
202
206
  "_batch_size": model._batch_size, # type:ignore[attr-defined]
@@ -110,8 +110,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
110
110
  sample_input_data=sample_input_data,
111
111
  get_prediction_fn=get_prediction,
112
112
  )
113
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
114
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
113
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
114
+ model_meta.task = model_task_and_output.task
115
115
  if enable_explainability:
116
116
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
117
117
  model_meta = handlers_utils.add_explain_method_signature(
@@ -240,7 +240,9 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
240
240
  import shap
241
241
 
242
242
  explainer = shap.TreeExplainer(raw_model)
243
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
243
+ df = handlers_utils.convert_explanations_to_2D_df(
244
+ raw_model, explainer.shap_values(X, from_call=True)
245
+ )
244
246
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
245
247
 
246
248
  if target_method == "explain":
@@ -14,8 +14,8 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
14
  from snowflake.ml.model._packager.model_meta import (
15
15
  model_blob_meta,
16
16
  model_meta as model_meta_api,
17
+ model_meta_schema,
17
18
  )
18
- from snowflake.ml.model._signatures import utils as model_signature_utils
19
19
  from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  if TYPE_CHECKING:
@@ -24,6 +24,25 @@ if TYPE_CHECKING:
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
+ def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
28
+ if list(sigs.keys()) != ["encode"]:
29
+ raise ValueError("target_methods can only be ['encode']")
30
+
31
+ if len(sigs["encode"].inputs) != 1:
32
+ raise ValueError("SentenceTransformer can only accept 1 input column")
33
+
34
+ if len(sigs["encode"].outputs) != 1:
35
+ raise ValueError("SentenceTransformer can only return 1 output column")
36
+
37
+ assert isinstance(sigs["encode"].inputs[0], model_signature.FeatureSpec)
38
+
39
+ if sigs["encode"].inputs[0]._shape is not None:
40
+ raise ValueError("SentenceTransformer does not support input shape")
41
+
42
+ if sigs["encode"].inputs[0]._dtype != model_signature.DataType.STRING:
43
+ raise ValueError("SentenceTransformer only accepts string input")
44
+
45
+
27
46
  @final
28
47
  class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.SentenceTransformer"]):
29
48
  HANDLER_TYPE = "sentence_transformers"
@@ -68,6 +87,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
68
87
  if enable_explainability:
69
88
  raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
70
89
 
90
+ batch_size = kwargs.get("batch_size", 32)
91
+ if not isinstance(batch_size, int) or batch_size <= 0:
92
+ raise ValueError("batch_size must be a positive integer")
93
+
71
94
  # Validate target methods and signature (if possible)
72
95
  if not is_sub_model:
73
96
  target_methods = handlers_utils.get_target_methods(
@@ -75,12 +98,23 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
75
98
  target_methods=kwargs.pop("target_methods", None),
76
99
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
77
100
  )
78
- assert target_methods == ["encode"], "target_methods can only be ['encode']"
101
+ if target_methods != ["encode"]:
102
+ raise ValueError("target_methods can only be ['encode']")
79
103
 
80
104
  def get_prediction(
81
105
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
82
106
  ) -> model_types.SupportedLocalDataType:
83
- return _sentence_transformer_encode(model, sample_input_data)
107
+ if not isinstance(sample_input_data, pd.DataFrame):
108
+ sample_input_data = model_signature._convert_local_data_to_df(data=sample_input_data)
109
+
110
+ if sample_input_data.shape[1] != 1:
111
+ raise ValueError(
112
+ "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
113
+ )
114
+ X_list = sample_input_data.iloc[:, 0].tolist()
115
+
116
+ assert callable(getattr(model, "encode", None))
117
+ return pd.DataFrame({0: model.encode(X_list, batch_size=batch_size).tolist()})
84
118
 
85
119
  if model_meta.signatures:
86
120
  handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
@@ -102,10 +136,16 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
102
136
  get_prediction_fn=get_prediction,
103
137
  )
104
138
 
139
+ _validate_sentence_transformers_signatures(model_meta.signatures)
140
+
105
141
  # save model
106
142
  model_blob_path = os.path.join(model_blobs_dir_path, name)
107
143
  os.makedirs(model_blob_path, exist_ok=True)
108
- model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
144
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
145
+ model.save(save_path)
146
+ handlers_utils.save_transformers_config_with_auto_map(
147
+ save_path,
148
+ )
109
149
 
110
150
  # save model metadata
111
151
  base_meta = model_blob_meta.ModelBlobMeta(
@@ -113,6 +153,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
113
153
  model_type=cls.HANDLER_TYPE,
114
154
  handler_version=cls.HANDLER_VERSION,
115
155
  path=cls.MODEL_BLOB_FILE_OR_DIR,
156
+ options=model_meta_schema.SentenceTransformersModelBlobOptions(batch_size=batch_size),
116
157
  )
117
158
  model_meta.models[name] = base_meta
118
159
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -149,6 +190,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
149
190
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
150
191
  # We need to redirect the same folders to a writable location in the sandbox.
151
192
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
193
+ os.environ["HF_HOME"] = "/tmp"
152
194
 
153
195
  model_blob_path = os.path.join(model_blobs_dir_path, name)
154
196
  model_blobs_metadata = model_meta.models
@@ -183,6 +225,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
183
225
  raw_model: "sentence_transformers.SentenceTransformer",
184
226
  model_meta: model_meta_api.ModelMetadata,
185
227
  ) -> Type[custom_model.CustomModel]:
228
+ batch_size = cast(
229
+ model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
230
+ ).get("batch_size", None)
231
+
186
232
  def get_prediction(
187
233
  raw_model: "sentence_transformers.SentenceTransformer",
188
234
  signature: model_signature.ModelSignature,
@@ -190,8 +236,11 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
190
236
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
191
237
  @custom_model.inference_api
192
238
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
193
- predictions_df = _sentence_transformer_encode(raw_model, X)
194
- return model_signature_utils.rename_pandas_df(predictions_df, signature.outputs)
239
+ X_list = X.iloc[:, 0].tolist()
240
+
241
+ return pd.DataFrame(
242
+ {signature.outputs[0].name: raw_model.encode(X_list, batch_size=batch_size).tolist()}
243
+ )
195
244
 
196
245
  return fn
197
246
 
@@ -217,17 +266,3 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
217
266
  predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
218
267
  assert callable(predict_method)
219
268
  return sentence_transformers_SentenceTransformer_model
220
-
221
-
222
- def _sentence_transformer_encode(
223
- model: "sentence_transformers.SentenceTransformer", X: model_types.SupportedLocalDataType
224
- ) -> model_types.SupportedLocalDataType:
225
-
226
- if not isinstance(X, pd.DataFrame):
227
- X = model_signature._convert_local_data_to_df(X)
228
-
229
- assert X.shape[1] == 1, "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
230
- X_list = X.iloc[:, 0].tolist()
231
-
232
- assert callable(getattr(model, "encode", None))
233
- return pd.DataFrame({0: model.encode(X_list, batch_size=X.shape[0]).tolist()})
@@ -152,8 +152,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
152
152
  sample_input_data, model_meta, explain_target_method
153
153
  )
154
154
 
155
- model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
156
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
155
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
156
+ model_meta.task = model_task_and_output_type.task
157
157
 
158
158
  # if users did not ask then we enable if we have background data
159
159
  if enable_explainability is None:
@@ -164,11 +164,17 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
164
164
  stacklevel=1,
165
165
  )
166
166
  enable_explainability = False
167
- elif model_meta.task == model_types.Task.UNKNOWN:
167
+ elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
168
168
  enable_explainability = False
169
169
  else:
170
170
  enable_explainability = True
171
171
  if enable_explainability:
172
+ model_meta = handlers_utils.add_explain_method_signature(
173
+ model_meta=model_meta,
174
+ explain_method="explain",
175
+ target_method=explain_target_method,
176
+ output_return_type=model_task_and_output_type.output_type,
177
+ )
172
178
  handlers_utils.save_background_data(
173
179
  model_blobs_dir_path,
174
180
  cls.EXPLAIN_ARTIFACTS_DIR,
@@ -177,13 +183,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
177
183
  background_data,
178
184
  )
179
185
 
180
- model_meta = handlers_utils.add_explain_method_signature(
181
- model_meta=model_meta,
182
- explain_method="explain",
183
- target_method=explain_target_method,
184
- output_return_type=model_task_and_output_type.output_type,
185
- )
186
-
187
186
  model_blob_path = os.path.join(model_blobs_dir_path, name)
188
187
  os.makedirs(model_blob_path, exist_ok=True)
189
188
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
@@ -68,21 +68,45 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
68
68
  return cast("BaseEstimator", model)
69
69
 
70
70
  @classmethod
71
- def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
71
+ def _get_supported_object_for_explainability(
72
+ cls,
73
+ estimator: "BaseEstimator",
74
+ background_data: Optional[model_types.SupportedDataType],
75
+ enable_explainability: Optional[bool],
76
+ ) -> Any:
72
77
  from snowflake.ml.modeling import pipeline as snowml_pipeline
73
78
 
74
79
  # handle pipeline objects separately
75
80
  if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
76
81
  return None
77
82
 
78
- methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
79
- for method_name in methods:
83
+ tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
84
+ non_tree_methods = ["to_sklearn"]
85
+ for method_name in tree_methods:
86
+ if hasattr(estimator, method_name):
87
+ try:
88
+ result = getattr(estimator, method_name)()
89
+ return result
90
+ except exceptions.SnowflakeMLException:
91
+ pass # Do nothing and continue to the next method
92
+ for method_name in non_tree_methods:
80
93
  if hasattr(estimator, method_name):
81
94
  try:
82
95
  result = getattr(estimator, method_name)()
96
+ if enable_explainability is None and background_data is None:
97
+ return None # cannot get explain without background data
98
+ elif enable_explainability and background_data is None:
99
+ raise ValueError(
100
+ "Provide `sample_input_data` to generate explanations for sklearn Snowpark ML models."
101
+ )
83
102
  return result
84
103
  except exceptions.SnowflakeMLException:
85
104
  pass # Do nothing and continue to the next method
105
+
106
+ if enable_explainability:
107
+ raise ValueError(
108
+ "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
109
+ )
86
110
  return None
87
111
 
88
112
  @classmethod
@@ -127,34 +151,39 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
127
151
  raise ValueError(f"Target method {method_name} does not exist in the model.")
128
152
  model_meta.signatures = temp_model_signature_dict
129
153
 
130
- if enable_explainability or enable_explainability is None:
131
- python_base_obj = cls._get_supported_object_for_explainability(model)
132
- if python_base_obj is None:
133
- if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
134
- raise ValueError(
135
- "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
136
- )
154
+ python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
155
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
156
+
157
+ if enable_explainability:
158
+ if explain_target_method is None:
159
+ raise ValueError(
160
+ "The model must have one of the following methods to enable explainability: "
161
+ + ", ".join(cls.EXPLAIN_TARGET_METHODS)
162
+ )
163
+ if enable_explainability is None:
164
+ if python_base_obj is None or explain_target_method is None:
137
165
  # set None to False so we don't include shap in the environment
138
166
  enable_explainability = False
139
167
  else:
140
- model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
141
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
142
- explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
143
- model_meta = handlers_utils.add_explain_method_signature(
144
- model_meta=model_meta,
145
- explain_method="explain",
146
- target_method=explain_target_method,
147
- output_return_type=model_task_and_output_type.output_type,
148
- )
149
168
  enable_explainability = True
150
-
151
- background_data = handlers_utils.get_explainability_supported_background(
152
- sample_input_data, model_meta, explain_target_method
169
+ if enable_explainability:
170
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
171
+ python_base_obj, model_meta.task
172
+ )
173
+ model_meta.task = model_task_and_output_type.task
174
+ model_meta = handlers_utils.add_explain_method_signature(
175
+ model_meta=model_meta,
176
+ explain_method="explain",
177
+ target_method=explain_target_method,
178
+ output_return_type=model_task_and_output_type.output_type,
179
+ )
180
+ background_data = handlers_utils.get_explainability_supported_background(
181
+ sample_input_data, model_meta, explain_target_method
182
+ )
183
+ if background_data is not None:
184
+ handlers_utils.save_background_data(
185
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
153
186
  )
154
- if background_data is not None:
155
- handlers_utils.save_background_data(
156
- model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
157
- )
158
187
 
159
188
  model_blob_path = os.path.join(model_blobs_dir_path, name)
160
189
  os.makedirs(model_blob_path, exist_ok=True)
@@ -237,8 +266,17 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
237
266
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
238
267
  import shap
239
268
 
240
- methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
241
- for method_name in methods:
269
+ tree_methods = ["to_xgboost", "to_lightgbm"]
270
+ non_tree_methods = ["to_sklearn"]
271
+ for method_name in tree_methods:
272
+ try:
273
+ base_model = getattr(raw_model, method_name)()
274
+ explainer = shap.TreeExplainer(base_model)
275
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
276
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
277
+ except exceptions.SnowflakeMLException:
278
+ pass # Do nothing and continue to the next method
279
+ for method_name in non_tree_methods:
242
280
  try:
243
281
  base_model = getattr(raw_model, method_name)()
244
282
  explainer = shap.Explainer(base_model, masker=background_data)
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
+ from packaging import version
6
7
  from typing_extensions import TypeGuard, Unpack
7
8
 
8
9
  from snowflake.ml._internal import type_utils
@@ -73,13 +74,42 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
73
74
  if enable_explainability:
74
75
  raise NotImplementedError("Explainability is not supported for Tensorflow model.")
75
76
 
77
+ # When tensorflow is installed, keras is also installed.
78
+ import keras
76
79
  import tensorflow
77
80
 
78
81
  assert isinstance(model, tensorflow.Module)
79
82
 
80
83
  is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
81
- "tf_keras.Model"
84
+ "keras.Model"
82
85
  ).isinstance(model)
86
+ is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
87
+ is_keras_functional_or_sequential_model = (
88
+ getattr(model, "_is_graph_network", False)
89
+ or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
90
+ or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
91
+ or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
92
+ )
93
+
94
+ assert isinstance(model, tensorflow.Module)
95
+
96
+ keras_version = version.parse(keras.__version__)
97
+
98
+ # Tensorflow and keras model save format is different.
99
+ # Keras functional or sequential models are saved as keras format
100
+ # Keras v3 other models are saved using cloudpickle
101
+ # Keras v2 other models are saved using tensorflow saved model format
102
+ # Tensorflow models are saved using tensorflow saved model format
103
+
104
+ if is_keras_model or is_tf_keras_model:
105
+ if is_keras_functional_or_sequential_model:
106
+ save_format = "keras"
107
+ elif keras_version.major == 2 or is_tf_keras_model:
108
+ save_format = "keras_tf"
109
+ else:
110
+ save_format = "cloudpickle"
111
+ else:
112
+ save_format = "tf"
83
113
 
84
114
  if is_keras_model:
85
115
  default_target_methods = ["predict"]
@@ -93,6 +123,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
93
123
  default_target_methods=default_target_methods,
94
124
  )
95
125
 
126
+ if is_keras_model and len(target_methods) > 1:
127
+ raise ValueError("Keras model can only have one target method.")
128
+
96
129
  def get_prediction(
97
130
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
98
131
  ) -> model_types.SupportedLocalDataType:
@@ -122,31 +155,43 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
122
155
 
123
156
  model_blob_path = os.path.join(model_blobs_dir_path, name)
124
157
  os.makedirs(model_blob_path, exist_ok=True)
125
- if is_keras_model:
126
- tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
127
- model_meta.env.include_if_absent(
128
- [
129
- model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
130
- ],
131
- check_local_version=False,
132
- )
158
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
159
+ if save_format == "keras":
160
+ model.save(save_path, save_format="keras")
161
+ elif save_format == "keras_tf":
162
+ model.save(save_path, save_format="tf")
163
+ elif save_format == "cloudpickle":
164
+ import cloudpickle
165
+
166
+ with open(save_path, "wb") as f:
167
+ cloudpickle.dump(model, f)
133
168
  else:
134
- tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
169
+ tensorflow.saved_model.save(
170
+ model,
171
+ save_path,
172
+ options=tensorflow.saved_model.SaveOptions(experimental_custom_gradients=False),
173
+ )
135
174
 
136
175
  base_meta = model_blob_meta.ModelBlobMeta(
137
176
  name=name,
138
177
  model_type=cls.HANDLER_TYPE,
139
178
  handler_version=cls.HANDLER_VERSION,
140
179
  path=cls.MODEL_BLOB_FILE_OR_DIR,
141
- options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
180
+ options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
142
181
  )
143
182
  model_meta.models[name] = base_meta
144
183
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
145
184
 
185
+ dependencies = [
186
+ model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
187
+ ]
188
+ if is_keras_model:
189
+ dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
190
+ elif is_tf_keras_model:
191
+ dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
192
+
146
193
  model_meta.env.include_if_absent(
147
- [
148
- model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
149
- ],
194
+ dependencies,
150
195
  check_local_version=True,
151
196
  )
152
197
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
@@ -166,10 +211,18 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
166
211
  model_blob_metadata = model_blobs_metadata[name]
167
212
  model_blob_filename = model_blob_metadata.path
168
213
  model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
169
- if model_blob_options.get("is_keras_model", False):
170
- m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
214
+ load_path = os.path.join(model_blob_path, model_blob_filename)
215
+ save_format = model_blob_options.get("save_format", "tf")
216
+ if save_format == "keras" or save_format == "keras_tf":
217
+ m = tensorflow.keras.models.load_model(load_path)
218
+ elif save_format == "cloudpickle":
219
+ import cloudpickle
220
+
221
+ with open(load_path, "rb") as f:
222
+ m = cloudpickle.load(f)
171
223
  else:
172
- m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
224
+ m = tensorflow.saved_model.load(load_path)
225
+
173
226
  return cast(tensorflow.Module, m)
174
227
 
175
228
  @classmethod
@@ -117,8 +117,8 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
117
117
  sample_input_data=sample_input_data,
118
118
  get_prediction_fn=get_prediction,
119
119
  )
120
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
121
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
120
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
121
+ model_meta.task = model_task_and_output.task
122
122
  if enable_explainability:
123
123
  model_meta = handlers_utils.add_explain_method_signature(
124
124
  model_meta=model_meta,
@@ -254,7 +254,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
254
254
  import shap
255
255
 
256
256
  explainer = shap.TreeExplainer(raw_model)
257
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
257
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
258
258
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
259
259
 
260
260
  if target_method == "explain":
@@ -215,6 +215,7 @@ class ModelMetadata:
215
215
  function_properties: A dict mapping function names to dict mapping function property key to value.
216
216
  metadata: User provided key-value metadata of the model. Defaults to None.
217
217
  creation_timestamp: Unix timestamp when the model metadata is created.
218
+ user_files: Dict mapping subdirectories to extra artifact file paths for files to include in the model.
218
219
  task: Model task like TABULAR_REGRESSION, tabular_classification, timeseries_forecasting etc.
219
220
  """
220
221
 
@@ -234,6 +235,7 @@ class ModelMetadata:
234
235
  runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
235
236
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
236
237
  function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
238
+ user_files: Optional[Dict[str, List[str]]] = None,
237
239
  metadata: Optional[Dict[str, str]] = None,
238
240
  creation_timestamp: Optional[str] = None,
239
241
  min_snowpark_ml_version: Optional[str] = None,
@@ -247,6 +249,7 @@ class ModelMetadata:
247
249
  if signatures:
248
250
  self.signatures = signatures
249
251
  self.function_properties = function_properties or {}
252
+ self.user_files = user_files
250
253
  self.metadata = metadata
251
254
  self.model_type = model_type
252
255
  self.env = env
@@ -59,7 +59,11 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
59
59
 
60
60
 
61
61
  class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
- is_keras_model: Required[bool]
62
+ save_format: Required[str]
63
+
64
+
65
+ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
66
+ batch_size: Required[int]
63
67
 
64
68
 
65
69
  ModelBlobOptions = Union[
@@ -68,6 +72,7 @@ ModelBlobOptions = Union[
68
72
  MLFlowModelBlobOptions,
69
73
  XgboostModelBlobOptions,
70
74
  TensorflowModelBlobOptions,
75
+ SentenceTransformersModelBlobOptions,
71
76
  ]
72
77
 
73
78
 
@@ -1,2 +1,2 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -149,8 +149,9 @@ def _get_model_task(model: Any) -> type_hints.Task:
149
149
  raise ValueError(f"Model type {type(model)} is not supported")
150
150
 
151
151
 
152
- def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
153
- task = _get_model_task(model)
152
+ def resolve_model_task_and_output_type(model: Any, passed_model_task: type_hints.Task) -> ModelTaskAndOutputType:
153
+ inferred_task = _get_model_task(model)
154
+ task = handlers_utils.validate_model_task(passed_model_task, inferred_task)
154
155
  output_type = model_signature.DataType.DOUBLE
155
156
  if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
156
157
  output_type = model_signature.DataType.STRING
@@ -224,6 +224,6 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
224
224
  df_col_dtypes = [df[col].dtype for col in df.columns]
225
225
  for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
226
226
  if df_col_dtype == np.dtype("O"):
227
- if isinstance(df[df_col][0], np.ndarray):
227
+ if isinstance(df[df_col].iloc[0], np.ndarray):
228
228
  df[df_col] = df[df_col].map(np.ndarray.tolist)
229
229
  return df
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Literal, Optional, Sequence, cast
2
+ from typing import Any, Literal, Optional, Sequence, cast
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -73,14 +73,20 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
73
73
  assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
74
74
  dtype_map[feature.name] = feature.as_dtype()
75
75
  df_local = data.to_pandas()
76
+
76
77
  # This is because Array will become string (Even though the correct schema is set)
77
78
  # and object will become variant type and requires an additional loads
78
79
  # to get correct data otherwise it would be string.
80
+ def load_if_not_null(x: str) -> Optional[Any]:
81
+ if x is None:
82
+ return None
83
+ return json.loads(x)
84
+
79
85
  for field in data.schema.fields:
80
86
  if isinstance(field.datatype, spt.ArrayType):
81
87
  df_local[identifier.get_unescaped_names(field.name)] = df_local[
82
88
  identifier.get_unescaped_names(field.name)
83
- ].map(json.loads)
89
+ ].map(load_if_not_null)
84
90
  # Only when the feature is not from inference, we are confident to do the type casting.
85
91
  # Otherwise, dtype_map will be empty.
86
92
  # Errors are ignored to make sure None won't be converted and won't raise Error
@@ -199,6 +199,7 @@ class HuggingFaceSaveOptions(BaseModelSaveOption):
199
199
  class SentenceTransformersSaveOptions(BaseModelSaveOption):
200
200
  target_methods: NotRequired[Sequence[str]]
201
201
  cuda_version: NotRequired[str]
202
+ batch_size: NotRequired[int]
202
203
 
203
204
 
204
205
  ModelSaveOption = Union[