snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class RadiusNeighborsClassifier(BaseTransformer):
|
64
72
|
r"""Classifier implementing a vote among neighbors within a given radius
|
65
73
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class RadiusNeighborsRegressor(BaseTransformer):
|
64
72
|
r"""Regression based on neighbors within a fixed radius
|
65
73
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class BernoulliRBM(BaseTransformer):
|
64
72
|
r"""Bernoulli Restricted Boltzmann Machine (RBM)
|
65
73
|
For more details on this class, see [sklearn.neural_network.BernoulliRBM]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class MLPClassifier(BaseTransformer):
|
64
72
|
r"""Multi-layer Perceptron classifier
|
65
73
|
For more details on this class, see [sklearn.neural_network.MLPClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class MLPRegressor(BaseTransformer):
|
64
72
|
r"""Multi-layer Perceptron regressor
|
65
73
|
For more details on this class, see [sklearn.neural_network.MLPRegressor]
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import posixpath
|
5
5
|
import tempfile
|
6
6
|
from itertools import chain
|
7
|
-
from typing import Any, Callable,
|
7
|
+
from typing import Any, Callable, Optional, Union
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
10
10
|
import numpy as np
|
@@ -63,7 +63,7 @@ def has_callable_attr(obj: object, attr: str) -> bool:
|
|
63
63
|
return callable(getattr(obj, attr, None))
|
64
64
|
|
65
65
|
|
66
|
-
def _get_column_indices(all_columns:
|
66
|
+
def _get_column_indices(all_columns: list[str], target_columns: list[str]) -> list[int]:
|
67
67
|
"""
|
68
68
|
Extract the indices of the target_columns from all_columns.
|
69
69
|
|
@@ -96,7 +96,7 @@ def _get_column_indices(all_columns: List[str], target_columns: List[str]) -> Li
|
|
96
96
|
|
97
97
|
|
98
98
|
class Pipeline(base.BaseTransformer):
|
99
|
-
def __init__(self, steps:
|
99
|
+
def __init__(self, steps: list[tuple[str, Any]]) -> None:
|
100
100
|
"""
|
101
101
|
Pipeline of transforms.
|
102
102
|
|
@@ -119,14 +119,14 @@ class Pipeline(base.BaseTransformer):
|
|
119
119
|
# to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
|
120
120
|
self._is_final_step_estimator = Pipeline._is_estimator(steps[-1][1])
|
121
121
|
self._is_fitted = False
|
122
|
-
self._feature_names_in:
|
123
|
-
self._n_features_in:
|
124
|
-
self._transformers_to_input_indices:
|
122
|
+
self._feature_names_in: list[np.ndarray[Any, np.dtype[Any]]] = []
|
123
|
+
self._n_features_in: list[int] = []
|
124
|
+
self._transformers_to_input_indices: dict[str, list[int]] = {}
|
125
125
|
self._modifies_label_or_sample_weight = True
|
126
126
|
|
127
|
-
self._model_signature_dict: Optional[
|
127
|
+
self._model_signature_dict: Optional[dict[str, ModelSignature]] = None
|
128
128
|
|
129
|
-
deps:
|
129
|
+
deps: set[str] = {f"pandas=={pd.__version__}", f"scikit-learn=={skversion}"}
|
130
130
|
for _, obj in steps:
|
131
131
|
if isinstance(obj, base.BaseTransformer):
|
132
132
|
deps = deps | set(obj._get_dependencies())
|
@@ -146,10 +146,10 @@ class Pipeline(base.BaseTransformer):
|
|
146
146
|
def _is_transformer(obj: object) -> bool:
|
147
147
|
return has_callable_attr(obj, "fit") and has_callable_attr(obj, "transform")
|
148
148
|
|
149
|
-
def _get_transformers(self) ->
|
149
|
+
def _get_transformers(self) -> list[tuple[str, Any]]:
|
150
150
|
return self.steps[:-1] if self._is_final_step_estimator else self.steps
|
151
151
|
|
152
|
-
def _get_estimator(self) -> Optional[
|
152
|
+
def _get_estimator(self) -> Optional[tuple[str, Any]]:
|
153
153
|
return self.steps[-1] if self._is_final_step_estimator else None
|
154
154
|
|
155
155
|
def _validate_steps(self) -> None:
|
@@ -215,7 +215,7 @@ class Pipeline(base.BaseTransformer):
|
|
215
215
|
processed_cols = set(chain.from_iterable([trans.get_input_cols() for (_, trans) in self._get_transformers()]))
|
216
216
|
return len(target_cols & processed_cols) > 0
|
217
217
|
|
218
|
-
def _get_sanitized_list_of_columns(self, columns:
|
218
|
+
def _get_sanitized_list_of_columns(self, columns: list[str]) -> list[str]:
|
219
219
|
"""
|
220
220
|
Removes the label and sample_weight columns from the input list of columns and returns the results for the
|
221
221
|
purpous of computing column indices for SKLearn ColumnTransformer objects.
|
@@ -237,7 +237,7 @@ class Pipeline(base.BaseTransformer):
|
|
237
237
|
|
238
238
|
return [c for c in columns if c not in target_cols]
|
239
239
|
|
240
|
-
def _append_step_feature_consumption_info(self, step_name: str, all_cols:
|
240
|
+
def _append_step_feature_consumption_info(self, step_name: str, all_cols: list[str], input_cols: list[str]) -> None:
|
241
241
|
if self._modifies_label_or_sample_weight:
|
242
242
|
all_cols = self._get_sanitized_list_of_columns(all_cols)
|
243
243
|
self._feature_names_in.append(np.asarray(all_cols, dtype=object))
|
@@ -269,7 +269,7 @@ class Pipeline(base.BaseTransformer):
|
|
269
269
|
|
270
270
|
return transformed_dataset
|
271
271
|
|
272
|
-
def _upload_model_to_stage(self, stage_name: str, estimator: object, session: Session) ->
|
272
|
+
def _upload_model_to_stage(self, stage_name: str, estimator: object, session: Session) -> tuple[str, str]:
|
273
273
|
"""
|
274
274
|
Util method to pickle and upload the model to a temp Snowflake stage.
|
275
275
|
|
@@ -331,10 +331,10 @@ class Pipeline(base.BaseTransformer):
|
|
331
331
|
|
332
332
|
def pipeline_within_one_sproc(
|
333
333
|
session: Session,
|
334
|
-
sql_queries:
|
334
|
+
sql_queries: list[str],
|
335
335
|
stage_estimator_file_name: str,
|
336
336
|
stage_result_file_name: str,
|
337
|
-
sproc_statement_params:
|
337
|
+
sproc_statement_params: dict[str, str],
|
338
338
|
) -> str:
|
339
339
|
import os
|
340
340
|
|
@@ -774,7 +774,7 @@ class Pipeline(base.BaseTransformer):
|
|
774
774
|
|
775
775
|
return ct
|
776
776
|
|
777
|
-
def _get_label_cols(self) ->
|
777
|
+
def _get_label_cols(self) -> list[str]:
|
778
778
|
"""Util function to get the label columns from the pipeline.
|
779
779
|
The label column is only present in the estimator
|
780
780
|
|
@@ -885,7 +885,7 @@ class Pipeline(base.BaseTransformer):
|
|
885
885
|
|
886
886
|
return pipeline.Pipeline(steps=sksteps)
|
887
887
|
|
888
|
-
def _get_dependencies(self) ->
|
888
|
+
def _get_dependencies(self) -> list[str]:
|
889
889
|
return self._deps
|
890
890
|
|
891
891
|
def _generate_model_signatures(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None:
|
@@ -919,7 +919,7 @@ class Pipeline(base.BaseTransformer):
|
|
919
919
|
)
|
920
920
|
|
921
921
|
@property
|
922
|
-
def model_signatures(self) ->
|
922
|
+
def model_signatures(self) -> dict[str, ModelSignature]:
|
923
923
|
if self._model_signature_dict is None:
|
924
924
|
raise exceptions.SnowflakeMLException(
|
925
925
|
error_code=error_codes.INVALID_ATTRIBUTE,
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|
@@ -2,7 +2,7 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
4
|
from itertools import chain
|
5
|
-
from typing import
|
5
|
+
from typing import Iterable, Optional, Union, cast
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -104,7 +104,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
104
104
|
def __init__(
|
105
105
|
self,
|
106
106
|
*,
|
107
|
-
n_bins: Union[int,
|
107
|
+
n_bins: Union[int, list[int]] = 5,
|
108
108
|
encode: str = "onehot",
|
109
109
|
strategy: str = "quantile",
|
110
110
|
input_cols: Optional[Union[str, Iterable[str]]] = None,
|
@@ -229,7 +229,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
229
229
|
# https://docs.google.com/document/d/1cilfCCtKYv6HvHqaqdZxfHAvQ0gg-t1AM8KYCQtJiLE/edit
|
230
230
|
agg_queries = []
|
231
231
|
for idx, col_name in enumerate(self.input_cols):
|
232
|
-
percentiles = np.linspace(0, 1, cast(
|
232
|
+
percentiles = np.linspace(0, 1, cast(list[int], self.n_bins)[idx] + 1)
|
233
233
|
for i, pct in enumerate(percentiles.tolist()):
|
234
234
|
agg_queries.append(F.percentile_cont(pct).within_group(col_name).alias(f"{col_name}_pct_{i}"))
|
235
235
|
state_df = dataset.agg(agg_queries)
|
@@ -246,7 +246,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
246
246
|
self.bin_edges_ = np.zeros(len(self.input_cols), dtype=object)
|
247
247
|
self.n_bins_ = np.zeros(len(self.input_cols), dtype=np.int_)
|
248
248
|
start = 0
|
249
|
-
for i, b in enumerate(cast(
|
249
|
+
for i, b in enumerate(cast(list[int], self.n_bins)):
|
250
250
|
self.bin_edges_[i] = decimal_to_float(state[start : start + b + 1])
|
251
251
|
start += b + 1
|
252
252
|
self.n_bins_[i] = len(self.bin_edges_[i]) - 1
|
@@ -275,7 +275,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
275
275
|
# 2. Populate internal state variables
|
276
276
|
self.bin_edges_ = np.zeros(len(self.input_cols), dtype=object)
|
277
277
|
self.n_bins_ = np.zeros(len(self.input_cols), dtype=np.int_)
|
278
|
-
for i, b in enumerate(cast(
|
278
|
+
for i, b in enumerate(cast(list[int], self.n_bins)):
|
279
279
|
self.bin_edges_[i] = np.linspace(state[i * 2], state[i * 2 + 1], b + 1)
|
280
280
|
self.n_bins_[i] = len(self.bin_edges_[i]) - 1
|
281
281
|
|
@@ -345,7 +345,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
345
345
|
session=dataset._session,
|
346
346
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
347
347
|
)
|
348
|
-
def vec_bucketize_temp(x: T.PandasSeries[float], boarders: T.PandasSeries[
|
348
|
+
def vec_bucketize_temp(x: T.PandasSeries[float], boarders: T.PandasSeries[list[float]]) -> T.PandasSeries[int]:
|
349
349
|
# NB: vectorized udf doesn't work well with const array arg, so we pass it in as a list via PandasSeries
|
350
350
|
boarders = boarders[0]
|
351
351
|
res = np.searchsorted(boarders[1:-1], x, side="right")
|
@@ -387,9 +387,9 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
387
387
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
388
388
|
)
|
389
389
|
def vec_bucketize_sparse_output_temp(
|
390
|
-
x: T.PandasSeries[float], boarders: T.PandasSeries[
|
391
|
-
) -> T.PandasSeries[
|
392
|
-
res:
|
390
|
+
x: T.PandasSeries[float], boarders: T.PandasSeries[list[float]]
|
391
|
+
) -> T.PandasSeries[dict[str, int]]:
|
392
|
+
res: list[dict[str, int]] = []
|
393
393
|
boarders = boarders[0]
|
394
394
|
buckets = np.searchsorted(boarders[1:-1], x, side="right")
|
395
395
|
assert isinstance(buckets, np.ndarray), f"expecting buckets to be numpy ndarray, got {type(buckets)}"
|
@@ -434,9 +434,9 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
434
434
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
435
435
|
)
|
436
436
|
def vec_bucketize_dense_output_temp(
|
437
|
-
x: T.PandasSeries[float], boarders: T.PandasSeries[
|
438
|
-
) -> T.PandasSeries[
|
439
|
-
res:
|
437
|
+
x: T.PandasSeries[float], boarders: T.PandasSeries[list[float]]
|
438
|
+
) -> T.PandasSeries[list[int]]:
|
439
|
+
res: list[npt.NDArray[np.int32]] = []
|
440
440
|
boarders = boarders[0]
|
441
441
|
buckets = np.searchsorted(boarders[1:-1], x, side="right")
|
442
442
|
assert isinstance(buckets, np.ndarray), f"expecting buckets to be numpy ndarray, got {type(buckets)}"
|
@@ -491,7 +491,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
491
491
|
else:
|
492
492
|
return transformed_dataset
|
493
493
|
|
494
|
-
def get_output_cols(self) ->
|
494
|
+
def get_output_cols(self) -> list[str]:
|
495
495
|
"""
|
496
496
|
Get output column names.
|
497
497
|
Expand output column names for 'onehot-dense' encoding.
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
from typing import
|
2
|
+
from typing import Iterable, Optional, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
@@ -88,10 +88,10 @@ class MaxAbsScaler(base.BaseTransformer):
|
|
88
88
|
max_abs_: dict {column_name: value} or None
|
89
89
|
Per feature maximum absolute value.
|
90
90
|
"""
|
91
|
-
self.max_abs_:
|
92
|
-
self.scale_:
|
91
|
+
self.max_abs_: dict[str, float] = {}
|
92
|
+
self.scale_: dict[str, float] = {}
|
93
93
|
|
94
|
-
self.custom_states:
|
94
|
+
self.custom_states: list[str] = [
|
95
95
|
"SQL>>>max(abs({col_name}))",
|
96
96
|
]
|
97
97
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
from typing import
|
2
|
+
from typing import Iterable, Optional, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
@@ -69,7 +69,7 @@ class MinMaxScaler(base.BaseTransformer):
|
|
69
69
|
def __init__(
|
70
70
|
self,
|
71
71
|
*,
|
72
|
-
feature_range:
|
72
|
+
feature_range: tuple[float, float] = (0, 1),
|
73
73
|
clip: bool = False,
|
74
74
|
input_cols: Optional[Union[str, Iterable[str]]] = None,
|
75
75
|
output_cols: Optional[Union[str, Iterable[str]]] = None,
|
@@ -101,13 +101,13 @@ class MinMaxScaler(base.BaseTransformer):
|
|
101
101
|
self.feature_range = feature_range
|
102
102
|
self.clip = clip
|
103
103
|
|
104
|
-
self.min_:
|
105
|
-
self.scale_:
|
106
|
-
self.data_min_:
|
107
|
-
self.data_max_:
|
108
|
-
self.data_range_:
|
104
|
+
self.min_: dict[str, float] = {}
|
105
|
+
self.scale_: dict[str, float] = {}
|
106
|
+
self.data_min_: dict[str, float] = {}
|
107
|
+
self.data_max_: dict[str, float] = {}
|
108
|
+
self.data_range_: dict[str, float] = {}
|
109
109
|
|
110
|
-
self.custom_states:
|
110
|
+
self.custom_states: list[str] = [_utils.NumericStatistics.MIN, _utils.NumericStatistics.MAX]
|
111
111
|
|
112
112
|
super().__init__(drop_input_cols=drop_input_cols, custom_states=self.custom_states)
|
113
113
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import numbers
|
3
3
|
import uuid
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Iterable, Optional, Union
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import numpy.typing as npt
|
@@ -214,7 +214,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
214
214
|
def __init__(
|
215
215
|
self,
|
216
216
|
*,
|
217
|
-
categories: Union[str,
|
217
|
+
categories: Union[str, list[type_utils.LiteralNDArrayType], dict[str, type_utils.LiteralNDArrayType]] = "auto",
|
218
218
|
drop: Optional[Union[str, npt.ArrayLike]] = None,
|
219
219
|
sparse: bool = False,
|
220
220
|
handle_unknown: str = "error",
|
@@ -238,23 +238,23 @@ class OneHotEncoder(base.BaseTransformer):
|
|
238
238
|
) or self.min_frequency is not None
|
239
239
|
|
240
240
|
# Fit state
|
241
|
-
self.categories_:
|
242
|
-
self._categories_list:
|
241
|
+
self.categories_: dict[str, type_utils.LiteralNDArrayType] = {}
|
242
|
+
self._categories_list: list[type_utils.LiteralNDArrayType] = []
|
243
243
|
self.drop_idx_: Optional[npt.NDArray[np.int_]] = None
|
244
244
|
self._drop_idx_after_grouping: Optional[npt.NDArray[np.int_]] = None
|
245
|
-
self._n_features_outs:
|
246
|
-
self._snowpark_cols:
|
245
|
+
self._n_features_outs: list[int] = []
|
246
|
+
self._snowpark_cols: dict[str, list[str]] = dict()
|
247
247
|
|
248
248
|
# Fit state if output columns are set before fitting
|
249
|
-
self._dense_output_cols_mappings:
|
250
|
-
self._inferred_output_cols:
|
249
|
+
self._dense_output_cols_mappings: dict[str, list[str]] = {}
|
250
|
+
self._inferred_output_cols: list[str] = []
|
251
251
|
|
252
252
|
self.set_input_cols(input_cols)
|
253
253
|
self.set_output_cols(output_cols)
|
254
254
|
self.set_passthrough_cols(passthrough_cols)
|
255
255
|
|
256
256
|
@property
|
257
|
-
def infrequent_categories_(self) ->
|
257
|
+
def infrequent_categories_(self) -> list[Optional[type_utils.LiteralNDArrayType]]:
|
258
258
|
"""Infrequent categories for each feature."""
|
259
259
|
# raises an AttributeError if `_infrequent_indices` is not defined
|
260
260
|
infrequent_indices = self._infrequent_indices
|
@@ -329,7 +329,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
329
329
|
self._drop_idx_after_grouping = sklearn_encoder.drop_idx_
|
330
330
|
self._n_features_outs = sklearn_encoder._n_features_outs
|
331
331
|
|
332
|
-
_state_pandas_counts:
|
332
|
+
_state_pandas_counts: list[pd.DataFrame] = []
|
333
333
|
for idx, input_col in enumerate(self.input_cols):
|
334
334
|
self.categories_[input_col] = self._categories_list[idx]
|
335
335
|
_column_counts = (
|
@@ -362,7 +362,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
362
362
|
self._n_features_outs = self._compute_n_features_outs()
|
363
363
|
self._update_categories_state()
|
364
364
|
|
365
|
-
def _fit_category_state(self, dataset: snowpark.DataFrame, return_counts: bool) ->
|
365
|
+
def _fit_category_state(self, dataset: snowpark.DataFrame, return_counts: bool) -> dict[str, Any]:
|
366
366
|
"""
|
367
367
|
Get the number of samples, categories and (optional) category counts of dataset.
|
368
368
|
Fitted categories are assigned to the object.
|
@@ -552,7 +552,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
552
552
|
categories_pandas = categories_pandas.rename(columns={_STATE: categories_col})
|
553
553
|
|
554
554
|
# {column_name: ndarray([category])}
|
555
|
-
categories:
|
555
|
+
categories: dict[str, type_utils.LiteralNDArrayType] = categories_pandas.set_index(_COLUMN_NAME).to_dict()[
|
556
556
|
categories_col
|
557
557
|
]
|
558
558
|
# Giving the original type back to categories.
|
@@ -769,7 +769,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
769
769
|
|
770
770
|
state_pandas = self._state_pandas
|
771
771
|
|
772
|
-
def map_encoded_value(row: pd.Series) ->
|
772
|
+
def map_encoded_value(row: pd.Series) -> dict[str, Any]:
|
773
773
|
n_features_out = row[_N_FEATURES_OUT]
|
774
774
|
encoding = row[_ENCODING]
|
775
775
|
encoded_value = {str(encoding): 1, "array_length": n_features_out}
|
@@ -836,7 +836,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
836
836
|
"""
|
837
837
|
state_pandas = self._state_pandas
|
838
838
|
|
839
|
-
def map_encoded_value(row: pd.Series) ->
|
839
|
+
def map_encoded_value(row: pd.Series) -> list[int]:
|
840
840
|
n_features_out = row[_N_FEATURES_OUT]
|
841
841
|
encoding = row[_ENCODING]
|
842
842
|
encoded_value = [0] * n_features_out
|
@@ -934,7 +934,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
934
934
|
packages=["numpy", "scikit-learn"],
|
935
935
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
936
936
|
)
|
937
|
-
def one_hot_encoder_sparse_transform(data: pd.DataFrame) ->
|
937
|
+
def one_hot_encoder_sparse_transform(data: pd.DataFrame) -> list[list[Optional[dict[Any, Any]]]]:
|
938
938
|
data = data.replace({np.nan: None}) # fill NA with None as represented in `categories_`
|
939
939
|
transformed_csr = encoder_sklearn.transform(data)
|
940
940
|
transformed_coo = transformed_csr.tocoo()
|
@@ -943,7 +943,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
943
943
|
transformed_vals = []
|
944
944
|
for _, row in data.iterrows():
|
945
945
|
base_encoding = 0
|
946
|
-
row_transformed_vals:
|
946
|
+
row_transformed_vals: list[Optional[dict[Any, Any]]] = []
|
947
947
|
for col_idx, val in row.items():
|
948
948
|
if val in encoder_sklearn.categories_[col_idx] or encoder_sklearn.handle_unknown != "ignore":
|
949
949
|
if col_idx > 0:
|
@@ -1101,7 +1101,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1101
1101
|
def _handle_unknown_in_transform(
|
1102
1102
|
self,
|
1103
1103
|
transformed_dataset: snowpark.DataFrame,
|
1104
|
-
input_cols: Optional[
|
1104
|
+
input_cols: Optional[list[str]] = None,
|
1105
1105
|
) -> snowpark.DataFrame:
|
1106
1106
|
"""
|
1107
1107
|
Handle unknown values in the transformed dataset.
|
@@ -1206,7 +1206,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1206
1206
|
if not self._infrequent_enabled:
|
1207
1207
|
return drop_idx
|
1208
1208
|
|
1209
|
-
default_to_infrequent: Optional[
|
1209
|
+
default_to_infrequent: Optional[list[int]] = self._default_to_infrequent_mappings[feature_idx]
|
1210
1210
|
if default_to_infrequent is None:
|
1211
1211
|
return drop_idx
|
1212
1212
|
|
@@ -1346,7 +1346,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1346
1346
|
self.drop_idx_ = np.asarray(drop_idx_, dtype=object)
|
1347
1347
|
|
1348
1348
|
def _fit_infrequent_category_mapping(
|
1349
|
-
self, n_samples: int, category_counts:
|
1349
|
+
self, n_samples: int, category_counts: dict[str, dict[str, dict[str, int]]]
|
1350
1350
|
) -> None:
|
1351
1351
|
"""
|
1352
1352
|
Fit infrequent categories.
|
@@ -1442,7 +1442,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1442
1442
|
output = np.flatnonzero(infrequent_mask)
|
1443
1443
|
return output if output.size > 0 else None
|
1444
1444
|
|
1445
|
-
def _compute_n_features_outs(self) ->
|
1445
|
+
def _compute_n_features_outs(self) -> list[int]:
|
1446
1446
|
"""Compute the n_features_out for each input feature."""
|
1447
1447
|
output = [len(cats) for cats in self._categories_list]
|
1448
1448
|
|
@@ -1463,7 +1463,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1463
1463
|
|
1464
1464
|
return output
|
1465
1465
|
|
1466
|
-
def get_output_cols(self) ->
|
1466
|
+
def get_output_cols(self) -> list[str]:
|
1467
1467
|
"""
|
1468
1468
|
Output columns getter.
|
1469
1469
|
|
@@ -1472,7 +1472,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1472
1472
|
"""
|
1473
1473
|
return self._inferred_output_cols
|
1474
1474
|
|
1475
|
-
def _get_inferred_output_cols(self) ->
|
1475
|
+
def _get_inferred_output_cols(self) -> list[str]:
|
1476
1476
|
"""
|
1477
1477
|
Get output column names meeting Snowflake requirements.
|
1478
1478
|
Only useful when fitting a pandas dataframe.
|
@@ -1556,11 +1556,11 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1556
1556
|
sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
|
1557
1557
|
sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
|
1558
1558
|
snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
|
1559
|
-
sklearn_added_keyword_to_version_dict: Optional[
|
1560
|
-
sklearn_added_kwarg_value_to_version_dict: Optional[
|
1561
|
-
sklearn_deprecated_keyword_to_version_dict: Optional[
|
1562
|
-
sklearn_removed_keyword_to_version_dict: Optional[
|
1563
|
-
) ->
|
1559
|
+
sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
1560
|
+
sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
|
1561
|
+
sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
1562
|
+
sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
1563
|
+
) -> dict[str, Any]:
|
1564
1564
|
"""Modified snowflake.ml.framework.base.Base.get_sklearn_args with `sparse` and `sparse_output` handling."""
|
1565
1565
|
default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
|
1566
1566
|
given_args = self.get_params()
|
@@ -1580,7 +1580,7 @@ class OneHotEncoder(base.BaseTransformer):
|
|
1580
1580
|
if version.parse(sklearn_version) >= version.parse(_SKLEARN_DEPRECATED_KEYWORD_TO_VERSION_DICT["sparse"]):
|
1581
1581
|
given_args["sparse_output"] = given_args.pop("sparse")
|
1582
1582
|
|
1583
|
-
sklearn_args:
|
1583
|
+
sklearn_args: dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
|
1584
1584
|
args=given_args,
|
1585
1585
|
default_sklearn_args=default_sklearn_args,
|
1586
1586
|
sklearn_initial_keywords=sklearn_initial_keywords,
|