snowflake-ml-python 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. snowflake/ml/_internal/telemetry.py +4 -2
  2. snowflake/ml/_internal/utils/import_utils.py +31 -0
  3. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  4. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  5. snowflake/ml/data/data_connector.py +1 -1
  6. snowflake/ml/data/torch_utils.py +33 -14
  7. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  8. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  9. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  10. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  11. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  12. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  13. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  14. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  15. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  16. snowflake/ml/feature_store/feature_store.py +1 -2
  17. snowflake/ml/feature_store/feature_view.py +5 -1
  18. snowflake/ml/model/_client/model/model_version_impl.py +144 -10
  19. snowflake/ml/model/_client/ops/model_ops.py +25 -6
  20. snowflake/ml/model/_client/ops/service_ops.py +33 -28
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  23. snowflake/ml/model/_client/sql/model.py +14 -0
  24. snowflake/ml/model/_client/sql/service.py +6 -18
  25. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  29. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -1
  30. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -6
  31. snowflake/ml/model/_packager/model_handlers/custom.py +2 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  33. snowflake/ml/model/_packager/model_handlers/lightgbm.py +3 -6
  34. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  35. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -6
  36. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -65
  37. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  38. snowflake/ml/model/_packager/model_packager.py +0 -11
  39. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +13 -25
  40. snowflake/ml/model/_signatures/pandas_handler.py +16 -0
  41. snowflake/ml/model/custom_model.py +47 -7
  42. snowflake/ml/model/model_signature.py +2 -0
  43. snowflake/ml/model/type_hints.py +8 -0
  44. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  45. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  46. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  49. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  51. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  53. snowflake/ml/modeling/cluster/k_means.py +14 -19
  54. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  55. snowflake/ml/modeling/cluster/optics.py +6 -6
  56. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  57. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  61. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  62. snowflake/ml/modeling/covariance/oas.py +1 -1
  63. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  64. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  65. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  66. snowflake/ml/modeling/decomposition/pca.py +28 -15
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  80. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  81. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  82. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  83. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  84. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  85. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  86. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  87. snowflake/ml/modeling/linear_model/lars.py +0 -10
  88. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  89. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  90. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  91. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  92. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  93. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  94. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  95. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  96. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  97. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  98. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  99. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  100. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  101. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  102. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  103. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  104. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  105. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  106. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  107. snowflake/ml/modeling/manifold/isomap.py +1 -1
  108. snowflake/ml/modeling/manifold/mds.py +3 -3
  109. snowflake/ml/modeling/manifold/tsne.py +10 -4
  110. snowflake/ml/modeling/metrics/classification.py +12 -16
  111. snowflake/ml/modeling/metrics/ranking.py +3 -3
  112. snowflake/ml/modeling/metrics/regression.py +3 -3
  113. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  114. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  115. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  116. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  117. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  118. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  119. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  120. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  121. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  122. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  123. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  124. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  125. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  126. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  127. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  128. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  129. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  130. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  131. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  132. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  133. snowflake/ml/modeling/svm/svc.py +9 -5
  134. snowflake/ml/modeling/svm/svr.py +3 -1
  135. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  136. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  137. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  138. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  139. snowflake/ml/monitoring/_client/{monitor_sql_client.py → model_monitor_sql_client.py} +1 -1
  140. snowflake/ml/monitoring/{_client → _manager}/model_monitor_manager.py +9 -8
  141. snowflake/ml/monitoring/{_client/model_monitor.py → model_monitor.py} +3 -3
  142. snowflake/ml/registry/_manager/model_manager.py +15 -1
  143. snowflake/ml/registry/registry.py +15 -8
  144. snowflake/ml/version.py +1 -1
  145. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/METADATA +81 -9
  146. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/RECORD +150 -150
  147. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/WHEEL +1 -1
  148. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  149. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/LICENSE.txt +0 -0
  150. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
- from importlib import metadata as importlib_metadata
5
4
  from typing import (
6
5
  TYPE_CHECKING,
7
6
  Any,
@@ -16,23 +15,19 @@ from typing import (
16
15
 
17
16
  import numpy as np
18
17
  import pandas as pd
19
- from packaging import version
20
18
  from typing_extensions import TypeGuard, Unpack
21
19
 
22
20
  from snowflake.ml._internal import type_utils
23
21
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
24
22
  from snowflake.ml.model._packager.model_env import model_env
25
- from snowflake.ml.model._packager.model_handlers import (
26
- _base,
27
- _utils as handlers_utils,
28
- model_objective_utils,
29
- )
23
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
30
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
31
25
  from snowflake.ml.model._packager.model_meta import (
32
26
  model_blob_meta,
33
27
  model_meta as model_meta_api,
34
28
  model_meta_schema,
35
29
  )
30
+ from snowflake.ml.model._packager.model_task import model_task_utils
36
31
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
37
32
 
38
33
  if TYPE_CHECKING:
@@ -94,23 +89,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
94
89
 
95
90
  assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
96
91
 
97
- local_xgb_version = None
98
-
99
- try:
100
- local_dist = importlib_metadata.distribution("xgboost")
101
- local_xgb_version = version.parse(local_dist.version)
102
- except importlib_metadata.PackageNotFoundError:
103
- pass
104
-
105
- if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
106
- warnings.warn(
107
- f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
108
- + "If you want model explanations, lower the xgboost version to <2.1.0.",
109
- category=UserWarning,
110
- stacklevel=1,
111
- )
112
- enable_explainability = False
113
-
114
92
  if not is_sub_model:
115
93
  target_methods = handlers_utils.get_target_methods(
116
94
  model=model,
@@ -139,7 +117,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
139
117
  sample_input_data=sample_input_data,
140
118
  get_prediction_fn=get_prediction,
141
119
  )
142
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
120
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
143
121
  model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
144
122
  if enable_explainability:
145
123
  model_meta = handlers_utils.add_explain_method_signature(
@@ -187,23 +165,15 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
187
165
  ],
188
166
  check_local_version=True,
189
167
  )
190
- if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
191
- model_meta.env.include_if_absent(
192
- [
193
- model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
194
- ],
195
- check_local_version=False,
196
- )
197
- else:
198
- model_meta.env.include_if_absent(
199
- [
200
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
201
- ],
202
- check_local_version=True,
203
- )
168
+ model_meta.env.include_if_absent(
169
+ [
170
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
171
+ ],
172
+ check_local_version=True,
173
+ )
204
174
 
205
175
  if enable_explainability:
206
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
176
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
207
177
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
208
178
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
209
179
 
@@ -61,17 +61,6 @@ class ModelPackager:
61
61
  if not options:
62
62
  options = model_types.BaseModelSaveOption()
63
63
 
64
- # here handling the case of enable_explainability is False/None
65
- enable_explainability = options.get("enable_explainability", None)
66
- if enable_explainability is False or enable_explainability is None:
67
- if (signatures is not None) and (sample_input_data is not None):
68
- raise snowml_exceptions.SnowflakeMLException(
69
- error_code=error_codes.INVALID_ARGUMENT,
70
- original_exception=ValueError(
71
- "Signatures and sample_input_data both cannot be specified at the same time."
72
- ),
73
- )
74
-
75
64
  handler = model_handler.find_handler(model)
76
65
  if handler is None:
77
66
  raise snowml_exceptions.SnowflakeMLException(
@@ -128,42 +128,30 @@ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> t
128
128
  return type_hints.Task.UNKNOWN
129
129
 
130
130
 
131
- def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
131
+ def _get_model_task(model: Any) -> type_hints.Task:
132
132
  if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
133
133
  model
134
134
  ):
135
- task = get_model_task_xgb(model)
136
- output_type = model_signature.DataType.DOUBLE
137
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
138
- output_type = model_signature.DataType.STRING
139
- return ModelTaskAndOutputType(task=task, output_type=output_type)
135
+ return get_model_task_xgb(model)
140
136
 
141
137
  if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
142
138
  "lightgbm.LGBMModel"
143
139
  ).isinstance(model):
144
- task = get_model_task_lightgbm(model)
145
- output_type = model_signature.DataType.DOUBLE
146
- if task in [
147
- type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
148
- type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
149
- ]:
150
- output_type = model_signature.DataType.STRING
151
- return ModelTaskAndOutputType(task=task, output_type=output_type)
140
+ return get_model_task_lightgbm(model)
152
141
 
153
142
  if type_utils.LazyType("catboost.CatBoost").isinstance(model):
154
- task = get_model_task_catboost(model)
155
- output_type = model_signature.DataType.DOUBLE
156
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
157
- output_type = model_signature.DataType.STRING
158
- return ModelTaskAndOutputType(task=task, output_type=output_type)
143
+ return get_model_task_catboost(model)
159
144
 
160
145
  if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
161
146
  "sklearn.pipeline.Pipeline"
162
147
  ).isinstance(model):
163
- task = get_task_skl(model)
164
- output_type = model_signature.DataType.DOUBLE
165
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
166
- output_type = model_signature.DataType.STRING
167
- return ModelTaskAndOutputType(task=task, output_type=output_type)
168
-
148
+ return get_task_skl(model)
169
149
  raise ValueError(f"Model type {type(model)} is not supported")
150
+
151
+
152
+ def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
153
+ task = _get_model_task(model)
154
+ output_type = model_signature.DataType.DOUBLE
155
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
156
+ output_type = model_signature.DataType.STRING
157
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
@@ -147,6 +147,22 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
147
147
  specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
148
148
  elif isinstance(data[df_col].iloc[0], bytes):
149
149
  specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
150
+ elif isinstance(df_col_dtype, pd.CategoricalDtype):
151
+ category_dtype = df_col_dtype.categories.dtype
152
+ if category_dtype == np.dtype("O"):
153
+ if isinstance(df_col_dtype.categories[0], str):
154
+ specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
155
+ elif isinstance(df_col_dtype.categories[0], bytes):
156
+ specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
157
+ else:
158
+ raise snowml_exceptions.SnowflakeMLException(
159
+ error_code=error_codes.INVALID_DATA,
160
+ original_exception=ValueError(
161
+ f"Data Validation Error: Unsupported type confronted in {df_col_dtype.categories[0]}"
162
+ ),
163
+ )
164
+ else:
165
+ specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(category_dtype), name=ft_name))
150
166
  elif isinstance(data[df_col].iloc[0], np.datetime64):
151
167
  specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
152
168
  else:
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import inspect
3
- from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
3
+ from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
@@ -104,19 +104,53 @@ class ModelContext:
104
104
  def __init__(
105
105
  self,
106
106
  *,
107
- artifacts: Optional[Dict[str, str]] = None,
108
- models: Optional[Dict[str, model_types.SupportedModelType]] = None,
107
+ artifacts: Optional[Union[Dict[str, str], str, model_types.SupportedModelType]] = None,
108
+ models: Optional[Union[Dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
109
+ **kwargs: Optional[Union[str, model_types.SupportedModelType]],
109
110
  ) -> None:
110
111
  """Initialize the model context.
111
112
 
112
113
  Args:
113
114
  artifacts: A dictionary mapping the name of the artifact to its currently available path. Defaults to None.
114
115
  models: A dictionary mapping the name of the sub-model to the corresponding model object. Defaults to None.
116
+ **kwargs: Additional keyword arguments to be used as artifacts or models.
117
+
118
+ Raises:
119
+ ValueError: Raised when the keyword argument is used as artifacts or models.
120
+ ValueError: Raised when the artifact name is duplicated.
121
+ ValueError: Raised when the model name is duplicated.
115
122
  """
116
- self.artifacts: Dict[str, str] = artifacts if artifacts else dict()
117
- self.model_refs: Dict[str, ModelRef] = (
118
- {name: ModelRef(name, model) for name, model in models.items()} if models else dict()
119
- )
123
+
124
+ self.artifacts: Dict[str, str] = dict()
125
+ self.model_refs: Dict[str, ModelRef] = dict()
126
+
127
+ # In case that artifacts is a dictionary, assume the original usage,
128
+ # which is to pass in a dictionary of artifacts.
129
+ # In other scenarios, (str or supported model types) we will try to parse the arguments as artifacts or models.
130
+ if isinstance(artifacts, dict):
131
+ self.artifacts = artifacts
132
+ elif isinstance(artifacts, str):
133
+ self.artifacts["artifacts"] = artifacts
134
+ elif artifacts is not None:
135
+ self.model_refs["artifacts"] = ModelRef("artifacts", artifacts)
136
+
137
+ if isinstance(models, dict):
138
+ self.model_refs = {name: ModelRef(name, model) for name, model in models.items()} if models else dict()
139
+ elif isinstance(models, str):
140
+ self.artifacts["models"] = models
141
+ elif models is not None:
142
+ self.model_refs["models"] = ModelRef("models", models)
143
+
144
+ # Handle any new arguments passed via kwargs
145
+ for key, value in kwargs.items():
146
+ if isinstance(value, str):
147
+ if key in self.artifacts:
148
+ raise ValueError(f"Duplicate artifact name: {key}")
149
+ self.artifacts[key] = value
150
+ else:
151
+ if key in self.model_refs:
152
+ raise ValueError(f"Duplicate model name: {key}")
153
+ self.model_refs[key] = ModelRef(key, value)
120
154
 
121
155
  def path(self, key: str) -> str:
122
156
  """Get the actual path to a specific artifact. This could be used when defining a Custom Model to retrieve
@@ -141,6 +175,12 @@ class ModelContext:
141
175
  """
142
176
  return self.model_refs[name]
143
177
 
178
+ def __getitem__(self, key: str) -> Union[str, ModelRef]:
179
+ combined: Dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
180
+ if key not in combined:
181
+ raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
182
+ return combined[key]
183
+
144
184
 
145
185
  class CustomModel:
146
186
  """Abstract class for user defined custom model.
@@ -214,6 +214,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
214
214
  assert isinstance(feature, core.FeatureSpec) # assert for mypy.
215
215
  ft_type = feature._dtype
216
216
  ft_shape = feature._shape
217
+ if isinstance(df_col_dtype, pd.CategoricalDtype):
218
+ df_col_dtype = df_col_dtype.categories.dtype
217
219
  if df_col_dtype != np.dtype("O"):
218
220
  if not _validate_numpy_array(data_col.to_numpy(), ft_type, strict=strict):
219
221
  raise snowml_exceptions.SnowflakeMLException(
@@ -298,3 +298,11 @@ class Task(Enum):
298
298
  TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
299
299
  TABULAR_REGRESSION = "TABULAR_REGRESSION"
300
300
  TABULAR_RANKING = "TABULAR_RANKING"
301
+
302
+
303
+ class TargetPlatform(Enum):
304
+ WAREHOUSE = "WAREHOUSE"
305
+ SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
306
+
307
+
308
+ SupportedTargetPlatformType = Union[TargetPlatform, str]
@@ -275,3 +275,16 @@ def upload_model_to_stage(
275
275
 
276
276
  temp_file_utils.cleanup_temp_files([local_transform_file_name])
277
277
  return os.path.basename(local_transform_file_name)
278
+
279
+
280
+ def should_include_sample_weight(estimator: object, method_name: str) -> bool:
281
+ # If this is a Grid Search or Randomized Search estimator, check the underlying estimator.
282
+ underlying_estimator = (
283
+ estimator.estimator if ("_search" in estimator.__module__ and hasattr(estimator, "estimator")) else estimator
284
+ )
285
+ method = getattr(underlying_estimator, method_name)
286
+ underlying_estimator_params = inspect.signature(method).parameters
287
+ if "sample_weight" in underlying_estimator_params:
288
+ return True
289
+
290
+ return False
@@ -4,7 +4,10 @@ from typing import Any, List, Optional
4
4
  import pandas as pd
5
5
 
6
6
  from snowflake.ml._internal.exceptions import error_codes, exceptions
7
- from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
7
+ from snowflake.ml.modeling._internal.estimator_utils import (
8
+ handle_inference_result,
9
+ should_include_sample_weight,
10
+ )
8
11
 
9
12
 
10
13
  class PandasTransformHandlers:
@@ -166,6 +169,7 @@ class PandasTransformHandlers:
166
169
  SnowflakeMLException: The input column list does not have one of `X` and `X_test`.
167
170
  """
168
171
  assert hasattr(self.estimator, "score") # make type checker happy
172
+
169
173
  params = inspect.signature(self.estimator.score).parameters
170
174
  if "X" in params:
171
175
  score_args = {"X": self.dataset[input_cols]}
@@ -181,7 +185,8 @@ class PandasTransformHandlers:
181
185
  label_arg_name = "Y" if "Y" in params else "y"
182
186
  score_args[label_arg_name] = self.dataset[label_cols].squeeze()
183
187
 
184
- if sample_weight_col is not None and "sample_weight" in params:
188
+ # Sample weight is not included in search estimators parameters, check the underlying estimator.
189
+ if sample_weight_col is not None and should_include_sample_weight(self.estimator, "score"):
185
190
  score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze()
186
191
 
187
192
  score = self.estimator.score(**score_args)
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
19
19
  snowpark_dataframe_utils,
20
20
  temp_file_utils,
21
21
  )
22
+ from snowflake.ml.modeling._internal.estimator_utils import should_include_sample_weight
22
23
  from snowflake.ml.modeling._internal.model_specifications import (
23
24
  ModelSpecificationsBuilder,
24
25
  )
@@ -38,6 +39,7 @@ from snowflake.snowpark.udtf import UDTFRegistration
38
39
  cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
39
40
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
40
41
  cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
42
+ cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
41
43
 
42
44
  _PROJECT = "ModelDevelopment"
43
45
  DEFAULT_UDTF_NJOBS = 3
@@ -393,7 +395,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
393
395
  import pandas as pd
394
396
  import pyarrow.parquet as pq
395
397
  from sklearn.metrics import check_scoring
396
- from sklearn.metrics._scorer import _check_multimetric_scoring
398
+ from sklearn.metrics._scorer import (
399
+ _check_multimetric_scoring,
400
+ _MultimetricScorer,
401
+ )
397
402
 
398
403
  for import_name in udf_imports:
399
404
  importlib.import_module(import_name)
@@ -606,6 +611,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
606
611
  scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
607
612
  estimator._check_refit_for_multimetric(scorers)
608
613
  refit_metric = original_refit
614
+ scorers = _MultimetricScorer(scorers=scorers)
609
615
 
610
616
  estimator.scorer_ = scorers
611
617
 
@@ -638,7 +644,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
638
644
  if label_cols:
639
645
  label_arg_name = "Y" if "Y" in argspec.args else "y"
640
646
  args[label_arg_name] = y
641
- if sample_weight_col is not None and "sample_weight" in argspec.args:
647
+ if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
642
648
  args["sample_weight"] = df[sample_weight_col].squeeze()
643
649
  estimator.refit = original_refit
644
650
  refit_start_time = time.time()
@@ -797,8 +803,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
797
803
  import pandas as pd
798
804
  import pyarrow.parquet as pq
799
805
  from sklearn.metrics import check_scoring
800
- from sklearn.metrics._scorer import _check_multimetric_scoring
801
- from sklearn.utils.validation import _check_fit_params, indexable
806
+ from sklearn.metrics._scorer import (
807
+ _check_multimetric_scoring,
808
+ _MultimetricScorer,
809
+ )
810
+ from sklearn.utils.validation import _check_method_params, indexable
802
811
 
803
812
  # import packages in sproc
804
813
  for import_name in udf_imports:
@@ -846,11 +855,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
846
855
  scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
847
856
  estimator._check_refit_for_multimetric(scorers)
848
857
  refit_metric = estimator.refit
858
+ scorers = _MultimetricScorer(scorers=scorers)
849
859
 
850
860
  # preprocess the attributes - (2) check fit_params
851
861
  groups = None
852
862
  X, y, _ = indexable(X, y, groups)
853
- fit_params = _check_fit_params(X, fit_params)
863
+ fit_params = _check_method_params(X, fit_params)
854
864
 
855
865
  # preprocess the attributes - (3) safe clone base estimator
856
866
  base_estimator = clone(estimator.estimator)
@@ -863,6 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
863
873
  fit_and_score_kwargs = dict(
864
874
  scorer=scorers,
865
875
  fit_params=fit_params,
876
+ score_params=None,
866
877
  return_train_score=estimator.return_train_score,
867
878
  return_n_test_samples=True,
868
879
  return_times=True,
@@ -18,7 +18,10 @@ from snowflake.ml._internal.utils import (
18
18
  )
19
19
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
20
20
  from snowflake.ml.modeling._internal import estimator_utils
21
- from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
21
+ from snowflake.ml.modeling._internal.estimator_utils import (
22
+ handle_inference_result,
23
+ should_include_sample_weight,
24
+ )
22
25
  from snowflake.snowpark import DataFrame, Session, functions as F, types as T
23
26
  from snowflake.snowpark._internal.utils import (
24
27
  TempObjectType,
@@ -28,6 +31,8 @@ from snowflake.snowpark._internal.utils import (
28
31
  cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
29
32
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
30
33
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
34
+ cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
35
+
31
36
 
32
37
  _PROJECT = "ModelDevelopment"
33
38
 
@@ -330,7 +335,8 @@ class SnowparkTransformHandlers:
330
335
  label_arg_name = "Y" if "Y" in params else "y"
331
336
  args[label_arg_name] = df[label_cols].squeeze()
332
337
 
333
- if sample_weight_col is not None and "sample_weight" in params:
338
+ # Sample weight is not included in search estimators parameters, check the underlying estimator.
339
+ if sample_weight_col is not None and should_include_sample_weight(estimator, "score"):
334
340
  args["sample_weight"] = df[sample_weight_col].squeeze()
335
341
 
336
342
  result: float = estimator.score(**args)
@@ -20,7 +20,10 @@ from snowflake.ml._internal.utils import (
20
20
  temp_file_utils,
21
21
  )
22
22
  from snowflake.ml.modeling._internal import estimator_utils
23
- from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
23
+ from snowflake.ml.modeling._internal.estimator_utils import (
24
+ handle_inference_result,
25
+ should_include_sample_weight,
26
+ )
24
27
  from snowflake.ml.modeling._internal.model_specifications import (
25
28
  ModelSpecifications,
26
29
  ModelSpecificationsBuilder,
@@ -32,6 +35,7 @@ from snowflake.snowpark.stored_procedure import StoredProcedure
32
35
  cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
33
36
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
34
37
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
38
+ cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
35
39
 
36
40
  _PROJECT = "ModelDevelopment"
37
41
  _ENABLE_ANONYMOUS_SPROC = False
@@ -170,12 +174,14 @@ class SnowparkModelTrainer:
170
174
  estimator = cp.load(local_transform_file_obj)
171
175
 
172
176
  params = inspect.signature(estimator.fit).parameters
177
+
173
178
  args = {"X": df[input_cols]}
174
179
  if label_cols:
175
180
  label_arg_name = "Y" if "Y" in params else "y"
176
181
  args[label_arg_name] = df[label_cols].squeeze()
177
182
 
178
- if sample_weight_col is not None and "sample_weight" in params:
183
+ # Sample weight is not included in search estimators parameters, check the underlying estimator.
184
+ if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
179
185
  args["sample_weight"] = df[sample_weight_col].squeeze()
180
186
 
181
187
  estimator.fit(**args)
@@ -412,7 +418,7 @@ class SnowparkModelTrainer:
412
418
  label_arg_name = "Y" if "Y" in params else "y"
413
419
  args[label_arg_name] = df[label_cols].squeeze()
414
420
 
415
- if sample_weight_col is not None and "sample_weight" in params:
421
+ if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
416
422
  args["sample_weight"] = df[sample_weight_col].squeeze()
417
423
 
418
424
  fit_transform_result = estimator.fit_transform(**args)
@@ -167,9 +167,6 @@ class CalibratedClassifierCV(BaseTransformer):
167
167
  `estimator` trained on all the data.
168
168
  Note that this method is also internally implemented in
169
169
  :mod:`sklearn.svm` estimators with the `probabilities=True` parameter.
170
-
171
- base_estimator: estimator instance
172
- This parameter is deprecated. Use `estimator` instead.
173
170
  """
174
171
 
175
172
  def __init__( # type: ignore[no-untyped-def]
@@ -180,7 +177,6 @@ class CalibratedClassifierCV(BaseTransformer):
180
177
  cv=None,
181
178
  n_jobs=None,
182
179
  ensemble=True,
183
- base_estimator="deprecated",
184
180
  input_cols: Optional[Union[str, Iterable[str]]] = None,
185
181
  output_cols: Optional[Union[str, Iterable[str]]] = None,
186
182
  label_cols: Optional[Union[str, Iterable[str]]] = None,
@@ -200,16 +196,13 @@ class CalibratedClassifierCV(BaseTransformer):
200
196
  self._batch_size = -1
201
197
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
202
198
  deps = deps | gather_dependencies(estimator)
203
- deps = deps | gather_dependencies(base_estimator)
204
199
  self._deps = list(deps)
205
200
  estimator = transform_snowml_obj_to_sklearn_obj(estimator)
206
- base_estimator = transform_snowml_obj_to_sklearn_obj(base_estimator)
207
201
  init_args = {'estimator':(estimator, None, False),
208
202
  'method':(method, "sigmoid", False),
209
203
  'cv':(cv, None, False),
210
204
  'n_jobs':(n_jobs, None, False),
211
- 'ensemble':(ensemble, True, False),
212
- 'base_estimator':(base_estimator, "deprecated", False),}
205
+ 'ensemble':(ensemble, True, False),}
213
206
  cleaned_up_init_args = validate_sklearn_args(
214
207
  args=init_args,
215
208
  klass=sklearn.calibration.CalibratedClassifierCV
@@ -113,28 +113,18 @@ class AgglomerativeClustering(BaseTransformer):
113
113
  The number of clusters to find. It must be ``None`` if
114
114
  ``distance_threshold`` is not ``None``.
115
115
 
116
- affinity: str or callable, default='euclidean'
117
- The metric to use when calculating distance between instances in a
118
- feature array. If metric is a string or callable, it must be one of
119
- the options allowed by :func:`sklearn.metrics.pairwise_distances` for
120
- its metric parameter.
121
- If linkage is "ward", only "euclidean" is accepted.
122
- If "precomputed", a distance matrix (instead of a similarity matrix)
123
- is needed as input for the fit method.
124
-
125
- metric: str or callable, default=None
116
+ metric: str or callable, default="euclidean"
126
117
  Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
127
- "manhattan", "cosine", or "precomputed". If set to `None` then
128
- "euclidean" is used. If linkage is "ward", only "euclidean" is
129
- accepted. If "precomputed", a distance matrix is needed as input for
130
- the fit method.
118
+ "manhattan", "cosine", or "precomputed". If linkage is "ward", only
119
+ "euclidean" is accepted. If "precomputed", a distance matrix is needed
120
+ as input for the fit method.
131
121
 
132
122
  memory: str or object with the joblib.Memory interface, default=None
133
123
  Used to cache the output of the computation of the tree.
134
124
  By default, no caching is done. If a string is given, it is the
135
125
  path to the caching directory.
136
126
 
137
- connectivity: array-like or callable, default=None
127
+ connectivity: array-like, sparse matrix, or callable, default=None
138
128
  Connectivity matrix. Defines for each sample the neighboring
139
129
  samples following a given structure of the data.
140
130
  This can be a connectivity matrix itself or a callable that transforms
@@ -142,6 +132,10 @@ class AgglomerativeClustering(BaseTransformer):
142
132
  `kneighbors_graph`. Default is ``None``, i.e, the
143
133
  hierarchical clustering algorithm is unstructured.
144
134
 
135
+ For an example of connectivity matrix using
136
+ :class:`~sklearn.neighbors.kneighbors_graph`, see
137
+ :ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_clustering.py`.
138
+
145
139
  compute_full_tree: 'auto' or bool, default='auto'
146
140
  Stop early the construction of the tree at ``n_clusters``. This is
147
141
  useful to decrease computation time if the number of clusters is not
@@ -167,6 +161,9 @@ class AgglomerativeClustering(BaseTransformer):
167
161
  - 'single' uses the minimum of the distances between all observations
168
162
  of the two sets.
169
163
 
164
+ For examples comparing different `linkage` criteria, see
165
+ :ref:`sphx_glr_auto_examples_cluster_plot_linkage_comparison.py`.
166
+
170
167
  distance_threshold: float, default=None
171
168
  The linkage distance threshold at or above which clusters will not be
172
169
  merged. If not ``None``, ``n_clusters`` must be ``None`` and
@@ -176,14 +173,16 @@ class AgglomerativeClustering(BaseTransformer):
176
173
  Computes distances between clusters even if `distance_threshold` is not
177
174
  used. This can be used to make dendrogram visualization, but introduces
178
175
  a computational and memory overhead.
176
+
177
+ For an example of dendrogram visualization, see
178
+ :ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_dendrogram.py`.
179
179
  """
180
180
 
181
181
  def __init__( # type: ignore[no-untyped-def]
182
182
  self,
183
183
  *,
184
184
  n_clusters=2,
185
- affinity="deprecated",
186
- metric=None,
185
+ metric="euclidean",
187
186
  memory=None,
188
187
  connectivity=None,
189
188
  compute_full_tree="auto",
@@ -212,8 +211,7 @@ class AgglomerativeClustering(BaseTransformer):
212
211
  self._deps = list(deps)
213
212
 
214
213
  init_args = {'n_clusters':(n_clusters, 2, False),
215
- 'affinity':(affinity, "deprecated", False),
216
- 'metric':(metric, None, False),
214
+ 'metric':(metric, "euclidean", False),
217
215
  'memory':(memory, None, False),
218
216
  'connectivity':(connectivity, None, False),
219
217
  'compute_full_tree':(compute_full_tree, "auto", False),
@@ -117,8 +117,11 @@ class DBSCAN(BaseTransformer):
117
117
  and distance function.
118
118
 
119
119
  min_samples: int, default=5
120
- The number of samples (or total weight) in a neighborhood for a point
121
- to be considered as a core point. This includes the point itself.
120
+ The number of samples (or total weight) in a neighborhood for a point to
121
+ be considered as a core point. This includes the point itself. If
122
+ `min_samples` is set to a higher value, DBSCAN will find denser clusters,
123
+ whereas if it is set to a lower value, the found clusters will be more
124
+ sparse.
122
125
 
123
126
  metric: str, or callable, default='euclidean'
124
127
  The metric to use when calculating distance between instances in a