snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,16 +1,7 @@
|
|
1
|
+
import logging
|
1
2
|
import pathlib
|
2
3
|
import textwrap
|
3
|
-
from typing import
|
4
|
-
Any,
|
5
|
-
Callable,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
TypeVar,
|
11
|
-
Union,
|
12
|
-
overload,
|
13
|
-
)
|
4
|
+
from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
|
14
5
|
from uuid import uuid4
|
15
6
|
|
16
7
|
import yaml
|
@@ -23,13 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
|
23
14
|
from snowflake.snowpark.context import get_active_session
|
24
15
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
25
16
|
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
26
19
|
_PROJECT = "MLJob"
|
27
20
|
JOB_ID_PREFIX = "MLJOB_"
|
28
21
|
|
29
22
|
T = TypeVar("T")
|
30
23
|
|
31
24
|
|
32
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
25
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
34
26
|
def list_jobs(
|
35
27
|
limit: int = 10,
|
@@ -60,7 +52,7 @@ def list_jobs(
|
|
60
52
|
query += f" LIMIT {limit}"
|
61
53
|
df = session.sql(query)
|
62
54
|
df = df.select(
|
63
|
-
df['"name"']
|
55
|
+
df['"name"'],
|
64
56
|
df['"owner"'],
|
65
57
|
df['"status"'],
|
66
58
|
df['"created_on"'],
|
@@ -69,21 +61,20 @@ def list_jobs(
|
|
69
61
|
return df
|
70
62
|
|
71
63
|
|
72
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
73
64
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
74
65
|
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
75
66
|
"""Retrieve a job service from the backend."""
|
76
67
|
session = session or get_active_session()
|
77
|
-
|
78
68
|
try:
|
79
|
-
|
80
|
-
|
69
|
+
database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
|
70
|
+
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
71
|
+
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
81
72
|
except ValueError as e:
|
82
73
|
raise ValueError(f"Invalid job ID: {job_id}") from e
|
83
74
|
|
75
|
+
job_id = f"{database}.{schema}.{job_name}"
|
84
76
|
try:
|
85
77
|
# Validate that job exists by doing a status check
|
86
|
-
# FIXME: Retrieve return path
|
87
78
|
job = jb.MLJob[Any](job_id, session=session)
|
88
79
|
_ = job.status
|
89
80
|
return job
|
@@ -93,7 +84,6 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
93
84
|
raise
|
94
85
|
|
95
86
|
|
96
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
87
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
88
|
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
99
89
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
@@ -106,21 +96,22 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
106
96
|
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
107
97
|
|
108
98
|
|
109
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
110
99
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
111
100
|
def submit_file(
|
112
101
|
file_path: str,
|
113
102
|
compute_pool: str,
|
114
103
|
*,
|
115
104
|
stage_name: str,
|
116
|
-
args: Optional[
|
117
|
-
env_vars: Optional[
|
118
|
-
pip_requirements: Optional[
|
119
|
-
external_access_integrations: Optional[
|
105
|
+
args: Optional[list[str]] = None,
|
106
|
+
env_vars: Optional[dict[str, str]] = None,
|
107
|
+
pip_requirements: Optional[list[str]] = None,
|
108
|
+
external_access_integrations: Optional[list[str]] = None,
|
120
109
|
query_warehouse: Optional[str] = None,
|
121
|
-
spec_overrides: Optional[
|
110
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
122
111
|
num_instances: Optional[int] = None,
|
123
112
|
enable_metrics: bool = False,
|
113
|
+
database: Optional[str] = None,
|
114
|
+
schema: Optional[str] = None,
|
124
115
|
session: Optional[snowpark.Session] = None,
|
125
116
|
) -> jb.MLJob[None]:
|
126
117
|
"""
|
@@ -138,6 +129,8 @@ def submit_file(
|
|
138
129
|
spec_overrides: Custom service specification overrides to apply.
|
139
130
|
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
140
131
|
enable_metrics: Whether to enable metrics publishing for the job.
|
132
|
+
database: The database to use.
|
133
|
+
schema: The schema to use.
|
141
134
|
session: The Snowpark session to use. If none specified, uses active session.
|
142
135
|
|
143
136
|
Returns:
|
@@ -155,11 +148,12 @@ def submit_file(
|
|
155
148
|
spec_overrides=spec_overrides,
|
156
149
|
num_instances=num_instances,
|
157
150
|
enable_metrics=enable_metrics,
|
151
|
+
database=database,
|
152
|
+
schema=schema,
|
158
153
|
session=session,
|
159
154
|
)
|
160
155
|
|
161
156
|
|
162
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
163
157
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
164
158
|
def submit_directory(
|
165
159
|
dir_path: str,
|
@@ -167,14 +161,16 @@ def submit_directory(
|
|
167
161
|
*,
|
168
162
|
entrypoint: str,
|
169
163
|
stage_name: str,
|
170
|
-
args: Optional[
|
171
|
-
env_vars: Optional[
|
172
|
-
pip_requirements: Optional[
|
173
|
-
external_access_integrations: Optional[
|
164
|
+
args: Optional[list[str]] = None,
|
165
|
+
env_vars: Optional[dict[str, str]] = None,
|
166
|
+
pip_requirements: Optional[list[str]] = None,
|
167
|
+
external_access_integrations: Optional[list[str]] = None,
|
174
168
|
query_warehouse: Optional[str] = None,
|
175
|
-
spec_overrides: Optional[
|
169
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
176
170
|
num_instances: Optional[int] = None,
|
177
171
|
enable_metrics: bool = False,
|
172
|
+
database: Optional[str] = None,
|
173
|
+
schema: Optional[str] = None,
|
178
174
|
session: Optional[snowpark.Session] = None,
|
179
175
|
) -> jb.MLJob[None]:
|
180
176
|
"""
|
@@ -193,6 +189,8 @@ def submit_directory(
|
|
193
189
|
spec_overrides: Custom service specification overrides to apply.
|
194
190
|
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
195
191
|
enable_metrics: Whether to enable metrics publishing for the job.
|
192
|
+
database: The database to use.
|
193
|
+
schema: The schema to use.
|
196
194
|
session: The Snowpark session to use. If none specified, uses active session.
|
197
195
|
|
198
196
|
Returns:
|
@@ -211,6 +209,8 @@ def submit_directory(
|
|
211
209
|
spec_overrides=spec_overrides,
|
212
210
|
num_instances=num_instances,
|
213
211
|
enable_metrics=enable_metrics,
|
212
|
+
database=database,
|
213
|
+
schema=schema,
|
214
214
|
session=session,
|
215
215
|
)
|
216
216
|
|
@@ -222,14 +222,16 @@ def _submit_job(
|
|
222
222
|
*,
|
223
223
|
stage_name: str,
|
224
224
|
entrypoint: Optional[str] = None,
|
225
|
-
args: Optional[
|
226
|
-
env_vars: Optional[
|
227
|
-
pip_requirements: Optional[
|
228
|
-
external_access_integrations: Optional[
|
225
|
+
args: Optional[list[str]] = None,
|
226
|
+
env_vars: Optional[dict[str, str]] = None,
|
227
|
+
pip_requirements: Optional[list[str]] = None,
|
228
|
+
external_access_integrations: Optional[list[str]] = None,
|
229
229
|
query_warehouse: Optional[str] = None,
|
230
|
-
spec_overrides: Optional[
|
230
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
231
231
|
num_instances: Optional[int] = None,
|
232
232
|
enable_metrics: bool = False,
|
233
|
+
database: Optional[str] = None,
|
234
|
+
schema: Optional[str] = None,
|
233
235
|
session: Optional[snowpark.Session] = None,
|
234
236
|
) -> jb.MLJob[None]:
|
235
237
|
...
|
@@ -242,14 +244,16 @@ def _submit_job(
|
|
242
244
|
*,
|
243
245
|
stage_name: str,
|
244
246
|
entrypoint: Optional[str] = None,
|
245
|
-
args: Optional[
|
246
|
-
env_vars: Optional[
|
247
|
-
pip_requirements: Optional[
|
248
|
-
external_access_integrations: Optional[
|
247
|
+
args: Optional[list[str]] = None,
|
248
|
+
env_vars: Optional[dict[str, str]] = None,
|
249
|
+
pip_requirements: Optional[list[str]] = None,
|
250
|
+
external_access_integrations: Optional[list[str]] = None,
|
249
251
|
query_warehouse: Optional[str] = None,
|
250
|
-
spec_overrides: Optional[
|
252
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
251
253
|
num_instances: Optional[int] = None,
|
252
254
|
enable_metrics: bool = False,
|
255
|
+
database: Optional[str] = None,
|
256
|
+
schema: Optional[str] = None,
|
253
257
|
session: Optional[snowpark.Session] = None,
|
254
258
|
) -> jb.MLJob[T]:
|
255
259
|
...
|
@@ -263,6 +267,8 @@ def _submit_job(
|
|
263
267
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
264
268
|
"pip_requirements",
|
265
269
|
"external_access_integrations",
|
270
|
+
"num_instances",
|
271
|
+
"enable_metrics",
|
266
272
|
],
|
267
273
|
)
|
268
274
|
def _submit_job(
|
@@ -271,14 +277,16 @@ def _submit_job(
|
|
271
277
|
*,
|
272
278
|
stage_name: str,
|
273
279
|
entrypoint: Optional[str] = None,
|
274
|
-
args: Optional[
|
275
|
-
env_vars: Optional[
|
276
|
-
pip_requirements: Optional[
|
277
|
-
external_access_integrations: Optional[
|
280
|
+
args: Optional[list[str]] = None,
|
281
|
+
env_vars: Optional[dict[str, str]] = None,
|
282
|
+
pip_requirements: Optional[list[str]] = None,
|
283
|
+
external_access_integrations: Optional[list[str]] = None,
|
278
284
|
query_warehouse: Optional[str] = None,
|
279
|
-
spec_overrides: Optional[
|
285
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
280
286
|
num_instances: Optional[int] = None,
|
281
287
|
enable_metrics: bool = False,
|
288
|
+
database: Optional[str] = None,
|
289
|
+
schema: Optional[str] = None,
|
282
290
|
session: Optional[snowpark.Session] = None,
|
283
291
|
) -> jb.MLJob[T]:
|
284
292
|
"""
|
@@ -297,6 +305,8 @@ def _submit_job(
|
|
297
305
|
spec_overrides: Custom service specification overrides to apply.
|
298
306
|
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
299
307
|
enable_metrics: Whether to enable metrics publishing for the job.
|
308
|
+
database: The database to use.
|
309
|
+
schema: The schema to use.
|
300
310
|
session: The Snowpark session to use. If none specified, uses active session.
|
301
311
|
|
302
312
|
Returns:
|
@@ -304,11 +314,28 @@ def _submit_job(
|
|
304
314
|
|
305
315
|
Raises:
|
306
316
|
RuntimeError: If required Snowflake features are not enabled.
|
317
|
+
ValueError: If database or schema value(s) are invalid
|
307
318
|
"""
|
319
|
+
# Display warning about PrPr parameters
|
320
|
+
if num_instances is not None:
|
321
|
+
logger.warning(
|
322
|
+
"_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
|
323
|
+
)
|
324
|
+
if database and not schema:
|
325
|
+
raise ValueError("Schema must be specified if database is specified.")
|
326
|
+
|
308
327
|
session = session or get_active_session()
|
309
|
-
|
310
|
-
|
311
|
-
|
328
|
+
|
329
|
+
# Validate database and schema identifiers on client side since
|
330
|
+
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
331
|
+
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
332
|
+
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
333
|
+
|
334
|
+
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
335
|
+
job_id = f"{database}.{schema}.{job_name}"
|
336
|
+
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
337
|
+
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
338
|
+
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
312
339
|
|
313
340
|
# Upload payload
|
314
341
|
uploaded_payload = payload_utils.JobPayload(
|
@@ -335,31 +362,34 @@ def _submit_job(
|
|
335
362
|
|
336
363
|
# Generate SQL command for job submission
|
337
364
|
query_template = textwrap.dedent(
|
338
|
-
|
365
|
+
"""\
|
339
366
|
EXECUTE JOB SERVICE
|
340
|
-
IN COMPUTE POOL
|
367
|
+
IN COMPUTE POOL IDENTIFIER(?)
|
341
368
|
FROM SPECIFICATION $$
|
342
|
-
{
|
369
|
+
{}
|
343
370
|
$$
|
344
|
-
NAME =
|
371
|
+
NAME = IDENTIFIER(?)
|
345
372
|
ASYNC = TRUE
|
346
373
|
"""
|
347
374
|
)
|
375
|
+
params: list[Any] = [compute_pool, job_id]
|
348
376
|
query = query_template.format(yaml.dump(spec)).splitlines()
|
349
377
|
if external_access_integrations:
|
350
378
|
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
351
379
|
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
352
380
|
query_warehouse = query_warehouse or session.get_current_warehouse()
|
353
381
|
if query_warehouse:
|
354
|
-
query.append(
|
382
|
+
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
383
|
+
params.append(query_warehouse)
|
355
384
|
if num_instances:
|
356
|
-
query.append(
|
385
|
+
query.append("REPLICAS = ?")
|
386
|
+
params.append(num_instances)
|
357
387
|
|
358
388
|
# Submit job
|
359
389
|
query_text = "\n".join(line for line in query if line)
|
360
390
|
|
361
391
|
try:
|
362
|
-
_ = session.sql(query_text).collect()
|
392
|
+
_ = session.sql(query_text, params=params).collect()
|
363
393
|
except SnowparkSQLException as e:
|
364
394
|
if "invalid property 'ASYNC'" in e.message:
|
365
395
|
raise RuntimeError(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from datetime import datetime
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal import telemetry
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
from snowflake.ml.model._client.model import model_version_impl
|
13
13
|
|
14
14
|
_PROJECT = "LINEAGE"
|
15
|
-
DOMAIN_LINEAGE_REGISTRY:
|
15
|
+
DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
|
16
16
|
|
17
17
|
|
18
18
|
class LineageNode:
|
@@ -87,8 +87,8 @@ class LineageNode:
|
|
87
87
|
def lineage(
|
88
88
|
self,
|
89
89
|
direction: Literal["upstream", "downstream"] = "downstream",
|
90
|
-
domain_filter: Optional[
|
91
|
-
) ->
|
90
|
+
domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
|
91
|
+
) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
|
92
92
|
"""
|
93
93
|
Retrieves the lineage nodes connected to this node.
|
94
94
|
|
@@ -109,7 +109,7 @@ class LineageNode:
|
|
109
109
|
if domain_filter is not None:
|
110
110
|
domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
|
111
111
|
|
112
|
-
lineage_nodes:
|
112
|
+
lineage_nodes: list["LineageNode"] = []
|
113
113
|
for row in df.collect():
|
114
114
|
lineage_object = (
|
115
115
|
json.loads(row["TARGET_OBJECT"])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -224,7 +224,7 @@ class Model:
|
|
224
224
|
project=_TELEMETRY_PROJECT,
|
225
225
|
subproject=_TELEMETRY_SUBPROJECT,
|
226
226
|
)
|
227
|
-
def versions(self) ->
|
227
|
+
def versions(self) -> list[model_version_impl.ModelVersion]:
|
228
228
|
"""Get all versions in the model.
|
229
229
|
|
230
230
|
Returns:
|
@@ -298,7 +298,7 @@ class Model:
|
|
298
298
|
project=_TELEMETRY_PROJECT,
|
299
299
|
subproject=_TELEMETRY_SUBPROJECT,
|
300
300
|
)
|
301
|
-
def show_tags(self) ->
|
301
|
+
def show_tags(self) -> dict[str, str]:
|
302
302
|
"""Get a dictionary showing the tag and its value attached to the model.
|
303
303
|
|
304
304
|
Returns:
|
@@ -2,10 +2,11 @@ import enum
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, Union, overload
|
6
6
|
|
7
7
|
import pandas as pd
|
8
8
|
|
9
|
+
from snowflake import snowpark
|
9
10
|
from snowflake.ml._internal import telemetry
|
10
11
|
from snowflake.ml._internal.utils import sql_identifier
|
11
12
|
from snowflake.ml.lineage import lineage_node
|
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
32
33
|
_service_ops: service_ops.ServiceOperator
|
33
34
|
_model_name: sql_identifier.SqlIdentifier
|
34
35
|
_version_name: sql_identifier.SqlIdentifier
|
35
|
-
_functions:
|
36
|
+
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
36
37
|
|
37
38
|
def __init__(self) -> None:
|
38
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
152
153
|
project=_TELEMETRY_PROJECT,
|
153
154
|
subproject=_TELEMETRY_SUBPROJECT,
|
154
155
|
)
|
155
|
-
def show_metrics(self) ->
|
156
|
+
def show_metrics(self) -> dict[str, Any]:
|
156
157
|
"""Show all metrics logged with the model version.
|
157
158
|
|
158
159
|
Returns:
|
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
293
294
|
statement_params=statement_params,
|
294
295
|
)
|
295
296
|
|
296
|
-
def _get_functions(self) ->
|
297
|
+
def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
297
298
|
statement_params = telemetry.get_statement_params(
|
298
299
|
project=_TELEMETRY_PROJECT,
|
299
300
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
327
328
|
project=_TELEMETRY_PROJECT,
|
328
329
|
subproject=_TELEMETRY_SUBPROJECT,
|
329
330
|
)
|
330
|
-
def show_functions(self) ->
|
331
|
+
def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
331
332
|
"""Show all functions information in a model version that is callable.
|
332
333
|
|
333
334
|
Returns:
|
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
405
406
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
406
407
|
type validation to make sure your input data won't overflow when providing to the model.
|
407
408
|
|
408
|
-
Raises:
|
409
|
-
ValueError: When no method with the corresponding name is available.
|
410
|
-
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
411
|
-
ValueError: When the partition column is not a valid Snowflake identifier.
|
412
|
-
|
413
409
|
Returns:
|
414
410
|
The prediction data. It would be the same type dataframe as your input.
|
415
411
|
"""
|
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
422
418
|
# Partition column must be a valid identifier
|
423
419
|
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
424
420
|
|
425
|
-
|
426
|
-
|
427
|
-
if function_name:
|
428
|
-
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
429
|
-
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
430
|
-
lambda method: method["name"] == req_method_name
|
431
|
-
)
|
432
|
-
target_function_info = next(
|
433
|
-
filter(find_method, functions),
|
434
|
-
None,
|
435
|
-
)
|
436
|
-
if target_function_info is None:
|
437
|
-
raise ValueError(
|
438
|
-
f"There is no method with name {function_name} available in the model"
|
439
|
-
f" {self.fully_qualified_model_name} version {self.version_name}"
|
440
|
-
)
|
441
|
-
elif len(functions) != 1:
|
442
|
-
raise ValueError(
|
443
|
-
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
444
|
-
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
445
|
-
)
|
446
|
-
else:
|
447
|
-
target_function_info = functions[0]
|
421
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
448
422
|
|
449
423
|
if service_name:
|
450
424
|
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
|
|
475
449
|
is_partitioned=target_function_info["is_partitioned"],
|
476
450
|
)
|
477
451
|
|
452
|
+
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
453
|
+
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
454
|
+
|
455
|
+
if function_name:
|
456
|
+
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
457
|
+
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
458
|
+
lambda method: method["name"] == req_method_name
|
459
|
+
)
|
460
|
+
target_function_info = next(
|
461
|
+
filter(find_method, functions),
|
462
|
+
None,
|
463
|
+
)
|
464
|
+
if target_function_info is None:
|
465
|
+
raise ValueError(
|
466
|
+
f"There is no method with name {function_name} available in the model"
|
467
|
+
f" {self.fully_qualified_model_name} version {self.version_name}"
|
468
|
+
)
|
469
|
+
elif len(functions) != 1:
|
470
|
+
raise ValueError(
|
471
|
+
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
472
|
+
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
target_function_info = functions[0]
|
476
|
+
|
477
|
+
return target_function_info
|
478
|
+
|
478
479
|
@telemetry.send_api_usage_telemetry(
|
479
480
|
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
480
481
|
)
|
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
684
685
|
num_workers: Optional[int] = None,
|
685
686
|
max_batch_rows: Optional[int] = None,
|
686
687
|
force_rebuild: bool = False,
|
687
|
-
build_external_access_integrations: Optional[
|
688
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
688
689
|
block: bool = True,
|
689
690
|
) -> Union[str, async_job.AsyncJob]:
|
690
691
|
"""Create an inference service with the given spec.
|
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
751
752
|
max_batch_rows: Optional[int] = None,
|
752
753
|
force_rebuild: bool = False,
|
753
754
|
build_external_access_integration: Optional[str] = None,
|
754
|
-
build_external_access_integrations: Optional[
|
755
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
755
756
|
block: bool = True,
|
756
757
|
) -> Union[str, async_job.AsyncJob]:
|
757
758
|
"""Create an inference service with the given spec.
|
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
|
|
914
915
|
statement_params=statement_params,
|
915
916
|
)
|
916
917
|
|
918
|
+
@snowpark._internal.utils.private_preview(version="1.8.3")
|
919
|
+
@telemetry.send_api_usage_telemetry(
|
920
|
+
project=_TELEMETRY_PROJECT,
|
921
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
922
|
+
)
|
923
|
+
def _run_job(
|
924
|
+
self,
|
925
|
+
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
926
|
+
*,
|
927
|
+
job_name: str,
|
928
|
+
compute_pool: str,
|
929
|
+
image_repo: str,
|
930
|
+
output_table_name: str,
|
931
|
+
function_name: Optional[str] = None,
|
932
|
+
cpu_requests: Optional[str] = None,
|
933
|
+
memory_requests: Optional[str] = None,
|
934
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
935
|
+
num_workers: Optional[int] = None,
|
936
|
+
max_batch_rows: Optional[int] = None,
|
937
|
+
force_rebuild: bool = False,
|
938
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
939
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
940
|
+
statement_params = telemetry.get_statement_params(
|
941
|
+
project=_TELEMETRY_PROJECT,
|
942
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
943
|
+
)
|
944
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
945
|
+
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
946
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
947
|
+
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
948
|
+
output_table_name
|
949
|
+
)
|
950
|
+
warehouse = self._service_ops._session.get_current_warehouse()
|
951
|
+
assert warehouse, "No active warehouse selected in the current session."
|
952
|
+
return self._service_ops.invoke_job_method(
|
953
|
+
target_method=target_function_info["target_method"],
|
954
|
+
signature=target_function_info["signature"],
|
955
|
+
X=X,
|
956
|
+
database_name=None,
|
957
|
+
schema_name=None,
|
958
|
+
model_name=self._model_name,
|
959
|
+
version_name=self._version_name,
|
960
|
+
job_database_name=job_db_id,
|
961
|
+
job_schema_name=job_schema_id,
|
962
|
+
job_name=job_id,
|
963
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
964
|
+
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
965
|
+
image_repo_database_name=image_repo_db_id,
|
966
|
+
image_repo_schema_name=image_repo_schema_id,
|
967
|
+
image_repo_name=image_repo_id,
|
968
|
+
output_table_database_name=output_table_db_id,
|
969
|
+
output_table_schema_name=output_table_schema_id,
|
970
|
+
output_table_name=output_table_id,
|
971
|
+
cpu_requests=cpu_requests,
|
972
|
+
memory_requests=memory_requests,
|
973
|
+
gpu_requests=gpu_requests,
|
974
|
+
num_workers=num_workers,
|
975
|
+
max_batch_rows=max_batch_rows,
|
976
|
+
force_rebuild=force_rebuild,
|
977
|
+
build_external_access_integrations=(
|
978
|
+
None
|
979
|
+
if build_external_access_integrations is None
|
980
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
981
|
+
),
|
982
|
+
statement_params=statement_params,
|
983
|
+
)
|
984
|
+
|
917
985
|
|
918
986
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, TypedDict
|
3
3
|
|
4
4
|
from typing_extensions import NotRequired
|
5
5
|
|
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
|
|
14
14
|
|
15
15
|
|
16
16
|
class ModelVersionMetadataSchema(TypedDict):
|
17
|
-
metrics: NotRequired[
|
17
|
+
metrics: NotRequired[dict[str, Any]]
|
18
18
|
|
19
19
|
|
20
20
|
class MetadataOperator:
|
@@ -44,7 +44,7 @@ class MetadataOperator:
|
|
44
44
|
)
|
45
45
|
|
46
46
|
@staticmethod
|
47
|
-
def _parse(metadata_dict:
|
47
|
+
def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
|
48
48
|
loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
|
49
49
|
if loaded_metadata_schema_version is None:
|
50
50
|
return ModelVersionMetadataSchema(metrics={})
|
@@ -65,8 +65,8 @@ class MetadataOperator:
|
|
65
65
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
66
66
|
model_name: sql_identifier.SqlIdentifier,
|
67
67
|
version_name: sql_identifier.SqlIdentifier,
|
68
|
-
statement_params: Optional[
|
69
|
-
) ->
|
68
|
+
statement_params: Optional[dict[str, Any]] = None,
|
69
|
+
) -> dict[str, Any]:
|
70
70
|
version_info_list = self._model_client.show_versions(
|
71
71
|
database_name=database_name,
|
72
72
|
schema_name=schema_name,
|
@@ -89,7 +89,7 @@ class MetadataOperator:
|
|
89
89
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
90
90
|
model_name: sql_identifier.SqlIdentifier,
|
91
91
|
version_name: sql_identifier.SqlIdentifier,
|
92
|
-
statement_params: Optional[
|
92
|
+
statement_params: Optional[dict[str, Any]] = None,
|
93
93
|
) -> ModelVersionMetadataSchema:
|
94
94
|
metadata_dict = self._get_current_metadata_dict(
|
95
95
|
database_name=database_name,
|
@@ -108,7 +108,7 @@ class MetadataOperator:
|
|
108
108
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
109
109
|
model_name: sql_identifier.SqlIdentifier,
|
110
110
|
version_name: sql_identifier.SqlIdentifier,
|
111
|
-
statement_params: Optional[
|
111
|
+
statement_params: Optional[dict[str, Any]] = None,
|
112
112
|
) -> None:
|
113
113
|
metadata_dict = self._get_current_metadata_dict(
|
114
114
|
database_name=database_name,
|