snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
from urllib.parse import ParseResult
|
6
6
|
|
7
7
|
from snowflake.ml._internal.utils import (
|
@@ -34,7 +34,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
34
34
|
model_name: sql_identifier.SqlIdentifier,
|
35
35
|
version_name: sql_identifier.SqlIdentifier,
|
36
36
|
stage_path: str,
|
37
|
-
statement_params: Optional[
|
37
|
+
statement_params: Optional[dict[str, Any]] = None,
|
38
38
|
) -> None:
|
39
39
|
query_result_checker.SqlResultValidator(
|
40
40
|
self._session,
|
@@ -56,7 +56,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
56
56
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
57
57
|
model_name: sql_identifier.SqlIdentifier,
|
58
58
|
version_name: sql_identifier.SqlIdentifier,
|
59
|
-
statement_params: Optional[
|
59
|
+
statement_params: Optional[dict[str, Any]] = None,
|
60
60
|
) -> None:
|
61
61
|
fq_source_model_name = self.fully_qualified_object_name(
|
62
62
|
source_database_name, source_schema_name, source_model_name
|
@@ -78,7 +78,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
78
78
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
79
79
|
model_name: sql_identifier.SqlIdentifier,
|
80
80
|
version_name: sql_identifier.SqlIdentifier,
|
81
|
-
statement_params: Optional[
|
81
|
+
statement_params: Optional[dict[str, Any]] = None,
|
82
82
|
) -> None:
|
83
83
|
sql = (
|
84
84
|
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -97,7 +97,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
97
97
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
98
|
model_name: sql_identifier.SqlIdentifier,
|
99
99
|
version_name: sql_identifier.SqlIdentifier,
|
100
|
-
statement_params: Optional[
|
100
|
+
statement_params: Optional[dict[str, Any]] = None,
|
101
101
|
) -> None:
|
102
102
|
sql = (
|
103
103
|
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -116,7 +116,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
116
116
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
117
117
|
model_name: sql_identifier.SqlIdentifier,
|
118
118
|
version_name: sql_identifier.SqlIdentifier,
|
119
|
-
statement_params: Optional[
|
119
|
+
statement_params: Optional[dict[str, Any]] = None,
|
120
120
|
) -> None:
|
121
121
|
sql = (
|
122
122
|
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -138,7 +138,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
138
138
|
model_name: sql_identifier.SqlIdentifier,
|
139
139
|
version_name: sql_identifier.SqlIdentifier,
|
140
140
|
stage_path: str,
|
141
|
-
statement_params: Optional[
|
141
|
+
statement_params: Optional[dict[str, Any]] = None,
|
142
142
|
) -> None:
|
143
143
|
query_result_checker.SqlResultValidator(
|
144
144
|
self._session,
|
@@ -160,7 +160,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
160
160
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
161
161
|
model_name: sql_identifier.SqlIdentifier,
|
162
162
|
version_name: sql_identifier.SqlIdentifier,
|
163
|
-
statement_params: Optional[
|
163
|
+
statement_params: Optional[dict[str, Any]] = None,
|
164
164
|
) -> None:
|
165
165
|
fq_source_model_name = self.fully_qualified_object_name(
|
166
166
|
source_database_name, source_schema_name, source_model_name
|
@@ -182,7 +182,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
182
182
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
183
183
|
model_name: sql_identifier.SqlIdentifier,
|
184
184
|
version_name: sql_identifier.SqlIdentifier,
|
185
|
-
statement_params: Optional[
|
185
|
+
statement_params: Optional[dict[str, Any]] = None,
|
186
186
|
) -> None:
|
187
187
|
query_result_checker.SqlResultValidator(
|
188
188
|
self._session,
|
@@ -201,7 +201,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
201
201
|
model_name: sql_identifier.SqlIdentifier,
|
202
202
|
version_name: sql_identifier.SqlIdentifier,
|
203
203
|
alias_name: sql_identifier.SqlIdentifier,
|
204
|
-
statement_params: Optional[
|
204
|
+
statement_params: Optional[dict[str, Any]] = None,
|
205
205
|
) -> None:
|
206
206
|
query_result_checker.SqlResultValidator(
|
207
207
|
self._session,
|
@@ -219,7 +219,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
219
219
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
220
220
|
model_name: sql_identifier.SqlIdentifier,
|
221
221
|
version_or_alias_name: sql_identifier.SqlIdentifier,
|
222
|
-
statement_params: Optional[
|
222
|
+
statement_params: Optional[dict[str, Any]] = None,
|
223
223
|
) -> None:
|
224
224
|
query_result_checker.SqlResultValidator(
|
225
225
|
self._session,
|
@@ -239,8 +239,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
239
239
|
version_name: sql_identifier.SqlIdentifier,
|
240
240
|
file_path: pathlib.PurePosixPath,
|
241
241
|
is_dir: bool = False,
|
242
|
-
statement_params: Optional[
|
243
|
-
) ->
|
242
|
+
statement_params: Optional[dict[str, Any]] = None,
|
243
|
+
) -> list[row.Row]:
|
244
244
|
# Workaround for snowURL bug.
|
245
245
|
trailing_slash = "/" if is_dir else ""
|
246
246
|
|
@@ -276,7 +276,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
276
276
|
version_name: sql_identifier.SqlIdentifier,
|
277
277
|
file_path: pathlib.PurePosixPath,
|
278
278
|
target_path: pathlib.Path,
|
279
|
-
statement_params: Optional[
|
279
|
+
statement_params: Optional[dict[str, Any]] = None,
|
280
280
|
) -> pathlib.Path:
|
281
281
|
stage_location = pathlib.PurePosixPath(
|
282
282
|
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
@@ -310,8 +310,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
310
310
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
311
311
|
model_name: sql_identifier.SqlIdentifier,
|
312
312
|
version_name: sql_identifier.SqlIdentifier,
|
313
|
-
statement_params: Optional[
|
314
|
-
) ->
|
313
|
+
statement_params: Optional[dict[str, Any]] = None,
|
314
|
+
) -> list[row.Row]:
|
315
315
|
res = query_result_checker.SqlResultValidator(
|
316
316
|
self._session,
|
317
317
|
(
|
@@ -331,7 +331,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
331
331
|
model_name: sql_identifier.SqlIdentifier,
|
332
332
|
version_name: sql_identifier.SqlIdentifier,
|
333
333
|
comment: str,
|
334
|
-
statement_params: Optional[
|
334
|
+
statement_params: Optional[dict[str, Any]] = None,
|
335
335
|
) -> None:
|
336
336
|
query_result_checker.SqlResultValidator(
|
337
337
|
self._session,
|
@@ -351,9 +351,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
351
351
|
version_name: sql_identifier.SqlIdentifier,
|
352
352
|
method_name: sql_identifier.SqlIdentifier,
|
353
353
|
input_df: dataframe.DataFrame,
|
354
|
-
input_args:
|
355
|
-
returns:
|
356
|
-
statement_params: Optional[
|
354
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
355
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
356
|
+
statement_params: Optional[dict[str, Any]] = None,
|
357
357
|
) -> dataframe.DataFrame:
|
358
358
|
with_statements = []
|
359
359
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
@@ -433,10 +433,10 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
433
433
|
version_name: sql_identifier.SqlIdentifier,
|
434
434
|
method_name: sql_identifier.SqlIdentifier,
|
435
435
|
input_df: dataframe.DataFrame,
|
436
|
-
input_args:
|
437
|
-
returns:
|
436
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
437
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
438
438
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
439
|
-
statement_params: Optional[
|
439
|
+
statement_params: Optional[dict[str, Any]] = None,
|
440
440
|
is_partitioned: bool = True,
|
441
441
|
) -> dataframe.DataFrame:
|
442
442
|
with_statements = []
|
@@ -529,13 +529,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
529
529
|
|
530
530
|
def set_metadata(
|
531
531
|
self,
|
532
|
-
metadata_dict:
|
532
|
+
metadata_dict: dict[str, Any],
|
533
533
|
*,
|
534
534
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
535
535
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
536
536
|
model_name: sql_identifier.SqlIdentifier,
|
537
537
|
version_name: sql_identifier.SqlIdentifier,
|
538
|
-
statement_params: Optional[
|
538
|
+
statement_params: Optional[dict[str, Any]] = None,
|
539
539
|
) -> None:
|
540
540
|
json_metadata = json.dumps(metadata_dict)
|
541
541
|
query_result_checker.SqlResultValidator(
|
@@ -554,7 +554,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
554
554
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
555
555
|
model_name: sql_identifier.SqlIdentifier,
|
556
556
|
version_name: sql_identifier.SqlIdentifier,
|
557
|
-
statement_params: Optional[
|
557
|
+
statement_params: Optional[dict[str, Any]] = None,
|
558
558
|
) -> None:
|
559
559
|
query_result_checker.SqlResultValidator(
|
560
560
|
self._session,
|
@@ -1,10 +1,9 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import textwrap
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
|
-
from snowflake.ml._internal import platform_capabilities
|
8
7
|
from snowflake.ml._internal.utils import (
|
9
8
|
identifier,
|
10
9
|
query_result_checker,
|
@@ -47,7 +46,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
47
46
|
gpu: Optional[Union[str, int]],
|
48
47
|
force_rebuild: bool,
|
49
48
|
external_access_integration: sql_identifier.SqlIdentifier,
|
50
|
-
statement_params: Optional[
|
49
|
+
statement_params: Optional[dict[str, Any]] = None,
|
51
50
|
) -> None:
|
52
51
|
actual_image_repo_database = image_repo_database_name or self._database_name
|
53
52
|
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
@@ -76,8 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
76
75
|
stage_path: Optional[str] = None,
|
77
76
|
model_deployment_spec_yaml_str: Optional[str] = None,
|
78
77
|
model_deployment_spec_file_rel_path: Optional[str] = None,
|
79
|
-
statement_params: Optional[
|
80
|
-
) ->
|
78
|
+
statement_params: Optional[dict[str, Any]] = None,
|
79
|
+
) -> tuple[str, snowpark.AsyncJob]:
|
81
80
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
82
81
|
if model_deployment_spec_yaml_str:
|
83
82
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
@@ -95,9 +94,9 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
95
94
|
service_name: sql_identifier.SqlIdentifier,
|
96
95
|
method_name: sql_identifier.SqlIdentifier,
|
97
96
|
input_df: dataframe.DataFrame,
|
98
|
-
input_args:
|
99
|
-
returns:
|
100
|
-
statement_params: Optional[
|
97
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
98
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
99
|
+
statement_params: Optional[dict[str, Any]] = None,
|
101
100
|
) -> dataframe.DataFrame:
|
102
101
|
with_statements = []
|
103
102
|
actual_database_name = database_name or self._database_name
|
@@ -133,18 +132,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
133
132
|
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
134
133
|
args_sql = f"object_construct_keep_null({input_args_sql})"
|
135
134
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
141
|
-
else:
|
142
|
-
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
143
|
-
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
144
|
-
actual_database_name.identifier(),
|
145
|
-
actual_schema_name.identifier(),
|
146
|
-
function_name,
|
147
|
-
)
|
135
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
136
|
+
actual_database_name, actual_schema_name, service_name
|
137
|
+
)
|
138
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
148
139
|
|
149
140
|
sql = textwrap.dedent(
|
150
141
|
f"""{with_sql}
|
@@ -181,7 +172,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
181
172
|
service_name: sql_identifier.SqlIdentifier,
|
182
173
|
instance_id: str = "0",
|
183
174
|
container_name: str,
|
184
|
-
statement_params: Optional[
|
175
|
+
statement_params: Optional[dict[str, Any]] = None,
|
185
176
|
) -> str:
|
186
177
|
system_func = "SYSTEM$GET_SERVICE_LOGS"
|
187
178
|
rows = (
|
@@ -206,8 +197,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
206
197
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
207
198
|
service_name: sql_identifier.SqlIdentifier,
|
208
199
|
include_message: bool = False,
|
209
|
-
statement_params: Optional[
|
210
|
-
) ->
|
200
|
+
statement_params: Optional[dict[str, Any]] = None,
|
201
|
+
) -> tuple[ServiceStatus, Optional[str]]:
|
211
202
|
system_func = "SYSTEM$GET_SERVICE_STATUS"
|
212
203
|
rows = (
|
213
204
|
query_result_checker.SqlResultValidator(
|
@@ -231,7 +222,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
231
222
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
232
223
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
233
224
|
service_name: sql_identifier.SqlIdentifier,
|
234
|
-
statement_params: Optional[
|
225
|
+
statement_params: Optional[dict[str, Any]] = None,
|
235
226
|
) -> None:
|
236
227
|
query_result_checker.SqlResultValidator(
|
237
228
|
self._session,
|
@@ -245,8 +236,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
245
236
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
246
237
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
247
238
|
service_name: sql_identifier.SqlIdentifier,
|
248
|
-
statement_params: Optional[
|
249
|
-
) ->
|
239
|
+
statement_params: Optional[dict[str, Any]] = None,
|
240
|
+
) -> list[row.Row]:
|
250
241
|
fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
251
242
|
res = (
|
252
243
|
query_result_checker.SqlResultValidator(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
@@ -11,7 +11,7 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
11
11
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
12
12
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
13
13
|
stage_name: sql_identifier.SqlIdentifier,
|
14
|
-
statement_params: Optional[
|
14
|
+
statement_params: Optional[dict[str, Any]] = None,
|
15
15
|
) -> None:
|
16
16
|
query_result_checker.SqlResultValidator(
|
17
17
|
self._session,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
@@ -16,7 +16,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
16
16
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
17
17
|
tag_name: sql_identifier.SqlIdentifier,
|
18
18
|
tag_value: str,
|
19
|
-
statement_params: Optional[
|
19
|
+
statement_params: Optional[dict[str, Any]] = None,
|
20
20
|
) -> None:
|
21
21
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
22
22
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -35,7 +35,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
35
35
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
36
36
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
37
37
|
tag_name: sql_identifier.SqlIdentifier,
|
38
|
-
statement_params: Optional[
|
38
|
+
statement_params: Optional[dict[str, Any]] = None,
|
39
39
|
) -> None:
|
40
40
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
41
41
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -54,7 +54,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
54
54
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
55
55
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
56
56
|
tag_name: sql_identifier.SqlIdentifier,
|
57
|
-
statement_params: Optional[
|
57
|
+
statement_params: Optional[dict[str, Any]] = None,
|
58
58
|
) -> row.Row:
|
59
59
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
60
60
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -75,8 +75,8 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
75
75
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
76
76
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
77
77
|
model_name: sql_identifier.SqlIdentifier,
|
78
|
-
statement_params: Optional[
|
79
|
-
) ->
|
78
|
+
statement_params: Optional[dict[str, Any]] = None,
|
79
|
+
) -> list[row.Row]:
|
80
80
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
81
81
|
actual_database_name = database_name or self._database_name
|
82
82
|
return (
|
@@ -3,13 +3,14 @@ import tempfile
|
|
3
3
|
import uuid
|
4
4
|
import warnings
|
5
5
|
from types import ModuleType
|
6
|
-
from typing import Any,
|
6
|
+
from typing import Any, Optional, Union
|
7
7
|
from urllib import parse
|
8
8
|
|
9
9
|
from absl import logging
|
10
10
|
from packaging import requirements
|
11
11
|
|
12
12
|
from snowflake import snowpark
|
13
|
+
from snowflake.ml import version as snowml_version
|
13
14
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
15
|
from snowflake.ml._internal.lineage import lineage_utils
|
15
16
|
from snowflake.ml.data import data_source
|
@@ -43,7 +44,7 @@ class ModelComposer:
|
|
43
44
|
session: Session,
|
44
45
|
stage_path: str,
|
45
46
|
*,
|
46
|
-
statement_params: Optional[
|
47
|
+
statement_params: Optional[dict[str, Any]] = None,
|
47
48
|
save_location: Optional[str] = None,
|
48
49
|
) -> None:
|
49
50
|
self.session = session
|
@@ -122,17 +123,18 @@ class ModelComposer:
|
|
122
123
|
*,
|
123
124
|
name: str,
|
124
125
|
model: model_types.SupportedModelType,
|
125
|
-
signatures: Optional[
|
126
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
126
127
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
127
|
-
metadata: Optional[
|
128
|
-
conda_dependencies: Optional[
|
129
|
-
pip_requirements: Optional[
|
130
|
-
artifact_repository_map: Optional[
|
131
|
-
|
128
|
+
metadata: Optional[dict[str, str]] = None,
|
129
|
+
conda_dependencies: Optional[list[str]] = None,
|
130
|
+
pip_requirements: Optional[list[str]] = None,
|
131
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
132
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
133
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
132
134
|
python_version: Optional[str] = None,
|
133
|
-
user_files: Optional[
|
134
|
-
ext_modules: Optional[
|
135
|
-
code_paths: Optional[
|
135
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
136
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
137
|
+
code_paths: Optional[list[str]] = None,
|
136
138
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
137
139
|
options: Optional[model_types.ModelSaveOption] = None,
|
138
140
|
) -> model_meta.ModelMetadata:
|
@@ -140,40 +142,63 @@ class ModelComposer:
|
|
140
142
|
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
141
143
|
conda_dependencies if conda_dependencies else []
|
142
144
|
)
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
145
|
+
|
146
|
+
enable_explainability = None
|
147
|
+
|
148
|
+
if options:
|
149
|
+
enable_explainability = options.get("enable_explainability", None)
|
150
|
+
|
151
|
+
# skip everything if user said False explicitly
|
152
|
+
if enable_explainability is None or enable_explainability is True:
|
153
|
+
is_warehouse_runnable = (
|
154
|
+
not conda_dep_dict
|
155
|
+
or all(
|
156
|
+
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
157
|
+
for chan in conda_dep_dict
|
158
|
+
)
|
159
|
+
) and (not pip_requirements)
|
160
|
+
|
161
|
+
only_spcs = (
|
162
|
+
target_platforms
|
163
|
+
and len(target_platforms) == 1
|
164
|
+
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
159
165
|
)
|
166
|
+
if only_spcs or (not is_warehouse_runnable):
|
167
|
+
# if only SPCS and user asked for explainability we fail
|
168
|
+
if enable_explainability is True:
|
169
|
+
raise ValueError(
|
170
|
+
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
171
|
+
"or the target platforms include SPCS."
|
172
|
+
)
|
173
|
+
elif not options: # explicitly set flag to false in these cases if not specified
|
174
|
+
options = model_types.BaseModelSaveOption()
|
175
|
+
options["enable_explainability"] = False
|
176
|
+
elif (
|
177
|
+
target_platforms
|
178
|
+
and len(target_platforms) > 1
|
179
|
+
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
180
|
+
): # if both then only available for WH
|
181
|
+
if enable_explainability is True:
|
182
|
+
warnings.warn(
|
183
|
+
("Explain function will only be available for model deployed to warehouse."),
|
184
|
+
category=UserWarning,
|
185
|
+
stacklevel=2,
|
186
|
+
)
|
160
187
|
|
161
188
|
if not options:
|
162
189
|
options = model_types.BaseModelSaveOption()
|
163
|
-
if disable_explainability:
|
164
|
-
options["enable_explainability"] = False
|
165
190
|
|
166
191
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
167
192
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
168
193
|
self.session,
|
169
|
-
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={
|
194
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
170
195
|
python_version=python_version or snowml_env.PYTHON_VERSION,
|
171
196
|
statement_params=self._statement_params,
|
172
197
|
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
173
198
|
|
174
199
|
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
175
200
|
logging.info(
|
176
|
-
f"Local snowflake-ml-python library has version {
|
201
|
+
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
177
202
|
" which is not available in the Snowflake server, embedding local ML library automatically."
|
178
203
|
)
|
179
204
|
options["embed_local_ml_library"] = True
|
@@ -187,6 +212,7 @@ class ModelComposer:
|
|
187
212
|
conda_dependencies=conda_dependencies,
|
188
213
|
pip_requirements=pip_requirements,
|
189
214
|
artifact_repository_map=artifact_repository_map,
|
215
|
+
resource_constraint=resource_constraint,
|
190
216
|
target_platforms=target_platforms,
|
191
217
|
python_version=python_version,
|
192
218
|
ext_modules=ext_modules,
|
@@ -226,7 +252,7 @@ class ModelComposer:
|
|
226
252
|
|
227
253
|
def _get_data_sources(
|
228
254
|
self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
|
229
|
-
) -> Optional[
|
255
|
+
) -> Optional[list[data_source.DataSource]]:
|
230
256
|
data_sources = lineage_utils.get_data_sources(model)
|
231
257
|
if not data_sources and sample_input_data is not None:
|
232
258
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
|
-
from typing import
|
5
|
+
from typing import Optional, cast
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -45,10 +45,10 @@ class ModelManifest:
|
|
45
45
|
self,
|
46
46
|
model_meta: model_meta_api.ModelMetadata,
|
47
47
|
model_rel_path: pathlib.PurePosixPath,
|
48
|
-
user_files: Optional[
|
48
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
49
49
|
options: Optional[type_hints.ModelSaveOption] = None,
|
50
|
-
data_sources: Optional[
|
51
|
-
target_platforms: Optional[
|
50
|
+
data_sources: Optional[list[data_source.DataSource]] = None,
|
51
|
+
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
52
52
|
) -> None:
|
53
53
|
if options is None:
|
54
54
|
options = {}
|
@@ -78,12 +78,13 @@ class ModelManifest:
|
|
78
78
|
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
79
79
|
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
80
80
|
logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
|
81
|
+
logger.info(f"resource_constraint: {runtime_to_use.runtime_env.resource_constraint}")
|
81
82
|
runtime_dict = runtime_to_use.save(
|
82
83
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
83
84
|
)
|
84
85
|
|
85
86
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
86
|
-
self.methods:
|
87
|
+
self.methods: list[model_method.ModelMethod] = []
|
87
88
|
|
88
89
|
for target_method in model_meta.signatures.keys():
|
89
90
|
method = model_method.ModelMethod(
|
@@ -100,7 +101,7 @@ class ModelManifest:
|
|
100
101
|
|
101
102
|
self.methods.append(method)
|
102
103
|
|
103
|
-
self.user_files:
|
104
|
+
self.user_files: list[model_user_file.ModelUserFile] = []
|
104
105
|
|
105
106
|
if user_files is not None:
|
106
107
|
for subdirectory, paths in user_files.items():
|
@@ -127,16 +128,19 @@ class ModelManifest:
|
|
127
128
|
if model_meta.env.artifact_repository_map:
|
128
129
|
dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
|
129
130
|
|
131
|
+
runtime = model_manifest_schema.ModelRuntimeDict(
|
132
|
+
language="PYTHON",
|
133
|
+
version=runtime_to_use.runtime_env.python_version,
|
134
|
+
imports=runtime_dict["imports"],
|
135
|
+
dependencies=dependencies,
|
136
|
+
)
|
137
|
+
|
138
|
+
if runtime_dict["resource_constraint"]:
|
139
|
+
runtime["resource_constraint"] = runtime_dict["resource_constraint"]
|
140
|
+
|
130
141
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
131
142
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
132
|
-
runtimes={
|
133
|
-
self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
|
134
|
-
language="PYTHON",
|
135
|
-
version=runtime_to_use.runtime_env.python_version,
|
136
|
-
imports=runtime_dict["imports"],
|
137
|
-
dependencies=dependencies,
|
138
|
-
)
|
139
|
-
},
|
143
|
+
runtimes={self._DEFAULT_RUNTIME_NAME: runtime},
|
140
144
|
methods=[
|
141
145
|
method.save(
|
142
146
|
self.workspace_path,
|
@@ -178,8 +182,8 @@ class ModelManifest:
|
|
178
182
|
return res
|
179
183
|
|
180
184
|
def _extract_lineage_info(
|
181
|
-
self, data_sources: Optional[
|
182
|
-
) ->
|
185
|
+
self, data_sources: Optional[list[data_source.DataSource]]
|
186
|
+
) -> list[model_manifest_schema.LineageSourceDict]:
|
183
187
|
result = []
|
184
188
|
if data_sources:
|
185
189
|
for source in data_sources:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
2
|
import enum
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Literal, Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
@@ -20,14 +20,15 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
21
|
conda: NotRequired[str]
|
22
22
|
pip: NotRequired[str]
|
23
|
-
artifact_repository_map: NotRequired[Optional[
|
23
|
+
artifact_repository_map: NotRequired[Optional[dict[str, str]]]
|
24
24
|
|
25
25
|
|
26
26
|
class ModelRuntimeDict(TypedDict):
|
27
27
|
language: Required[Literal["PYTHON"]]
|
28
28
|
version: Required[str]
|
29
|
-
imports: Required[
|
29
|
+
imports: Required[list[str]]
|
30
30
|
dependencies: Required[ModelRuntimeDependenciesDict]
|
31
|
+
resource_constraint: NotRequired[Optional[dict[str, str]]]
|
31
32
|
|
32
33
|
|
33
34
|
class ModelMethodSignatureField(TypedDict):
|
@@ -43,8 +44,8 @@ class ModelFunctionMethodDict(TypedDict):
|
|
43
44
|
runtime: Required[str]
|
44
45
|
type: Required[str]
|
45
46
|
handler: Required[str]
|
46
|
-
inputs: Required[
|
47
|
-
outputs: Required[Union[
|
47
|
+
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
48
|
+
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
48
49
|
|
49
50
|
|
50
51
|
ModelMethodDict = ModelFunctionMethodDict
|
@@ -71,12 +72,12 @@ class ModelFunctionInfo(TypedDict):
|
|
71
72
|
class ModelFunctionInfoDict(TypedDict):
|
72
73
|
name: Required[str]
|
73
74
|
target_method: Required[str]
|
74
|
-
signature: Required[
|
75
|
+
signature: Required[dict[str, Any]]
|
75
76
|
|
76
77
|
|
77
78
|
class SnowparkMLDataDict(TypedDict):
|
78
79
|
schema_version: Required[str]
|
79
|
-
functions: Required[
|
80
|
+
functions: Required[list[ModelFunctionInfoDict]]
|
80
81
|
|
81
82
|
|
82
83
|
class LineageSourceTypes(enum.Enum):
|
@@ -92,9 +93,9 @@ class LineageSourceDict(TypedDict):
|
|
92
93
|
|
93
94
|
class ModelManifestDict(TypedDict):
|
94
95
|
manifest_version: Required[str]
|
95
|
-
runtimes: Required[
|
96
|
-
methods: Required[
|
97
|
-
user_data: NotRequired[
|
98
|
-
user_files: NotRequired[
|
99
|
-
lineage_sources: NotRequired[
|
100
|
-
target_platforms: NotRequired[
|
96
|
+
runtimes: Required[dict[str, ModelRuntimeDict]]
|
97
|
+
methods: Required[list[ModelMethodDict]]
|
98
|
+
user_data: NotRequired[dict[str, Any]]
|
99
|
+
user_files: NotRequired[list[str]]
|
100
|
+
lineage_sources: NotRequired[list[LineageSourceDict]]
|
101
|
+
target_platforms: NotRequired[list[str]]
|