snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class RandomForestClassifier(BaseTransformer):
|
71
64
|
r"""A random forest classifier
|
72
65
|
For more details on this class, see [sklearn.ensemble.RandomForestClassifier]
|
@@ -375,12 +368,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
375
368
|
)
|
376
369
|
return selected_cols
|
377
370
|
|
378
|
-
|
379
|
-
project=_PROJECT,
|
380
|
-
subproject=_SUBPROJECT,
|
381
|
-
custom_tags=dict([("autogen", True)]),
|
382
|
-
)
|
383
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomForestClassifier":
|
371
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomForestClassifier":
|
384
372
|
"""Build a forest of trees from the training set (X, y)
|
385
373
|
For more details on this function, see [sklearn.ensemble.RandomForestClassifier.fit]
|
386
374
|
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier.fit)
|
@@ -407,12 +395,14 @@ class RandomForestClassifier(BaseTransformer):
|
|
407
395
|
|
408
396
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
409
397
|
|
410
|
-
|
398
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
411
399
|
if SNOWML_SPROC_ENV in os.environ:
|
412
400
|
statement_params = telemetry.get_function_usage_statement_params(
|
413
401
|
project=_PROJECT,
|
414
402
|
subproject=_SUBPROJECT,
|
415
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
403
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
404
|
+
inspect.currentframe(), RandomForestClassifier.__class__.__name__
|
405
|
+
),
|
416
406
|
api_calls=[Session.call],
|
417
407
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
418
408
|
)
|
@@ -433,27 +423,24 @@ class RandomForestClassifier(BaseTransformer):
|
|
433
423
|
)
|
434
424
|
self._sklearn_object = model_trainer.train()
|
435
425
|
self._is_fitted = True
|
436
|
-
self.
|
426
|
+
self._generate_model_signatures(dataset)
|
437
427
|
return self
|
438
428
|
|
439
429
|
def _batch_inference_validate_snowpark(
|
440
430
|
self,
|
441
431
|
dataset: DataFrame,
|
442
432
|
inference_method: str,
|
443
|
-
) ->
|
444
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
445
|
-
return the available package that exists in the snowflake anaconda channel
|
433
|
+
) -> None:
|
434
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
446
435
|
|
447
436
|
Args:
|
448
437
|
dataset: snowpark dataframe
|
449
438
|
inference_method: the inference method such as predict, score...
|
450
|
-
|
439
|
+
|
451
440
|
Raises:
|
452
441
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
453
442
|
SnowflakeMLException: If the session is None, raise error
|
454
443
|
|
455
|
-
Returns:
|
456
|
-
A list of available package that exists in the snowflake anaconda channel
|
457
444
|
"""
|
458
445
|
if not self._is_fitted:
|
459
446
|
raise exceptions.SnowflakeMLException(
|
@@ -471,9 +458,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
471
458
|
"Session must not specified for snowpark dataset."
|
472
459
|
),
|
473
460
|
)
|
474
|
-
|
475
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
476
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
461
|
+
|
477
462
|
|
478
463
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
479
464
|
@telemetry.send_api_usage_telemetry(
|
@@ -509,7 +494,9 @@ class RandomForestClassifier(BaseTransformer):
|
|
509
494
|
# when it is classifier, infer the datatype from label columns
|
510
495
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
511
496
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
512
|
-
label_cols_signatures = [
|
497
|
+
label_cols_signatures = [
|
498
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
499
|
+
]
|
513
500
|
if len(label_cols_signatures) == 0:
|
514
501
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
515
502
|
raise exceptions.SnowflakeMLException(
|
@@ -517,25 +504,23 @@ class RandomForestClassifier(BaseTransformer):
|
|
517
504
|
original_exception=ValueError(error_str),
|
518
505
|
)
|
519
506
|
|
520
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
521
|
-
label_cols_signatures[0].as_snowpark_type()
|
522
|
-
)
|
507
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
523
508
|
|
524
|
-
self.
|
525
|
-
|
509
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
510
|
+
self._deps = self._get_dependencies()
|
511
|
+
assert isinstance(
|
512
|
+
dataset._session, Session
|
513
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
526
514
|
|
527
515
|
transform_kwargs = dict(
|
528
|
-
session
|
529
|
-
dependencies
|
530
|
-
drop_input_cols
|
531
|
-
expected_output_cols_type
|
516
|
+
session=dataset._session,
|
517
|
+
dependencies=self._deps,
|
518
|
+
drop_input_cols=self._drop_input_cols,
|
519
|
+
expected_output_cols_type=expected_type_inferred,
|
532
520
|
)
|
533
521
|
|
534
522
|
elif isinstance(dataset, pd.DataFrame):
|
535
|
-
transform_kwargs = dict(
|
536
|
-
snowpark_input_cols = self._snowpark_cols,
|
537
|
-
drop_input_cols = self._drop_input_cols
|
538
|
-
)
|
523
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
539
524
|
|
540
525
|
transform_handlers = ModelTransformerBuilder.build(
|
541
526
|
dataset=dataset,
|
@@ -575,7 +560,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
575
560
|
Transformed dataset.
|
576
561
|
"""
|
577
562
|
super()._check_dataset_type(dataset)
|
578
|
-
inference_method="transform"
|
563
|
+
inference_method = "transform"
|
579
564
|
|
580
565
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
581
566
|
# are specific to the type of dataset used.
|
@@ -605,24 +590,19 @@ class RandomForestClassifier(BaseTransformer):
|
|
605
590
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
606
591
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
607
592
|
|
608
|
-
self.
|
609
|
-
|
610
|
-
inference_method=inference_method,
|
611
|
-
)
|
593
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
594
|
+
self._deps = self._get_dependencies()
|
612
595
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
613
596
|
|
614
597
|
transform_kwargs = dict(
|
615
|
-
session
|
616
|
-
dependencies
|
617
|
-
drop_input_cols
|
618
|
-
expected_output_cols_type
|
598
|
+
session=dataset._session,
|
599
|
+
dependencies=self._deps,
|
600
|
+
drop_input_cols=self._drop_input_cols,
|
601
|
+
expected_output_cols_type=expected_dtype,
|
619
602
|
)
|
620
603
|
|
621
604
|
elif isinstance(dataset, pd.DataFrame):
|
622
|
-
transform_kwargs = dict(
|
623
|
-
snowpark_input_cols = self._snowpark_cols,
|
624
|
-
drop_input_cols = self._drop_input_cols
|
625
|
-
)
|
605
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
626
606
|
|
627
607
|
transform_handlers = ModelTransformerBuilder.build(
|
628
608
|
dataset=dataset,
|
@@ -641,7 +621,11 @@ class RandomForestClassifier(BaseTransformer):
|
|
641
621
|
return output_df
|
642
622
|
|
643
623
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
644
|
-
def fit_predict(
|
624
|
+
def fit_predict(
|
625
|
+
self,
|
626
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
627
|
+
output_cols_prefix: str = "fit_predict_",
|
628
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
645
629
|
""" Method not supported for this class.
|
646
630
|
|
647
631
|
|
@@ -666,22 +650,104 @@ class RandomForestClassifier(BaseTransformer):
|
|
666
650
|
)
|
667
651
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
668
652
|
drop_input_cols=self._drop_input_cols,
|
669
|
-
expected_output_cols_list=
|
653
|
+
expected_output_cols_list=(
|
654
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
655
|
+
),
|
670
656
|
)
|
671
657
|
self._sklearn_object = fitted_estimator
|
672
658
|
self._is_fitted = True
|
673
659
|
return output_result
|
674
660
|
|
661
|
+
|
662
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
663
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
664
|
+
""" Method not supported for this class.
|
665
|
+
|
675
666
|
|
676
|
-
|
677
|
-
|
678
|
-
|
667
|
+
Raises:
|
668
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
672
|
+
Snowpark or Pandas DataFrame.
|
673
|
+
output_cols_prefix: Prefix for the response columns
|
679
674
|
Returns:
|
680
675
|
Transformed dataset.
|
681
676
|
"""
|
682
|
-
self.
|
683
|
-
|
684
|
-
|
677
|
+
self._infer_input_output_cols(dataset)
|
678
|
+
super()._check_dataset_type(dataset)
|
679
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
680
|
+
estimator=self._sklearn_object,
|
681
|
+
dataset=dataset,
|
682
|
+
input_cols=self.input_cols,
|
683
|
+
label_cols=self.label_cols,
|
684
|
+
sample_weight_col=self.sample_weight_col,
|
685
|
+
autogenerated=self._autogenerated,
|
686
|
+
subproject=_SUBPROJECT,
|
687
|
+
)
|
688
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
689
|
+
drop_input_cols=self._drop_input_cols,
|
690
|
+
expected_output_cols_list=self.output_cols,
|
691
|
+
)
|
692
|
+
self._sklearn_object = fitted_estimator
|
693
|
+
self._is_fitted = True
|
694
|
+
return output_result
|
695
|
+
|
696
|
+
|
697
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
698
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
699
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
700
|
+
"""
|
701
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
702
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
703
|
+
if output_cols:
|
704
|
+
output_cols = [
|
705
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
706
|
+
for c in output_cols
|
707
|
+
]
|
708
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
709
|
+
output_cols = [output_cols_prefix]
|
710
|
+
elif self._sklearn_object is not None:
|
711
|
+
classes = self._sklearn_object.classes_
|
712
|
+
if isinstance(classes, numpy.ndarray):
|
713
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
714
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
715
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
716
|
+
output_cols = []
|
717
|
+
for i, cl in enumerate(classes):
|
718
|
+
# For binary classification, there is only one output column for each class
|
719
|
+
# ndarray as the two classes are complementary.
|
720
|
+
if len(cl) == 2:
|
721
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
722
|
+
else:
|
723
|
+
output_cols.extend([
|
724
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
725
|
+
])
|
726
|
+
else:
|
727
|
+
output_cols = []
|
728
|
+
|
729
|
+
# Make sure column names are valid snowflake identifiers.
|
730
|
+
assert output_cols is not None # Make MyPy happy
|
731
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
732
|
+
|
733
|
+
return rv
|
734
|
+
|
735
|
+
def _align_expected_output_names(
|
736
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
737
|
+
) -> List[str]:
|
738
|
+
# in case the inferred output column names dimension is different
|
739
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
740
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
741
|
+
output_df_columns = list(output_df_pd.columns)
|
742
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
743
|
+
if self.sample_weight_col:
|
744
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
745
|
+
# if the dimension of inferred output column names is correct; use it
|
746
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
747
|
+
return expected_output_cols_list
|
748
|
+
# otherwise, use the sklearn estimator's output
|
749
|
+
else:
|
750
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
685
751
|
|
686
752
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
687
753
|
@telemetry.send_api_usage_telemetry(
|
@@ -715,24 +781,26 @@ class RandomForestClassifier(BaseTransformer):
|
|
715
781
|
# are specific to the type of dataset used.
|
716
782
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
717
783
|
|
784
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
785
|
+
|
718
786
|
if isinstance(dataset, DataFrame):
|
719
|
-
self.
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
787
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
788
|
+
self._deps = self._get_dependencies()
|
789
|
+
assert isinstance(
|
790
|
+
dataset._session, Session
|
791
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
724
792
|
transform_kwargs = dict(
|
725
793
|
session=dataset._session,
|
726
794
|
dependencies=self._deps,
|
727
|
-
drop_input_cols
|
795
|
+
drop_input_cols=self._drop_input_cols,
|
728
796
|
expected_output_cols_type="float",
|
729
797
|
)
|
798
|
+
expected_output_cols = self._align_expected_output_names(
|
799
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
800
|
+
)
|
730
801
|
|
731
802
|
elif isinstance(dataset, pd.DataFrame):
|
732
|
-
transform_kwargs = dict(
|
733
|
-
snowpark_input_cols = self._snowpark_cols,
|
734
|
-
drop_input_cols = self._drop_input_cols
|
735
|
-
)
|
803
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
736
804
|
|
737
805
|
transform_handlers = ModelTransformerBuilder.build(
|
738
806
|
dataset=dataset,
|
@@ -744,7 +812,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
744
812
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
745
813
|
inference_method=inference_method,
|
746
814
|
input_cols=self.input_cols,
|
747
|
-
expected_output_cols=
|
815
|
+
expected_output_cols=expected_output_cols,
|
748
816
|
**transform_kwargs
|
749
817
|
)
|
750
818
|
return output_df
|
@@ -776,29 +844,30 @@ class RandomForestClassifier(BaseTransformer):
|
|
776
844
|
Output dataset with log probability of the sample for each class in the model.
|
777
845
|
"""
|
778
846
|
super()._check_dataset_type(dataset)
|
779
|
-
inference_method="predict_log_proba"
|
847
|
+
inference_method = "predict_log_proba"
|
848
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
780
849
|
|
781
850
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
782
851
|
# are specific to the type of dataset used.
|
783
852
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
784
853
|
|
785
854
|
if isinstance(dataset, DataFrame):
|
786
|
-
self.
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
857
|
+
assert isinstance(
|
858
|
+
dataset._session, Session
|
859
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
791
860
|
transform_kwargs = dict(
|
792
861
|
session=dataset._session,
|
793
862
|
dependencies=self._deps,
|
794
|
-
drop_input_cols
|
863
|
+
drop_input_cols=self._drop_input_cols,
|
795
864
|
expected_output_cols_type="float",
|
796
865
|
)
|
866
|
+
expected_output_cols = self._align_expected_output_names(
|
867
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
868
|
+
)
|
797
869
|
elif isinstance(dataset, pd.DataFrame):
|
798
|
-
transform_kwargs = dict(
|
799
|
-
snowpark_input_cols = self._snowpark_cols,
|
800
|
-
drop_input_cols = self._drop_input_cols
|
801
|
-
)
|
870
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
802
871
|
|
803
872
|
transform_handlers = ModelTransformerBuilder.build(
|
804
873
|
dataset=dataset,
|
@@ -811,7 +880,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
811
880
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
812
881
|
inference_method=inference_method,
|
813
882
|
input_cols=self.input_cols,
|
814
|
-
expected_output_cols=
|
883
|
+
expected_output_cols=expected_output_cols,
|
815
884
|
**transform_kwargs
|
816
885
|
)
|
817
886
|
return output_df
|
@@ -837,30 +906,32 @@ class RandomForestClassifier(BaseTransformer):
|
|
837
906
|
Output dataset with results of the decision function for the samples in input dataset.
|
838
907
|
"""
|
839
908
|
super()._check_dataset_type(dataset)
|
840
|
-
inference_method="decision_function"
|
909
|
+
inference_method = "decision_function"
|
841
910
|
|
842
911
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
843
912
|
# are specific to the type of dataset used.
|
844
913
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
845
914
|
|
915
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
916
|
+
|
846
917
|
if isinstance(dataset, DataFrame):
|
847
|
-
self.
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
918
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
919
|
+
self._deps = self._get_dependencies()
|
920
|
+
assert isinstance(
|
921
|
+
dataset._session, Session
|
922
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
852
923
|
transform_kwargs = dict(
|
853
924
|
session=dataset._session,
|
854
925
|
dependencies=self._deps,
|
855
|
-
drop_input_cols
|
926
|
+
drop_input_cols=self._drop_input_cols,
|
856
927
|
expected_output_cols_type="float",
|
857
928
|
)
|
929
|
+
expected_output_cols = self._align_expected_output_names(
|
930
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
931
|
+
)
|
858
932
|
|
859
933
|
elif isinstance(dataset, pd.DataFrame):
|
860
|
-
transform_kwargs = dict(
|
861
|
-
snowpark_input_cols = self._snowpark_cols,
|
862
|
-
drop_input_cols = self._drop_input_cols
|
863
|
-
)
|
934
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
864
935
|
|
865
936
|
transform_handlers = ModelTransformerBuilder.build(
|
866
937
|
dataset=dataset,
|
@@ -873,7 +944,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
873
944
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
874
945
|
inference_method=inference_method,
|
875
946
|
input_cols=self.input_cols,
|
876
|
-
expected_output_cols=
|
947
|
+
expected_output_cols=expected_output_cols,
|
877
948
|
**transform_kwargs
|
878
949
|
)
|
879
950
|
return output_df
|
@@ -902,17 +973,17 @@ class RandomForestClassifier(BaseTransformer):
|
|
902
973
|
Output dataset with probability of the sample for each class in the model.
|
903
974
|
"""
|
904
975
|
super()._check_dataset_type(dataset)
|
905
|
-
inference_method="score_samples"
|
976
|
+
inference_method = "score_samples"
|
906
977
|
|
907
978
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
908
979
|
# are specific to the type of dataset used.
|
909
980
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
910
981
|
|
982
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
983
|
+
|
911
984
|
if isinstance(dataset, DataFrame):
|
912
|
-
self.
|
913
|
-
|
914
|
-
inference_method=inference_method,
|
915
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
986
|
+
self._deps = self._get_dependencies()
|
916
987
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
917
988
|
transform_kwargs = dict(
|
918
989
|
session=dataset._session,
|
@@ -920,6 +991,9 @@ class RandomForestClassifier(BaseTransformer):
|
|
920
991
|
drop_input_cols = self._drop_input_cols,
|
921
992
|
expected_output_cols_type="float",
|
922
993
|
)
|
994
|
+
expected_output_cols = self._align_expected_output_names(
|
995
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
996
|
+
)
|
923
997
|
|
924
998
|
elif isinstance(dataset, pd.DataFrame):
|
925
999
|
transform_kwargs = dict(
|
@@ -938,7 +1012,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
938
1012
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
939
1013
|
inference_method=inference_method,
|
940
1014
|
input_cols=self.input_cols,
|
941
|
-
expected_output_cols=
|
1015
|
+
expected_output_cols=expected_output_cols,
|
942
1016
|
**transform_kwargs
|
943
1017
|
)
|
944
1018
|
return output_df
|
@@ -973,17 +1047,15 @@ class RandomForestClassifier(BaseTransformer):
|
|
973
1047
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
974
1048
|
|
975
1049
|
if isinstance(dataset, DataFrame):
|
976
|
-
self.
|
977
|
-
|
978
|
-
inference_method="score",
|
979
|
-
)
|
1050
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1051
|
+
self._deps = self._get_dependencies()
|
980
1052
|
selected_cols = self._get_active_columns()
|
981
1053
|
if len(selected_cols) > 0:
|
982
1054
|
dataset = dataset.select(selected_cols)
|
983
1055
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
984
1056
|
transform_kwargs = dict(
|
985
1057
|
session=dataset._session,
|
986
|
-
dependencies=
|
1058
|
+
dependencies=self._deps,
|
987
1059
|
score_sproc_imports=['sklearn'],
|
988
1060
|
)
|
989
1061
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1048,11 +1120,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
1048
1120
|
|
1049
1121
|
if isinstance(dataset, DataFrame):
|
1050
1122
|
|
1051
|
-
self.
|
1052
|
-
|
1053
|
-
inference_method=inference_method,
|
1054
|
-
|
1055
|
-
)
|
1123
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1124
|
+
self._deps = self._get_dependencies()
|
1056
1125
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1057
1126
|
transform_kwargs = dict(
|
1058
1127
|
session = dataset._session,
|
@@ -1085,50 +1154,84 @@ class RandomForestClassifier(BaseTransformer):
|
|
1085
1154
|
)
|
1086
1155
|
return output_df
|
1087
1156
|
|
1157
|
+
|
1158
|
+
|
1159
|
+
def to_sklearn(self) -> Any:
|
1160
|
+
"""Get sklearn.ensemble.RandomForestClassifier object.
|
1161
|
+
"""
|
1162
|
+
if self._sklearn_object is None:
|
1163
|
+
self._sklearn_object = self._create_sklearn_object()
|
1164
|
+
return self._sklearn_object
|
1165
|
+
|
1166
|
+
def to_xgboost(self) -> Any:
|
1167
|
+
raise exceptions.SnowflakeMLException(
|
1168
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1169
|
+
original_exception=AttributeError(
|
1170
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1171
|
+
"to_xgboost()",
|
1172
|
+
"to_sklearn()"
|
1173
|
+
)
|
1174
|
+
),
|
1175
|
+
)
|
1176
|
+
|
1177
|
+
def to_lightgbm(self) -> Any:
|
1178
|
+
raise exceptions.SnowflakeMLException(
|
1179
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1180
|
+
original_exception=AttributeError(
|
1181
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1182
|
+
"to_lightgbm()",
|
1183
|
+
"to_sklearn()"
|
1184
|
+
)
|
1185
|
+
),
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
def _get_dependencies(self) -> List[str]:
|
1189
|
+
return self._deps
|
1190
|
+
|
1088
1191
|
|
1089
|
-
def
|
1192
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1090
1193
|
self._model_signature_dict = dict()
|
1091
1194
|
|
1092
1195
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1093
1196
|
|
1094
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1197
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1095
1198
|
outputs: List[BaseFeatureSpec] = []
|
1096
1199
|
if hasattr(self, "predict"):
|
1097
1200
|
# keep mypy happy
|
1098
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1201
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1099
1202
|
# For classifier, the type of predict is the same as the type of label
|
1100
|
-
if self._sklearn_object._estimator_type ==
|
1101
|
-
|
1203
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1204
|
+
# label columns is the desired type for output
|
1102
1205
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1103
1206
|
# rename the output columns
|
1104
1207
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1105
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1106
|
-
|
1107
|
-
|
1208
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1209
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1210
|
+
)
|
1108
1211
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1109
1212
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1110
|
-
# Clusterer returns int64 cluster labels.
|
1213
|
+
# Clusterer returns int64 cluster labels.
|
1111
1214
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1112
1215
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1113
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1216
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1217
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1218
|
+
)
|
1219
|
+
|
1117
1220
|
# For regressor, the type of predict is float64
|
1118
|
-
elif self._sklearn_object._estimator_type ==
|
1221
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1119
1222
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1120
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1223
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1224
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1225
|
+
)
|
1226
|
+
|
1124
1227
|
for prob_func in PROB_FUNCTIONS:
|
1125
1228
|
if hasattr(self, prob_func):
|
1126
1229
|
output_cols_prefix: str = f"{prob_func}_"
|
1127
1230
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1128
1231
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1129
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1130
|
-
|
1131
|
-
|
1232
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1233
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1234
|
+
)
|
1132
1235
|
|
1133
1236
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1134
1237
|
items = list(self._model_signature_dict.items())
|
@@ -1141,10 +1244,10 @@ class RandomForestClassifier(BaseTransformer):
|
|
1141
1244
|
"""Returns model signature of current class.
|
1142
1245
|
|
1143
1246
|
Raises:
|
1144
|
-
|
1247
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1145
1248
|
|
1146
1249
|
Returns:
|
1147
|
-
Dict
|
1250
|
+
Dict with each method and its input output signature
|
1148
1251
|
"""
|
1149
1252
|
if self._model_signature_dict is None:
|
1150
1253
|
raise exceptions.SnowflakeMLException(
|
@@ -1152,35 +1255,3 @@ class RandomForestClassifier(BaseTransformer):
|
|
1152
1255
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1153
1256
|
)
|
1154
1257
|
return self._model_signature_dict
|
1155
|
-
|
1156
|
-
def to_sklearn(self) -> Any:
|
1157
|
-
"""Get sklearn.ensemble.RandomForestClassifier object.
|
1158
|
-
"""
|
1159
|
-
if self._sklearn_object is None:
|
1160
|
-
self._sklearn_object = self._create_sklearn_object()
|
1161
|
-
return self._sklearn_object
|
1162
|
-
|
1163
|
-
def to_xgboost(self) -> Any:
|
1164
|
-
raise exceptions.SnowflakeMLException(
|
1165
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1166
|
-
original_exception=AttributeError(
|
1167
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1168
|
-
"to_xgboost()",
|
1169
|
-
"to_sklearn()"
|
1170
|
-
)
|
1171
|
-
),
|
1172
|
-
)
|
1173
|
-
|
1174
|
-
def to_lightgbm(self) -> Any:
|
1175
|
-
raise exceptions.SnowflakeMLException(
|
1176
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1177
|
-
original_exception=AttributeError(
|
1178
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1179
|
-
"to_lightgbm()",
|
1180
|
-
"to_sklearn()"
|
1181
|
-
)
|
1182
|
-
),
|
1183
|
-
)
|
1184
|
-
|
1185
|
-
def _get_dependencies(self) -> List[str]:
|
1186
|
-
return self._deps
|