snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import pathlib
|
5
5
|
import tempfile
|
6
6
|
import warnings
|
7
|
-
from typing import Any,
|
7
|
+
from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
8
8
|
|
9
9
|
import yaml
|
10
10
|
|
@@ -104,7 +104,7 @@ class ModelOperator:
|
|
104
104
|
*,
|
105
105
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
106
106
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
107
|
-
statement_params: Optional[
|
107
|
+
statement_params: Optional[dict[str, Any]] = None,
|
108
108
|
) -> str:
|
109
109
|
stage_name = sql_identifier.SqlIdentifier(
|
110
110
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
@@ -137,7 +137,7 @@ class ModelOperator:
|
|
137
137
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
138
138
|
model_name: sql_identifier.SqlIdentifier,
|
139
139
|
version_name: sql_identifier.SqlIdentifier,
|
140
|
-
statement_params: Optional[
|
140
|
+
statement_params: Optional[dict[str, Any]] = None,
|
141
141
|
) -> ModelAction:
|
142
142
|
if self.validate_existence(
|
143
143
|
database_name=database_name,
|
@@ -169,7 +169,7 @@ class ModelOperator:
|
|
169
169
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
170
170
|
model_name: sql_identifier.SqlIdentifier,
|
171
171
|
version_name: sql_identifier.SqlIdentifier,
|
172
|
-
statement_params: Optional[
|
172
|
+
statement_params: Optional[dict[str, Any]] = None,
|
173
173
|
) -> None:
|
174
174
|
model_action = self.get_model_action_from_model_name_and_version(
|
175
175
|
database_name=database_name,
|
@@ -205,7 +205,7 @@ class ModelOperator:
|
|
205
205
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
206
206
|
model_name: sql_identifier.SqlIdentifier,
|
207
207
|
version_name: sql_identifier.SqlIdentifier,
|
208
|
-
statement_params: Optional[
|
208
|
+
statement_params: Optional[dict[str, Any]] = None,
|
209
209
|
use_live_commit: Optional[bool] = False,
|
210
210
|
) -> None:
|
211
211
|
|
@@ -263,7 +263,7 @@ class ModelOperator:
|
|
263
263
|
model_name: sql_identifier.SqlIdentifier,
|
264
264
|
version_name: sql_identifier.SqlIdentifier,
|
265
265
|
model_exists: bool,
|
266
|
-
statement_params: Optional[
|
266
|
+
statement_params: Optional[dict[str, Any]] = None,
|
267
267
|
) -> None:
|
268
268
|
if model_exists:
|
269
269
|
return self._model_version_client.add_version_from_model_version(
|
@@ -296,8 +296,8 @@ class ModelOperator:
|
|
296
296
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
297
297
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
298
298
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
299
|
-
statement_params: Optional[
|
300
|
-
) ->
|
299
|
+
statement_params: Optional[dict[str, Any]] = None,
|
300
|
+
) -> list[row.Row]:
|
301
301
|
if model_name:
|
302
302
|
return self._model_client.show_versions(
|
303
303
|
database_name=database_name,
|
@@ -320,8 +320,8 @@ class ModelOperator:
|
|
320
320
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
321
321
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
322
322
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
323
|
-
statement_params: Optional[
|
324
|
-
) ->
|
323
|
+
statement_params: Optional[dict[str, Any]] = None,
|
324
|
+
) -> list[sql_identifier.SqlIdentifier]:
|
325
325
|
res = self.show_models_or_versions(
|
326
326
|
database_name=database_name,
|
327
327
|
schema_name=schema_name,
|
@@ -341,7 +341,7 @@ class ModelOperator:
|
|
341
341
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
342
342
|
model_name: sql_identifier.SqlIdentifier,
|
343
343
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
344
|
-
statement_params: Optional[
|
344
|
+
statement_params: Optional[dict[str, Any]] = None,
|
345
345
|
) -> bool:
|
346
346
|
if version_name:
|
347
347
|
res = self._model_client.show_versions(
|
@@ -369,7 +369,7 @@ class ModelOperator:
|
|
369
369
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
370
370
|
model_name: sql_identifier.SqlIdentifier,
|
371
371
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
372
|
-
statement_params: Optional[
|
372
|
+
statement_params: Optional[dict[str, Any]] = None,
|
373
373
|
) -> str:
|
374
374
|
if version_name:
|
375
375
|
res = self._model_client.show_versions(
|
@@ -398,7 +398,7 @@ class ModelOperator:
|
|
398
398
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
399
399
|
model_name: sql_identifier.SqlIdentifier,
|
400
400
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
401
|
-
statement_params: Optional[
|
401
|
+
statement_params: Optional[dict[str, Any]] = None,
|
402
402
|
) -> None:
|
403
403
|
if version_name:
|
404
404
|
self._model_version_client.set_comment(
|
@@ -426,7 +426,7 @@ class ModelOperator:
|
|
426
426
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
427
427
|
model_name: sql_identifier.SqlIdentifier,
|
428
428
|
version_name: sql_identifier.SqlIdentifier,
|
429
|
-
statement_params: Optional[
|
429
|
+
statement_params: Optional[dict[str, Any]] = None,
|
430
430
|
) -> None:
|
431
431
|
self._model_version_client.set_alias(
|
432
432
|
alias_name=alias_name,
|
@@ -444,7 +444,7 @@ class ModelOperator:
|
|
444
444
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
445
445
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
446
446
|
model_name: sql_identifier.SqlIdentifier,
|
447
|
-
statement_params: Optional[
|
447
|
+
statement_params: Optional[dict[str, Any]] = None,
|
448
448
|
) -> None:
|
449
449
|
self._model_version_client.unset_alias(
|
450
450
|
database_name=database_name,
|
@@ -461,7 +461,7 @@ class ModelOperator:
|
|
461
461
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
462
462
|
model_name: sql_identifier.SqlIdentifier,
|
463
463
|
version_name: sql_identifier.SqlIdentifier,
|
464
|
-
statement_params: Optional[
|
464
|
+
statement_params: Optional[dict[str, Any]] = None,
|
465
465
|
) -> None:
|
466
466
|
if not self.validate_existence(
|
467
467
|
database_name=database_name,
|
@@ -485,7 +485,7 @@ class ModelOperator:
|
|
485
485
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
486
486
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
487
487
|
model_name: sql_identifier.SqlIdentifier,
|
488
|
-
statement_params: Optional[
|
488
|
+
statement_params: Optional[dict[str, Any]] = None,
|
489
489
|
) -> sql_identifier.SqlIdentifier:
|
490
490
|
res = self._model_client.show_models(
|
491
491
|
database_name=database_name,
|
@@ -504,7 +504,7 @@ class ModelOperator:
|
|
504
504
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
505
505
|
model_name: sql_identifier.SqlIdentifier,
|
506
506
|
alias_name: sql_identifier.SqlIdentifier,
|
507
|
-
statement_params: Optional[
|
507
|
+
statement_params: Optional[dict[str, Any]] = None,
|
508
508
|
) -> Optional[sql_identifier.SqlIdentifier]:
|
509
509
|
res = self._model_client.show_versions(
|
510
510
|
database_name=database_name,
|
@@ -528,7 +528,7 @@ class ModelOperator:
|
|
528
528
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
529
529
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
530
530
|
tag_name: sql_identifier.SqlIdentifier,
|
531
|
-
statement_params: Optional[
|
531
|
+
statement_params: Optional[dict[str, Any]] = None,
|
532
532
|
) -> Optional[str]:
|
533
533
|
r = self._tag_client.get_tag_value(
|
534
534
|
database_name=database_name,
|
@@ -550,15 +550,15 @@ class ModelOperator:
|
|
550
550
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
551
551
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
552
552
|
model_name: sql_identifier.SqlIdentifier,
|
553
|
-
statement_params: Optional[
|
554
|
-
) ->
|
553
|
+
statement_params: Optional[dict[str, Any]] = None,
|
554
|
+
) -> dict[str, str]:
|
555
555
|
tags_info = self._tag_client.get_tag_list(
|
556
556
|
database_name=database_name,
|
557
557
|
schema_name=schema_name,
|
558
558
|
model_name=model_name,
|
559
559
|
statement_params=statement_params,
|
560
560
|
)
|
561
|
-
res:
|
561
|
+
res: dict[str, str] = {
|
562
562
|
identifier.get_schema_level_object_identifier(
|
563
563
|
sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
|
564
564
|
sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
|
@@ -578,7 +578,7 @@ class ModelOperator:
|
|
578
578
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
579
579
|
tag_name: sql_identifier.SqlIdentifier,
|
580
580
|
tag_value: str,
|
581
|
-
statement_params: Optional[
|
581
|
+
statement_params: Optional[dict[str, Any]] = None,
|
582
582
|
) -> None:
|
583
583
|
self._tag_client.set_tag_on_model(
|
584
584
|
database_name=database_name,
|
@@ -600,7 +600,7 @@ class ModelOperator:
|
|
600
600
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
601
601
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
602
602
|
tag_name: sql_identifier.SqlIdentifier,
|
603
|
-
statement_params: Optional[
|
603
|
+
statement_params: Optional[dict[str, Any]] = None,
|
604
604
|
) -> None:
|
605
605
|
self._tag_client.unset_tag_on_model(
|
606
606
|
database_name=database_name,
|
@@ -619,8 +619,8 @@ class ModelOperator:
|
|
619
619
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
620
620
|
model_name: sql_identifier.SqlIdentifier,
|
621
621
|
version_name: sql_identifier.SqlIdentifier,
|
622
|
-
statement_params: Optional[
|
623
|
-
) ->
|
622
|
+
statement_params: Optional[dict[str, Any]] = None,
|
623
|
+
) -> list[ServiceInfo]:
|
624
624
|
res = self._model_client.show_versions(
|
625
625
|
database_name=database_name,
|
626
626
|
schema_name=schema_name,
|
@@ -682,7 +682,7 @@ class ModelOperator:
|
|
682
682
|
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
683
683
|
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
684
684
|
service_name: sql_identifier.SqlIdentifier,
|
685
|
-
statement_params: Optional[
|
685
|
+
statement_params: Optional[dict[str, Any]] = None,
|
686
686
|
) -> None:
|
687
687
|
services = self.show_services(
|
688
688
|
database_name=database_name,
|
@@ -724,7 +724,7 @@ class ModelOperator:
|
|
724
724
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
725
725
|
model_name: sql_identifier.SqlIdentifier,
|
726
726
|
version_name: sql_identifier.SqlIdentifier,
|
727
|
-
statement_params: Optional[
|
727
|
+
statement_params: Optional[dict[str, Any]] = None,
|
728
728
|
) -> model_manifest_schema.ModelManifestDict:
|
729
729
|
with tempfile.TemporaryDirectory() as tmpdir:
|
730
730
|
self._model_version_client.get_file(
|
@@ -741,9 +741,9 @@ class ModelOperator:
|
|
741
741
|
|
742
742
|
@staticmethod
|
743
743
|
def _match_model_spec_with_sql_functions(
|
744
|
-
sql_functions_names:
|
745
|
-
) ->
|
746
|
-
res:
|
744
|
+
sql_functions_names: list[sql_identifier.SqlIdentifier], target_methods: list[str]
|
745
|
+
) -> dict[sql_identifier.SqlIdentifier, str]:
|
746
|
+
res: dict[sql_identifier.SqlIdentifier, str] = {}
|
747
747
|
|
748
748
|
for target_method in target_methods:
|
749
749
|
# Here we need to find the SQL function corresponding to the Python function.
|
@@ -766,7 +766,7 @@ class ModelOperator:
|
|
766
766
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
767
767
|
model_name: sql_identifier.SqlIdentifier,
|
768
768
|
version_name: sql_identifier.SqlIdentifier,
|
769
|
-
statement_params: Optional[
|
769
|
+
statement_params: Optional[dict[str, Any]] = None,
|
770
770
|
) -> model_meta_schema.ModelMetadataDict:
|
771
771
|
raw_model_spec_res = self._model_client.show_versions(
|
772
772
|
database_name=database_name,
|
@@ -787,7 +787,7 @@ class ModelOperator:
|
|
787
787
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
788
788
|
model_name: sql_identifier.SqlIdentifier,
|
789
789
|
version_name: sql_identifier.SqlIdentifier,
|
790
|
-
statement_params: Optional[
|
790
|
+
statement_params: Optional[dict[str, Any]] = None,
|
791
791
|
) -> type_hints.Task:
|
792
792
|
model_version = self._model_client.show_versions(
|
793
793
|
database_name=database_name,
|
@@ -809,8 +809,8 @@ class ModelOperator:
|
|
809
809
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
810
810
|
model_name: sql_identifier.SqlIdentifier,
|
811
811
|
version_name: sql_identifier.SqlIdentifier,
|
812
|
-
statement_params: Optional[
|
813
|
-
) ->
|
812
|
+
statement_params: Optional[dict[str, Any]] = None,
|
813
|
+
) -> list[model_manifest_schema.ModelFunctionInfo]:
|
814
814
|
model_spec = self._fetch_model_spec(
|
815
815
|
database_name=database_name,
|
816
816
|
schema_name=schema_name,
|
@@ -907,7 +907,7 @@ class ModelOperator:
|
|
907
907
|
version_name: sql_identifier.SqlIdentifier,
|
908
908
|
strict_input_validation: bool = False,
|
909
909
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
910
|
-
statement_params: Optional[
|
910
|
+
statement_params: Optional[dict[str, str]] = None,
|
911
911
|
is_partitioned: Optional[bool] = None,
|
912
912
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
913
913
|
...
|
@@ -923,7 +923,7 @@ class ModelOperator:
|
|
923
923
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
924
924
|
service_name: sql_identifier.SqlIdentifier,
|
925
925
|
strict_input_validation: bool = False,
|
926
|
-
statement_params: Optional[
|
926
|
+
statement_params: Optional[dict[str, str]] = None,
|
927
927
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
928
928
|
...
|
929
929
|
|
@@ -941,7 +941,7 @@ class ModelOperator:
|
|
941
941
|
service_name: Optional[sql_identifier.SqlIdentifier] = None,
|
942
942
|
strict_input_validation: bool = False,
|
943
943
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
944
|
-
statement_params: Optional[
|
944
|
+
statement_params: Optional[dict[str, str]] = None,
|
945
945
|
is_partitioned: Optional[bool] = None,
|
946
946
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
947
947
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
@@ -1059,7 +1059,7 @@ class ModelOperator:
|
|
1059
1059
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
1060
1060
|
model_name: sql_identifier.SqlIdentifier,
|
1061
1061
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
1062
|
-
statement_params: Optional[
|
1062
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1063
1063
|
) -> None:
|
1064
1064
|
if version_name:
|
1065
1065
|
self._model_version_client.drop_version(
|
@@ -1086,7 +1086,7 @@ class ModelOperator:
|
|
1086
1086
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
1087
1087
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
1088
1088
|
new_model_name: sql_identifier.SqlIdentifier,
|
1089
|
-
statement_params: Optional[
|
1089
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1090
1090
|
) -> None:
|
1091
1091
|
self._model_client.rename(
|
1092
1092
|
database_name=database_name,
|
@@ -1121,7 +1121,7 @@ class ModelOperator:
|
|
1121
1121
|
version_name: sql_identifier.SqlIdentifier,
|
1122
1122
|
target_path: pathlib.Path,
|
1123
1123
|
mode: Literal["full", "model", "minimal"] = "model",
|
1124
|
-
statement_params: Optional[
|
1124
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1125
1125
|
) -> None:
|
1126
1126
|
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
1127
1127
|
list_file_res = self._model_version_client.list_file(
|
@@ -6,14 +6,16 @@ import re
|
|
6
6
|
import tempfile
|
7
7
|
import threading
|
8
8
|
import time
|
9
|
-
from typing import Any,
|
9
|
+
from typing import Any, Optional, Union, cast
|
10
10
|
|
11
11
|
from snowflake import snowpark
|
12
12
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
13
|
-
from snowflake.ml._internal.utils import service_logger, sql_identifier
|
13
|
+
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
14
|
+
from snowflake.ml.model import model_signature, type_hints
|
14
15
|
from snowflake.ml.model._client.service import model_deployment_spec
|
15
16
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
16
|
-
from snowflake.
|
17
|
+
from snowflake.ml.model._signatures import snowpark_handler
|
18
|
+
from snowflake.snowpark import async_job, dataframe, exceptions, row, session
|
17
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
18
20
|
|
19
21
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
@@ -104,9 +106,9 @@ class ServiceOperator:
|
|
104
106
|
num_workers: Optional[int],
|
105
107
|
max_batch_rows: Optional[int],
|
106
108
|
force_rebuild: bool,
|
107
|
-
build_external_access_integrations: Optional[
|
109
|
+
build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
|
108
110
|
block: bool,
|
109
|
-
statement_params: Optional[
|
111
|
+
statement_params: Optional[dict[str, Any]] = None,
|
110
112
|
) -> Union[str, async_job.AsyncJob]:
|
111
113
|
|
112
114
|
# Fall back to the registry's database and schema if not provided
|
@@ -120,32 +122,28 @@ class ServiceOperator:
|
|
120
122
|
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
121
123
|
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
122
124
|
if self._workspace:
|
123
|
-
|
124
|
-
stage_name = sql_identifier.SqlIdentifier(
|
125
|
-
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
126
|
-
)
|
127
|
-
self._stage_client.create_tmp_stage(
|
128
|
-
database_name=database_name,
|
129
|
-
schema_name=schema_name,
|
130
|
-
stage_name=stage_name,
|
131
|
-
statement_params=statement_params,
|
132
|
-
)
|
133
|
-
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
125
|
+
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
134
126
|
else:
|
135
127
|
stage_path = None
|
136
|
-
|
128
|
+
self._model_deployment_spec.add_model_spec(
|
137
129
|
database_name=database_name,
|
138
130
|
schema_name=schema_name,
|
139
131
|
model_name=model_name,
|
140
132
|
version_name=version_name,
|
141
|
-
|
142
|
-
|
143
|
-
service_name=service_name,
|
133
|
+
)
|
134
|
+
self._model_deployment_spec.add_image_build_spec(
|
144
135
|
image_build_compute_pool_name=image_build_compute_pool_name,
|
145
|
-
service_compute_pool_name=service_compute_pool_name,
|
146
136
|
image_repo_database_name=image_repo_database_name,
|
147
137
|
image_repo_schema_name=image_repo_schema_name,
|
148
138
|
image_repo_name=image_repo_name,
|
139
|
+
force_rebuild=force_rebuild,
|
140
|
+
external_access_integrations=build_external_access_integrations,
|
141
|
+
)
|
142
|
+
self._model_deployment_spec.add_service_spec(
|
143
|
+
service_database_name=service_database_name,
|
144
|
+
service_schema_name=service_schema_name,
|
145
|
+
service_name=service_name,
|
146
|
+
inference_compute_pool_name=service_compute_pool_name,
|
149
147
|
ingress_enabled=ingress_enabled,
|
150
148
|
max_instances=max_instances,
|
151
149
|
cpu=cpu_requests,
|
@@ -153,9 +151,8 @@ class ServiceOperator:
|
|
153
151
|
gpu=gpu_requests,
|
154
152
|
num_workers=num_workers,
|
155
153
|
max_batch_rows=max_batch_rows,
|
156
|
-
force_rebuild=force_rebuild,
|
157
|
-
external_access_integrations=build_external_access_integrations,
|
158
154
|
)
|
155
|
+
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
159
156
|
if self._workspace:
|
160
157
|
assert stage_path is not None
|
161
158
|
file_utils.upload_directory_to_stage(
|
@@ -210,7 +207,7 @@ class ServiceOperator:
|
|
210
207
|
if block:
|
211
208
|
log_thread.join()
|
212
209
|
|
213
|
-
res = cast(str, cast(
|
210
|
+
res = cast(str, cast(list[row.Row], async_job.result())[0][0])
|
214
211
|
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
215
212
|
return res
|
216
213
|
else:
|
@@ -219,10 +216,10 @@ class ServiceOperator:
|
|
219
216
|
def _start_service_log_streaming(
|
220
217
|
self,
|
221
218
|
async_job: snowpark.AsyncJob,
|
222
|
-
services:
|
219
|
+
services: list[ServiceLogInfo],
|
223
220
|
model_inference_service_exists: bool,
|
224
221
|
force_rebuild: bool,
|
225
|
-
statement_params: Optional[
|
222
|
+
statement_params: Optional[dict[str, Any]] = None,
|
226
223
|
) -> threading.Thread:
|
227
224
|
"""Start the service log streaming in a separate thread."""
|
228
225
|
log_thread = threading.Thread(
|
@@ -241,14 +238,14 @@ class ServiceOperator:
|
|
241
238
|
def _stream_service_logs(
|
242
239
|
self,
|
243
240
|
async_job: snowpark.AsyncJob,
|
244
|
-
services:
|
241
|
+
services: list[ServiceLogInfo],
|
245
242
|
model_inference_service_exists: bool,
|
246
243
|
force_rebuild: bool,
|
247
|
-
statement_params: Optional[
|
244
|
+
statement_params: Optional[dict[str, Any]] = None,
|
248
245
|
) -> None:
|
249
246
|
"""Stream service logs while the async job is running."""
|
250
247
|
|
251
|
-
def fetch_logs(service: ServiceLogInfo, offset: int) ->
|
248
|
+
def fetch_logs(service: ServiceLogInfo, offset: int) -> tuple[str, int]:
|
252
249
|
service_logs = self._service_client.get_service_logs(
|
253
250
|
database_name=service.database_name,
|
254
251
|
schema_name=service.schema_name,
|
@@ -393,7 +390,7 @@ class ServiceOperator:
|
|
393
390
|
service_logger: logging.Logger,
|
394
391
|
service: ServiceLogInfo,
|
395
392
|
offset: int,
|
396
|
-
statement_params: Optional[
|
393
|
+
statement_params: Optional[dict[str, Any]] = None,
|
397
394
|
) -> None:
|
398
395
|
"""Fetch service logs after the async job is done to ensure no logs are missed."""
|
399
396
|
try:
|
@@ -425,8 +422,8 @@ class ServiceOperator:
|
|
425
422
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
426
423
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
427
424
|
service_name: sql_identifier.SqlIdentifier,
|
428
|
-
service_status_list_if_exists: Optional[
|
429
|
-
statement_params: Optional[
|
425
|
+
service_status_list_if_exists: Optional[list[service_sql.ServiceStatus]] = None,
|
426
|
+
statement_params: Optional[dict[str, Any]] = None,
|
430
427
|
) -> bool:
|
431
428
|
if service_status_list_if_exists is None:
|
432
429
|
service_status_list_if_exists = [
|
@@ -448,3 +445,191 @@ class ServiceOperator:
|
|
448
445
|
return any(service_status == status for status in service_status_list_if_exists)
|
449
446
|
except exceptions.SnowparkSQLException:
|
450
447
|
return False
|
448
|
+
|
449
|
+
def invoke_job_method(
|
450
|
+
self,
|
451
|
+
target_method: str,
|
452
|
+
signature: model_signature.ModelSignature,
|
453
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
454
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
455
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
456
|
+
model_name: sql_identifier.SqlIdentifier,
|
457
|
+
version_name: sql_identifier.SqlIdentifier,
|
458
|
+
job_database_name: Optional[sql_identifier.SqlIdentifier],
|
459
|
+
job_schema_name: Optional[sql_identifier.SqlIdentifier],
|
460
|
+
job_name: sql_identifier.SqlIdentifier,
|
461
|
+
compute_pool_name: sql_identifier.SqlIdentifier,
|
462
|
+
warehouse_name: sql_identifier.SqlIdentifier,
|
463
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
464
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
465
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
466
|
+
output_table_database_name: Optional[sql_identifier.SqlIdentifier],
|
467
|
+
output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
|
468
|
+
output_table_name: sql_identifier.SqlIdentifier,
|
469
|
+
cpu_requests: Optional[str],
|
470
|
+
memory_requests: Optional[str],
|
471
|
+
gpu_requests: Optional[Union[int, str]],
|
472
|
+
num_workers: Optional[int],
|
473
|
+
max_batch_rows: Optional[int],
|
474
|
+
force_rebuild: bool,
|
475
|
+
build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
|
476
|
+
statement_params: Optional[dict[str, Any]] = None,
|
477
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
478
|
+
# fall back to the registry's database and schema if not provided
|
479
|
+
database_name = database_name or self._database_name
|
480
|
+
schema_name = schema_name or self._schema_name
|
481
|
+
|
482
|
+
# fall back to the model's database and schema if not provided then to the registry's database and schema
|
483
|
+
job_database_name = job_database_name or database_name or self._database_name
|
484
|
+
job_schema_name = job_schema_name or schema_name or self._schema_name
|
485
|
+
|
486
|
+
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
487
|
+
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
488
|
+
|
489
|
+
input_table_database_name = job_database_name
|
490
|
+
input_table_schema_name = job_schema_name
|
491
|
+
output_table_database_name = output_table_database_name or database_name or self._database_name
|
492
|
+
output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
|
493
|
+
|
494
|
+
if self._workspace:
|
495
|
+
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
496
|
+
else:
|
497
|
+
stage_path = None
|
498
|
+
|
499
|
+
# validate and prepare input
|
500
|
+
if not isinstance(X, dataframe.DataFrame):
|
501
|
+
keep_order = True
|
502
|
+
output_with_input_features = False
|
503
|
+
df = model_signature._convert_and_validate_local_data(X, signature.inputs)
|
504
|
+
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
505
|
+
self._session, df, keep_order=keep_order, features=signature.inputs
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
keep_order = False
|
509
|
+
output_with_input_features = True
|
510
|
+
s_df = X
|
511
|
+
|
512
|
+
# only write the index and feature input columns
|
513
|
+
cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
|
514
|
+
cols += [
|
515
|
+
sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
|
516
|
+
]
|
517
|
+
s_df = s_df.select(cols)
|
518
|
+
original_cols = s_df.columns
|
519
|
+
|
520
|
+
# input/output tables
|
521
|
+
fq_output_table_name = identifier.get_schema_level_object_identifier(
|
522
|
+
output_table_database_name.identifier(),
|
523
|
+
output_table_schema_name.identifier(),
|
524
|
+
output_table_name.identifier(),
|
525
|
+
)
|
526
|
+
tmp_input_table_id = sql_identifier.SqlIdentifier(
|
527
|
+
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
528
|
+
)
|
529
|
+
fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
|
530
|
+
job_database_name.identifier(),
|
531
|
+
job_schema_name.identifier(),
|
532
|
+
tmp_input_table_id.identifier(),
|
533
|
+
)
|
534
|
+
s_df.write.save_as_table(
|
535
|
+
table_name=fq_tmp_input_table_name,
|
536
|
+
mode="errorifexists",
|
537
|
+
statement_params=statement_params,
|
538
|
+
)
|
539
|
+
|
540
|
+
try:
|
541
|
+
# save the spec
|
542
|
+
self._model_deployment_spec.add_model_spec(
|
543
|
+
database_name=database_name,
|
544
|
+
schema_name=schema_name,
|
545
|
+
model_name=model_name,
|
546
|
+
version_name=version_name,
|
547
|
+
)
|
548
|
+
self._model_deployment_spec.add_job_spec(
|
549
|
+
job_database_name=job_database_name,
|
550
|
+
job_schema_name=job_schema_name,
|
551
|
+
job_name=job_name,
|
552
|
+
inference_compute_pool_name=compute_pool_name,
|
553
|
+
cpu=cpu_requests,
|
554
|
+
memory=memory_requests,
|
555
|
+
gpu=gpu_requests,
|
556
|
+
num_workers=num_workers,
|
557
|
+
max_batch_rows=max_batch_rows,
|
558
|
+
warehouse=warehouse_name,
|
559
|
+
target_method=target_method,
|
560
|
+
input_table_database_name=input_table_database_name,
|
561
|
+
input_table_schema_name=input_table_schema_name,
|
562
|
+
input_table_name=tmp_input_table_id,
|
563
|
+
output_table_database_name=output_table_database_name,
|
564
|
+
output_table_schema_name=output_table_schema_name,
|
565
|
+
output_table_name=output_table_name,
|
566
|
+
)
|
567
|
+
|
568
|
+
self._model_deployment_spec.add_image_build_spec(
|
569
|
+
image_build_compute_pool_name=compute_pool_name,
|
570
|
+
image_repo_database_name=image_repo_database_name,
|
571
|
+
image_repo_schema_name=image_repo_schema_name,
|
572
|
+
image_repo_name=image_repo_name,
|
573
|
+
force_rebuild=force_rebuild,
|
574
|
+
external_access_integrations=build_external_access_integrations,
|
575
|
+
)
|
576
|
+
|
577
|
+
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
578
|
+
if self._workspace:
|
579
|
+
assert stage_path is not None
|
580
|
+
file_utils.upload_directory_to_stage(
|
581
|
+
self._session,
|
582
|
+
local_path=pathlib.Path(self._workspace.name),
|
583
|
+
stage_path=pathlib.PurePosixPath(stage_path),
|
584
|
+
statement_params=statement_params,
|
585
|
+
)
|
586
|
+
|
587
|
+
# deploy the job
|
588
|
+
query_id, async_job = self._service_client.deploy_model(
|
589
|
+
stage_path=stage_path if self._workspace else None,
|
590
|
+
model_deployment_spec_file_rel_path=(
|
591
|
+
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
592
|
+
),
|
593
|
+
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
594
|
+
statement_params=statement_params,
|
595
|
+
)
|
596
|
+
|
597
|
+
while not async_job.is_done():
|
598
|
+
time.sleep(5)
|
599
|
+
finally:
|
600
|
+
self._session.table(fq_tmp_input_table_name).drop_table()
|
601
|
+
|
602
|
+
# handle the output
|
603
|
+
df_res = self._session.table(fq_output_table_name)
|
604
|
+
if keep_order:
|
605
|
+
df_res = df_res.sort(
|
606
|
+
snowpark_handler._KEEP_ORDER_COL_NAME,
|
607
|
+
ascending=True,
|
608
|
+
)
|
609
|
+
df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
|
610
|
+
|
611
|
+
if not output_with_input_features:
|
612
|
+
df_res = df_res.drop(*original_cols)
|
613
|
+
|
614
|
+
# get final result
|
615
|
+
if not isinstance(X, dataframe.DataFrame):
|
616
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
|
617
|
+
else:
|
618
|
+
return df_res
|
619
|
+
|
620
|
+
def _create_temp_stage(
|
621
|
+
self,
|
622
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
623
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
624
|
+
statement_params: Optional[dict[str, Any]] = None,
|
625
|
+
) -> str:
|
626
|
+
stage_name = sql_identifier.SqlIdentifier(
|
627
|
+
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
628
|
+
)
|
629
|
+
self._stage_client.create_tmp_stage(
|
630
|
+
database_name=database_name,
|
631
|
+
schema_name=schema_name,
|
632
|
+
stage_name=stage_name,
|
633
|
+
statement_params=statement_params,
|
634
|
+
)
|
635
|
+
return self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name) # stage path
|