snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class Birch(BaseTransformer):
|
71
64
|
r"""Implements the BIRCH clustering algorithm
|
72
65
|
For more details on this class, see [sklearn.cluster.Birch]
|
@@ -233,12 +226,7 @@ class Birch(BaseTransformer):
|
|
233
226
|
)
|
234
227
|
return selected_cols
|
235
228
|
|
236
|
-
|
237
|
-
project=_PROJECT,
|
238
|
-
subproject=_SUBPROJECT,
|
239
|
-
custom_tags=dict([("autogen", True)]),
|
240
|
-
)
|
241
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Birch":
|
229
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Birch":
|
242
230
|
"""Build a CF Tree for the input data
|
243
231
|
For more details on this function, see [sklearn.cluster.Birch.fit]
|
244
232
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit)
|
@@ -265,12 +253,14 @@ class Birch(BaseTransformer):
|
|
265
253
|
|
266
254
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
267
255
|
|
268
|
-
|
256
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
269
257
|
if SNOWML_SPROC_ENV in os.environ:
|
270
258
|
statement_params = telemetry.get_function_usage_statement_params(
|
271
259
|
project=_PROJECT,
|
272
260
|
subproject=_SUBPROJECT,
|
273
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
261
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
262
|
+
inspect.currentframe(), Birch.__class__.__name__
|
263
|
+
),
|
274
264
|
api_calls=[Session.call],
|
275
265
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
276
266
|
)
|
@@ -291,27 +281,24 @@ class Birch(BaseTransformer):
|
|
291
281
|
)
|
292
282
|
self._sklearn_object = model_trainer.train()
|
293
283
|
self._is_fitted = True
|
294
|
-
self.
|
284
|
+
self._generate_model_signatures(dataset)
|
295
285
|
return self
|
296
286
|
|
297
287
|
def _batch_inference_validate_snowpark(
|
298
288
|
self,
|
299
289
|
dataset: DataFrame,
|
300
290
|
inference_method: str,
|
301
|
-
) ->
|
302
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
303
|
-
return the available package that exists in the snowflake anaconda channel
|
291
|
+
) -> None:
|
292
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
304
293
|
|
305
294
|
Args:
|
306
295
|
dataset: snowpark dataframe
|
307
296
|
inference_method: the inference method such as predict, score...
|
308
|
-
|
297
|
+
|
309
298
|
Raises:
|
310
299
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
311
300
|
SnowflakeMLException: If the session is None, raise error
|
312
301
|
|
313
|
-
Returns:
|
314
|
-
A list of available package that exists in the snowflake anaconda channel
|
315
302
|
"""
|
316
303
|
if not self._is_fitted:
|
317
304
|
raise exceptions.SnowflakeMLException(
|
@@ -329,9 +316,7 @@ class Birch(BaseTransformer):
|
|
329
316
|
"Session must not specified for snowpark dataset."
|
330
317
|
),
|
331
318
|
)
|
332
|
-
|
333
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
334
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
+
|
335
320
|
|
336
321
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
337
322
|
@telemetry.send_api_usage_telemetry(
|
@@ -367,7 +352,9 @@ class Birch(BaseTransformer):
|
|
367
352
|
# when it is classifier, infer the datatype from label columns
|
368
353
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
369
354
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
370
|
-
label_cols_signatures = [
|
355
|
+
label_cols_signatures = [
|
356
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
357
|
+
]
|
371
358
|
if len(label_cols_signatures) == 0:
|
372
359
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
373
360
|
raise exceptions.SnowflakeMLException(
|
@@ -375,25 +362,23 @@ class Birch(BaseTransformer):
|
|
375
362
|
original_exception=ValueError(error_str),
|
376
363
|
)
|
377
364
|
|
378
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
379
|
-
label_cols_signatures[0].as_snowpark_type()
|
380
|
-
)
|
365
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
381
366
|
|
382
|
-
self.
|
383
|
-
|
367
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
368
|
+
self._deps = self._get_dependencies()
|
369
|
+
assert isinstance(
|
370
|
+
dataset._session, Session
|
371
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
384
372
|
|
385
373
|
transform_kwargs = dict(
|
386
|
-
session
|
387
|
-
dependencies
|
388
|
-
drop_input_cols
|
389
|
-
expected_output_cols_type
|
374
|
+
session=dataset._session,
|
375
|
+
dependencies=self._deps,
|
376
|
+
drop_input_cols=self._drop_input_cols,
|
377
|
+
expected_output_cols_type=expected_type_inferred,
|
390
378
|
)
|
391
379
|
|
392
380
|
elif isinstance(dataset, pd.DataFrame):
|
393
|
-
transform_kwargs = dict(
|
394
|
-
snowpark_input_cols = self._snowpark_cols,
|
395
|
-
drop_input_cols = self._drop_input_cols
|
396
|
-
)
|
381
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
397
382
|
|
398
383
|
transform_handlers = ModelTransformerBuilder.build(
|
399
384
|
dataset=dataset,
|
@@ -435,7 +420,7 @@ class Birch(BaseTransformer):
|
|
435
420
|
Transformed dataset.
|
436
421
|
"""
|
437
422
|
super()._check_dataset_type(dataset)
|
438
|
-
inference_method="transform"
|
423
|
+
inference_method = "transform"
|
439
424
|
|
440
425
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
441
426
|
# are specific to the type of dataset used.
|
@@ -465,24 +450,19 @@ class Birch(BaseTransformer):
|
|
465
450
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
466
451
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
467
452
|
|
468
|
-
self.
|
469
|
-
|
470
|
-
inference_method=inference_method,
|
471
|
-
)
|
453
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
454
|
+
self._deps = self._get_dependencies()
|
472
455
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
473
456
|
|
474
457
|
transform_kwargs = dict(
|
475
|
-
session
|
476
|
-
dependencies
|
477
|
-
drop_input_cols
|
478
|
-
expected_output_cols_type
|
458
|
+
session=dataset._session,
|
459
|
+
dependencies=self._deps,
|
460
|
+
drop_input_cols=self._drop_input_cols,
|
461
|
+
expected_output_cols_type=expected_dtype,
|
479
462
|
)
|
480
463
|
|
481
464
|
elif isinstance(dataset, pd.DataFrame):
|
482
|
-
transform_kwargs = dict(
|
483
|
-
snowpark_input_cols = self._snowpark_cols,
|
484
|
-
drop_input_cols = self._drop_input_cols
|
485
|
-
)
|
465
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
486
466
|
|
487
467
|
transform_handlers = ModelTransformerBuilder.build(
|
488
468
|
dataset=dataset,
|
@@ -501,7 +481,11 @@ class Birch(BaseTransformer):
|
|
501
481
|
return output_df
|
502
482
|
|
503
483
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
504
|
-
def fit_predict(
|
484
|
+
def fit_predict(
|
485
|
+
self,
|
486
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
487
|
+
output_cols_prefix: str = "fit_predict_",
|
488
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
505
489
|
""" Perform clustering on `X` and returns cluster labels
|
506
490
|
For more details on this function, see [sklearn.cluster.Birch.fit_predict]
|
507
491
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_predict)
|
@@ -528,22 +512,106 @@ class Birch(BaseTransformer):
|
|
528
512
|
)
|
529
513
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
530
514
|
drop_input_cols=self._drop_input_cols,
|
531
|
-
expected_output_cols_list=
|
515
|
+
expected_output_cols_list=(
|
516
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
517
|
+
),
|
532
518
|
)
|
533
519
|
self._sklearn_object = fitted_estimator
|
534
520
|
self._is_fitted = True
|
535
521
|
return output_result
|
536
522
|
|
523
|
+
|
524
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
525
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
526
|
+
""" Fit to data, then transform it
|
527
|
+
For more details on this function, see [sklearn.cluster.Birch.fit_transform]
|
528
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_transform)
|
529
|
+
|
530
|
+
|
531
|
+
Raises:
|
532
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
537
533
|
|
538
|
-
|
539
|
-
|
540
|
-
|
534
|
+
Args:
|
535
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
536
|
+
Snowpark or Pandas DataFrame.
|
537
|
+
output_cols_prefix: Prefix for the response columns
|
541
538
|
Returns:
|
542
539
|
Transformed dataset.
|
543
540
|
"""
|
544
|
-
self.
|
545
|
-
|
546
|
-
|
541
|
+
self._infer_input_output_cols(dataset)
|
542
|
+
super()._check_dataset_type(dataset)
|
543
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
544
|
+
estimator=self._sklearn_object,
|
545
|
+
dataset=dataset,
|
546
|
+
input_cols=self.input_cols,
|
547
|
+
label_cols=self.label_cols,
|
548
|
+
sample_weight_col=self.sample_weight_col,
|
549
|
+
autogenerated=self._autogenerated,
|
550
|
+
subproject=_SUBPROJECT,
|
551
|
+
)
|
552
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
553
|
+
drop_input_cols=self._drop_input_cols,
|
554
|
+
expected_output_cols_list=self.output_cols,
|
555
|
+
)
|
556
|
+
self._sklearn_object = fitted_estimator
|
557
|
+
self._is_fitted = True
|
558
|
+
return output_result
|
559
|
+
|
560
|
+
|
561
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
562
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
563
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
564
|
+
"""
|
565
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
566
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
567
|
+
if output_cols:
|
568
|
+
output_cols = [
|
569
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
570
|
+
for c in output_cols
|
571
|
+
]
|
572
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
573
|
+
output_cols = [output_cols_prefix]
|
574
|
+
elif self._sklearn_object is not None:
|
575
|
+
classes = self._sklearn_object.classes_
|
576
|
+
if isinstance(classes, numpy.ndarray):
|
577
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
578
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
579
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
580
|
+
output_cols = []
|
581
|
+
for i, cl in enumerate(classes):
|
582
|
+
# For binary classification, there is only one output column for each class
|
583
|
+
# ndarray as the two classes are complementary.
|
584
|
+
if len(cl) == 2:
|
585
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
586
|
+
else:
|
587
|
+
output_cols.extend([
|
588
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
589
|
+
])
|
590
|
+
else:
|
591
|
+
output_cols = []
|
592
|
+
|
593
|
+
# Make sure column names are valid snowflake identifiers.
|
594
|
+
assert output_cols is not None # Make MyPy happy
|
595
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
596
|
+
|
597
|
+
return rv
|
598
|
+
|
599
|
+
def _align_expected_output_names(
|
600
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
601
|
+
) -> List[str]:
|
602
|
+
# in case the inferred output column names dimension is different
|
603
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
604
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
605
|
+
output_df_columns = list(output_df_pd.columns)
|
606
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
607
|
+
if self.sample_weight_col:
|
608
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
609
|
+
# if the dimension of inferred output column names is correct; use it
|
610
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
611
|
+
return expected_output_cols_list
|
612
|
+
# otherwise, use the sklearn estimator's output
|
613
|
+
else:
|
614
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
547
615
|
|
548
616
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
549
617
|
@telemetry.send_api_usage_telemetry(
|
@@ -575,24 +643,26 @@ class Birch(BaseTransformer):
|
|
575
643
|
# are specific to the type of dataset used.
|
576
644
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
577
645
|
|
646
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
647
|
+
|
578
648
|
if isinstance(dataset, DataFrame):
|
579
|
-
self.
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
649
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
650
|
+
self._deps = self._get_dependencies()
|
651
|
+
assert isinstance(
|
652
|
+
dataset._session, Session
|
653
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
584
654
|
transform_kwargs = dict(
|
585
655
|
session=dataset._session,
|
586
656
|
dependencies=self._deps,
|
587
|
-
drop_input_cols
|
657
|
+
drop_input_cols=self._drop_input_cols,
|
588
658
|
expected_output_cols_type="float",
|
589
659
|
)
|
660
|
+
expected_output_cols = self._align_expected_output_names(
|
661
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
662
|
+
)
|
590
663
|
|
591
664
|
elif isinstance(dataset, pd.DataFrame):
|
592
|
-
transform_kwargs = dict(
|
593
|
-
snowpark_input_cols = self._snowpark_cols,
|
594
|
-
drop_input_cols = self._drop_input_cols
|
595
|
-
)
|
665
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
596
666
|
|
597
667
|
transform_handlers = ModelTransformerBuilder.build(
|
598
668
|
dataset=dataset,
|
@@ -604,7 +674,7 @@ class Birch(BaseTransformer):
|
|
604
674
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
605
675
|
inference_method=inference_method,
|
606
676
|
input_cols=self.input_cols,
|
607
|
-
expected_output_cols=
|
677
|
+
expected_output_cols=expected_output_cols,
|
608
678
|
**transform_kwargs
|
609
679
|
)
|
610
680
|
return output_df
|
@@ -634,29 +704,30 @@ class Birch(BaseTransformer):
|
|
634
704
|
Output dataset with log probability of the sample for each class in the model.
|
635
705
|
"""
|
636
706
|
super()._check_dataset_type(dataset)
|
637
|
-
inference_method="predict_log_proba"
|
707
|
+
inference_method = "predict_log_proba"
|
708
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
638
709
|
|
639
710
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
640
711
|
# are specific to the type of dataset used.
|
641
712
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
642
713
|
|
643
714
|
if isinstance(dataset, DataFrame):
|
644
|
-
self.
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
715
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
716
|
+
self._deps = self._get_dependencies()
|
717
|
+
assert isinstance(
|
718
|
+
dataset._session, Session
|
719
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
649
720
|
transform_kwargs = dict(
|
650
721
|
session=dataset._session,
|
651
722
|
dependencies=self._deps,
|
652
|
-
drop_input_cols
|
723
|
+
drop_input_cols=self._drop_input_cols,
|
653
724
|
expected_output_cols_type="float",
|
654
725
|
)
|
726
|
+
expected_output_cols = self._align_expected_output_names(
|
727
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
728
|
+
)
|
655
729
|
elif isinstance(dataset, pd.DataFrame):
|
656
|
-
transform_kwargs = dict(
|
657
|
-
snowpark_input_cols = self._snowpark_cols,
|
658
|
-
drop_input_cols = self._drop_input_cols
|
659
|
-
)
|
730
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
660
731
|
|
661
732
|
transform_handlers = ModelTransformerBuilder.build(
|
662
733
|
dataset=dataset,
|
@@ -669,7 +740,7 @@ class Birch(BaseTransformer):
|
|
669
740
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
670
741
|
inference_method=inference_method,
|
671
742
|
input_cols=self.input_cols,
|
672
|
-
expected_output_cols=
|
743
|
+
expected_output_cols=expected_output_cols,
|
673
744
|
**transform_kwargs
|
674
745
|
)
|
675
746
|
return output_df
|
@@ -695,30 +766,32 @@ class Birch(BaseTransformer):
|
|
695
766
|
Output dataset with results of the decision function for the samples in input dataset.
|
696
767
|
"""
|
697
768
|
super()._check_dataset_type(dataset)
|
698
|
-
inference_method="decision_function"
|
769
|
+
inference_method = "decision_function"
|
699
770
|
|
700
771
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
701
772
|
# are specific to the type of dataset used.
|
702
773
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
703
774
|
|
775
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
776
|
+
|
704
777
|
if isinstance(dataset, DataFrame):
|
705
|
-
self.
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
778
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
779
|
+
self._deps = self._get_dependencies()
|
780
|
+
assert isinstance(
|
781
|
+
dataset._session, Session
|
782
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
710
783
|
transform_kwargs = dict(
|
711
784
|
session=dataset._session,
|
712
785
|
dependencies=self._deps,
|
713
|
-
drop_input_cols
|
786
|
+
drop_input_cols=self._drop_input_cols,
|
714
787
|
expected_output_cols_type="float",
|
715
788
|
)
|
789
|
+
expected_output_cols = self._align_expected_output_names(
|
790
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
791
|
+
)
|
716
792
|
|
717
793
|
elif isinstance(dataset, pd.DataFrame):
|
718
|
-
transform_kwargs = dict(
|
719
|
-
snowpark_input_cols = self._snowpark_cols,
|
720
|
-
drop_input_cols = self._drop_input_cols
|
721
|
-
)
|
794
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
722
795
|
|
723
796
|
transform_handlers = ModelTransformerBuilder.build(
|
724
797
|
dataset=dataset,
|
@@ -731,7 +804,7 @@ class Birch(BaseTransformer):
|
|
731
804
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
732
805
|
inference_method=inference_method,
|
733
806
|
input_cols=self.input_cols,
|
734
|
-
expected_output_cols=
|
807
|
+
expected_output_cols=expected_output_cols,
|
735
808
|
**transform_kwargs
|
736
809
|
)
|
737
810
|
return output_df
|
@@ -760,17 +833,17 @@ class Birch(BaseTransformer):
|
|
760
833
|
Output dataset with probability of the sample for each class in the model.
|
761
834
|
"""
|
762
835
|
super()._check_dataset_type(dataset)
|
763
|
-
inference_method="score_samples"
|
836
|
+
inference_method = "score_samples"
|
764
837
|
|
765
838
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
766
839
|
# are specific to the type of dataset used.
|
767
840
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
768
841
|
|
842
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
843
|
+
|
769
844
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
inference_method=inference_method,
|
773
|
-
)
|
845
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
846
|
+
self._deps = self._get_dependencies()
|
774
847
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
775
848
|
transform_kwargs = dict(
|
776
849
|
session=dataset._session,
|
@@ -778,6 +851,9 @@ class Birch(BaseTransformer):
|
|
778
851
|
drop_input_cols = self._drop_input_cols,
|
779
852
|
expected_output_cols_type="float",
|
780
853
|
)
|
854
|
+
expected_output_cols = self._align_expected_output_names(
|
855
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
856
|
+
)
|
781
857
|
|
782
858
|
elif isinstance(dataset, pd.DataFrame):
|
783
859
|
transform_kwargs = dict(
|
@@ -796,7 +872,7 @@ class Birch(BaseTransformer):
|
|
796
872
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
797
873
|
inference_method=inference_method,
|
798
874
|
input_cols=self.input_cols,
|
799
|
-
expected_output_cols=
|
875
|
+
expected_output_cols=expected_output_cols,
|
800
876
|
**transform_kwargs
|
801
877
|
)
|
802
878
|
return output_df
|
@@ -829,17 +905,15 @@ class Birch(BaseTransformer):
|
|
829
905
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
830
906
|
|
831
907
|
if isinstance(dataset, DataFrame):
|
832
|
-
self.
|
833
|
-
|
834
|
-
inference_method="score",
|
835
|
-
)
|
908
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
909
|
+
self._deps = self._get_dependencies()
|
836
910
|
selected_cols = self._get_active_columns()
|
837
911
|
if len(selected_cols) > 0:
|
838
912
|
dataset = dataset.select(selected_cols)
|
839
913
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
840
914
|
transform_kwargs = dict(
|
841
915
|
session=dataset._session,
|
842
|
-
dependencies=
|
916
|
+
dependencies=self._deps,
|
843
917
|
score_sproc_imports=['sklearn'],
|
844
918
|
)
|
845
919
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -904,11 +978,8 @@ class Birch(BaseTransformer):
|
|
904
978
|
|
905
979
|
if isinstance(dataset, DataFrame):
|
906
980
|
|
907
|
-
self.
|
908
|
-
|
909
|
-
inference_method=inference_method,
|
910
|
-
|
911
|
-
)
|
981
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
982
|
+
self._deps = self._get_dependencies()
|
912
983
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
913
984
|
transform_kwargs = dict(
|
914
985
|
session = dataset._session,
|
@@ -941,50 +1012,84 @@ class Birch(BaseTransformer):
|
|
941
1012
|
)
|
942
1013
|
return output_df
|
943
1014
|
|
1015
|
+
|
1016
|
+
|
1017
|
+
def to_sklearn(self) -> Any:
|
1018
|
+
"""Get sklearn.cluster.Birch object.
|
1019
|
+
"""
|
1020
|
+
if self._sklearn_object is None:
|
1021
|
+
self._sklearn_object = self._create_sklearn_object()
|
1022
|
+
return self._sklearn_object
|
1023
|
+
|
1024
|
+
def to_xgboost(self) -> Any:
|
1025
|
+
raise exceptions.SnowflakeMLException(
|
1026
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1027
|
+
original_exception=AttributeError(
|
1028
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1029
|
+
"to_xgboost()",
|
1030
|
+
"to_sklearn()"
|
1031
|
+
)
|
1032
|
+
),
|
1033
|
+
)
|
944
1034
|
|
945
|
-
def
|
1035
|
+
def to_lightgbm(self) -> Any:
|
1036
|
+
raise exceptions.SnowflakeMLException(
|
1037
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1038
|
+
original_exception=AttributeError(
|
1039
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1040
|
+
"to_lightgbm()",
|
1041
|
+
"to_sklearn()"
|
1042
|
+
)
|
1043
|
+
),
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
def _get_dependencies(self) -> List[str]:
|
1047
|
+
return self._deps
|
1048
|
+
|
1049
|
+
|
1050
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
946
1051
|
self._model_signature_dict = dict()
|
947
1052
|
|
948
1053
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
949
1054
|
|
950
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1055
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
951
1056
|
outputs: List[BaseFeatureSpec] = []
|
952
1057
|
if hasattr(self, "predict"):
|
953
1058
|
# keep mypy happy
|
954
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1059
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
955
1060
|
# For classifier, the type of predict is the same as the type of label
|
956
|
-
if self._sklearn_object._estimator_type ==
|
957
|
-
|
1061
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1062
|
+
# label columns is the desired type for output
|
958
1063
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
959
1064
|
# rename the output columns
|
960
1065
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
961
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
962
|
-
|
963
|
-
|
1066
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1067
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1068
|
+
)
|
964
1069
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
965
1070
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
966
|
-
# Clusterer returns int64 cluster labels.
|
1071
|
+
# Clusterer returns int64 cluster labels.
|
967
1072
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
968
1073
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
969
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
970
|
-
|
971
|
-
|
972
|
-
|
1074
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1075
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1076
|
+
)
|
1077
|
+
|
973
1078
|
# For regressor, the type of predict is float64
|
974
|
-
elif self._sklearn_object._estimator_type ==
|
1079
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
975
1080
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
976
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
977
|
-
|
978
|
-
|
979
|
-
|
1081
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1082
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1083
|
+
)
|
1084
|
+
|
980
1085
|
for prob_func in PROB_FUNCTIONS:
|
981
1086
|
if hasattr(self, prob_func):
|
982
1087
|
output_cols_prefix: str = f"{prob_func}_"
|
983
1088
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
984
1089
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
985
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
986
|
-
|
987
|
-
|
1090
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1091
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1092
|
+
)
|
988
1093
|
|
989
1094
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
990
1095
|
items = list(self._model_signature_dict.items())
|
@@ -997,10 +1102,10 @@ class Birch(BaseTransformer):
|
|
997
1102
|
"""Returns model signature of current class.
|
998
1103
|
|
999
1104
|
Raises:
|
1000
|
-
|
1105
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1001
1106
|
|
1002
1107
|
Returns:
|
1003
|
-
Dict
|
1108
|
+
Dict with each method and its input output signature
|
1004
1109
|
"""
|
1005
1110
|
if self._model_signature_dict is None:
|
1006
1111
|
raise exceptions.SnowflakeMLException(
|
@@ -1008,35 +1113,3 @@ class Birch(BaseTransformer):
|
|
1008
1113
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1009
1114
|
)
|
1010
1115
|
return self._model_signature_dict
|
1011
|
-
|
1012
|
-
def to_sklearn(self) -> Any:
|
1013
|
-
"""Get sklearn.cluster.Birch object.
|
1014
|
-
"""
|
1015
|
-
if self._sklearn_object is None:
|
1016
|
-
self._sklearn_object = self._create_sklearn_object()
|
1017
|
-
return self._sklearn_object
|
1018
|
-
|
1019
|
-
def to_xgboost(self) -> Any:
|
1020
|
-
raise exceptions.SnowflakeMLException(
|
1021
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1022
|
-
original_exception=AttributeError(
|
1023
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1024
|
-
"to_xgboost()",
|
1025
|
-
"to_sklearn()"
|
1026
|
-
)
|
1027
|
-
),
|
1028
|
-
)
|
1029
|
-
|
1030
|
-
def to_lightgbm(self) -> Any:
|
1031
|
-
raise exceptions.SnowflakeMLException(
|
1032
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1033
|
-
original_exception=AttributeError(
|
1034
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1035
|
-
"to_lightgbm()",
|
1036
|
-
"to_sklearn()"
|
1037
|
-
)
|
1038
|
-
),
|
1039
|
-
)
|
1040
|
-
|
1041
|
-
def _get_dependencies(self) -> List[str]:
|
1042
|
-
return self._deps
|