snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ def get_model_method_options_from_options(
27
27
  options: type_hints.ModelSaveOption, target_method: str
28
28
  ) -> ModelMethodOptions:
29
29
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
30
- if options.get("enable_explainability", False) and target_method.startswith("explain"):
30
+ if target_method == "explain":
31
31
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
32
32
  method_option = options.get("method_options", {}).get(target_method, {})
33
33
  global_function_type = options.get("function_type", default_function_type)
@@ -174,6 +174,18 @@ class ModelEnv:
174
174
  except env_utils.DuplicateDependencyError:
175
175
  pass
176
176
 
177
+ def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
178
+ """Remove conda requirements from model env if present.
179
+
180
+ Args:
181
+ conda_pkgs: A list of package name to be removed from conda requirements.
182
+ """
183
+ for pkg_name in conda_pkgs:
184
+ spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
185
+ if spec_conda:
186
+ channel, spec = spec_conda
187
+ self._conda_dependencies[channel].remove(spec)
188
+
177
189
  def generate_env_for_cuda(self) -> None:
178
190
  if self.cuda_version is None:
179
191
  return
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
179
179
  return pd.DataFrame(explanations)
180
180
 
181
181
  if hasattr(model, "classes_"):
182
- classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
182
+ classes_list = [str(cl) for cl in model.classes_]
183
183
  len_classes = len(classes_list)
184
184
  if explanations.shape[2] != len_classes:
185
185
  raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
@@ -191,7 +191,11 @@ def convert_explanations_to_2D_df(
191
191
  # convert to object or numpy creates strings of fixed length
192
192
  return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
193
193
 
194
- exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
194
+ # convert to dict only for multiclass
195
+ if len(classes_list) > 2:
196
+ exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
197
+ else: # assumes index 1 is positive class always
198
+ exp_2d = np.apply_along_axis(lambda arr: arr[1], -1, explanations)
195
199
 
196
200
  return pd.DataFrame(exp_2d)
197
201
 
@@ -9,17 +9,14 @@ from typing_extensions import TypeGuard, Unpack
9
9
  from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
- from snowflake.ml.model._packager.model_handlers import (
13
- _base,
14
- _utils as handlers_utils,
15
- model_objective_utils,
16
- )
12
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
17
13
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
18
14
  from snowflake.ml.model._packager.model_meta import (
19
15
  model_blob_meta,
20
16
  model_meta as model_meta_api,
21
17
  model_meta_schema,
22
18
  )
19
+ from snowflake.ml.model._packager.model_task import model_task_utils
23
20
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
24
21
 
25
22
  if TYPE_CHECKING:
@@ -97,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
97
94
  sample_input_data=sample_input_data,
98
95
  get_prediction_fn=get_prediction,
99
96
  )
100
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
101
- model_meta.task = model_task_and_output.task
97
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
98
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
102
99
  if enable_explainability:
103
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
104
101
  model_meta = handlers_utils.add_explain_method_signature(
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, final
5
+ from typing import Dict, Optional, Type, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -99,6 +99,8 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
101
  assert handler is not None
102
+ if handler is None:
103
+ raise TypeError("Your input type to custom model is not currently supported")
102
104
  sub_model = handler.cast_model(model_ref.model)
103
105
  handler.save_model(
104
106
  name=sub_name,
@@ -106,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
106
108
  model_meta=model_meta,
107
109
  model_blobs_dir_path=model_blobs_dir_path,
108
110
  is_sub_model=True,
111
+ **cast(model_types.BaseModelSaveOption, kwargs),
109
112
  )
110
113
 
111
114
  # Make sure that the module where the model is defined get pickled by value as well.
@@ -173,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
173
176
  name=sub_model_name,
174
177
  model_meta=model_meta,
175
178
  model_blobs_dir_path=model_blobs_dir_path,
179
+ **cast(model_types.BaseModelLoadOption, kwargs),
176
180
  )
177
181
  models[sub_model_name] = sub_model
178
182
  reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
@@ -256,12 +256,20 @@ class HuggingFacePipelineHandler(
256
256
  @staticmethod
257
257
  def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
258
258
  device_config: Dict[str, Any] = {}
259
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
260
+ gpu_nums = 0
261
+ if cuda_visible_devices is not None:
262
+ gpu_nums = len(cuda_visible_devices.split(","))
259
263
  if (
260
264
  kwargs.get("use_gpu", False)
261
265
  and kwargs.get("device_map", None) is None
262
266
  and kwargs.get("device", None) is None
263
267
  ):
264
- device_config["device_map"] = "auto"
268
+ if gpu_nums == 0 or gpu_nums > 1:
269
+ # Use accelerator if there are multiple GPUs or no GPU
270
+ device_config["device_map"] = "auto"
271
+ else:
272
+ device_config["device"] = "cuda"
265
273
  elif kwargs.get("device_map", None) is not None:
266
274
  device_config["device_map"] = kwargs["device_map"]
267
275
  elif kwargs.get("device", None) is not None:
@@ -310,6 +318,7 @@ class HuggingFacePipelineHandler(
310
318
  m = transformers.pipeline(
311
319
  model_blob_options["task"],
312
320
  model=model_blob_file_or_dir_path,
321
+ trust_remote_code=True,
313
322
  **device_config,
314
323
  )
315
324
 
@@ -20,17 +20,14 @@ from typing_extensions import TypeGuard, Unpack
20
20
  from snowflake.ml._internal import type_utils
21
21
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
22
22
  from snowflake.ml.model._packager.model_env import model_env
23
- from snowflake.ml.model._packager.model_handlers import (
24
- _base,
25
- _utils as handlers_utils,
26
- model_objective_utils,
27
- )
23
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
28
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
29
25
  from snowflake.ml.model._packager.model_meta import (
30
26
  model_blob_meta,
31
27
  model_meta as model_meta_api,
32
28
  model_meta_schema,
33
29
  )
30
+ from snowflake.ml.model._packager.model_task import model_task_utils
34
31
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
35
32
 
36
33
  if TYPE_CHECKING:
@@ -113,7 +110,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
113
110
  sample_input_data=sample_input_data,
114
111
  get_prediction_fn=get_prediction,
115
112
  )
116
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
113
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
117
114
  model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
118
115
  if enable_explainability:
119
116
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
@@ -199,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
199
196
  with open(model_blob_file_path, "rb") as f:
200
197
  model = cloudpickle.load(f)
201
198
  assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
199
+ assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
202
200
 
203
201
  return model
204
202
 
205
203
  @classmethod
206
204
  def convert_as_custom_model(
207
205
  cls,
208
- raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
206
+ raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
209
207
  model_meta: model_meta_api.ModelMetadata,
210
208
  background_data: Optional[pd.DataFrame] = None,
211
209
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import logging
2
3
  import os
3
4
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
@@ -155,8 +156,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
155
156
  model_blob_filename = model_blob_metadata.path
156
157
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
157
158
 
159
+ additional_kwargs = {}
160
+ if "trust_remote_code" in inspect.signature(sentence_transformers.SentenceTransformer).parameters:
161
+ additional_kwargs["trust_remote_code"] = True
162
+
158
163
  model = sentence_transformers.SentenceTransformer(
159
- model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
164
+ model_blob_file_or_dir_path,
165
+ device=cls._get_device_config(**kwargs),
166
+ **additional_kwargs,
160
167
  )
161
168
  return model
162
169
 
@@ -10,24 +10,35 @@ from typing_extensions import TypeGuard, Unpack
10
10
  from snowflake.ml._internal import type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import (
14
- _base,
15
- _utils as handlers_utils,
16
- model_objective_utils,
17
- )
13
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
18
14
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
19
15
  from snowflake.ml.model._packager.model_meta import (
20
16
  model_blob_meta,
21
17
  model_meta as model_meta_api,
22
18
  model_meta_schema,
23
19
  )
20
+ from snowflake.ml.model._packager.model_task import model_task_utils
24
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
+ from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
25
23
 
26
24
  if TYPE_CHECKING:
27
25
  import sklearn.base
28
26
  import sklearn.pipeline
29
27
 
30
28
 
29
+ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
30
+ new_steps = []
31
+ for step_name, step in model.steps:
32
+ new_reg = step
33
+ if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None:
34
+ # Unpack estimator to open source.
35
+ new_reg = step._sklearn_estimator
36
+ new_steps.append((step_name, new_reg))
37
+
38
+ model.steps = new_steps
39
+ return model
40
+
41
+
31
42
  @final
32
43
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
33
44
  """Handler for scikit-learn based model.
@@ -104,6 +115,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
104
115
  if sample_input_data is None:
105
116
  raise ValueError("Sample input data is required to enable explainability.")
106
117
 
118
+ # If this is a pipeline and we are in the container runtime, check for distributed estimator.
119
+ if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
120
+ model = _unpack_container_runtime_pipeline(model)
121
+
107
122
  if not is_sub_model:
108
123
  target_methods = handlers_utils.get_target_methods(
109
124
  model=model,
@@ -137,8 +152,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
137
152
  sample_input_data, model_meta, explain_target_method
138
153
  )
139
154
 
140
- model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
141
- model_meta.task = model_task_and_output_type.task
155
+ model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
156
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
142
157
 
143
158
  # if users did not ask then we enable if we have background data
144
159
  if enable_explainability is None:
@@ -180,6 +195,35 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
180
195
  model_meta.models[name] = base_meta
181
196
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
182
197
 
198
+ # if model instance is a pipeline, check the pipeline steps
199
+ if isinstance(model, sklearn.pipeline.Pipeline):
200
+ for _, pipeline_step in model.steps:
201
+ if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType(
202
+ "lightgbm.Booster"
203
+ ).isinstance(pipeline_step):
204
+ model_meta.env.include_if_absent(
205
+ [
206
+ model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
207
+ ],
208
+ check_local_version=True,
209
+ )
210
+ elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType(
211
+ "xgboost.Booster"
212
+ ).isinstance(pipeline_step):
213
+ model_meta.env.include_if_absent(
214
+ [
215
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
216
+ ],
217
+ check_local_version=True,
218
+ )
219
+ elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step):
220
+ model_meta.env.include_if_absent(
221
+ [
222
+ model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
223
+ ],
224
+ check_local_version=True,
225
+ )
226
+
183
227
  if enable_explainability:
184
228
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
185
229
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
@@ -5,24 +5,20 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, fin
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
- from packaging import version
9
8
  from typing_extensions import TypeGuard, Unpack
10
9
 
11
10
  from snowflake.ml._internal import type_utils
12
11
  from snowflake.ml._internal.exceptions import exceptions
13
12
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
14
13
  from snowflake.ml.model._packager.model_env import model_env
15
- from snowflake.ml.model._packager.model_handlers import (
16
- _base,
17
- _utils as handlers_utils,
18
- model_objective_utils,
19
- )
14
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
20
15
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
21
16
  from snowflake.ml.model._packager.model_meta import (
22
17
  model_blob_meta,
23
18
  model_meta as model_meta_api,
24
19
  model_meta_schema,
25
20
  )
21
+ from snowflake.ml.model._packager.model_task import model_task_utils
26
22
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
27
23
 
28
24
  if TYPE_CHECKING:
@@ -72,41 +68,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
72
68
  return cast("BaseEstimator", model)
73
69
 
74
70
  @classmethod
75
- def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
76
- from importlib import metadata as importlib_metadata
77
-
78
- from packaging import version
79
-
80
- local_version = None
81
-
82
- try:
83
- local_dist = importlib_metadata.distribution(pkg_name)
84
- local_version = version.parse(local_dist.version)
85
- except importlib_metadata.PackageNotFoundError:
86
- pass
87
-
88
- return local_version
89
-
90
- @classmethod
91
- def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
92
-
93
- local_xgb_version = cls._get_local_version_package("xgboost")
94
-
95
- if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
96
- if enable_explainability:
97
- warnings.warn(
98
- f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
99
- + "If you want model explanations, lower the xgboost version to <2.1.0.",
100
- category=UserWarning,
101
- stacklevel=1,
102
- )
103
- return False
104
- return True
105
-
106
- @classmethod
107
- def _get_supported_object_for_explainability(
108
- cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
109
- ) -> Any:
71
+ def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
110
72
  from snowflake.ml.modeling import pipeline as snowml_pipeline
111
73
 
112
74
  # handle pipeline objects separately
@@ -118,8 +80,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
118
80
  if hasattr(estimator, method_name):
119
81
  try:
120
82
  result = getattr(estimator, method_name)()
121
- if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
122
- return None
123
83
  return result
124
84
  except exceptions.SnowflakeMLException:
125
85
  pass # Do nothing and continue to the next method
@@ -168,7 +128,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
168
128
  model_meta.signatures = temp_model_signature_dict
169
129
 
170
130
  if enable_explainability or enable_explainability is None:
171
- python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
131
+ python_base_obj = cls._get_supported_object_for_explainability(model)
172
132
  if python_base_obj is None:
173
133
  if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
174
134
  raise ValueError(
@@ -177,8 +137,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
177
137
  # set None to False so we don't include shap in the environment
178
138
  enable_explainability = False
179
139
  else:
180
- model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
181
- model_meta.task = model_task_and_output_type.task
140
+ model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
141
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
182
142
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
183
143
  model_meta = handlers_utils.add_explain_method_signature(
184
144
  model_meta=model_meta,
@@ -213,28 +173,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
213
173
  model_dependencies = model._get_dependencies()
214
174
  for dep in model_dependencies:
215
175
  pkg_name = dep.split("==")[0]
216
- if pkg_name != "xgboost":
217
- _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
218
- continue
219
-
220
- local_xgb_version = cls._get_local_version_package("xgboost")
221
- if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
222
- model_meta.env.include_if_absent(
223
- [
224
- model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
225
- ],
226
- check_local_version=False,
227
- )
228
- else:
229
- model_meta.env.include_if_absent(
230
- [
231
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
232
- ],
233
- check_local_version=True,
234
- )
176
+ _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
235
177
 
236
178
  if enable_explainability:
237
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
179
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
238
180
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
239
181
  model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
240
182
 
@@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
13
  from snowflake.ml.model._packager.model_meta import (
14
14
  model_blob_meta,
15
15
  model_meta as model_meta_api,
16
+ model_meta_schema,
16
17
  )
17
18
  from snowflake.ml.model._signatures import (
18
19
  numpy_handler,
@@ -76,7 +77,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
76
77
 
77
78
  assert isinstance(model, tensorflow.Module)
78
79
 
79
- if isinstance(model, tensorflow.keras.Model):
80
+ is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
81
+ "tf_keras.Model"
82
+ ).isinstance(model)
83
+
84
+ if is_keras_model:
80
85
  default_target_methods = ["predict"]
81
86
  else:
82
87
  default_target_methods = cls.DEFAULT_TARGET_METHODS
@@ -117,8 +122,14 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
117
122
 
118
123
  model_blob_path = os.path.join(model_blobs_dir_path, name)
119
124
  os.makedirs(model_blob_path, exist_ok=True)
120
- if isinstance(model, tensorflow.keras.Model):
125
+ if is_keras_model:
121
126
  tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
127
+ model_meta.env.include_if_absent(
128
+ [
129
+ model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
130
+ ],
131
+ check_local_version=False,
132
+ )
122
133
  else:
123
134
  tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
124
135
 
@@ -127,12 +138,16 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
127
138
  model_type=cls.HANDLER_TYPE,
128
139
  handler_version=cls.HANDLER_VERSION,
129
140
  path=cls.MODEL_BLOB_FILE_OR_DIR,
141
+ options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
130
142
  )
131
143
  model_meta.models[name] = base_meta
132
144
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
133
145
 
134
146
  model_meta.env.include_if_absent(
135
- [model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow")], check_local_version=True
147
+ [
148
+ model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
149
+ ],
150
+ check_local_version=True,
136
151
  )
137
152
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
138
153
 
@@ -150,9 +165,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
150
165
  model_blobs_metadata = model_meta.models
151
166
  model_blob_metadata = model_blobs_metadata[name]
152
167
  model_blob_filename = model_blob_metadata.path
153
- m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
154
- if isinstance(m, tensorflow.keras.Model):
155
- return m
168
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
169
+ if model_blob_options.get("is_keras_model", False):
170
+ m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
171
+ else:
172
+ m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
156
173
  return cast(tensorflow.Module, m)
157
174
 
158
175
  @classmethod
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
 
24
24
 
25
25
  @final
26
- class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
26
+ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
27
27
  """Handler for PyTorch JIT based model.
28
28
 
29
29
  Currently torch.jit.ScriptModule based classes are supported.
@@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
41
41
  def can_handle(
42
42
  cls,
43
43
  model: model_types.SupportedModelType,
44
- ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
44
+ ) -> TypeGuard["torch.jit.ScriptModule"]:
45
45
  return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
46
46
 
47
47
  @classmethod
48
48
  def cast_model(
49
49
  cls,
50
50
  model: model_types.SupportedModelType,
51
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
51
+ ) -> "torch.jit.ScriptModule":
52
52
  import torch
53
53
 
54
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
54
+ assert isinstance(model, torch.jit.ScriptModule)
55
55
 
56
- return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
56
+ return cast(torch.jit.ScriptModule, model)
57
57
 
58
58
  @classmethod
59
59
  def save_model(
60
60
  cls,
61
61
  name: str,
62
- model: "torch.jit.ScriptModule", # type:ignore[name-defined]
62
+ model: "torch.jit.ScriptModule",
63
63
  model_meta: model_meta_api.ModelMetadata,
64
64
  model_blobs_dir_path: str,
65
65
  sample_input_data: Optional[model_types.SupportedDataType] = None,
@@ -72,7 +72,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
72
72
 
73
73
  import torch
74
74
 
75
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
75
+ assert isinstance(model, torch.jit.ScriptModule)
76
76
 
77
77
  if not is_sub_model:
78
78
  target_methods = handlers_utils.get_target_methods(
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
111
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
112
112
  os.makedirs(model_blob_path, exist_ok=True)
113
113
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
- torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call]
115
115
  base_meta = model_blob_meta.ModelBlobMeta(
116
116
  name=name,
117
117
  model_type=cls.HANDLER_TYPE,
@@ -133,7 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
133
133
  model_meta: model_meta_api.ModelMetadata,
134
134
  model_blobs_dir_path: str,
135
135
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
136
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
136
+ ) -> "torch.jit.ScriptModule":
137
137
  import torch
138
138
 
139
139
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -141,10 +141,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
141
141
  model_blob_metadata = model_blobs_metadata[name]
142
142
  model_blob_filename = model_blob_metadata.path
143
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
144
- m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call]
145
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
146
146
  )
147
- assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
147
+ assert isinstance(m, torch.jit.ScriptModule)
148
148
 
149
149
  if kwargs.get("use_gpu", False):
150
150
  m = m.cuda()
@@ -154,7 +154,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
154
154
  @classmethod
155
155
  def convert_as_custom_model(
156
156
  cls,
157
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
157
+ raw_model: "torch.jit.ScriptModule",
158
158
  model_meta: model_meta_api.ModelMetadata,
159
159
  background_data: Optional[pd.DataFrame] = None,
160
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
@@ -162,11 +162,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
162
162
  from snowflake.ml.model import custom_model
163
163
 
164
164
  def _create_custom_model(
165
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
165
+ raw_model: "torch.jit.ScriptModule",
166
166
  model_meta: model_meta_api.ModelMetadata,
167
167
  ) -> Type[custom_model.CustomModel]:
168
168
  def fn_factory(
169
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
169
+ raw_model: "torch.jit.ScriptModule",
170
170
  signature: model_signature.ModelSignature,
171
171
  target_method: str,
172
172
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]: