snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class Ridge(BaseTransformer):
|
64
72
|
r"""Linear least squares with l2 regularization
|
65
73
|
For more details on this class, see [sklearn.linear_model.Ridge]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class RidgeClassifier(BaseTransformer):
|
64
72
|
r"""Classifier using Ridge regression
|
65
73
|
For more details on this class, see [sklearn.linear_model.RidgeClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class RidgeClassifierCV(BaseTransformer):
|
64
72
|
r"""Ridge classifier with built-in cross-validation
|
65
73
|
For more details on this class, see [sklearn.linear_model.RidgeClassifierCV]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class RidgeCV(BaseTransformer):
|
64
72
|
r"""Ridge regression with built-in cross-validation
|
65
73
|
For more details on this class, see [sklearn.linear_model.RidgeCV]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SGDClassifier(BaseTransformer):
|
64
72
|
r"""Linear classifiers (SVM, logistic regression, etc
|
65
73
|
For more details on this class, see [sklearn.linear_model.SGDClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SGDOneClassSVM(BaseTransformer):
|
64
72
|
r"""Solves linear One-Class SVM using Stochastic Gradient Descent
|
65
73
|
For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SGDRegressor(BaseTransformer):
|
64
72
|
r"""Linear model fitted by minimizing a regularized empirical loss with SGD
|
65
73
|
For more details on this class, see [sklearn.linear_model.SGDRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class TheilSenRegressor(BaseTransformer):
|
64
72
|
r"""Theil-Sen Estimator: robust multivariate regression model
|
65
73
|
For more details on this class, see [sklearn.linear_model.TheilSenRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class TweedieRegressor(BaseTransformer):
|
64
72
|
r"""Generalized Linear Model with a Tweedie distribution
|
65
73
|
For more details on this class, see [sklearn.linear_model.TweedieRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class Isomap(BaseTransformer):
|
64
72
|
r"""Isomap Embedding
|
65
73
|
For more details on this class, see [sklearn.manifold.Isomap]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class MDS(BaseTransformer):
|
64
72
|
r"""Multidimensional scaling
|
65
73
|
For more details on this class, see [sklearn.manifold.MDS]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SpectralEmbedding(BaseTransformer):
|
64
72
|
r"""Spectral embedding for non-linear dimensionality reduction
|
65
73
|
For more details on this class, see [sklearn.manifold.SpectralEmbedding]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class TSNE(BaseTransformer):
|
64
72
|
r"""T-distributed Stochastic Neighbor Embedding
|
65
73
|
For more details on this class, see [sklearn.manifold.TSNE]
|
@@ -5,7 +5,7 @@ import cloudpickle
|
|
5
5
|
from snowflake.ml._internal import init_utils
|
6
6
|
from snowflake.ml._internal.utils import result
|
7
7
|
|
8
|
-
pkg_dir = os.path.dirname(
|
8
|
+
pkg_dir = os.path.dirname(__file__)
|
9
9
|
pkg_name = __name__
|
10
10
|
exportable_functions = init_utils.fetch_functions_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
11
11
|
for k, v in exportable_functions.items():
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import json
|
3
3
|
import math
|
4
4
|
import warnings
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Iterable, Optional, Union
|
6
6
|
|
7
7
|
import cloudpickle
|
8
8
|
import numpy as np
|
@@ -32,8 +32,8 @@ _SUBPROJECT = "Metrics"
|
|
32
32
|
def accuracy_score(
|
33
33
|
*,
|
34
34
|
df: snowpark.DataFrame,
|
35
|
-
y_true_col_names: Union[str,
|
36
|
-
y_pred_col_names: Union[str,
|
35
|
+
y_true_col_names: Union[str, list[str]],
|
36
|
+
y_pred_col_names: Union[str, list[str]],
|
37
37
|
normalize: bool = True,
|
38
38
|
sample_weight_col_name: Optional[str] = None,
|
39
39
|
) -> float:
|
@@ -221,7 +221,7 @@ def confusion_matrix(
|
|
221
221
|
return cm
|
222
222
|
|
223
223
|
|
224
|
-
def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params:
|
224
|
+
def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params: dict[str, Any]) -> str:
|
225
225
|
"""Registers confusion matrix computation UDTF in Snowflake and returns the name of the UDTF.
|
226
226
|
|
227
227
|
Args:
|
@@ -247,7 +247,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
247
247
|
# Number of labels.
|
248
248
|
self._n_label = 0
|
249
249
|
|
250
|
-
def process(self, input_row:
|
250
|
+
def process(self, input_row: list[float], n_label: int) -> None:
|
251
251
|
"""Computes confusion matrix.
|
252
252
|
|
253
253
|
Args:
|
@@ -270,7 +270,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
270
270
|
self.update_confusion_matrix()
|
271
271
|
self._cur_count = 0
|
272
272
|
|
273
|
-
def end_partition(self) -> Iterable[
|
273
|
+
def end_partition(self) -> Iterable[tuple[bytes, str]]:
|
274
274
|
# 3. Compute sum and dot_prod for the remaining rows in the batch.
|
275
275
|
if self._cur_count > 0:
|
276
276
|
self.update_confusion_matrix()
|
@@ -313,8 +313,8 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
313
313
|
def f1_score(
|
314
314
|
*,
|
315
315
|
df: snowpark.DataFrame,
|
316
|
-
y_true_col_names: Union[str,
|
317
|
-
y_pred_col_names: Union[str,
|
316
|
+
y_true_col_names: Union[str, list[str]],
|
317
|
+
y_pred_col_names: Union[str, list[str]],
|
318
318
|
labels: Optional[npt.ArrayLike] = None,
|
319
319
|
pos_label: Union[str, int] = 1,
|
320
320
|
average: Optional[str] = "binary",
|
@@ -406,8 +406,8 @@ def f1_score(
|
|
406
406
|
def fbeta_score(
|
407
407
|
*,
|
408
408
|
df: snowpark.DataFrame,
|
409
|
-
y_true_col_names: Union[str,
|
410
|
-
y_pred_col_names: Union[str,
|
409
|
+
y_true_col_names: Union[str, list[str]],
|
410
|
+
y_pred_col_names: Union[str, list[str]],
|
411
411
|
beta: float,
|
412
412
|
labels: Optional[npt.ArrayLike] = None,
|
413
413
|
pos_label: Union[str, int] = 1,
|
@@ -501,8 +501,8 @@ def fbeta_score(
|
|
501
501
|
def log_loss(
|
502
502
|
*,
|
503
503
|
df: snowpark.DataFrame,
|
504
|
-
y_true_col_names: Union[str,
|
505
|
-
y_pred_col_names: Union[str,
|
504
|
+
y_true_col_names: Union[str, list[str]],
|
505
|
+
y_pred_col_names: Union[str, list[str]],
|
506
506
|
eps: Union[float, str] = "auto",
|
507
507
|
normalize: bool = True,
|
508
508
|
sample_weight_col_name: Optional[str] = None,
|
@@ -625,7 +625,7 @@ def log_loss(
|
|
625
625
|
def _register_log_loss_computer(
|
626
626
|
*,
|
627
627
|
session: snowpark.Session,
|
628
|
-
statement_params:
|
628
|
+
statement_params: dict[str, Any],
|
629
629
|
labels: Optional[npt.ArrayLike] = None,
|
630
630
|
) -> str:
|
631
631
|
"""Registers log loss computation UDTF in Snowflake and returns the name of the UDTF.
|
@@ -644,16 +644,16 @@ def _register_log_loss_computer(
|
|
644
644
|
class LogLossComputer:
|
645
645
|
def __init__(self) -> None:
|
646
646
|
self._labels = labels
|
647
|
-
self._y_true:
|
648
|
-
self._y_pred:
|
649
|
-
self._sample_weight:
|
647
|
+
self._y_true: list[list[int]] = []
|
648
|
+
self._y_pred: list[list[float]] = []
|
649
|
+
self._sample_weight: list[float] = []
|
650
650
|
|
651
|
-
def process(self, y_true:
|
651
|
+
def process(self, y_true: list[int], y_pred: list[float], sample_weight: float) -> None:
|
652
652
|
self._y_true.append(y_true)
|
653
653
|
self._y_pred.append(y_pred)
|
654
654
|
self._sample_weight.append(sample_weight)
|
655
655
|
|
656
|
-
def end_partition(self) -> Iterable[
|
656
|
+
def end_partition(self) -> Iterable[tuple[float]]:
|
657
657
|
res = metrics.log_loss(
|
658
658
|
self._y_true,
|
659
659
|
self._y_pred,
|
@@ -685,18 +685,18 @@ def _register_log_loss_computer(
|
|
685
685
|
def precision_recall_fscore_support(
|
686
686
|
*,
|
687
687
|
df: snowpark.DataFrame,
|
688
|
-
y_true_col_names: Union[str,
|
689
|
-
y_pred_col_names: Union[str,
|
688
|
+
y_true_col_names: Union[str, list[str]],
|
689
|
+
y_pred_col_names: Union[str, list[str]],
|
690
690
|
beta: float = 1.0,
|
691
691
|
labels: Optional[npt.ArrayLike] = None,
|
692
692
|
pos_label: Union[str, int] = 1,
|
693
693
|
average: Optional[str] = None,
|
694
|
-
warn_for: Union[
|
694
|
+
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
695
695
|
sample_weight_col_name: Optional[str] = None,
|
696
696
|
zero_division: Union[str, int] = "warn",
|
697
697
|
) -> Union[
|
698
|
-
|
699
|
-
|
698
|
+
tuple[float, float, float, None],
|
699
|
+
tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
|
700
700
|
]:
|
701
701
|
"""
|
702
702
|
Compute precision, recall, F-measure and support for each class.
|
@@ -854,8 +854,8 @@ def precision_recall_fscore_support(
|
|
854
854
|
result_object = result.deserialize(session, precision_recall_fscore_support_anon_sproc(session, **kwargs))
|
855
855
|
|
856
856
|
res: Union[
|
857
|
-
|
858
|
-
|
857
|
+
tuple[float, float, float, None],
|
858
|
+
tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
|
859
859
|
] = result_object[:4]
|
860
860
|
warning = result_object[-1]
|
861
861
|
if warning:
|
@@ -1039,18 +1039,18 @@ def _register_multilabel_confusion_matrix_computer(
|
|
1039
1039
|
def __init__(self) -> None:
|
1040
1040
|
self._labels = labels
|
1041
1041
|
self._samplewise = samplewise
|
1042
|
-
self._y_true:
|
1043
|
-
self._y_pred:
|
1044
|
-
self._sample_weight:
|
1042
|
+
self._y_true: list[list[int]] = []
|
1043
|
+
self._y_pred: list[list[int]] = []
|
1044
|
+
self._sample_weight: list[float] = []
|
1045
1045
|
|
1046
|
-
def process(self, y_true:
|
1046
|
+
def process(self, y_true: list[int], y_pred: list[int], sample_weight: float) -> None:
|
1047
1047
|
self._y_true.append(y_true)
|
1048
1048
|
self._y_pred.append(y_pred)
|
1049
1049
|
self._sample_weight.append(sample_weight)
|
1050
1050
|
|
1051
1051
|
def end_partition(
|
1052
1052
|
self,
|
1053
|
-
) -> Iterable[
|
1053
|
+
) -> Iterable[tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
|
1054
1054
|
MCM = metrics.multilabel_confusion_matrix(
|
1055
1055
|
self._y_true,
|
1056
1056
|
self._y_pred,
|
@@ -1093,8 +1093,8 @@ def _register_multilabel_confusion_matrix_computer(
|
|
1093
1093
|
def _binary_precision_score(
|
1094
1094
|
*,
|
1095
1095
|
df: snowpark.DataFrame,
|
1096
|
-
y_true_col_names: Union[str,
|
1097
|
-
y_pred_col_names: Union[str,
|
1096
|
+
y_true_col_names: Union[str, list[str]],
|
1097
|
+
y_pred_col_names: Union[str, list[str]],
|
1098
1098
|
pos_label: Union[str, int] = 1,
|
1099
1099
|
sample_weight_col_name: Optional[str] = None,
|
1100
1100
|
zero_division: Union[str, int] = "warn",
|
@@ -1166,8 +1166,8 @@ def _binary_precision_score(
|
|
1166
1166
|
def precision_score(
|
1167
1167
|
*,
|
1168
1168
|
df: snowpark.DataFrame,
|
1169
|
-
y_true_col_names: Union[str,
|
1170
|
-
y_pred_col_names: Union[str,
|
1169
|
+
y_true_col_names: Union[str, list[str]],
|
1170
|
+
y_pred_col_names: Union[str, list[str]],
|
1171
1171
|
labels: Optional[npt.ArrayLike] = None,
|
1172
1172
|
pos_label: Union[str, int] = 1,
|
1173
1173
|
average: Optional[str] = "binary",
|
@@ -1264,8 +1264,8 @@ def precision_score(
|
|
1264
1264
|
def recall_score(
|
1265
1265
|
*,
|
1266
1266
|
df: snowpark.DataFrame,
|
1267
|
-
y_true_col_names: Union[str,
|
1268
|
-
y_pred_col_names: Union[str,
|
1267
|
+
y_true_col_names: Union[str, list[str]],
|
1268
|
+
y_pred_col_names: Union[str, list[str]],
|
1269
1269
|
labels: Optional[npt.ArrayLike] = None,
|
1270
1270
|
pos_label: Union[str, int] = 1,
|
1271
1271
|
average: Optional[str] = "binary",
|
@@ -1376,9 +1376,9 @@ def _sum_array_col(df: snowpark.DataFrame, col_name: str) -> snowpark.DataFrame:
|
|
1376
1376
|
|
1377
1377
|
|
1378
1378
|
def _check_binary_labels(
|
1379
|
-
labels:
|
1379
|
+
labels: list[Any],
|
1380
1380
|
pos_label: Union[str, int] = 1,
|
1381
|
-
) ->
|
1381
|
+
) -> list[Any]:
|
1382
1382
|
"""Validation associated with binary average labels.
|
1383
1383
|
|
1384
1384
|
Args:
|
@@ -1411,7 +1411,7 @@ def _prf_divide(
|
|
1411
1411
|
metric: str,
|
1412
1412
|
modifier: str,
|
1413
1413
|
average: Optional[str] = None,
|
1414
|
-
warn_for: Union[
|
1414
|
+
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
1415
1415
|
zero_division: Union[str, int] = "warn",
|
1416
1416
|
) -> npt.NDArray[np.float_]:
|
1417
1417
|
"""Performs division and handles divide-by-zero.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import math
|
2
2
|
import warnings
|
3
|
-
from typing import Any, Collection,
|
3
|
+
from typing import Any, Collection, Iterable, Optional, Union
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -18,7 +18,7 @@ INDEX = "INDEX"
|
|
18
18
|
BATCH_SIZE = 1000
|
19
19
|
|
20
20
|
|
21
|
-
def register_accumulator_udtf(*, session: Session, statement_params:
|
21
|
+
def register_accumulator_udtf(*, session: Session, statement_params: dict[str, Any]) -> str:
|
22
22
|
"""Registers accumulator UDTF in Snowflake and returns the name of the UDTF.
|
23
23
|
|
24
24
|
Args:
|
@@ -47,7 +47,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
|
|
47
47
|
else:
|
48
48
|
self._accumulated_row = self._accumulated_row + row
|
49
49
|
|
50
|
-
def end_partition(self) -> Iterable[
|
50
|
+
def end_partition(self) -> Iterable[tuple[bytes]]:
|
51
51
|
yield (cloudpickle.dumps(self._accumulated_row),)
|
52
52
|
|
53
53
|
accumulator = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE_FUNCTION)
|
@@ -68,7 +68,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
|
|
68
68
|
return accumulator
|
69
69
|
|
70
70
|
|
71
|
-
def register_sharded_dot_sum_computer(*, session: Session, statement_params:
|
71
|
+
def register_sharded_dot_sum_computer(*, session: Session, statement_params: dict[str, Any]) -> str:
|
72
72
|
"""Registers dot and sum computation UDTF in Snowflake and returns the name of the UDTF.
|
73
73
|
|
74
74
|
Args:
|
@@ -110,7 +110,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
110
110
|
# Square root of count - ddof
|
111
111
|
self._sqrt_count_d = -1.0
|
112
112
|
|
113
|
-
def process(self, input_row:
|
113
|
+
def process(self, input_row: list[float], count: int, ddof: int) -> None:
|
114
114
|
"""Computes sum and dot product.
|
115
115
|
|
116
116
|
Args:
|
@@ -138,7 +138,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
138
138
|
self.accumulate_batch_sum_and_dot_prod()
|
139
139
|
self._cur_count = 0
|
140
140
|
|
141
|
-
def end_partition(self) -> Iterable[
|
141
|
+
def end_partition(self) -> Iterable[tuple[bytes, str]]:
|
142
142
|
# 3. Compute sum and dot_prod for the remaining rows in the batch.
|
143
143
|
if self._cur_count > 0:
|
144
144
|
self.accumulate_batch_sum_and_dot_prod()
|
@@ -185,7 +185,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
185
185
|
|
186
186
|
def validate_and_return_dataframe_and_columns(
|
187
187
|
*, df: snowpark.DataFrame, columns: Optional[Collection[str]] = None
|
188
|
-
) ->
|
188
|
+
) -> tuple[snowpark.DataFrame, Collection[str]]:
|
189
189
|
"""Validates that the columns are all numeric and returns a dataframe with those columns.
|
190
190
|
|
191
191
|
Args:
|
@@ -212,8 +212,8 @@ def validate_and_return_dataframe_and_columns(
|
|
212
212
|
|
213
213
|
|
214
214
|
def check_label_columns(
|
215
|
-
y_true_col_names: Union[str,
|
216
|
-
y_pred_col_names: Union[str,
|
215
|
+
y_true_col_names: Union[str, list[str]],
|
216
|
+
y_pred_col_names: Union[str, list[str]],
|
217
217
|
) -> None:
|
218
218
|
"""Check y true and y pred columns.
|
219
219
|
|
@@ -238,7 +238,7 @@ def check_label_columns(
|
|
238
238
|
)
|
239
239
|
|
240
240
|
|
241
|
-
def flatten_cols(cols:
|
241
|
+
def flatten_cols(cols: list[Optional[Union[str, list[str]]]]) -> list[str]:
|
242
242
|
res = []
|
243
243
|
for col in cols:
|
244
244
|
if isinstance(col, str):
|
@@ -251,7 +251,7 @@ def flatten_cols(cols: List[Optional[Union[str, List[str]]]]) -> List[str]:
|
|
251
251
|
def unique_labels(
|
252
252
|
*,
|
253
253
|
df: snowpark.DataFrame,
|
254
|
-
columns:
|
254
|
+
columns: list[snowpark.Column],
|
255
255
|
) -> snowpark.DataFrame:
|
256
256
|
"""Extract indexed ordered unique labels as a dataframe.
|
257
257
|
|
@@ -311,7 +311,7 @@ def weighted_sum(
|
|
311
311
|
sample_score_column: snowpark.Column,
|
312
312
|
sample_weight_column: Optional[snowpark.Column] = None,
|
313
313
|
normalize: bool = False,
|
314
|
-
statement_params:
|
314
|
+
statement_params: dict[str, str],
|
315
315
|
) -> float:
|
316
316
|
"""Weighted sum of the sample score column.
|
317
317
|
|