snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class MiniBatchDictionaryLearning(BaseTransformer):
|
71
64
|
r"""Mini-batch dictionary learning
|
72
65
|
For more details on this class, see [sklearn.decomposition.MiniBatchDictionaryLearning]
|
@@ -339,12 +332,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
339
332
|
)
|
340
333
|
return selected_cols
|
341
334
|
|
342
|
-
|
343
|
-
project=_PROJECT,
|
344
|
-
subproject=_SUBPROJECT,
|
345
|
-
custom_tags=dict([("autogen", True)]),
|
346
|
-
)
|
347
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MiniBatchDictionaryLearning":
|
335
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MiniBatchDictionaryLearning":
|
348
336
|
"""Fit the model from data in X
|
349
337
|
For more details on this function, see [sklearn.decomposition.MiniBatchDictionaryLearning.fit]
|
350
338
|
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchDictionaryLearning.html#sklearn.decomposition.MiniBatchDictionaryLearning.fit)
|
@@ -371,12 +359,14 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
371
359
|
|
372
360
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
373
361
|
|
374
|
-
|
362
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
375
363
|
if SNOWML_SPROC_ENV in os.environ:
|
376
364
|
statement_params = telemetry.get_function_usage_statement_params(
|
377
365
|
project=_PROJECT,
|
378
366
|
subproject=_SUBPROJECT,
|
379
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
367
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
368
|
+
inspect.currentframe(), MiniBatchDictionaryLearning.__class__.__name__
|
369
|
+
),
|
380
370
|
api_calls=[Session.call],
|
381
371
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
382
372
|
)
|
@@ -397,27 +387,24 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
397
387
|
)
|
398
388
|
self._sklearn_object = model_trainer.train()
|
399
389
|
self._is_fitted = True
|
400
|
-
self.
|
390
|
+
self._generate_model_signatures(dataset)
|
401
391
|
return self
|
402
392
|
|
403
393
|
def _batch_inference_validate_snowpark(
|
404
394
|
self,
|
405
395
|
dataset: DataFrame,
|
406
396
|
inference_method: str,
|
407
|
-
) ->
|
408
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
409
|
-
return the available package that exists in the snowflake anaconda channel
|
397
|
+
) -> None:
|
398
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
410
399
|
|
411
400
|
Args:
|
412
401
|
dataset: snowpark dataframe
|
413
402
|
inference_method: the inference method such as predict, score...
|
414
|
-
|
403
|
+
|
415
404
|
Raises:
|
416
405
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
417
406
|
SnowflakeMLException: If the session is None, raise error
|
418
407
|
|
419
|
-
Returns:
|
420
|
-
A list of available package that exists in the snowflake anaconda channel
|
421
408
|
"""
|
422
409
|
if not self._is_fitted:
|
423
410
|
raise exceptions.SnowflakeMLException(
|
@@ -435,9 +422,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
435
422
|
"Session must not specified for snowpark dataset."
|
436
423
|
),
|
437
424
|
)
|
438
|
-
|
439
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
440
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
425
|
+
|
441
426
|
|
442
427
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
443
428
|
@telemetry.send_api_usage_telemetry(
|
@@ -471,7 +456,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
471
456
|
# when it is classifier, infer the datatype from label columns
|
472
457
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
473
458
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
474
|
-
label_cols_signatures = [
|
459
|
+
label_cols_signatures = [
|
460
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
461
|
+
]
|
475
462
|
if len(label_cols_signatures) == 0:
|
476
463
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
477
464
|
raise exceptions.SnowflakeMLException(
|
@@ -479,25 +466,23 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
479
466
|
original_exception=ValueError(error_str),
|
480
467
|
)
|
481
468
|
|
482
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
483
|
-
label_cols_signatures[0].as_snowpark_type()
|
484
|
-
)
|
469
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
485
470
|
|
486
|
-
self.
|
487
|
-
|
471
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
472
|
+
self._deps = self._get_dependencies()
|
473
|
+
assert isinstance(
|
474
|
+
dataset._session, Session
|
475
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
488
476
|
|
489
477
|
transform_kwargs = dict(
|
490
|
-
session
|
491
|
-
dependencies
|
492
|
-
drop_input_cols
|
493
|
-
expected_output_cols_type
|
478
|
+
session=dataset._session,
|
479
|
+
dependencies=self._deps,
|
480
|
+
drop_input_cols=self._drop_input_cols,
|
481
|
+
expected_output_cols_type=expected_type_inferred,
|
494
482
|
)
|
495
483
|
|
496
484
|
elif isinstance(dataset, pd.DataFrame):
|
497
|
-
transform_kwargs = dict(
|
498
|
-
snowpark_input_cols = self._snowpark_cols,
|
499
|
-
drop_input_cols = self._drop_input_cols
|
500
|
-
)
|
485
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
501
486
|
|
502
487
|
transform_handlers = ModelTransformerBuilder.build(
|
503
488
|
dataset=dataset,
|
@@ -539,7 +524,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
539
524
|
Transformed dataset.
|
540
525
|
"""
|
541
526
|
super()._check_dataset_type(dataset)
|
542
|
-
inference_method="transform"
|
527
|
+
inference_method = "transform"
|
543
528
|
|
544
529
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
545
530
|
# are specific to the type of dataset used.
|
@@ -569,24 +554,19 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
569
554
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
570
555
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
571
556
|
|
572
|
-
self.
|
573
|
-
|
574
|
-
inference_method=inference_method,
|
575
|
-
)
|
557
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
558
|
+
self._deps = self._get_dependencies()
|
576
559
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
577
560
|
|
578
561
|
transform_kwargs = dict(
|
579
|
-
session
|
580
|
-
dependencies
|
581
|
-
drop_input_cols
|
582
|
-
expected_output_cols_type
|
562
|
+
session=dataset._session,
|
563
|
+
dependencies=self._deps,
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_type=expected_dtype,
|
583
566
|
)
|
584
567
|
|
585
568
|
elif isinstance(dataset, pd.DataFrame):
|
586
|
-
transform_kwargs = dict(
|
587
|
-
snowpark_input_cols = self._snowpark_cols,
|
588
|
-
drop_input_cols = self._drop_input_cols
|
589
|
-
)
|
569
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
590
570
|
|
591
571
|
transform_handlers = ModelTransformerBuilder.build(
|
592
572
|
dataset=dataset,
|
@@ -605,7 +585,11 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
605
585
|
return output_df
|
606
586
|
|
607
587
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
608
|
-
def fit_predict(
|
588
|
+
def fit_predict(
|
589
|
+
self,
|
590
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
591
|
+
output_cols_prefix: str = "fit_predict_",
|
592
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
609
593
|
""" Method not supported for this class.
|
610
594
|
|
611
595
|
|
@@ -630,22 +614,106 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
630
614
|
)
|
631
615
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
632
616
|
drop_input_cols=self._drop_input_cols,
|
633
|
-
expected_output_cols_list=
|
617
|
+
expected_output_cols_list=(
|
618
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
619
|
+
),
|
634
620
|
)
|
635
621
|
self._sklearn_object = fitted_estimator
|
636
622
|
self._is_fitted = True
|
637
623
|
return output_result
|
638
624
|
|
625
|
+
|
626
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
627
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
628
|
+
""" Fit to data, then transform it
|
629
|
+
For more details on this function, see [sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform]
|
630
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchDictionaryLearning.html#sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform)
|
631
|
+
|
632
|
+
|
633
|
+
Raises:
|
634
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
639
635
|
|
640
|
-
|
641
|
-
|
642
|
-
|
636
|
+
Args:
|
637
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
638
|
+
Snowpark or Pandas DataFrame.
|
639
|
+
output_cols_prefix: Prefix for the response columns
|
643
640
|
Returns:
|
644
641
|
Transformed dataset.
|
645
642
|
"""
|
646
|
-
self.
|
647
|
-
|
648
|
-
|
643
|
+
self._infer_input_output_cols(dataset)
|
644
|
+
super()._check_dataset_type(dataset)
|
645
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
646
|
+
estimator=self._sklearn_object,
|
647
|
+
dataset=dataset,
|
648
|
+
input_cols=self.input_cols,
|
649
|
+
label_cols=self.label_cols,
|
650
|
+
sample_weight_col=self.sample_weight_col,
|
651
|
+
autogenerated=self._autogenerated,
|
652
|
+
subproject=_SUBPROJECT,
|
653
|
+
)
|
654
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
655
|
+
drop_input_cols=self._drop_input_cols,
|
656
|
+
expected_output_cols_list=self.output_cols,
|
657
|
+
)
|
658
|
+
self._sklearn_object = fitted_estimator
|
659
|
+
self._is_fitted = True
|
660
|
+
return output_result
|
661
|
+
|
662
|
+
|
663
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
664
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
665
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
666
|
+
"""
|
667
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
668
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
669
|
+
if output_cols:
|
670
|
+
output_cols = [
|
671
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
672
|
+
for c in output_cols
|
673
|
+
]
|
674
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
675
|
+
output_cols = [output_cols_prefix]
|
676
|
+
elif self._sklearn_object is not None:
|
677
|
+
classes = self._sklearn_object.classes_
|
678
|
+
if isinstance(classes, numpy.ndarray):
|
679
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
680
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
681
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
682
|
+
output_cols = []
|
683
|
+
for i, cl in enumerate(classes):
|
684
|
+
# For binary classification, there is only one output column for each class
|
685
|
+
# ndarray as the two classes are complementary.
|
686
|
+
if len(cl) == 2:
|
687
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
688
|
+
else:
|
689
|
+
output_cols.extend([
|
690
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
691
|
+
])
|
692
|
+
else:
|
693
|
+
output_cols = []
|
694
|
+
|
695
|
+
# Make sure column names are valid snowflake identifiers.
|
696
|
+
assert output_cols is not None # Make MyPy happy
|
697
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
698
|
+
|
699
|
+
return rv
|
700
|
+
|
701
|
+
def _align_expected_output_names(
|
702
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
703
|
+
) -> List[str]:
|
704
|
+
# in case the inferred output column names dimension is different
|
705
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
706
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
707
|
+
output_df_columns = list(output_df_pd.columns)
|
708
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
709
|
+
if self.sample_weight_col:
|
710
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
711
|
+
# if the dimension of inferred output column names is correct; use it
|
712
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
713
|
+
return expected_output_cols_list
|
714
|
+
# otherwise, use the sklearn estimator's output
|
715
|
+
else:
|
716
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
649
717
|
|
650
718
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
651
719
|
@telemetry.send_api_usage_telemetry(
|
@@ -677,24 +745,26 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
677
745
|
# are specific to the type of dataset used.
|
678
746
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
679
747
|
|
748
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
749
|
+
|
680
750
|
if isinstance(dataset, DataFrame):
|
681
|
-
self.
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
753
|
+
assert isinstance(
|
754
|
+
dataset._session, Session
|
755
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
686
756
|
transform_kwargs = dict(
|
687
757
|
session=dataset._session,
|
688
758
|
dependencies=self._deps,
|
689
|
-
drop_input_cols
|
759
|
+
drop_input_cols=self._drop_input_cols,
|
690
760
|
expected_output_cols_type="float",
|
691
761
|
)
|
762
|
+
expected_output_cols = self._align_expected_output_names(
|
763
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
764
|
+
)
|
692
765
|
|
693
766
|
elif isinstance(dataset, pd.DataFrame):
|
694
|
-
transform_kwargs = dict(
|
695
|
-
snowpark_input_cols = self._snowpark_cols,
|
696
|
-
drop_input_cols = self._drop_input_cols
|
697
|
-
)
|
767
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
698
768
|
|
699
769
|
transform_handlers = ModelTransformerBuilder.build(
|
700
770
|
dataset=dataset,
|
@@ -706,7 +776,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
706
776
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
707
777
|
inference_method=inference_method,
|
708
778
|
input_cols=self.input_cols,
|
709
|
-
expected_output_cols=
|
779
|
+
expected_output_cols=expected_output_cols,
|
710
780
|
**transform_kwargs
|
711
781
|
)
|
712
782
|
return output_df
|
@@ -736,29 +806,30 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
736
806
|
Output dataset with log probability of the sample for each class in the model.
|
737
807
|
"""
|
738
808
|
super()._check_dataset_type(dataset)
|
739
|
-
inference_method="predict_log_proba"
|
809
|
+
inference_method = "predict_log_proba"
|
810
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
740
811
|
|
741
812
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
742
813
|
# are specific to the type of dataset used.
|
743
814
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
744
815
|
|
745
816
|
if isinstance(dataset, DataFrame):
|
746
|
-
self.
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
817
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
818
|
+
self._deps = self._get_dependencies()
|
819
|
+
assert isinstance(
|
820
|
+
dataset._session, Session
|
821
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
751
822
|
transform_kwargs = dict(
|
752
823
|
session=dataset._session,
|
753
824
|
dependencies=self._deps,
|
754
|
-
drop_input_cols
|
825
|
+
drop_input_cols=self._drop_input_cols,
|
755
826
|
expected_output_cols_type="float",
|
756
827
|
)
|
828
|
+
expected_output_cols = self._align_expected_output_names(
|
829
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
|
+
)
|
757
831
|
elif isinstance(dataset, pd.DataFrame):
|
758
|
-
transform_kwargs = dict(
|
759
|
-
snowpark_input_cols = self._snowpark_cols,
|
760
|
-
drop_input_cols = self._drop_input_cols
|
761
|
-
)
|
832
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
762
833
|
|
763
834
|
transform_handlers = ModelTransformerBuilder.build(
|
764
835
|
dataset=dataset,
|
@@ -771,7 +842,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
771
842
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
772
843
|
inference_method=inference_method,
|
773
844
|
input_cols=self.input_cols,
|
774
|
-
expected_output_cols=
|
845
|
+
expected_output_cols=expected_output_cols,
|
775
846
|
**transform_kwargs
|
776
847
|
)
|
777
848
|
return output_df
|
@@ -797,30 +868,32 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
797
868
|
Output dataset with results of the decision function for the samples in input dataset.
|
798
869
|
"""
|
799
870
|
super()._check_dataset_type(dataset)
|
800
|
-
inference_method="decision_function"
|
871
|
+
inference_method = "decision_function"
|
801
872
|
|
802
873
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
803
874
|
# are specific to the type of dataset used.
|
804
875
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
805
876
|
|
877
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
878
|
+
|
806
879
|
if isinstance(dataset, DataFrame):
|
807
|
-
self.
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
880
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
881
|
+
self._deps = self._get_dependencies()
|
882
|
+
assert isinstance(
|
883
|
+
dataset._session, Session
|
884
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
812
885
|
transform_kwargs = dict(
|
813
886
|
session=dataset._session,
|
814
887
|
dependencies=self._deps,
|
815
|
-
drop_input_cols
|
888
|
+
drop_input_cols=self._drop_input_cols,
|
816
889
|
expected_output_cols_type="float",
|
817
890
|
)
|
891
|
+
expected_output_cols = self._align_expected_output_names(
|
892
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
893
|
+
)
|
818
894
|
|
819
895
|
elif isinstance(dataset, pd.DataFrame):
|
820
|
-
transform_kwargs = dict(
|
821
|
-
snowpark_input_cols = self._snowpark_cols,
|
822
|
-
drop_input_cols = self._drop_input_cols
|
823
|
-
)
|
896
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
824
897
|
|
825
898
|
transform_handlers = ModelTransformerBuilder.build(
|
826
899
|
dataset=dataset,
|
@@ -833,7 +906,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
833
906
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
834
907
|
inference_method=inference_method,
|
835
908
|
input_cols=self.input_cols,
|
836
|
-
expected_output_cols=
|
909
|
+
expected_output_cols=expected_output_cols,
|
837
910
|
**transform_kwargs
|
838
911
|
)
|
839
912
|
return output_df
|
@@ -862,17 +935,17 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
862
935
|
Output dataset with probability of the sample for each class in the model.
|
863
936
|
"""
|
864
937
|
super()._check_dataset_type(dataset)
|
865
|
-
inference_method="score_samples"
|
938
|
+
inference_method = "score_samples"
|
866
939
|
|
867
940
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
868
941
|
# are specific to the type of dataset used.
|
869
942
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
870
943
|
|
944
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
945
|
+
|
871
946
|
if isinstance(dataset, DataFrame):
|
872
|
-
self.
|
873
|
-
|
874
|
-
inference_method=inference_method,
|
875
|
-
)
|
947
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
948
|
+
self._deps = self._get_dependencies()
|
876
949
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
877
950
|
transform_kwargs = dict(
|
878
951
|
session=dataset._session,
|
@@ -880,6 +953,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
880
953
|
drop_input_cols = self._drop_input_cols,
|
881
954
|
expected_output_cols_type="float",
|
882
955
|
)
|
956
|
+
expected_output_cols = self._align_expected_output_names(
|
957
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
958
|
+
)
|
883
959
|
|
884
960
|
elif isinstance(dataset, pd.DataFrame):
|
885
961
|
transform_kwargs = dict(
|
@@ -898,7 +974,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
898
974
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
899
975
|
inference_method=inference_method,
|
900
976
|
input_cols=self.input_cols,
|
901
|
-
expected_output_cols=
|
977
|
+
expected_output_cols=expected_output_cols,
|
902
978
|
**transform_kwargs
|
903
979
|
)
|
904
980
|
return output_df
|
@@ -931,17 +1007,15 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
931
1007
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
932
1008
|
|
933
1009
|
if isinstance(dataset, DataFrame):
|
934
|
-
self.
|
935
|
-
|
936
|
-
inference_method="score",
|
937
|
-
)
|
1010
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1011
|
+
self._deps = self._get_dependencies()
|
938
1012
|
selected_cols = self._get_active_columns()
|
939
1013
|
if len(selected_cols) > 0:
|
940
1014
|
dataset = dataset.select(selected_cols)
|
941
1015
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
942
1016
|
transform_kwargs = dict(
|
943
1017
|
session=dataset._session,
|
944
|
-
dependencies=
|
1018
|
+
dependencies=self._deps,
|
945
1019
|
score_sproc_imports=['sklearn'],
|
946
1020
|
)
|
947
1021
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1006,11 +1080,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1006
1080
|
|
1007
1081
|
if isinstance(dataset, DataFrame):
|
1008
1082
|
|
1009
|
-
self.
|
1010
|
-
|
1011
|
-
inference_method=inference_method,
|
1012
|
-
|
1013
|
-
)
|
1083
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1084
|
+
self._deps = self._get_dependencies()
|
1014
1085
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1015
1086
|
transform_kwargs = dict(
|
1016
1087
|
session = dataset._session,
|
@@ -1043,50 +1114,84 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1043
1114
|
)
|
1044
1115
|
return output_df
|
1045
1116
|
|
1117
|
+
|
1118
|
+
|
1119
|
+
def to_sklearn(self) -> Any:
|
1120
|
+
"""Get sklearn.decomposition.MiniBatchDictionaryLearning object.
|
1121
|
+
"""
|
1122
|
+
if self._sklearn_object is None:
|
1123
|
+
self._sklearn_object = self._create_sklearn_object()
|
1124
|
+
return self._sklearn_object
|
1125
|
+
|
1126
|
+
def to_xgboost(self) -> Any:
|
1127
|
+
raise exceptions.SnowflakeMLException(
|
1128
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1129
|
+
original_exception=AttributeError(
|
1130
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1131
|
+
"to_xgboost()",
|
1132
|
+
"to_sklearn()"
|
1133
|
+
)
|
1134
|
+
),
|
1135
|
+
)
|
1046
1136
|
|
1047
|
-
def
|
1137
|
+
def to_lightgbm(self) -> Any:
|
1138
|
+
raise exceptions.SnowflakeMLException(
|
1139
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1140
|
+
original_exception=AttributeError(
|
1141
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1142
|
+
"to_lightgbm()",
|
1143
|
+
"to_sklearn()"
|
1144
|
+
)
|
1145
|
+
),
|
1146
|
+
)
|
1147
|
+
|
1148
|
+
def _get_dependencies(self) -> List[str]:
|
1149
|
+
return self._deps
|
1150
|
+
|
1151
|
+
|
1152
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1048
1153
|
self._model_signature_dict = dict()
|
1049
1154
|
|
1050
1155
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1051
1156
|
|
1052
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1157
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1053
1158
|
outputs: List[BaseFeatureSpec] = []
|
1054
1159
|
if hasattr(self, "predict"):
|
1055
1160
|
# keep mypy happy
|
1056
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1161
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1057
1162
|
# For classifier, the type of predict is the same as the type of label
|
1058
|
-
if self._sklearn_object._estimator_type ==
|
1059
|
-
|
1163
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1164
|
+
# label columns is the desired type for output
|
1060
1165
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1061
1166
|
# rename the output columns
|
1062
1167
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1063
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1064
|
-
|
1065
|
-
|
1168
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1169
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1170
|
+
)
|
1066
1171
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1067
1172
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1068
|
-
# Clusterer returns int64 cluster labels.
|
1173
|
+
# Clusterer returns int64 cluster labels.
|
1069
1174
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1070
1175
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1071
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1176
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1177
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1178
|
+
)
|
1179
|
+
|
1075
1180
|
# For regressor, the type of predict is float64
|
1076
|
-
elif self._sklearn_object._estimator_type ==
|
1181
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1077
1182
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1078
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1183
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1184
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1185
|
+
)
|
1186
|
+
|
1082
1187
|
for prob_func in PROB_FUNCTIONS:
|
1083
1188
|
if hasattr(self, prob_func):
|
1084
1189
|
output_cols_prefix: str = f"{prob_func}_"
|
1085
1190
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1086
1191
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1087
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1088
|
-
|
1089
|
-
|
1192
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1193
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1194
|
+
)
|
1090
1195
|
|
1091
1196
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1092
1197
|
items = list(self._model_signature_dict.items())
|
@@ -1099,10 +1204,10 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1099
1204
|
"""Returns model signature of current class.
|
1100
1205
|
|
1101
1206
|
Raises:
|
1102
|
-
|
1207
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1103
1208
|
|
1104
1209
|
Returns:
|
1105
|
-
Dict
|
1210
|
+
Dict with each method and its input output signature
|
1106
1211
|
"""
|
1107
1212
|
if self._model_signature_dict is None:
|
1108
1213
|
raise exceptions.SnowflakeMLException(
|
@@ -1110,35 +1215,3 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1110
1215
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1111
1216
|
)
|
1112
1217
|
return self._model_signature_dict
|
1113
|
-
|
1114
|
-
def to_sklearn(self) -> Any:
|
1115
|
-
"""Get sklearn.decomposition.MiniBatchDictionaryLearning object.
|
1116
|
-
"""
|
1117
|
-
if self._sklearn_object is None:
|
1118
|
-
self._sklearn_object = self._create_sklearn_object()
|
1119
|
-
return self._sklearn_object
|
1120
|
-
|
1121
|
-
def to_xgboost(self) -> Any:
|
1122
|
-
raise exceptions.SnowflakeMLException(
|
1123
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1124
|
-
original_exception=AttributeError(
|
1125
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1126
|
-
"to_xgboost()",
|
1127
|
-
"to_sklearn()"
|
1128
|
-
)
|
1129
|
-
),
|
1130
|
-
)
|
1131
|
-
|
1132
|
-
def to_lightgbm(self) -> Any:
|
1133
|
-
raise exceptions.SnowflakeMLException(
|
1134
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1135
|
-
original_exception=AttributeError(
|
1136
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1137
|
-
"to_lightgbm()",
|
1138
|
-
"to_sklearn()"
|
1139
|
-
)
|
1140
|
-
),
|
1141
|
-
)
|
1142
|
-
|
1143
|
-
def _get_dependencies(self) -> List[str]:
|
1144
|
-
return self._deps
|