snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import pathlib
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired
|
6
6
|
|
@@ -137,8 +137,8 @@ class ModelMethod:
|
|
137
137
|
)
|
138
138
|
|
139
139
|
outputs: Union[
|
140
|
-
|
141
|
-
|
140
|
+
list[model_manifest_schema.ModelMethodSignatureField],
|
141
|
+
list[model_manifest_schema.ModelMethodSignatureFieldWithName],
|
142
142
|
]
|
143
143
|
if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
144
144
|
outputs = [
|
@@ -3,10 +3,11 @@ import itertools
|
|
3
3
|
import os
|
4
4
|
import pathlib
|
5
5
|
import warnings
|
6
|
-
from typing import DefaultDict,
|
6
|
+
from typing import DefaultDict, Optional
|
7
7
|
|
8
8
|
from packaging import requirements, version
|
9
9
|
|
10
|
+
from snowflake.ml import version as snowml_version
|
10
11
|
from snowflake.ml._internal import env as snowml_env, env_utils
|
11
12
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
12
13
|
|
@@ -19,9 +20,8 @@ _DEFAULT_CONDA_ENV_FILENAME = "conda.yml"
|
|
19
20
|
_DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
|
20
21
|
|
21
22
|
# The default CUDA version is chosen based on the driver availability in SPCS.
|
22
|
-
#
|
23
|
-
|
24
|
-
DEFAULT_CUDA_VERSION = "11.8"
|
23
|
+
# Make sure they are aligned with default CUDA version in inference server.
|
24
|
+
DEFAULT_CUDA_VERSION = "12.4"
|
25
25
|
|
26
26
|
|
27
27
|
class ModelEnv:
|
@@ -38,15 +38,16 @@ class ModelEnv:
|
|
38
38
|
self.prefer_pip: bool = prefer_pip
|
39
39
|
self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
|
40
40
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
|
41
|
-
self.artifact_repository_map: Optional[
|
42
|
-
self.
|
43
|
-
self.
|
41
|
+
self.artifact_repository_map: Optional[dict[str, str]] = None
|
42
|
+
self.resource_constraint: Optional[dict[str, str]] = None
|
43
|
+
self._conda_dependencies: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
|
44
|
+
self._pip_requirements: list[requirements.Requirement] = []
|
44
45
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
45
46
|
self._cuda_version: Optional[version.Version] = None
|
46
|
-
self._snowpark_ml_version: version.Version = version.parse(
|
47
|
+
self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
|
47
48
|
|
48
49
|
@property
|
49
|
-
def conda_dependencies(self) ->
|
50
|
+
def conda_dependencies(self) -> list[str]:
|
50
51
|
"""List of conda channel and dependencies from that to run the model"""
|
51
52
|
return sorted(
|
52
53
|
f"{chan}::{str(req)}" if chan else str(req)
|
@@ -57,24 +58,24 @@ class ModelEnv:
|
|
57
58
|
@conda_dependencies.setter
|
58
59
|
def conda_dependencies(
|
59
60
|
self,
|
60
|
-
conda_dependencies: Optional[
|
61
|
+
conda_dependencies: Optional[list[str]] = None,
|
61
62
|
) -> None:
|
62
63
|
self._conda_dependencies = env_utils.validate_conda_dependency_string_list(
|
63
|
-
conda_dependencies if conda_dependencies else []
|
64
|
+
conda_dependencies if conda_dependencies else [], add_local_version_specifier=True
|
64
65
|
)
|
65
66
|
|
66
67
|
@property
|
67
|
-
def pip_requirements(self) ->
|
68
|
+
def pip_requirements(self) -> list[str]:
|
68
69
|
"""List of pip Python packages requirements for running the model."""
|
69
70
|
return sorted(list(map(str, self._pip_requirements)))
|
70
71
|
|
71
72
|
@pip_requirements.setter
|
72
73
|
def pip_requirements(
|
73
74
|
self,
|
74
|
-
pip_requirements: Optional[
|
75
|
+
pip_requirements: Optional[list[str]] = None,
|
75
76
|
) -> None:
|
76
77
|
self._pip_requirements = env_utils.validate_pip_requirement_string_list(
|
77
|
-
pip_requirements if pip_requirements else []
|
78
|
+
pip_requirements if pip_requirements else [], add_local_version_specifier=True
|
78
79
|
)
|
79
80
|
|
80
81
|
@property
|
@@ -117,7 +118,7 @@ class ModelEnv:
|
|
117
118
|
|
118
119
|
def include_if_absent(
|
119
120
|
self,
|
120
|
-
pkgs:
|
121
|
+
pkgs: list[ModelDependency],
|
121
122
|
check_local_version: bool = False,
|
122
123
|
) -> None:
|
123
124
|
"""Append requirements into model env if absent. Depending on the environment, requirements may be added
|
@@ -128,7 +129,7 @@ class ModelEnv:
|
|
128
129
|
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
129
130
|
"""
|
130
131
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
131
|
-
pip_pkg_reqs:
|
132
|
+
pip_pkg_reqs: list[str] = []
|
132
133
|
warnings.warn(
|
133
134
|
(
|
134
135
|
"Dependencies specified from pip requirements."
|
@@ -145,7 +146,7 @@ class ModelEnv:
|
|
145
146
|
else:
|
146
147
|
self._include_if_absent_conda(pkgs, check_local_version)
|
147
148
|
|
148
|
-
def _include_if_absent_conda(self, pkgs:
|
149
|
+
def _include_if_absent_conda(self, pkgs: list[ModelDependency], check_local_version: bool = False) -> None:
|
149
150
|
"""Append requirements into model env conda dependencies if absent.
|
150
151
|
|
151
152
|
Args:
|
@@ -190,7 +191,7 @@ class ModelEnv:
|
|
190
191
|
stacklevel=2,
|
191
192
|
)
|
192
193
|
|
193
|
-
def _include_if_absent_pip(self, pkgs:
|
194
|
+
def _include_if_absent_pip(self, pkgs: list[str], check_local_version: bool = False) -> None:
|
194
195
|
"""Append pip requirements into model env pip requirements if absent.
|
195
196
|
|
196
197
|
Args:
|
@@ -207,7 +208,7 @@ class ModelEnv:
|
|
207
208
|
except env_utils.DuplicateDependencyError:
|
208
209
|
pass
|
209
210
|
|
210
|
-
def remove_if_present_conda(self, conda_pkgs:
|
211
|
+
def remove_if_present_conda(self, conda_pkgs: list[str]) -> None:
|
211
212
|
"""Remove conda requirements from model env if present.
|
212
213
|
|
213
214
|
Args:
|
@@ -352,13 +353,14 @@ class ModelEnv:
|
|
352
353
|
def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
|
353
354
|
self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
|
354
355
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
|
355
|
-
self.artifact_repository_map = env_dict.get("artifact_repository_map"
|
356
|
+
self.artifact_repository_map = env_dict.get("artifact_repository_map")
|
357
|
+
self.resource_constraint = env_dict.get("resource_constraint")
|
356
358
|
|
357
359
|
self.load_from_conda_file(base_dir / self.conda_env_rel_path)
|
358
360
|
self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
|
359
361
|
|
360
362
|
self.python_version = env_dict["python_version"]
|
361
|
-
self.cuda_version = env_dict.get("cuda_version"
|
363
|
+
self.cuda_version = env_dict.get("cuda_version")
|
362
364
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
363
365
|
|
364
366
|
def save_as_dict(
|
@@ -381,7 +383,8 @@ class ModelEnv:
|
|
381
383
|
return {
|
382
384
|
"conda": self.conda_env_rel_path.as_posix(),
|
383
385
|
"pip": self.pip_requirements_rel_path.as_posix(),
|
384
|
-
"artifact_repository_map": self.artifact_repository_map
|
386
|
+
"artifact_repository_map": self.artifact_repository_map or {},
|
387
|
+
"resource_constraint": self.resource_constraint or {},
|
385
388
|
"python_version": self.python_version,
|
386
389
|
"cuda_version": self.cuda_version,
|
387
390
|
"snowpark_ml_version": self.snowpark_ml_version,
|
@@ -389,7 +392,7 @@ class ModelEnv:
|
|
389
392
|
|
390
393
|
def validate_with_local_env(
|
391
394
|
self, check_snowpark_ml_version: bool = False
|
392
|
-
) ->
|
395
|
+
) -> list[env_utils.IncorrectLocalEnvironmentError]:
|
393
396
|
errors = []
|
394
397
|
try:
|
395
398
|
env_utils.validate_py_runtime_version(str(self._python_version))
|
@@ -413,10 +416,10 @@ class ModelEnv:
|
|
413
416
|
|
414
417
|
if check_snowpark_ml_version:
|
415
418
|
# For Modeling model
|
416
|
-
if self._snowpark_ml_version.base_version !=
|
419
|
+
if self._snowpark_ml_version.base_version != snowml_version.VERSION:
|
417
420
|
errors.append(
|
418
421
|
env_utils.IncorrectLocalEnvironmentError(
|
419
|
-
f"The local installed version of Snowpark ML library is {
|
422
|
+
f"The local installed version of Snowpark ML library is {snowml_version.VERSION} "
|
420
423
|
f"which differs from required version {self.snowpark_ml_version}."
|
421
424
|
)
|
422
425
|
)
|
@@ -2,13 +2,13 @@ import functools
|
|
2
2
|
import importlib
|
3
3
|
import pkgutil
|
4
4
|
from types import ModuleType
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, TypeVar, cast
|
6
6
|
|
7
7
|
from snowflake.ml.model import type_hints as model_types
|
8
8
|
from snowflake.ml.model._packager.model_handlers import _base
|
9
9
|
|
10
10
|
_HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
|
11
|
-
_MODEL_HANDLER_REGISTRY:
|
11
|
+
_MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
|
12
12
|
_IS_HANDLER_LOADED = False
|
13
13
|
|
14
14
|
|
@@ -54,7 +54,7 @@ def ensure_handlers_registration(fn: F) -> F:
|
|
54
54
|
@ensure_handlers_registration
|
55
55
|
def find_handler(
|
56
56
|
model: model_types.SupportedModelType,
|
57
|
-
) -> Optional[
|
57
|
+
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
58
58
|
for handler in _MODEL_HANDLER_REGISTRY.values():
|
59
59
|
if handler.can_handle(model):
|
60
60
|
return handler
|
@@ -64,7 +64,7 @@ def find_handler(
|
|
64
64
|
@ensure_handlers_registration
|
65
65
|
def load_handler(
|
66
66
|
target_model_type: model_types.SupportedModelHandlerType,
|
67
|
-
) -> Optional[
|
67
|
+
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
68
68
|
for model_type, handler in _MODEL_HANDLER_REGISTRY.items():
|
69
69
|
if target_model_type == model_type:
|
70
70
|
return handler
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from abc import abstractmethod
|
3
|
-
from typing import
|
3
|
+
from typing import Generic, Optional, Protocol, final
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
from typing_extensions import TypeGuard, Unpack
|
@@ -14,7 +14,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
14
14
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
15
15
|
HANDLER_VERSION: str
|
16
16
|
_MIN_SNOWPARK_ML_VERSION: str
|
17
|
-
_HANDLER_MIGRATOR_PLANS:
|
17
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]]
|
18
18
|
|
19
19
|
@classmethod
|
20
20
|
@abstractmethod
|
@@ -1,8 +1,9 @@
|
|
1
|
+
import importlib
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import pathlib
|
4
5
|
import warnings
|
5
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
import numpy.typing as npt
|
@@ -10,8 +11,10 @@ import pandas as pd
|
|
10
11
|
from absl import logging
|
11
12
|
|
12
13
|
import snowflake.snowpark.dataframe as sp_df
|
14
|
+
from snowflake.ml._internal import env
|
13
15
|
from snowflake.ml._internal.utils import identifier
|
14
16
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
17
|
+
from snowflake.ml.model._packager.model_env import model_env
|
15
18
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
19
|
from snowflake.ml.model._signatures import (
|
17
20
|
core,
|
@@ -106,6 +109,35 @@ def get_input_signature(
|
|
106
109
|
return input_sig
|
107
110
|
|
108
111
|
|
112
|
+
def add_inferred_explain_method_signature(
|
113
|
+
model_meta: model_meta.ModelMetadata,
|
114
|
+
explain_method: str,
|
115
|
+
target_method: str,
|
116
|
+
background_data: model_types.SupportedDataType,
|
117
|
+
explain_fn: Callable[[model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
|
118
|
+
output_feature_names: Optional[Sequence[str]] = None,
|
119
|
+
) -> model_meta.ModelMetadata:
|
120
|
+
inputs = get_input_signature(model_meta, target_method)
|
121
|
+
if output_feature_names is None: # If not provided, assume output feature names are the same as input feature names
|
122
|
+
output_feature_names = [spec.name for spec in inputs]
|
123
|
+
|
124
|
+
if model_meta.model_type == "snowml":
|
125
|
+
suffixed_output_names = [identifier.concat_names([name, "_explanation"]) for name in output_feature_names]
|
126
|
+
else:
|
127
|
+
suffixed_output_names = [f"{name}_explanation" for name in output_feature_names]
|
128
|
+
|
129
|
+
truncated_background_data = get_truncated_sample_data(background_data, 5)
|
130
|
+
sig = model_signature.infer_signature(
|
131
|
+
input_data=truncated_background_data,
|
132
|
+
output_data=explain_fn(truncated_background_data),
|
133
|
+
input_feature_names=[spec.name for spec in inputs],
|
134
|
+
output_feature_names=suffixed_output_names,
|
135
|
+
)
|
136
|
+
|
137
|
+
model_meta.signatures[explain_method] = sig
|
138
|
+
return model_meta
|
139
|
+
|
140
|
+
|
109
141
|
def add_explain_method_signature(
|
110
142
|
model_meta: model_meta.ModelMetadata,
|
111
143
|
explain_method: str,
|
@@ -231,10 +263,11 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
|
|
231
263
|
|
232
264
|
|
233
265
|
def get_explain_target_method(
|
234
|
-
model_metadata: model_meta.ModelMetadata, target_methods_list:
|
266
|
+
model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
|
235
267
|
) -> Optional[str]:
|
236
|
-
|
237
|
-
|
268
|
+
"""Returns the first target method that is found in the model metadata signatures."""
|
269
|
+
for method in target_methods_list:
|
270
|
+
if method in model_metadata.signatures.keys():
|
238
271
|
return method
|
239
272
|
return None
|
240
273
|
|
@@ -248,7 +281,7 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
|
248
281
|
config_dict = json.load(f)
|
249
282
|
|
250
283
|
# a. get repository and class_path from configs
|
251
|
-
auto_map_configs = cast(
|
284
|
+
auto_map_configs = cast(dict[str, str], config_dict.get("auto_map", {}))
|
252
285
|
for config_name, config_value in auto_map_configs.items():
|
253
286
|
repository, _, class_path = config_value.rpartition("--")
|
254
287
|
|
@@ -261,3 +294,12 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
|
261
294
|
|
262
295
|
with open(f_path, "w") as f:
|
263
296
|
json.dump(config_dict, f)
|
297
|
+
|
298
|
+
|
299
|
+
def get_default_cuda_version() -> str:
|
300
|
+
# Default to the env cuda version when running in ML runtime
|
301
|
+
if env.IN_ML_RUNTIME and importlib.util.find_spec("torch") is not None:
|
302
|
+
import torch
|
303
|
+
|
304
|
+
return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
|
305
|
+
return model_env.DEFAULT_CUDA_VERSION
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
@@ -30,7 +30,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
30
30
|
HANDLER_TYPE = "catboost"
|
31
31
|
HANDLER_VERSION = "2024-03-21"
|
32
32
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
33
|
-
_HANDLER_MIGRATOR_PLANS:
|
33
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
34
34
|
|
35
35
|
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
36
36
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -147,7 +147,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
147
147
|
if enable_explainability:
|
148
148
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
149
149
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
150
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
150
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
151
151
|
|
152
152
|
return None
|
153
153
|
|
@@ -202,7 +202,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
202
202
|
def _create_custom_model(
|
203
203
|
raw_model: "catboost.CatBoost",
|
204
204
|
model_meta: model_meta_api.ModelMetadata,
|
205
|
-
) ->
|
205
|
+
) -> type[custom_model.CustomModel]:
|
206
206
|
def fn_factory(
|
207
207
|
raw_model: "catboost.CatBoost",
|
208
208
|
signature: model_signature.ModelSignature,
|
@@ -235,7 +235,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
235
235
|
|
236
236
|
return fn
|
237
237
|
|
238
|
-
type_method_dict:
|
238
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
239
239
|
for target_method_name, sig in model_meta.signatures.items():
|
240
240
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
241
241
|
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import
|
5
|
+
from typing import Optional, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
28
28
|
HANDLER_TYPE = "custom"
|
29
29
|
HANDLER_VERSION = "2023-12-01"
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
31
|
-
_HANDLER_MIGRATOR_PLANS:
|
31
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
33
|
@classmethod
|
34
34
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
|
@@ -72,7 +72,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
72
72
|
predictions_df = target_method(model, sample_input_data)
|
73
73
|
return predictions_df
|
74
74
|
|
75
|
-
for func_name in model.
|
75
|
+
for func_name in model._get_partitioned_methods():
|
76
76
|
function_properties = model_meta.function_properties.get(func_name, {})
|
77
77
|
function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
|
78
78
|
model_meta.function_properties[func_name] = function_properties
|
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
101
|
if handler is None:
|
102
|
-
raise TypeError(
|
102
|
+
raise TypeError(
|
103
|
+
f"Model {sub_name} in model context is not a supported model type. See "
|
104
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
|
105
|
+
"bring-your-own-model-types for more details."
|
106
|
+
)
|
103
107
|
sub_model = handler.cast_model(model_ref.model)
|
104
108
|
handler.save_model(
|
105
109
|
name=sub_name,
|
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
161
165
|
name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
|
162
166
|
for name, rel_path in artifacts_meta.items()
|
163
167
|
}
|
164
|
-
models:
|
168
|
+
models: dict[str, model_types.SupportedModelType] = dict()
|
165
169
|
for sub_model_name, _ref in context.model_refs.items():
|
166
170
|
model_type = model_meta.models[sub_model_name].model_type
|
167
171
|
handler = model_handler.load_handler(model_type)
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Type,
|
12
|
-
Union,
|
13
|
-
cast,
|
14
|
-
final,
|
15
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
16
5
|
|
17
6
|
import cloudpickle
|
18
7
|
import numpy as np
|
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
|
|
38
27
|
import transformers
|
39
28
|
|
40
29
|
|
41
|
-
def get_requirements_from_task(task: str, spcs_only: bool = False) ->
|
30
|
+
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
42
31
|
# Text
|
43
32
|
if task in [
|
44
33
|
"conversational",
|
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
|
|
84
73
|
HANDLER_TYPE = "huggingface_pipeline"
|
85
74
|
HANDLER_VERSION = "2023-12-01"
|
86
75
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
87
|
-
_HANDLER_MIGRATOR_PLANS:
|
76
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
88
77
|
|
89
78
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
90
79
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
|
|
250
239
|
task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
251
240
|
)
|
252
241
|
if framework is None or framework == "pt":
|
253
|
-
# Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
|
254
|
-
# Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
|
255
|
-
# users are not required to install pytorch locally if they are using the wrapper.
|
256
242
|
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
257
243
|
elif framework == "tf":
|
258
244
|
pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
259
245
|
model_meta.env.include_if_absent(
|
260
246
|
pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
261
247
|
)
|
262
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
248
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
263
249
|
|
264
250
|
@staticmethod
|
265
|
-
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) ->
|
266
|
-
device_config:
|
251
|
+
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
|
252
|
+
device_config: dict[str, Any] = {}
|
267
253
|
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
268
254
|
gpu_nums = 0
|
269
255
|
if cuda_visible_devices is not None:
|
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
|
|
369
355
|
def _create_custom_model(
|
370
356
|
raw_model: "transformers.Pipeline",
|
371
357
|
model_meta: model_meta_api.ModelMetadata,
|
372
|
-
) ->
|
358
|
+
) -> type[custom_model.CustomModel]:
|
373
359
|
def fn_factory(
|
374
360
|
raw_model: "transformers.Pipeline",
|
375
361
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
32
32
|
HANDLER_TYPE = "keras"
|
33
33
|
HANDLER_VERSION = "2025-01-01"
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
35
|
-
_HANDLER_MIGRATOR_PLANS:
|
35
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
37
|
MODEL_BLOB_FILE_OR_DIR = "model.keras"
|
38
38
|
CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
|
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
146
146
|
dependencies,
|
147
147
|
check_local_version=True,
|
148
148
|
)
|
149
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
149
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
150
150
|
|
151
151
|
@classmethod
|
152
152
|
def load_model(
|
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
185
185
|
def _create_custom_model(
|
186
186
|
raw_model: "keras.Model",
|
187
187
|
model_meta: model_meta_api.ModelMetadata,
|
188
|
-
) ->
|
188
|
+
) -> type[custom_model.CustomModel]:
|
189
189
|
def fn_factory(
|
190
190
|
raw_model: "keras.Model",
|
191
191
|
signature: model_signature.ModelSignature,
|
@@ -1,16 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import
|
4
|
-
TYPE_CHECKING,
|
5
|
-
Any,
|
6
|
-
Callable,
|
7
|
-
Dict,
|
8
|
-
Optional,
|
9
|
-
Type,
|
10
|
-
Union,
|
11
|
-
cast,
|
12
|
-
final,
|
13
|
-
)
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
14
4
|
|
15
5
|
import cloudpickle
|
16
6
|
import numpy as np
|
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
31
|
HANDLER_TYPE = "lightgbm"
|
42
32
|
HANDLER_VERSION = "2024-03-19"
|
43
33
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
44
|
-
_HANDLER_MIGRATOR_PLANS:
|
34
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
45
35
|
|
46
36
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
47
37
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
215
205
|
def _create_custom_model(
|
216
206
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
217
207
|
model_meta: model_meta_api.ModelMetadata,
|
218
|
-
) ->
|
208
|
+
) -> type[custom_model.CustomModel]:
|
219
209
|
def fn_factory(
|
220
210
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
221
211
|
signature: model_signature.ModelSignature,
|
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
250
240
|
|
251
241
|
return fn
|
252
242
|
|
253
|
-
type_method_dict:
|
243
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
254
244
|
for target_method_name, sig in model_meta.signatures.items():
|
255
245
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
256
246
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
61
61
|
HANDLER_TYPE = "mlflow"
|
62
62
|
HANDLER_VERSION = "2023-12-01"
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
|
-
_HANDLER_MIGRATOR_PLANS:
|
64
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
66
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
204
204
|
def _create_custom_model(
|
205
205
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
206
206
|
model_meta: model_meta_api.ModelMetadata,
|
207
|
-
) ->
|
207
|
+
) -> type[custom_model.CustomModel]:
|
208
208
|
def fn_factory(
|
209
209
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
210
210
|
signature: model_signature.ModelSignature,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import pandas as pd
|
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
38
38
|
HANDLER_TYPE = "pytorch"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
43
|
}
|
44
44
|
|
@@ -82,6 +82,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
82
82
|
enable_explainability = kwargs.get("enable_explainability", False)
|
83
83
|
if enable_explainability:
|
84
84
|
raise NotImplementedError("Explainability is not supported for PyTorch model.")
|
85
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
85
86
|
|
86
87
|
import torch
|
87
88
|
|
@@ -94,8 +95,6 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
94
95
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
95
96
|
)
|
96
97
|
|
97
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
98
|
-
|
99
98
|
def get_prediction(
|
100
99
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
101
100
|
) -> model_types.SupportedLocalDataType:
|
@@ -151,7 +150,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
151
150
|
model_meta.env.include_if_absent(
|
152
151
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
153
152
|
)
|
154
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
153
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
155
154
|
|
156
155
|
@classmethod
|
157
156
|
def load_model(
|
@@ -188,7 +187,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
188
187
|
def _create_custom_model(
|
189
188
|
raw_model: "torch.nn.Module",
|
190
189
|
model_meta: model_meta_api.ModelMetadata,
|
191
|
-
) ->
|
190
|
+
) -> type[custom_model.CustomModel]:
|
192
191
|
multiple_inputs = cast(
|
193
192
|
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
193
|
)["multiple_inputs"]
|