snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +11 -1
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/feature_store/feature_store.py +151 -78
- snowflake/ml/feature_store/feature_view.py +12 -24
- snowflake/ml/fileset/sfcfs.py +56 -50
- snowflake/ml/fileset/stage_fs.py +48 -13
- snowflake/ml/model/_client/model/model_version_impl.py +2 -50
- snowflake/ml/model/_client/ops/model_ops.py +78 -29
- snowflake/ml/model/_client/sql/model.py +23 -2
- snowflake/ml/model/_client/sql/model_version.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -2
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
- snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
- snowflake/ml/modeling/cluster/birch.py +195 -123
- snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
- snowflake/ml/modeling/cluster/dbscan.py +195 -123
- snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
- snowflake/ml/modeling/cluster/k_means.py +195 -123
- snowflake/ml/modeling/cluster/mean_shift.py +195 -123
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
- snowflake/ml/modeling/cluster/optics.py +195 -123
- snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
- snowflake/ml/modeling/compose/column_transformer.py +195 -123
- snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
- snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
- snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
- snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
- snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
- snowflake/ml/modeling/covariance/oas.py +195 -123
- snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
- snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
- snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
- snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
- snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/pca.py +195 -123
- snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
- snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
- snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
- snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
- snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
- snowflake/ml/modeling/impute/knn_imputer.py +195 -123
- snowflake/ml/modeling/impute/missing_indicator.py +195 -123
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/lars.py +195 -123
- snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
- snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/perceptron.py +195 -123
- snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ridge.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
- snowflake/ml/modeling/manifold/isomap.py +195 -123
- snowflake/ml/modeling/manifold/mds.py +195 -123
- snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
- snowflake/ml/modeling/manifold/tsne.py +195 -123
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
- snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
- snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
- snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
- snowflake/ml/modeling/svm/linear_svc.py +195 -123
- snowflake/ml/modeling/svm/linear_svr.py +195 -123
- snowflake/ml/modeling/svm/nu_svc.py +195 -123
- snowflake/ml/modeling/svm/nu_svr.py +195 -123
- snowflake/ml/modeling/svm/svc.py +195 -123
- snowflake/ml/modeling/svm/svr.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -322,12 +321,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
322
321
|
)
|
323
322
|
return selected_cols
|
324
323
|
|
325
|
-
|
326
|
-
project=_PROJECT,
|
327
|
-
subproject=_SUBPROJECT,
|
328
|
-
custom_tags=dict([("autogen", True)]),
|
329
|
-
)
|
330
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreeClassifier":
|
324
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreeClassifier":
|
331
325
|
"""Build a decision tree classifier from the training set (X, y)
|
332
326
|
For more details on this function, see [sklearn.tree.ExtraTreeClassifier.fit]
|
333
327
|
(https://scikit-learn.org/stable/modules/generated/sklearn.tree.ExtraTreeClassifier.html#sklearn.tree.ExtraTreeClassifier.fit)
|
@@ -354,12 +348,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
354
348
|
|
355
349
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
356
350
|
|
357
|
-
|
351
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
358
352
|
if SNOWML_SPROC_ENV in os.environ:
|
359
353
|
statement_params = telemetry.get_function_usage_statement_params(
|
360
354
|
project=_PROJECT,
|
361
355
|
subproject=_SUBPROJECT,
|
362
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
356
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
357
|
+
inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
|
358
|
+
),
|
363
359
|
api_calls=[Session.call],
|
364
360
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
365
361
|
)
|
@@ -380,7 +376,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
380
376
|
)
|
381
377
|
self._sklearn_object = model_trainer.train()
|
382
378
|
self._is_fitted = True
|
383
|
-
self.
|
379
|
+
self._generate_model_signatures(dataset)
|
384
380
|
return self
|
385
381
|
|
386
382
|
def _batch_inference_validate_snowpark(
|
@@ -456,7 +452,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
456
452
|
# when it is classifier, infer the datatype from label columns
|
457
453
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
458
454
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
459
|
-
label_cols_signatures = [
|
455
|
+
label_cols_signatures = [
|
456
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
457
|
+
]
|
460
458
|
if len(label_cols_signatures) == 0:
|
461
459
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
462
460
|
raise exceptions.SnowflakeMLException(
|
@@ -464,25 +462,22 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
464
462
|
original_exception=ValueError(error_str),
|
465
463
|
)
|
466
464
|
|
467
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
468
|
-
label_cols_signatures[0].as_snowpark_type()
|
469
|
-
)
|
465
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
470
466
|
|
471
467
|
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
472
|
-
assert isinstance(
|
468
|
+
assert isinstance(
|
469
|
+
dataset._session, Session
|
470
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
473
471
|
|
474
472
|
transform_kwargs = dict(
|
475
|
-
session
|
476
|
-
dependencies
|
477
|
-
drop_input_cols
|
478
|
-
expected_output_cols_type
|
473
|
+
session=dataset._session,
|
474
|
+
dependencies=self._deps,
|
475
|
+
drop_input_cols=self._drop_input_cols,
|
476
|
+
expected_output_cols_type=expected_type_inferred,
|
479
477
|
)
|
480
478
|
|
481
479
|
elif isinstance(dataset, pd.DataFrame):
|
482
|
-
transform_kwargs = dict(
|
483
|
-
snowpark_input_cols = self._snowpark_cols,
|
484
|
-
drop_input_cols = self._drop_input_cols
|
485
|
-
)
|
480
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
486
481
|
|
487
482
|
transform_handlers = ModelTransformerBuilder.build(
|
488
483
|
dataset=dataset,
|
@@ -522,7 +517,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
522
517
|
Transformed dataset.
|
523
518
|
"""
|
524
519
|
super()._check_dataset_type(dataset)
|
525
|
-
inference_method="transform"
|
520
|
+
inference_method = "transform"
|
526
521
|
|
527
522
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
528
523
|
# are specific to the type of dataset used.
|
@@ -559,17 +554,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
559
554
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
560
555
|
|
561
556
|
transform_kwargs = dict(
|
562
|
-
session
|
563
|
-
dependencies
|
564
|
-
drop_input_cols
|
565
|
-
expected_output_cols_type
|
557
|
+
session=dataset._session,
|
558
|
+
dependencies=self._deps,
|
559
|
+
drop_input_cols=self._drop_input_cols,
|
560
|
+
expected_output_cols_type=expected_dtype,
|
566
561
|
)
|
567
562
|
|
568
563
|
elif isinstance(dataset, pd.DataFrame):
|
569
|
-
transform_kwargs = dict(
|
570
|
-
snowpark_input_cols = self._snowpark_cols,
|
571
|
-
drop_input_cols = self._drop_input_cols
|
572
|
-
)
|
564
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
573
565
|
|
574
566
|
transform_handlers = ModelTransformerBuilder.build(
|
575
567
|
dataset=dataset,
|
@@ -588,7 +580,11 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
588
580
|
return output_df
|
589
581
|
|
590
582
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
591
|
-
def fit_predict(
|
583
|
+
def fit_predict(
|
584
|
+
self,
|
585
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
586
|
+
output_cols_prefix: str = "fit_predict_",
|
587
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
592
588
|
""" Method not supported for this class.
|
593
589
|
|
594
590
|
|
@@ -613,7 +609,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
613
609
|
)
|
614
610
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
615
611
|
drop_input_cols=self._drop_input_cols,
|
616
|
-
expected_output_cols_list=
|
612
|
+
expected_output_cols_list=(
|
613
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
614
|
+
),
|
617
615
|
)
|
618
616
|
self._sklearn_object = fitted_estimator
|
619
617
|
self._is_fitted = True
|
@@ -630,6 +628,62 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
630
628
|
assert self._sklearn_object is not None
|
631
629
|
return self._sklearn_object.embedding_
|
632
630
|
|
631
|
+
|
632
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
633
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
634
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
635
|
+
"""
|
636
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
637
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
638
|
+
if output_cols:
|
639
|
+
output_cols = [
|
640
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
641
|
+
for c in output_cols
|
642
|
+
]
|
643
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
644
|
+
output_cols = [output_cols_prefix]
|
645
|
+
elif self._sklearn_object is not None:
|
646
|
+
classes = self._sklearn_object.classes_
|
647
|
+
if isinstance(classes, numpy.ndarray):
|
648
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
649
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
650
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
651
|
+
output_cols = []
|
652
|
+
for i, cl in enumerate(classes):
|
653
|
+
# For binary classification, there is only one output column for each class
|
654
|
+
# ndarray as the two classes are complementary.
|
655
|
+
if len(cl) == 2:
|
656
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
657
|
+
else:
|
658
|
+
output_cols.extend([
|
659
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
660
|
+
])
|
661
|
+
else:
|
662
|
+
output_cols = []
|
663
|
+
|
664
|
+
# Make sure column names are valid snowflake identifiers.
|
665
|
+
assert output_cols is not None # Make MyPy happy
|
666
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
667
|
+
|
668
|
+
return rv
|
669
|
+
|
670
|
+
def _align_expected_output_names(
|
671
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
672
|
+
) -> List[str]:
|
673
|
+
# in case the inferred output column names dimension is different
|
674
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
675
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
676
|
+
output_df_columns = list(output_df_pd.columns)
|
677
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
678
|
+
if self.sample_weight_col:
|
679
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
680
|
+
# if the dimension of inferred output column names is correct; use it
|
681
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
682
|
+
return expected_output_cols_list
|
683
|
+
# otherwise, use the sklearn estimator's output
|
684
|
+
else:
|
685
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
686
|
+
|
633
687
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
634
688
|
@telemetry.send_api_usage_telemetry(
|
635
689
|
project=_PROJECT,
|
@@ -662,24 +716,28 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
662
716
|
# are specific to the type of dataset used.
|
663
717
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
664
718
|
|
719
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
720
|
+
|
665
721
|
if isinstance(dataset, DataFrame):
|
666
722
|
self._deps = self._batch_inference_validate_snowpark(
|
667
723
|
dataset=dataset,
|
668
724
|
inference_method=inference_method,
|
669
725
|
)
|
670
|
-
assert isinstance(
|
726
|
+
assert isinstance(
|
727
|
+
dataset._session, Session
|
728
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
671
729
|
transform_kwargs = dict(
|
672
730
|
session=dataset._session,
|
673
731
|
dependencies=self._deps,
|
674
|
-
drop_input_cols
|
732
|
+
drop_input_cols=self._drop_input_cols,
|
675
733
|
expected_output_cols_type="float",
|
676
734
|
)
|
735
|
+
expected_output_cols = self._align_expected_output_names(
|
736
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
737
|
+
)
|
677
738
|
|
678
739
|
elif isinstance(dataset, pd.DataFrame):
|
679
|
-
transform_kwargs = dict(
|
680
|
-
snowpark_input_cols = self._snowpark_cols,
|
681
|
-
drop_input_cols = self._drop_input_cols
|
682
|
-
)
|
740
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
683
741
|
|
684
742
|
transform_handlers = ModelTransformerBuilder.build(
|
685
743
|
dataset=dataset,
|
@@ -691,7 +749,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
691
749
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
692
750
|
inference_method=inference_method,
|
693
751
|
input_cols=self.input_cols,
|
694
|
-
expected_output_cols=
|
752
|
+
expected_output_cols=expected_output_cols,
|
695
753
|
**transform_kwargs
|
696
754
|
)
|
697
755
|
return output_df
|
@@ -723,7 +781,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
723
781
|
Output dataset with log probability of the sample for each class in the model.
|
724
782
|
"""
|
725
783
|
super()._check_dataset_type(dataset)
|
726
|
-
inference_method="predict_log_proba"
|
784
|
+
inference_method = "predict_log_proba"
|
785
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
727
786
|
|
728
787
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
729
788
|
# are specific to the type of dataset used.
|
@@ -734,18 +793,20 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
734
793
|
dataset=dataset,
|
735
794
|
inference_method=inference_method,
|
736
795
|
)
|
737
|
-
assert isinstance(
|
796
|
+
assert isinstance(
|
797
|
+
dataset._session, Session
|
798
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
738
799
|
transform_kwargs = dict(
|
739
800
|
session=dataset._session,
|
740
801
|
dependencies=self._deps,
|
741
|
-
drop_input_cols
|
802
|
+
drop_input_cols=self._drop_input_cols,
|
742
803
|
expected_output_cols_type="float",
|
743
804
|
)
|
805
|
+
expected_output_cols = self._align_expected_output_names(
|
806
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
807
|
+
)
|
744
808
|
elif isinstance(dataset, pd.DataFrame):
|
745
|
-
transform_kwargs = dict(
|
746
|
-
snowpark_input_cols = self._snowpark_cols,
|
747
|
-
drop_input_cols = self._drop_input_cols
|
748
|
-
)
|
809
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
749
810
|
|
750
811
|
transform_handlers = ModelTransformerBuilder.build(
|
751
812
|
dataset=dataset,
|
@@ -758,7 +819,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
758
819
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
759
820
|
inference_method=inference_method,
|
760
821
|
input_cols=self.input_cols,
|
761
|
-
expected_output_cols=
|
822
|
+
expected_output_cols=expected_output_cols,
|
762
823
|
**transform_kwargs
|
763
824
|
)
|
764
825
|
return output_df
|
@@ -784,30 +845,34 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
784
845
|
Output dataset with results of the decision function for the samples in input dataset.
|
785
846
|
"""
|
786
847
|
super()._check_dataset_type(dataset)
|
787
|
-
inference_method="decision_function"
|
848
|
+
inference_method = "decision_function"
|
788
849
|
|
789
850
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
790
851
|
# are specific to the type of dataset used.
|
791
852
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
792
853
|
|
854
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
855
|
+
|
793
856
|
if isinstance(dataset, DataFrame):
|
794
857
|
self._deps = self._batch_inference_validate_snowpark(
|
795
858
|
dataset=dataset,
|
796
859
|
inference_method=inference_method,
|
797
860
|
)
|
798
|
-
assert isinstance(
|
861
|
+
assert isinstance(
|
862
|
+
dataset._session, Session
|
863
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
799
864
|
transform_kwargs = dict(
|
800
865
|
session=dataset._session,
|
801
866
|
dependencies=self._deps,
|
802
|
-
drop_input_cols
|
867
|
+
drop_input_cols=self._drop_input_cols,
|
803
868
|
expected_output_cols_type="float",
|
804
869
|
)
|
870
|
+
expected_output_cols = self._align_expected_output_names(
|
871
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
872
|
+
)
|
805
873
|
|
806
874
|
elif isinstance(dataset, pd.DataFrame):
|
807
|
-
transform_kwargs = dict(
|
808
|
-
snowpark_input_cols = self._snowpark_cols,
|
809
|
-
drop_input_cols = self._drop_input_cols
|
810
|
-
)
|
875
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
811
876
|
|
812
877
|
transform_handlers = ModelTransformerBuilder.build(
|
813
878
|
dataset=dataset,
|
@@ -820,7 +885,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
820
885
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
821
886
|
inference_method=inference_method,
|
822
887
|
input_cols=self.input_cols,
|
823
|
-
expected_output_cols=
|
888
|
+
expected_output_cols=expected_output_cols,
|
824
889
|
**transform_kwargs
|
825
890
|
)
|
826
891
|
return output_df
|
@@ -849,12 +914,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
849
914
|
Output dataset with probability of the sample for each class in the model.
|
850
915
|
"""
|
851
916
|
super()._check_dataset_type(dataset)
|
852
|
-
inference_method="score_samples"
|
917
|
+
inference_method = "score_samples"
|
853
918
|
|
854
919
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
855
920
|
# are specific to the type of dataset used.
|
856
921
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
857
922
|
|
923
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
924
|
+
|
858
925
|
if isinstance(dataset, DataFrame):
|
859
926
|
self._deps = self._batch_inference_validate_snowpark(
|
860
927
|
dataset=dataset,
|
@@ -867,6 +934,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
867
934
|
drop_input_cols = self._drop_input_cols,
|
868
935
|
expected_output_cols_type="float",
|
869
936
|
)
|
937
|
+
expected_output_cols = self._align_expected_output_names(
|
938
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
939
|
+
)
|
870
940
|
|
871
941
|
elif isinstance(dataset, pd.DataFrame):
|
872
942
|
transform_kwargs = dict(
|
@@ -885,7 +955,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
885
955
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
886
956
|
inference_method=inference_method,
|
887
957
|
input_cols=self.input_cols,
|
888
|
-
expected_output_cols=
|
958
|
+
expected_output_cols=expected_output_cols,
|
889
959
|
**transform_kwargs
|
890
960
|
)
|
891
961
|
return output_df
|
@@ -1032,50 +1102,84 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1032
1102
|
)
|
1033
1103
|
return output_df
|
1034
1104
|
|
1105
|
+
|
1106
|
+
|
1107
|
+
def to_sklearn(self) -> Any:
|
1108
|
+
"""Get sklearn.tree.ExtraTreeClassifier object.
|
1109
|
+
"""
|
1110
|
+
if self._sklearn_object is None:
|
1111
|
+
self._sklearn_object = self._create_sklearn_object()
|
1112
|
+
return self._sklearn_object
|
1113
|
+
|
1114
|
+
def to_xgboost(self) -> Any:
|
1115
|
+
raise exceptions.SnowflakeMLException(
|
1116
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1117
|
+
original_exception=AttributeError(
|
1118
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1119
|
+
"to_xgboost()",
|
1120
|
+
"to_sklearn()"
|
1121
|
+
)
|
1122
|
+
),
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
def to_lightgbm(self) -> Any:
|
1126
|
+
raise exceptions.SnowflakeMLException(
|
1127
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1128
|
+
original_exception=AttributeError(
|
1129
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1130
|
+
"to_lightgbm()",
|
1131
|
+
"to_sklearn()"
|
1132
|
+
)
|
1133
|
+
),
|
1134
|
+
)
|
1035
1135
|
|
1036
|
-
def
|
1136
|
+
def _get_dependencies(self) -> List[str]:
|
1137
|
+
return self._deps
|
1138
|
+
|
1139
|
+
|
1140
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1037
1141
|
self._model_signature_dict = dict()
|
1038
1142
|
|
1039
1143
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1040
1144
|
|
1041
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1145
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1042
1146
|
outputs: List[BaseFeatureSpec] = []
|
1043
1147
|
if hasattr(self, "predict"):
|
1044
1148
|
# keep mypy happy
|
1045
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1149
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1046
1150
|
# For classifier, the type of predict is the same as the type of label
|
1047
|
-
if self._sklearn_object._estimator_type ==
|
1048
|
-
|
1151
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1152
|
+
# label columns is the desired type for output
|
1049
1153
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1050
1154
|
# rename the output columns
|
1051
1155
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1052
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1053
|
-
|
1054
|
-
|
1156
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1157
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1158
|
+
)
|
1055
1159
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1056
1160
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1057
|
-
# Clusterer returns int64 cluster labels.
|
1161
|
+
# Clusterer returns int64 cluster labels.
|
1058
1162
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1059
1163
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1060
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1164
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1165
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1166
|
+
)
|
1167
|
+
|
1064
1168
|
# For regressor, the type of predict is float64
|
1065
|
-
elif self._sklearn_object._estimator_type ==
|
1169
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1066
1170
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1067
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1171
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1172
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1173
|
+
)
|
1174
|
+
|
1071
1175
|
for prob_func in PROB_FUNCTIONS:
|
1072
1176
|
if hasattr(self, prob_func):
|
1073
1177
|
output_cols_prefix: str = f"{prob_func}_"
|
1074
1178
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1075
1179
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1076
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1077
|
-
|
1078
|
-
|
1180
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1181
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1182
|
+
)
|
1079
1183
|
|
1080
1184
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1081
1185
|
items = list(self._model_signature_dict.items())
|
@@ -1088,10 +1192,10 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1088
1192
|
"""Returns model signature of current class.
|
1089
1193
|
|
1090
1194
|
Raises:
|
1091
|
-
|
1195
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1092
1196
|
|
1093
1197
|
Returns:
|
1094
|
-
Dict
|
1198
|
+
Dict with each method and its input output signature
|
1095
1199
|
"""
|
1096
1200
|
if self._model_signature_dict is None:
|
1097
1201
|
raise exceptions.SnowflakeMLException(
|
@@ -1099,35 +1203,3 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1099
1203
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1100
1204
|
)
|
1101
1205
|
return self._model_signature_dict
|
1102
|
-
|
1103
|
-
def to_sklearn(self) -> Any:
|
1104
|
-
"""Get sklearn.tree.ExtraTreeClassifier object.
|
1105
|
-
"""
|
1106
|
-
if self._sklearn_object is None:
|
1107
|
-
self._sklearn_object = self._create_sklearn_object()
|
1108
|
-
return self._sklearn_object
|
1109
|
-
|
1110
|
-
def to_xgboost(self) -> Any:
|
1111
|
-
raise exceptions.SnowflakeMLException(
|
1112
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1113
|
-
original_exception=AttributeError(
|
1114
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1115
|
-
"to_xgboost()",
|
1116
|
-
"to_sklearn()"
|
1117
|
-
)
|
1118
|
-
),
|
1119
|
-
)
|
1120
|
-
|
1121
|
-
def to_lightgbm(self) -> Any:
|
1122
|
-
raise exceptions.SnowflakeMLException(
|
1123
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1124
|
-
original_exception=AttributeError(
|
1125
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1126
|
-
"to_lightgbm()",
|
1127
|
-
"to_sklearn()"
|
1128
|
-
)
|
1129
|
-
),
|
1130
|
-
)
|
1131
|
-
|
1132
|
-
def _get_dependencies(self) -> List[str]:
|
1133
|
-
return self._deps
|