snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class RBFSampler(BaseTransformer):
|
71
64
|
r"""Approximate a RBF kernel feature map using random Fourier features
|
72
65
|
For more details on this class, see [sklearn.kernel_approximation.RBFSampler]
|
@@ -210,12 +203,7 @@ class RBFSampler(BaseTransformer):
|
|
210
203
|
)
|
211
204
|
return selected_cols
|
212
205
|
|
213
|
-
|
214
|
-
project=_PROJECT,
|
215
|
-
subproject=_SUBPROJECT,
|
216
|
-
custom_tags=dict([("autogen", True)]),
|
217
|
-
)
|
218
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RBFSampler":
|
206
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RBFSampler":
|
219
207
|
"""Fit the model with X
|
220
208
|
For more details on this function, see [sklearn.kernel_approximation.RBFSampler.fit]
|
221
209
|
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.RBFSampler.html#sklearn.kernel_approximation.RBFSampler.fit)
|
@@ -242,12 +230,14 @@ class RBFSampler(BaseTransformer):
|
|
242
230
|
|
243
231
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
244
232
|
|
245
|
-
|
233
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
246
234
|
if SNOWML_SPROC_ENV in os.environ:
|
247
235
|
statement_params = telemetry.get_function_usage_statement_params(
|
248
236
|
project=_PROJECT,
|
249
237
|
subproject=_SUBPROJECT,
|
250
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
238
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
239
|
+
inspect.currentframe(), RBFSampler.__class__.__name__
|
240
|
+
),
|
251
241
|
api_calls=[Session.call],
|
252
242
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
253
243
|
)
|
@@ -268,27 +258,24 @@ class RBFSampler(BaseTransformer):
|
|
268
258
|
)
|
269
259
|
self._sklearn_object = model_trainer.train()
|
270
260
|
self._is_fitted = True
|
271
|
-
self.
|
261
|
+
self._generate_model_signatures(dataset)
|
272
262
|
return self
|
273
263
|
|
274
264
|
def _batch_inference_validate_snowpark(
|
275
265
|
self,
|
276
266
|
dataset: DataFrame,
|
277
267
|
inference_method: str,
|
278
|
-
) ->
|
279
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
280
|
-
return the available package that exists in the snowflake anaconda channel
|
268
|
+
) -> None:
|
269
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
281
270
|
|
282
271
|
Args:
|
283
272
|
dataset: snowpark dataframe
|
284
273
|
inference_method: the inference method such as predict, score...
|
285
|
-
|
274
|
+
|
286
275
|
Raises:
|
287
276
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
288
277
|
SnowflakeMLException: If the session is None, raise error
|
289
278
|
|
290
|
-
Returns:
|
291
|
-
A list of available package that exists in the snowflake anaconda channel
|
292
279
|
"""
|
293
280
|
if not self._is_fitted:
|
294
281
|
raise exceptions.SnowflakeMLException(
|
@@ -306,9 +293,7 @@ class RBFSampler(BaseTransformer):
|
|
306
293
|
"Session must not specified for snowpark dataset."
|
307
294
|
),
|
308
295
|
)
|
309
|
-
|
310
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
311
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
296
|
+
|
312
297
|
|
313
298
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
314
299
|
@telemetry.send_api_usage_telemetry(
|
@@ -342,7 +327,9 @@ class RBFSampler(BaseTransformer):
|
|
342
327
|
# when it is classifier, infer the datatype from label columns
|
343
328
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
344
329
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
345
|
-
label_cols_signatures = [
|
330
|
+
label_cols_signatures = [
|
331
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
332
|
+
]
|
346
333
|
if len(label_cols_signatures) == 0:
|
347
334
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
348
335
|
raise exceptions.SnowflakeMLException(
|
@@ -350,25 +337,23 @@ class RBFSampler(BaseTransformer):
|
|
350
337
|
original_exception=ValueError(error_str),
|
351
338
|
)
|
352
339
|
|
353
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
354
|
-
label_cols_signatures[0].as_snowpark_type()
|
355
|
-
)
|
340
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
356
341
|
|
357
|
-
self.
|
358
|
-
|
342
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
343
|
+
self._deps = self._get_dependencies()
|
344
|
+
assert isinstance(
|
345
|
+
dataset._session, Session
|
346
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
359
347
|
|
360
348
|
transform_kwargs = dict(
|
361
|
-
session
|
362
|
-
dependencies
|
363
|
-
drop_input_cols
|
364
|
-
expected_output_cols_type
|
349
|
+
session=dataset._session,
|
350
|
+
dependencies=self._deps,
|
351
|
+
drop_input_cols=self._drop_input_cols,
|
352
|
+
expected_output_cols_type=expected_type_inferred,
|
365
353
|
)
|
366
354
|
|
367
355
|
elif isinstance(dataset, pd.DataFrame):
|
368
|
-
transform_kwargs = dict(
|
369
|
-
snowpark_input_cols = self._snowpark_cols,
|
370
|
-
drop_input_cols = self._drop_input_cols
|
371
|
-
)
|
356
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
372
357
|
|
373
358
|
transform_handlers = ModelTransformerBuilder.build(
|
374
359
|
dataset=dataset,
|
@@ -410,7 +395,7 @@ class RBFSampler(BaseTransformer):
|
|
410
395
|
Transformed dataset.
|
411
396
|
"""
|
412
397
|
super()._check_dataset_type(dataset)
|
413
|
-
inference_method="transform"
|
398
|
+
inference_method = "transform"
|
414
399
|
|
415
400
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
416
401
|
# are specific to the type of dataset used.
|
@@ -440,24 +425,19 @@ class RBFSampler(BaseTransformer):
|
|
440
425
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
441
426
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
442
427
|
|
443
|
-
self.
|
444
|
-
|
445
|
-
inference_method=inference_method,
|
446
|
-
)
|
428
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
429
|
+
self._deps = self._get_dependencies()
|
447
430
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
448
431
|
|
449
432
|
transform_kwargs = dict(
|
450
|
-
session
|
451
|
-
dependencies
|
452
|
-
drop_input_cols
|
453
|
-
expected_output_cols_type
|
433
|
+
session=dataset._session,
|
434
|
+
dependencies=self._deps,
|
435
|
+
drop_input_cols=self._drop_input_cols,
|
436
|
+
expected_output_cols_type=expected_dtype,
|
454
437
|
)
|
455
438
|
|
456
439
|
elif isinstance(dataset, pd.DataFrame):
|
457
|
-
transform_kwargs = dict(
|
458
|
-
snowpark_input_cols = self._snowpark_cols,
|
459
|
-
drop_input_cols = self._drop_input_cols
|
460
|
-
)
|
440
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
461
441
|
|
462
442
|
transform_handlers = ModelTransformerBuilder.build(
|
463
443
|
dataset=dataset,
|
@@ -476,7 +456,11 @@ class RBFSampler(BaseTransformer):
|
|
476
456
|
return output_df
|
477
457
|
|
478
458
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
479
|
-
def fit_predict(
|
459
|
+
def fit_predict(
|
460
|
+
self,
|
461
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
462
|
+
output_cols_prefix: str = "fit_predict_",
|
463
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
480
464
|
""" Method not supported for this class.
|
481
465
|
|
482
466
|
|
@@ -501,22 +485,106 @@ class RBFSampler(BaseTransformer):
|
|
501
485
|
)
|
502
486
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
487
|
drop_input_cols=self._drop_input_cols,
|
504
|
-
expected_output_cols_list=
|
488
|
+
expected_output_cols_list=(
|
489
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
490
|
+
),
|
505
491
|
)
|
506
492
|
self._sklearn_object = fitted_estimator
|
507
493
|
self._is_fitted = True
|
508
494
|
return output_result
|
509
495
|
|
496
|
+
|
497
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
498
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
499
|
+
""" Fit to data, then transform it
|
500
|
+
For more details on this function, see [sklearn.kernel_approximation.RBFSampler.fit_transform]
|
501
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.RBFSampler.html#sklearn.kernel_approximation.RBFSampler.fit_transform)
|
502
|
+
|
503
|
+
|
504
|
+
Raises:
|
505
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
510
506
|
|
511
|
-
|
512
|
-
|
513
|
-
|
507
|
+
Args:
|
508
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
509
|
+
Snowpark or Pandas DataFrame.
|
510
|
+
output_cols_prefix: Prefix for the response columns
|
514
511
|
Returns:
|
515
512
|
Transformed dataset.
|
516
513
|
"""
|
517
|
-
self.
|
518
|
-
|
519
|
-
|
514
|
+
self._infer_input_output_cols(dataset)
|
515
|
+
super()._check_dataset_type(dataset)
|
516
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
517
|
+
estimator=self._sklearn_object,
|
518
|
+
dataset=dataset,
|
519
|
+
input_cols=self.input_cols,
|
520
|
+
label_cols=self.label_cols,
|
521
|
+
sample_weight_col=self.sample_weight_col,
|
522
|
+
autogenerated=self._autogenerated,
|
523
|
+
subproject=_SUBPROJECT,
|
524
|
+
)
|
525
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
526
|
+
drop_input_cols=self._drop_input_cols,
|
527
|
+
expected_output_cols_list=self.output_cols,
|
528
|
+
)
|
529
|
+
self._sklearn_object = fitted_estimator
|
530
|
+
self._is_fitted = True
|
531
|
+
return output_result
|
532
|
+
|
533
|
+
|
534
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
535
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
536
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
537
|
+
"""
|
538
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
539
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
540
|
+
if output_cols:
|
541
|
+
output_cols = [
|
542
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
543
|
+
for c in output_cols
|
544
|
+
]
|
545
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
546
|
+
output_cols = [output_cols_prefix]
|
547
|
+
elif self._sklearn_object is not None:
|
548
|
+
classes = self._sklearn_object.classes_
|
549
|
+
if isinstance(classes, numpy.ndarray):
|
550
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
551
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
552
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
553
|
+
output_cols = []
|
554
|
+
for i, cl in enumerate(classes):
|
555
|
+
# For binary classification, there is only one output column for each class
|
556
|
+
# ndarray as the two classes are complementary.
|
557
|
+
if len(cl) == 2:
|
558
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
559
|
+
else:
|
560
|
+
output_cols.extend([
|
561
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
562
|
+
])
|
563
|
+
else:
|
564
|
+
output_cols = []
|
565
|
+
|
566
|
+
# Make sure column names are valid snowflake identifiers.
|
567
|
+
assert output_cols is not None # Make MyPy happy
|
568
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
569
|
+
|
570
|
+
return rv
|
571
|
+
|
572
|
+
def _align_expected_output_names(
|
573
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
574
|
+
) -> List[str]:
|
575
|
+
# in case the inferred output column names dimension is different
|
576
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
577
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
578
|
+
output_df_columns = list(output_df_pd.columns)
|
579
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
580
|
+
if self.sample_weight_col:
|
581
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
582
|
+
# if the dimension of inferred output column names is correct; use it
|
583
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
584
|
+
return expected_output_cols_list
|
585
|
+
# otherwise, use the sklearn estimator's output
|
586
|
+
else:
|
587
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
520
588
|
|
521
589
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
522
590
|
@telemetry.send_api_usage_telemetry(
|
@@ -548,24 +616,26 @@ class RBFSampler(BaseTransformer):
|
|
548
616
|
# are specific to the type of dataset used.
|
549
617
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
550
618
|
|
619
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
620
|
+
|
551
621
|
if isinstance(dataset, DataFrame):
|
552
|
-
self.
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
622
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
623
|
+
self._deps = self._get_dependencies()
|
624
|
+
assert isinstance(
|
625
|
+
dataset._session, Session
|
626
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
557
627
|
transform_kwargs = dict(
|
558
628
|
session=dataset._session,
|
559
629
|
dependencies=self._deps,
|
560
|
-
drop_input_cols
|
630
|
+
drop_input_cols=self._drop_input_cols,
|
561
631
|
expected_output_cols_type="float",
|
562
632
|
)
|
633
|
+
expected_output_cols = self._align_expected_output_names(
|
634
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
635
|
+
)
|
563
636
|
|
564
637
|
elif isinstance(dataset, pd.DataFrame):
|
565
|
-
transform_kwargs = dict(
|
566
|
-
snowpark_input_cols = self._snowpark_cols,
|
567
|
-
drop_input_cols = self._drop_input_cols
|
568
|
-
)
|
638
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
569
639
|
|
570
640
|
transform_handlers = ModelTransformerBuilder.build(
|
571
641
|
dataset=dataset,
|
@@ -577,7 +647,7 @@ class RBFSampler(BaseTransformer):
|
|
577
647
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
578
648
|
inference_method=inference_method,
|
579
649
|
input_cols=self.input_cols,
|
580
|
-
expected_output_cols=
|
650
|
+
expected_output_cols=expected_output_cols,
|
581
651
|
**transform_kwargs
|
582
652
|
)
|
583
653
|
return output_df
|
@@ -607,29 +677,30 @@ class RBFSampler(BaseTransformer):
|
|
607
677
|
Output dataset with log probability of the sample for each class in the model.
|
608
678
|
"""
|
609
679
|
super()._check_dataset_type(dataset)
|
610
|
-
inference_method="predict_log_proba"
|
680
|
+
inference_method = "predict_log_proba"
|
681
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
611
682
|
|
612
683
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
613
684
|
# are specific to the type of dataset used.
|
614
685
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
615
686
|
|
616
687
|
if isinstance(dataset, DataFrame):
|
617
|
-
self.
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
688
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
689
|
+
self._deps = self._get_dependencies()
|
690
|
+
assert isinstance(
|
691
|
+
dataset._session, Session
|
692
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
622
693
|
transform_kwargs = dict(
|
623
694
|
session=dataset._session,
|
624
695
|
dependencies=self._deps,
|
625
|
-
drop_input_cols
|
696
|
+
drop_input_cols=self._drop_input_cols,
|
626
697
|
expected_output_cols_type="float",
|
627
698
|
)
|
699
|
+
expected_output_cols = self._align_expected_output_names(
|
700
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
701
|
+
)
|
628
702
|
elif isinstance(dataset, pd.DataFrame):
|
629
|
-
transform_kwargs = dict(
|
630
|
-
snowpark_input_cols = self._snowpark_cols,
|
631
|
-
drop_input_cols = self._drop_input_cols
|
632
|
-
)
|
703
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
633
704
|
|
634
705
|
transform_handlers = ModelTransformerBuilder.build(
|
635
706
|
dataset=dataset,
|
@@ -642,7 +713,7 @@ class RBFSampler(BaseTransformer):
|
|
642
713
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
643
714
|
inference_method=inference_method,
|
644
715
|
input_cols=self.input_cols,
|
645
|
-
expected_output_cols=
|
716
|
+
expected_output_cols=expected_output_cols,
|
646
717
|
**transform_kwargs
|
647
718
|
)
|
648
719
|
return output_df
|
@@ -668,30 +739,32 @@ class RBFSampler(BaseTransformer):
|
|
668
739
|
Output dataset with results of the decision function for the samples in input dataset.
|
669
740
|
"""
|
670
741
|
super()._check_dataset_type(dataset)
|
671
|
-
inference_method="decision_function"
|
742
|
+
inference_method = "decision_function"
|
672
743
|
|
673
744
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
674
745
|
# are specific to the type of dataset used.
|
675
746
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
676
747
|
|
748
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
749
|
+
|
677
750
|
if isinstance(dataset, DataFrame):
|
678
|
-
self.
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
753
|
+
assert isinstance(
|
754
|
+
dataset._session, Session
|
755
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
683
756
|
transform_kwargs = dict(
|
684
757
|
session=dataset._session,
|
685
758
|
dependencies=self._deps,
|
686
|
-
drop_input_cols
|
759
|
+
drop_input_cols=self._drop_input_cols,
|
687
760
|
expected_output_cols_type="float",
|
688
761
|
)
|
762
|
+
expected_output_cols = self._align_expected_output_names(
|
763
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
764
|
+
)
|
689
765
|
|
690
766
|
elif isinstance(dataset, pd.DataFrame):
|
691
|
-
transform_kwargs = dict(
|
692
|
-
snowpark_input_cols = self._snowpark_cols,
|
693
|
-
drop_input_cols = self._drop_input_cols
|
694
|
-
)
|
767
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
695
768
|
|
696
769
|
transform_handlers = ModelTransformerBuilder.build(
|
697
770
|
dataset=dataset,
|
@@ -704,7 +777,7 @@ class RBFSampler(BaseTransformer):
|
|
704
777
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
705
778
|
inference_method=inference_method,
|
706
779
|
input_cols=self.input_cols,
|
707
|
-
expected_output_cols=
|
780
|
+
expected_output_cols=expected_output_cols,
|
708
781
|
**transform_kwargs
|
709
782
|
)
|
710
783
|
return output_df
|
@@ -733,17 +806,17 @@ class RBFSampler(BaseTransformer):
|
|
733
806
|
Output dataset with probability of the sample for each class in the model.
|
734
807
|
"""
|
735
808
|
super()._check_dataset_type(dataset)
|
736
|
-
inference_method="score_samples"
|
809
|
+
inference_method = "score_samples"
|
737
810
|
|
738
811
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
739
812
|
# are specific to the type of dataset used.
|
740
813
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
741
814
|
|
815
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
816
|
+
|
742
817
|
if isinstance(dataset, DataFrame):
|
743
|
-
self.
|
744
|
-
|
745
|
-
inference_method=inference_method,
|
746
|
-
)
|
818
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
819
|
+
self._deps = self._get_dependencies()
|
747
820
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
748
821
|
transform_kwargs = dict(
|
749
822
|
session=dataset._session,
|
@@ -751,6 +824,9 @@ class RBFSampler(BaseTransformer):
|
|
751
824
|
drop_input_cols = self._drop_input_cols,
|
752
825
|
expected_output_cols_type="float",
|
753
826
|
)
|
827
|
+
expected_output_cols = self._align_expected_output_names(
|
828
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
829
|
+
)
|
754
830
|
|
755
831
|
elif isinstance(dataset, pd.DataFrame):
|
756
832
|
transform_kwargs = dict(
|
@@ -769,7 +845,7 @@ class RBFSampler(BaseTransformer):
|
|
769
845
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
770
846
|
inference_method=inference_method,
|
771
847
|
input_cols=self.input_cols,
|
772
|
-
expected_output_cols=
|
848
|
+
expected_output_cols=expected_output_cols,
|
773
849
|
**transform_kwargs
|
774
850
|
)
|
775
851
|
return output_df
|
@@ -802,17 +878,15 @@ class RBFSampler(BaseTransformer):
|
|
802
878
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
803
879
|
|
804
880
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method="score",
|
808
|
-
)
|
881
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
882
|
+
self._deps = self._get_dependencies()
|
809
883
|
selected_cols = self._get_active_columns()
|
810
884
|
if len(selected_cols) > 0:
|
811
885
|
dataset = dataset.select(selected_cols)
|
812
886
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
813
887
|
transform_kwargs = dict(
|
814
888
|
session=dataset._session,
|
815
|
-
dependencies=
|
889
|
+
dependencies=self._deps,
|
816
890
|
score_sproc_imports=['sklearn'],
|
817
891
|
)
|
818
892
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -877,11 +951,8 @@ class RBFSampler(BaseTransformer):
|
|
877
951
|
|
878
952
|
if isinstance(dataset, DataFrame):
|
879
953
|
|
880
|
-
self.
|
881
|
-
|
882
|
-
inference_method=inference_method,
|
883
|
-
|
884
|
-
)
|
954
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
955
|
+
self._deps = self._get_dependencies()
|
885
956
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
886
957
|
transform_kwargs = dict(
|
887
958
|
session = dataset._session,
|
@@ -914,50 +985,84 @@ class RBFSampler(BaseTransformer):
|
|
914
985
|
)
|
915
986
|
return output_df
|
916
987
|
|
988
|
+
|
989
|
+
|
990
|
+
def to_sklearn(self) -> Any:
|
991
|
+
"""Get sklearn.kernel_approximation.RBFSampler object.
|
992
|
+
"""
|
993
|
+
if self._sklearn_object is None:
|
994
|
+
self._sklearn_object = self._create_sklearn_object()
|
995
|
+
return self._sklearn_object
|
996
|
+
|
997
|
+
def to_xgboost(self) -> Any:
|
998
|
+
raise exceptions.SnowflakeMLException(
|
999
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1000
|
+
original_exception=AttributeError(
|
1001
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1002
|
+
"to_xgboost()",
|
1003
|
+
"to_sklearn()"
|
1004
|
+
)
|
1005
|
+
),
|
1006
|
+
)
|
917
1007
|
|
918
|
-
def
|
1008
|
+
def to_lightgbm(self) -> Any:
|
1009
|
+
raise exceptions.SnowflakeMLException(
|
1010
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1011
|
+
original_exception=AttributeError(
|
1012
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1013
|
+
"to_lightgbm()",
|
1014
|
+
"to_sklearn()"
|
1015
|
+
)
|
1016
|
+
),
|
1017
|
+
)
|
1018
|
+
|
1019
|
+
def _get_dependencies(self) -> List[str]:
|
1020
|
+
return self._deps
|
1021
|
+
|
1022
|
+
|
1023
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
919
1024
|
self._model_signature_dict = dict()
|
920
1025
|
|
921
1026
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
922
1027
|
|
923
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1028
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
924
1029
|
outputs: List[BaseFeatureSpec] = []
|
925
1030
|
if hasattr(self, "predict"):
|
926
1031
|
# keep mypy happy
|
927
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1032
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
928
1033
|
# For classifier, the type of predict is the same as the type of label
|
929
|
-
if self._sklearn_object._estimator_type ==
|
930
|
-
|
1034
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1035
|
+
# label columns is the desired type for output
|
931
1036
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
932
1037
|
# rename the output columns
|
933
1038
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
934
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
935
|
-
|
936
|
-
|
1039
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1040
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1041
|
+
)
|
937
1042
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
938
1043
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
939
|
-
# Clusterer returns int64 cluster labels.
|
1044
|
+
# Clusterer returns int64 cluster labels.
|
940
1045
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
941
1046
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
942
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
943
|
-
|
944
|
-
|
945
|
-
|
1047
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1048
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1049
|
+
)
|
1050
|
+
|
946
1051
|
# For regressor, the type of predict is float64
|
947
|
-
elif self._sklearn_object._estimator_type ==
|
1052
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
948
1053
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
949
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
950
|
-
|
951
|
-
|
952
|
-
|
1054
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1055
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1056
|
+
)
|
1057
|
+
|
953
1058
|
for prob_func in PROB_FUNCTIONS:
|
954
1059
|
if hasattr(self, prob_func):
|
955
1060
|
output_cols_prefix: str = f"{prob_func}_"
|
956
1061
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
957
1062
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
958
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
959
|
-
|
960
|
-
|
1063
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1064
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1065
|
+
)
|
961
1066
|
|
962
1067
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
963
1068
|
items = list(self._model_signature_dict.items())
|
@@ -970,10 +1075,10 @@ class RBFSampler(BaseTransformer):
|
|
970
1075
|
"""Returns model signature of current class.
|
971
1076
|
|
972
1077
|
Raises:
|
973
|
-
|
1078
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
974
1079
|
|
975
1080
|
Returns:
|
976
|
-
Dict
|
1081
|
+
Dict with each method and its input output signature
|
977
1082
|
"""
|
978
1083
|
if self._model_signature_dict is None:
|
979
1084
|
raise exceptions.SnowflakeMLException(
|
@@ -981,35 +1086,3 @@ class RBFSampler(BaseTransformer):
|
|
981
1086
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
982
1087
|
)
|
983
1088
|
return self._model_signature_dict
|
984
|
-
|
985
|
-
def to_sklearn(self) -> Any:
|
986
|
-
"""Get sklearn.kernel_approximation.RBFSampler object.
|
987
|
-
"""
|
988
|
-
if self._sklearn_object is None:
|
989
|
-
self._sklearn_object = self._create_sklearn_object()
|
990
|
-
return self._sklearn_object
|
991
|
-
|
992
|
-
def to_xgboost(self) -> Any:
|
993
|
-
raise exceptions.SnowflakeMLException(
|
994
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
995
|
-
original_exception=AttributeError(
|
996
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
997
|
-
"to_xgboost()",
|
998
|
-
"to_sklearn()"
|
999
|
-
)
|
1000
|
-
),
|
1001
|
-
)
|
1002
|
-
|
1003
|
-
def to_lightgbm(self) -> Any:
|
1004
|
-
raise exceptions.SnowflakeMLException(
|
1005
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1006
|
-
original_exception=AttributeError(
|
1007
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1008
|
-
"to_lightgbm()",
|
1009
|
-
"to_sklearn()"
|
1010
|
-
)
|
1011
|
-
),
|
1012
|
-
)
|
1013
|
-
|
1014
|
-
def _get_dependencies(self) -> List[str]:
|
1015
|
-
return self._deps
|