snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import re
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Union, overload
|
3
3
|
|
4
4
|
from snowflake.snowpark._internal.analyzer import analyzer_utils
|
5
5
|
|
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
|
|
12
12
|
_SF_SCHEMA_LEVEL_OBJECT = (
|
13
13
|
rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
|
14
14
|
)
|
15
|
-
_SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path
|
15
|
+
_SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
|
16
16
|
_SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
|
17
17
|
_SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
|
18
18
|
|
@@ -112,7 +112,7 @@ def get_inferred_name(name: str) -> str:
|
|
112
112
|
return escaped_id
|
113
113
|
|
114
114
|
|
115
|
-
def concat_names(names:
|
115
|
+
def concat_names(names: list[str]) -> str:
|
116
116
|
"""Concatenates `names` to form one valid id.
|
117
117
|
|
118
118
|
|
@@ -142,7 +142,7 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
|
|
142
142
|
|
143
143
|
def parse_schema_level_object_identifier(
|
144
144
|
object_name: str,
|
145
|
-
) ->
|
145
|
+
) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
|
146
146
|
"""Parse a string which starts with schema level object.
|
147
147
|
|
148
148
|
Args:
|
@@ -172,7 +172,7 @@ def parse_schema_level_object_identifier(
|
|
172
172
|
|
173
173
|
def parse_snowflake_stage_path(
|
174
174
|
path: str,
|
175
|
-
) ->
|
175
|
+
) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
|
176
176
|
"""Parse a string which represents a snowflake stage path.
|
177
177
|
|
178
178
|
Args:
|
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
|
|
197
197
|
res.group("db"),
|
198
198
|
res.group("schema"),
|
199
199
|
res.group("object"),
|
200
|
-
res.group("path"),
|
200
|
+
res.group("path") or "",
|
201
201
|
)
|
202
202
|
|
203
203
|
|
@@ -260,11 +260,11 @@ def get_unescaped_names(ids: str) -> str:
|
|
260
260
|
|
261
261
|
|
262
262
|
@overload
|
263
|
-
def get_unescaped_names(ids:
|
263
|
+
def get_unescaped_names(ids: list[str]) -> list[str]:
|
264
264
|
...
|
265
265
|
|
266
266
|
|
267
|
-
def get_unescaped_names(ids: Optional[Union[str,
|
267
|
+
def get_unescaped_names(ids: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
|
268
268
|
"""Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
|
269
269
|
response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
|
270
270
|
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
@@ -308,11 +308,11 @@ def get_inferred_names(names: str) -> str:
|
|
308
308
|
|
309
309
|
|
310
310
|
@overload
|
311
|
-
def get_inferred_names(names:
|
311
|
+
def get_inferred_names(names: list[str]) -> list[str]:
|
312
312
|
...
|
313
313
|
|
314
314
|
|
315
|
-
def get_inferred_names(names: Optional[Union[str,
|
315
|
+
def get_inferred_names(names: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
|
316
316
|
"""Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
|
317
317
|
in case of column name contains special characters, and maintains case-sensitivity
|
318
318
|
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import importlib
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any
|
3
3
|
|
4
4
|
|
5
5
|
class MissingOptionalDependency:
|
@@ -46,7 +46,7 @@ def import_with_fallbacks(*targets: str) -> Any:
|
|
46
46
|
raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
|
47
47
|
|
48
48
|
|
49
|
-
def import_or_get_dummy(target: str) ->
|
49
|
+
def import_or_get_dummy(target: str) -> tuple[Any, bool]:
|
50
50
|
"""Try to import the the given target or return a dummy object.
|
51
51
|
|
52
52
|
If the import target (package/module/symbol) is available, the target will be returned. If it is not available,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import math
|
2
2
|
from contextlib import contextmanager
|
3
3
|
from timeit import default_timer
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Generator, Iterable, Optional
|
5
5
|
|
6
6
|
import snowflake.snowpark.functions as F
|
7
7
|
from snowflake import snowpark
|
@@ -17,17 +17,17 @@ def timer() -> Generator[Callable[[], float], None, None]:
|
|
17
17
|
yield lambda: elapser()
|
18
18
|
|
19
19
|
|
20
|
-
def _flatten(L: Iterable[
|
20
|
+
def _flatten(L: Iterable[list[Any]]) -> list[Any]:
|
21
21
|
return [val for sublist in L for val in sublist]
|
22
22
|
|
23
23
|
|
24
24
|
def map_dataframe_by_column(
|
25
25
|
df: snowpark.DataFrame,
|
26
|
-
cols:
|
27
|
-
map_func: Callable[[snowpark.DataFrame,
|
26
|
+
cols: list[str],
|
27
|
+
map_func: Callable[[snowpark.DataFrame, list[str]], snowpark.DataFrame],
|
28
28
|
partition_size: int,
|
29
|
-
statement_params: Optional[
|
30
|
-
) ->
|
29
|
+
statement_params: Optional[dict[str, Any]] = None,
|
30
|
+
) -> list[list[Any]]:
|
31
31
|
"""Applies the `map_func` to the input DataFrame by parallelizing it over subsets of the column.
|
32
32
|
|
33
33
|
Because the return results are materialized as Python lists *in memory*, this method should
|
@@ -84,7 +84,7 @@ def map_dataframe_by_column(
|
|
84
84
|
unioned_df = mapped_df if unioned_df is None else unioned_df.union_all(mapped_df)
|
85
85
|
|
86
86
|
# Store results in a list of size |n_partitions| x |n_rows| x |n_output_cols|
|
87
|
-
all_results:
|
87
|
+
all_results: list[list[list[Any]]] = [[] for _ in range(n_partitions - 1)]
|
88
88
|
|
89
89
|
# Collect the results of the first n-1 partitions, removing the partition_id column
|
90
90
|
unioned_result = unioned_df.collect(statement_params=statement_params) if unioned_df is not None else []
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import sys
|
2
2
|
import warnings
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, Union
|
4
4
|
|
5
5
|
from packaging.version import Version
|
6
6
|
|
@@ -8,7 +8,7 @@ from snowflake.ml._internal import telemetry
|
|
8
8
|
from snowflake.snowpark import AsyncJob, Row, Session
|
9
9
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
10
10
|
|
11
|
-
cache:
|
11
|
+
cache: dict[str, Optional[str]] = {}
|
12
12
|
|
13
13
|
_PROJECT = "ModelDevelopment"
|
14
14
|
_SUBPROJECT = "utils"
|
@@ -23,8 +23,8 @@ def is_relaxed() -> bool:
|
|
23
23
|
|
24
24
|
|
25
25
|
def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
26
|
-
pkg_versions:
|
27
|
-
) ->
|
26
|
+
pkg_versions: list[str], session: Session, subproject: Optional[str] = None
|
27
|
+
) -> list[str]:
|
28
28
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
29
29
|
return pkg_versions
|
30
30
|
else:
|
@@ -32,9 +32,9 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
|
32
32
|
|
33
33
|
|
34
34
|
def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
35
|
-
pkg_versions:
|
36
|
-
) ->
|
37
|
-
pkg_version_async_job_list:
|
35
|
+
pkg_versions: list[str], session: Session, subproject: Optional[str] = None
|
36
|
+
) -> list[str]:
|
37
|
+
pkg_version_async_job_list: list[tuple[str, AsyncJob]] = []
|
38
38
|
for pkg_version in pkg_versions:
|
39
39
|
if pkg_version not in cache:
|
40
40
|
# Execute pkg version queries asynchronously.
|
@@ -64,7 +64,7 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
|
64
64
|
|
65
65
|
def _query_pkg_version_supported_in_snowflake_conda_channel(
|
66
66
|
pkg_version: str, session: Session, block: bool, subproject: Optional[str] = None
|
67
|
-
) -> Union[AsyncJob,
|
67
|
+
) -> Union[AsyncJob, list[Row]]:
|
68
68
|
tokens = pkg_version.split("==")
|
69
69
|
if len(tokens) != 2:
|
70
70
|
raise RuntimeError(
|
@@ -102,9 +102,9 @@ def _query_pkg_version_supported_in_snowflake_conda_channel(
|
|
102
102
|
return pkg_version_list_or_async_job
|
103
103
|
|
104
104
|
|
105
|
-
def _get_conda_packages_and_emit_warnings(pkg_versions:
|
106
|
-
pkg_version_conda_list:
|
107
|
-
pkg_version_warning_list:
|
105
|
+
def _get_conda_packages_and_emit_warnings(pkg_versions: list[str]) -> list[str]:
|
106
|
+
pkg_version_conda_list: list[str] = []
|
107
|
+
pkg_version_warning_list: list[list[str]] = []
|
108
108
|
for pkg_version in pkg_versions:
|
109
109
|
try:
|
110
110
|
conda_pkg_version = cache[pkg_version]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations # for return self methods
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional
|
5
5
|
|
6
6
|
from snowflake import connector, snowpark
|
7
7
|
from snowflake.ml._internal.utils import formatting
|
@@ -123,7 +123,7 @@ def cell_value_by_column_matcher(
|
|
123
123
|
return True
|
124
124
|
|
125
125
|
|
126
|
-
_DEFAULT_MATCHERS:
|
126
|
+
_DEFAULT_MATCHERS: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = [
|
127
127
|
partial(result_dimension_matcher, 1, 1),
|
128
128
|
partial(column_name_matcher, "status"),
|
129
129
|
]
|
@@ -252,12 +252,12 @@ class SqlResultValidator(ResultValidator):
|
|
252
252
|
"""
|
253
253
|
|
254
254
|
def __init__(
|
255
|
-
self, session: snowpark.Session, query: str, statement_params: Optional[
|
255
|
+
self, session: snowpark.Session, query: str, statement_params: Optional[dict[str, Any]] = None
|
256
256
|
) -> None:
|
257
257
|
self._session: snowpark.Session = session
|
258
258
|
self._query: str = query
|
259
259
|
self._success_matchers: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = []
|
260
|
-
self._statement_params: Optional[
|
260
|
+
self._statement_params: Optional[dict[str, Any]] = statement_params
|
261
261
|
|
262
262
|
def _get_result(self) -> list[snowpark.Row]:
|
263
263
|
"""Collect the result of the given SQL query."""
|
@@ -1,15 +1,15 @@
|
|
1
1
|
import enum
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, TypedDict, cast
|
3
3
|
|
4
4
|
from packaging import version
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
7
7
|
from snowflake.ml._internal.utils import query_result_checker
|
8
|
-
from snowflake.snowpark import session
|
8
|
+
from snowflake.snowpark import exceptions as sp_exceptions, session
|
9
9
|
|
10
10
|
|
11
11
|
def get_current_snowflake_version(
|
12
|
-
sess: session.Session, *, statement_params: Optional[
|
12
|
+
sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
|
13
13
|
) -> version.Version:
|
14
14
|
"""Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
|
15
15
|
"7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
|
@@ -60,8 +60,8 @@ class SnowflakeRegion(TypedDict):
|
|
60
60
|
|
61
61
|
|
62
62
|
def get_regions(
|
63
|
-
sess: session.Session, *, statement_params: Optional[
|
64
|
-
) ->
|
63
|
+
sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
|
64
|
+
) -> dict[str, SnowflakeRegion]:
|
65
65
|
res = (
|
66
66
|
query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
|
67
67
|
.has_column("snowflake_region")
|
@@ -93,7 +93,7 @@ def get_regions(
|
|
93
93
|
return res_dict
|
94
94
|
|
95
95
|
|
96
|
-
def get_current_region_id(sess: session.Session, *, statement_params: Optional[
|
96
|
+
def get_current_region_id(sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None) -> str:
|
97
97
|
res = (
|
98
98
|
query_result_checker.SqlResultValidator(
|
99
99
|
sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
|
@@ -103,3 +103,25 @@ def get_current_region_id(sess: session.Session, *, statement_params: Optional[D
|
|
103
103
|
)
|
104
104
|
|
105
105
|
return cast(str, res.CURRENT_REGION)
|
106
|
+
|
107
|
+
|
108
|
+
def get_current_cloud(
|
109
|
+
sess: session.Session,
|
110
|
+
default: Optional[SnowflakeCloudType] = None,
|
111
|
+
*,
|
112
|
+
statement_params: Optional[dict[str, Any]] = None,
|
113
|
+
) -> SnowflakeCloudType:
|
114
|
+
region_id = get_current_region_id(sess, statement_params=statement_params)
|
115
|
+
try:
|
116
|
+
region = get_regions(sess, statement_params=statement_params)[region_id]
|
117
|
+
return region["cloud"]
|
118
|
+
except sp_exceptions.SnowparkSQLException:
|
119
|
+
# SHOW REGIONS not available, try to infer cloud from region name
|
120
|
+
region_name = region_id.split(".", 1)[-1] # Drop region group if any, e.g. PUBLIC
|
121
|
+
cloud_name_maybe = region_name.split("_", 1)[0] # Extract cloud name, e.g. AWS_US_WEST -> AWS
|
122
|
+
try:
|
123
|
+
return SnowflakeCloudType.from_value(cloud_name_maybe)
|
124
|
+
except ValueError:
|
125
|
+
if default:
|
126
|
+
return default
|
127
|
+
raise
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal.utils import sql_identifier
|
7
7
|
from snowflake.snowpark import functions, types
|
8
8
|
|
9
9
|
|
10
|
-
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[
|
10
|
+
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[list[str]] = None) -> snowpark.DataFrame:
|
11
11
|
"""Cast columns in the dataframe to types that are compatible with tensor.
|
12
12
|
|
13
13
|
It assists FileSet.make() in performing implicit data casting.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import identifier
|
4
4
|
|
@@ -77,13 +77,13 @@ class SqlIdentifier(str):
|
|
77
77
|
return super().__hash__()
|
78
78
|
|
79
79
|
|
80
|
-
def to_sql_identifiers(list_of_str:
|
80
|
+
def to_sql_identifiers(list_of_str: list[str], *, case_sensitive: bool = False) -> list[SqlIdentifier]:
|
81
81
|
return [SqlIdentifier(val, case_sensitive=case_sensitive) for val in list_of_str]
|
82
82
|
|
83
83
|
|
84
84
|
def parse_fully_qualified_name(
|
85
85
|
name: str,
|
86
|
-
) ->
|
86
|
+
) -> tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
|
87
87
|
db, schema, object = identifier.parse_schema_level_object_identifier(name)
|
88
88
|
|
89
89
|
assert name is not None, f"Unable parse the input name `{name}` as fully qualified."
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
|
@@ -24,8 +24,8 @@ def create_single_table(
|
|
24
24
|
database_name: str,
|
25
25
|
schema_name: str,
|
26
26
|
table_name: str,
|
27
|
-
table_schema:
|
28
|
-
statement_params: Optional[
|
27
|
+
table_schema: list[tuple[str, str]],
|
28
|
+
statement_params: Optional[dict[str, Any]] = None,
|
29
29
|
) -> str:
|
30
30
|
"""Creates a single table for registry and returns the fully qualified name of the table.
|
31
31
|
|
@@ -55,7 +55,7 @@ def create_single_table(
|
|
55
55
|
return fully_qualified_table_name
|
56
56
|
|
57
57
|
|
58
|
-
def insert_table_entry(session: snowpark.Session, table: str, columns:
|
58
|
+
def insert_table_entry(session: snowpark.Session, table: str, columns: dict[str, Any]) -> list[snowpark.Row]:
|
59
59
|
"""Insert an entry into an internal Model Registry table.
|
60
60
|
|
61
61
|
Args:
|
@@ -99,9 +99,9 @@ def validate_table_exist(session: snowpark.Session, table: str, qualified_schema
|
|
99
99
|
return len(tables) == 1
|
100
100
|
|
101
101
|
|
102
|
-
def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) ->
|
102
|
+
def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> dict[str, str]:
|
103
103
|
result = session.sql(f"DESC TABLE {qualified_schema_name}.{table_name}").collect()
|
104
|
-
schema_dict:
|
104
|
+
schema_dict: dict[str, str] = {}
|
105
105
|
for row in result:
|
106
106
|
schema_dict[row["name"]] = row["type"]
|
107
107
|
return schema_dict
|
@@ -112,13 +112,13 @@ def get_table_schema_types(
|
|
112
112
|
database: str,
|
113
113
|
schema: str,
|
114
114
|
table_name: str,
|
115
|
-
) ->
|
115
|
+
) -> dict[str, types.DataType]:
|
116
116
|
fully_qualified_table_name = identifier.get_schema_level_object_identifier(
|
117
117
|
db=database, schema=schema, object_name=table_name
|
118
118
|
)
|
119
|
-
struct_fields:
|
119
|
+
struct_fields: list[types.StructField] = session.table(fully_qualified_table_name).schema.fields
|
120
120
|
|
121
|
-
schema_dict:
|
121
|
+
schema_dict: dict[str, types.DataType] = {}
|
122
122
|
for field in struct_fields:
|
123
123
|
schema_dict[field.name] = field.datatype
|
124
124
|
return schema_dict
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from typing import Any, Deque,
|
5
|
+
from typing import Any, Deque, Iterator, Optional, Sequence, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -71,7 +71,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
71
71
|
return cls(session, sources)
|
72
72
|
|
73
73
|
@property
|
74
|
-
def data_sources(self) ->
|
74
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
75
75
|
return self._data_sources
|
76
76
|
|
77
77
|
def to_batches(
|
@@ -79,7 +79,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
79
79
|
batch_size: int,
|
80
80
|
shuffle: bool = True,
|
81
81
|
drop_last_batch: bool = True,
|
82
|
-
) -> Iterator[
|
82
|
+
) -> Iterator[dict[str, npt.NDArray[Any]]]:
|
83
83
|
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
84
84
|
|
85
85
|
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
@@ -120,7 +120,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
120
120
|
|
121
121
|
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
122
122
|
format = self._format
|
123
|
-
sources:
|
123
|
+
sources: list[Any] = []
|
124
124
|
source_format = None
|
125
125
|
for source in self._data_sources:
|
126
126
|
if isinstance(source, str):
|
@@ -155,7 +155,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
155
155
|
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
156
156
|
return pa_dataset
|
157
157
|
|
158
|
-
def _get_batches_from_buffer(self, batch_size: int) ->
|
158
|
+
def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
|
159
159
|
"""Generate new batches from the existing record batch buffer."""
|
160
160
|
cnt_rbs_num_rows = 0
|
161
161
|
candidates = []
|
@@ -180,7 +180,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
180
180
|
return _record_batch_to_arrays(res)
|
181
181
|
|
182
182
|
|
183
|
-
def _merge_record_batches(record_batches:
|
183
|
+
def _merge_record_batches(record_batches: list[pa.RecordBatch]) -> pa.RecordBatch:
|
184
184
|
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
185
185
|
if not record_batches:
|
186
186
|
return _EMPTY_RECORD_BATCH
|
@@ -192,7 +192,7 @@ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatc
|
|
192
192
|
return batches[0]
|
193
193
|
|
194
194
|
|
195
|
-
def _record_batch_to_arrays(rb: pa.RecordBatch) ->
|
195
|
+
def _record_batch_to_arrays(rb: pa.RecordBatch) -> dict[str, npt.NDArray[Any]]:
|
196
196
|
"""Transform the record batch to a (string, numpy array) dict."""
|
197
197
|
batch_dict = {}
|
198
198
|
for column, column_schema in zip(rb, rb.schema):
|
@@ -1,28 +1,13 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
3
|
-
TYPE_CHECKING,
|
4
|
-
Any,
|
5
|
-
Dict,
|
6
|
-
Generator,
|
7
|
-
List,
|
8
|
-
Optional,
|
9
|
-
Sequence,
|
10
|
-
Type,
|
11
|
-
TypeVar,
|
12
|
-
cast,
|
13
|
-
)
|
2
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
|
14
3
|
|
15
4
|
import numpy.typing as npt
|
16
5
|
from typing_extensions import deprecated
|
17
6
|
|
18
7
|
from snowflake import snowpark
|
19
|
-
from snowflake.ml._internal import telemetry
|
8
|
+
from snowflake.ml._internal import env, telemetry
|
20
9
|
from snowflake.ml.data import data_ingestor, data_source
|
21
10
|
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
22
|
-
from snowflake.ml.modeling._internal.constants import (
|
23
|
-
IN_ML_RUNTIME_ENV_VAR,
|
24
|
-
USE_OPTIMIZED_DATA_INGESTOR,
|
25
|
-
)
|
26
11
|
from snowflake.snowpark import context as sf_context
|
27
12
|
|
28
13
|
if TYPE_CHECKING:
|
@@ -43,7 +28,7 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
|
43
28
|
class DataConnector:
|
44
29
|
"""Snowflake data reader which provides application integration connectors"""
|
45
30
|
|
46
|
-
DEFAULT_INGESTOR_CLASS:
|
31
|
+
DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
|
47
32
|
|
48
33
|
def __init__(
|
49
34
|
self,
|
@@ -54,27 +39,22 @@ class DataConnector:
|
|
54
39
|
self._kwargs = kwargs
|
55
40
|
|
56
41
|
@classmethod
|
57
|
-
@snowpark._internal.utils.private_preview(version="1.6.0")
|
58
42
|
def from_dataframe(
|
59
|
-
cls:
|
43
|
+
cls: type[DataConnectorType],
|
60
44
|
df: snowpark.DataFrame,
|
61
|
-
ingestor_class: Optional[
|
45
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
62
46
|
**kwargs: Any,
|
63
47
|
) -> DataConnectorType:
|
64
48
|
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
65
49
|
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
66
|
-
return
|
67
|
-
DataConnectorType,
|
68
|
-
cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
|
69
|
-
)
|
50
|
+
return cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs)
|
70
51
|
|
71
52
|
@classmethod
|
72
|
-
@snowpark._internal.utils.private_preview(version="1.7.3")
|
73
53
|
def from_sql(
|
74
|
-
cls:
|
54
|
+
cls: type[DataConnectorType],
|
75
55
|
query: str,
|
76
56
|
session: Optional[snowpark.Session] = None,
|
77
|
-
ingestor_class: Optional[
|
57
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
78
58
|
**kwargs: Any,
|
79
59
|
) -> DataConnectorType:
|
80
60
|
session = session or sf_context.get_active_session()
|
@@ -83,9 +63,9 @@ class DataConnector:
|
|
83
63
|
|
84
64
|
@classmethod
|
85
65
|
def from_dataset(
|
86
|
-
cls:
|
66
|
+
cls: type[DataConnectorType],
|
87
67
|
ds: "dataset.Dataset",
|
88
|
-
ingestor_class: Optional[
|
68
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
89
69
|
**kwargs: Any,
|
90
70
|
) -> DataConnectorType:
|
91
71
|
dsv = ds.selected_version
|
@@ -102,10 +82,10 @@ class DataConnector:
|
|
102
82
|
func_params_to_log=["sources", "ingestor_class"],
|
103
83
|
)
|
104
84
|
def from_sources(
|
105
|
-
cls:
|
85
|
+
cls: type[DataConnectorType],
|
106
86
|
session: snowpark.Session,
|
107
87
|
sources: Sequence[data_source.DataSource],
|
108
|
-
ingestor_class: Optional[
|
88
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
109
89
|
**kwargs: Any,
|
110
90
|
) -> DataConnectorType:
|
111
91
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
@@ -113,7 +93,7 @@ class DataConnector:
|
|
113
93
|
return cls(ingestor, **kwargs)
|
114
94
|
|
115
95
|
@property
|
116
|
-
def data_sources(self) ->
|
96
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
117
97
|
return self._ingestor.data_sources
|
118
98
|
|
119
99
|
@telemetry.send_api_usage_telemetry(
|
@@ -139,7 +119,7 @@ class DataConnector:
|
|
139
119
|
"""
|
140
120
|
import tensorflow as tf
|
141
121
|
|
142
|
-
def generator() -> Generator[
|
122
|
+
def generator() -> Generator[dict[str, npt.NDArray[Any]], None, None]:
|
143
123
|
yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
|
144
124
|
|
145
125
|
# Derive TensorFlow signature
|
@@ -269,11 +249,10 @@ class DataConnector:
|
|
269
249
|
|
270
250
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
271
251
|
# Fail silently if the data ingester is not found
|
272
|
-
if
|
252
|
+
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
|
273
253
|
try:
|
274
254
|
from runtime_external_entities import get_ingester_class
|
275
255
|
|
276
256
|
DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
|
277
257
|
except ImportError:
|
278
258
|
"""Runtime Default Ingester not found, ignore"""
|
279
|
-
pass
|
@@ -1,15 +1,4 @@
|
|
1
|
-
from typing import
|
2
|
-
TYPE_CHECKING,
|
3
|
-
Any,
|
4
|
-
Dict,
|
5
|
-
Iterator,
|
6
|
-
List,
|
7
|
-
Optional,
|
8
|
-
Protocol,
|
9
|
-
Sequence,
|
10
|
-
Type,
|
11
|
-
TypeVar,
|
12
|
-
)
|
1
|
+
from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, Sequence, TypeVar
|
13
2
|
|
14
3
|
from numpy import typing as npt
|
15
4
|
|
@@ -26,12 +15,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
26
15
|
class DataIngestor(Protocol):
|
27
16
|
@classmethod
|
28
17
|
def from_sources(
|
29
|
-
cls:
|
18
|
+
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
30
19
|
) -> DataIngestorType:
|
31
20
|
raise NotImplementedError
|
32
21
|
|
33
22
|
@property
|
34
|
-
def data_sources(self) ->
|
23
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
35
24
|
raise NotImplementedError
|
36
25
|
|
37
26
|
def to_batches(
|
@@ -39,7 +28,7 @@ class DataIngestor(Protocol):
|
|
39
28
|
batch_size: int,
|
40
29
|
shuffle: bool = True,
|
41
30
|
drop_last_batch: bool = True,
|
42
|
-
) -> Iterator[
|
31
|
+
) -> Iterator[dict[str, npt.NDArray[Any]]]:
|
43
32
|
raise NotImplementedError
|
44
33
|
|
45
34
|
def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
|
snowflake/ml/data/data_source.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import dataclasses
|
2
|
-
from typing import
|
2
|
+
from typing import Optional, Union
|
3
3
|
|
4
4
|
|
5
5
|
@dataclasses.dataclass(frozen=True)
|
@@ -17,7 +17,7 @@ class DatasetInfo:
|
|
17
17
|
fully_qualified_name: str
|
18
18
|
version: str
|
19
19
|
url: Optional[str] = None
|
20
|
-
exclude_cols: Optional[
|
20
|
+
exclude_cols: Optional[list[str]] = None
|
21
21
|
|
22
22
|
|
23
23
|
DataSource = Union[DataFrameInfo, DatasetInfo, str]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional
|
2
2
|
|
3
3
|
import fsspec
|
4
4
|
import pyarrow as pa
|
@@ -33,7 +33,7 @@ def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFr
|
|
33
33
|
|
34
34
|
def get_dataframe_result_batches(
|
35
35
|
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
36
|
-
) ->
|
36
|
+
) -> list[result_batch.ResultBatch]:
|
37
37
|
"""Retrieve the ResultBatches for a given query"""
|
38
38
|
cursor = _get_dataframe_cursor(session, df_info)
|
39
39
|
batches = cursor.get_result_batches()
|
@@ -63,7 +63,7 @@ def get_dataset_filesystem(
|
|
63
63
|
|
64
64
|
def get_dataset_files(
|
65
65
|
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
66
|
-
) ->
|
66
|
+
) -> list[str]:
|
67
67
|
"""Get the list of files in a given Dataset"""
|
68
68
|
if filesystem is None:
|
69
69
|
filesystem = get_dataset_filesystem(session, ds_info)
|