snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +14 -0
  5. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  6. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  7. snowflake/ml/data/data_connector.py +59 -6
  8. snowflake/ml/data/data_ingestor.py +18 -1
  9. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  10. snowflake/ml/data/torch_dataset.py +33 -0
  11. snowflake/ml/dataset/dataset_metadata.py +3 -1
  12. snowflake/ml/dataset/dataset_reader.py +9 -3
  13. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  16. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  25. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  26. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  27. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  29. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  30. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  31. snowflake/ml/feature_store/feature_store.py +59 -24
  32. snowflake/ml/feature_store/feature_view.py +148 -4
  33. snowflake/ml/model/_client/model/model_impl.py +11 -2
  34. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  35. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  36. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  37. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  38. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  39. snowflake/ml/model/_client/sql/model_version.py +13 -4
  40. snowflake/ml/model/_client/sql/service.py +129 -0
  41. snowflake/ml/model/_model_composer/model_composer.py +3 -0
  42. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  44. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  45. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
  47. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  48. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
  50. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  51. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  52. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  53. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  54. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  55. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  57. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  58. snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
  59. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  60. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  61. snowflake/ml/model/_packager/model_packager.py +2 -1
  62. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  63. snowflake/ml/model/type_hints.py +1 -3
  64. snowflake/ml/modeling/framework/base.py +28 -19
  65. snowflake/ml/modeling/pipeline/pipeline.py +3 -0
  66. snowflake/ml/registry/_manager/model_manager.py +16 -2
  67. snowflake/ml/utils/sql_client.py +22 -0
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
  70. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
  71. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  72. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
  74. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
45
45
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
46
46
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
47
47
 
48
- MODELE_BLOB_FILE_OR_DIR = "model.ubj"
48
+ MODEL_BLOB_FILE_OR_DIR = "model.ubj"
49
49
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
50
  _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
51
  _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
@@ -53,33 +53,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
53
53
  _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
54
 
55
55
  @classmethod
56
- def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
56
+ def get_model_objective(
57
+ cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
58
+ ) -> model_meta_schema.ModelObjective:
57
59
  import xgboost
58
60
 
59
61
  if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
60
62
  num_classes = handlers_utils.get_num_classes_if_exists(model)
61
63
  if num_classes == 2:
62
- return _base.ModelObjective.BINARY_CLASSIFICATION
63
- return _base.ModelObjective.MULTI_CLASSIFICATION
64
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
65
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
64
66
  if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
65
- return _base.ModelObjective.REGRESSION
67
+ return model_meta_schema.ModelObjective.REGRESSION
66
68
  if isinstance(model, xgboost.XGBRanker):
67
- return _base.ModelObjective.RANKING
69
+ return model_meta_schema.ModelObjective.RANKING
68
70
  model_params = json.loads(model.save_config())
69
71
  model_objective = model_params["learner"]["objective"]
70
72
  for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
71
73
  if classification_objective in model_objective:
72
- return _base.ModelObjective.BINARY_CLASSIFICATION
74
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
73
75
  for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
74
76
  if classification_objective in model_objective:
75
- return _base.ModelObjective.MULTI_CLASSIFICATION
77
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
76
78
  for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
77
79
  if ranking_objective in model_objective:
78
- return _base.ModelObjective.RANKING
80
+ return model_meta_schema.ModelObjective.RANKING
79
81
  for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
80
82
  if regression_objective in model_objective:
81
- return _base.ModelObjective.REGRESSION
82
- return _base.ModelObjective.UNKNOWN
83
+ return model_meta_schema.ModelObjective.REGRESSION
84
+ return model_meta_schema.ModelObjective.UNKNOWN
83
85
 
84
86
  @classmethod
85
87
  def can_handle(
@@ -146,9 +148,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
146
148
  sample_input_data=sample_input_data,
147
149
  get_prediction_fn=get_prediction,
148
150
  )
149
- if kwargs.get("enable_explainability", False):
151
+ model_objective = cls.get_model_objective(model)
152
+ model_meta.model_objective = model_objective
153
+ if kwargs.get("enable_explainability", True):
150
154
  output_type = model_signature.DataType.DOUBLE
151
- if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
155
+ if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
152
156
  output_type = model_signature.DataType.STRING
153
157
  model_meta = handlers_utils.add_explain_method_signature(
154
158
  model_meta=model_meta,
@@ -156,15 +160,18 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
156
160
  target_method="predict",
157
161
  output_return_type=output_type,
158
162
  )
163
+ model_meta.function_properties = {
164
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
165
+ }
159
166
 
160
167
  model_blob_path = os.path.join(model_blobs_dir_path, name)
161
168
  os.makedirs(model_blob_path, exist_ok=True)
162
- model.save_model(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
169
+ model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
163
170
  base_meta = model_blob_meta.ModelBlobMeta(
164
171
  name=name,
165
172
  model_type=cls.HANDLER_TYPE,
166
173
  handler_version=cls.HANDLER_VERSION,
167
- path=cls.MODELE_BLOB_FILE_OR_DIR,
174
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
168
175
  options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
169
176
  )
170
177
  model_meta.models[name] = base_meta
@@ -177,11 +184,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
177
184
  ],
178
185
  check_local_version=True,
179
186
  )
180
- if kwargs.get("enable_explainability", False):
187
+ if kwargs.get("enable_explainability", True):
181
188
  model_meta.env.include_if_absent(
182
189
  [model_env.ModelDependency(requirement="shap", pip_name="shap")],
183
190
  check_local_version=True,
184
191
  )
192
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
185
193
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
186
194
 
187
195
  @classmethod
@@ -224,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
224
232
  cls,
225
233
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
226
234
  model_meta: model_meta_api.ModelMetadata,
235
+ background_data: Optional[pd.DataFrame] = None,
227
236
  **kwargs: Unpack[model_types.XGBModelLoadOptions],
228
237
  ) -> custom_model.CustomModel:
229
238
  import xgboost
@@ -237,6 +237,7 @@ class ModelMetadata:
237
237
  function_properties: A dict mapping function names to dict mapping function property key to value.
238
238
  metadata: User provided key-value metadata of the model. Defaults to None.
239
239
  creation_timestamp: Unix timestamp when the model metadata is created.
240
+ model_objective: Model objective like regression, classification etc.
240
241
  """
241
242
 
242
243
  def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
@@ -260,6 +261,8 @@ class ModelMetadata:
260
261
  min_snowpark_ml_version: Optional[str] = None,
261
262
  models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
262
263
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
264
+ model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN,
265
+ explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
263
266
  ) -> None:
264
267
  self.name = name
265
268
  self.signatures: Dict[str, model_signature.ModelSignature] = dict()
@@ -284,6 +287,11 @@ class ModelMetadata:
284
287
 
285
288
  self.original_metadata_version = original_metadata_version
286
289
 
290
+ self.model_objective: model_meta_schema.ModelObjective = (
291
+ model_objective or model_meta_schema.ModelObjective.UNKNOWN
292
+ )
293
+ self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
294
+
287
295
  @property
288
296
  def min_snowpark_ml_version(self) -> str:
289
297
  return self._min_snowpark_ml_version.base_version
@@ -321,9 +329,11 @@ class ModelMetadata:
321
329
  model_dict = model_meta_schema.ModelMetadataDict(
322
330
  {
323
331
  "creation_timestamp": self.creation_timestamp,
324
- "env": self.env.save_as_dict(pathlib.Path(model_dir_path)),
332
+ "env": self.env.save_as_dict(
333
+ pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
334
+ ),
325
335
  "runtimes": {
326
- runtime_name: runtime.save(pathlib.Path(model_dir_path))
336
+ runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
327
337
  for runtime_name, runtime in self.runtimes.items()
328
338
  },
329
339
  "metadata": self.metadata,
@@ -333,6 +343,13 @@ class ModelMetadata:
333
343
  "signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
334
344
  "version": model_meta_schema.MODEL_METADATA_VERSION,
335
345
  "min_snowpark_ml_version": self.min_snowpark_ml_version,
346
+ "model_objective": self.model_objective.value,
347
+ "explainability": (
348
+ model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
349
+ if self.explain_algorithm
350
+ else None
351
+ ),
352
+ "function_properties": self.function_properties,
336
353
  }
337
354
  )
338
355
 
@@ -370,6 +387,9 @@ class ModelMetadata:
370
387
  signatures=loaded_meta["signatures"],
371
388
  version=original_loaded_meta_version,
372
389
  min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
390
+ model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value),
391
+ explainability=loaded_meta.get("explainability", None),
392
+ function_properties=loaded_meta.get("function_properties", {}),
373
393
  )
374
394
 
375
395
  @classmethod
@@ -406,6 +426,11 @@ class ModelMetadata:
406
426
  else:
407
427
  runtimes = None
408
428
 
429
+ explanation_algorithm_dict = model_dict.get("explainability", None)
430
+ explanation_algorithm = None
431
+ if explanation_algorithm_dict:
432
+ explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
433
+
409
434
  return cls(
410
435
  name=model_dict["name"],
411
436
  model_type=model_dict["model_type"],
@@ -417,4 +442,9 @@ class ModelMetadata:
417
442
  min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
418
443
  models=models,
419
444
  original_metadata_version=model_dict["version"],
445
+ model_objective=model_meta_schema.ModelObjective(
446
+ model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value)
447
+ ),
448
+ explain_algorithm=explanation_algorithm,
449
+ function_properties=model_dict.get("function_properties", {}),
420
450
  )
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
71
71
  ]
72
72
 
73
73
 
74
+ class ExplainabilityMetadataDict(TypedDict):
75
+ algorithm: Required[str]
76
+
77
+
74
78
  class ModelBlobMetadataDict(TypedDict):
75
79
  name: Required[str]
76
80
  model_type: Required[type_hints.SupportedModelHandlerType]
@@ -92,3 +96,18 @@ class ModelMetadataDict(TypedDict):
92
96
  signatures: Required[Dict[str, Dict[str, Any]]]
93
97
  version: Required[str]
94
98
  min_snowpark_ml_version: Required[str]
99
+ model_objective: Required[str]
100
+ explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
101
+ function_properties: NotRequired[Dict[str, Dict[str, Any]]]
102
+
103
+
104
+ class ModelObjective(Enum):
105
+ UNKNOWN = "unknown"
106
+ BINARY_CLASSIFICATION = "binary_classification"
107
+ MULTI_CLASSIFICATION = "multi_classification"
108
+ REGRESSION = "regression"
109
+ RANKING = "ranking"
110
+
111
+
112
+ class ModelExplainAlgorithm(Enum):
113
+ SHAP = "shap"
@@ -146,7 +146,8 @@ class ModelPackager:
146
146
  m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
147
147
 
148
148
  if as_custom_model:
149
- m = handler.convert_as_custom_model(m, self.meta, **options)
149
+ background_data = handler.load_background_data(self.meta.name, model_blobs_path)
150
+ m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
150
151
  assert isinstance(m, custom_model.CustomModel)
151
152
 
152
153
  self.model = m
@@ -67,7 +67,9 @@ class ModelRuntime:
67
67
  def runtime_rel_path(self) -> pathlib.PurePosixPath:
68
68
  return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
69
69
 
70
- def save(self, packager_path: pathlib.Path) -> model_meta_schema.ModelRuntimeDict:
70
+ def save(
71
+ self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
72
+ ) -> model_meta_schema.ModelRuntimeDict:
71
73
  runtime_base_path = packager_path / self.runtime_rel_path
72
74
  runtime_base_path.mkdir(parents=True, exist_ok=True)
73
75
 
@@ -80,7 +82,7 @@ class ModelRuntime:
80
82
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
81
83
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
82
84
 
83
- env_dict = self.runtime_env.save_as_dict(packager_path)
85
+ env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
84
86
 
85
87
  return model_meta_schema.ModelRuntimeDict(
86
88
  imports=list(map(str, self.imports)),
@@ -233,12 +233,12 @@ class BaseModelSaveOption(TypedDict):
233
233
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
234
234
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
235
235
  include_pip_dependencies: NotRequired[bool]
236
+ enable_explainability: NotRequired[bool]
236
237
 
237
238
 
238
239
  class CatBoostModelSaveOptions(BaseModelSaveOption):
239
240
  target_methods: NotRequired[Sequence[str]]
240
241
  cuda_version: NotRequired[str]
241
- enable_explainability: NotRequired[bool]
242
242
 
243
243
 
244
244
  class CustomModelSaveOption(BaseModelSaveOption):
@@ -252,12 +252,10 @@ class SKLModelSaveOptions(BaseModelSaveOption):
252
252
  class XGBModelSaveOptions(BaseModelSaveOption):
253
253
  target_methods: NotRequired[Sequence[str]]
254
254
  cuda_version: NotRequired[str]
255
- enable_explainability: NotRequired[bool]
256
255
 
257
256
 
258
257
  class LGBMModelSaveOptions(BaseModelSaveOption):
259
258
  target_methods: NotRequired[Sequence[str]]
260
- enable_explainability: NotRequired[bool]
261
259
 
262
260
 
263
261
  class SNOWModelSaveOptions(BaseModelSaveOption):
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python3
2
2
  import inspect
3
3
  from abc import abstractmethod
4
- from collections import defaultdict
5
4
  from datetime import datetime
6
5
  from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
7
6
 
@@ -18,6 +17,7 @@ from snowflake.ml._internal.exceptions import (
18
17
  )
19
18
  from snowflake.ml._internal.lineage import lineage_utils
20
19
  from snowflake.ml._internal.utils import identifier, parallelize
20
+ from snowflake.ml.data import data_source
21
21
  from snowflake.ml.modeling.framework import _utils
22
22
  from snowflake.snowpark import functions as F
23
23
 
@@ -246,7 +246,7 @@ class Base:
246
246
 
247
247
  def get_params(self, deep: bool = True) -> Dict[str, Any]:
248
248
  """
249
- Get parameters for this transformer.
249
+ Get the snowflake-ml parameters for this transformer.
250
250
 
251
251
  Args:
252
252
  deep: If True, will return the parameters for this transformer and
@@ -265,13 +265,13 @@ class Base:
265
265
  out[key] = value
266
266
  return out
267
267
 
268
- def set_params(self, **params: Dict[str, Any]) -> None:
268
+ def set_params(self, **params: Any) -> None:
269
269
  """
270
270
  Set the parameters of this transformer.
271
271
 
272
- The method works on simple transformers as well as on nested objects.
273
- The latter have parameters of the form ``<component>__<parameter>``
274
- so that it's possible to update each component of a nested object.
272
+ The method works on simple transformers as well as on sklearn compatible pipelines with nested
273
+ objects, once the transformer has been fit. Nested objects have parameters of the form
274
+ ``<component>__<parameter>`` so that it's possible to update each component of a nested object.
275
275
 
276
276
  Args:
277
277
  **params: Transformer parameter names mapped to their values.
@@ -283,12 +283,28 @@ class Base:
283
283
  # simple optimization to gain speed (inspect is slow)
284
284
  return
285
285
  valid_params = self.get_params(deep=True)
286
+ valid_skl_params = {}
287
+ if hasattr(self, "_sklearn_object") and self._sklearn_object is not None:
288
+ valid_skl_params = self._sklearn_object.get_params()
286
289
 
287
- nested_params: Dict[str, Any] = defaultdict(dict) # grouped by prefix
288
290
  for key, value in params.items():
289
- key, delim, sub_key = key.partition("__")
290
- if key not in valid_params:
291
- local_valid_params = self._get_param_names()
291
+ if valid_params.get("steps"):
292
+ # Recurse through pipeline steps
293
+ key, _, sub_key = key.partition("__")
294
+ for name, nested_object in valid_params["steps"]:
295
+ if name == key:
296
+ nested_object.set_params(**{sub_key: value})
297
+
298
+ elif key in valid_params:
299
+ setattr(self, key, value)
300
+ valid_params[key] = value
301
+ elif key in valid_skl_params:
302
+ # This dictionary would be empty if the following assert were not true, as specified above.
303
+ assert hasattr(self, "_sklearn_object") and self._sklearn_object is not None
304
+ setattr(self._sklearn_object, key, value)
305
+ valid_skl_params[key] = value
306
+ else:
307
+ local_valid_params = self._get_param_names() + list(valid_skl_params.keys())
292
308
  raise exceptions.SnowflakeMLException(
293
309
  error_code=error_codes.INVALID_ARGUMENT,
294
310
  original_exception=ValueError(
@@ -298,15 +314,6 @@ class Base:
298
314
  ),
299
315
  )
300
316
 
301
- if delim:
302
- nested_params[key][sub_key] = value
303
- else:
304
- setattr(self, key, value)
305
- valid_params[key] = value
306
-
307
- for key, sub_params in nested_params.items():
308
- valid_params[key].set_params(**sub_params)
309
-
310
317
  def get_sklearn_args(
311
318
  self,
312
319
  default_sklearn_obj: Optional[object] = None,
@@ -427,6 +434,8 @@ class BaseEstimator(Base):
427
434
  def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
428
435
  """Runs universal logics for all fit implementations."""
429
436
  data_sources = lineage_utils.get_data_sources(dataset)
437
+ if not data_sources and isinstance(dataset, snowpark.DataFrame):
438
+ data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
430
439
  lineage_utils.set_data_sources(self, data_sources)
431
440
  return self._fit(dataset)
432
441
 
@@ -19,6 +19,7 @@ from snowflake.ml._internal import file_utils, telemetry
19
19
  from snowflake.ml._internal.exceptions import error_codes, exceptions
20
20
  from snowflake.ml._internal.lineage import lineage_utils
21
21
  from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
22
+ from snowflake.ml.data import data_source
22
23
  from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
23
24
  from snowflake.ml.modeling._internal.model_transformer_builder import (
24
25
  ModelTransformerBuilder,
@@ -431,6 +432,8 @@ class Pipeline(base.BaseTransformer):
431
432
 
432
433
  # Extract lineage information here since we're overriding fit() directly
433
434
  data_sources = lineage_utils.get_data_sources(dataset)
435
+ if not data_sources and isinstance(dataset, snowpark.DataFrame):
436
+ data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
434
437
  lineage_utils.set_data_sources(self, data_sources)
435
438
 
436
439
  if self._can_be_trained_in_ml_runtime(dataset):
@@ -9,7 +9,7 @@ from snowflake.ml._internal.human_readable_id import hrid_generator
9
9
  from snowflake.ml._internal.utils import sql_identifier
10
10
  from snowflake.ml.model import model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._client.model import model_impl, model_version_impl
12
- from snowflake.ml.model._client.ops import metadata_ops, model_ops
12
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
13
13
  from snowflake.ml.model._model_composer import model_composer
14
14
  from snowflake.ml.model._packager.model_meta import model_meta
15
15
  from snowflake.snowpark import session
@@ -30,6 +30,9 @@ class ModelManager:
30
30
  self._model_ops = model_ops.ModelOperator(
31
31
  session, database_name=self._database_name, schema_name=self._schema_name
32
32
  )
33
+ self._service_ops = service_ops.ServiceOperator(
34
+ session, database_name=self._database_name, schema_name=self._schema_name
35
+ )
33
36
  self._hrid_generator = hrid_generator.HRID16()
34
37
 
35
38
  def log_model(
@@ -173,11 +176,16 @@ class ModelManager:
173
176
  )
174
177
 
175
178
  mv = model_version_impl.ModelVersion._ref(
176
- model_ops.ModelOperator(
179
+ model_ops=model_ops.ModelOperator(
177
180
  self._model_ops._session,
178
181
  database_name=database_name_id or self._database_name,
179
182
  schema_name=schema_name_id or self._schema_name,
180
183
  ),
184
+ service_ops=service_ops.ServiceOperator(
185
+ self._service_ops._session,
186
+ database_name=database_name_id or self._database_name,
187
+ schema_name=schema_name_id or self._schema_name,
188
+ ),
181
189
  model_name=model_name_id,
182
190
  version_name=version_name_id,
183
191
  )
@@ -216,6 +224,11 @@ class ModelManager:
216
224
  database_name=database_name_id or self._database_name,
217
225
  schema_name=schema_name_id or self._schema_name,
218
226
  ),
227
+ service_ops=service_ops.ServiceOperator(
228
+ self._service_ops._session,
229
+ database_name=database_name_id or self._database_name,
230
+ schema_name=schema_name_id or self._schema_name,
231
+ ),
219
232
  model_name=model_name_id,
220
233
  )
221
234
  else:
@@ -234,6 +247,7 @@ class ModelManager:
234
247
  return [
235
248
  model_impl.Model._ref(
236
249
  self._model_ops,
250
+ service_ops=self._service_ops,
237
251
  model_name=model_name,
238
252
  )
239
253
  for model_name in model_names
@@ -0,0 +1,22 @@
1
+ from enum import Enum
2
+ from typing import Dict
3
+
4
+
5
+ class CreationOption(Enum):
6
+ FAIL_IF_NOT_EXIST = 1
7
+ CREATE_IF_NOT_EXIST = 2
8
+ OR_REPLACE = 3
9
+
10
+
11
+ class CreationMode:
12
+ def __init__(self, *, if_not_exists: bool = False, or_replace: bool = False) -> None:
13
+ self.if_not_exists = if_not_exists
14
+ self.or_replace = or_replace
15
+
16
+ def get_ddl_phrases(self) -> Dict[CreationOption, str]:
17
+ if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
18
+ or_replace_sql = " OR REPLACE" if self.or_replace else ""
19
+ return {
20
+ CreationOption.CREATE_IF_NOT_EXIST: if_not_exists_sql,
21
+ CreationOption.OR_REPLACE: or_replace_sql,
22
+ }
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.6.0"
1
+ VERSION="1.6.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: snowflake-ml-python
3
- Version: 1.6.0
3
+ Version: 1.6.1
4
4
  Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
5
5
  Author-email: "Snowflake, Inc" <support@snowflake.com>
6
6
  License:
@@ -373,7 +373,32 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
373
373
 
374
374
  # Release History
375
375
 
376
- ## 1.6.0
376
+ ## 1.6.1 (TBD)
377
+
378
+ ### Bug Fixes
379
+
380
+ - Feature Store: Support large metadata blob when generating dataset
381
+ - Feature Store: Added a hidden knob in FeatureView as kargs for setting customized
382
+ refresh_mode
383
+ - Registry: Fix an error message in Model Version `run` when `function_name` is not mentioned and model has multiple
384
+ target methods.
385
+ - Cortex inference: snowflake.cortex.Complete now only uses the REST API for streaming and the use_rest_api_experimental
386
+ is no longer needed.
387
+ - Feature Store: Add a new API: FeatureView.list_columns() which list all column information.
388
+ - Data: Fix `DataFrame` ingestion with `ArrowIngestor`.
389
+
390
+ ### New Features
391
+
392
+ - Enable `set_params` to set the parameters of the underlying sklearn estimator, if the snowflake-ml model has been fit.
393
+ - Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`.
394
+ - Data: Add `snowflake.ml.data.ingestor_utils` module with utility functions helpful for `DataIngestor` implementations.
395
+ - Data: Add new `to_torch_dataset()` connector to `DataConnector` to replace deprecated DataPipe.
396
+ - Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature.
397
+ - Registry: Option to `enable_explainability` when registering SHAP supported sklearn models.
398
+
399
+ ### Behavior Changes
400
+
401
+ ## 1.6.0 (2024-07-29)
377
402
 
378
403
  ### Bug Fixes
379
404
 
@@ -402,6 +427,14 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
402
427
  distributed_hpo_trainer.ENABLE_EFFICIENT_MEMORY_USAGE = False
403
428
  `
404
429
  - Registry: Option to `enable_explainability` when registering LightGBM models as a pre-PuPr feature.
430
+ - Data: Add new `snowflake.ml.data` preview module which contains data reading utilities like `DataConnector`
431
+ - `DataConnector` provides efficient connectors from Snowpark `DataFrame`
432
+ and Snowpark ML `Dataset` to external frameworks like PyTorch, TensorFlow, and Pandas. Create `DataConnector`
433
+ instances using the classmethod constructors `DataConnector.from_dataset()` and `DataConnector.from_dataframe()`.
434
+ - Data: Add new `DataConnector.from_sources()` classmethod constructor for constructing from `DataSource` objects.
435
+ - Data: Add new `ingestor_class` arg to `DataConnector` classmethod constructors for easier `DataIngestor` injection.
436
+ - Dataset: `DatasetReader` now subclasses new `DataConnector` class.
437
+ - Add optional `limit` arg to `DatasetReader.to_pandas()`
405
438
 
406
439
  ### Behavior Changes
407
440