snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,203 +0,0 @@
|
|
1
|
-
from typing import Dict, Optional, Tuple
|
2
|
-
|
3
|
-
from typing_extensions import TypedDict
|
4
|
-
|
5
|
-
from snowflake import snowpark
|
6
|
-
from snowflake.ml._internal import telemetry
|
7
|
-
from snowflake.snowpark import functions
|
8
|
-
|
9
|
-
_PROJECT = "MLOps"
|
10
|
-
_SUBPROJECT = "Monitor"
|
11
|
-
|
12
|
-
|
13
|
-
class BucketConfig(TypedDict):
|
14
|
-
""" "Options for bucketizing the data."""
|
15
|
-
|
16
|
-
min: int
|
17
|
-
max: int
|
18
|
-
size: int
|
19
|
-
|
20
|
-
|
21
|
-
@telemetry.send_api_usage_telemetry(
|
22
|
-
project=_PROJECT,
|
23
|
-
subproject=_SUBPROJECT,
|
24
|
-
)
|
25
|
-
@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release
|
26
|
-
def compare_udfs_outputs(
|
27
|
-
base_udf_name: str,
|
28
|
-
test_udf_name: str,
|
29
|
-
input_data_df: snowpark.DataFrame,
|
30
|
-
bucket_config: Optional[BucketConfig] = None,
|
31
|
-
) -> snowpark.DataFrame:
|
32
|
-
"""Compare outputs of 2 UDFs. Outputs are bucketized the based on bucketConfig.
|
33
|
-
This is useful when someone retrain a Model and deploy as UDF to compare against earlier UDF as ground truth.
|
34
|
-
NOTE: Only supports UDFs with single Column output.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
base_udf_name: used as control ground truth UDF.
|
38
|
-
test_udf_name: output of this UDF is compared against that of `base_udf`.
|
39
|
-
input_data_df: Input data used for computing metric.
|
40
|
-
bucket_config: must have the kv as {"min":xx, "max":xx, "size"}, keys in lowercase; it's width_bucket
|
41
|
-
Sqloperator's config, using https://docs.snowflake.com/en/sql-reference/functions/width_bucket.
|
42
|
-
|
43
|
-
Returns:
|
44
|
-
snowpark.DataFrame.
|
45
|
-
"BASEUDF" is base_udf's bucketized output, "TESTUDF" is test_udf's bucketized output,
|
46
|
-
"""
|
47
|
-
if bucket_config:
|
48
|
-
assert len(bucket_config) == 3
|
49
|
-
assert "min" in bucket_config and "max" in bucket_config and "size" in bucket_config
|
50
|
-
|
51
|
-
argStr = ",".join(input_data_df.columns)
|
52
|
-
query1Str = _get_udf_query_str("BASEUDF", f"{base_udf_name}({argStr})", input_data_df, bucket_config)
|
53
|
-
query2Str = _get_udf_query_str("TESTUDF", f"{test_udf_name}({argStr})", input_data_df, bucket_config)
|
54
|
-
|
55
|
-
if bucket_config:
|
56
|
-
finalStr = (
|
57
|
-
"select A.bucket, BASEUDF, TESTUDF \n from ({}) as A \n join ({}) as B \n on A.bucket=B.bucket".format(
|
58
|
-
query1Str, query2Str
|
59
|
-
)
|
60
|
-
)
|
61
|
-
else: # don't bucket at all
|
62
|
-
finalStr = "select {},{} \n from ({})".format(query1Str, query2Str, input_data_df.queries["queries"][0])
|
63
|
-
|
64
|
-
assert input_data_df._session is not None
|
65
|
-
return input_data_df._session.sql(finalStr)
|
66
|
-
|
67
|
-
|
68
|
-
@telemetry.send_api_usage_telemetry(
|
69
|
-
project=_PROJECT,
|
70
|
-
subproject=_SUBPROJECT,
|
71
|
-
)
|
72
|
-
@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release
|
73
|
-
def get_basic_stats(df: snowpark.DataFrame) -> Tuple[Dict[str, int], Dict[str, int]]:
|
74
|
-
"""Get basic stats of 2 Columns
|
75
|
-
Note this isn't public API. Only support min, max, stddev, HLL--cardinality estimate
|
76
|
-
|
77
|
-
Args:
|
78
|
-
df: input Snowpark Dataframe, must have 2 and only 2 columns
|
79
|
-
|
80
|
-
Returns:
|
81
|
-
2 Dict for 2 columns' stats
|
82
|
-
"""
|
83
|
-
projStr = ""
|
84
|
-
stats = ["MIN", "MAX", "STDDEV", "HLL"]
|
85
|
-
assert len(df.columns) == 2
|
86
|
-
for colName in df.columns:
|
87
|
-
for stat in stats:
|
88
|
-
projStr += f"{stat}({colName}) as {colName}_{stat},"
|
89
|
-
finalStr = "select {} \n from ({})".format(projStr[:-1], df.queries["queries"][0])
|
90
|
-
assert df._session is not None
|
91
|
-
resDf = df._session.sql(finalStr).to_pandas()
|
92
|
-
d1 = {}
|
93
|
-
col1 = df.columns[0]
|
94
|
-
d2 = {}
|
95
|
-
col2 = df.columns[1]
|
96
|
-
for stat in stats:
|
97
|
-
d1[stat] = resDf.iloc[0][f"{col1}_{stat}"]
|
98
|
-
d2[stat] = resDf.iloc[0][f"{col2}_{stat}"]
|
99
|
-
return d1, d2
|
100
|
-
|
101
|
-
|
102
|
-
@telemetry.send_api_usage_telemetry(
|
103
|
-
project=_PROJECT,
|
104
|
-
subproject=_SUBPROJECT,
|
105
|
-
)
|
106
|
-
@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release
|
107
|
-
def jensenshannon(df1: snowpark.DataFrame, colname1: str, df2: snowpark.DataFrame, colname2: str) -> float:
|
108
|
-
"""
|
109
|
-
Similar to scipy implementation:
|
110
|
-
https://github.com/scipy/scipy/blob/e4dec2c5993faa381bb4f76dce551d0d79734f8f/scipy/spatial/distance.py#L1174
|
111
|
-
It's server solution, all computing being in Snowflake warehouse, so will be significantly faster than client.
|
112
|
-
|
113
|
-
Args:
|
114
|
-
df1: 1st Snowpark Dataframe;
|
115
|
-
colname1: the col to be selected in df1
|
116
|
-
df2: 2nd Snowpark Dataframe;
|
117
|
-
colname2: the col to be selected in df2
|
118
|
-
Supported data Tyte: any data type that Snowflake supports, including VARIANT, OBJECT...etc.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
a jensenshannon value
|
122
|
-
"""
|
123
|
-
df1 = df1.select(colname1)
|
124
|
-
df1 = (
|
125
|
-
df1.group_by(colname1)
|
126
|
-
.agg(functions.count(colname1).alias("c1"))
|
127
|
-
.select(functions.col(colname1).alias("d1"), "c1")
|
128
|
-
)
|
129
|
-
df2 = df2.select(colname2)
|
130
|
-
df2 = (
|
131
|
-
df2.group_by(colname2)
|
132
|
-
.agg(functions.count(colname2).alias("c2"))
|
133
|
-
.select(functions.col(colname2).alias("d2"), "c2")
|
134
|
-
)
|
135
|
-
|
136
|
-
dfsum = df1.select("c1").agg(functions.sum("c1").alias("SUM1"))
|
137
|
-
sum1 = dfsum.collect()[0].as_dict()["SUM1"]
|
138
|
-
dfsum = df2.select("c2").agg(functions.sum("c2").alias("SUM2"))
|
139
|
-
sum2 = dfsum.collect()[0].as_dict()["SUM2"]
|
140
|
-
|
141
|
-
df1 = df1.select("d1", functions.sql_expr("c1 / " + str(sum1)).alias("p"))
|
142
|
-
minp = df1.select(functions.min("P").alias("MINP")).collect()[0].as_dict()["MINP"]
|
143
|
-
df2 = df2.select("d2", functions.sql_expr("c2 / " + str(sum2)).alias("q"))
|
144
|
-
minq = df2.select(functions.min("Q").alias("MINQ")).collect()[0].as_dict()["MINQ"]
|
145
|
-
|
146
|
-
DECAY_FACTOR = 0.5
|
147
|
-
df = df1.join(df2, df1.d1 == df2.d2, "fullouter").select(
|
148
|
-
"d1",
|
149
|
-
"d2",
|
150
|
-
functions.sql_expr(
|
151
|
-
"""
|
152
|
-
CASE
|
153
|
-
WHEN p is NULL THEN {}*{}
|
154
|
-
ELSE p
|
155
|
-
END
|
156
|
-
""".format(
|
157
|
-
minp, DECAY_FACTOR
|
158
|
-
)
|
159
|
-
).alias("p"),
|
160
|
-
functions.sql_expr(
|
161
|
-
"""
|
162
|
-
CASE
|
163
|
-
WHEN q is NULL THEN {}*{}
|
164
|
-
ELSE q
|
165
|
-
END
|
166
|
-
""".format(
|
167
|
-
minq, DECAY_FACTOR
|
168
|
-
)
|
169
|
-
).alias("q"),
|
170
|
-
)
|
171
|
-
|
172
|
-
df = df.select("p", "q", functions.sql_expr("(p+q)/2.0").alias("m"))
|
173
|
-
df = df.select(
|
174
|
-
functions.sql_expr(
|
175
|
-
"""
|
176
|
-
CASE
|
177
|
-
WHEN p > 0 AND m > 0 THEN p * LOG(2, p/m)
|
178
|
-
ELSE 0
|
179
|
-
END
|
180
|
-
"""
|
181
|
-
).alias("left"),
|
182
|
-
functions.sql_expr(
|
183
|
-
"""
|
184
|
-
CASE
|
185
|
-
WHEN q > 0 AND m > 0 THEN q * LOG(2, q/m)
|
186
|
-
ELSE 0
|
187
|
-
END
|
188
|
-
"""
|
189
|
-
).alias("right"),
|
190
|
-
)
|
191
|
-
resdf = df.select(functions.sql_expr("sqrt((sum(left) + sum(right)) / 2.0)").alias("JS"))
|
192
|
-
return float(resdf.collect()[0].as_dict()["JS"])
|
193
|
-
|
194
|
-
|
195
|
-
def _get_udf_query_str(
|
196
|
-
name: str, col: str, df: snowpark.DataFrame, bucket_config: Optional[BucketConfig] = None
|
197
|
-
) -> str:
|
198
|
-
if bucket_config:
|
199
|
-
return "select count(1) as {}, width_bucket({}, {}, {}, {}) bucket from ({}) group by bucket".format(
|
200
|
-
name, col, bucket_config["min"], bucket_config["max"], bucket_config["size"], df.queries["queries"][0]
|
201
|
-
)
|
202
|
-
else: # don't bucket at all
|
203
|
-
return f"{col} as {name}"
|
@@ -1,142 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, List, Tuple
|
2
|
-
|
3
|
-
from snowflake import snowpark
|
4
|
-
from snowflake.ml._internal.utils import identifier, query_result_checker, table_manager
|
5
|
-
|
6
|
-
# THIS FILE CONTAINS INITIAL REGISTRY SCHEMA.
|
7
|
-
# !!!!!!! WARNING !!!!!!!
|
8
|
-
# Please do not modify initial schema and use schema evolution mechanism in SchemaVersionManager to change the schema.
|
9
|
-
# If you are touching this file, make sure you understand what you are doing.
|
10
|
-
|
11
|
-
_INITIAL_VERSION: int = 0
|
12
|
-
|
13
|
-
_MODELS_TABLE_NAME: str = "_SYSTEM_REGISTRY_MODELS"
|
14
|
-
_METADATA_TABLE_NAME: str = "_SYSTEM_REGISTRY_METADATA"
|
15
|
-
_DEPLOYMENT_TABLE_NAME: str = "_SYSTEM_REGISTRY_DEPLOYMENTS"
|
16
|
-
_ARTIFACT_TABLE_NAME: str = "_SYSTEM_REGISTRY_ARTIFACTS"
|
17
|
-
|
18
|
-
_INITIAL_REGISTRY_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
19
|
-
("CREATION_CONTEXT", "VARCHAR"),
|
20
|
-
("CREATION_ENVIRONMENT_SPEC", "OBJECT"),
|
21
|
-
("CREATION_ROLE", "VARCHAR"),
|
22
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
23
|
-
("ID", "VARCHAR PRIMARY KEY RELY"),
|
24
|
-
("INPUT_SPEC", "OBJECT"),
|
25
|
-
("NAME", "VARCHAR"),
|
26
|
-
("OUTPUT_SPEC", "OBJECT"),
|
27
|
-
("RUNTIME_ENVIRONMENT_SPEC", "OBJECT"),
|
28
|
-
("TRAINING_DATASET_ID", "VARCHAR"),
|
29
|
-
("TYPE", "VARCHAR"),
|
30
|
-
("URI", "VARCHAR"),
|
31
|
-
("VERSION", "VARCHAR"),
|
32
|
-
]
|
33
|
-
|
34
|
-
_INITIAL_METADATA_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
35
|
-
("ATTRIBUTE_NAME", "VARCHAR"),
|
36
|
-
("EVENT_ID", "VARCHAR UNIQUE NOT NULL"),
|
37
|
-
("EVENT_TIMESTAMP", "TIMESTAMP_TZ"),
|
38
|
-
("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"),
|
39
|
-
("OPERATION", "VARCHAR"),
|
40
|
-
("ROLE", "VARCHAR"),
|
41
|
-
("SEQUENCE_ID", "BIGINT AUTOINCREMENT START 0 INCREMENT 1 PRIMARY KEY"),
|
42
|
-
("VALUE", "OBJECT"),
|
43
|
-
]
|
44
|
-
|
45
|
-
_INITIAL_DEPLOYMENTS_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
46
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
47
|
-
("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"),
|
48
|
-
("DEPLOYMENT_NAME", "VARCHAR"),
|
49
|
-
("OPTIONS", "VARIANT"),
|
50
|
-
("TARGET_PLATFORM", "VARCHAR"),
|
51
|
-
("ROLE", "VARCHAR"),
|
52
|
-
("STAGE_PATH", "VARCHAR"),
|
53
|
-
("SIGNATURE", "VARIANT"),
|
54
|
-
("TARGET_METHOD", "VARCHAR"),
|
55
|
-
]
|
56
|
-
|
57
|
-
_INITIAL_ARTIFACT_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
58
|
-
("ID", "VARCHAR"),
|
59
|
-
("TYPE", "VARCHAR"),
|
60
|
-
("NAME", "VARCHAR"),
|
61
|
-
("VERSION", "VARCHAR"),
|
62
|
-
("CREATION_ROLE", "VARCHAR"),
|
63
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
64
|
-
("ARTIFACT_SPEC", "OBJECT"),
|
65
|
-
# Below is out-of-line constraints of Snowflake table.
|
66
|
-
# See https://docs.snowflake.com/en/sql-reference/sql/create-table
|
67
|
-
("PRIMARY KEY", "(ID, TYPE) RELY"),
|
68
|
-
]
|
69
|
-
|
70
|
-
_INITIAL_TABLE_SCHEMAS = {
|
71
|
-
_MODELS_TABLE_NAME: _INITIAL_REGISTRY_TABLE_SCHEMA,
|
72
|
-
_METADATA_TABLE_NAME: _INITIAL_METADATA_TABLE_SCHEMA,
|
73
|
-
_DEPLOYMENT_TABLE_NAME: _INITIAL_DEPLOYMENTS_TABLE_SCHEMA,
|
74
|
-
_ARTIFACT_TABLE_NAME: _INITIAL_ARTIFACT_TABLE_SCHEMA,
|
75
|
-
}
|
76
|
-
|
77
|
-
|
78
|
-
def create_initial_registry_tables(
|
79
|
-
session: snowpark.Session,
|
80
|
-
database_name: str,
|
81
|
-
schema_name: str,
|
82
|
-
statement_params: Dict[str, Any],
|
83
|
-
) -> None:
|
84
|
-
"""Creates initial set of tables for registry. This is the legacy schema from which schema evolution is supported.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
session: Active session to create tables.
|
88
|
-
database_name: Name of database in which tables will be created.
|
89
|
-
schema_name: Name of schema in which tables will be created.
|
90
|
-
statement_params: Statement parameters for telemetry tracking.
|
91
|
-
"""
|
92
|
-
model_table_full_path = table_manager.get_fully_qualified_table_name(database_name, schema_name, _MODELS_TABLE_NAME)
|
93
|
-
|
94
|
-
for table_name, schema_template in _INITIAL_TABLE_SCHEMAS.items():
|
95
|
-
table_schema = [(k, v.format(registry_table_name=model_table_full_path)) for k, v in schema_template]
|
96
|
-
table_manager.create_single_table(
|
97
|
-
session=session,
|
98
|
-
database_name=database_name,
|
99
|
-
schema_name=schema_name,
|
100
|
-
table_name=table_name,
|
101
|
-
table_schema=table_schema,
|
102
|
-
statement_params=statement_params,
|
103
|
-
)
|
104
|
-
|
105
|
-
|
106
|
-
def check_access(session: snowpark.Session, database_name: str, schema_name: str) -> None:
|
107
|
-
"""Check that the required tables exist and are accessible by the current role.
|
108
|
-
|
109
|
-
Args:
|
110
|
-
session: Active session to execution SQL queries.
|
111
|
-
database_name: Name of database where schema tables live.
|
112
|
-
schema_name: Name of schema where schema tables live.
|
113
|
-
"""
|
114
|
-
query_result_checker.SqlResultValidator(
|
115
|
-
session,
|
116
|
-
query=f"SHOW DATABASES LIKE '{identifier.get_unescaped_names(database_name)}'",
|
117
|
-
).has_dimensions(expected_rows=1).validate()
|
118
|
-
|
119
|
-
query_result_checker.SqlResultValidator(
|
120
|
-
session,
|
121
|
-
query=f"SHOW SCHEMAS LIKE '{identifier.get_unescaped_names(schema_name)}' IN DATABASE {database_name}",
|
122
|
-
).has_dimensions(expected_rows=1).validate()
|
123
|
-
|
124
|
-
full_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name)
|
125
|
-
|
126
|
-
table_manager.validate_table_exist(
|
127
|
-
session,
|
128
|
-
identifier.get_unescaped_names(_MODELS_TABLE_NAME),
|
129
|
-
full_qualified_schema_name,
|
130
|
-
)
|
131
|
-
table_manager.validate_table_exist(
|
132
|
-
session,
|
133
|
-
identifier.get_unescaped_names(_METADATA_TABLE_NAME),
|
134
|
-
full_qualified_schema_name,
|
135
|
-
)
|
136
|
-
table_manager.validate_table_exist(
|
137
|
-
session,
|
138
|
-
identifier.get_unescaped_names(_DEPLOYMENT_TABLE_NAME),
|
139
|
-
full_qualified_schema_name,
|
140
|
-
)
|
141
|
-
|
142
|
-
# TODO(zzhu): Also check validity of views.
|
snowflake/ml/registry/_schema.py
DELETED
@@ -1,82 +0,0 @@
|
|
1
|
-
from typing import Dict, List, Tuple, Type
|
2
|
-
|
3
|
-
from snowflake.ml.registry import _initial_schema, _schema_upgrade_plans
|
4
|
-
|
5
|
-
# BUMP THIS VERSION WHENEVER YOU CHANGE ANY SCHEMA TABLES.
|
6
|
-
# ALSO UPDATE SCHEMA UPGRADE PLANS.
|
7
|
-
_CURRENT_SCHEMA_VERSION = 3
|
8
|
-
|
9
|
-
_REGISTRY_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
10
|
-
("CREATION_CONTEXT", "VARCHAR"),
|
11
|
-
("CREATION_ENVIRONMENT_SPEC", "OBJECT"),
|
12
|
-
("CREATION_ROLE", "VARCHAR"),
|
13
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
14
|
-
("ID", "VARCHAR PRIMARY KEY RELY"),
|
15
|
-
("INPUT_SPEC", "OBJECT"),
|
16
|
-
("NAME", "VARCHAR"),
|
17
|
-
("OUTPUT_SPEC", "OBJECT"),
|
18
|
-
("RUNTIME_ENVIRONMENT_SPEC", "OBJECT"),
|
19
|
-
("ARTIFACT_IDS", "ARRAY"),
|
20
|
-
("TYPE", "VARCHAR"),
|
21
|
-
("URI", "VARCHAR"),
|
22
|
-
("VERSION", "VARCHAR"),
|
23
|
-
]
|
24
|
-
|
25
|
-
_METADATA_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
26
|
-
("ATTRIBUTE_NAME", "VARCHAR"),
|
27
|
-
("EVENT_ID", "VARCHAR UNIQUE NOT NULL"),
|
28
|
-
("EVENT_TIMESTAMP", "TIMESTAMP_TZ"),
|
29
|
-
("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"),
|
30
|
-
("OPERATION", "VARCHAR"),
|
31
|
-
("ROLE", "VARCHAR"),
|
32
|
-
("SEQUENCE_ID", "BIGINT AUTOINCREMENT START 0 INCREMENT 1 PRIMARY KEY"),
|
33
|
-
("VALUE", "OBJECT"),
|
34
|
-
]
|
35
|
-
|
36
|
-
_DEPLOYMENTS_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
37
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
38
|
-
("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"),
|
39
|
-
("DEPLOYMENT_NAME", "VARCHAR"),
|
40
|
-
("OPTIONS", "VARIANT"),
|
41
|
-
("TARGET_PLATFORM", "VARCHAR"),
|
42
|
-
("ROLE", "VARCHAR"),
|
43
|
-
("STAGE_PATH", "VARCHAR"),
|
44
|
-
("SIGNATURE", "VARIANT"),
|
45
|
-
("TARGET_METHOD", "VARCHAR"),
|
46
|
-
]
|
47
|
-
|
48
|
-
_ARTIFACT_TABLE_SCHEMA: List[Tuple[str, str]] = [
|
49
|
-
("ID", "VARCHAR"),
|
50
|
-
("TYPE", "VARCHAR"),
|
51
|
-
("NAME", "VARCHAR"),
|
52
|
-
("VERSION", "VARCHAR"),
|
53
|
-
("CREATION_ROLE", "VARCHAR"),
|
54
|
-
("CREATION_TIME", "TIMESTAMP_TZ"),
|
55
|
-
("ARTIFACT_SPEC", "VARCHAR"),
|
56
|
-
# Below is out-of-line constraints of Snowflake table.
|
57
|
-
# See https://docs.snowflake.com/en/sql-reference/sql/create-table
|
58
|
-
("PRIMARY KEY", "(ID, TYPE) RELY"),
|
59
|
-
]
|
60
|
-
|
61
|
-
# Note, one can add/remove tables from this tuple as well. As long as correct schema update process is followed.
|
62
|
-
# In case of a new table, they should not be defined in _initial_schema.
|
63
|
-
_CURRENT_TABLE_SCHEMAS = {
|
64
|
-
_initial_schema._MODELS_TABLE_NAME: _REGISTRY_TABLE_SCHEMA,
|
65
|
-
_initial_schema._METADATA_TABLE_NAME: _METADATA_TABLE_SCHEMA,
|
66
|
-
_initial_schema._DEPLOYMENT_TABLE_NAME: _DEPLOYMENTS_TABLE_SCHEMA,
|
67
|
-
_initial_schema._ARTIFACT_TABLE_NAME: _ARTIFACT_TABLE_SCHEMA,
|
68
|
-
}
|
69
|
-
|
70
|
-
|
71
|
-
_SCHEMA_UPGRADE_PLANS: Dict[int, Type[_schema_upgrade_plans.BaseSchemaUpgradePlans]] = {
|
72
|
-
# Currently _CURRENT_SCHEMA_VERSION == _initial_schema._INITIAL_VERSION, so no entry.
|
73
|
-
# But if schema evolves it must contain:
|
74
|
-
# Key = a version number
|
75
|
-
# Value = a subclass of _schema_upgrades.BaseSchemaUpgrade
|
76
|
-
# NOTE, all version from _INITIAL_VERSION + 1 till _CURRENT_SCHEMA_VERSION must exists.
|
77
|
-
1: _schema_upgrade_plans.AddTrainingDatasetIdIfNotExists,
|
78
|
-
2: _schema_upgrade_plans.ReplaceTrainingDatasetIdWithArtifactIds,
|
79
|
-
3: _schema_upgrade_plans.ChangeArtifactSpecFromObjectToVarchar,
|
80
|
-
}
|
81
|
-
|
82
|
-
assert len(_SCHEMA_UPGRADE_PLANS) == _CURRENT_SCHEMA_VERSION - _initial_schema._INITIAL_VERSION
|
@@ -1,116 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, Dict, Optional
|
3
|
-
|
4
|
-
from snowflake import snowpark
|
5
|
-
from snowflake.ml._internal.utils import table_manager
|
6
|
-
from snowflake.ml.registry import _initial_schema
|
7
|
-
|
8
|
-
|
9
|
-
class BaseSchemaUpgradePlans(ABC):
|
10
|
-
"""Abstract Class for specifying schema upgrades for registry."""
|
11
|
-
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: snowpark.Session,
|
15
|
-
database_name: str,
|
16
|
-
schema_name: str,
|
17
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database = database_name
|
21
|
-
self._schema = schema_name
|
22
|
-
self._statement_params = statement_params
|
23
|
-
|
24
|
-
@abstractmethod
|
25
|
-
def upgrade(self) -> None:
|
26
|
-
"""Convert schema from previous version to `_current_version`."""
|
27
|
-
pass
|
28
|
-
|
29
|
-
|
30
|
-
class AddTrainingDatasetIdIfNotExists(BaseSchemaUpgradePlans):
|
31
|
-
"""Add Column TRAINING_DATASET_ID in registry schema table."""
|
32
|
-
|
33
|
-
def __init__(
|
34
|
-
self,
|
35
|
-
session: snowpark.Session,
|
36
|
-
database_name: str,
|
37
|
-
schema_name: str,
|
38
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
39
|
-
) -> None:
|
40
|
-
super().__init__(session, database_name, schema_name, statement_params)
|
41
|
-
|
42
|
-
def upgrade(self) -> None:
|
43
|
-
full_schema_path = f"{self._database}.{self._schema}"
|
44
|
-
table_schema_dict = table_manager.get_table_schema(
|
45
|
-
self._session, _initial_schema._MODELS_TABLE_NAME, full_schema_path
|
46
|
-
)
|
47
|
-
new_column = "TRAINING_DATASET_ID"
|
48
|
-
if new_column not in table_schema_dict:
|
49
|
-
self._session.sql(
|
50
|
-
f"""ALTER TABLE {self._database}.{self._schema}.{_initial_schema._MODELS_TABLE_NAME}
|
51
|
-
ADD COLUMN {new_column} VARCHAR
|
52
|
-
"""
|
53
|
-
).collect(statement_params=self._statement_params)
|
54
|
-
|
55
|
-
|
56
|
-
class ReplaceTrainingDatasetIdWithArtifactIds(BaseSchemaUpgradePlans):
|
57
|
-
"""Drop column `TRAINING_DATASET_ID`, add `ARTIFACT_IDS`."""
|
58
|
-
|
59
|
-
def __init__(
|
60
|
-
self,
|
61
|
-
session: snowpark.Session,
|
62
|
-
database_name: str,
|
63
|
-
schema_name: str,
|
64
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
65
|
-
) -> None:
|
66
|
-
super().__init__(session, database_name, schema_name, statement_params)
|
67
|
-
|
68
|
-
def upgrade(self) -> None:
|
69
|
-
full_schema_path = f"{self._database}.{self._schema}"
|
70
|
-
old_column = "TRAINING_DATASET_ID"
|
71
|
-
self._session.sql(
|
72
|
-
f"""ALTER TABLE {full_schema_path}.{_initial_schema._MODELS_TABLE_NAME}
|
73
|
-
DROP COLUMN {old_column}
|
74
|
-
"""
|
75
|
-
).collect(statement_params=self._statement_params)
|
76
|
-
|
77
|
-
new_column = "ARTIFACT_IDS"
|
78
|
-
self._session.sql(
|
79
|
-
f"""ALTER TABLE {full_schema_path}.{_initial_schema._MODELS_TABLE_NAME}
|
80
|
-
ADD COLUMN {new_column} ARRAY
|
81
|
-
"""
|
82
|
-
).collect(statement_params=self._statement_params)
|
83
|
-
|
84
|
-
|
85
|
-
class ChangeArtifactSpecFromObjectToVarchar(BaseSchemaUpgradePlans):
|
86
|
-
"""Change artifact spec type from object to varchar. It's fine to drop the column as it's empty."""
|
87
|
-
|
88
|
-
def __init__(
|
89
|
-
self,
|
90
|
-
session: snowpark.Session,
|
91
|
-
database_name: str,
|
92
|
-
schema_name: str,
|
93
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
94
|
-
) -> None:
|
95
|
-
super().__init__(session, database_name, schema_name, statement_params)
|
96
|
-
|
97
|
-
def upgrade(self) -> None:
|
98
|
-
full_schema_path = f"{self._database}.{self._schema}"
|
99
|
-
update_col = "ARTIFACT_SPEC"
|
100
|
-
self._session.sql(
|
101
|
-
f"""ALTER TABLE {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME}
|
102
|
-
DROP COLUMN {update_col}
|
103
|
-
"""
|
104
|
-
).collect(statement_params=self._statement_params)
|
105
|
-
|
106
|
-
self._session.sql(
|
107
|
-
f"""ALTER TABLE {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME}
|
108
|
-
ADD COLUMN {update_col} VARCHAR
|
109
|
-
"""
|
110
|
-
).collect(statement_params=self._statement_params)
|
111
|
-
|
112
|
-
self._session.sql(
|
113
|
-
f"""COMMENT ON COLUMN {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME}.{update_col} IS
|
114
|
-
'This column is VARCHAR but supposed to store a valid JSON object'
|
115
|
-
"""
|
116
|
-
).collect(statement_params=self._statement_params)
|