snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class BernoulliRBM(BaseTransformer):
|
71
64
|
r"""Bernoulli Restricted Boltzmann Machine (RBM)
|
72
65
|
For more details on this class, see [sklearn.neural_network.BernoulliRBM]
|
@@ -232,12 +225,7 @@ class BernoulliRBM(BaseTransformer):
|
|
232
225
|
)
|
233
226
|
return selected_cols
|
234
227
|
|
235
|
-
|
236
|
-
project=_PROJECT,
|
237
|
-
subproject=_SUBPROJECT,
|
238
|
-
custom_tags=dict([("autogen", True)]),
|
239
|
-
)
|
240
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BernoulliRBM":
|
228
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BernoulliRBM":
|
241
229
|
"""Fit the model to the data X
|
242
230
|
For more details on this function, see [sklearn.neural_network.BernoulliRBM.fit]
|
243
231
|
(https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.BernoulliRBM.html#sklearn.neural_network.BernoulliRBM.fit)
|
@@ -264,12 +252,14 @@ class BernoulliRBM(BaseTransformer):
|
|
264
252
|
|
265
253
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
266
254
|
|
267
|
-
|
255
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
268
256
|
if SNOWML_SPROC_ENV in os.environ:
|
269
257
|
statement_params = telemetry.get_function_usage_statement_params(
|
270
258
|
project=_PROJECT,
|
271
259
|
subproject=_SUBPROJECT,
|
272
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
260
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
261
|
+
inspect.currentframe(), BernoulliRBM.__class__.__name__
|
262
|
+
),
|
273
263
|
api_calls=[Session.call],
|
274
264
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
275
265
|
)
|
@@ -290,27 +280,24 @@ class BernoulliRBM(BaseTransformer):
|
|
290
280
|
)
|
291
281
|
self._sklearn_object = model_trainer.train()
|
292
282
|
self._is_fitted = True
|
293
|
-
self.
|
283
|
+
self._generate_model_signatures(dataset)
|
294
284
|
return self
|
295
285
|
|
296
286
|
def _batch_inference_validate_snowpark(
|
297
287
|
self,
|
298
288
|
dataset: DataFrame,
|
299
289
|
inference_method: str,
|
300
|
-
) ->
|
301
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
302
|
-
return the available package that exists in the snowflake anaconda channel
|
290
|
+
) -> None:
|
291
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
303
292
|
|
304
293
|
Args:
|
305
294
|
dataset: snowpark dataframe
|
306
295
|
inference_method: the inference method such as predict, score...
|
307
|
-
|
296
|
+
|
308
297
|
Raises:
|
309
298
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
310
299
|
SnowflakeMLException: If the session is None, raise error
|
311
300
|
|
312
|
-
Returns:
|
313
|
-
A list of available package that exists in the snowflake anaconda channel
|
314
301
|
"""
|
315
302
|
if not self._is_fitted:
|
316
303
|
raise exceptions.SnowflakeMLException(
|
@@ -328,9 +315,7 @@ class BernoulliRBM(BaseTransformer):
|
|
328
315
|
"Session must not specified for snowpark dataset."
|
329
316
|
),
|
330
317
|
)
|
331
|
-
|
332
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
333
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
318
|
+
|
334
319
|
|
335
320
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
336
321
|
@telemetry.send_api_usage_telemetry(
|
@@ -364,7 +349,9 @@ class BernoulliRBM(BaseTransformer):
|
|
364
349
|
# when it is classifier, infer the datatype from label columns
|
365
350
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
366
351
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
367
|
-
label_cols_signatures = [
|
352
|
+
label_cols_signatures = [
|
353
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
354
|
+
]
|
368
355
|
if len(label_cols_signatures) == 0:
|
369
356
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
370
357
|
raise exceptions.SnowflakeMLException(
|
@@ -372,25 +359,23 @@ class BernoulliRBM(BaseTransformer):
|
|
372
359
|
original_exception=ValueError(error_str),
|
373
360
|
)
|
374
361
|
|
375
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
376
|
-
label_cols_signatures[0].as_snowpark_type()
|
377
|
-
)
|
362
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
378
363
|
|
379
|
-
self.
|
380
|
-
|
364
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
365
|
+
self._deps = self._get_dependencies()
|
366
|
+
assert isinstance(
|
367
|
+
dataset._session, Session
|
368
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
381
369
|
|
382
370
|
transform_kwargs = dict(
|
383
|
-
session
|
384
|
-
dependencies
|
385
|
-
drop_input_cols
|
386
|
-
expected_output_cols_type
|
371
|
+
session=dataset._session,
|
372
|
+
dependencies=self._deps,
|
373
|
+
drop_input_cols=self._drop_input_cols,
|
374
|
+
expected_output_cols_type=expected_type_inferred,
|
387
375
|
)
|
388
376
|
|
389
377
|
elif isinstance(dataset, pd.DataFrame):
|
390
|
-
transform_kwargs = dict(
|
391
|
-
snowpark_input_cols = self._snowpark_cols,
|
392
|
-
drop_input_cols = self._drop_input_cols
|
393
|
-
)
|
378
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
394
379
|
|
395
380
|
transform_handlers = ModelTransformerBuilder.build(
|
396
381
|
dataset=dataset,
|
@@ -432,7 +417,7 @@ class BernoulliRBM(BaseTransformer):
|
|
432
417
|
Transformed dataset.
|
433
418
|
"""
|
434
419
|
super()._check_dataset_type(dataset)
|
435
|
-
inference_method="transform"
|
420
|
+
inference_method = "transform"
|
436
421
|
|
437
422
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
438
423
|
# are specific to the type of dataset used.
|
@@ -462,24 +447,19 @@ class BernoulliRBM(BaseTransformer):
|
|
462
447
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
463
448
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
464
449
|
|
465
|
-
self.
|
466
|
-
|
467
|
-
inference_method=inference_method,
|
468
|
-
)
|
450
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
451
|
+
self._deps = self._get_dependencies()
|
469
452
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
470
453
|
|
471
454
|
transform_kwargs = dict(
|
472
|
-
session
|
473
|
-
dependencies
|
474
|
-
drop_input_cols
|
475
|
-
expected_output_cols_type
|
455
|
+
session=dataset._session,
|
456
|
+
dependencies=self._deps,
|
457
|
+
drop_input_cols=self._drop_input_cols,
|
458
|
+
expected_output_cols_type=expected_dtype,
|
476
459
|
)
|
477
460
|
|
478
461
|
elif isinstance(dataset, pd.DataFrame):
|
479
|
-
transform_kwargs = dict(
|
480
|
-
snowpark_input_cols = self._snowpark_cols,
|
481
|
-
drop_input_cols = self._drop_input_cols
|
482
|
-
)
|
462
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
483
463
|
|
484
464
|
transform_handlers = ModelTransformerBuilder.build(
|
485
465
|
dataset=dataset,
|
@@ -498,7 +478,11 @@ class BernoulliRBM(BaseTransformer):
|
|
498
478
|
return output_df
|
499
479
|
|
500
480
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
501
|
-
def fit_predict(
|
481
|
+
def fit_predict(
|
482
|
+
self,
|
483
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
484
|
+
output_cols_prefix: str = "fit_predict_",
|
485
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
502
486
|
""" Method not supported for this class.
|
503
487
|
|
504
488
|
|
@@ -523,22 +507,106 @@ class BernoulliRBM(BaseTransformer):
|
|
523
507
|
)
|
524
508
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
525
509
|
drop_input_cols=self._drop_input_cols,
|
526
|
-
expected_output_cols_list=
|
510
|
+
expected_output_cols_list=(
|
511
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
512
|
+
),
|
527
513
|
)
|
528
514
|
self._sklearn_object = fitted_estimator
|
529
515
|
self._is_fitted = True
|
530
516
|
return output_result
|
531
517
|
|
518
|
+
|
519
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
520
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
521
|
+
""" Fit to data, then transform it
|
522
|
+
For more details on this function, see [sklearn.neural_network.BernoulliRBM.fit_transform]
|
523
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.BernoulliRBM.html#sklearn.neural_network.BernoulliRBM.fit_transform)
|
524
|
+
|
525
|
+
|
526
|
+
Raises:
|
527
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
532
528
|
|
533
|
-
|
534
|
-
|
535
|
-
|
529
|
+
Args:
|
530
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
531
|
+
Snowpark or Pandas DataFrame.
|
532
|
+
output_cols_prefix: Prefix for the response columns
|
536
533
|
Returns:
|
537
534
|
Transformed dataset.
|
538
535
|
"""
|
539
|
-
self.
|
540
|
-
|
541
|
-
|
536
|
+
self._infer_input_output_cols(dataset)
|
537
|
+
super()._check_dataset_type(dataset)
|
538
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
539
|
+
estimator=self._sklearn_object,
|
540
|
+
dataset=dataset,
|
541
|
+
input_cols=self.input_cols,
|
542
|
+
label_cols=self.label_cols,
|
543
|
+
sample_weight_col=self.sample_weight_col,
|
544
|
+
autogenerated=self._autogenerated,
|
545
|
+
subproject=_SUBPROJECT,
|
546
|
+
)
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=self.output_cols,
|
550
|
+
)
|
551
|
+
self._sklearn_object = fitted_estimator
|
552
|
+
self._is_fitted = True
|
553
|
+
return output_result
|
554
|
+
|
555
|
+
|
556
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
557
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
558
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
559
|
+
"""
|
560
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
561
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
562
|
+
if output_cols:
|
563
|
+
output_cols = [
|
564
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
565
|
+
for c in output_cols
|
566
|
+
]
|
567
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
568
|
+
output_cols = [output_cols_prefix]
|
569
|
+
elif self._sklearn_object is not None:
|
570
|
+
classes = self._sklearn_object.classes_
|
571
|
+
if isinstance(classes, numpy.ndarray):
|
572
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
573
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
574
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
575
|
+
output_cols = []
|
576
|
+
for i, cl in enumerate(classes):
|
577
|
+
# For binary classification, there is only one output column for each class
|
578
|
+
# ndarray as the two classes are complementary.
|
579
|
+
if len(cl) == 2:
|
580
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
581
|
+
else:
|
582
|
+
output_cols.extend([
|
583
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
584
|
+
])
|
585
|
+
else:
|
586
|
+
output_cols = []
|
587
|
+
|
588
|
+
# Make sure column names are valid snowflake identifiers.
|
589
|
+
assert output_cols is not None # Make MyPy happy
|
590
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
591
|
+
|
592
|
+
return rv
|
593
|
+
|
594
|
+
def _align_expected_output_names(
|
595
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
596
|
+
) -> List[str]:
|
597
|
+
# in case the inferred output column names dimension is different
|
598
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
599
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
600
|
+
output_df_columns = list(output_df_pd.columns)
|
601
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
602
|
+
if self.sample_weight_col:
|
603
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
604
|
+
# if the dimension of inferred output column names is correct; use it
|
605
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
606
|
+
return expected_output_cols_list
|
607
|
+
# otherwise, use the sklearn estimator's output
|
608
|
+
else:
|
609
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
542
610
|
|
543
611
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
544
612
|
@telemetry.send_api_usage_telemetry(
|
@@ -570,24 +638,26 @@ class BernoulliRBM(BaseTransformer):
|
|
570
638
|
# are specific to the type of dataset used.
|
571
639
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
572
640
|
|
641
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
642
|
+
|
573
643
|
if isinstance(dataset, DataFrame):
|
574
|
-
self.
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
644
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
645
|
+
self._deps = self._get_dependencies()
|
646
|
+
assert isinstance(
|
647
|
+
dataset._session, Session
|
648
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
579
649
|
transform_kwargs = dict(
|
580
650
|
session=dataset._session,
|
581
651
|
dependencies=self._deps,
|
582
|
-
drop_input_cols
|
652
|
+
drop_input_cols=self._drop_input_cols,
|
583
653
|
expected_output_cols_type="float",
|
584
654
|
)
|
655
|
+
expected_output_cols = self._align_expected_output_names(
|
656
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
657
|
+
)
|
585
658
|
|
586
659
|
elif isinstance(dataset, pd.DataFrame):
|
587
|
-
transform_kwargs = dict(
|
588
|
-
snowpark_input_cols = self._snowpark_cols,
|
589
|
-
drop_input_cols = self._drop_input_cols
|
590
|
-
)
|
660
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
591
661
|
|
592
662
|
transform_handlers = ModelTransformerBuilder.build(
|
593
663
|
dataset=dataset,
|
@@ -599,7 +669,7 @@ class BernoulliRBM(BaseTransformer):
|
|
599
669
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
600
670
|
inference_method=inference_method,
|
601
671
|
input_cols=self.input_cols,
|
602
|
-
expected_output_cols=
|
672
|
+
expected_output_cols=expected_output_cols,
|
603
673
|
**transform_kwargs
|
604
674
|
)
|
605
675
|
return output_df
|
@@ -629,29 +699,30 @@ class BernoulliRBM(BaseTransformer):
|
|
629
699
|
Output dataset with log probability of the sample for each class in the model.
|
630
700
|
"""
|
631
701
|
super()._check_dataset_type(dataset)
|
632
|
-
inference_method="predict_log_proba"
|
702
|
+
inference_method = "predict_log_proba"
|
703
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
633
704
|
|
634
705
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
635
706
|
# are specific to the type of dataset used.
|
636
707
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
637
708
|
|
638
709
|
if isinstance(dataset, DataFrame):
|
639
|
-
self.
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
710
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
711
|
+
self._deps = self._get_dependencies()
|
712
|
+
assert isinstance(
|
713
|
+
dataset._session, Session
|
714
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
644
715
|
transform_kwargs = dict(
|
645
716
|
session=dataset._session,
|
646
717
|
dependencies=self._deps,
|
647
|
-
drop_input_cols
|
718
|
+
drop_input_cols=self._drop_input_cols,
|
648
719
|
expected_output_cols_type="float",
|
649
720
|
)
|
721
|
+
expected_output_cols = self._align_expected_output_names(
|
722
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
723
|
+
)
|
650
724
|
elif isinstance(dataset, pd.DataFrame):
|
651
|
-
transform_kwargs = dict(
|
652
|
-
snowpark_input_cols = self._snowpark_cols,
|
653
|
-
drop_input_cols = self._drop_input_cols
|
654
|
-
)
|
725
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
655
726
|
|
656
727
|
transform_handlers = ModelTransformerBuilder.build(
|
657
728
|
dataset=dataset,
|
@@ -664,7 +735,7 @@ class BernoulliRBM(BaseTransformer):
|
|
664
735
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
665
736
|
inference_method=inference_method,
|
666
737
|
input_cols=self.input_cols,
|
667
|
-
expected_output_cols=
|
738
|
+
expected_output_cols=expected_output_cols,
|
668
739
|
**transform_kwargs
|
669
740
|
)
|
670
741
|
return output_df
|
@@ -690,30 +761,32 @@ class BernoulliRBM(BaseTransformer):
|
|
690
761
|
Output dataset with results of the decision function for the samples in input dataset.
|
691
762
|
"""
|
692
763
|
super()._check_dataset_type(dataset)
|
693
|
-
inference_method="decision_function"
|
764
|
+
inference_method = "decision_function"
|
694
765
|
|
695
766
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
696
767
|
# are specific to the type of dataset used.
|
697
768
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
698
769
|
|
770
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
771
|
+
|
699
772
|
if isinstance(dataset, DataFrame):
|
700
|
-
self.
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
773
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
774
|
+
self._deps = self._get_dependencies()
|
775
|
+
assert isinstance(
|
776
|
+
dataset._session, Session
|
777
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
705
778
|
transform_kwargs = dict(
|
706
779
|
session=dataset._session,
|
707
780
|
dependencies=self._deps,
|
708
|
-
drop_input_cols
|
781
|
+
drop_input_cols=self._drop_input_cols,
|
709
782
|
expected_output_cols_type="float",
|
710
783
|
)
|
784
|
+
expected_output_cols = self._align_expected_output_names(
|
785
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
786
|
+
)
|
711
787
|
|
712
788
|
elif isinstance(dataset, pd.DataFrame):
|
713
|
-
transform_kwargs = dict(
|
714
|
-
snowpark_input_cols = self._snowpark_cols,
|
715
|
-
drop_input_cols = self._drop_input_cols
|
716
|
-
)
|
789
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
717
790
|
|
718
791
|
transform_handlers = ModelTransformerBuilder.build(
|
719
792
|
dataset=dataset,
|
@@ -726,7 +799,7 @@ class BernoulliRBM(BaseTransformer):
|
|
726
799
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
727
800
|
inference_method=inference_method,
|
728
801
|
input_cols=self.input_cols,
|
729
|
-
expected_output_cols=
|
802
|
+
expected_output_cols=expected_output_cols,
|
730
803
|
**transform_kwargs
|
731
804
|
)
|
732
805
|
return output_df
|
@@ -757,17 +830,17 @@ class BernoulliRBM(BaseTransformer):
|
|
757
830
|
Output dataset with probability of the sample for each class in the model.
|
758
831
|
"""
|
759
832
|
super()._check_dataset_type(dataset)
|
760
|
-
inference_method="score_samples"
|
833
|
+
inference_method = "score_samples"
|
761
834
|
|
762
835
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
763
836
|
# are specific to the type of dataset used.
|
764
837
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
765
838
|
|
839
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
840
|
+
|
766
841
|
if isinstance(dataset, DataFrame):
|
767
|
-
self.
|
768
|
-
|
769
|
-
inference_method=inference_method,
|
770
|
-
)
|
842
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
843
|
+
self._deps = self._get_dependencies()
|
771
844
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
772
845
|
transform_kwargs = dict(
|
773
846
|
session=dataset._session,
|
@@ -775,6 +848,9 @@ class BernoulliRBM(BaseTransformer):
|
|
775
848
|
drop_input_cols = self._drop_input_cols,
|
776
849
|
expected_output_cols_type="float",
|
777
850
|
)
|
851
|
+
expected_output_cols = self._align_expected_output_names(
|
852
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
853
|
+
)
|
778
854
|
|
779
855
|
elif isinstance(dataset, pd.DataFrame):
|
780
856
|
transform_kwargs = dict(
|
@@ -793,7 +869,7 @@ class BernoulliRBM(BaseTransformer):
|
|
793
869
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
794
870
|
inference_method=inference_method,
|
795
871
|
input_cols=self.input_cols,
|
796
|
-
expected_output_cols=
|
872
|
+
expected_output_cols=expected_output_cols,
|
797
873
|
**transform_kwargs
|
798
874
|
)
|
799
875
|
return output_df
|
@@ -826,17 +902,15 @@ class BernoulliRBM(BaseTransformer):
|
|
826
902
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
827
903
|
|
828
904
|
if isinstance(dataset, DataFrame):
|
829
|
-
self.
|
830
|
-
|
831
|
-
inference_method="score",
|
832
|
-
)
|
905
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
906
|
+
self._deps = self._get_dependencies()
|
833
907
|
selected_cols = self._get_active_columns()
|
834
908
|
if len(selected_cols) > 0:
|
835
909
|
dataset = dataset.select(selected_cols)
|
836
910
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
837
911
|
transform_kwargs = dict(
|
838
912
|
session=dataset._session,
|
839
|
-
dependencies=
|
913
|
+
dependencies=self._deps,
|
840
914
|
score_sproc_imports=['sklearn'],
|
841
915
|
)
|
842
916
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -901,11 +975,8 @@ class BernoulliRBM(BaseTransformer):
|
|
901
975
|
|
902
976
|
if isinstance(dataset, DataFrame):
|
903
977
|
|
904
|
-
self.
|
905
|
-
|
906
|
-
inference_method=inference_method,
|
907
|
-
|
908
|
-
)
|
978
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
979
|
+
self._deps = self._get_dependencies()
|
909
980
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
910
981
|
transform_kwargs = dict(
|
911
982
|
session = dataset._session,
|
@@ -938,50 +1009,84 @@ class BernoulliRBM(BaseTransformer):
|
|
938
1009
|
)
|
939
1010
|
return output_df
|
940
1011
|
|
1012
|
+
|
1013
|
+
|
1014
|
+
def to_sklearn(self) -> Any:
|
1015
|
+
"""Get sklearn.neural_network.BernoulliRBM object.
|
1016
|
+
"""
|
1017
|
+
if self._sklearn_object is None:
|
1018
|
+
self._sklearn_object = self._create_sklearn_object()
|
1019
|
+
return self._sklearn_object
|
1020
|
+
|
1021
|
+
def to_xgboost(self) -> Any:
|
1022
|
+
raise exceptions.SnowflakeMLException(
|
1023
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1024
|
+
original_exception=AttributeError(
|
1025
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1026
|
+
"to_xgboost()",
|
1027
|
+
"to_sklearn()"
|
1028
|
+
)
|
1029
|
+
),
|
1030
|
+
)
|
941
1031
|
|
942
|
-
def
|
1032
|
+
def to_lightgbm(self) -> Any:
|
1033
|
+
raise exceptions.SnowflakeMLException(
|
1034
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1035
|
+
original_exception=AttributeError(
|
1036
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1037
|
+
"to_lightgbm()",
|
1038
|
+
"to_sklearn()"
|
1039
|
+
)
|
1040
|
+
),
|
1041
|
+
)
|
1042
|
+
|
1043
|
+
def _get_dependencies(self) -> List[str]:
|
1044
|
+
return self._deps
|
1045
|
+
|
1046
|
+
|
1047
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
943
1048
|
self._model_signature_dict = dict()
|
944
1049
|
|
945
1050
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
946
1051
|
|
947
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1052
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
948
1053
|
outputs: List[BaseFeatureSpec] = []
|
949
1054
|
if hasattr(self, "predict"):
|
950
1055
|
# keep mypy happy
|
951
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1056
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
952
1057
|
# For classifier, the type of predict is the same as the type of label
|
953
|
-
if self._sklearn_object._estimator_type ==
|
954
|
-
|
1058
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1059
|
+
# label columns is the desired type for output
|
955
1060
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
956
1061
|
# rename the output columns
|
957
1062
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
958
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
959
|
-
|
960
|
-
|
1063
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1064
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1065
|
+
)
|
961
1066
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
962
1067
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
963
|
-
# Clusterer returns int64 cluster labels.
|
1068
|
+
# Clusterer returns int64 cluster labels.
|
964
1069
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
965
1070
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
966
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
967
|
-
|
968
|
-
|
969
|
-
|
1071
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1072
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1073
|
+
)
|
1074
|
+
|
970
1075
|
# For regressor, the type of predict is float64
|
971
|
-
elif self._sklearn_object._estimator_type ==
|
1076
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
972
1077
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
973
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
974
|
-
|
975
|
-
|
976
|
-
|
1078
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1079
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1080
|
+
)
|
1081
|
+
|
977
1082
|
for prob_func in PROB_FUNCTIONS:
|
978
1083
|
if hasattr(self, prob_func):
|
979
1084
|
output_cols_prefix: str = f"{prob_func}_"
|
980
1085
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
981
1086
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
982
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
983
|
-
|
984
|
-
|
1087
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1088
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1089
|
+
)
|
985
1090
|
|
986
1091
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
987
1092
|
items = list(self._model_signature_dict.items())
|
@@ -994,10 +1099,10 @@ class BernoulliRBM(BaseTransformer):
|
|
994
1099
|
"""Returns model signature of current class.
|
995
1100
|
|
996
1101
|
Raises:
|
997
|
-
|
1102
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
998
1103
|
|
999
1104
|
Returns:
|
1000
|
-
Dict
|
1105
|
+
Dict with each method and its input output signature
|
1001
1106
|
"""
|
1002
1107
|
if self._model_signature_dict is None:
|
1003
1108
|
raise exceptions.SnowflakeMLException(
|
@@ -1005,35 +1110,3 @@ class BernoulliRBM(BaseTransformer):
|
|
1005
1110
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1006
1111
|
)
|
1007
1112
|
return self._model_signature_dict
|
1008
|
-
|
1009
|
-
def to_sklearn(self) -> Any:
|
1010
|
-
"""Get sklearn.neural_network.BernoulliRBM object.
|
1011
|
-
"""
|
1012
|
-
if self._sklearn_object is None:
|
1013
|
-
self._sklearn_object = self._create_sklearn_object()
|
1014
|
-
return self._sklearn_object
|
1015
|
-
|
1016
|
-
def to_xgboost(self) -> Any:
|
1017
|
-
raise exceptions.SnowflakeMLException(
|
1018
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1019
|
-
original_exception=AttributeError(
|
1020
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1021
|
-
"to_xgboost()",
|
1022
|
-
"to_sklearn()"
|
1023
|
-
)
|
1024
|
-
),
|
1025
|
-
)
|
1026
|
-
|
1027
|
-
def to_lightgbm(self) -> Any:
|
1028
|
-
raise exceptions.SnowflakeMLException(
|
1029
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1030
|
-
original_exception=AttributeError(
|
1031
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1032
|
-
"to_lightgbm()",
|
1033
|
-
"to_sklearn()"
|
1034
|
-
)
|
1035
|
-
),
|
1036
|
-
)
|
1037
|
-
|
1038
|
-
def _get_dependencies(self) -> List[str]:
|
1039
|
-
return self._deps
|