snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class DictionaryLearning(BaseTransformer):
|
71
64
|
r"""Dictionary learning
|
72
65
|
For more details on this class, see [sklearn.decomposition.DictionaryLearning]
|
@@ -314,12 +307,7 @@ class DictionaryLearning(BaseTransformer):
|
|
314
307
|
)
|
315
308
|
return selected_cols
|
316
309
|
|
317
|
-
|
318
|
-
project=_PROJECT,
|
319
|
-
subproject=_SUBPROJECT,
|
320
|
-
custom_tags=dict([("autogen", True)]),
|
321
|
-
)
|
322
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "DictionaryLearning":
|
310
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "DictionaryLearning":
|
323
311
|
"""Fit the model from data in X
|
324
312
|
For more details on this function, see [sklearn.decomposition.DictionaryLearning.fit]
|
325
313
|
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.DictionaryLearning.html#sklearn.decomposition.DictionaryLearning.fit)
|
@@ -346,12 +334,14 @@ class DictionaryLearning(BaseTransformer):
|
|
346
334
|
|
347
335
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
348
336
|
|
349
|
-
|
337
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
350
338
|
if SNOWML_SPROC_ENV in os.environ:
|
351
339
|
statement_params = telemetry.get_function_usage_statement_params(
|
352
340
|
project=_PROJECT,
|
353
341
|
subproject=_SUBPROJECT,
|
354
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
342
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
343
|
+
inspect.currentframe(), DictionaryLearning.__class__.__name__
|
344
|
+
),
|
355
345
|
api_calls=[Session.call],
|
356
346
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
357
347
|
)
|
@@ -372,27 +362,24 @@ class DictionaryLearning(BaseTransformer):
|
|
372
362
|
)
|
373
363
|
self._sklearn_object = model_trainer.train()
|
374
364
|
self._is_fitted = True
|
375
|
-
self.
|
365
|
+
self._generate_model_signatures(dataset)
|
376
366
|
return self
|
377
367
|
|
378
368
|
def _batch_inference_validate_snowpark(
|
379
369
|
self,
|
380
370
|
dataset: DataFrame,
|
381
371
|
inference_method: str,
|
382
|
-
) ->
|
383
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
384
|
-
return the available package that exists in the snowflake anaconda channel
|
372
|
+
) -> None:
|
373
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
385
374
|
|
386
375
|
Args:
|
387
376
|
dataset: snowpark dataframe
|
388
377
|
inference_method: the inference method such as predict, score...
|
389
|
-
|
378
|
+
|
390
379
|
Raises:
|
391
380
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
392
381
|
SnowflakeMLException: If the session is None, raise error
|
393
382
|
|
394
|
-
Returns:
|
395
|
-
A list of available package that exists in the snowflake anaconda channel
|
396
383
|
"""
|
397
384
|
if not self._is_fitted:
|
398
385
|
raise exceptions.SnowflakeMLException(
|
@@ -410,9 +397,7 @@ class DictionaryLearning(BaseTransformer):
|
|
410
397
|
"Session must not specified for snowpark dataset."
|
411
398
|
),
|
412
399
|
)
|
413
|
-
|
414
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
415
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
400
|
+
|
416
401
|
|
417
402
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
418
403
|
@telemetry.send_api_usage_telemetry(
|
@@ -446,7 +431,9 @@ class DictionaryLearning(BaseTransformer):
|
|
446
431
|
# when it is classifier, infer the datatype from label columns
|
447
432
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
448
433
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
449
|
-
label_cols_signatures = [
|
434
|
+
label_cols_signatures = [
|
435
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
436
|
+
]
|
450
437
|
if len(label_cols_signatures) == 0:
|
451
438
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
452
439
|
raise exceptions.SnowflakeMLException(
|
@@ -454,25 +441,23 @@ class DictionaryLearning(BaseTransformer):
|
|
454
441
|
original_exception=ValueError(error_str),
|
455
442
|
)
|
456
443
|
|
457
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
458
|
-
label_cols_signatures[0].as_snowpark_type()
|
459
|
-
)
|
444
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
460
445
|
|
461
|
-
self.
|
462
|
-
|
446
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
447
|
+
self._deps = self._get_dependencies()
|
448
|
+
assert isinstance(
|
449
|
+
dataset._session, Session
|
450
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
463
451
|
|
464
452
|
transform_kwargs = dict(
|
465
|
-
session
|
466
|
-
dependencies
|
467
|
-
drop_input_cols
|
468
|
-
expected_output_cols_type
|
453
|
+
session=dataset._session,
|
454
|
+
dependencies=self._deps,
|
455
|
+
drop_input_cols=self._drop_input_cols,
|
456
|
+
expected_output_cols_type=expected_type_inferred,
|
469
457
|
)
|
470
458
|
|
471
459
|
elif isinstance(dataset, pd.DataFrame):
|
472
|
-
transform_kwargs = dict(
|
473
|
-
snowpark_input_cols = self._snowpark_cols,
|
474
|
-
drop_input_cols = self._drop_input_cols
|
475
|
-
)
|
460
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
476
461
|
|
477
462
|
transform_handlers = ModelTransformerBuilder.build(
|
478
463
|
dataset=dataset,
|
@@ -514,7 +499,7 @@ class DictionaryLearning(BaseTransformer):
|
|
514
499
|
Transformed dataset.
|
515
500
|
"""
|
516
501
|
super()._check_dataset_type(dataset)
|
517
|
-
inference_method="transform"
|
502
|
+
inference_method = "transform"
|
518
503
|
|
519
504
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
520
505
|
# are specific to the type of dataset used.
|
@@ -544,24 +529,19 @@ class DictionaryLearning(BaseTransformer):
|
|
544
529
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
545
530
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
546
531
|
|
547
|
-
self.
|
548
|
-
|
549
|
-
inference_method=inference_method,
|
550
|
-
)
|
532
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
533
|
+
self._deps = self._get_dependencies()
|
551
534
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
552
535
|
|
553
536
|
transform_kwargs = dict(
|
554
|
-
session
|
555
|
-
dependencies
|
556
|
-
drop_input_cols
|
557
|
-
expected_output_cols_type
|
537
|
+
session=dataset._session,
|
538
|
+
dependencies=self._deps,
|
539
|
+
drop_input_cols=self._drop_input_cols,
|
540
|
+
expected_output_cols_type=expected_dtype,
|
558
541
|
)
|
559
542
|
|
560
543
|
elif isinstance(dataset, pd.DataFrame):
|
561
|
-
transform_kwargs = dict(
|
562
|
-
snowpark_input_cols = self._snowpark_cols,
|
563
|
-
drop_input_cols = self._drop_input_cols
|
564
|
-
)
|
544
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
565
545
|
|
566
546
|
transform_handlers = ModelTransformerBuilder.build(
|
567
547
|
dataset=dataset,
|
@@ -580,7 +560,11 @@ class DictionaryLearning(BaseTransformer):
|
|
580
560
|
return output_df
|
581
561
|
|
582
562
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
583
|
-
def fit_predict(
|
563
|
+
def fit_predict(
|
564
|
+
self,
|
565
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
566
|
+
output_cols_prefix: str = "fit_predict_",
|
567
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
584
568
|
""" Method not supported for this class.
|
585
569
|
|
586
570
|
|
@@ -605,22 +589,106 @@ class DictionaryLearning(BaseTransformer):
|
|
605
589
|
)
|
606
590
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
607
591
|
drop_input_cols=self._drop_input_cols,
|
608
|
-
expected_output_cols_list=
|
592
|
+
expected_output_cols_list=(
|
593
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
594
|
+
),
|
609
595
|
)
|
610
596
|
self._sklearn_object = fitted_estimator
|
611
597
|
self._is_fitted = True
|
612
598
|
return output_result
|
613
599
|
|
600
|
+
|
601
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
602
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
603
|
+
""" Fit the model from data in X and return the transformed data
|
604
|
+
For more details on this function, see [sklearn.decomposition.DictionaryLearning.fit_transform]
|
605
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.DictionaryLearning.html#sklearn.decomposition.DictionaryLearning.fit_transform)
|
606
|
+
|
607
|
+
|
608
|
+
Raises:
|
609
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
614
610
|
|
615
|
-
|
616
|
-
|
617
|
-
|
611
|
+
Args:
|
612
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
613
|
+
Snowpark or Pandas DataFrame.
|
614
|
+
output_cols_prefix: Prefix for the response columns
|
618
615
|
Returns:
|
619
616
|
Transformed dataset.
|
620
617
|
"""
|
621
|
-
self.
|
622
|
-
|
623
|
-
|
618
|
+
self._infer_input_output_cols(dataset)
|
619
|
+
super()._check_dataset_type(dataset)
|
620
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
621
|
+
estimator=self._sklearn_object,
|
622
|
+
dataset=dataset,
|
623
|
+
input_cols=self.input_cols,
|
624
|
+
label_cols=self.label_cols,
|
625
|
+
sample_weight_col=self.sample_weight_col,
|
626
|
+
autogenerated=self._autogenerated,
|
627
|
+
subproject=_SUBPROJECT,
|
628
|
+
)
|
629
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
630
|
+
drop_input_cols=self._drop_input_cols,
|
631
|
+
expected_output_cols_list=self.output_cols,
|
632
|
+
)
|
633
|
+
self._sklearn_object = fitted_estimator
|
634
|
+
self._is_fitted = True
|
635
|
+
return output_result
|
636
|
+
|
637
|
+
|
638
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
639
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
640
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
641
|
+
"""
|
642
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
643
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
644
|
+
if output_cols:
|
645
|
+
output_cols = [
|
646
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
647
|
+
for c in output_cols
|
648
|
+
]
|
649
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
650
|
+
output_cols = [output_cols_prefix]
|
651
|
+
elif self._sklearn_object is not None:
|
652
|
+
classes = self._sklearn_object.classes_
|
653
|
+
if isinstance(classes, numpy.ndarray):
|
654
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
655
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
656
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
657
|
+
output_cols = []
|
658
|
+
for i, cl in enumerate(classes):
|
659
|
+
# For binary classification, there is only one output column for each class
|
660
|
+
# ndarray as the two classes are complementary.
|
661
|
+
if len(cl) == 2:
|
662
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
663
|
+
else:
|
664
|
+
output_cols.extend([
|
665
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
666
|
+
])
|
667
|
+
else:
|
668
|
+
output_cols = []
|
669
|
+
|
670
|
+
# Make sure column names are valid snowflake identifiers.
|
671
|
+
assert output_cols is not None # Make MyPy happy
|
672
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
673
|
+
|
674
|
+
return rv
|
675
|
+
|
676
|
+
def _align_expected_output_names(
|
677
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
678
|
+
) -> List[str]:
|
679
|
+
# in case the inferred output column names dimension is different
|
680
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
681
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
682
|
+
output_df_columns = list(output_df_pd.columns)
|
683
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
684
|
+
if self.sample_weight_col:
|
685
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
686
|
+
# if the dimension of inferred output column names is correct; use it
|
687
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
688
|
+
return expected_output_cols_list
|
689
|
+
# otherwise, use the sklearn estimator's output
|
690
|
+
else:
|
691
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
624
692
|
|
625
693
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
626
694
|
@telemetry.send_api_usage_telemetry(
|
@@ -652,24 +720,26 @@ class DictionaryLearning(BaseTransformer):
|
|
652
720
|
# are specific to the type of dataset used.
|
653
721
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
654
722
|
|
723
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
724
|
+
|
655
725
|
if isinstance(dataset, DataFrame):
|
656
|
-
self.
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
726
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
727
|
+
self._deps = self._get_dependencies()
|
728
|
+
assert isinstance(
|
729
|
+
dataset._session, Session
|
730
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
661
731
|
transform_kwargs = dict(
|
662
732
|
session=dataset._session,
|
663
733
|
dependencies=self._deps,
|
664
|
-
drop_input_cols
|
734
|
+
drop_input_cols=self._drop_input_cols,
|
665
735
|
expected_output_cols_type="float",
|
666
736
|
)
|
737
|
+
expected_output_cols = self._align_expected_output_names(
|
738
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
739
|
+
)
|
667
740
|
|
668
741
|
elif isinstance(dataset, pd.DataFrame):
|
669
|
-
transform_kwargs = dict(
|
670
|
-
snowpark_input_cols = self._snowpark_cols,
|
671
|
-
drop_input_cols = self._drop_input_cols
|
672
|
-
)
|
742
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
673
743
|
|
674
744
|
transform_handlers = ModelTransformerBuilder.build(
|
675
745
|
dataset=dataset,
|
@@ -681,7 +751,7 @@ class DictionaryLearning(BaseTransformer):
|
|
681
751
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
682
752
|
inference_method=inference_method,
|
683
753
|
input_cols=self.input_cols,
|
684
|
-
expected_output_cols=
|
754
|
+
expected_output_cols=expected_output_cols,
|
685
755
|
**transform_kwargs
|
686
756
|
)
|
687
757
|
return output_df
|
@@ -711,29 +781,30 @@ class DictionaryLearning(BaseTransformer):
|
|
711
781
|
Output dataset with log probability of the sample for each class in the model.
|
712
782
|
"""
|
713
783
|
super()._check_dataset_type(dataset)
|
714
|
-
inference_method="predict_log_proba"
|
784
|
+
inference_method = "predict_log_proba"
|
785
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
715
786
|
|
716
787
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
717
788
|
# are specific to the type of dataset used.
|
718
789
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
719
790
|
|
720
791
|
if isinstance(dataset, DataFrame):
|
721
|
-
self.
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
792
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
793
|
+
self._deps = self._get_dependencies()
|
794
|
+
assert isinstance(
|
795
|
+
dataset._session, Session
|
796
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
726
797
|
transform_kwargs = dict(
|
727
798
|
session=dataset._session,
|
728
799
|
dependencies=self._deps,
|
729
|
-
drop_input_cols
|
800
|
+
drop_input_cols=self._drop_input_cols,
|
730
801
|
expected_output_cols_type="float",
|
731
802
|
)
|
803
|
+
expected_output_cols = self._align_expected_output_names(
|
804
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
805
|
+
)
|
732
806
|
elif isinstance(dataset, pd.DataFrame):
|
733
|
-
transform_kwargs = dict(
|
734
|
-
snowpark_input_cols = self._snowpark_cols,
|
735
|
-
drop_input_cols = self._drop_input_cols
|
736
|
-
)
|
807
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
737
808
|
|
738
809
|
transform_handlers = ModelTransformerBuilder.build(
|
739
810
|
dataset=dataset,
|
@@ -746,7 +817,7 @@ class DictionaryLearning(BaseTransformer):
|
|
746
817
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
747
818
|
inference_method=inference_method,
|
748
819
|
input_cols=self.input_cols,
|
749
|
-
expected_output_cols=
|
820
|
+
expected_output_cols=expected_output_cols,
|
750
821
|
**transform_kwargs
|
751
822
|
)
|
752
823
|
return output_df
|
@@ -772,30 +843,32 @@ class DictionaryLearning(BaseTransformer):
|
|
772
843
|
Output dataset with results of the decision function for the samples in input dataset.
|
773
844
|
"""
|
774
845
|
super()._check_dataset_type(dataset)
|
775
|
-
inference_method="decision_function"
|
846
|
+
inference_method = "decision_function"
|
776
847
|
|
777
848
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
778
849
|
# are specific to the type of dataset used.
|
779
850
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
780
851
|
|
852
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
853
|
+
|
781
854
|
if isinstance(dataset, DataFrame):
|
782
|
-
self.
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
857
|
+
assert isinstance(
|
858
|
+
dataset._session, Session
|
859
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
787
860
|
transform_kwargs = dict(
|
788
861
|
session=dataset._session,
|
789
862
|
dependencies=self._deps,
|
790
|
-
drop_input_cols
|
863
|
+
drop_input_cols=self._drop_input_cols,
|
791
864
|
expected_output_cols_type="float",
|
792
865
|
)
|
866
|
+
expected_output_cols = self._align_expected_output_names(
|
867
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
868
|
+
)
|
793
869
|
|
794
870
|
elif isinstance(dataset, pd.DataFrame):
|
795
|
-
transform_kwargs = dict(
|
796
|
-
snowpark_input_cols = self._snowpark_cols,
|
797
|
-
drop_input_cols = self._drop_input_cols
|
798
|
-
)
|
871
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
799
872
|
|
800
873
|
transform_handlers = ModelTransformerBuilder.build(
|
801
874
|
dataset=dataset,
|
@@ -808,7 +881,7 @@ class DictionaryLearning(BaseTransformer):
|
|
808
881
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
809
882
|
inference_method=inference_method,
|
810
883
|
input_cols=self.input_cols,
|
811
|
-
expected_output_cols=
|
884
|
+
expected_output_cols=expected_output_cols,
|
812
885
|
**transform_kwargs
|
813
886
|
)
|
814
887
|
return output_df
|
@@ -837,17 +910,17 @@ class DictionaryLearning(BaseTransformer):
|
|
837
910
|
Output dataset with probability of the sample for each class in the model.
|
838
911
|
"""
|
839
912
|
super()._check_dataset_type(dataset)
|
840
|
-
inference_method="score_samples"
|
913
|
+
inference_method = "score_samples"
|
841
914
|
|
842
915
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
843
916
|
# are specific to the type of dataset used.
|
844
917
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
845
918
|
|
919
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
920
|
+
|
846
921
|
if isinstance(dataset, DataFrame):
|
847
|
-
self.
|
848
|
-
|
849
|
-
inference_method=inference_method,
|
850
|
-
)
|
922
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
923
|
+
self._deps = self._get_dependencies()
|
851
924
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
852
925
|
transform_kwargs = dict(
|
853
926
|
session=dataset._session,
|
@@ -855,6 +928,9 @@ class DictionaryLearning(BaseTransformer):
|
|
855
928
|
drop_input_cols = self._drop_input_cols,
|
856
929
|
expected_output_cols_type="float",
|
857
930
|
)
|
931
|
+
expected_output_cols = self._align_expected_output_names(
|
932
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
933
|
+
)
|
858
934
|
|
859
935
|
elif isinstance(dataset, pd.DataFrame):
|
860
936
|
transform_kwargs = dict(
|
@@ -873,7 +949,7 @@ class DictionaryLearning(BaseTransformer):
|
|
873
949
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
874
950
|
inference_method=inference_method,
|
875
951
|
input_cols=self.input_cols,
|
876
|
-
expected_output_cols=
|
952
|
+
expected_output_cols=expected_output_cols,
|
877
953
|
**transform_kwargs
|
878
954
|
)
|
879
955
|
return output_df
|
@@ -906,17 +982,15 @@ class DictionaryLearning(BaseTransformer):
|
|
906
982
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
907
983
|
|
908
984
|
if isinstance(dataset, DataFrame):
|
909
|
-
self.
|
910
|
-
|
911
|
-
inference_method="score",
|
912
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
986
|
+
self._deps = self._get_dependencies()
|
913
987
|
selected_cols = self._get_active_columns()
|
914
988
|
if len(selected_cols) > 0:
|
915
989
|
dataset = dataset.select(selected_cols)
|
916
990
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
917
991
|
transform_kwargs = dict(
|
918
992
|
session=dataset._session,
|
919
|
-
dependencies=
|
993
|
+
dependencies=self._deps,
|
920
994
|
score_sproc_imports=['sklearn'],
|
921
995
|
)
|
922
996
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -981,11 +1055,8 @@ class DictionaryLearning(BaseTransformer):
|
|
981
1055
|
|
982
1056
|
if isinstance(dataset, DataFrame):
|
983
1057
|
|
984
|
-
self.
|
985
|
-
|
986
|
-
inference_method=inference_method,
|
987
|
-
|
988
|
-
)
|
1058
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1059
|
+
self._deps = self._get_dependencies()
|
989
1060
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
990
1061
|
transform_kwargs = dict(
|
991
1062
|
session = dataset._session,
|
@@ -1018,50 +1089,84 @@ class DictionaryLearning(BaseTransformer):
|
|
1018
1089
|
)
|
1019
1090
|
return output_df
|
1020
1091
|
|
1092
|
+
|
1093
|
+
|
1094
|
+
def to_sklearn(self) -> Any:
|
1095
|
+
"""Get sklearn.decomposition.DictionaryLearning object.
|
1096
|
+
"""
|
1097
|
+
if self._sklearn_object is None:
|
1098
|
+
self._sklearn_object = self._create_sklearn_object()
|
1099
|
+
return self._sklearn_object
|
1100
|
+
|
1101
|
+
def to_xgboost(self) -> Any:
|
1102
|
+
raise exceptions.SnowflakeMLException(
|
1103
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1104
|
+
original_exception=AttributeError(
|
1105
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1106
|
+
"to_xgboost()",
|
1107
|
+
"to_sklearn()"
|
1108
|
+
)
|
1109
|
+
),
|
1110
|
+
)
|
1021
1111
|
|
1022
|
-
def
|
1112
|
+
def to_lightgbm(self) -> Any:
|
1113
|
+
raise exceptions.SnowflakeMLException(
|
1114
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1115
|
+
original_exception=AttributeError(
|
1116
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1117
|
+
"to_lightgbm()",
|
1118
|
+
"to_sklearn()"
|
1119
|
+
)
|
1120
|
+
),
|
1121
|
+
)
|
1122
|
+
|
1123
|
+
def _get_dependencies(self) -> List[str]:
|
1124
|
+
return self._deps
|
1125
|
+
|
1126
|
+
|
1127
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1023
1128
|
self._model_signature_dict = dict()
|
1024
1129
|
|
1025
1130
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1026
1131
|
|
1027
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1132
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1028
1133
|
outputs: List[BaseFeatureSpec] = []
|
1029
1134
|
if hasattr(self, "predict"):
|
1030
1135
|
# keep mypy happy
|
1031
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1136
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1032
1137
|
# For classifier, the type of predict is the same as the type of label
|
1033
|
-
if self._sklearn_object._estimator_type ==
|
1034
|
-
|
1138
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1139
|
+
# label columns is the desired type for output
|
1035
1140
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1036
1141
|
# rename the output columns
|
1037
1142
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1038
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1039
|
-
|
1040
|
-
|
1143
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1144
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1145
|
+
)
|
1041
1146
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1042
1147
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1043
|
-
# Clusterer returns int64 cluster labels.
|
1148
|
+
# Clusterer returns int64 cluster labels.
|
1044
1149
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1045
1150
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1046
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1151
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1152
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1153
|
+
)
|
1154
|
+
|
1050
1155
|
# For regressor, the type of predict is float64
|
1051
|
-
elif self._sklearn_object._estimator_type ==
|
1156
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1052
1157
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1053
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1158
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1159
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1160
|
+
)
|
1161
|
+
|
1057
1162
|
for prob_func in PROB_FUNCTIONS:
|
1058
1163
|
if hasattr(self, prob_func):
|
1059
1164
|
output_cols_prefix: str = f"{prob_func}_"
|
1060
1165
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1061
1166
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1062
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1063
|
-
|
1064
|
-
|
1167
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1168
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1169
|
+
)
|
1065
1170
|
|
1066
1171
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1067
1172
|
items = list(self._model_signature_dict.items())
|
@@ -1074,10 +1179,10 @@ class DictionaryLearning(BaseTransformer):
|
|
1074
1179
|
"""Returns model signature of current class.
|
1075
1180
|
|
1076
1181
|
Raises:
|
1077
|
-
|
1182
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1078
1183
|
|
1079
1184
|
Returns:
|
1080
|
-
Dict
|
1185
|
+
Dict with each method and its input output signature
|
1081
1186
|
"""
|
1082
1187
|
if self._model_signature_dict is None:
|
1083
1188
|
raise exceptions.SnowflakeMLException(
|
@@ -1085,35 +1190,3 @@ class DictionaryLearning(BaseTransformer):
|
|
1085
1190
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1086
1191
|
)
|
1087
1192
|
return self._model_signature_dict
|
1088
|
-
|
1089
|
-
def to_sklearn(self) -> Any:
|
1090
|
-
"""Get sklearn.decomposition.DictionaryLearning object.
|
1091
|
-
"""
|
1092
|
-
if self._sklearn_object is None:
|
1093
|
-
self._sklearn_object = self._create_sklearn_object()
|
1094
|
-
return self._sklearn_object
|
1095
|
-
|
1096
|
-
def to_xgboost(self) -> Any:
|
1097
|
-
raise exceptions.SnowflakeMLException(
|
1098
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1099
|
-
original_exception=AttributeError(
|
1100
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1101
|
-
"to_xgboost()",
|
1102
|
-
"to_sklearn()"
|
1103
|
-
)
|
1104
|
-
),
|
1105
|
-
)
|
1106
|
-
|
1107
|
-
def to_lightgbm(self) -> Any:
|
1108
|
-
raise exceptions.SnowflakeMLException(
|
1109
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1110
|
-
original_exception=AttributeError(
|
1111
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1112
|
-
"to_lightgbm()",
|
1113
|
-
"to_sklearn()"
|
1114
|
-
)
|
1115
|
-
),
|
1116
|
-
)
|
1117
|
-
|
1118
|
-
def _get_dependencies(self) -> List[str]:
|
1119
|
-
return self._deps
|