snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklear
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ColumnTransformer(BaseTransformer):
|
71
64
|
r"""Applies transformers to columns of an array or pandas DataFrame
|
72
65
|
For more details on this class, see [sklearn.compose.ColumnTransformer]
|
@@ -270,12 +263,7 @@ class ColumnTransformer(BaseTransformer):
|
|
270
263
|
)
|
271
264
|
return selected_cols
|
272
265
|
|
273
|
-
|
274
|
-
project=_PROJECT,
|
275
|
-
subproject=_SUBPROJECT,
|
276
|
-
custom_tags=dict([("autogen", True)]),
|
277
|
-
)
|
278
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ColumnTransformer":
|
266
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ColumnTransformer":
|
279
267
|
"""Fit all transformers using X
|
280
268
|
For more details on this function, see [sklearn.compose.ColumnTransformer.fit]
|
281
269
|
(https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer.fit)
|
@@ -302,12 +290,14 @@ class ColumnTransformer(BaseTransformer):
|
|
302
290
|
|
303
291
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
304
292
|
|
305
|
-
|
293
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
306
294
|
if SNOWML_SPROC_ENV in os.environ:
|
307
295
|
statement_params = telemetry.get_function_usage_statement_params(
|
308
296
|
project=_PROJECT,
|
309
297
|
subproject=_SUBPROJECT,
|
310
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
298
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
299
|
+
inspect.currentframe(), ColumnTransformer.__class__.__name__
|
300
|
+
),
|
311
301
|
api_calls=[Session.call],
|
312
302
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
313
303
|
)
|
@@ -328,27 +318,24 @@ class ColumnTransformer(BaseTransformer):
|
|
328
318
|
)
|
329
319
|
self._sklearn_object = model_trainer.train()
|
330
320
|
self._is_fitted = True
|
331
|
-
self.
|
321
|
+
self._generate_model_signatures(dataset)
|
332
322
|
return self
|
333
323
|
|
334
324
|
def _batch_inference_validate_snowpark(
|
335
325
|
self,
|
336
326
|
dataset: DataFrame,
|
337
327
|
inference_method: str,
|
338
|
-
) ->
|
339
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
340
|
-
return the available package that exists in the snowflake anaconda channel
|
328
|
+
) -> None:
|
329
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
341
330
|
|
342
331
|
Args:
|
343
332
|
dataset: snowpark dataframe
|
344
333
|
inference_method: the inference method such as predict, score...
|
345
|
-
|
334
|
+
|
346
335
|
Raises:
|
347
336
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
348
337
|
SnowflakeMLException: If the session is None, raise error
|
349
338
|
|
350
|
-
Returns:
|
351
|
-
A list of available package that exists in the snowflake anaconda channel
|
352
339
|
"""
|
353
340
|
if not self._is_fitted:
|
354
341
|
raise exceptions.SnowflakeMLException(
|
@@ -366,9 +353,7 @@ class ColumnTransformer(BaseTransformer):
|
|
366
353
|
"Session must not specified for snowpark dataset."
|
367
354
|
),
|
368
355
|
)
|
369
|
-
|
370
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
371
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
356
|
+
|
372
357
|
|
373
358
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
374
359
|
@telemetry.send_api_usage_telemetry(
|
@@ -402,7 +387,9 @@ class ColumnTransformer(BaseTransformer):
|
|
402
387
|
# when it is classifier, infer the datatype from label columns
|
403
388
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
404
389
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
405
|
-
label_cols_signatures = [
|
390
|
+
label_cols_signatures = [
|
391
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
392
|
+
]
|
406
393
|
if len(label_cols_signatures) == 0:
|
407
394
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
408
395
|
raise exceptions.SnowflakeMLException(
|
@@ -410,25 +397,23 @@ class ColumnTransformer(BaseTransformer):
|
|
410
397
|
original_exception=ValueError(error_str),
|
411
398
|
)
|
412
399
|
|
413
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
414
|
-
label_cols_signatures[0].as_snowpark_type()
|
415
|
-
)
|
400
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
416
401
|
|
417
|
-
self.
|
418
|
-
|
402
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
403
|
+
self._deps = self._get_dependencies()
|
404
|
+
assert isinstance(
|
405
|
+
dataset._session, Session
|
406
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
419
407
|
|
420
408
|
transform_kwargs = dict(
|
421
|
-
session
|
422
|
-
dependencies
|
423
|
-
drop_input_cols
|
424
|
-
expected_output_cols_type
|
409
|
+
session=dataset._session,
|
410
|
+
dependencies=self._deps,
|
411
|
+
drop_input_cols=self._drop_input_cols,
|
412
|
+
expected_output_cols_type=expected_type_inferred,
|
425
413
|
)
|
426
414
|
|
427
415
|
elif isinstance(dataset, pd.DataFrame):
|
428
|
-
transform_kwargs = dict(
|
429
|
-
snowpark_input_cols = self._snowpark_cols,
|
430
|
-
drop_input_cols = self._drop_input_cols
|
431
|
-
)
|
416
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
432
417
|
|
433
418
|
transform_handlers = ModelTransformerBuilder.build(
|
434
419
|
dataset=dataset,
|
@@ -470,7 +455,7 @@ class ColumnTransformer(BaseTransformer):
|
|
470
455
|
Transformed dataset.
|
471
456
|
"""
|
472
457
|
super()._check_dataset_type(dataset)
|
473
|
-
inference_method="transform"
|
458
|
+
inference_method = "transform"
|
474
459
|
|
475
460
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
476
461
|
# are specific to the type of dataset used.
|
@@ -500,24 +485,19 @@ class ColumnTransformer(BaseTransformer):
|
|
500
485
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
501
486
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
502
487
|
|
503
|
-
self.
|
504
|
-
|
505
|
-
inference_method=inference_method,
|
506
|
-
)
|
488
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
489
|
+
self._deps = self._get_dependencies()
|
507
490
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
508
491
|
|
509
492
|
transform_kwargs = dict(
|
510
|
-
session
|
511
|
-
dependencies
|
512
|
-
drop_input_cols
|
513
|
-
expected_output_cols_type
|
493
|
+
session=dataset._session,
|
494
|
+
dependencies=self._deps,
|
495
|
+
drop_input_cols=self._drop_input_cols,
|
496
|
+
expected_output_cols_type=expected_dtype,
|
514
497
|
)
|
515
498
|
|
516
499
|
elif isinstance(dataset, pd.DataFrame):
|
517
|
-
transform_kwargs = dict(
|
518
|
-
snowpark_input_cols = self._snowpark_cols,
|
519
|
-
drop_input_cols = self._drop_input_cols
|
520
|
-
)
|
500
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
521
501
|
|
522
502
|
transform_handlers = ModelTransformerBuilder.build(
|
523
503
|
dataset=dataset,
|
@@ -536,7 +516,11 @@ class ColumnTransformer(BaseTransformer):
|
|
536
516
|
return output_df
|
537
517
|
|
538
518
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
539
|
-
def fit_predict(
|
519
|
+
def fit_predict(
|
520
|
+
self,
|
521
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
522
|
+
output_cols_prefix: str = "fit_predict_",
|
523
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
540
524
|
""" Method not supported for this class.
|
541
525
|
|
542
526
|
|
@@ -561,22 +545,106 @@ class ColumnTransformer(BaseTransformer):
|
|
561
545
|
)
|
562
546
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
563
547
|
drop_input_cols=self._drop_input_cols,
|
564
|
-
expected_output_cols_list=
|
548
|
+
expected_output_cols_list=(
|
549
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
550
|
+
),
|
565
551
|
)
|
566
552
|
self._sklearn_object = fitted_estimator
|
567
553
|
self._is_fitted = True
|
568
554
|
return output_result
|
569
555
|
|
556
|
+
|
557
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
558
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
559
|
+
""" Fit all transformers, transform the data and concatenate results
|
560
|
+
For more details on this function, see [sklearn.compose.ColumnTransformer.fit_transform]
|
561
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer.fit_transform)
|
562
|
+
|
563
|
+
|
564
|
+
Raises:
|
565
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
570
566
|
|
571
|
-
|
572
|
-
|
573
|
-
|
567
|
+
Args:
|
568
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
569
|
+
Snowpark or Pandas DataFrame.
|
570
|
+
output_cols_prefix: Prefix for the response columns
|
574
571
|
Returns:
|
575
572
|
Transformed dataset.
|
576
573
|
"""
|
577
|
-
self.
|
578
|
-
|
579
|
-
|
574
|
+
self._infer_input_output_cols(dataset)
|
575
|
+
super()._check_dataset_type(dataset)
|
576
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
577
|
+
estimator=self._sklearn_object,
|
578
|
+
dataset=dataset,
|
579
|
+
input_cols=self.input_cols,
|
580
|
+
label_cols=self.label_cols,
|
581
|
+
sample_weight_col=self.sample_weight_col,
|
582
|
+
autogenerated=self._autogenerated,
|
583
|
+
subproject=_SUBPROJECT,
|
584
|
+
)
|
585
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
586
|
+
drop_input_cols=self._drop_input_cols,
|
587
|
+
expected_output_cols_list=self.output_cols,
|
588
|
+
)
|
589
|
+
self._sklearn_object = fitted_estimator
|
590
|
+
self._is_fitted = True
|
591
|
+
return output_result
|
592
|
+
|
593
|
+
|
594
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
595
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
596
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
597
|
+
"""
|
598
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
599
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
600
|
+
if output_cols:
|
601
|
+
output_cols = [
|
602
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
603
|
+
for c in output_cols
|
604
|
+
]
|
605
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
606
|
+
output_cols = [output_cols_prefix]
|
607
|
+
elif self._sklearn_object is not None:
|
608
|
+
classes = self._sklearn_object.classes_
|
609
|
+
if isinstance(classes, numpy.ndarray):
|
610
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
611
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
612
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
613
|
+
output_cols = []
|
614
|
+
for i, cl in enumerate(classes):
|
615
|
+
# For binary classification, there is only one output column for each class
|
616
|
+
# ndarray as the two classes are complementary.
|
617
|
+
if len(cl) == 2:
|
618
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
619
|
+
else:
|
620
|
+
output_cols.extend([
|
621
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
622
|
+
])
|
623
|
+
else:
|
624
|
+
output_cols = []
|
625
|
+
|
626
|
+
# Make sure column names are valid snowflake identifiers.
|
627
|
+
assert output_cols is not None # Make MyPy happy
|
628
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
629
|
+
|
630
|
+
return rv
|
631
|
+
|
632
|
+
def _align_expected_output_names(
|
633
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
634
|
+
) -> List[str]:
|
635
|
+
# in case the inferred output column names dimension is different
|
636
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
637
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
638
|
+
output_df_columns = list(output_df_pd.columns)
|
639
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
640
|
+
if self.sample_weight_col:
|
641
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
642
|
+
# if the dimension of inferred output column names is correct; use it
|
643
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
644
|
+
return expected_output_cols_list
|
645
|
+
# otherwise, use the sklearn estimator's output
|
646
|
+
else:
|
647
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
580
648
|
|
581
649
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
582
650
|
@telemetry.send_api_usage_telemetry(
|
@@ -608,24 +676,26 @@ class ColumnTransformer(BaseTransformer):
|
|
608
676
|
# are specific to the type of dataset used.
|
609
677
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
610
678
|
|
679
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
680
|
+
|
611
681
|
if isinstance(dataset, DataFrame):
|
612
|
-
self.
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
682
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
683
|
+
self._deps = self._get_dependencies()
|
684
|
+
assert isinstance(
|
685
|
+
dataset._session, Session
|
686
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
617
687
|
transform_kwargs = dict(
|
618
688
|
session=dataset._session,
|
619
689
|
dependencies=self._deps,
|
620
|
-
drop_input_cols
|
690
|
+
drop_input_cols=self._drop_input_cols,
|
621
691
|
expected_output_cols_type="float",
|
622
692
|
)
|
693
|
+
expected_output_cols = self._align_expected_output_names(
|
694
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
695
|
+
)
|
623
696
|
|
624
697
|
elif isinstance(dataset, pd.DataFrame):
|
625
|
-
transform_kwargs = dict(
|
626
|
-
snowpark_input_cols = self._snowpark_cols,
|
627
|
-
drop_input_cols = self._drop_input_cols
|
628
|
-
)
|
698
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
629
699
|
|
630
700
|
transform_handlers = ModelTransformerBuilder.build(
|
631
701
|
dataset=dataset,
|
@@ -637,7 +707,7 @@ class ColumnTransformer(BaseTransformer):
|
|
637
707
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
638
708
|
inference_method=inference_method,
|
639
709
|
input_cols=self.input_cols,
|
640
|
-
expected_output_cols=
|
710
|
+
expected_output_cols=expected_output_cols,
|
641
711
|
**transform_kwargs
|
642
712
|
)
|
643
713
|
return output_df
|
@@ -667,29 +737,30 @@ class ColumnTransformer(BaseTransformer):
|
|
667
737
|
Output dataset with log probability of the sample for each class in the model.
|
668
738
|
"""
|
669
739
|
super()._check_dataset_type(dataset)
|
670
|
-
inference_method="predict_log_proba"
|
740
|
+
inference_method = "predict_log_proba"
|
741
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
671
742
|
|
672
743
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
673
744
|
# are specific to the type of dataset used.
|
674
745
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
675
746
|
|
676
747
|
if isinstance(dataset, DataFrame):
|
677
|
-
self.
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
748
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
749
|
+
self._deps = self._get_dependencies()
|
750
|
+
assert isinstance(
|
751
|
+
dataset._session, Session
|
752
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
682
753
|
transform_kwargs = dict(
|
683
754
|
session=dataset._session,
|
684
755
|
dependencies=self._deps,
|
685
|
-
drop_input_cols
|
756
|
+
drop_input_cols=self._drop_input_cols,
|
686
757
|
expected_output_cols_type="float",
|
687
758
|
)
|
759
|
+
expected_output_cols = self._align_expected_output_names(
|
760
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
761
|
+
)
|
688
762
|
elif isinstance(dataset, pd.DataFrame):
|
689
|
-
transform_kwargs = dict(
|
690
|
-
snowpark_input_cols = self._snowpark_cols,
|
691
|
-
drop_input_cols = self._drop_input_cols
|
692
|
-
)
|
763
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
693
764
|
|
694
765
|
transform_handlers = ModelTransformerBuilder.build(
|
695
766
|
dataset=dataset,
|
@@ -702,7 +773,7 @@ class ColumnTransformer(BaseTransformer):
|
|
702
773
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
703
774
|
inference_method=inference_method,
|
704
775
|
input_cols=self.input_cols,
|
705
|
-
expected_output_cols=
|
776
|
+
expected_output_cols=expected_output_cols,
|
706
777
|
**transform_kwargs
|
707
778
|
)
|
708
779
|
return output_df
|
@@ -728,30 +799,32 @@ class ColumnTransformer(BaseTransformer):
|
|
728
799
|
Output dataset with results of the decision function for the samples in input dataset.
|
729
800
|
"""
|
730
801
|
super()._check_dataset_type(dataset)
|
731
|
-
inference_method="decision_function"
|
802
|
+
inference_method = "decision_function"
|
732
803
|
|
733
804
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
734
805
|
# are specific to the type of dataset used.
|
735
806
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
736
807
|
|
808
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
809
|
+
|
737
810
|
if isinstance(dataset, DataFrame):
|
738
|
-
self.
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
811
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
812
|
+
self._deps = self._get_dependencies()
|
813
|
+
assert isinstance(
|
814
|
+
dataset._session, Session
|
815
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
743
816
|
transform_kwargs = dict(
|
744
817
|
session=dataset._session,
|
745
818
|
dependencies=self._deps,
|
746
|
-
drop_input_cols
|
819
|
+
drop_input_cols=self._drop_input_cols,
|
747
820
|
expected_output_cols_type="float",
|
748
821
|
)
|
822
|
+
expected_output_cols = self._align_expected_output_names(
|
823
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
824
|
+
)
|
749
825
|
|
750
826
|
elif isinstance(dataset, pd.DataFrame):
|
751
|
-
transform_kwargs = dict(
|
752
|
-
snowpark_input_cols = self._snowpark_cols,
|
753
|
-
drop_input_cols = self._drop_input_cols
|
754
|
-
)
|
827
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
755
828
|
|
756
829
|
transform_handlers = ModelTransformerBuilder.build(
|
757
830
|
dataset=dataset,
|
@@ -764,7 +837,7 @@ class ColumnTransformer(BaseTransformer):
|
|
764
837
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
765
838
|
inference_method=inference_method,
|
766
839
|
input_cols=self.input_cols,
|
767
|
-
expected_output_cols=
|
840
|
+
expected_output_cols=expected_output_cols,
|
768
841
|
**transform_kwargs
|
769
842
|
)
|
770
843
|
return output_df
|
@@ -793,17 +866,17 @@ class ColumnTransformer(BaseTransformer):
|
|
793
866
|
Output dataset with probability of the sample for each class in the model.
|
794
867
|
"""
|
795
868
|
super()._check_dataset_type(dataset)
|
796
|
-
inference_method="score_samples"
|
869
|
+
inference_method = "score_samples"
|
797
870
|
|
798
871
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
799
872
|
# are specific to the type of dataset used.
|
800
873
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
801
874
|
|
875
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
876
|
+
|
802
877
|
if isinstance(dataset, DataFrame):
|
803
|
-
self.
|
804
|
-
|
805
|
-
inference_method=inference_method,
|
806
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
879
|
+
self._deps = self._get_dependencies()
|
807
880
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
808
881
|
transform_kwargs = dict(
|
809
882
|
session=dataset._session,
|
@@ -811,6 +884,9 @@ class ColumnTransformer(BaseTransformer):
|
|
811
884
|
drop_input_cols = self._drop_input_cols,
|
812
885
|
expected_output_cols_type="float",
|
813
886
|
)
|
887
|
+
expected_output_cols = self._align_expected_output_names(
|
888
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
889
|
+
)
|
814
890
|
|
815
891
|
elif isinstance(dataset, pd.DataFrame):
|
816
892
|
transform_kwargs = dict(
|
@@ -829,7 +905,7 @@ class ColumnTransformer(BaseTransformer):
|
|
829
905
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
830
906
|
inference_method=inference_method,
|
831
907
|
input_cols=self.input_cols,
|
832
|
-
expected_output_cols=
|
908
|
+
expected_output_cols=expected_output_cols,
|
833
909
|
**transform_kwargs
|
834
910
|
)
|
835
911
|
return output_df
|
@@ -862,17 +938,15 @@ class ColumnTransformer(BaseTransformer):
|
|
862
938
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
863
939
|
|
864
940
|
if isinstance(dataset, DataFrame):
|
865
|
-
self.
|
866
|
-
|
867
|
-
inference_method="score",
|
868
|
-
)
|
941
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
942
|
+
self._deps = self._get_dependencies()
|
869
943
|
selected_cols = self._get_active_columns()
|
870
944
|
if len(selected_cols) > 0:
|
871
945
|
dataset = dataset.select(selected_cols)
|
872
946
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
873
947
|
transform_kwargs = dict(
|
874
948
|
session=dataset._session,
|
875
|
-
dependencies=
|
949
|
+
dependencies=self._deps,
|
876
950
|
score_sproc_imports=['sklearn'],
|
877
951
|
)
|
878
952
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -937,11 +1011,8 @@ class ColumnTransformer(BaseTransformer):
|
|
937
1011
|
|
938
1012
|
if isinstance(dataset, DataFrame):
|
939
1013
|
|
940
|
-
self.
|
941
|
-
|
942
|
-
inference_method=inference_method,
|
943
|
-
|
944
|
-
)
|
1014
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1015
|
+
self._deps = self._get_dependencies()
|
945
1016
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
946
1017
|
transform_kwargs = dict(
|
947
1018
|
session = dataset._session,
|
@@ -974,50 +1045,84 @@ class ColumnTransformer(BaseTransformer):
|
|
974
1045
|
)
|
975
1046
|
return output_df
|
976
1047
|
|
1048
|
+
|
1049
|
+
|
1050
|
+
def to_sklearn(self) -> Any:
|
1051
|
+
"""Get sklearn.compose.ColumnTransformer object.
|
1052
|
+
"""
|
1053
|
+
if self._sklearn_object is None:
|
1054
|
+
self._sklearn_object = self._create_sklearn_object()
|
1055
|
+
return self._sklearn_object
|
1056
|
+
|
1057
|
+
def to_xgboost(self) -> Any:
|
1058
|
+
raise exceptions.SnowflakeMLException(
|
1059
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1060
|
+
original_exception=AttributeError(
|
1061
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1062
|
+
"to_xgboost()",
|
1063
|
+
"to_sklearn()"
|
1064
|
+
)
|
1065
|
+
),
|
1066
|
+
)
|
977
1067
|
|
978
|
-
def
|
1068
|
+
def to_lightgbm(self) -> Any:
|
1069
|
+
raise exceptions.SnowflakeMLException(
|
1070
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1071
|
+
original_exception=AttributeError(
|
1072
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1073
|
+
"to_lightgbm()",
|
1074
|
+
"to_sklearn()"
|
1075
|
+
)
|
1076
|
+
),
|
1077
|
+
)
|
1078
|
+
|
1079
|
+
def _get_dependencies(self) -> List[str]:
|
1080
|
+
return self._deps
|
1081
|
+
|
1082
|
+
|
1083
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
979
1084
|
self._model_signature_dict = dict()
|
980
1085
|
|
981
1086
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
982
1087
|
|
983
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1088
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
984
1089
|
outputs: List[BaseFeatureSpec] = []
|
985
1090
|
if hasattr(self, "predict"):
|
986
1091
|
# keep mypy happy
|
987
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1092
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
988
1093
|
# For classifier, the type of predict is the same as the type of label
|
989
|
-
if self._sklearn_object._estimator_type ==
|
990
|
-
|
1094
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1095
|
+
# label columns is the desired type for output
|
991
1096
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
992
1097
|
# rename the output columns
|
993
1098
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
994
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
995
|
-
|
996
|
-
|
1099
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1100
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1101
|
+
)
|
997
1102
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
998
1103
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
999
|
-
# Clusterer returns int64 cluster labels.
|
1104
|
+
# Clusterer returns int64 cluster labels.
|
1000
1105
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1001
1106
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1002
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1107
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1108
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1109
|
+
)
|
1110
|
+
|
1006
1111
|
# For regressor, the type of predict is float64
|
1007
|
-
elif self._sklearn_object._estimator_type ==
|
1112
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1008
1113
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1009
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1114
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1115
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1116
|
+
)
|
1117
|
+
|
1013
1118
|
for prob_func in PROB_FUNCTIONS:
|
1014
1119
|
if hasattr(self, prob_func):
|
1015
1120
|
output_cols_prefix: str = f"{prob_func}_"
|
1016
1121
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1017
1122
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1018
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1019
|
-
|
1020
|
-
|
1123
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1124
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1125
|
+
)
|
1021
1126
|
|
1022
1127
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1023
1128
|
items = list(self._model_signature_dict.items())
|
@@ -1030,10 +1135,10 @@ class ColumnTransformer(BaseTransformer):
|
|
1030
1135
|
"""Returns model signature of current class.
|
1031
1136
|
|
1032
1137
|
Raises:
|
1033
|
-
|
1138
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1034
1139
|
|
1035
1140
|
Returns:
|
1036
|
-
Dict
|
1141
|
+
Dict with each method and its input output signature
|
1037
1142
|
"""
|
1038
1143
|
if self._model_signature_dict is None:
|
1039
1144
|
raise exceptions.SnowflakeMLException(
|
@@ -1041,35 +1146,3 @@ class ColumnTransformer(BaseTransformer):
|
|
1041
1146
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1042
1147
|
)
|
1043
1148
|
return self._model_signature_dict
|
1044
|
-
|
1045
|
-
def to_sklearn(self) -> Any:
|
1046
|
-
"""Get sklearn.compose.ColumnTransformer object.
|
1047
|
-
"""
|
1048
|
-
if self._sklearn_object is None:
|
1049
|
-
self._sklearn_object = self._create_sklearn_object()
|
1050
|
-
return self._sklearn_object
|
1051
|
-
|
1052
|
-
def to_xgboost(self) -> Any:
|
1053
|
-
raise exceptions.SnowflakeMLException(
|
1054
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1055
|
-
original_exception=AttributeError(
|
1056
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1057
|
-
"to_xgboost()",
|
1058
|
-
"to_sklearn()"
|
1059
|
-
)
|
1060
|
-
),
|
1061
|
-
)
|
1062
|
-
|
1063
|
-
def to_lightgbm(self) -> Any:
|
1064
|
-
raise exceptions.SnowflakeMLException(
|
1065
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1066
|
-
original_exception=AttributeError(
|
1067
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1068
|
-
"to_lightgbm()",
|
1069
|
-
"to_sklearn()"
|
1070
|
-
)
|
1071
|
-
),
|
1072
|
-
)
|
1073
|
-
|
1074
|
-
def _get_dependencies(self) -> List[str]:
|
1075
|
-
return self._deps
|