snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +67 -10
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  6. snowflake/ml/_internal/telemetry.py +12 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  8. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  9. snowflake/ml/data/data_connector.py +133 -0
  10. snowflake/ml/data/data_ingestor.py +28 -0
  11. snowflake/ml/data/data_source.py +23 -0
  12. snowflake/ml/dataset/dataset.py +1 -13
  13. snowflake/ml/dataset/dataset_reader.py +18 -118
  14. snowflake/ml/feature_store/access_manager.py +7 -1
  15. snowflake/ml/feature_store/entity.py +19 -2
  16. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  25. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  26. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  27. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  28. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  29. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  30. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  31. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  32. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  33. snowflake/ml/feature_store/feature_store.py +579 -53
  34. snowflake/ml/feature_store/feature_view.py +168 -5
  35. snowflake/ml/fileset/stage_fs.py +18 -10
  36. snowflake/ml/lineage/lineage_node.py +1 -1
  37. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  38. snowflake/ml/model/_model_composer/model_composer.py +11 -14
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  41. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  42. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  43. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  44. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  45. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  46. snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
  47. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  48. snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
  50. snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
  51. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  52. snowflake/ml/model/model_signature.py +4 -4
  53. snowflake/ml/model/type_hints.py +4 -0
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  56. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  57. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  58. snowflake/ml/registry/registry.py +100 -13
  59. snowflake/ml/version.py +1 -1
  60. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
  61. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
  62. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  63. snowflake/ml/_internal/lineage/data_source.py +0 -10
  64. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  from abc import abstractmethod
2
+ from enum import Enum
2
3
  from typing import Dict, Generic, Optional, Protocol, Type, final
3
4
 
4
5
  from typing_extensions import TypeGuard, Unpack
@@ -8,6 +9,15 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
8
9
  from snowflake.ml.model._packager.model_meta import model_meta
9
10
 
10
11
 
12
+ class ModelObjective(Enum):
13
+ # This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
14
+ UNKNOWN = "unknown"
15
+ BINARY_CLASSIFICATION = "binary_classification"
16
+ MULTI_CLASSIFICATION = "multi_classification"
17
+ REGRESSION = "regression"
18
+ RANKING = "ranking"
19
+
20
+
11
21
  class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
12
22
  HANDLER_TYPE: model_types.SupportedModelHandlerType
13
23
  HANDLER_VERSION: str
@@ -16,7 +26,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
16
26
 
17
27
  @classmethod
18
28
  @abstractmethod
19
- def can_handle(cls, model: model_types.SupportedDataType) -> TypeGuard[model_types._ModelType]:
29
+ def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
20
30
  """Whether this handler could support the type of the `model`.
21
31
 
22
32
  Args:
@@ -1,4 +1,9 @@
1
- from typing import Callable, Iterable, Optional, Sequence, cast
1
+ import json
2
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pandas as pd
2
7
 
3
8
  from snowflake.ml.model import model_signature, type_hints as model_types
4
9
  from snowflake.ml.model._packager.model_meta import model_meta
@@ -40,6 +45,24 @@ def validate_signature(
40
45
  return model_meta
41
46
 
42
47
 
48
+ def add_explain_method_signature(
49
+ model_meta: model_meta.ModelMetadata,
50
+ explain_method: str,
51
+ target_method: str,
52
+ output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
53
+ ) -> model_meta.ModelMetadata:
54
+ if target_method not in model_meta.signatures:
55
+ raise ValueError(f"Signature for target method {target_method} is missing")
56
+ inputs = model_meta.signatures[target_method].inputs
57
+ model_meta.signatures[explain_method] = model_signature.ModelSignature(
58
+ inputs=inputs,
59
+ outputs=[
60
+ model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
61
+ ],
62
+ )
63
+ return model_meta
64
+
65
+
43
66
  def get_target_methods(
44
67
  model: model_types.SupportedModelType,
45
68
  target_methods: Optional[Sequence[str]],
@@ -56,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
56
79
  for method_name in target_methods:
57
80
  if not _is_callable(model, method_name):
58
81
  raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
82
+
83
+
84
+ def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
85
+ num_classes = getattr(model, "classes_", [])
86
+ return len(num_classes)
87
+
88
+
89
+ def convert_explanations_to_2D_df(
90
+ model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
91
+ ) -> pd.DataFrame:
92
+ if explanations.ndim != 3:
93
+ return pd.DataFrame(explanations)
94
+
95
+ if hasattr(model, "classes_"):
96
+ classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
97
+ len_classes = len(classes_list)
98
+ if explanations.shape[2] != len_classes:
99
+ raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
100
+ else:
101
+ classes_list = [i for i in range(explanations.shape[2])]
102
+ exp_2d = []
103
+ # TODO (SNOW-1549044): Optimize this
104
+ for row in explanations:
105
+ col_list = []
106
+ for column in row:
107
+ class_explanations = {}
108
+ for cl, cl_exp in zip(classes_list, column):
109
+ if isinstance(cl, (int, np.integer)):
110
+ cl = int(cl)
111
+ class_explanations[cl] = cl_exp
112
+ col_list.append(json.dumps(class_explanations))
113
+ exp_2d.append(col_list)
114
+
115
+ return pd.DataFrame(exp_2d)
@@ -33,6 +33,22 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
33
33
  MODELE_BLOB_FILE_OR_DIR = "model.bin"
34
34
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
35
 
36
+ @classmethod
37
+ def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
38
+ import catboost
39
+
40
+ if isinstance(model, catboost.CatBoostClassifier):
41
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
42
+ if num_classes == 2:
43
+ return _base.ModelObjective.BINARY_CLASSIFICATION
44
+ return _base.ModelObjective.MULTI_CLASSIFICATION
45
+ if isinstance(model, catboost.CatBoostRanker):
46
+ return _base.ModelObjective.RANKING
47
+ if isinstance(model, catboost.CatBoostRegressor):
48
+ return _base.ModelObjective.REGRESSION
49
+ # TODO: Find out model type from the generic Catboost Model
50
+ return _base.ModelObjective.UNKNOWN
51
+
36
52
  @classmethod
37
53
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
38
54
  return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
@@ -89,6 +105,16 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
89
105
  sample_input_data=sample_input_data,
90
106
  get_prediction_fn=get_prediction,
91
107
  )
108
+ if kwargs.get("enable_explainability", False):
109
+ output_type = model_signature.DataType.DOUBLE
110
+ if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
111
+ output_type = model_signature.DataType.STRING
112
+ model_meta = handlers_utils.add_explain_method_signature(
113
+ model_meta=model_meta,
114
+ explain_method="explain",
115
+ target_method="predict",
116
+ output_return_type=output_type,
117
+ )
92
118
 
93
119
  model_blob_path = os.path.join(model_blobs_dir_path, name)
94
120
  os.makedirs(model_blob_path, exist_ok=True)
@@ -112,6 +138,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
112
138
  ],
113
139
  check_local_version=True,
114
140
  )
141
+ if kwargs.get("enable_explainability", False):
142
+ model_meta.env.include_if_absent(
143
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
144
+ check_local_version=True,
145
+ )
115
146
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
116
147
 
117
148
  return None
@@ -186,6 +217,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
186
217
 
187
218
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
188
219
 
220
+ @custom_model.inference_api
221
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
222
+ import shap
223
+
224
+ explainer = shap.TreeExplainer(raw_model)
225
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
226
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
227
+
228
+ if target_method == "explain":
229
+ return explain_fn
230
+
189
231
  return fn
190
232
 
191
233
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -43,6 +43,45 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
43
43
 
44
44
  MODELE_BLOB_FILE_OR_DIR = "model.pkl"
45
45
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
+ _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
47
+ _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
48
+ _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
49
+ _REGRESSION_OBJECTIVES = [
50
+ "regression",
51
+ "regression_l1",
52
+ "huber",
53
+ "fair",
54
+ "poisson",
55
+ "quantile",
56
+ "tweedie",
57
+ "mape",
58
+ "gamma",
59
+ ]
60
+
61
+ @classmethod
62
+ def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
63
+ import lightgbm
64
+
65
+ # does not account for cross-entropy and custom
66
+ if isinstance(model, lightgbm.LGBMClassifier):
67
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
68
+ if num_classes == 2:
69
+ return _base.ModelObjective.BINARY_CLASSIFICATION
70
+ return _base.ModelObjective.MULTI_CLASSIFICATION
71
+ if isinstance(model, lightgbm.LGBMRanker):
72
+ return _base.ModelObjective.RANKING
73
+ if isinstance(model, lightgbm.LGBMRegressor):
74
+ return _base.ModelObjective.REGRESSION
75
+ model_objective = model.params["objective"]
76
+ if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
77
+ return _base.ModelObjective.BINARY_CLASSIFICATION
78
+ if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
79
+ return _base.ModelObjective.MULTI_CLASSIFICATION
80
+ if model_objective in cls._RANKING_OBJECTIVES:
81
+ return _base.ModelObjective.RANKING
82
+ if model_objective in cls._REGRESSION_OBJECTIVES:
83
+ return _base.ModelObjective.REGRESSION
84
+ return _base.ModelObjective.UNKNOWN
46
85
 
47
86
  @classmethod
48
87
  def can_handle(
@@ -105,6 +144,19 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
105
144
  sample_input_data=sample_input_data,
106
145
  get_prediction_fn=get_prediction,
107
146
  )
147
+ if kwargs.get("enable_explainability", False):
148
+ output_type = model_signature.DataType.DOUBLE
149
+ if cls.get_model_objective(model) in [
150
+ _base.ModelObjective.BINARY_CLASSIFICATION,
151
+ _base.ModelObjective.MULTI_CLASSIFICATION,
152
+ ]:
153
+ output_type = model_signature.DataType.STRING
154
+ model_meta = handlers_utils.add_explain_method_signature(
155
+ model_meta=model_meta,
156
+ explain_method="explain",
157
+ target_method="predict",
158
+ output_return_type=output_type,
159
+ )
108
160
 
109
161
  model_blob_path = os.path.join(model_blobs_dir_path, name)
110
162
  os.makedirs(model_blob_path, exist_ok=True)
@@ -130,6 +182,11 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
130
182
  ],
131
183
  check_local_version=True,
132
184
  )
185
+ if kwargs.get("enable_explainability", False):
186
+ model_meta.env.include_if_absent(
187
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
188
+ check_local_version=True,
189
+ )
133
190
 
134
191
  return None
135
192
 
@@ -198,6 +255,17 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
198
255
 
199
256
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
200
257
 
258
+ @custom_model.inference_api
259
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
260
+ import shap
261
+
262
+ explainer = shap.TreeExplainer(raw_model)
263
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
264
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
265
+
266
+ if target_method == "explain":
267
+ return explain_fn
268
+
201
269
  return fn
202
270
 
203
271
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -1,4 +1,5 @@
1
1
  # mypy: disable-error-code="import"
2
+ import json
2
3
  import os
3
4
  from typing import (
4
5
  TYPE_CHECKING,
@@ -46,6 +47,39 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
46
47
 
47
48
  MODELE_BLOB_FILE_OR_DIR = "model.ubj"
48
49
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
+ _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
+ _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
52
+ _RANKING_OBJECTIVE_PREFIX = ["rank:"]
53
+ _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
+
55
+ @classmethod
56
+ def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
57
+ import xgboost
58
+
59
+ if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
60
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
61
+ if num_classes == 2:
62
+ return _base.ModelObjective.BINARY_CLASSIFICATION
63
+ return _base.ModelObjective.MULTI_CLASSIFICATION
64
+ if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
65
+ return _base.ModelObjective.REGRESSION
66
+ if isinstance(model, xgboost.XGBRanker):
67
+ return _base.ModelObjective.RANKING
68
+ model_params = json.loads(model.save_config())
69
+ model_objective = model_params["learner"]["objective"]
70
+ for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
71
+ if classification_objective in model_objective:
72
+ return _base.ModelObjective.BINARY_CLASSIFICATION
73
+ for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
74
+ if classification_objective in model_objective:
75
+ return _base.ModelObjective.MULTI_CLASSIFICATION
76
+ for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
77
+ if ranking_objective in model_objective:
78
+ return _base.ModelObjective.RANKING
79
+ for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
80
+ if regression_objective in model_objective:
81
+ return _base.ModelObjective.REGRESSION
82
+ return _base.ModelObjective.UNKNOWN
49
83
 
50
84
  @classmethod
51
85
  def can_handle(
@@ -112,6 +146,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
112
146
  sample_input_data=sample_input_data,
113
147
  get_prediction_fn=get_prediction,
114
148
  )
149
+ if kwargs.get("enable_explainability", False):
150
+ output_type = model_signature.DataType.DOUBLE
151
+ if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
152
+ output_type = model_signature.DataType.STRING
153
+ model_meta = handlers_utils.add_explain_method_signature(
154
+ model_meta=model_meta,
155
+ explain_method="explain",
156
+ target_method="predict",
157
+ output_return_type=output_type,
158
+ )
115
159
 
116
160
  model_blob_path = os.path.join(model_blobs_dir_path, name)
117
161
  os.makedirs(model_blob_path, exist_ok=True)
@@ -133,6 +177,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
133
177
  ],
134
178
  check_local_version=True,
135
179
  )
180
+ if kwargs.get("enable_explainability", False):
181
+ model_meta.env.include_if_absent(
182
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
183
+ check_local_version=True,
184
+ )
136
185
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
137
186
 
138
187
  @classmethod
@@ -206,6 +255,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
206
255
 
207
256
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
208
257
 
258
+ @custom_model.inference_api
259
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
260
+ import shap
261
+
262
+ explainer = shap.TreeExplainer(raw_model)
263
+ df = pd.DataFrame(explainer(X).values)
264
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
265
+
266
+ if target_method == "explain":
267
+ return explain_fn
209
268
  return fn
210
269
 
211
270
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -35,7 +35,7 @@ class ModelRuntime:
35
35
  self,
36
36
  name: str,
37
37
  env: model_env.ModelEnv,
38
- imports: Optional[List[pathlib.PurePosixPath]] = None,
38
+ imports: Optional[List[str]] = None,
39
39
  is_gpu: bool = False,
40
40
  loading_from_file: bool = False,
41
41
  ) -> None:
@@ -75,7 +75,7 @@ class ModelRuntime:
75
75
  snowpark_ml_lib_path = runtime_base_path / "snowflake-ml-python.zip"
76
76
  file_utils.zip_python_package(str(snowpark_ml_lib_path), "snowflake.ml")
77
77
  snowpark_ml_lib_rel_path = pathlib.PurePosixPath(snowpark_ml_lib_path.relative_to(packager_path).as_posix())
78
- self.imports.append(snowpark_ml_lib_rel_path)
78
+ self.imports.append(str(snowpark_ml_lib_rel_path))
79
79
 
80
80
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
81
81
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
@@ -108,6 +108,4 @@ class ModelRuntime:
108
108
  warnings.simplefilter("ignore")
109
109
  env.load_from_conda_file(packager_path / conda_env_rel_path)
110
110
  env.load_from_pip_file(packager_path / pip_requirements_rel_path)
111
- return ModelRuntime(
112
- name=name, env=env, imports=list(map(pathlib.PurePosixPath, loaded_dict["imports"])), loading_from_file=True
113
- )
111
+ return ModelRuntime(name=name, env=env, imports=loaded_dict["imports"], loading_from_file=True)
@@ -232,7 +232,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
232
232
  ),
233
233
  )
234
234
  else:
235
- if isinstance(data_col[0], list):
235
+ if isinstance(data_col.iloc[0], list):
236
236
  if not ft_shape:
237
237
  raise snowml_exceptions.SnowflakeMLException(
238
238
  error_code=error_codes.INVALID_DATA,
@@ -266,7 +266,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
266
266
  ),
267
267
  )
268
268
 
269
- elif isinstance(data_col[0], np.ndarray):
269
+ elif isinstance(data_col.iloc[0], np.ndarray):
270
270
  if not ft_shape:
271
271
  raise snowml_exceptions.SnowflakeMLException(
272
272
  error_code=error_codes.INVALID_DATA,
@@ -297,7 +297,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
297
297
  ),
298
298
  )
299
299
 
300
- elif isinstance(data_col[0], str):
300
+ elif isinstance(data_col.iloc[0], str):
301
301
  if ft_shape is not None:
302
302
  raise snowml_exceptions.SnowflakeMLException(
303
303
  error_code=error_codes.INVALID_DATA,
@@ -316,7 +316,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
316
316
  ),
317
317
  )
318
318
 
319
- elif isinstance(data_col[0], bytes):
319
+ elif isinstance(data_col.iloc[0], bytes):
320
320
  if ft_shape is not None:
321
321
  raise snowml_exceptions.SnowflakeMLException(
322
322
  error_code=error_codes.INVALID_DATA,
@@ -232,11 +232,13 @@ class BaseModelSaveOption(TypedDict):
232
232
  _legacy_save: NotRequired[bool]
233
233
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
234
234
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
235
+ include_pip_dependencies: NotRequired[bool]
235
236
 
236
237
 
237
238
  class CatBoostModelSaveOptions(BaseModelSaveOption):
238
239
  target_methods: NotRequired[Sequence[str]]
239
240
  cuda_version: NotRequired[str]
241
+ enable_explainability: NotRequired[bool]
240
242
 
241
243
 
242
244
  class CustomModelSaveOption(BaseModelSaveOption):
@@ -250,10 +252,12 @@ class SKLModelSaveOptions(BaseModelSaveOption):
250
252
  class XGBModelSaveOptions(BaseModelSaveOption):
251
253
  target_methods: NotRequired[Sequence[str]]
252
254
  cuda_version: NotRequired[str]
255
+ enable_explainability: NotRequired[bool]
253
256
 
254
257
 
255
258
  class LGBMModelSaveOptions(BaseModelSaveOption):
256
259
  target_methods: NotRequired[Sequence[str]]
260
+ enable_explainability: NotRequired[bool]
257
261
 
258
262
 
259
263
  class SNOWModelSaveOptions(BaseModelSaveOption):
@@ -41,7 +41,7 @@ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snow
41
41
 
42
42
  _PROJECT = "ModelDevelopment"
43
43
  DEFAULT_UDTF_NJOBS = 3
44
- ENABLE_EFFICIENT_MEMORY_USAGE = False
44
+ ENABLE_EFFICIENT_MEMORY_USAGE = True
45
45
  _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
46
46
 
47
47
 
@@ -83,7 +83,19 @@ def _load_data_into_udf() -> Tuple[
83
83
  with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
84
84
  fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
85
85
 
86
- # convert dataframe to numpy would save memory consumption
86
+ # Convert dataframe to numpy would save memory consumption
87
+ # Except for Pipeline, we need to keep the dataframe for the column names
88
+ from sklearn.pipeline import Pipeline
89
+ if isinstance(base_estimator, Pipeline):
90
+ return (
91
+ df[CONSTANTS['input_cols']],
92
+ df[CONSTANTS['label_cols']].squeeze(),
93
+ indices,
94
+ params_to_evaluate,
95
+ base_estimator,
96
+ fit_and_score_kwargs,
97
+ CONSTANTS
98
+ )
87
99
  return (
88
100
  df[CONSTANTS['input_cols']].to_numpy(),
89
101
  df[CONSTANTS['label_cols']].squeeze().to_numpy(),
@@ -1,5 +1,6 @@
1
1
  #!/usr/bin/env python3
2
2
  import copy
3
+ import warnings
3
4
  from typing import Any, Dict, Iterable, Optional, Type, Union
4
5
 
5
6
  import numpy as np
@@ -10,6 +11,7 @@ from sklearn import impute
10
11
  from snowflake import snowpark
11
12
  from snowflake.ml._internal import telemetry
12
13
  from snowflake.ml._internal.exceptions import error_codes, exceptions
14
+ from snowflake.ml._internal.utils import formatting
13
15
  from snowflake.ml.modeling.framework import _utils, base
14
16
  from snowflake.snowpark import functions as F, types as T
15
17
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -171,6 +173,14 @@ class SimpleImputer(base.BaseTransformer):
171
173
  self.set_output_cols(output_cols)
172
174
  self.set_passthrough_cols(passthrough_cols)
173
175
 
176
+ def _is_integer_type(self, column_type: T.DataType) -> bool:
177
+ return (
178
+ isinstance(column_type, T.ByteType)
179
+ or isinstance(column_type, T.ShortType)
180
+ or isinstance(column_type, T.IntegerType)
181
+ or isinstance(column_type, T.LongType)
182
+ )
183
+
174
184
  def _reset(self) -> None:
175
185
  """
176
186
  Reset internal data-dependent state of the imputer, if necessary.
@@ -389,6 +399,22 @@ class SimpleImputer(base.BaseTransformer):
389
399
  # Use `fillna` for replacing nans. Check if the column has a string data type, or coerce a float.
390
400
  if not isinstance(input_col_datatypes[input_col], T.StringType):
391
401
  statistic = float(statistic)
402
+
403
+ if self._is_integer_type(input_col_datatypes[input_col]):
404
+ if statistic.is_integer():
405
+ statistic = int(statistic)
406
+ else:
407
+ warnings.warn(
408
+ formatting.unwrap(
409
+ f"""
410
+ Integer column may not be imputed with a non-integer value {statistic}.
411
+ In order to impute a non-integer value, convert the column to FloatType before imputing.
412
+ """
413
+ ),
414
+ category=UserWarning,
415
+ stacklevel=1,
416
+ )
417
+
392
418
  transformed_dataset = transformed_dataset.na.fill({output_col: statistic})
393
419
  else:
394
420
  transformed_dataset = transformed_dataset.na.replace(
@@ -99,10 +99,6 @@ class Pipeline(base.BaseTransformer):
99
99
  must implement `fit` and `transform` methods.
100
100
  The final step can be a transform or estimator, that is, it must implement
101
101
  `fit` and `transform`/`predict` methods.
102
- TODO: SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
103
- estimator(like None or passthrough). Currently this Pipeline class works with a list of all
104
- transforms or a list of transforms ending with an estimator. Should we change this implementation
105
- to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
106
102
 
107
103
  Args:
108
104
  steps: List of (name, transform) tuples (implementing `fit`/`transform`) that
@@ -111,6 +107,10 @@ class Pipeline(base.BaseTransformer):
111
107
  """
112
108
  super().__init__()
113
109
  self.steps = steps
110
+ # TODO(snandamuri): SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
111
+ # estimator(like None or passthrough). Currently this Pipeline class works with a list of all
112
+ # transforms or a list of transforms ending with an estimator. Should we change this implementation
113
+ # to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
114
114
  self._is_final_step_estimator = Pipeline._is_estimator(steps[-1][1])
115
115
  self._is_fitted = False
116
116
  self._feature_names_in: List[np.ndarray[Any, np.dtype[Any]]] = []