snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Sequence
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
@@ -12,7 +12,7 @@ from snowflake.ml._internal.exceptions import (
|
|
12
12
|
from snowflake.ml.model._signatures import core
|
13
13
|
|
14
14
|
|
15
|
-
def convert_list_to_ndarray(data:
|
15
|
+
def convert_list_to_ndarray(data: list[Any]) -> npt.NDArray[Any]:
|
16
16
|
"""Create a numpy array from list or nested list. Avoid ragged list and unaligned types.
|
17
17
|
|
18
18
|
Args:
|
@@ -49,7 +49,7 @@ def convert_list_to_ndarray(data: List[Any]) -> npt.NDArray[Any]:
|
|
49
49
|
|
50
50
|
|
51
51
|
def rename_features(
|
52
|
-
features: Sequence[core.BaseFeatureSpec], feature_names: Optional[
|
52
|
+
features: Sequence[core.BaseFeatureSpec], feature_names: Optional[list[str]] = None
|
53
53
|
) -> Sequence[core.BaseFeatureSpec]:
|
54
54
|
"""It renames the feature in features provided optional feature names.
|
55
55
|
|
@@ -104,7 +104,7 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
|
|
104
104
|
return data
|
105
105
|
|
106
106
|
|
107
|
-
def huggingface_pipeline_signature_auto_infer(task: str, params:
|
107
|
+
def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any]) -> Optional[core.ModelSignature]:
|
108
108
|
# Text
|
109
109
|
|
110
110
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
@@ -351,7 +351,7 @@ def series_dropna(series: pd.Series) -> pd.Series:
|
|
351
351
|
return series.dropna(inplace=False).reset_index(drop=True).convert_dtypes()
|
352
352
|
|
353
353
|
|
354
|
-
def infer_list(name: str, data:
|
354
|
+
def infer_list(name: str, data: list[Any]) -> core.BaseFeatureSpec:
|
355
355
|
"""Infer the feature specification from a list.
|
356
356
|
|
357
357
|
Args:
|
@@ -382,7 +382,7 @@ def infer_list(name: str, data: List[Any]) -> core.BaseFeatureSpec:
|
|
382
382
|
return core.FeatureSpec(name=name, dtype=arr_dtype, shape=arr.shape)
|
383
383
|
|
384
384
|
|
385
|
-
def infer_dict(name: str, data:
|
385
|
+
def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
|
386
386
|
"""Infer the feature specification from a dictionary.
|
387
387
|
|
388
388
|
Args:
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine,
|
3
|
+
from typing import Any, Callable, Coroutine, Generator, Optional, Union
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
7
|
+
from typing_extensions import deprecated
|
7
8
|
|
8
9
|
from snowflake.ml.model import type_hints as model_types
|
9
10
|
|
@@ -78,7 +79,7 @@ class ModelRef:
|
|
78
79
|
return MethodRef(self, method_name)
|
79
80
|
raise AttributeError(f"Method {method_name} not found in model {self._name}.")
|
80
81
|
|
81
|
-
def __getstate__(self) ->
|
82
|
+
def __getstate__(self) -> dict[str, Any]:
|
82
83
|
state = self.__dict__.copy()
|
83
84
|
del state["_model"]
|
84
85
|
return state
|
@@ -113,8 +114,8 @@ class ModelContext:
|
|
113
114
|
def __init__(
|
114
115
|
self,
|
115
116
|
*,
|
116
|
-
artifacts: Optional[Union[
|
117
|
-
models: Optional[Union[
|
117
|
+
artifacts: Optional[Union[dict[str, str], str, model_types.SupportedModelType]] = None,
|
118
|
+
models: Optional[Union[dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
|
118
119
|
**kwargs: Optional[Union[str, model_types.SupportedModelType]],
|
119
120
|
) -> None:
|
120
121
|
"""Initialize the model context.
|
@@ -130,8 +131,8 @@ class ModelContext:
|
|
130
131
|
ValueError: Raised when the model name is duplicated.
|
131
132
|
"""
|
132
133
|
|
133
|
-
self.artifacts:
|
134
|
-
self.model_refs:
|
134
|
+
self.artifacts: dict[str, str] = dict()
|
135
|
+
self.model_refs: dict[str, ModelRef] = dict()
|
135
136
|
|
136
137
|
# In case that artifacts is a dictionary, assume the original usage,
|
137
138
|
# which is to pass in a dictionary of artifacts.
|
@@ -185,7 +186,7 @@ class ModelContext:
|
|
185
186
|
return self.model_refs[name]
|
186
187
|
|
187
188
|
def __getitem__(self, key: str) -> Union[str, ModelRef]:
|
188
|
-
combined:
|
189
|
+
combined: dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
|
189
190
|
if key not in combined:
|
190
191
|
raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
|
191
192
|
return combined[key]
|
@@ -226,12 +227,12 @@ class CustomModel:
|
|
226
227
|
else:
|
227
228
|
raise TypeError("A non-method inference API function is not supported.")
|
228
229
|
|
229
|
-
def
|
230
|
-
"""Returns all methods in CLS with `
|
230
|
+
def _get_partitioned_methods(self) -> list[str]:
|
231
|
+
"""Returns all methods in CLS with `partitioned_api` as the outermost decorator."""
|
231
232
|
rv = []
|
232
233
|
for cls_method_str in dir(self):
|
233
234
|
cls_method = getattr(self, cls_method_str)
|
234
|
-
if getattr(cls_method, "
|
235
|
+
if getattr(cls_method, "_is_partitioned_api", False):
|
235
236
|
if inspect.ismethod(cls_method):
|
236
237
|
rv.append(cls_method_str)
|
237
238
|
else:
|
@@ -282,9 +283,21 @@ def inference_api(
|
|
282
283
|
return func
|
283
284
|
|
284
285
|
|
286
|
+
def partitioned_api(
|
287
|
+
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
288
|
+
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
289
|
+
func.__dict__["_is_inference_api"] = True
|
290
|
+
func.__dict__["_is_partitioned_api"] = True
|
291
|
+
return func
|
292
|
+
|
293
|
+
|
294
|
+
@deprecated(
|
295
|
+
"snowflake.ml.custom_model.partitioned_inference_api is deprecated and will be removed in a future release."
|
296
|
+
" Use snowflake.ml.custom_model.partitioned_api instead."
|
297
|
+
)
|
285
298
|
def partitioned_inference_api(
|
286
299
|
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
287
300
|
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
288
301
|
func.__dict__["_is_inference_api"] = True
|
289
|
-
func.__dict__["
|
302
|
+
func.__dict__["_is_partitioned_api"] = True
|
290
303
|
return func
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
Any,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
Sequence,
|
11
|
-
Tuple,
|
12
|
-
Type,
|
13
|
-
Union,
|
14
|
-
cast,
|
15
|
-
)
|
4
|
+
from typing import Any, Literal, Optional, Sequence, Union, cast
|
16
5
|
|
17
6
|
import numpy as np
|
18
7
|
import pandas as pd
|
@@ -30,7 +19,7 @@ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
30
19
|
from snowflake.ml.model import type_hints as model_types
|
31
20
|
from snowflake.ml.model._signatures import (
|
32
21
|
base_handler,
|
33
|
-
builtins_handler
|
22
|
+
builtins_handler,
|
34
23
|
core,
|
35
24
|
dmatrix_handler,
|
36
25
|
numpy_handler,
|
@@ -48,7 +37,7 @@ FeatureGroupSpec = core.FeatureGroupSpec
|
|
48
37
|
ModelSignature = core.ModelSignature
|
49
38
|
|
50
39
|
|
51
|
-
_LOCAL_DATA_HANDLERS:
|
40
|
+
_LOCAL_DATA_HANDLERS: list[type[base_handler.BaseDataHandler[Any]]] = [
|
52
41
|
pandas_handler.PandasDataFrameHandler,
|
53
42
|
numpy_handler.NumpyArrayHandler,
|
54
43
|
builtins_handler.ListOfBuiltinHandler,
|
@@ -82,9 +71,9 @@ def _truncate_data(
|
|
82
71
|
warnings.warn(
|
83
72
|
formatting.unwrap(
|
84
73
|
f"""
|
85
|
-
The sample input has {row_count} rows
|
86
|
-
|
87
|
-
|
74
|
+
The sample input has {row_count} rows. Using the first 100 rows to define the inputs and outputs
|
75
|
+
of the model and the data types of each. Use `signatures` parameter to specify model inputs and
|
76
|
+
outputs manually if the automatic inference is not correct.
|
88
77
|
"""
|
89
78
|
),
|
90
79
|
category=UserWarning,
|
@@ -414,7 +403,7 @@ class SnowparkIdentifierRule(enum.Enum):
|
|
414
403
|
|
415
404
|
def _get_dataframe_values_range(
|
416
405
|
df: snowflake.snowpark.DataFrame,
|
417
|
-
) ->
|
406
|
+
) -> dict[str, Union[tuple[int, int], tuple[float, float]]]:
|
418
407
|
columns = [
|
419
408
|
F.array_construct(F.min(field.name), F.max(field.name)).as_(field.name)
|
420
409
|
for field in df.schema.fields
|
@@ -429,7 +418,7 @@ def _get_dataframe_values_range(
|
|
429
418
|
original_exception=ValueError(f"Unable to get the value range of fields {df.columns}"),
|
430
419
|
)
|
431
420
|
return cast(
|
432
|
-
|
421
|
+
dict[str, Union[tuple[int, int], tuple[float, float]]],
|
433
422
|
{
|
434
423
|
sql_identifier.SqlIdentifier(k, case_sensitive=True).identifier(): (json.loads(v)[0], json.loads(v)[1])
|
435
424
|
for k, v in res[0].as_dict().items()
|
@@ -456,7 +445,7 @@ def _validate_snowpark_data(
|
|
456
445
|
- inferred: signature `a` - Snowpark DF `"a"`, use `get_inferred_name`
|
457
446
|
- normalized: signature `a` - Snowpark DF `A`, use `resolve_identifier`
|
458
447
|
"""
|
459
|
-
errors:
|
448
|
+
errors: dict[SnowparkIdentifierRule, list[Exception]] = {
|
460
449
|
SnowparkIdentifierRule.INFERRED: [],
|
461
450
|
SnowparkIdentifierRule.NORMALIZED: [],
|
462
451
|
}
|
@@ -549,7 +538,7 @@ def _validate_snowpark_type_feature(
|
|
549
538
|
field: spt.StructField,
|
550
539
|
ft_type: DataType,
|
551
540
|
ft_name: str,
|
552
|
-
value_range: Optional[Union[
|
541
|
+
value_range: Optional[Union[tuple[int, int], tuple[float, float]]],
|
553
542
|
strict: bool = False,
|
554
543
|
) -> None:
|
555
544
|
field_data_type = field.datatype
|
@@ -716,8 +705,8 @@ def _convert_and_validate_local_data(
|
|
716
705
|
def infer_signature(
|
717
706
|
input_data: model_types.SupportedLocalDataType,
|
718
707
|
output_data: model_types.SupportedLocalDataType,
|
719
|
-
input_feature_names: Optional[
|
720
|
-
output_feature_names: Optional[
|
708
|
+
input_feature_names: Optional[list[str]] = None,
|
709
|
+
output_feature_names: Optional[list[str]] = None,
|
721
710
|
input_data_limit: Optional[int] = 100,
|
722
711
|
output_data_limit: Optional[int] = 100,
|
723
712
|
) -> core.ModelSignature:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from packaging import version
|
5
5
|
|
@@ -13,7 +13,7 @@ class HuggingFacePipelineModel:
|
|
13
13
|
revision: Optional[str] = None,
|
14
14
|
token: Optional[str] = None,
|
15
15
|
trust_remote_code: Optional[bool] = None,
|
16
|
-
model_kwargs: Optional[
|
16
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
17
17
|
**kwargs: Any,
|
18
18
|
) -> None:
|
19
19
|
"""
|
@@ -65,6 +65,7 @@ class HuggingFacePipelineModel:
|
|
65
65
|
warnings.warn(
|
66
66
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
|
67
67
|
FutureWarning,
|
68
|
+
stacklevel=2,
|
68
69
|
)
|
69
70
|
if token is not None:
|
70
71
|
raise ValueError(
|
@@ -183,7 +184,8 @@ class HuggingFacePipelineModel:
|
|
183
184
|
warnings.warn(
|
184
185
|
f"No model was supplied, defaulted to {model} and revision"
|
185
186
|
f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
186
|
-
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
187
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended.",
|
188
|
+
stacklevel=2,
|
187
189
|
)
|
188
190
|
if config is None and isinstance(model, str):
|
189
191
|
config_obj = transformers.AutoConfig.from_pretrained(
|
@@ -200,7 +202,8 @@ class HuggingFacePipelineModel:
|
|
200
202
|
if kwargs.get("device", None) is not None:
|
201
203
|
warnings.warn(
|
202
204
|
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
203
|
-
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
|
205
|
+
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
|
206
|
+
stacklevel=2,
|
204
207
|
)
|
205
208
|
|
206
209
|
# ==== End pipeline logic from transformers ====
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
from enum import Enum
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
4
|
|
5
5
|
import numpy.typing as npt
|
6
6
|
from typing_extensions import NotRequired
|
@@ -32,7 +32,7 @@ _SupportedBuiltins = Union[
|
|
32
32
|
bool,
|
33
33
|
str,
|
34
34
|
bytes,
|
35
|
-
|
35
|
+
dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
|
36
36
|
"_SupportedBuiltinsList",
|
37
37
|
]
|
38
38
|
_SupportedNumpyDtype = Union[
|
@@ -153,7 +153,7 @@ class BaseModelSaveOption(TypedDict):
|
|
153
153
|
embed_local_ml_library: NotRequired[bool]
|
154
154
|
relax_version: NotRequired[bool]
|
155
155
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
156
|
-
method_options: NotRequired[
|
156
|
+
method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
|
157
157
|
enable_explainability: NotRequired[bool]
|
158
158
|
save_location: NotRequired[str]
|
159
159
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import numbers
|
3
3
|
import os
|
4
|
-
from typing import Any, Callable
|
4
|
+
from typing import Any, Callable
|
5
5
|
|
6
6
|
import cloudpickle as cp
|
7
7
|
import numpy as np
|
@@ -16,7 +16,7 @@ from snowflake.snowpark import Session
|
|
16
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
17
|
|
18
18
|
|
19
|
-
def validate_sklearn_args(args:
|
19
|
+
def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -> dict[str, Any]:
|
20
20
|
"""Validate if all the keyword args are supported by current version of SKLearn/XGBoost object.
|
21
21
|
|
22
22
|
Args:
|
@@ -71,7 +71,7 @@ def transform_snowml_obj_to_sklearn_obj(obj: Any) -> Any:
|
|
71
71
|
return obj
|
72
72
|
|
73
73
|
|
74
|
-
def gather_dependencies(obj: Any) ->
|
74
|
+
def gather_dependencies(obj: Any) -> set[str]:
|
75
75
|
"""Gathers dependencies from the SnowML Estimator and Transformer objects.
|
76
76
|
|
77
77
|
Args:
|
@@ -82,7 +82,7 @@ def gather_dependencies(obj: Any) -> Set[str]:
|
|
82
82
|
"""
|
83
83
|
|
84
84
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
85
|
-
deps:
|
85
|
+
deps: set[str] = set()
|
86
86
|
for elem in obj:
|
87
87
|
deps = deps | set(gather_dependencies(elem))
|
88
88
|
return deps
|
@@ -167,8 +167,8 @@ def get_module_name(model: object) -> str:
|
|
167
167
|
|
168
168
|
|
169
169
|
def handle_inference_result(
|
170
|
-
inference_res: Any, output_cols:
|
171
|
-
) ->
|
170
|
+
inference_res: Any, output_cols: list[str], inference_method: str, within_udf: bool = False
|
171
|
+
) -> tuple[npt.NDArray[Any], list[str]]:
|
172
172
|
if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
|
173
173
|
# In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
|
174
174
|
# ndarrays. We need to concatenate them.
|
@@ -248,7 +248,7 @@ def create_temp_stage(session: Session) -> str:
|
|
248
248
|
|
249
249
|
|
250
250
|
def upload_model_to_stage(
|
251
|
-
stage_name: str, estimator: object, session: Session, statement_params:
|
251
|
+
stage_name: str, estimator: object, session: Session, statement_params: dict[str, str]
|
252
252
|
) -> str:
|
253
253
|
"""Util method to pickle and upload the model to a temp Snowflake stage.
|
254
254
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
@@ -38,9 +38,9 @@ class PandasTransformHandlers:
|
|
38
38
|
def batch_inference(
|
39
39
|
self,
|
40
40
|
inference_method: str,
|
41
|
-
input_cols:
|
42
|
-
expected_output_cols:
|
43
|
-
snowpark_input_cols: Optional[
|
41
|
+
input_cols: list[str],
|
42
|
+
expected_output_cols: list[str],
|
43
|
+
snowpark_input_cols: Optional[list[str]] = None,
|
44
44
|
drop_input_cols: Optional[bool] = False,
|
45
45
|
*args: Any,
|
46
46
|
**kwargs: Any,
|
@@ -147,8 +147,8 @@ class PandasTransformHandlers:
|
|
147
147
|
|
148
148
|
def score(
|
149
149
|
self,
|
150
|
-
input_cols:
|
151
|
-
label_cols:
|
150
|
+
input_cols: list[str],
|
151
|
+
label_cols: list[str],
|
152
152
|
sample_weight_col: Optional[str],
|
153
153
|
*args: Any,
|
154
154
|
**kwargs: Any,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
@@ -15,8 +15,8 @@ class PandasModelTrainer:
|
|
15
15
|
self,
|
16
16
|
estimator: object,
|
17
17
|
dataset: pd.DataFrame,
|
18
|
-
input_cols:
|
19
|
-
label_cols: Optional[
|
18
|
+
input_cols: list[str],
|
19
|
+
label_cols: Optional[list[str]],
|
20
20
|
sample_weight_col: Optional[str],
|
21
21
|
) -> None:
|
22
22
|
"""
|
@@ -57,10 +57,10 @@ class PandasModelTrainer:
|
|
57
57
|
|
58
58
|
def train_fit_predict(
|
59
59
|
self,
|
60
|
-
expected_output_cols_list:
|
60
|
+
expected_output_cols_list: list[str],
|
61
61
|
drop_input_cols: Optional[bool] = False,
|
62
62
|
example_output_pd_df: Optional[pd.DataFrame] = None,
|
63
|
-
) ->
|
63
|
+
) -> tuple[pd.DataFrame, object]:
|
64
64
|
"""Trains the model using specified features and target columns from the dataset.
|
65
65
|
This API is different from fit itself because it would also provide the predict
|
66
66
|
output.
|
@@ -92,9 +92,9 @@ class PandasModelTrainer:
|
|
92
92
|
|
93
93
|
def train_fit_transform(
|
94
94
|
self,
|
95
|
-
expected_output_cols_list:
|
95
|
+
expected_output_cols_list: list[str],
|
96
96
|
drop_input_cols: Optional[bool] = False,
|
97
|
-
) ->
|
97
|
+
) -> tuple[pd.DataFrame, object]:
|
98
98
|
"""Trains the model using specified features and target columns from the dataset.
|
99
99
|
This API is different from fit itself because it would also provide the transform
|
100
100
|
output.
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
import cloudpickle as cp
|
4
2
|
import numpy as np
|
5
3
|
|
@@ -11,7 +9,7 @@ class ModelSpecifications:
|
|
11
9
|
A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs.
|
12
10
|
"""
|
13
11
|
|
14
|
-
def __init__(self, imports:
|
12
|
+
def __init__(self, imports: list[str], pkgDependencies: list[str]) -> None:
|
15
13
|
self.imports = imports
|
16
14
|
self.pkgDependencies = pkgDependencies
|
17
15
|
|
@@ -20,7 +18,7 @@ class SKLearnModelSpecifications(ModelSpecifications):
|
|
20
18
|
def __init__(self) -> None:
|
21
19
|
import sklearn
|
22
20
|
|
23
|
-
imports:
|
21
|
+
imports: list[str] = ["sklearn"]
|
24
22
|
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
|
25
23
|
pkgDependencies = [
|
26
24
|
f"numpy=={np.__version__}",
|
@@ -56,8 +54,8 @@ class XGBoostModelSpecifications(ModelSpecifications):
|
|
56
54
|
import sklearn
|
57
55
|
import xgboost
|
58
56
|
|
59
|
-
imports:
|
60
|
-
pkgDependencies:
|
57
|
+
imports: list[str] = ["xgboost"]
|
58
|
+
pkgDependencies: list[str] = [
|
61
59
|
f"numpy=={np.__version__}",
|
62
60
|
f"scikit-learn=={sklearn.__version__}",
|
63
61
|
f"xgboost=={xgboost.__version__}",
|
@@ -71,8 +69,8 @@ class LightGBMModelSpecifications(ModelSpecifications):
|
|
71
69
|
import lightgbm
|
72
70
|
import sklearn
|
73
71
|
|
74
|
-
imports:
|
75
|
-
pkgDependencies:
|
72
|
+
imports: list[str] = ["lightgbm"]
|
73
|
+
pkgDependencies: list[str] = [
|
76
74
|
f"numpy=={np.__version__}",
|
77
75
|
f"scikit-learn=={sklearn.__version__}",
|
78
76
|
f"lightgbm=={lightgbm.__version__}",
|
@@ -86,8 +84,8 @@ class SklearnModelSelectionModelSpecifications(ModelSpecifications):
|
|
86
84
|
import sklearn
|
87
85
|
import xgboost
|
88
86
|
|
89
|
-
imports:
|
90
|
-
pkgDependencies:
|
87
|
+
imports: list[str] = ["sklearn", "xgboost"]
|
88
|
+
pkgDependencies: list[str] = [
|
91
89
|
f"numpy=={np.__version__}",
|
92
90
|
f"scikit-learn=={sklearn.__version__}",
|
93
91
|
f"cloudpickle=={cp.__version__}",
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Protocol, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -18,15 +18,15 @@ class ModelTrainer(Protocol):
|
|
18
18
|
|
19
19
|
def train_fit_predict(
|
20
20
|
self,
|
21
|
-
expected_output_cols_list:
|
21
|
+
expected_output_cols_list: list[str],
|
22
22
|
drop_input_cols: Optional[bool] = False,
|
23
23
|
example_output_pd_df: Optional[pd.DataFrame] = None,
|
24
|
-
) ->
|
24
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
25
25
|
raise NotImplementedError
|
26
26
|
|
27
27
|
def train_fit_transform(
|
28
28
|
self,
|
29
|
-
expected_output_cols_list:
|
29
|
+
expected_output_cols_list: list[str],
|
30
30
|
drop_input_cols: Optional[bool] = False,
|
31
|
-
) ->
|
31
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
32
32
|
raise NotImplementedError
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
from sklearn import model_selection
|
@@ -71,8 +71,8 @@ class ModelTrainerBuilder:
|
|
71
71
|
cls,
|
72
72
|
estimator: object,
|
73
73
|
dataset: Union[DataFrame, pd.DataFrame],
|
74
|
-
input_cols: Optional[
|
75
|
-
label_cols: Optional[
|
74
|
+
input_cols: Optional[list[str]] = None,
|
75
|
+
label_cols: Optional[list[str]] = None,
|
76
76
|
sample_weight_col: Optional[str] = None,
|
77
77
|
autogenerated: bool = False,
|
78
78
|
subproject: str = "",
|
@@ -130,7 +130,7 @@ class ModelTrainerBuilder:
|
|
130
130
|
cls,
|
131
131
|
estimator: object,
|
132
132
|
dataset: Union[DataFrame, pd.DataFrame],
|
133
|
-
input_cols:
|
133
|
+
input_cols: list[str],
|
134
134
|
autogenerated: bool = False,
|
135
135
|
subproject: str = "",
|
136
136
|
) -> ModelTrainer:
|
@@ -169,8 +169,8 @@ class ModelTrainerBuilder:
|
|
169
169
|
cls,
|
170
170
|
estimator: object,
|
171
171
|
dataset: Union[DataFrame, pd.DataFrame],
|
172
|
-
input_cols:
|
173
|
-
label_cols: Optional[
|
172
|
+
input_cols: list[str],
|
173
|
+
label_cols: Optional[list[str]] = None,
|
174
174
|
sample_weight_col: Optional[str] = None,
|
175
175
|
autogenerated: bool = False,
|
176
176
|
subproject: str = "",
|