snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class AdditiveChi2Sampler(BaseTransformer):
|
71
64
|
r"""Approximate feature map for additive chi2 kernel
|
72
65
|
For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
|
@@ -199,12 +192,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
199
192
|
)
|
200
193
|
return selected_cols
|
201
194
|
|
202
|
-
|
203
|
-
project=_PROJECT,
|
204
|
-
subproject=_SUBPROJECT,
|
205
|
-
custom_tags=dict([("autogen", True)]),
|
206
|
-
)
|
207
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "AdditiveChi2Sampler":
|
195
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "AdditiveChi2Sampler":
|
208
196
|
"""Only validates estimator's parameters
|
209
197
|
For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit]
|
210
198
|
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit)
|
@@ -231,12 +219,14 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
231
219
|
|
232
220
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
233
221
|
|
234
|
-
|
222
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
235
223
|
if SNOWML_SPROC_ENV in os.environ:
|
236
224
|
statement_params = telemetry.get_function_usage_statement_params(
|
237
225
|
project=_PROJECT,
|
238
226
|
subproject=_SUBPROJECT,
|
239
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
227
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
228
|
+
inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__
|
229
|
+
),
|
240
230
|
api_calls=[Session.call],
|
241
231
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
242
232
|
)
|
@@ -257,27 +247,24 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
257
247
|
)
|
258
248
|
self._sklearn_object = model_trainer.train()
|
259
249
|
self._is_fitted = True
|
260
|
-
self.
|
250
|
+
self._generate_model_signatures(dataset)
|
261
251
|
return self
|
262
252
|
|
263
253
|
def _batch_inference_validate_snowpark(
|
264
254
|
self,
|
265
255
|
dataset: DataFrame,
|
266
256
|
inference_method: str,
|
267
|
-
) ->
|
268
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
269
|
-
return the available package that exists in the snowflake anaconda channel
|
257
|
+
) -> None:
|
258
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
270
259
|
|
271
260
|
Args:
|
272
261
|
dataset: snowpark dataframe
|
273
262
|
inference_method: the inference method such as predict, score...
|
274
|
-
|
263
|
+
|
275
264
|
Raises:
|
276
265
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
277
266
|
SnowflakeMLException: If the session is None, raise error
|
278
267
|
|
279
|
-
Returns:
|
280
|
-
A list of available package that exists in the snowflake anaconda channel
|
281
268
|
"""
|
282
269
|
if not self._is_fitted:
|
283
270
|
raise exceptions.SnowflakeMLException(
|
@@ -295,9 +282,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
295
282
|
"Session must not specified for snowpark dataset."
|
296
283
|
),
|
297
284
|
)
|
298
|
-
|
299
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
300
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
285
|
+
|
301
286
|
|
302
287
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
303
288
|
@telemetry.send_api_usage_telemetry(
|
@@ -331,7 +316,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
331
316
|
# when it is classifier, infer the datatype from label columns
|
332
317
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
333
318
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
334
|
-
label_cols_signatures = [
|
319
|
+
label_cols_signatures = [
|
320
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
321
|
+
]
|
335
322
|
if len(label_cols_signatures) == 0:
|
336
323
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
337
324
|
raise exceptions.SnowflakeMLException(
|
@@ -339,25 +326,23 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
339
326
|
original_exception=ValueError(error_str),
|
340
327
|
)
|
341
328
|
|
342
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
343
|
-
label_cols_signatures[0].as_snowpark_type()
|
344
|
-
)
|
329
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
345
330
|
|
346
|
-
self.
|
347
|
-
|
331
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
332
|
+
self._deps = self._get_dependencies()
|
333
|
+
assert isinstance(
|
334
|
+
dataset._session, Session
|
335
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
348
336
|
|
349
337
|
transform_kwargs = dict(
|
350
|
-
session
|
351
|
-
dependencies
|
352
|
-
drop_input_cols
|
353
|
-
expected_output_cols_type
|
338
|
+
session=dataset._session,
|
339
|
+
dependencies=self._deps,
|
340
|
+
drop_input_cols=self._drop_input_cols,
|
341
|
+
expected_output_cols_type=expected_type_inferred,
|
354
342
|
)
|
355
343
|
|
356
344
|
elif isinstance(dataset, pd.DataFrame):
|
357
|
-
transform_kwargs = dict(
|
358
|
-
snowpark_input_cols = self._snowpark_cols,
|
359
|
-
drop_input_cols = self._drop_input_cols
|
360
|
-
)
|
345
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
361
346
|
|
362
347
|
transform_handlers = ModelTransformerBuilder.build(
|
363
348
|
dataset=dataset,
|
@@ -399,7 +384,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
399
384
|
Transformed dataset.
|
400
385
|
"""
|
401
386
|
super()._check_dataset_type(dataset)
|
402
|
-
inference_method="transform"
|
387
|
+
inference_method = "transform"
|
403
388
|
|
404
389
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
405
390
|
# are specific to the type of dataset used.
|
@@ -429,24 +414,19 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
429
414
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
430
415
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
431
416
|
|
432
|
-
self.
|
433
|
-
|
434
|
-
inference_method=inference_method,
|
435
|
-
)
|
417
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
418
|
+
self._deps = self._get_dependencies()
|
436
419
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
437
420
|
|
438
421
|
transform_kwargs = dict(
|
439
|
-
session
|
440
|
-
dependencies
|
441
|
-
drop_input_cols
|
442
|
-
expected_output_cols_type
|
422
|
+
session=dataset._session,
|
423
|
+
dependencies=self._deps,
|
424
|
+
drop_input_cols=self._drop_input_cols,
|
425
|
+
expected_output_cols_type=expected_dtype,
|
443
426
|
)
|
444
427
|
|
445
428
|
elif isinstance(dataset, pd.DataFrame):
|
446
|
-
transform_kwargs = dict(
|
447
|
-
snowpark_input_cols = self._snowpark_cols,
|
448
|
-
drop_input_cols = self._drop_input_cols
|
449
|
-
)
|
429
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
450
430
|
|
451
431
|
transform_handlers = ModelTransformerBuilder.build(
|
452
432
|
dataset=dataset,
|
@@ -465,7 +445,11 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
465
445
|
return output_df
|
466
446
|
|
467
447
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
468
|
-
def fit_predict(
|
448
|
+
def fit_predict(
|
449
|
+
self,
|
450
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
451
|
+
output_cols_prefix: str = "fit_predict_",
|
452
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
469
453
|
""" Method not supported for this class.
|
470
454
|
|
471
455
|
|
@@ -490,22 +474,106 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
490
474
|
)
|
491
475
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
492
476
|
drop_input_cols=self._drop_input_cols,
|
493
|
-
expected_output_cols_list=
|
477
|
+
expected_output_cols_list=(
|
478
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
479
|
+
),
|
494
480
|
)
|
495
481
|
self._sklearn_object = fitted_estimator
|
496
482
|
self._is_fitted = True
|
497
483
|
return output_result
|
498
484
|
|
485
|
+
|
486
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
487
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
488
|
+
""" Fit to data, then transform it
|
489
|
+
For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform]
|
490
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform)
|
491
|
+
|
492
|
+
|
493
|
+
Raises:
|
494
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
499
495
|
|
500
|
-
|
501
|
-
|
502
|
-
|
496
|
+
Args:
|
497
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
498
|
+
Snowpark or Pandas DataFrame.
|
499
|
+
output_cols_prefix: Prefix for the response columns
|
503
500
|
Returns:
|
504
501
|
Transformed dataset.
|
505
502
|
"""
|
506
|
-
self.
|
507
|
-
|
508
|
-
|
503
|
+
self._infer_input_output_cols(dataset)
|
504
|
+
super()._check_dataset_type(dataset)
|
505
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
506
|
+
estimator=self._sklearn_object,
|
507
|
+
dataset=dataset,
|
508
|
+
input_cols=self.input_cols,
|
509
|
+
label_cols=self.label_cols,
|
510
|
+
sample_weight_col=self.sample_weight_col,
|
511
|
+
autogenerated=self._autogenerated,
|
512
|
+
subproject=_SUBPROJECT,
|
513
|
+
)
|
514
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
516
|
+
expected_output_cols_list=self.output_cols,
|
517
|
+
)
|
518
|
+
self._sklearn_object = fitted_estimator
|
519
|
+
self._is_fitted = True
|
520
|
+
return output_result
|
521
|
+
|
522
|
+
|
523
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
524
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
525
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
526
|
+
"""
|
527
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
528
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
529
|
+
if output_cols:
|
530
|
+
output_cols = [
|
531
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
532
|
+
for c in output_cols
|
533
|
+
]
|
534
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
535
|
+
output_cols = [output_cols_prefix]
|
536
|
+
elif self._sklearn_object is not None:
|
537
|
+
classes = self._sklearn_object.classes_
|
538
|
+
if isinstance(classes, numpy.ndarray):
|
539
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
540
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
541
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
542
|
+
output_cols = []
|
543
|
+
for i, cl in enumerate(classes):
|
544
|
+
# For binary classification, there is only one output column for each class
|
545
|
+
# ndarray as the two classes are complementary.
|
546
|
+
if len(cl) == 2:
|
547
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
548
|
+
else:
|
549
|
+
output_cols.extend([
|
550
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
551
|
+
])
|
552
|
+
else:
|
553
|
+
output_cols = []
|
554
|
+
|
555
|
+
# Make sure column names are valid snowflake identifiers.
|
556
|
+
assert output_cols is not None # Make MyPy happy
|
557
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
558
|
+
|
559
|
+
return rv
|
560
|
+
|
561
|
+
def _align_expected_output_names(
|
562
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
563
|
+
) -> List[str]:
|
564
|
+
# in case the inferred output column names dimension is different
|
565
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
566
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
567
|
+
output_df_columns = list(output_df_pd.columns)
|
568
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
569
|
+
if self.sample_weight_col:
|
570
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
571
|
+
# if the dimension of inferred output column names is correct; use it
|
572
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
573
|
+
return expected_output_cols_list
|
574
|
+
# otherwise, use the sklearn estimator's output
|
575
|
+
else:
|
576
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
509
577
|
|
510
578
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
511
579
|
@telemetry.send_api_usage_telemetry(
|
@@ -537,24 +605,26 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
537
605
|
# are specific to the type of dataset used.
|
538
606
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
539
607
|
|
608
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
609
|
+
|
540
610
|
if isinstance(dataset, DataFrame):
|
541
|
-
self.
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
611
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
612
|
+
self._deps = self._get_dependencies()
|
613
|
+
assert isinstance(
|
614
|
+
dataset._session, Session
|
615
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
546
616
|
transform_kwargs = dict(
|
547
617
|
session=dataset._session,
|
548
618
|
dependencies=self._deps,
|
549
|
-
drop_input_cols
|
619
|
+
drop_input_cols=self._drop_input_cols,
|
550
620
|
expected_output_cols_type="float",
|
551
621
|
)
|
622
|
+
expected_output_cols = self._align_expected_output_names(
|
623
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
624
|
+
)
|
552
625
|
|
553
626
|
elif isinstance(dataset, pd.DataFrame):
|
554
|
-
transform_kwargs = dict(
|
555
|
-
snowpark_input_cols = self._snowpark_cols,
|
556
|
-
drop_input_cols = self._drop_input_cols
|
557
|
-
)
|
627
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
558
628
|
|
559
629
|
transform_handlers = ModelTransformerBuilder.build(
|
560
630
|
dataset=dataset,
|
@@ -566,7 +636,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
566
636
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
567
637
|
inference_method=inference_method,
|
568
638
|
input_cols=self.input_cols,
|
569
|
-
expected_output_cols=
|
639
|
+
expected_output_cols=expected_output_cols,
|
570
640
|
**transform_kwargs
|
571
641
|
)
|
572
642
|
return output_df
|
@@ -596,29 +666,30 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
596
666
|
Output dataset with log probability of the sample for each class in the model.
|
597
667
|
"""
|
598
668
|
super()._check_dataset_type(dataset)
|
599
|
-
inference_method="predict_log_proba"
|
669
|
+
inference_method = "predict_log_proba"
|
670
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
600
671
|
|
601
672
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
602
673
|
# are specific to the type of dataset used.
|
603
674
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
604
675
|
|
605
676
|
if isinstance(dataset, DataFrame):
|
606
|
-
self.
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
677
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
678
|
+
self._deps = self._get_dependencies()
|
679
|
+
assert isinstance(
|
680
|
+
dataset._session, Session
|
681
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
611
682
|
transform_kwargs = dict(
|
612
683
|
session=dataset._session,
|
613
684
|
dependencies=self._deps,
|
614
|
-
drop_input_cols
|
685
|
+
drop_input_cols=self._drop_input_cols,
|
615
686
|
expected_output_cols_type="float",
|
616
687
|
)
|
688
|
+
expected_output_cols = self._align_expected_output_names(
|
689
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
690
|
+
)
|
617
691
|
elif isinstance(dataset, pd.DataFrame):
|
618
|
-
transform_kwargs = dict(
|
619
|
-
snowpark_input_cols = self._snowpark_cols,
|
620
|
-
drop_input_cols = self._drop_input_cols
|
621
|
-
)
|
692
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
622
693
|
|
623
694
|
transform_handlers = ModelTransformerBuilder.build(
|
624
695
|
dataset=dataset,
|
@@ -631,7 +702,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
631
702
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
632
703
|
inference_method=inference_method,
|
633
704
|
input_cols=self.input_cols,
|
634
|
-
expected_output_cols=
|
705
|
+
expected_output_cols=expected_output_cols,
|
635
706
|
**transform_kwargs
|
636
707
|
)
|
637
708
|
return output_df
|
@@ -657,30 +728,32 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
657
728
|
Output dataset with results of the decision function for the samples in input dataset.
|
658
729
|
"""
|
659
730
|
super()._check_dataset_type(dataset)
|
660
|
-
inference_method="decision_function"
|
731
|
+
inference_method = "decision_function"
|
661
732
|
|
662
733
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
663
734
|
# are specific to the type of dataset used.
|
664
735
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
665
736
|
|
737
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
738
|
+
|
666
739
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
740
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
741
|
+
self._deps = self._get_dependencies()
|
742
|
+
assert isinstance(
|
743
|
+
dataset._session, Session
|
744
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
672
745
|
transform_kwargs = dict(
|
673
746
|
session=dataset._session,
|
674
747
|
dependencies=self._deps,
|
675
|
-
drop_input_cols
|
748
|
+
drop_input_cols=self._drop_input_cols,
|
676
749
|
expected_output_cols_type="float",
|
677
750
|
)
|
751
|
+
expected_output_cols = self._align_expected_output_names(
|
752
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
753
|
+
)
|
678
754
|
|
679
755
|
elif isinstance(dataset, pd.DataFrame):
|
680
|
-
transform_kwargs = dict(
|
681
|
-
snowpark_input_cols = self._snowpark_cols,
|
682
|
-
drop_input_cols = self._drop_input_cols
|
683
|
-
)
|
756
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
684
757
|
|
685
758
|
transform_handlers = ModelTransformerBuilder.build(
|
686
759
|
dataset=dataset,
|
@@ -693,7 +766,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
693
766
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
694
767
|
inference_method=inference_method,
|
695
768
|
input_cols=self.input_cols,
|
696
|
-
expected_output_cols=
|
769
|
+
expected_output_cols=expected_output_cols,
|
697
770
|
**transform_kwargs
|
698
771
|
)
|
699
772
|
return output_df
|
@@ -722,17 +795,17 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
722
795
|
Output dataset with probability of the sample for each class in the model.
|
723
796
|
"""
|
724
797
|
super()._check_dataset_type(dataset)
|
725
|
-
inference_method="score_samples"
|
798
|
+
inference_method = "score_samples"
|
726
799
|
|
727
800
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
728
801
|
# are specific to the type of dataset used.
|
729
802
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
730
803
|
|
804
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
805
|
+
|
731
806
|
if isinstance(dataset, DataFrame):
|
732
|
-
self.
|
733
|
-
|
734
|
-
inference_method=inference_method,
|
735
|
-
)
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
736
809
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
737
810
|
transform_kwargs = dict(
|
738
811
|
session=dataset._session,
|
@@ -740,6 +813,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
740
813
|
drop_input_cols = self._drop_input_cols,
|
741
814
|
expected_output_cols_type="float",
|
742
815
|
)
|
816
|
+
expected_output_cols = self._align_expected_output_names(
|
817
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
818
|
+
)
|
743
819
|
|
744
820
|
elif isinstance(dataset, pd.DataFrame):
|
745
821
|
transform_kwargs = dict(
|
@@ -758,7 +834,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
758
834
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
759
835
|
inference_method=inference_method,
|
760
836
|
input_cols=self.input_cols,
|
761
|
-
expected_output_cols=
|
837
|
+
expected_output_cols=expected_output_cols,
|
762
838
|
**transform_kwargs
|
763
839
|
)
|
764
840
|
return output_df
|
@@ -791,17 +867,15 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
791
867
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
792
868
|
|
793
869
|
if isinstance(dataset, DataFrame):
|
794
|
-
self.
|
795
|
-
|
796
|
-
inference_method="score",
|
797
|
-
)
|
870
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
871
|
+
self._deps = self._get_dependencies()
|
798
872
|
selected_cols = self._get_active_columns()
|
799
873
|
if len(selected_cols) > 0:
|
800
874
|
dataset = dataset.select(selected_cols)
|
801
875
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
802
876
|
transform_kwargs = dict(
|
803
877
|
session=dataset._session,
|
804
|
-
dependencies=
|
878
|
+
dependencies=self._deps,
|
805
879
|
score_sproc_imports=['sklearn'],
|
806
880
|
)
|
807
881
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -866,11 +940,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
866
940
|
|
867
941
|
if isinstance(dataset, DataFrame):
|
868
942
|
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
|
873
|
-
)
|
943
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
944
|
+
self._deps = self._get_dependencies()
|
874
945
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
875
946
|
transform_kwargs = dict(
|
876
947
|
session = dataset._session,
|
@@ -903,50 +974,84 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
903
974
|
)
|
904
975
|
return output_df
|
905
976
|
|
977
|
+
|
978
|
+
|
979
|
+
def to_sklearn(self) -> Any:
|
980
|
+
"""Get sklearn.kernel_approximation.AdditiveChi2Sampler object.
|
981
|
+
"""
|
982
|
+
if self._sklearn_object is None:
|
983
|
+
self._sklearn_object = self._create_sklearn_object()
|
984
|
+
return self._sklearn_object
|
985
|
+
|
986
|
+
def to_xgboost(self) -> Any:
|
987
|
+
raise exceptions.SnowflakeMLException(
|
988
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
989
|
+
original_exception=AttributeError(
|
990
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
991
|
+
"to_xgboost()",
|
992
|
+
"to_sklearn()"
|
993
|
+
)
|
994
|
+
),
|
995
|
+
)
|
906
996
|
|
907
|
-
def
|
997
|
+
def to_lightgbm(self) -> Any:
|
998
|
+
raise exceptions.SnowflakeMLException(
|
999
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1000
|
+
original_exception=AttributeError(
|
1001
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1002
|
+
"to_lightgbm()",
|
1003
|
+
"to_sklearn()"
|
1004
|
+
)
|
1005
|
+
),
|
1006
|
+
)
|
1007
|
+
|
1008
|
+
def _get_dependencies(self) -> List[str]:
|
1009
|
+
return self._deps
|
1010
|
+
|
1011
|
+
|
1012
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
908
1013
|
self._model_signature_dict = dict()
|
909
1014
|
|
910
1015
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
911
1016
|
|
912
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1017
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
913
1018
|
outputs: List[BaseFeatureSpec] = []
|
914
1019
|
if hasattr(self, "predict"):
|
915
1020
|
# keep mypy happy
|
916
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1021
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
917
1022
|
# For classifier, the type of predict is the same as the type of label
|
918
|
-
if self._sklearn_object._estimator_type ==
|
919
|
-
|
1023
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1024
|
+
# label columns is the desired type for output
|
920
1025
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
921
1026
|
# rename the output columns
|
922
1027
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
923
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
924
|
-
|
925
|
-
|
1028
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1029
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1030
|
+
)
|
926
1031
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
927
1032
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
928
|
-
# Clusterer returns int64 cluster labels.
|
1033
|
+
# Clusterer returns int64 cluster labels.
|
929
1034
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
930
1035
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
931
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
932
|
-
|
933
|
-
|
934
|
-
|
1036
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1037
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1038
|
+
)
|
1039
|
+
|
935
1040
|
# For regressor, the type of predict is float64
|
936
|
-
elif self._sklearn_object._estimator_type ==
|
1041
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
937
1042
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
938
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
939
|
-
|
940
|
-
|
941
|
-
|
1043
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1044
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1045
|
+
)
|
1046
|
+
|
942
1047
|
for prob_func in PROB_FUNCTIONS:
|
943
1048
|
if hasattr(self, prob_func):
|
944
1049
|
output_cols_prefix: str = f"{prob_func}_"
|
945
1050
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
946
1051
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
947
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
948
|
-
|
949
|
-
|
1052
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1053
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1054
|
+
)
|
950
1055
|
|
951
1056
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
952
1057
|
items = list(self._model_signature_dict.items())
|
@@ -959,10 +1064,10 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
959
1064
|
"""Returns model signature of current class.
|
960
1065
|
|
961
1066
|
Raises:
|
962
|
-
|
1067
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
963
1068
|
|
964
1069
|
Returns:
|
965
|
-
Dict
|
1070
|
+
Dict with each method and its input output signature
|
966
1071
|
"""
|
967
1072
|
if self._model_signature_dict is None:
|
968
1073
|
raise exceptions.SnowflakeMLException(
|
@@ -970,35 +1075,3 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
970
1075
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
971
1076
|
)
|
972
1077
|
return self._model_signature_dict
|
973
|
-
|
974
|
-
def to_sklearn(self) -> Any:
|
975
|
-
"""Get sklearn.kernel_approximation.AdditiveChi2Sampler object.
|
976
|
-
"""
|
977
|
-
if self._sklearn_object is None:
|
978
|
-
self._sklearn_object = self._create_sklearn_object()
|
979
|
-
return self._sklearn_object
|
980
|
-
|
981
|
-
def to_xgboost(self) -> Any:
|
982
|
-
raise exceptions.SnowflakeMLException(
|
983
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
984
|
-
original_exception=AttributeError(
|
985
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
986
|
-
"to_xgboost()",
|
987
|
-
"to_sklearn()"
|
988
|
-
)
|
989
|
-
),
|
990
|
-
)
|
991
|
-
|
992
|
-
def to_lightgbm(self) -> Any:
|
993
|
-
raise exceptions.SnowflakeMLException(
|
994
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
995
|
-
original_exception=AttributeError(
|
996
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
997
|
-
"to_lightgbm()",
|
998
|
-
"to_sklearn()"
|
999
|
-
)
|
1000
|
-
),
|
1001
|
-
)
|
1002
|
-
|
1003
|
-
def _get_dependencies(self) -> List[str]:
|
1004
|
-
return self._deps
|