snowflake-ml-python 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +144 -10
- snowflake/ml/model/_client/ops/model_ops.py +25 -6
- snowflake/ml/model/_client/ops/service_ops.py +33 -28
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/model.py +14 -0
- snowflake/ml/model/_client/sql/service.py +6 -18
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -6
- snowflake/ml/model/_packager/model_handlers/custom.py +2 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +3 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -65
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +13 -25
- snowflake/ml/model/_signatures/pandas_handler.py +16 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +8 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/{monitor_sql_client.py → model_monitor_sql_client.py} +1 -1
- snowflake/ml/monitoring/{_client → _manager}/model_monitor_manager.py +9 -8
- snowflake/ml/monitoring/{_client/model_monitor.py → model_monitor.py} +3 -3
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +15 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/METADATA +81 -9
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/RECORD +150 -150
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/WHEEL +1 -1
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from importlib import metadata as importlib_metadata
|
5
4
|
from typing import (
|
6
5
|
TYPE_CHECKING,
|
7
6
|
Any,
|
@@ -16,23 +15,19 @@ from typing import (
|
|
16
15
|
|
17
16
|
import numpy as np
|
18
17
|
import pandas as pd
|
19
|
-
from packaging import version
|
20
18
|
from typing_extensions import TypeGuard, Unpack
|
21
19
|
|
22
20
|
from snowflake.ml._internal import type_utils
|
23
21
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
24
22
|
from snowflake.ml.model._packager.model_env import model_env
|
25
|
-
from snowflake.ml.model._packager.model_handlers import
|
26
|
-
_base,
|
27
|
-
_utils as handlers_utils,
|
28
|
-
model_objective_utils,
|
29
|
-
)
|
23
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
30
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
31
25
|
from snowflake.ml.model._packager.model_meta import (
|
32
26
|
model_blob_meta,
|
33
27
|
model_meta as model_meta_api,
|
34
28
|
model_meta_schema,
|
35
29
|
)
|
30
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
36
31
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
37
32
|
|
38
33
|
if TYPE_CHECKING:
|
@@ -94,23 +89,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
94
89
|
|
95
90
|
assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
|
96
91
|
|
97
|
-
local_xgb_version = None
|
98
|
-
|
99
|
-
try:
|
100
|
-
local_dist = importlib_metadata.distribution("xgboost")
|
101
|
-
local_xgb_version = version.parse(local_dist.version)
|
102
|
-
except importlib_metadata.PackageNotFoundError:
|
103
|
-
pass
|
104
|
-
|
105
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
|
106
|
-
warnings.warn(
|
107
|
-
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
108
|
-
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
109
|
-
category=UserWarning,
|
110
|
-
stacklevel=1,
|
111
|
-
)
|
112
|
-
enable_explainability = False
|
113
|
-
|
114
92
|
if not is_sub_model:
|
115
93
|
target_methods = handlers_utils.get_target_methods(
|
116
94
|
model=model,
|
@@ -139,7 +117,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
139
117
|
sample_input_data=sample_input_data,
|
140
118
|
get_prediction_fn=get_prediction,
|
141
119
|
)
|
142
|
-
model_task_and_output =
|
120
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
143
121
|
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
144
122
|
if enable_explainability:
|
145
123
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -187,23 +165,15 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
187
165
|
],
|
188
166
|
check_local_version=True,
|
189
167
|
)
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
)
|
197
|
-
else:
|
198
|
-
model_meta.env.include_if_absent(
|
199
|
-
[
|
200
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
201
|
-
],
|
202
|
-
check_local_version=True,
|
203
|
-
)
|
168
|
+
model_meta.env.include_if_absent(
|
169
|
+
[
|
170
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
171
|
+
],
|
172
|
+
check_local_version=True,
|
173
|
+
)
|
204
174
|
|
205
175
|
if enable_explainability:
|
206
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
176
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
207
177
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
208
178
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
209
179
|
|
@@ -61,17 +61,6 @@ class ModelPackager:
|
|
61
61
|
if not options:
|
62
62
|
options = model_types.BaseModelSaveOption()
|
63
63
|
|
64
|
-
# here handling the case of enable_explainability is False/None
|
65
|
-
enable_explainability = options.get("enable_explainability", None)
|
66
|
-
if enable_explainability is False or enable_explainability is None:
|
67
|
-
if (signatures is not None) and (sample_input_data is not None):
|
68
|
-
raise snowml_exceptions.SnowflakeMLException(
|
69
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
70
|
-
original_exception=ValueError(
|
71
|
-
"Signatures and sample_input_data both cannot be specified at the same time."
|
72
|
-
),
|
73
|
-
)
|
74
|
-
|
75
64
|
handler = model_handler.find_handler(model)
|
76
65
|
if handler is None:
|
77
66
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -128,42 +128,30 @@ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> t
|
|
128
128
|
return type_hints.Task.UNKNOWN
|
129
129
|
|
130
130
|
|
131
|
-
def
|
131
|
+
def _get_model_task(model: Any) -> type_hints.Task:
|
132
132
|
if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
|
133
133
|
model
|
134
134
|
):
|
135
|
-
|
136
|
-
output_type = model_signature.DataType.DOUBLE
|
137
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
138
|
-
output_type = model_signature.DataType.STRING
|
139
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
135
|
+
return get_model_task_xgb(model)
|
140
136
|
|
141
137
|
if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
|
142
138
|
"lightgbm.LGBMModel"
|
143
139
|
).isinstance(model):
|
144
|
-
|
145
|
-
output_type = model_signature.DataType.DOUBLE
|
146
|
-
if task in [
|
147
|
-
type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
|
148
|
-
type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
|
149
|
-
]:
|
150
|
-
output_type = model_signature.DataType.STRING
|
151
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
140
|
+
return get_model_task_lightgbm(model)
|
152
141
|
|
153
142
|
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
154
|
-
|
155
|
-
output_type = model_signature.DataType.DOUBLE
|
156
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
157
|
-
output_type = model_signature.DataType.STRING
|
158
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
143
|
+
return get_model_task_catboost(model)
|
159
144
|
|
160
145
|
if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
|
161
146
|
"sklearn.pipeline.Pipeline"
|
162
147
|
).isinstance(model):
|
163
|
-
|
164
|
-
output_type = model_signature.DataType.DOUBLE
|
165
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
166
|
-
output_type = model_signature.DataType.STRING
|
167
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
168
|
-
|
148
|
+
return get_task_skl(model)
|
169
149
|
raise ValueError(f"Model type {type(model)} is not supported")
|
150
|
+
|
151
|
+
|
152
|
+
def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
|
153
|
+
task = _get_model_task(model)
|
154
|
+
output_type = model_signature.DataType.DOUBLE
|
155
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
156
|
+
output_type = model_signature.DataType.STRING
|
157
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
@@ -147,6 +147,22 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
147
147
|
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
148
148
|
elif isinstance(data[df_col].iloc[0], bytes):
|
149
149
|
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
150
|
+
elif isinstance(df_col_dtype, pd.CategoricalDtype):
|
151
|
+
category_dtype = df_col_dtype.categories.dtype
|
152
|
+
if category_dtype == np.dtype("O"):
|
153
|
+
if isinstance(df_col_dtype.categories[0], str):
|
154
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
155
|
+
elif isinstance(df_col_dtype.categories[0], bytes):
|
156
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
157
|
+
else:
|
158
|
+
raise snowml_exceptions.SnowflakeMLException(
|
159
|
+
error_code=error_codes.INVALID_DATA,
|
160
|
+
original_exception=ValueError(
|
161
|
+
f"Data Validation Error: Unsupported type confronted in {df_col_dtype.categories[0]}"
|
162
|
+
),
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(category_dtype), name=ft_name))
|
150
166
|
elif isinstance(data[df_col].iloc[0], np.datetime64):
|
151
167
|
specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
|
152
168
|
else:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
|
3
|
+
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
@@ -104,19 +104,53 @@ class ModelContext:
|
|
104
104
|
def __init__(
|
105
105
|
self,
|
106
106
|
*,
|
107
|
-
artifacts: Optional[Dict[str, str]] = None,
|
108
|
-
models: Optional[Dict[str, model_types.SupportedModelType]] = None,
|
107
|
+
artifacts: Optional[Union[Dict[str, str], str, model_types.SupportedModelType]] = None,
|
108
|
+
models: Optional[Union[Dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
|
109
|
+
**kwargs: Optional[Union[str, model_types.SupportedModelType]],
|
109
110
|
) -> None:
|
110
111
|
"""Initialize the model context.
|
111
112
|
|
112
113
|
Args:
|
113
114
|
artifacts: A dictionary mapping the name of the artifact to its currently available path. Defaults to None.
|
114
115
|
models: A dictionary mapping the name of the sub-model to the corresponding model object. Defaults to None.
|
116
|
+
**kwargs: Additional keyword arguments to be used as artifacts or models.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
ValueError: Raised when the keyword argument is used as artifacts or models.
|
120
|
+
ValueError: Raised when the artifact name is duplicated.
|
121
|
+
ValueError: Raised when the model name is duplicated.
|
115
122
|
"""
|
116
|
-
|
117
|
-
self.
|
118
|
-
|
119
|
-
|
123
|
+
|
124
|
+
self.artifacts: Dict[str, str] = dict()
|
125
|
+
self.model_refs: Dict[str, ModelRef] = dict()
|
126
|
+
|
127
|
+
# In case that artifacts is a dictionary, assume the original usage,
|
128
|
+
# which is to pass in a dictionary of artifacts.
|
129
|
+
# In other scenarios, (str or supported model types) we will try to parse the arguments as artifacts or models.
|
130
|
+
if isinstance(artifacts, dict):
|
131
|
+
self.artifacts = artifacts
|
132
|
+
elif isinstance(artifacts, str):
|
133
|
+
self.artifacts["artifacts"] = artifacts
|
134
|
+
elif artifacts is not None:
|
135
|
+
self.model_refs["artifacts"] = ModelRef("artifacts", artifacts)
|
136
|
+
|
137
|
+
if isinstance(models, dict):
|
138
|
+
self.model_refs = {name: ModelRef(name, model) for name, model in models.items()} if models else dict()
|
139
|
+
elif isinstance(models, str):
|
140
|
+
self.artifacts["models"] = models
|
141
|
+
elif models is not None:
|
142
|
+
self.model_refs["models"] = ModelRef("models", models)
|
143
|
+
|
144
|
+
# Handle any new arguments passed via kwargs
|
145
|
+
for key, value in kwargs.items():
|
146
|
+
if isinstance(value, str):
|
147
|
+
if key in self.artifacts:
|
148
|
+
raise ValueError(f"Duplicate artifact name: {key}")
|
149
|
+
self.artifacts[key] = value
|
150
|
+
else:
|
151
|
+
if key in self.model_refs:
|
152
|
+
raise ValueError(f"Duplicate model name: {key}")
|
153
|
+
self.model_refs[key] = ModelRef(key, value)
|
120
154
|
|
121
155
|
def path(self, key: str) -> str:
|
122
156
|
"""Get the actual path to a specific artifact. This could be used when defining a Custom Model to retrieve
|
@@ -141,6 +175,12 @@ class ModelContext:
|
|
141
175
|
"""
|
142
176
|
return self.model_refs[name]
|
143
177
|
|
178
|
+
def __getitem__(self, key: str) -> Union[str, ModelRef]:
|
179
|
+
combined: Dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
|
180
|
+
if key not in combined:
|
181
|
+
raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
|
182
|
+
return combined[key]
|
183
|
+
|
144
184
|
|
145
185
|
class CustomModel:
|
146
186
|
"""Abstract class for user defined custom model.
|
@@ -214,6 +214,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
214
214
|
assert isinstance(feature, core.FeatureSpec) # assert for mypy.
|
215
215
|
ft_type = feature._dtype
|
216
216
|
ft_shape = feature._shape
|
217
|
+
if isinstance(df_col_dtype, pd.CategoricalDtype):
|
218
|
+
df_col_dtype = df_col_dtype.categories.dtype
|
217
219
|
if df_col_dtype != np.dtype("O"):
|
218
220
|
if not _validate_numpy_array(data_col.to_numpy(), ft_type, strict=strict):
|
219
221
|
raise snowml_exceptions.SnowflakeMLException(
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -298,3 +298,11 @@ class Task(Enum):
|
|
298
298
|
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
299
299
|
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
300
300
|
TABULAR_RANKING = "TABULAR_RANKING"
|
301
|
+
|
302
|
+
|
303
|
+
class TargetPlatform(Enum):
|
304
|
+
WAREHOUSE = "WAREHOUSE"
|
305
|
+
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
306
|
+
|
307
|
+
|
308
|
+
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
@@ -275,3 +275,16 @@ def upload_model_to_stage(
|
|
275
275
|
|
276
276
|
temp_file_utils.cleanup_temp_files([local_transform_file_name])
|
277
277
|
return os.path.basename(local_transform_file_name)
|
278
|
+
|
279
|
+
|
280
|
+
def should_include_sample_weight(estimator: object, method_name: str) -> bool:
|
281
|
+
# If this is a Grid Search or Randomized Search estimator, check the underlying estimator.
|
282
|
+
underlying_estimator = (
|
283
|
+
estimator.estimator if ("_search" in estimator.__module__ and hasattr(estimator, "estimator")) else estimator
|
284
|
+
)
|
285
|
+
method = getattr(underlying_estimator, method_name)
|
286
|
+
underlying_estimator_params = inspect.signature(method).parameters
|
287
|
+
if "sample_weight" in underlying_estimator_params:
|
288
|
+
return True
|
289
|
+
|
290
|
+
return False
|
@@ -4,7 +4,10 @@ from typing import Any, List, Optional
|
|
4
4
|
import pandas as pd
|
5
5
|
|
6
6
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
7
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
7
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
8
|
+
handle_inference_result,
|
9
|
+
should_include_sample_weight,
|
10
|
+
)
|
8
11
|
|
9
12
|
|
10
13
|
class PandasTransformHandlers:
|
@@ -166,6 +169,7 @@ class PandasTransformHandlers:
|
|
166
169
|
SnowflakeMLException: The input column list does not have one of `X` and `X_test`.
|
167
170
|
"""
|
168
171
|
assert hasattr(self.estimator, "score") # make type checker happy
|
172
|
+
|
169
173
|
params = inspect.signature(self.estimator.score).parameters
|
170
174
|
if "X" in params:
|
171
175
|
score_args = {"X": self.dataset[input_cols]}
|
@@ -181,7 +185,8 @@ class PandasTransformHandlers:
|
|
181
185
|
label_arg_name = "Y" if "Y" in params else "y"
|
182
186
|
score_args[label_arg_name] = self.dataset[label_cols].squeeze()
|
183
187
|
|
184
|
-
|
188
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
189
|
+
if sample_weight_col is not None and should_include_sample_weight(self.estimator, "score"):
|
185
190
|
score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze()
|
186
191
|
|
187
192
|
score = self.estimator.score(**score_args)
|
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
|
|
19
19
|
snowpark_dataframe_utils,
|
20
20
|
temp_file_utils,
|
21
21
|
)
|
22
|
+
from snowflake.ml.modeling._internal.estimator_utils import should_include_sample_weight
|
22
23
|
from snowflake.ml.modeling._internal.model_specifications import (
|
23
24
|
ModelSpecificationsBuilder,
|
24
25
|
)
|
@@ -38,6 +39,7 @@ from snowflake.snowpark.udtf import UDTFRegistration
|
|
38
39
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
39
40
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
40
41
|
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
42
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
41
43
|
|
42
44
|
_PROJECT = "ModelDevelopment"
|
43
45
|
DEFAULT_UDTF_NJOBS = 3
|
@@ -393,7 +395,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
393
395
|
import pandas as pd
|
394
396
|
import pyarrow.parquet as pq
|
395
397
|
from sklearn.metrics import check_scoring
|
396
|
-
from sklearn.metrics._scorer import
|
398
|
+
from sklearn.metrics._scorer import (
|
399
|
+
_check_multimetric_scoring,
|
400
|
+
_MultimetricScorer,
|
401
|
+
)
|
397
402
|
|
398
403
|
for import_name in udf_imports:
|
399
404
|
importlib.import_module(import_name)
|
@@ -606,6 +611,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
606
611
|
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
607
612
|
estimator._check_refit_for_multimetric(scorers)
|
608
613
|
refit_metric = original_refit
|
614
|
+
scorers = _MultimetricScorer(scorers=scorers)
|
609
615
|
|
610
616
|
estimator.scorer_ = scorers
|
611
617
|
|
@@ -638,7 +644,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
638
644
|
if label_cols:
|
639
645
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
640
646
|
args[label_arg_name] = y
|
641
|
-
if sample_weight_col is not None and "
|
647
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
642
648
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
643
649
|
estimator.refit = original_refit
|
644
650
|
refit_start_time = time.time()
|
@@ -797,8 +803,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
797
803
|
import pandas as pd
|
798
804
|
import pyarrow.parquet as pq
|
799
805
|
from sklearn.metrics import check_scoring
|
800
|
-
from sklearn.metrics._scorer import
|
801
|
-
|
806
|
+
from sklearn.metrics._scorer import (
|
807
|
+
_check_multimetric_scoring,
|
808
|
+
_MultimetricScorer,
|
809
|
+
)
|
810
|
+
from sklearn.utils.validation import _check_method_params, indexable
|
802
811
|
|
803
812
|
# import packages in sproc
|
804
813
|
for import_name in udf_imports:
|
@@ -846,11 +855,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
846
855
|
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
847
856
|
estimator._check_refit_for_multimetric(scorers)
|
848
857
|
refit_metric = estimator.refit
|
858
|
+
scorers = _MultimetricScorer(scorers=scorers)
|
849
859
|
|
850
860
|
# preprocess the attributes - (2) check fit_params
|
851
861
|
groups = None
|
852
862
|
X, y, _ = indexable(X, y, groups)
|
853
|
-
fit_params =
|
863
|
+
fit_params = _check_method_params(X, fit_params)
|
854
864
|
|
855
865
|
# preprocess the attributes - (3) safe clone base estimator
|
856
866
|
base_estimator = clone(estimator.estimator)
|
@@ -863,6 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
863
873
|
fit_and_score_kwargs = dict(
|
864
874
|
scorer=scorers,
|
865
875
|
fit_params=fit_params,
|
876
|
+
score_params=None,
|
866
877
|
return_train_score=estimator.return_train_score,
|
867
878
|
return_n_test_samples=True,
|
868
879
|
return_times=True,
|
@@ -18,7 +18,10 @@ from snowflake.ml._internal.utils import (
|
|
18
18
|
)
|
19
19
|
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
20
20
|
from snowflake.ml.modeling._internal import estimator_utils
|
21
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
21
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
22
|
+
handle_inference_result,
|
23
|
+
should_include_sample_weight,
|
24
|
+
)
|
22
25
|
from snowflake.snowpark import DataFrame, Session, functions as F, types as T
|
23
26
|
from snowflake.snowpark._internal.utils import (
|
24
27
|
TempObjectType,
|
@@ -28,6 +31,8 @@ from snowflake.snowpark._internal.utils import (
|
|
28
31
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
29
32
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
30
33
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
34
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
35
|
+
|
31
36
|
|
32
37
|
_PROJECT = "ModelDevelopment"
|
33
38
|
|
@@ -330,7 +335,8 @@ class SnowparkTransformHandlers:
|
|
330
335
|
label_arg_name = "Y" if "Y" in params else "y"
|
331
336
|
args[label_arg_name] = df[label_cols].squeeze()
|
332
337
|
|
333
|
-
|
338
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
339
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "score"):
|
334
340
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
335
341
|
|
336
342
|
result: float = estimator.score(**args)
|
@@ -20,7 +20,10 @@ from snowflake.ml._internal.utils import (
|
|
20
20
|
temp_file_utils,
|
21
21
|
)
|
22
22
|
from snowflake.ml.modeling._internal import estimator_utils
|
23
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
23
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
24
|
+
handle_inference_result,
|
25
|
+
should_include_sample_weight,
|
26
|
+
)
|
24
27
|
from snowflake.ml.modeling._internal.model_specifications import (
|
25
28
|
ModelSpecifications,
|
26
29
|
ModelSpecificationsBuilder,
|
@@ -32,6 +35,7 @@ from snowflake.snowpark.stored_procedure import StoredProcedure
|
|
32
35
|
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
33
36
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
34
37
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
38
|
+
cp.register_pickle_by_value(inspect.getmodule(should_include_sample_weight))
|
35
39
|
|
36
40
|
_PROJECT = "ModelDevelopment"
|
37
41
|
_ENABLE_ANONYMOUS_SPROC = False
|
@@ -170,12 +174,14 @@ class SnowparkModelTrainer:
|
|
170
174
|
estimator = cp.load(local_transform_file_obj)
|
171
175
|
|
172
176
|
params = inspect.signature(estimator.fit).parameters
|
177
|
+
|
173
178
|
args = {"X": df[input_cols]}
|
174
179
|
if label_cols:
|
175
180
|
label_arg_name = "Y" if "Y" in params else "y"
|
176
181
|
args[label_arg_name] = df[label_cols].squeeze()
|
177
182
|
|
178
|
-
|
183
|
+
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
184
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
179
185
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
180
186
|
|
181
187
|
estimator.fit(**args)
|
@@ -412,7 +418,7 @@ class SnowparkModelTrainer:
|
|
412
418
|
label_arg_name = "Y" if "Y" in params else "y"
|
413
419
|
args[label_arg_name] = df[label_cols].squeeze()
|
414
420
|
|
415
|
-
if sample_weight_col is not None and "
|
421
|
+
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
416
422
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
417
423
|
|
418
424
|
fit_transform_result = estimator.fit_transform(**args)
|
@@ -167,9 +167,6 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
167
167
|
`estimator` trained on all the data.
|
168
168
|
Note that this method is also internally implemented in
|
169
169
|
:mod:`sklearn.svm` estimators with the `probabilities=True` parameter.
|
170
|
-
|
171
|
-
base_estimator: estimator instance
|
172
|
-
This parameter is deprecated. Use `estimator` instead.
|
173
170
|
"""
|
174
171
|
|
175
172
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -180,7 +177,6 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
180
177
|
cv=None,
|
181
178
|
n_jobs=None,
|
182
179
|
ensemble=True,
|
183
|
-
base_estimator="deprecated",
|
184
180
|
input_cols: Optional[Union[str, Iterable[str]]] = None,
|
185
181
|
output_cols: Optional[Union[str, Iterable[str]]] = None,
|
186
182
|
label_cols: Optional[Union[str, Iterable[str]]] = None,
|
@@ -200,16 +196,13 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
200
196
|
self._batch_size = -1
|
201
197
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
202
198
|
deps = deps | gather_dependencies(estimator)
|
203
|
-
deps = deps | gather_dependencies(base_estimator)
|
204
199
|
self._deps = list(deps)
|
205
200
|
estimator = transform_snowml_obj_to_sklearn_obj(estimator)
|
206
|
-
base_estimator = transform_snowml_obj_to_sklearn_obj(base_estimator)
|
207
201
|
init_args = {'estimator':(estimator, None, False),
|
208
202
|
'method':(method, "sigmoid", False),
|
209
203
|
'cv':(cv, None, False),
|
210
204
|
'n_jobs':(n_jobs, None, False),
|
211
|
-
'ensemble':(ensemble, True, False),
|
212
|
-
'base_estimator':(base_estimator, "deprecated", False),}
|
205
|
+
'ensemble':(ensemble, True, False),}
|
213
206
|
cleaned_up_init_args = validate_sklearn_args(
|
214
207
|
args=init_args,
|
215
208
|
klass=sklearn.calibration.CalibratedClassifierCV
|
@@ -113,28 +113,18 @@ class AgglomerativeClustering(BaseTransformer):
|
|
113
113
|
The number of clusters to find. It must be ``None`` if
|
114
114
|
``distance_threshold`` is not ``None``.
|
115
115
|
|
116
|
-
|
117
|
-
The metric to use when calculating distance between instances in a
|
118
|
-
feature array. If metric is a string or callable, it must be one of
|
119
|
-
the options allowed by :func:`sklearn.metrics.pairwise_distances` for
|
120
|
-
its metric parameter.
|
121
|
-
If linkage is "ward", only "euclidean" is accepted.
|
122
|
-
If "precomputed", a distance matrix (instead of a similarity matrix)
|
123
|
-
is needed as input for the fit method.
|
124
|
-
|
125
|
-
metric: str or callable, default=None
|
116
|
+
metric: str or callable, default="euclidean"
|
126
117
|
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
|
127
|
-
"manhattan", "cosine", or "precomputed". If
|
128
|
-
"euclidean" is
|
129
|
-
|
130
|
-
the fit method.
|
118
|
+
"manhattan", "cosine", or "precomputed". If linkage is "ward", only
|
119
|
+
"euclidean" is accepted. If "precomputed", a distance matrix is needed
|
120
|
+
as input for the fit method.
|
131
121
|
|
132
122
|
memory: str or object with the joblib.Memory interface, default=None
|
133
123
|
Used to cache the output of the computation of the tree.
|
134
124
|
By default, no caching is done. If a string is given, it is the
|
135
125
|
path to the caching directory.
|
136
126
|
|
137
|
-
connectivity: array-like or callable, default=None
|
127
|
+
connectivity: array-like, sparse matrix, or callable, default=None
|
138
128
|
Connectivity matrix. Defines for each sample the neighboring
|
139
129
|
samples following a given structure of the data.
|
140
130
|
This can be a connectivity matrix itself or a callable that transforms
|
@@ -142,6 +132,10 @@ class AgglomerativeClustering(BaseTransformer):
|
|
142
132
|
`kneighbors_graph`. Default is ``None``, i.e, the
|
143
133
|
hierarchical clustering algorithm is unstructured.
|
144
134
|
|
135
|
+
For an example of connectivity matrix using
|
136
|
+
:class:`~sklearn.neighbors.kneighbors_graph`, see
|
137
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_clustering.py`.
|
138
|
+
|
145
139
|
compute_full_tree: 'auto' or bool, default='auto'
|
146
140
|
Stop early the construction of the tree at ``n_clusters``. This is
|
147
141
|
useful to decrease computation time if the number of clusters is not
|
@@ -167,6 +161,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
167
161
|
- 'single' uses the minimum of the distances between all observations
|
168
162
|
of the two sets.
|
169
163
|
|
164
|
+
For examples comparing different `linkage` criteria, see
|
165
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_linkage_comparison.py`.
|
166
|
+
|
170
167
|
distance_threshold: float, default=None
|
171
168
|
The linkage distance threshold at or above which clusters will not be
|
172
169
|
merged. If not ``None``, ``n_clusters`` must be ``None`` and
|
@@ -176,14 +173,16 @@ class AgglomerativeClustering(BaseTransformer):
|
|
176
173
|
Computes distances between clusters even if `distance_threshold` is not
|
177
174
|
used. This can be used to make dendrogram visualization, but introduces
|
178
175
|
a computational and memory overhead.
|
176
|
+
|
177
|
+
For an example of dendrogram visualization, see
|
178
|
+
:ref:`sphx_glr_auto_examples_cluster_plot_agglomerative_dendrogram.py`.
|
179
179
|
"""
|
180
180
|
|
181
181
|
def __init__( # type: ignore[no-untyped-def]
|
182
182
|
self,
|
183
183
|
*,
|
184
184
|
n_clusters=2,
|
185
|
-
|
186
|
-
metric=None,
|
185
|
+
metric="euclidean",
|
187
186
|
memory=None,
|
188
187
|
connectivity=None,
|
189
188
|
compute_full_tree="auto",
|
@@ -212,8 +211,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
212
211
|
self._deps = list(deps)
|
213
212
|
|
214
213
|
init_args = {'n_clusters':(n_clusters, 2, False),
|
215
|
-
'
|
216
|
-
'metric':(metric, None, False),
|
214
|
+
'metric':(metric, "euclidean", False),
|
217
215
|
'memory':(memory, None, False),
|
218
216
|
'connectivity':(connectivity, None, False),
|
219
217
|
'compute_full_tree':(compute_full_tree, "auto", False),
|
@@ -117,8 +117,11 @@ class DBSCAN(BaseTransformer):
|
|
117
117
|
and distance function.
|
118
118
|
|
119
119
|
min_samples: int, default=5
|
120
|
-
The number of samples (or total weight) in a neighborhood for a point
|
121
|
-
|
120
|
+
The number of samples (or total weight) in a neighborhood for a point to
|
121
|
+
be considered as a core point. This includes the point itself. If
|
122
|
+
`min_samples` is set to a higher value, DBSCAN will find denser clusters,
|
123
|
+
whereas if it is set to a lower value, the found clusters will be more
|
124
|
+
sparse.
|
122
125
|
|
123
126
|
metric: str, or callable, default='euclidean'
|
124
127
|
The metric to use when calculating distance between instances in a
|