snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class MissingIndicator(BaseTransformer):
|
71
64
|
r"""Binary indicators for missing values
|
72
65
|
For more details on this class, see [sklearn.impute.MissingIndicator]
|
@@ -224,12 +217,7 @@ class MissingIndicator(BaseTransformer):
|
|
224
217
|
)
|
225
218
|
return selected_cols
|
226
219
|
|
227
|
-
|
228
|
-
project=_PROJECT,
|
229
|
-
subproject=_SUBPROJECT,
|
230
|
-
custom_tags=dict([("autogen", True)]),
|
231
|
-
)
|
232
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MissingIndicator":
|
220
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MissingIndicator":
|
233
221
|
"""Fit the transformer on `X`
|
234
222
|
For more details on this function, see [sklearn.impute.MissingIndicator.fit]
|
235
223
|
(https://scikit-learn.org/stable/modules/generated/sklearn.impute.MissingIndicator.html#sklearn.impute.MissingIndicator.fit)
|
@@ -256,12 +244,14 @@ class MissingIndicator(BaseTransformer):
|
|
256
244
|
|
257
245
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
258
246
|
|
259
|
-
|
247
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
260
248
|
if SNOWML_SPROC_ENV in os.environ:
|
261
249
|
statement_params = telemetry.get_function_usage_statement_params(
|
262
250
|
project=_PROJECT,
|
263
251
|
subproject=_SUBPROJECT,
|
264
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
252
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
253
|
+
inspect.currentframe(), MissingIndicator.__class__.__name__
|
254
|
+
),
|
265
255
|
api_calls=[Session.call],
|
266
256
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
267
257
|
)
|
@@ -282,27 +272,24 @@ class MissingIndicator(BaseTransformer):
|
|
282
272
|
)
|
283
273
|
self._sklearn_object = model_trainer.train()
|
284
274
|
self._is_fitted = True
|
285
|
-
self.
|
275
|
+
self._generate_model_signatures(dataset)
|
286
276
|
return self
|
287
277
|
|
288
278
|
def _batch_inference_validate_snowpark(
|
289
279
|
self,
|
290
280
|
dataset: DataFrame,
|
291
281
|
inference_method: str,
|
292
|
-
) ->
|
293
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
294
|
-
return the available package that exists in the snowflake anaconda channel
|
282
|
+
) -> None:
|
283
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
295
284
|
|
296
285
|
Args:
|
297
286
|
dataset: snowpark dataframe
|
298
287
|
inference_method: the inference method such as predict, score...
|
299
|
-
|
288
|
+
|
300
289
|
Raises:
|
301
290
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
302
291
|
SnowflakeMLException: If the session is None, raise error
|
303
292
|
|
304
|
-
Returns:
|
305
|
-
A list of available package that exists in the snowflake anaconda channel
|
306
293
|
"""
|
307
294
|
if not self._is_fitted:
|
308
295
|
raise exceptions.SnowflakeMLException(
|
@@ -320,9 +307,7 @@ class MissingIndicator(BaseTransformer):
|
|
320
307
|
"Session must not specified for snowpark dataset."
|
321
308
|
),
|
322
309
|
)
|
323
|
-
|
324
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
325
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
310
|
+
|
326
311
|
|
327
312
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
328
313
|
@telemetry.send_api_usage_telemetry(
|
@@ -356,7 +341,9 @@ class MissingIndicator(BaseTransformer):
|
|
356
341
|
# when it is classifier, infer the datatype from label columns
|
357
342
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
358
343
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
359
|
-
label_cols_signatures = [
|
344
|
+
label_cols_signatures = [
|
345
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
346
|
+
]
|
360
347
|
if len(label_cols_signatures) == 0:
|
361
348
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
362
349
|
raise exceptions.SnowflakeMLException(
|
@@ -364,25 +351,23 @@ class MissingIndicator(BaseTransformer):
|
|
364
351
|
original_exception=ValueError(error_str),
|
365
352
|
)
|
366
353
|
|
367
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
368
|
-
label_cols_signatures[0].as_snowpark_type()
|
369
|
-
)
|
354
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
370
355
|
|
371
|
-
self.
|
372
|
-
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
358
|
+
assert isinstance(
|
359
|
+
dataset._session, Session
|
360
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
373
361
|
|
374
362
|
transform_kwargs = dict(
|
375
|
-
session
|
376
|
-
dependencies
|
377
|
-
drop_input_cols
|
378
|
-
expected_output_cols_type
|
363
|
+
session=dataset._session,
|
364
|
+
dependencies=self._deps,
|
365
|
+
drop_input_cols=self._drop_input_cols,
|
366
|
+
expected_output_cols_type=expected_type_inferred,
|
379
367
|
)
|
380
368
|
|
381
369
|
elif isinstance(dataset, pd.DataFrame):
|
382
|
-
transform_kwargs = dict(
|
383
|
-
snowpark_input_cols = self._snowpark_cols,
|
384
|
-
drop_input_cols = self._drop_input_cols
|
385
|
-
)
|
370
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
386
371
|
|
387
372
|
transform_handlers = ModelTransformerBuilder.build(
|
388
373
|
dataset=dataset,
|
@@ -424,7 +409,7 @@ class MissingIndicator(BaseTransformer):
|
|
424
409
|
Transformed dataset.
|
425
410
|
"""
|
426
411
|
super()._check_dataset_type(dataset)
|
427
|
-
inference_method="transform"
|
412
|
+
inference_method = "transform"
|
428
413
|
|
429
414
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
430
415
|
# are specific to the type of dataset used.
|
@@ -454,24 +439,19 @@ class MissingIndicator(BaseTransformer):
|
|
454
439
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
455
440
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
456
441
|
|
457
|
-
self.
|
458
|
-
|
459
|
-
inference_method=inference_method,
|
460
|
-
)
|
442
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
443
|
+
self._deps = self._get_dependencies()
|
461
444
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
462
445
|
|
463
446
|
transform_kwargs = dict(
|
464
|
-
session
|
465
|
-
dependencies
|
466
|
-
drop_input_cols
|
467
|
-
expected_output_cols_type
|
447
|
+
session=dataset._session,
|
448
|
+
dependencies=self._deps,
|
449
|
+
drop_input_cols=self._drop_input_cols,
|
450
|
+
expected_output_cols_type=expected_dtype,
|
468
451
|
)
|
469
452
|
|
470
453
|
elif isinstance(dataset, pd.DataFrame):
|
471
|
-
transform_kwargs = dict(
|
472
|
-
snowpark_input_cols = self._snowpark_cols,
|
473
|
-
drop_input_cols = self._drop_input_cols
|
474
|
-
)
|
454
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
475
455
|
|
476
456
|
transform_handlers = ModelTransformerBuilder.build(
|
477
457
|
dataset=dataset,
|
@@ -490,7 +470,11 @@ class MissingIndicator(BaseTransformer):
|
|
490
470
|
return output_df
|
491
471
|
|
492
472
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
493
|
-
def fit_predict(
|
473
|
+
def fit_predict(
|
474
|
+
self,
|
475
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
476
|
+
output_cols_prefix: str = "fit_predict_",
|
477
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
494
478
|
""" Method not supported for this class.
|
495
479
|
|
496
480
|
|
@@ -515,22 +499,106 @@ class MissingIndicator(BaseTransformer):
|
|
515
499
|
)
|
516
500
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
517
501
|
drop_input_cols=self._drop_input_cols,
|
518
|
-
expected_output_cols_list=
|
502
|
+
expected_output_cols_list=(
|
503
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
504
|
+
),
|
519
505
|
)
|
520
506
|
self._sklearn_object = fitted_estimator
|
521
507
|
self._is_fitted = True
|
522
508
|
return output_result
|
523
509
|
|
510
|
+
|
511
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
512
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
513
|
+
""" Generate missing values indicator for `X`
|
514
|
+
For more details on this function, see [sklearn.impute.MissingIndicator.fit_transform]
|
515
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.impute.MissingIndicator.html#sklearn.impute.MissingIndicator.fit_transform)
|
516
|
+
|
517
|
+
|
518
|
+
Raises:
|
519
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
524
520
|
|
525
|
-
|
526
|
-
|
527
|
-
|
521
|
+
Args:
|
522
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
523
|
+
Snowpark or Pandas DataFrame.
|
524
|
+
output_cols_prefix: Prefix for the response columns
|
528
525
|
Returns:
|
529
526
|
Transformed dataset.
|
530
527
|
"""
|
531
|
-
self.
|
532
|
-
|
533
|
-
|
528
|
+
self._infer_input_output_cols(dataset)
|
529
|
+
super()._check_dataset_type(dataset)
|
530
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
531
|
+
estimator=self._sklearn_object,
|
532
|
+
dataset=dataset,
|
533
|
+
input_cols=self.input_cols,
|
534
|
+
label_cols=self.label_cols,
|
535
|
+
sample_weight_col=self.sample_weight_col,
|
536
|
+
autogenerated=self._autogenerated,
|
537
|
+
subproject=_SUBPROJECT,
|
538
|
+
)
|
539
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
540
|
+
drop_input_cols=self._drop_input_cols,
|
541
|
+
expected_output_cols_list=self.output_cols,
|
542
|
+
)
|
543
|
+
self._sklearn_object = fitted_estimator
|
544
|
+
self._is_fitted = True
|
545
|
+
return output_result
|
546
|
+
|
547
|
+
|
548
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
549
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
550
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
551
|
+
"""
|
552
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
553
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
554
|
+
if output_cols:
|
555
|
+
output_cols = [
|
556
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
557
|
+
for c in output_cols
|
558
|
+
]
|
559
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
560
|
+
output_cols = [output_cols_prefix]
|
561
|
+
elif self._sklearn_object is not None:
|
562
|
+
classes = self._sklearn_object.classes_
|
563
|
+
if isinstance(classes, numpy.ndarray):
|
564
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
565
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
566
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
567
|
+
output_cols = []
|
568
|
+
for i, cl in enumerate(classes):
|
569
|
+
# For binary classification, there is only one output column for each class
|
570
|
+
# ndarray as the two classes are complementary.
|
571
|
+
if len(cl) == 2:
|
572
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
573
|
+
else:
|
574
|
+
output_cols.extend([
|
575
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
576
|
+
])
|
577
|
+
else:
|
578
|
+
output_cols = []
|
579
|
+
|
580
|
+
# Make sure column names are valid snowflake identifiers.
|
581
|
+
assert output_cols is not None # Make MyPy happy
|
582
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
583
|
+
|
584
|
+
return rv
|
585
|
+
|
586
|
+
def _align_expected_output_names(
|
587
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
588
|
+
) -> List[str]:
|
589
|
+
# in case the inferred output column names dimension is different
|
590
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
591
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
592
|
+
output_df_columns = list(output_df_pd.columns)
|
593
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
594
|
+
if self.sample_weight_col:
|
595
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
596
|
+
# if the dimension of inferred output column names is correct; use it
|
597
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
598
|
+
return expected_output_cols_list
|
599
|
+
# otherwise, use the sklearn estimator's output
|
600
|
+
else:
|
601
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
534
602
|
|
535
603
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
536
604
|
@telemetry.send_api_usage_telemetry(
|
@@ -562,24 +630,26 @@ class MissingIndicator(BaseTransformer):
|
|
562
630
|
# are specific to the type of dataset used.
|
563
631
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
564
632
|
|
633
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
634
|
+
|
565
635
|
if isinstance(dataset, DataFrame):
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
636
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
637
|
+
self._deps = self._get_dependencies()
|
638
|
+
assert isinstance(
|
639
|
+
dataset._session, Session
|
640
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
571
641
|
transform_kwargs = dict(
|
572
642
|
session=dataset._session,
|
573
643
|
dependencies=self._deps,
|
574
|
-
drop_input_cols
|
644
|
+
drop_input_cols=self._drop_input_cols,
|
575
645
|
expected_output_cols_type="float",
|
576
646
|
)
|
647
|
+
expected_output_cols = self._align_expected_output_names(
|
648
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
649
|
+
)
|
577
650
|
|
578
651
|
elif isinstance(dataset, pd.DataFrame):
|
579
|
-
transform_kwargs = dict(
|
580
|
-
snowpark_input_cols = self._snowpark_cols,
|
581
|
-
drop_input_cols = self._drop_input_cols
|
582
|
-
)
|
652
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
583
653
|
|
584
654
|
transform_handlers = ModelTransformerBuilder.build(
|
585
655
|
dataset=dataset,
|
@@ -591,7 +661,7 @@ class MissingIndicator(BaseTransformer):
|
|
591
661
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
592
662
|
inference_method=inference_method,
|
593
663
|
input_cols=self.input_cols,
|
594
|
-
expected_output_cols=
|
664
|
+
expected_output_cols=expected_output_cols,
|
595
665
|
**transform_kwargs
|
596
666
|
)
|
597
667
|
return output_df
|
@@ -621,29 +691,30 @@ class MissingIndicator(BaseTransformer):
|
|
621
691
|
Output dataset with log probability of the sample for each class in the model.
|
622
692
|
"""
|
623
693
|
super()._check_dataset_type(dataset)
|
624
|
-
inference_method="predict_log_proba"
|
694
|
+
inference_method = "predict_log_proba"
|
695
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
625
696
|
|
626
697
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
627
698
|
# are specific to the type of dataset used.
|
628
699
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
629
700
|
|
630
701
|
if isinstance(dataset, DataFrame):
|
631
|
-
self.
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
702
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._get_dependencies()
|
704
|
+
assert isinstance(
|
705
|
+
dataset._session, Session
|
706
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
636
707
|
transform_kwargs = dict(
|
637
708
|
session=dataset._session,
|
638
709
|
dependencies=self._deps,
|
639
|
-
drop_input_cols
|
710
|
+
drop_input_cols=self._drop_input_cols,
|
640
711
|
expected_output_cols_type="float",
|
641
712
|
)
|
713
|
+
expected_output_cols = self._align_expected_output_names(
|
714
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
715
|
+
)
|
642
716
|
elif isinstance(dataset, pd.DataFrame):
|
643
|
-
transform_kwargs = dict(
|
644
|
-
snowpark_input_cols = self._snowpark_cols,
|
645
|
-
drop_input_cols = self._drop_input_cols
|
646
|
-
)
|
717
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
647
718
|
|
648
719
|
transform_handlers = ModelTransformerBuilder.build(
|
649
720
|
dataset=dataset,
|
@@ -656,7 +727,7 @@ class MissingIndicator(BaseTransformer):
|
|
656
727
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
657
728
|
inference_method=inference_method,
|
658
729
|
input_cols=self.input_cols,
|
659
|
-
expected_output_cols=
|
730
|
+
expected_output_cols=expected_output_cols,
|
660
731
|
**transform_kwargs
|
661
732
|
)
|
662
733
|
return output_df
|
@@ -682,30 +753,32 @@ class MissingIndicator(BaseTransformer):
|
|
682
753
|
Output dataset with results of the decision function for the samples in input dataset.
|
683
754
|
"""
|
684
755
|
super()._check_dataset_type(dataset)
|
685
|
-
inference_method="decision_function"
|
756
|
+
inference_method = "decision_function"
|
686
757
|
|
687
758
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
688
759
|
# are specific to the type of dataset used.
|
689
760
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
690
761
|
|
762
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
763
|
+
|
691
764
|
if isinstance(dataset, DataFrame):
|
692
|
-
self.
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
767
|
+
assert isinstance(
|
768
|
+
dataset._session, Session
|
769
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
697
770
|
transform_kwargs = dict(
|
698
771
|
session=dataset._session,
|
699
772
|
dependencies=self._deps,
|
700
|
-
drop_input_cols
|
773
|
+
drop_input_cols=self._drop_input_cols,
|
701
774
|
expected_output_cols_type="float",
|
702
775
|
)
|
776
|
+
expected_output_cols = self._align_expected_output_names(
|
777
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
778
|
+
)
|
703
779
|
|
704
780
|
elif isinstance(dataset, pd.DataFrame):
|
705
|
-
transform_kwargs = dict(
|
706
|
-
snowpark_input_cols = self._snowpark_cols,
|
707
|
-
drop_input_cols = self._drop_input_cols
|
708
|
-
)
|
781
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
709
782
|
|
710
783
|
transform_handlers = ModelTransformerBuilder.build(
|
711
784
|
dataset=dataset,
|
@@ -718,7 +791,7 @@ class MissingIndicator(BaseTransformer):
|
|
718
791
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
719
792
|
inference_method=inference_method,
|
720
793
|
input_cols=self.input_cols,
|
721
|
-
expected_output_cols=
|
794
|
+
expected_output_cols=expected_output_cols,
|
722
795
|
**transform_kwargs
|
723
796
|
)
|
724
797
|
return output_df
|
@@ -747,17 +820,17 @@ class MissingIndicator(BaseTransformer):
|
|
747
820
|
Output dataset with probability of the sample for each class in the model.
|
748
821
|
"""
|
749
822
|
super()._check_dataset_type(dataset)
|
750
|
-
inference_method="score_samples"
|
823
|
+
inference_method = "score_samples"
|
751
824
|
|
752
825
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
753
826
|
# are specific to the type of dataset used.
|
754
827
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
755
828
|
|
829
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
|
+
|
756
831
|
if isinstance(dataset, DataFrame):
|
757
|
-
self.
|
758
|
-
|
759
|
-
inference_method=inference_method,
|
760
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
761
834
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
762
835
|
transform_kwargs = dict(
|
763
836
|
session=dataset._session,
|
@@ -765,6 +838,9 @@ class MissingIndicator(BaseTransformer):
|
|
765
838
|
drop_input_cols = self._drop_input_cols,
|
766
839
|
expected_output_cols_type="float",
|
767
840
|
)
|
841
|
+
expected_output_cols = self._align_expected_output_names(
|
842
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
843
|
+
)
|
768
844
|
|
769
845
|
elif isinstance(dataset, pd.DataFrame):
|
770
846
|
transform_kwargs = dict(
|
@@ -783,7 +859,7 @@ class MissingIndicator(BaseTransformer):
|
|
783
859
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
784
860
|
inference_method=inference_method,
|
785
861
|
input_cols=self.input_cols,
|
786
|
-
expected_output_cols=
|
862
|
+
expected_output_cols=expected_output_cols,
|
787
863
|
**transform_kwargs
|
788
864
|
)
|
789
865
|
return output_df
|
@@ -816,17 +892,15 @@ class MissingIndicator(BaseTransformer):
|
|
816
892
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
817
893
|
|
818
894
|
if isinstance(dataset, DataFrame):
|
819
|
-
self.
|
820
|
-
|
821
|
-
inference_method="score",
|
822
|
-
)
|
895
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
896
|
+
self._deps = self._get_dependencies()
|
823
897
|
selected_cols = self._get_active_columns()
|
824
898
|
if len(selected_cols) > 0:
|
825
899
|
dataset = dataset.select(selected_cols)
|
826
900
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
827
901
|
transform_kwargs = dict(
|
828
902
|
session=dataset._session,
|
829
|
-
dependencies=
|
903
|
+
dependencies=self._deps,
|
830
904
|
score_sproc_imports=['sklearn'],
|
831
905
|
)
|
832
906
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -891,11 +965,8 @@ class MissingIndicator(BaseTransformer):
|
|
891
965
|
|
892
966
|
if isinstance(dataset, DataFrame):
|
893
967
|
|
894
|
-
self.
|
895
|
-
|
896
|
-
inference_method=inference_method,
|
897
|
-
|
898
|
-
)
|
968
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
969
|
+
self._deps = self._get_dependencies()
|
899
970
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
900
971
|
transform_kwargs = dict(
|
901
972
|
session = dataset._session,
|
@@ -928,50 +999,84 @@ class MissingIndicator(BaseTransformer):
|
|
928
999
|
)
|
929
1000
|
return output_df
|
930
1001
|
|
1002
|
+
|
1003
|
+
|
1004
|
+
def to_sklearn(self) -> Any:
|
1005
|
+
"""Get sklearn.impute.MissingIndicator object.
|
1006
|
+
"""
|
1007
|
+
if self._sklearn_object is None:
|
1008
|
+
self._sklearn_object = self._create_sklearn_object()
|
1009
|
+
return self._sklearn_object
|
1010
|
+
|
1011
|
+
def to_xgboost(self) -> Any:
|
1012
|
+
raise exceptions.SnowflakeMLException(
|
1013
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1014
|
+
original_exception=AttributeError(
|
1015
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1016
|
+
"to_xgboost()",
|
1017
|
+
"to_sklearn()"
|
1018
|
+
)
|
1019
|
+
),
|
1020
|
+
)
|
931
1021
|
|
932
|
-
def
|
1022
|
+
def to_lightgbm(self) -> Any:
|
1023
|
+
raise exceptions.SnowflakeMLException(
|
1024
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1025
|
+
original_exception=AttributeError(
|
1026
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1027
|
+
"to_lightgbm()",
|
1028
|
+
"to_sklearn()"
|
1029
|
+
)
|
1030
|
+
),
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
def _get_dependencies(self) -> List[str]:
|
1034
|
+
return self._deps
|
1035
|
+
|
1036
|
+
|
1037
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
933
1038
|
self._model_signature_dict = dict()
|
934
1039
|
|
935
1040
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
936
1041
|
|
937
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1042
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
938
1043
|
outputs: List[BaseFeatureSpec] = []
|
939
1044
|
if hasattr(self, "predict"):
|
940
1045
|
# keep mypy happy
|
941
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1046
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
942
1047
|
# For classifier, the type of predict is the same as the type of label
|
943
|
-
if self._sklearn_object._estimator_type ==
|
944
|
-
|
1048
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1049
|
+
# label columns is the desired type for output
|
945
1050
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
946
1051
|
# rename the output columns
|
947
1052
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
948
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
949
|
-
|
950
|
-
|
1053
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1054
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1055
|
+
)
|
951
1056
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
952
1057
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
953
|
-
# Clusterer returns int64 cluster labels.
|
1058
|
+
# Clusterer returns int64 cluster labels.
|
954
1059
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
955
1060
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
956
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
957
|
-
|
958
|
-
|
959
|
-
|
1061
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1062
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1063
|
+
)
|
1064
|
+
|
960
1065
|
# For regressor, the type of predict is float64
|
961
|
-
elif self._sklearn_object._estimator_type ==
|
1066
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
962
1067
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
963
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
964
|
-
|
965
|
-
|
966
|
-
|
1068
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1069
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1070
|
+
)
|
1071
|
+
|
967
1072
|
for prob_func in PROB_FUNCTIONS:
|
968
1073
|
if hasattr(self, prob_func):
|
969
1074
|
output_cols_prefix: str = f"{prob_func}_"
|
970
1075
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
971
1076
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
972
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
973
|
-
|
974
|
-
|
1077
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1078
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1079
|
+
)
|
975
1080
|
|
976
1081
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
977
1082
|
items = list(self._model_signature_dict.items())
|
@@ -984,10 +1089,10 @@ class MissingIndicator(BaseTransformer):
|
|
984
1089
|
"""Returns model signature of current class.
|
985
1090
|
|
986
1091
|
Raises:
|
987
|
-
|
1092
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
988
1093
|
|
989
1094
|
Returns:
|
990
|
-
Dict
|
1095
|
+
Dict with each method and its input output signature
|
991
1096
|
"""
|
992
1097
|
if self._model_signature_dict is None:
|
993
1098
|
raise exceptions.SnowflakeMLException(
|
@@ -995,35 +1100,3 @@ class MissingIndicator(BaseTransformer):
|
|
995
1100
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
996
1101
|
)
|
997
1102
|
return self._model_signature_dict
|
998
|
-
|
999
|
-
def to_sklearn(self) -> Any:
|
1000
|
-
"""Get sklearn.impute.MissingIndicator object.
|
1001
|
-
"""
|
1002
|
-
if self._sklearn_object is None:
|
1003
|
-
self._sklearn_object = self._create_sklearn_object()
|
1004
|
-
return self._sklearn_object
|
1005
|
-
|
1006
|
-
def to_xgboost(self) -> Any:
|
1007
|
-
raise exceptions.SnowflakeMLException(
|
1008
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1009
|
-
original_exception=AttributeError(
|
1010
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1011
|
-
"to_xgboost()",
|
1012
|
-
"to_sklearn()"
|
1013
|
-
)
|
1014
|
-
),
|
1015
|
-
)
|
1016
|
-
|
1017
|
-
def to_lightgbm(self) -> Any:
|
1018
|
-
raise exceptions.SnowflakeMLException(
|
1019
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1020
|
-
original_exception=AttributeError(
|
1021
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1022
|
-
"to_lightgbm()",
|
1023
|
-
"to_sklearn()"
|
1024
|
-
)
|
1025
|
-
),
|
1026
|
-
)
|
1027
|
-
|
1028
|
-
def _get_dependencies(self) -> List[str]:
|
1029
|
-
return self._deps
|