snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
from snowflake.cortex._classify_text import ClassifyText, classify_text
|
2
|
-
from snowflake.cortex._complete import
|
2
|
+
from snowflake.cortex._complete import (
|
3
|
+
Complete,
|
4
|
+
CompleteOptions,
|
5
|
+
ConversationMessage,
|
6
|
+
complete,
|
7
|
+
)
|
3
8
|
from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
|
4
9
|
from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
|
5
10
|
from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
|
@@ -14,6 +19,7 @@ __all__ = [
|
|
14
19
|
"Complete",
|
15
20
|
"complete",
|
16
21
|
"CompleteOptions",
|
22
|
+
"ConversationMessage",
|
17
23
|
"EmbedText768",
|
18
24
|
"embed_text_768",
|
19
25
|
"EmbedText1024",
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union, cast
|
2
2
|
|
3
3
|
from typing_extensions import deprecated
|
4
4
|
|
@@ -12,7 +12,7 @@ from snowflake.ml._internal import telemetry
|
|
12
12
|
)
|
13
13
|
def classify_text(
|
14
14
|
str_input: Union[str, snowpark.Column],
|
15
|
-
categories: Union[
|
15
|
+
categories: Union[list[str], snowpark.Column],
|
16
16
|
session: Optional[snowpark.Session] = None,
|
17
17
|
) -> Union[str, snowpark.Column]:
|
18
18
|
"""Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
|
@@ -32,7 +32,7 @@ def classify_text(
|
|
32
32
|
def _classify_text_impl(
|
33
33
|
function: str,
|
34
34
|
str_input: Union[str, snowpark.Column],
|
35
|
-
categories: Union[
|
35
|
+
categories: Union[list[str], snowpark.Column],
|
36
36
|
session: Optional[snowpark.Session] = None,
|
37
37
|
) -> Union[str, snowpark.Column]:
|
38
38
|
return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories))
|
snowflake/cortex/_complete.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import time
|
4
4
|
import typing
|
5
5
|
from io import BytesIO
|
6
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Iterator, Optional, TypedDict, Union, cast
|
7
7
|
from urllib.parse import urlunparse
|
8
8
|
|
9
9
|
import requests
|
@@ -30,7 +30,7 @@ class ResponseFormat(TypedDict):
|
|
30
30
|
|
31
31
|
type: str
|
32
32
|
"""The response format type (e.g. "json")"""
|
33
|
-
schema:
|
33
|
+
schema: dict[str, Any]
|
34
34
|
"""The schema defining the structure of the response. For json it should be a valid json schema object"""
|
35
35
|
|
36
36
|
|
@@ -71,12 +71,11 @@ class CompleteOptions(TypedDict):
|
|
71
71
|
class ResponseParseException(Exception):
|
72
72
|
"""This exception is raised when the server response cannot be parsed."""
|
73
73
|
|
74
|
-
pass
|
75
|
-
|
76
74
|
|
77
75
|
class MidStreamException(Exception):
|
78
76
|
"""The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
|
79
|
-
using the “error” event type. This exception is raised when there is such a mid-stream error.
|
77
|
+
using the “error” event type. This exception is raised when there is such a mid-stream error.
|
78
|
+
"""
|
80
79
|
|
81
80
|
def __init__(
|
82
81
|
self,
|
@@ -135,7 +134,7 @@ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Resp
|
|
135
134
|
return inner
|
136
135
|
|
137
136
|
|
138
|
-
def _make_common_request_headers() ->
|
137
|
+
def _make_common_request_headers() -> dict[str, str]:
|
139
138
|
headers = {
|
140
139
|
"Content-Type": "application/json",
|
141
140
|
"Accept": "application/json, text/event-stream",
|
@@ -143,7 +142,7 @@ def _make_common_request_headers() -> Dict[str, str]:
|
|
143
142
|
return headers
|
144
143
|
|
145
144
|
|
146
|
-
def _get_request_id(resp:
|
145
|
+
def _get_request_id(resp: dict[str, Any]) -> Optional[Any]:
|
147
146
|
request_id = None
|
148
147
|
if "headers" in resp:
|
149
148
|
for key, value in resp["headers"].items():
|
@@ -183,14 +182,14 @@ def _validate_response_format_object(options: CompleteOptions) -> None:
|
|
183
182
|
|
184
183
|
def _make_request_body(
|
185
184
|
model: str,
|
186
|
-
prompt: Union[str,
|
185
|
+
prompt: Union[str, list[ConversationMessage]],
|
187
186
|
options: Optional[CompleteOptions] = None,
|
188
|
-
) ->
|
187
|
+
) -> dict[str, Any]:
|
189
188
|
data = {
|
190
189
|
"model": model,
|
191
190
|
"stream": True,
|
192
191
|
}
|
193
|
-
if isinstance(prompt,
|
192
|
+
if isinstance(prompt, list):
|
194
193
|
data["messages"] = prompt
|
195
194
|
else:
|
196
195
|
data["messages"] = [{"content": prompt}]
|
@@ -217,7 +216,7 @@ def _make_request_body(
|
|
217
216
|
|
218
217
|
# XP endpoint returns a dict response which needs to be converted to a format which can
|
219
218
|
# be consumed by the SSEClient. This method does that.
|
220
|
-
def _xp_dict_to_response(raw_resp:
|
219
|
+
def _xp_dict_to_response(raw_resp: dict[str, Any]) -> requests.Response:
|
221
220
|
|
222
221
|
response = requests.Response()
|
223
222
|
response.status_code = int(raw_resp["status"])
|
@@ -251,9 +250,9 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
251
250
|
|
252
251
|
@retry
|
253
252
|
def _call_complete_xp(
|
254
|
-
snow_api_xp_request_handler: Optional[Callable[...,
|
253
|
+
snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
|
255
254
|
model: str,
|
256
|
-
prompt: Union[str,
|
255
|
+
prompt: Union[str, list[ConversationMessage]],
|
257
256
|
options: Optional[CompleteOptions] = None,
|
258
257
|
deadline: Optional[float] = None,
|
259
258
|
) -> requests.Response:
|
@@ -267,7 +266,7 @@ def _call_complete_xp(
|
|
267
266
|
@retry
|
268
267
|
def _call_complete_rest(
|
269
268
|
model: str,
|
270
|
-
prompt: Union[str,
|
269
|
+
prompt: Union[str, list[ConversationMessage]],
|
271
270
|
options: Optional[CompleteOptions] = None,
|
272
271
|
session: Optional[snowpark.Session] = None,
|
273
272
|
) -> requests.Response:
|
@@ -340,9 +339,9 @@ def _complete_call_sql_function_snowpark(
|
|
340
339
|
|
341
340
|
|
342
341
|
def _complete_non_streaming_immediate(
|
343
|
-
snow_api_xp_request_handler: Optional[Callable[...,
|
342
|
+
snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
|
344
343
|
model: str,
|
345
|
-
prompt: Union[str,
|
344
|
+
prompt: Union[str, list[ConversationMessage]],
|
346
345
|
options: Optional[CompleteOptions],
|
347
346
|
session: Optional[snowpark.Session] = None,
|
348
347
|
deadline: Optional[float] = None,
|
@@ -359,10 +358,10 @@ def _complete_non_streaming_immediate(
|
|
359
358
|
|
360
359
|
|
361
360
|
def _complete_non_streaming_impl(
|
362
|
-
snow_api_xp_request_handler: Optional[Callable[...,
|
361
|
+
snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
|
363
362
|
function: str,
|
364
363
|
model: Union[str, snowpark.Column],
|
365
|
-
prompt: Union[str,
|
364
|
+
prompt: Union[str, list[ConversationMessage], snowpark.Column],
|
366
365
|
options: Optional[Union[CompleteOptions, snowpark.Column]],
|
367
366
|
session: Optional[snowpark.Session] = None,
|
368
367
|
deadline: Optional[float] = None,
|
@@ -389,9 +388,9 @@ def _complete_non_streaming_impl(
|
|
389
388
|
|
390
389
|
|
391
390
|
def _complete_rest(
|
392
|
-
snow_api_xp_request_handler: Optional[Callable[...,
|
391
|
+
snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
|
393
392
|
model: str,
|
394
|
-
prompt: Union[str,
|
393
|
+
prompt: Union[str, list[ConversationMessage]],
|
395
394
|
options: Optional[CompleteOptions] = None,
|
396
395
|
session: Optional[snowpark.Session] = None,
|
397
396
|
deadline: Optional[float] = None,
|
@@ -414,8 +413,8 @@ def _complete_rest(
|
|
414
413
|
|
415
414
|
def _complete_impl(
|
416
415
|
model: Union[str, snowpark.Column],
|
417
|
-
prompt: Union[str,
|
418
|
-
snow_api_xp_request_handler: Optional[Callable[...,
|
416
|
+
prompt: Union[str, list[ConversationMessage], snowpark.Column],
|
417
|
+
snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]] = None,
|
419
418
|
function: str = "snowflake.cortex.complete",
|
420
419
|
options: Optional[CompleteOptions] = None,
|
421
420
|
session: Optional[snowpark.Session] = None,
|
@@ -430,7 +429,7 @@ def _complete_impl(
|
|
430
429
|
if stream:
|
431
430
|
if not isinstance(model, str):
|
432
431
|
raise ValueError("in REST mode, 'model' must be a string")
|
433
|
-
if not isinstance(prompt, str) and not isinstance(prompt,
|
432
|
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
434
433
|
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
435
434
|
return _complete_rest(
|
436
435
|
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
@@ -456,7 +455,7 @@ def _complete_impl(
|
|
456
455
|
)
|
457
456
|
def complete(
|
458
457
|
model: Union[str, snowpark.Column],
|
459
|
-
prompt: Union[str,
|
458
|
+
prompt: Union[str, list[ConversationMessage], snowpark.Column],
|
460
459
|
*,
|
461
460
|
options: Optional[CompleteOptions] = None,
|
462
461
|
session: Optional[snowpark.Session] = None,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union, cast
|
2
2
|
|
3
3
|
from typing_extensions import deprecated
|
4
4
|
|
@@ -14,7 +14,7 @@ def embed_text_1024(
|
|
14
14
|
model: Union[str, snowpark.Column],
|
15
15
|
text: Union[str, snowpark.Column],
|
16
16
|
session: Optional[snowpark.Session] = None,
|
17
|
-
) -> Union[
|
17
|
+
) -> Union[list[float], snowpark.Column]:
|
18
18
|
"""Calls into the LLM inference service to embed the text.
|
19
19
|
|
20
20
|
Args:
|
@@ -35,8 +35,8 @@ def _embed_text_1024_impl(
|
|
35
35
|
model: Union[str, snowpark.Column],
|
36
36
|
text: Union[str, snowpark.Column],
|
37
37
|
session: Optional[snowpark.Session] = None,
|
38
|
-
) -> Union[
|
39
|
-
return cast(Union[
|
38
|
+
) -> Union[list[float], snowpark.Column]:
|
39
|
+
return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
|
40
40
|
|
41
41
|
|
42
42
|
EmbedText1024 = deprecated(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union, cast
|
2
2
|
|
3
3
|
from typing_extensions import deprecated
|
4
4
|
|
@@ -14,7 +14,7 @@ def embed_text_768(
|
|
14
14
|
model: Union[str, snowpark.Column],
|
15
15
|
text: Union[str, snowpark.Column],
|
16
16
|
session: Optional[snowpark.Session] = None,
|
17
|
-
) -> Union[
|
17
|
+
) -> Union[list[float], snowpark.Column]:
|
18
18
|
"""Calls into the LLM inference service to embed the text.
|
19
19
|
|
20
20
|
Args:
|
@@ -35,8 +35,8 @@ def _embed_text_768_impl(
|
|
35
35
|
model: Union[str, snowpark.Column],
|
36
36
|
text: Union[str, snowpark.Column],
|
37
37
|
session: Optional[snowpark.Session] = None,
|
38
|
-
) -> Union[
|
39
|
-
return cast(Union[
|
38
|
+
) -> Union[list[float], snowpark.Column]:
|
39
|
+
return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
|
40
40
|
|
41
41
|
|
42
42
|
EmbedText768 = deprecated(
|
snowflake/cortex/_finetune.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from dataclasses import dataclass
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Optional, Union, cast
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.cortex._util import (
|
@@ -53,7 +53,7 @@ class FinetuneStatus:
|
|
53
53
|
created_on: Optional[int] = None
|
54
54
|
"""Creation timestamp of the Fine-tuning job in milliseconds."""
|
55
55
|
|
56
|
-
error: Optional[
|
56
|
+
error: Optional[dict[str, Any]] = None
|
57
57
|
"""Error message propagated from the job."""
|
58
58
|
|
59
59
|
finished_on: Optional[int] = None
|
@@ -62,7 +62,7 @@ class FinetuneStatus:
|
|
62
62
|
progress: Optional[float] = None
|
63
63
|
"""Progress made as a fraction of total [0.0,1.0]."""
|
64
64
|
|
65
|
-
training_result: Optional[
|
65
|
+
training_result: Optional[list[dict[str, Any]]] = None
|
66
66
|
"""Detailed metrics report for a completed training."""
|
67
67
|
|
68
68
|
trained_tokens: Optional[int] = None
|
@@ -135,7 +135,7 @@ class FinetuneJob:
|
|
135
135
|
"""
|
136
136
|
result_string = _finetune_impl(operation="DESCRIBE", session=self._session, function_args=[self.status.id])
|
137
137
|
|
138
|
-
result = FinetuneStatus(**cast(
|
138
|
+
result = FinetuneStatus(**cast(dict[str, Any], _try_load_json(result_string)))
|
139
139
|
return result
|
140
140
|
|
141
141
|
|
@@ -167,7 +167,7 @@ class Finetune:
|
|
167
167
|
base_model: str,
|
168
168
|
training_data: Union[str, snowpark.DataFrame],
|
169
169
|
validation_data: Optional[Union[str, snowpark.DataFrame]] = None,
|
170
|
-
options: Optional[
|
170
|
+
options: Optional[dict[str, Any]] = None,
|
171
171
|
) -> FinetuneJob:
|
172
172
|
"""Create a new fine-tuning runs.
|
173
173
|
|
@@ -240,7 +240,7 @@ class Finetune:
|
|
240
240
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
241
241
|
subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT,
|
242
242
|
)
|
243
|
-
def list_jobs(self) ->
|
243
|
+
def list_jobs(self) -> list["FinetuneJob"]:
|
244
244
|
"""Show current and past fine-tuning runs.
|
245
245
|
|
246
246
|
Returns:
|
@@ -253,7 +253,7 @@ class Finetune:
|
|
253
253
|
return [FinetuneJob(session=self._session, status=FinetuneStatus(**run_status)) for run_status in result]
|
254
254
|
|
255
255
|
|
256
|
-
def _try_load_json(json_string: str) -> Union[
|
256
|
+
def _try_load_json(json_string: str) -> Union[dict[Any, Any], list[Any]]:
|
257
257
|
try:
|
258
258
|
result = json.loads(str(json_string))
|
259
259
|
except json.JSONDecodeError as e:
|
@@ -269,5 +269,5 @@ def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]:
|
|
269
269
|
return result
|
270
270
|
|
271
271
|
|
272
|
-
def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args:
|
272
|
+
def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: list[Any]) -> str:
|
273
273
|
return call_sql_function_literals(_CORTEX_FINETUNE_SYSTEM_FUNCTION_NAME, session, operation, *function_args)
|
snowflake/cortex/_util.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional, Union, cast
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
@@ -11,22 +11,18 @@ CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
|
11
11
|
class SnowflakeAuthenticationException(Exception):
|
12
12
|
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
13
13
|
|
14
|
-
pass
|
15
|
-
|
16
14
|
|
17
15
|
class SnowflakeConfigurationException(Exception):
|
18
16
|
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
19
17
|
|
20
|
-
pass
|
21
|
-
|
22
18
|
|
23
19
|
# Calls a sql function, handling both immediate (e.g. python types) and batch
|
24
20
|
# (e.g. snowpark column and literal type modes).
|
25
21
|
def call_sql_function(
|
26
22
|
function: str,
|
27
23
|
session: Optional[snowpark.Session],
|
28
|
-
*args: Union[str,
|
29
|
-
) -> Union[str,
|
24
|
+
*args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
|
25
|
+
) -> Union[str, list[float], snowpark.Column]:
|
30
26
|
handle_as_column = False
|
31
27
|
|
32
28
|
for arg in args:
|
@@ -34,15 +30,15 @@ def call_sql_function(
|
|
34
30
|
handle_as_column = True
|
35
31
|
|
36
32
|
if handle_as_column:
|
37
|
-
return cast(Union[str,
|
33
|
+
return cast(Union[str, list[float], snowpark.Column], _call_sql_function_column(function, *args))
|
38
34
|
return cast(
|
39
|
-
Union[str,
|
35
|
+
Union[str, list[float], snowpark.Column],
|
40
36
|
_call_sql_function_immediate(function, session, *args),
|
41
37
|
)
|
42
38
|
|
43
39
|
|
44
40
|
def _call_sql_function_column(
|
45
|
-
function: str, *args: Union[str,
|
41
|
+
function: str, *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]]
|
46
42
|
) -> snowpark.Column:
|
47
43
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
48
44
|
|
@@ -50,8 +46,8 @@ def _call_sql_function_column(
|
|
50
46
|
def _call_sql_function_immediate(
|
51
47
|
function: str,
|
52
48
|
session: Optional[snowpark.Session],
|
53
|
-
*args: Union[str,
|
54
|
-
) -> Union[str,
|
49
|
+
*args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
|
50
|
+
) -> Union[str, list[float]]:
|
55
51
|
session = session or context.get_active_session()
|
56
52
|
if session is None:
|
57
53
|
raise SnowflakeAuthenticationException(
|
snowflake/ml/_internal/env.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
+
import os
|
1
2
|
import platform
|
2
3
|
|
3
|
-
from snowflake.ml import version
|
4
|
-
|
5
4
|
SOURCE = "SnowML"
|
6
|
-
VERSION = version.VERSION
|
7
5
|
PYTHON_VERSION = platform.python_version()
|
8
6
|
OS = platform.system()
|
7
|
+
IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME"
|
8
|
+
IN_ML_RUNTIME = os.getenv(IN_ML_RUNTIME_ENV_VAR)
|
9
|
+
USE_OPTIMIZED_DATA_INGESTOR = "USE_OPTIMIZED_DATA_INGESTOR"
|
@@ -6,12 +6,13 @@ import textwrap
|
|
6
6
|
import warnings
|
7
7
|
from enum import Enum
|
8
8
|
from importlib import metadata as importlib_metadata
|
9
|
-
from typing import Any, DefaultDict,
|
9
|
+
from typing import Any, DefaultDict, Optional
|
10
10
|
|
11
11
|
import yaml
|
12
12
|
from packaging import requirements, specifiers, version
|
13
13
|
|
14
14
|
import snowflake.connector
|
15
|
+
from snowflake.ml import version as snowml_version
|
15
16
|
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
16
17
|
from snowflake.ml._internal.utils import query_result_checker
|
17
18
|
from snowflake.snowpark import context, exceptions, session
|
@@ -27,8 +28,8 @@ class CONDA_OS(Enum):
|
|
27
28
|
|
28
29
|
|
29
30
|
_NODEFAULTS = "nodefaults"
|
30
|
-
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
|
31
|
-
_SNOWFLAKE_CONDA_PACKAGE_CACHE:
|
31
|
+
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
|
32
|
+
_SNOWFLAKE_CONDA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
|
32
33
|
_SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
|
33
34
|
|
34
35
|
DEFAULT_CHANNEL_NAME = ""
|
@@ -64,7 +65,7 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
64
65
|
return r
|
65
66
|
|
66
67
|
|
67
|
-
def _validate_conda_dependency_string(dep_str: str) ->
|
68
|
+
def _validate_conda_dependency_string(dep_str: str) -> tuple[str, requirements.Requirement]:
|
68
69
|
"""Validate conda dependency string like `pytorch == 1.12.1` or `conda-forge::transformer` and split the channel
|
69
70
|
name before the double colon and requirement specification after that.
|
70
71
|
|
@@ -115,7 +116,7 @@ class DuplicateDependencyInMultipleChannelsError(Exception):
|
|
115
116
|
...
|
116
117
|
|
117
118
|
|
118
|
-
def append_requirement_list(req_list:
|
119
|
+
def append_requirement_list(req_list: list[requirements.Requirement], p_req: requirements.Requirement) -> None:
|
119
120
|
"""Append a requirement to an existing requirement list. If need and able to merge, merge it, otherwise, append it.
|
120
121
|
|
121
122
|
Args:
|
@@ -134,7 +135,7 @@ def append_requirement_list(req_list: List[requirements.Requirement], p_req: req
|
|
134
135
|
|
135
136
|
|
136
137
|
def append_conda_dependency(
|
137
|
-
conda_chan_deps: DefaultDict[str,
|
138
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], p_chan_dep: tuple[str, requirements.Requirement]
|
138
139
|
) -> None:
|
139
140
|
"""Append a conda dependency to an existing conda dependencies dict, if not existed in any channel.
|
140
141
|
To avoid making unnecessary modification to dict, we check the existence first, then try to merge, then append,
|
@@ -164,45 +165,73 @@ def append_conda_dependency(
|
|
164
165
|
conda_chan_deps[p_channel].append(p_req)
|
165
166
|
|
166
167
|
|
167
|
-
def validate_pip_requirement_string_list(
|
168
|
-
|
168
|
+
def validate_pip_requirement_string_list(
|
169
|
+
req_str_list: list[str], add_local_version_specifier: bool = False
|
170
|
+
) -> list[requirements.Requirement]:
|
171
|
+
"""Validate the list of pip requirement strings according to PEP 508.
|
169
172
|
|
170
173
|
Args:
|
171
|
-
req_str_list: The list of
|
174
|
+
req_str_list: The list of strings containing the pip requirement specification.
|
175
|
+
add_local_version_specifier: if True, add the version specifier of the locally installed package version to
|
176
|
+
requirements without version specifiers.
|
172
177
|
|
173
178
|
Returns:
|
174
179
|
A requirements.Requirement list containing the requirement information.
|
175
180
|
"""
|
176
|
-
seen_pip_requirement_list:
|
181
|
+
seen_pip_requirement_list: list[requirements.Requirement] = []
|
177
182
|
for req_str in req_str_list:
|
178
183
|
append_requirement_list(seen_pip_requirement_list, _validate_pip_requirement_string(req_str=req_str))
|
179
184
|
|
185
|
+
if add_local_version_specifier:
|
186
|
+
# For any requirement string that does not contain a specifier, add the specifier of a locally installed version
|
187
|
+
# if it exists.
|
188
|
+
seen_pip_requirement_list = list(
|
189
|
+
map(
|
190
|
+
lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req),
|
191
|
+
seen_pip_requirement_list,
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
180
195
|
return seen_pip_requirement_list
|
181
196
|
|
182
197
|
|
183
|
-
def validate_conda_dependency_string_list(
|
198
|
+
def validate_conda_dependency_string_list(
|
199
|
+
dep_str_list: list[str], add_local_version_specifier: bool = False
|
200
|
+
) -> DefaultDict[str, list[requirements.Requirement]]:
|
184
201
|
"""Validate a list of conda dependency string, find any duplicate package across different channel and create a dict
|
185
202
|
to represent the whole dependencies.
|
186
203
|
|
187
204
|
Args:
|
188
205
|
dep_str_list: The list of string contains the conda dependency specification.
|
206
|
+
add_local_version_specifier: if True, add the version specifier of the locally installed package version to
|
207
|
+
requirements without version specifiers.
|
189
208
|
|
190
209
|
Returns:
|
191
210
|
A dict mapping from the channel name to the list of requirements from that channel.
|
192
211
|
"""
|
193
212
|
validated_conda_dependency_list = list(map(_validate_conda_dependency_string, dep_str_list))
|
194
|
-
ret_conda_dependency_dict: DefaultDict[str,
|
213
|
+
ret_conda_dependency_dict: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
|
195
214
|
for p_channel, p_req in validated_conda_dependency_list:
|
196
215
|
append_conda_dependency(ret_conda_dependency_dict, (p_channel, p_req))
|
197
216
|
|
217
|
+
if add_local_version_specifier:
|
218
|
+
# For any conda dependency string that does not contain a specifier, add the specifier of a locally installed
|
219
|
+
# version if it exists. This is best-effort: if the conda package does not have the same name as the pip
|
220
|
+
# package, it won't be found in the local environment.
|
221
|
+
for channel_str, reqs in ret_conda_dependency_dict.items():
|
222
|
+
reqs = list(
|
223
|
+
map(lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req), reqs)
|
224
|
+
)
|
225
|
+
ret_conda_dependency_dict[channel_str] = reqs
|
226
|
+
|
198
227
|
return ret_conda_dependency_dict
|
199
228
|
|
200
229
|
|
201
230
|
def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement) -> requirements.Requirement:
|
202
231
|
"""Get the local installed version of a given pip package requirement.
|
203
|
-
If the package is locally installed, and the local version
|
232
|
+
If the package is locally installed, and the local version meets the specifier of the requirements, return a new
|
204
233
|
requirement specifier that pins the version.
|
205
|
-
If the local version does not meet the specifier of the requirements, a
|
234
|
+
If the local version does not meet the specifier of the requirements, a warning will be emitted and returns
|
206
235
|
the original package requirement.
|
207
236
|
If the package is not locally installed or not found, the original package requirement is returned.
|
208
237
|
|
@@ -217,7 +246,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
|
|
217
246
|
local_dist_version = local_dist.version
|
218
247
|
except importlib_metadata.PackageNotFoundError:
|
219
248
|
if pip_req.name == SNOWPARK_ML_PKG_NAME:
|
220
|
-
local_dist_version =
|
249
|
+
local_dist_version = snowml_version.VERSION
|
221
250
|
else:
|
222
251
|
return pip_req
|
223
252
|
new_pip_req = copy.deepcopy(pip_req)
|
@@ -372,8 +401,8 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
372
401
|
|
373
402
|
|
374
403
|
def get_matched_package_versions_in_information_schema_with_active_session(
|
375
|
-
reqs:
|
376
|
-
) ->
|
404
|
+
reqs: list[requirements.Requirement], python_version: str
|
405
|
+
) -> dict[str, list[version.Version]]:
|
377
406
|
try:
|
378
407
|
session = context.get_active_session()
|
379
408
|
except exceptions.SnowparkSessionException:
|
@@ -383,10 +412,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
|
|
383
412
|
|
384
413
|
def get_matched_package_versions_in_information_schema(
|
385
414
|
session: session.Session,
|
386
|
-
reqs:
|
415
|
+
reqs: list[requirements.Requirement],
|
387
416
|
python_version: str,
|
388
|
-
statement_params: Optional[
|
389
|
-
) ->
|
417
|
+
statement_params: Optional[dict[str, Any]] = None,
|
418
|
+
) -> dict[str, list[version.Version]]:
|
390
419
|
"""Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
|
391
420
|
Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
|
392
421
|
exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
|
@@ -400,8 +429,8 @@ def get_matched_package_versions_in_information_schema(
|
|
400
429
|
Returns:
|
401
430
|
A Dict, whose key is the package name, and value is a list of versions match the requirements.
|
402
431
|
"""
|
403
|
-
ret_dict:
|
404
|
-
reqs_to_request:
|
432
|
+
ret_dict: dict[str, list[version.Version]] = {}
|
433
|
+
reqs_to_request: list[requirements.Requirement] = []
|
405
434
|
for req in reqs:
|
406
435
|
if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
|
407
436
|
available_versions = list(
|
@@ -457,7 +486,7 @@ def get_matched_package_versions_in_information_schema(
|
|
457
486
|
|
458
487
|
def save_conda_env_file(
|
459
488
|
path: pathlib.Path,
|
460
|
-
conda_chan_deps: DefaultDict[str,
|
489
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
|
461
490
|
python_version: str,
|
462
491
|
cuda_version: Optional[str] = None,
|
463
492
|
default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
|
@@ -478,7 +507,7 @@ def save_conda_env_file(
|
|
478
507
|
"""
|
479
508
|
assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
|
480
509
|
path.parent.mkdir(parents=True, exist_ok=True)
|
481
|
-
env:
|
510
|
+
env: dict[str, Any] = dict()
|
482
511
|
env["name"] = "snow-env"
|
483
512
|
# Get all channels in the dependencies, ordered by the number of the packages which belongs to and put into
|
484
513
|
# channels section.
|
@@ -505,7 +534,7 @@ def save_conda_env_file(
|
|
505
534
|
yaml.safe_dump(env, stream=f, default_flow_style=False)
|
506
535
|
|
507
536
|
|
508
|
-
def save_requirements_file(path: pathlib.Path, pip_deps:
|
537
|
+
def save_requirements_file(path: pathlib.Path, pip_deps: list[requirements.Requirement]) -> None:
|
509
538
|
"""Generate Python requirements.txt file in the given directory path.
|
510
539
|
|
511
540
|
Args:
|
@@ -521,9 +550,9 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
|
|
521
550
|
|
522
551
|
def load_conda_env_file(
|
523
552
|
path: pathlib.Path,
|
524
|
-
) ->
|
525
|
-
DefaultDict[str,
|
526
|
-
Optional[
|
553
|
+
) -> tuple[
|
554
|
+
DefaultDict[str, list[requirements.Requirement]],
|
555
|
+
Optional[list[requirements.Requirement]],
|
527
556
|
Optional[str],
|
528
557
|
Optional[str],
|
529
558
|
]:
|
@@ -601,7 +630,7 @@ def load_conda_env_file(
|
|
601
630
|
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
|
602
631
|
|
603
632
|
|
604
|
-
def load_requirements_file(path: pathlib.Path) ->
|
633
|
+
def load_requirements_file(path: pathlib.Path) -> list[requirements.Requirement]:
|
605
634
|
"""Load Python requirements.txt file from the given directory path.
|
606
635
|
|
607
636
|
Args:
|
@@ -641,8 +670,8 @@ def parse_python_version_string(dep: str) -> Optional[str]:
|
|
641
670
|
|
642
671
|
|
643
672
|
def _find_conda_dep_spec(
|
644
|
-
conda_chan_deps: DefaultDict[str,
|
645
|
-
) -> Optional[
|
673
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], pkg_name: str
|
674
|
+
) -> Optional[tuple[str, requirements.Requirement]]:
|
646
675
|
for channel in conda_chan_deps:
|
647
676
|
spec = next(filter(lambda req: req.name == pkg_name, conda_chan_deps[channel]), None)
|
648
677
|
if spec:
|
@@ -650,14 +679,14 @@ def _find_conda_dep_spec(
|
|
650
679
|
return None
|
651
680
|
|
652
681
|
|
653
|
-
def _find_pip_req_spec(pip_reqs:
|
682
|
+
def _find_pip_req_spec(pip_reqs: list[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
|
654
683
|
spec = next(filter(lambda req: req.name == pkg_name, pip_reqs), None)
|
655
684
|
return spec
|
656
685
|
|
657
686
|
|
658
687
|
def find_dep_spec(
|
659
|
-
conda_chan_deps: DefaultDict[str,
|
660
|
-
pip_reqs:
|
688
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
|
689
|
+
pip_reqs: list[requirements.Requirement],
|
661
690
|
conda_pkg_name: str,
|
662
691
|
pip_pkg_name: Optional[str] = None,
|
663
692
|
remove_spec: bool = False,
|