snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -109,10 +109,10 @@ class SeqOfTensorflowTensorHandler(
|
|
109
109
|
dtype = core.DataType.from_numpy_type(data_col.dtype.as_numpy_dtype)
|
110
110
|
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
111
111
|
if len(data_col.shape) == 1:
|
112
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name))
|
112
|
+
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
|
113
113
|
else:
|
114
114
|
ft_shape = tuple(data_col.shape[1:])
|
115
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
|
115
|
+
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
|
116
116
|
return features
|
117
117
|
|
118
118
|
@staticmethod
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
|
3
|
+
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
@@ -104,19 +104,53 @@ class ModelContext:
|
|
104
104
|
def __init__(
|
105
105
|
self,
|
106
106
|
*,
|
107
|
-
artifacts: Optional[Dict[str, str]] = None,
|
108
|
-
models: Optional[Dict[str, model_types.SupportedModelType]] = None,
|
107
|
+
artifacts: Optional[Union[Dict[str, str], str, model_types.SupportedModelType]] = None,
|
108
|
+
models: Optional[Union[Dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
|
109
|
+
**kwargs: Optional[Union[str, model_types.SupportedModelType]],
|
109
110
|
) -> None:
|
110
111
|
"""Initialize the model context.
|
111
112
|
|
112
113
|
Args:
|
113
114
|
artifacts: A dictionary mapping the name of the artifact to its currently available path. Defaults to None.
|
114
115
|
models: A dictionary mapping the name of the sub-model to the corresponding model object. Defaults to None.
|
116
|
+
**kwargs: Additional keyword arguments to be used as artifacts or models.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
ValueError: Raised when the keyword argument is used as artifacts or models.
|
120
|
+
ValueError: Raised when the artifact name is duplicated.
|
121
|
+
ValueError: Raised when the model name is duplicated.
|
115
122
|
"""
|
116
|
-
|
117
|
-
self.
|
118
|
-
|
119
|
-
|
123
|
+
|
124
|
+
self.artifacts: Dict[str, str] = dict()
|
125
|
+
self.model_refs: Dict[str, ModelRef] = dict()
|
126
|
+
|
127
|
+
# In case that artifacts is a dictionary, assume the original usage,
|
128
|
+
# which is to pass in a dictionary of artifacts.
|
129
|
+
# In other scenarios, (str or supported model types) we will try to parse the arguments as artifacts or models.
|
130
|
+
if isinstance(artifacts, dict):
|
131
|
+
self.artifacts = artifacts
|
132
|
+
elif isinstance(artifacts, str):
|
133
|
+
self.artifacts["artifacts"] = artifacts
|
134
|
+
elif artifacts is not None:
|
135
|
+
self.model_refs["artifacts"] = ModelRef("artifacts", artifacts)
|
136
|
+
|
137
|
+
if isinstance(models, dict):
|
138
|
+
self.model_refs = {name: ModelRef(name, model) for name, model in models.items()} if models else dict()
|
139
|
+
elif isinstance(models, str):
|
140
|
+
self.artifacts["models"] = models
|
141
|
+
elif models is not None:
|
142
|
+
self.model_refs["models"] = ModelRef("models", models)
|
143
|
+
|
144
|
+
# Handle any new arguments passed via kwargs
|
145
|
+
for key, value in kwargs.items():
|
146
|
+
if isinstance(value, str):
|
147
|
+
if key in self.artifacts:
|
148
|
+
raise ValueError(f"Duplicate artifact name: {key}")
|
149
|
+
self.artifacts[key] = value
|
150
|
+
else:
|
151
|
+
if key in self.model_refs:
|
152
|
+
raise ValueError(f"Duplicate model name: {key}")
|
153
|
+
self.model_refs[key] = ModelRef(key, value)
|
120
154
|
|
121
155
|
def path(self, key: str) -> str:
|
122
156
|
"""Get the actual path to a specific artifact. This could be used when defining a Custom Model to retrieve
|
@@ -141,6 +175,12 @@ class ModelContext:
|
|
141
175
|
"""
|
142
176
|
return self.model_refs[name]
|
143
177
|
|
178
|
+
def __getitem__(self, key: str) -> Union[str, ModelRef]:
|
179
|
+
combined: Dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
|
180
|
+
if key not in combined:
|
181
|
+
raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
|
182
|
+
return combined[key]
|
183
|
+
|
144
184
|
|
145
185
|
class CustomModel:
|
146
186
|
"""Abstract class for user defined custom model.
|
@@ -139,9 +139,32 @@ def _rename_signature_with_snowflake_identifiers(
|
|
139
139
|
return signature
|
140
140
|
|
141
141
|
|
142
|
-
def
|
143
|
-
arr: model_types._SupportedNumpyArray, feature_type: core.DataType, strict: bool = False
|
142
|
+
def _validate_array_or_series_type(
|
143
|
+
arr: Union[model_types._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
|
144
144
|
) -> bool:
|
145
|
+
original_dtype = arr.dtype
|
146
|
+
dtype = arr.dtype
|
147
|
+
if isinstance(
|
148
|
+
dtype,
|
149
|
+
(
|
150
|
+
pd.Int8Dtype,
|
151
|
+
pd.Int16Dtype,
|
152
|
+
pd.Int32Dtype,
|
153
|
+
pd.Int64Dtype,
|
154
|
+
pd.UInt8Dtype,
|
155
|
+
pd.UInt16Dtype,
|
156
|
+
pd.UInt32Dtype,
|
157
|
+
pd.UInt64Dtype,
|
158
|
+
pd.Float32Dtype,
|
159
|
+
pd.Float64Dtype,
|
160
|
+
pd.BooleanDtype,
|
161
|
+
),
|
162
|
+
):
|
163
|
+
dtype = dtype.type
|
164
|
+
elif isinstance(dtype, pd.CategoricalDtype):
|
165
|
+
dtype = dtype.categories.dtype
|
166
|
+
elif isinstance(dtype, pd.StringDtype):
|
167
|
+
dtype = np.str_
|
145
168
|
if feature_type in [
|
146
169
|
core.DataType.INT8,
|
147
170
|
core.DataType.INT16,
|
@@ -152,14 +175,17 @@ def _validate_numpy_array(
|
|
152
175
|
core.DataType.UINT32,
|
153
176
|
core.DataType.UINT64,
|
154
177
|
]:
|
155
|
-
if not (np.issubdtype(
|
178
|
+
if not (np.issubdtype(dtype, np.integer)):
|
156
179
|
return False
|
157
180
|
if not strict:
|
158
181
|
return True
|
159
|
-
|
182
|
+
if isinstance(original_dtype, pd.CategoricalDtype):
|
183
|
+
min_v, max_v = arr.cat.as_ordered().min(), arr.cat.as_ordered().min() # type: ignore[union-attr]
|
184
|
+
else:
|
185
|
+
min_v, max_v = arr.min(), arr.max()
|
160
186
|
return bool(max_v <= np.iinfo(feature_type._numpy_type).max and min_v >= np.iinfo(feature_type._numpy_type).min)
|
161
187
|
elif feature_type in [core.DataType.FLOAT, core.DataType.DOUBLE]:
|
162
|
-
if not (np.issubdtype(
|
188
|
+
if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.floating)):
|
163
189
|
return False
|
164
190
|
if not strict:
|
165
191
|
return True
|
@@ -171,7 +197,7 @@ def _validate_numpy_array(
|
|
171
197
|
elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
|
172
198
|
return np.issubdtype(arr.dtype, np.datetime64)
|
173
199
|
else:
|
174
|
-
return np.can_cast(
|
200
|
+
return np.can_cast(dtype, feature_type._numpy_type, casting="no")
|
175
201
|
|
176
202
|
|
177
203
|
def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec], strict: bool = False) -> None:
|
@@ -204,7 +230,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
204
230
|
original_exception=ValueError(f"Data Validation Error: feature {ft_name} does not exist in data."),
|
205
231
|
)
|
206
232
|
|
233
|
+
if data_col.isnull().any():
|
234
|
+
data_col = utils.series_dropna(data_col)
|
207
235
|
df_col_dtype = data_col.dtype
|
236
|
+
|
208
237
|
if isinstance(feature, core.FeatureGroupSpec):
|
209
238
|
raise snowml_exceptions.SnowflakeMLException(
|
210
239
|
error_code=error_codes.NOT_IMPLEMENTED,
|
@@ -214,8 +243,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
214
243
|
assert isinstance(feature, core.FeatureSpec) # assert for mypy.
|
215
244
|
ft_type = feature._dtype
|
216
245
|
ft_shape = feature._shape
|
246
|
+
if isinstance(df_col_dtype, pd.CategoricalDtype):
|
247
|
+
df_col_dtype = df_col_dtype.categories.dtype
|
217
248
|
if df_col_dtype != np.dtype("O"):
|
218
|
-
if not
|
249
|
+
if not _validate_array_or_series_type(data_col, ft_type, strict=strict):
|
219
250
|
raise snowml_exceptions.SnowflakeMLException(
|
220
251
|
error_code=error_codes.INVALID_DATA,
|
221
252
|
original_exception=ValueError(
|
@@ -245,7 +276,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
245
276
|
converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data_col]
|
246
277
|
|
247
278
|
if not all(
|
248
|
-
|
279
|
+
_validate_array_or_series_type(converted_data, ft_type, strict=strict)
|
249
280
|
for converted_data in converted_data_list
|
250
281
|
):
|
251
282
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -276,7 +307,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
276
307
|
),
|
277
308
|
)
|
278
309
|
|
279
|
-
if not all(
|
310
|
+
if not all(_validate_array_or_series_type(data_row, ft_type, strict=strict) for data_row in data_col):
|
280
311
|
raise snowml_exceptions.SnowflakeMLException(
|
281
312
|
error_code=error_codes.INVALID_DATA,
|
282
313
|
original_exception=ValueError(
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -66,7 +66,7 @@ SupportedRequireSignatureModelType = Union[
|
|
66
66
|
"xgboost.XGBModel",
|
67
67
|
"xgboost.Booster",
|
68
68
|
"torch.nn.Module",
|
69
|
-
"torch.jit.ScriptModule",
|
69
|
+
"torch.jit.ScriptModule",
|
70
70
|
"tensorflow.Module",
|
71
71
|
]
|
72
72
|
|
@@ -298,3 +298,11 @@ class Task(Enum):
|
|
298
298
|
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
299
299
|
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
300
300
|
TABULAR_RANKING = "TABULAR_RANKING"
|
301
|
+
|
302
|
+
|
303
|
+
class TargetPlatform(Enum):
|
304
|
+
WAREHOUSE = "WAREHOUSE"
|
305
|
+
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
306
|
+
|
307
|
+
|
308
|
+
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
@@ -275,3 +275,16 @@ def upload_model_to_stage(
|
|
275
275
|
|
276
276
|
temp_file_utils.cleanup_temp_files([local_transform_file_name])
|
277
277
|
return os.path.basename(local_transform_file_name)
|
278
|
+
|
279
|
+
|
280
|
+
def should_include_sample_weight(estimator: object, method_name: str) -> bool:
|
281
|
+
# If this is a Grid Search or Randomized Search estimator, check the underlying estimator.
|
282
|
+
underlying_estimator = (
|
283
|
+
estimator.estimator if ("_search" in estimator.__module__ and hasattr(estimator, "estimator")) else estimator
|
284
|
+
)
|
285
|
+
method = getattr(underlying_estimator, method_name)
|
286
|
+
underlying_estimator_params = inspect.signature(method).parameters
|
287
|
+
if "sample_weight" in underlying_estimator_params:
|
288
|
+
return True
|
289
|
+
|
290
|
+
return False
|
@@ -4,7 +4,10 @@ from typing import Any, List, Optional
|
|
4
4
|
import pandas as pd
|
5
5
|
|
6
6
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
7
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
7
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
8
|
+
handle_inference_result,
|
9
|
+
should_include_sample_weight,
|
10
|
+
)
|
8
11
|
|
9
12
|
|
10
13
|
class PandasTransformHandlers:
|
@@ -166,6 +169,7 @@ class PandasTransformHandlers:
|
|
166
169
|
SnowflakeMLException: The input column list does not have one of `X` and `X_test`.
|
167
170
|
"""
|
168
171
|
assert hasattr(self.estimator, "score") # make type checker happy
|
172
|
+
|
169
173
|
params = inspect.signature(self.estimator.score).parameters
|
170
174
|
if "X" in params:
|
171
175
|
score_args = {"X": self.dataset[input_cols]}
|
@@ -181,7 +185,8 @@ class PandasTransformHandlers:
|
|
181
185
|
label_arg_name = "Y" if "Y" in params else "y"
|
182
186
|
score_args[label_arg_name] = self.dataset[label_cols].squeeze()
|
183
187
|
|
184
|
-
|
188
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
189
|
+
if sample_weight_col is not None and should_include_sample_weight(self.estimator, "score"):
|
185
190
|
score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze()
|
186
191
|
|
187
192
|
score = self.estimator.score(**score_args)
|
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
|
|
19
19
|
snowpark_dataframe_utils,
|
20
20
|
temp_file_utils,
|
21
21
|
)
|
22
|
+
from snowflake.ml.modeling._internal.estimator_utils import should_include_sample_weight
|
22
23
|
from snowflake.ml.modeling._internal.model_specifications import (
|
23
24
|
ModelSpecificationsBuilder,
|
24
25
|
)
|
@@ -38,6 +39,7 @@ from snowflake.snowpark.udtf import UDTFRegistration
|
|
38
39
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
39
40
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
40
41
|
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
42
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
41
43
|
|
42
44
|
_PROJECT = "ModelDevelopment"
|
43
45
|
DEFAULT_UDTF_NJOBS = 3
|
@@ -393,7 +395,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
393
395
|
import pandas as pd
|
394
396
|
import pyarrow.parquet as pq
|
395
397
|
from sklearn.metrics import check_scoring
|
396
|
-
from sklearn.metrics._scorer import
|
398
|
+
from sklearn.metrics._scorer import (
|
399
|
+
_check_multimetric_scoring,
|
400
|
+
_MultimetricScorer,
|
401
|
+
)
|
397
402
|
|
398
403
|
for import_name in udf_imports:
|
399
404
|
importlib.import_module(import_name)
|
@@ -606,6 +611,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
606
611
|
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
607
612
|
estimator._check_refit_for_multimetric(scorers)
|
608
613
|
refit_metric = original_refit
|
614
|
+
scorers = _MultimetricScorer(scorers=scorers)
|
609
615
|
|
610
616
|
estimator.scorer_ = scorers
|
611
617
|
|
@@ -638,7 +644,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
638
644
|
if label_cols:
|
639
645
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
640
646
|
args[label_arg_name] = y
|
641
|
-
if sample_weight_col is not None and "
|
647
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
642
648
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
643
649
|
estimator.refit = original_refit
|
644
650
|
refit_start_time = time.time()
|
@@ -797,8 +803,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
797
803
|
import pandas as pd
|
798
804
|
import pyarrow.parquet as pq
|
799
805
|
from sklearn.metrics import check_scoring
|
800
|
-
from sklearn.metrics._scorer import
|
801
|
-
|
806
|
+
from sklearn.metrics._scorer import (
|
807
|
+
_check_multimetric_scoring,
|
808
|
+
_MultimetricScorer,
|
809
|
+
)
|
810
|
+
from sklearn.utils.validation import _check_method_params, indexable
|
802
811
|
|
803
812
|
# import packages in sproc
|
804
813
|
for import_name in udf_imports:
|
@@ -846,11 +855,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
846
855
|
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
847
856
|
estimator._check_refit_for_multimetric(scorers)
|
848
857
|
refit_metric = estimator.refit
|
858
|
+
scorers = _MultimetricScorer(scorers=scorers)
|
849
859
|
|
850
860
|
# preprocess the attributes - (2) check fit_params
|
851
861
|
groups = None
|
852
862
|
X, y, _ = indexable(X, y, groups)
|
853
|
-
fit_params =
|
863
|
+
fit_params = _check_method_params(X, fit_params)
|
854
864
|
|
855
865
|
# preprocess the attributes - (3) safe clone base estimator
|
856
866
|
base_estimator = clone(estimator.estimator)
|
@@ -863,6 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
863
873
|
fit_and_score_kwargs = dict(
|
864
874
|
scorer=scorers,
|
865
875
|
fit_params=fit_params,
|
876
|
+
score_params=None,
|
866
877
|
return_train_score=estimator.return_train_score,
|
867
878
|
return_n_test_samples=True,
|
868
879
|
return_times=True,
|
@@ -18,7 +18,10 @@ from snowflake.ml._internal.utils import (
|
|
18
18
|
)
|
19
19
|
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
20
20
|
from snowflake.ml.modeling._internal import estimator_utils
|
21
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
21
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
22
|
+
handle_inference_result,
|
23
|
+
should_include_sample_weight,
|
24
|
+
)
|
22
25
|
from snowflake.snowpark import DataFrame, Session, functions as F, types as T
|
23
26
|
from snowflake.snowpark._internal.utils import (
|
24
27
|
TempObjectType,
|
@@ -28,6 +31,8 @@ from snowflake.snowpark._internal.utils import (
|
|
28
31
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
29
32
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
30
33
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
34
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
35
|
+
|
31
36
|
|
32
37
|
_PROJECT = "ModelDevelopment"
|
33
38
|
|
@@ -330,7 +335,8 @@ class SnowparkTransformHandlers:
|
|
330
335
|
label_arg_name = "Y" if "Y" in params else "y"
|
331
336
|
args[label_arg_name] = df[label_cols].squeeze()
|
332
337
|
|
333
|
-
|
338
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
339
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "score"):
|
334
340
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
335
341
|
|
336
342
|
result: float = estimator.score(**args)
|
@@ -20,7 +20,10 @@ from snowflake.ml._internal.utils import (
|
|
20
20
|
temp_file_utils,
|
21
21
|
)
|
22
22
|
from snowflake.ml.modeling._internal import estimator_utils
|
23
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
23
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
24
|
+
handle_inference_result,
|
25
|
+
should_include_sample_weight,
|
26
|
+
)
|
24
27
|
from snowflake.ml.modeling._internal.model_specifications import (
|
25
28
|
ModelSpecifications,
|
26
29
|
ModelSpecificationsBuilder,
|
@@ -32,6 +35,7 @@ from snowflake.snowpark.stored_procedure import StoredProcedure
|
|
32
35
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
33
36
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
34
37
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
38
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
35
39
|
|
36
40
|
_PROJECT = "ModelDevelopment"
|
37
41
|
_ENABLE_ANONYMOUS_SPROC = False
|
@@ -170,12 +174,14 @@ class SnowparkModelTrainer:
|
|
170
174
|
estimator = cp.load(local_transform_file_obj)
|
171
175
|
|
172
176
|
params = inspect.signature(estimator.fit).parameters
|
177
|
+
|
173
178
|
args = {"X": df[input_cols]}
|
174
179
|
if label_cols:
|
175
180
|
label_arg_name = "Y" if "Y" in params else "y"
|
176
181
|
args[label_arg_name] = df[label_cols].squeeze()
|
177
182
|
|
178
|
-
|
183
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
184
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
179
185
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
180
186
|
|
181
187
|
estimator.fit(**args)
|
@@ -412,7 +418,7 @@ class SnowparkModelTrainer:
|
|
412
418
|
label_arg_name = "Y" if "Y" in params else "y"
|
413
419
|
args[label_arg_name] = df[label_cols].squeeze()
|
414
420
|
|
415
|
-
if sample_weight_col is not None and "
|
421
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
416
422
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
417
423
|
|
418
424
|
fit_transform_result = estimator.fit_transform(**args)
|
@@ -167,9 +167,6 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
167
167
|
`estimator` trained on all the data.
|
168
168
|
Note that this method is also internally implemented in
|
169
169
|
:mod:`sklearn.svm` estimators with the `probabilities=True` parameter.
|
170
|
-
|
171
|
-
base_estimator: estimator instance
|
172
|
-
This parameter is deprecated. Use `estimator` instead.
|
173
170
|
"""
|
174
171
|
|
175
172
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -180,7 +177,6 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
180
177
|
cv=None,
|
181
178
|
n_jobs=None,
|
182
179
|
ensemble=True,
|
183
|
-
base_estimator="deprecated",
|
184
180
|
input_cols: Optional[Union[str, Iterable[str]]] = None,
|
185
181
|
output_cols: Optional[Union[str, Iterable[str]]] = None,
|
186
182
|
label_cols: Optional[Union[str, Iterable[str]]] = None,
|
@@ -200,16 +196,13 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
200
196
|
self._batch_size = -1
|
201
197
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
202
198
|
deps = deps | gather_dependencies(estimator)
|
203
|
-
deps = deps | gather_dependencies(base_estimator)
|
204
199
|
self._deps = list(deps)
|
205
200
|
estimator = transform_snowml_obj_to_sklearn_obj(estimator)
|
206
|
-
base_estimator = transform_snowml_obj_to_sklearn_obj(base_estimator)
|
207
201
|
init_args = {'estimator':(estimator, None, False),
|
208
202
|
'method':(method, "sigmoid", False),
|
209
203
|
'cv':(cv, None, False),
|
210
204
|
'n_jobs':(n_jobs, None, False),
|
211
|
-
'ensemble':(ensemble, True, False),
|
212
|
-
'base_estimator':(base_estimator, "deprecated", False),}
|
205
|
+
'ensemble':(ensemble, True, False),}
|
213
206
|
cleaned_up_init_args = validate_sklearn_args(
|
214
207
|
args=init_args,
|
215
208
|
klass=sklearn.calibration.CalibratedClassifierCV
|
@@ -113,28 +113,18 @@ class AgglomerativeClustering(BaseTransformer):
|
|
113
113
|
The number of clusters to find. It must be ``None`` if
|
114
114
|
``distance_threshold`` is not ``None``.
|
115
115
|
|
116
|
-
|
117
|
-
The metric to use when calculating distance between instances in a
|
118
|
-
feature array. If metric is a string or callable, it must be one of
|
119
|
-
the options allowed by :func:`sklearn.metrics.pairwise_distances` for
|
120
|
-
its metric parameter.
|
121
|
-
If linkage is "ward", only "euclidean" is accepted.
|
122
|
-
If "precomputed", a distance matrix (instead of a similarity matrix)
|
123
|
-
is needed as input for the fit method.
|
124
|
-
|
125
|
-
metric: str or callable, default=None
|
116
|
+
metric: str or callable, default="euclidean"
|
126
117
|
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
|
127
|
-
"manhattan", "cosine", or "precomputed". If
|
128
|
-
"euclidean" is
|
129
|
-
|
130
|
-
the fit method.
|
118
|
+
"manhattan", "cosine", or "precomputed". If linkage is "ward", only
|
119
|
+
"euclidean" is accepted. If "precomputed", a distance matrix is needed
|
120
|
+
as input for the fit method.
|
131
121
|
|
132
122
|
memory: str or object with the joblib.Memory interface, default=None
|
133
123
|
Used to cache the output of the computation of the tree.
|
134
124
|
By default, no caching is done. If a string is given, it is the
|
135
125
|
path to the caching directory.
|
136
126
|
|
137
|
-
connectivity: array-like or callable, default=None
|
127
|
+
connectivity: array-like, sparse matrix, or callable, default=None
|
138
128
|
Connectivity matrix. Defines for each sample the neighboring
|
139
129
|
samples following a given structure of the data.
|
140
130
|
This can be a connectivity matrix itself or a callable that transforms
|
@@ -142,6 +132,10 @@ class AgglomerativeClustering(BaseTransformer):
|
|
142
132
|
`kneighbors_graph`. Default is ``None``, i.e, the
|
143
133
|
hierarchical clustering algorithm is unstructured.
|
144
134
|
|
135
|
+
For an example of connectivity matrix using
|
136
|
+
:class:`~sklearn.neighbors.kneighbors_graph`, see
|
137
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_clustering.py`.
|
138
|
+
|
145
139
|
compute_full_tree: 'auto' or bool, default='auto'
|
146
140
|
Stop early the construction of the tree at ``n_clusters``. This is
|
147
141
|
useful to decrease computation time if the number of clusters is not
|
@@ -167,6 +161,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
167
161
|
- 'single' uses the minimum of the distances between all observations
|
168
162
|
of the two sets.
|
169
163
|
|
164
|
+
For examples comparing different `linkage` criteria, see
|
165
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_linkage_comparison.py`.
|
166
|
+
|
170
167
|
distance_threshold: float, default=None
|
171
168
|
The linkage distance threshold at or above which clusters will not be
|
172
169
|
merged. If not ``None``, ``n_clusters`` must be ``None`` and
|
@@ -176,14 +173,16 @@ class AgglomerativeClustering(BaseTransformer):
|
|
176
173
|
Computes distances between clusters even if `distance_threshold` is not
|
177
174
|
used. This can be used to make dendrogram visualization, but introduces
|
178
175
|
a computational and memory overhead.
|
176
|
+
|
177
|
+
For an example of dendrogram visualization, see
|
178
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_dendrogram.py`.
|
179
179
|
"""
|
180
180
|
|
181
181
|
def __init__( # type: ignore[no-untyped-def]
|
182
182
|
self,
|
183
183
|
*,
|
184
184
|
n_clusters=2,
|
185
|
-
|
186
|
-
metric=None,
|
185
|
+
metric="euclidean",
|
187
186
|
memory=None,
|
188
187
|
connectivity=None,
|
189
188
|
compute_full_tree="auto",
|
@@ -212,8 +211,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
212
211
|
self._deps = list(deps)
|
213
212
|
|
214
213
|
init_args = {'n_clusters':(n_clusters, 2, False),
|
215
|
-
'
|
216
|
-
'metric':(metric, None, False),
|
214
|
+
'metric':(metric, "euclidean", False),
|
217
215
|
'memory':(memory, None, False),
|
218
216
|
'connectivity':(connectivity, None, False),
|
219
217
|
'compute_full_tree':(compute_full_tree, "auto", False),
|
@@ -117,8 +117,11 @@ class DBSCAN(BaseTransformer):
|
|
117
117
|
and distance function.
|
118
118
|
|
119
119
|
min_samples: int, default=5
|
120
|
-
The number of samples (or total weight) in a neighborhood for a point
|
121
|
-
|
120
|
+
The number of samples (or total weight) in a neighborhood for a point to
|
121
|
+
be considered as a core point. This includes the point itself. If
|
122
|
+
`min_samples` is set to a higher value, DBSCAN will find denser clusters,
|
123
|
+
whereas if it is set to a lower value, the found clusters will be more
|
124
|
+
sparse.
|
122
125
|
|
123
126
|
metric: str, or callable, default='euclidean'
|
124
127
|
The metric to use when calculating distance between instances in a
|
@@ -113,28 +113,18 @@ class FeatureAgglomeration(BaseTransformer):
|
|
113
113
|
The number of clusters to find. It must be ``None`` if
|
114
114
|
``distance_threshold`` is not ``None``.
|
115
115
|
|
116
|
-
|
117
|
-
The metric to use when calculating distance between instances in a
|
118
|
-
feature array. If metric is a string or callable, it must be one of
|
119
|
-
the options allowed by :func:`sklearn.metrics.pairwise_distances` for
|
120
|
-
its metric parameter.
|
121
|
-
If linkage is "ward", only "euclidean" is accepted.
|
122
|
-
If "precomputed", a distance matrix (instead of a similarity matrix)
|
123
|
-
is needed as input for the fit method.
|
124
|
-
|
125
|
-
metric: str or callable, default=None
|
116
|
+
metric: str or callable, default="euclidean"
|
126
117
|
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
|
127
|
-
"manhattan", "cosine", or "precomputed". If
|
128
|
-
"euclidean" is
|
129
|
-
|
130
|
-
the fit method.
|
118
|
+
"manhattan", "cosine", or "precomputed". If linkage is "ward", only
|
119
|
+
"euclidean" is accepted. If "precomputed", a distance matrix is needed
|
120
|
+
as input for the fit method.
|
131
121
|
|
132
122
|
memory: str or object with the joblib.Memory interface, default=None
|
133
123
|
Used to cache the output of the computation of the tree.
|
134
124
|
By default, no caching is done. If a string is given, it is the
|
135
125
|
path to the caching directory.
|
136
126
|
|
137
|
-
connectivity: array-like or callable, default=None
|
127
|
+
connectivity: array-like, sparse matrix, or callable, default=None
|
138
128
|
Connectivity matrix. Defines for each feature the neighboring
|
139
129
|
features following a given structure of the data.
|
140
130
|
This can be a connectivity matrix itself or a callable that transforms
|
@@ -187,8 +177,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
187
177
|
self,
|
188
178
|
*,
|
189
179
|
n_clusters=2,
|
190
|
-
|
191
|
-
metric=None,
|
180
|
+
metric="euclidean",
|
192
181
|
memory=None,
|
193
182
|
connectivity=None,
|
194
183
|
compute_full_tree="auto",
|
@@ -218,8 +207,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
218
207
|
self._deps = list(deps)
|
219
208
|
|
220
209
|
init_args = {'n_clusters':(n_clusters, 2, False),
|
221
|
-
'
|
222
|
-
'metric':(metric, None, False),
|
210
|
+
'metric':(metric, "euclidean", False),
|
223
211
|
'memory':(memory, None, False),
|
224
212
|
'connectivity':(connectivity, None, False),
|
225
213
|
'compute_full_tree':(compute_full_tree, "auto", False),
|