snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py
RENAMED
@@ -23,7 +23,9 @@ from typing import Dict, List, Optional, Tuple
|
|
23
23
|
|
24
24
|
import requests
|
25
25
|
|
26
|
-
from snowflake.ml._internal.
|
26
|
+
from snowflake.ml._internal.container_services.image_registry import (
|
27
|
+
http_client as image_registry_http_client,
|
28
|
+
)
|
27
29
|
|
28
30
|
# Common HTTP headers
|
29
31
|
_CONTENT_LENGTH_HEADER = "content-length"
|
@@ -3,12 +3,14 @@ import logging
|
|
3
3
|
from typing import Dict, Optional, cast
|
4
4
|
from urllib.parse import urlunparse
|
5
5
|
|
6
|
+
from snowflake.ml._internal.container_services.image_registry import (
|
7
|
+
http_client as image_registry_http_client,
|
8
|
+
imagelib,
|
9
|
+
)
|
6
10
|
from snowflake.ml._internal.exceptions import (
|
7
11
|
error_codes,
|
8
12
|
exceptions as snowml_exceptions,
|
9
13
|
)
|
10
|
-
from snowflake.ml._internal.utils import image_registry_http_client
|
11
|
-
from snowflake.ml.model._deploy_client.utils import imagelib
|
12
14
|
from snowflake.snowpark import Session
|
13
15
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
14
16
|
|
@@ -33,7 +33,6 @@ class CONDA_OS(Enum):
|
|
33
33
|
|
34
34
|
_SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
|
35
35
|
_NODEFAULTS = "nodefaults"
|
36
|
-
_INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None
|
37
36
|
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
|
38
37
|
_SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
|
39
38
|
|
@@ -267,18 +266,6 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
267
266
|
return new_req
|
268
267
|
|
269
268
|
|
270
|
-
def _check_runtime_version_column_existence(session: session.Session) -> bool:
|
271
|
-
sql = textwrap.dedent(
|
272
|
-
"""
|
273
|
-
SHOW COLUMNS
|
274
|
-
LIKE 'runtime_version'
|
275
|
-
IN TABLE information_schema.packages;
|
276
|
-
"""
|
277
|
-
)
|
278
|
-
result = session.sql(sql).count()
|
279
|
-
return result == 1
|
280
|
-
|
281
|
-
|
282
269
|
def get_matched_package_versions_in_snowflake_conda_channel(
|
283
270
|
req: requirements.Requirement,
|
284
271
|
python_version: str = snowml_env.PYTHON_VERSION,
|
@@ -325,9 +312,9 @@ def get_matched_package_versions_in_snowflake_conda_channel(
|
|
325
312
|
return matched_versions
|
326
313
|
|
327
314
|
|
328
|
-
def
|
315
|
+
def get_matched_package_versions_in_information_schema(
|
329
316
|
session: session.Session, reqs: List[requirements.Requirement], python_version: str
|
330
|
-
) ->
|
317
|
+
) -> Dict[str, List[version.Version]]:
|
331
318
|
"""Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
|
332
319
|
Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
|
333
320
|
exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
|
@@ -338,42 +325,35 @@ def validate_requirements_in_information_schema(
|
|
338
325
|
python_version: A string of python version where model is run.
|
339
326
|
|
340
327
|
Returns:
|
341
|
-
A
|
328
|
+
A Dict, whose key is the package name, and value is a list of versions match the requirements.
|
342
329
|
"""
|
343
|
-
|
344
|
-
|
345
|
-
if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION is None:
|
346
|
-
_INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = _check_runtime_version_column_existence(session)
|
347
|
-
ret_list = []
|
348
|
-
reqs_to_request = []
|
330
|
+
ret_dict: Dict[str, List[version.Version]] = {}
|
331
|
+
reqs_to_request: List[requirements.Requirement] = []
|
349
332
|
for req in reqs:
|
350
|
-
if req.name
|
333
|
+
if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
|
334
|
+
available_versions = list(
|
335
|
+
sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
|
336
|
+
)
|
337
|
+
ret_dict[req.name] = available_versions
|
338
|
+
else:
|
351
339
|
reqs_to_request.append(req)
|
340
|
+
|
352
341
|
if reqs_to_request:
|
353
342
|
pkg_names_str = " OR ".join(
|
354
343
|
f"package_name = '{req_name}'" for req_name in sorted(req.name for req in reqs_to_request)
|
355
344
|
)
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
else:
|
369
|
-
sql = textwrap.dedent(
|
370
|
-
f"""
|
371
|
-
SELECT PACKAGE_NAME, VERSION
|
372
|
-
FROM information_schema.packages
|
373
|
-
WHERE ({pkg_names_str})
|
374
|
-
AND language = 'python';
|
375
|
-
"""
|
376
|
-
)
|
345
|
+
|
346
|
+
parsed_python_version = version.Version(python_version)
|
347
|
+
sql = textwrap.dedent(
|
348
|
+
f"""
|
349
|
+
SELECT PACKAGE_NAME, VERSION
|
350
|
+
FROM information_schema.packages
|
351
|
+
WHERE ({pkg_names_str})
|
352
|
+
AND language = 'python'
|
353
|
+
AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'
|
354
|
+
OR runtime_version is null);
|
355
|
+
"""
|
356
|
+
)
|
377
357
|
|
378
358
|
try:
|
379
359
|
result = (
|
@@ -392,14 +372,13 @@ def validate_requirements_in_information_schema(
|
|
392
372
|
cached_req_ver_list.append(req_ver)
|
393
373
|
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list
|
394
374
|
except snowflake.connector.DataError:
|
395
|
-
return
|
396
|
-
for req in
|
397
|
-
available_versions = list(
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
return sorted(ret_list)
|
375
|
+
return ret_dict
|
376
|
+
for req in reqs_to_request:
|
377
|
+
available_versions = list(
|
378
|
+
sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
|
379
|
+
)
|
380
|
+
ret_dict[req.name] = available_versions
|
381
|
+
return ret_dict
|
403
382
|
|
404
383
|
|
405
384
|
def save_conda_env_file(
|
@@ -362,3 +362,20 @@ def download_directory_from_stage(
|
|
362
362
|
wait_exponential_multiplier=100,
|
363
363
|
wait_exponential_max=10000,
|
364
364
|
)(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params)
|
365
|
+
|
366
|
+
|
367
|
+
def open_file(path: str, *args: Any, **kwargs: Any) -> Any:
|
368
|
+
"""This function is a wrapper on top of the Python built-in "open" function, with a few added default values
|
369
|
+
to ensure successful execution across different platforms.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
path: file path
|
373
|
+
*args: arguments.
|
374
|
+
**kwargs: key arguments.
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
Open file and return a stream.
|
378
|
+
"""
|
379
|
+
kwargs.setdefault("newline", "\n")
|
380
|
+
kwargs.setdefault("encoding", "utf-8")
|
381
|
+
return open(path, *args, **kwargs)
|
@@ -584,3 +584,22 @@ class _SourceTelemetryClient:
|
|
584
584
|
"""Send the telemetry data batch immediately."""
|
585
585
|
if self._telemetry:
|
586
586
|
self._telemetry.send_batch()
|
587
|
+
|
588
|
+
|
589
|
+
def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: Dict[str, Any]) -> Dict[str, Any]:
|
590
|
+
"""
|
591
|
+
Get statement_params keyword argument for sproc call.
|
592
|
+
|
593
|
+
Args:
|
594
|
+
sproc: sproc function
|
595
|
+
statement_params: dictionary to be passed as statement params, if possible
|
596
|
+
|
597
|
+
Returns:
|
598
|
+
Keyword arguments dict
|
599
|
+
"""
|
600
|
+
sproc_argspec = inspect.getfullargspec(sproc)
|
601
|
+
kwargs = {}
|
602
|
+
if "statement_params" in sproc_argspec.args:
|
603
|
+
kwargs["statement_params"] = statement_params
|
604
|
+
|
605
|
+
return kwargs
|
@@ -60,9 +60,13 @@ def result_dimension_matcher(
|
|
60
60
|
return True
|
61
61
|
|
62
62
|
|
63
|
-
def column_name_matcher(
|
63
|
+
def column_name_matcher(
|
64
|
+
expected_col_name: str, allow_empty: bool, result: list[snowpark.Row], sql: str | None = None
|
65
|
+
) -> bool:
|
64
66
|
"""Returns true if `expected_col_name` is found. Raise exception otherwise."""
|
65
67
|
if not result:
|
68
|
+
if allow_empty:
|
69
|
+
return True
|
66
70
|
raise connector.DataError(f"Query Result is empty.{_query_log(sql)}")
|
67
71
|
if expected_col_name not in result[0]:
|
68
72
|
raise connector.DataError(
|
@@ -159,16 +163,17 @@ class ResultValidator:
|
|
159
163
|
self._success_matchers.append(partial(result_dimension_matcher, expected_rows, expected_cols))
|
160
164
|
return self
|
161
165
|
|
162
|
-
def has_column(self, expected_col_name: str) -> ResultValidator:
|
166
|
+
def has_column(self, expected_col_name: str, allow_empty: bool = False) -> ResultValidator:
|
163
167
|
"""Validate that the a column with the name `expected_column_name` exists in the result.
|
164
168
|
|
165
169
|
Args:
|
166
170
|
expected_col_name: Name of the column that is expected to be present in the result (case sensitive).
|
171
|
+
allow_empty: If the check will fail if the result is empty.
|
167
172
|
|
168
173
|
Returns:
|
169
174
|
ResultValidator object (self)
|
170
175
|
"""
|
171
|
-
self._success_matchers.append(partial(column_name_matcher, expected_col_name))
|
176
|
+
self._success_matchers.append(partial(column_name_matcher, expected_col_name, allow_empty))
|
172
177
|
return self
|
173
178
|
|
174
179
|
def has_named_value_match(self, row_idx: int, col_name: str, expected_value: Any) -> ResultValidator:
|
@@ -224,8 +229,6 @@ class ResultValidator:
|
|
224
229
|
Returns:
|
225
230
|
Query result.
|
226
231
|
"""
|
227
|
-
if len(self._success_matchers) == 0:
|
228
|
-
self._success_matchers = _DEFAULT_MATCHERS
|
229
232
|
result = self._get_result()
|
230
233
|
for matcher in self._success_matchers:
|
231
234
|
assert matcher(result, self._query)
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import enum
|
2
|
+
from typing import Any, Dict, Optional, TypedDict, cast
|
3
|
+
|
4
|
+
from packaging import version
|
5
|
+
from typing_extensions import Required
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import query_result_checker
|
8
|
+
from snowflake.snowpark import session
|
9
|
+
|
10
|
+
|
11
|
+
def get_current_snowflake_version(
|
12
|
+
sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
|
13
|
+
) -> version.Version:
|
14
|
+
"""Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
|
15
|
+
"7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
|
16
|
+
|
17
|
+
Args:
|
18
|
+
sess: Snowpark Session.
|
19
|
+
statement_params: Statement params. Defaults to None.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
The version of Snowflake Version.
|
23
|
+
"""
|
24
|
+
res = (
|
25
|
+
query_result_checker.SqlResultValidator(
|
26
|
+
sess, "SELECT CURRENT_VERSION() AS CURRENT_VERSION", statement_params=statement_params
|
27
|
+
)
|
28
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
29
|
+
.validate()[0]
|
30
|
+
)
|
31
|
+
|
32
|
+
version_str = res.CURRENT_VERSION
|
33
|
+
assert isinstance(version_str, str)
|
34
|
+
|
35
|
+
version_str = "+".join(version_str.split())
|
36
|
+
return version.parse(version_str)
|
37
|
+
|
38
|
+
|
39
|
+
class SnowflakeCloudType(enum.Enum):
|
40
|
+
AWS = "aws"
|
41
|
+
AZURE = "azure"
|
42
|
+
GCP = "gcp"
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def from_value(cls, value: str) -> "SnowflakeCloudType":
|
46
|
+
assert value
|
47
|
+
for k in cls:
|
48
|
+
if k.value == value.lower():
|
49
|
+
return k
|
50
|
+
else:
|
51
|
+
raise ValueError(f"'{cls.__name__}' enum not found for '{value}'")
|
52
|
+
|
53
|
+
|
54
|
+
class SnowflakeRegion(TypedDict):
|
55
|
+
region_group: Required[str]
|
56
|
+
snowflake_region: Required[str]
|
57
|
+
cloud: Required[SnowflakeCloudType]
|
58
|
+
region: Required[str]
|
59
|
+
display_name: Required[str]
|
60
|
+
|
61
|
+
|
62
|
+
def get_regions(
|
63
|
+
sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
|
64
|
+
) -> Dict[str, SnowflakeRegion]:
|
65
|
+
res = (
|
66
|
+
query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
|
67
|
+
.has_column("region_group")
|
68
|
+
.has_column("snowflake_region")
|
69
|
+
.has_column("cloud")
|
70
|
+
.has_column("region")
|
71
|
+
.has_column("display_name")
|
72
|
+
.validate()
|
73
|
+
)
|
74
|
+
return {
|
75
|
+
f"{r.region_group}.{r.snowflake_region}": SnowflakeRegion(
|
76
|
+
region_group=r.region_group,
|
77
|
+
snowflake_region=r.snowflake_region,
|
78
|
+
cloud=SnowflakeCloudType.from_value(r.cloud),
|
79
|
+
region=r.region,
|
80
|
+
display_name=r.display_name,
|
81
|
+
)
|
82
|
+
for r in res
|
83
|
+
}
|
84
|
+
|
85
|
+
|
86
|
+
def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
|
87
|
+
res = (
|
88
|
+
query_result_checker.SqlResultValidator(
|
89
|
+
sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
|
90
|
+
)
|
91
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
92
|
+
.validate()[0]
|
93
|
+
)
|
94
|
+
|
95
|
+
return cast(str, res.CURRENT_REGION)
|
@@ -1,4 +1,6 @@
|
|
1
1
|
import collections
|
2
|
+
import logging
|
3
|
+
import time
|
2
4
|
from typing import Any, Deque, Dict, Iterator, List
|
3
5
|
|
4
6
|
import fsspec
|
@@ -83,7 +85,7 @@ class ParquetParser:
|
|
83
85
|
np.random.shuffle(files)
|
84
86
|
pa_dataset: ds.Dataset = ds.dataset(files, format="parquet", filesystem=self._fs)
|
85
87
|
|
86
|
-
for rb in pa_dataset
|
88
|
+
for rb in _retryable_batches(pa_dataset, batch_size=self._dataset_batch_size):
|
87
89
|
if self._shuffle:
|
88
90
|
rb = rb.take(np.random.permutation(rb.num_rows))
|
89
91
|
self._rb_buffer.append(rb)
|
@@ -138,3 +140,31 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
|
138
140
|
array = column.to_numpy(zero_copy_only=False)
|
139
141
|
batch_dict[column_schema.name] = array
|
140
142
|
return batch_dict
|
143
|
+
|
144
|
+
|
145
|
+
def _retryable_batches(
|
146
|
+
dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
147
|
+
) -> Iterator[pa.RecordBatch]:
|
148
|
+
"""Make the Dataset to_batches retryable."""
|
149
|
+
retries = 0
|
150
|
+
current_batch_index = 0
|
151
|
+
|
152
|
+
while True:
|
153
|
+
try:
|
154
|
+
for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
|
155
|
+
if batch_index < current_batch_index:
|
156
|
+
# Skip batches that have already been processed
|
157
|
+
continue
|
158
|
+
|
159
|
+
yield batch
|
160
|
+
current_batch_index = batch_index + 1
|
161
|
+
# Exit the loop once all batches are processed
|
162
|
+
break
|
163
|
+
|
164
|
+
except Exception as e:
|
165
|
+
if retries < max_retries:
|
166
|
+
retries += 1
|
167
|
+
logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
|
168
|
+
time.sleep(delay)
|
169
|
+
else:
|
170
|
+
raise e
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from snowflake.ml.model._client.model.model_impl import Model
|
2
|
+
from snowflake.ml.model._client.model.model_version_impl import ModelVersion
|
3
|
+
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
4
|
+
from snowflake.ml.model.models.llm import LLM, LLMOptions
|
5
|
+
|
6
|
+
__all__ = ["Model", "ModelVersion", "HuggingFacePipelineModel", "LLM", "LLMOptions"]
|
@@ -1,7 +1,9 @@
|
|
1
|
-
from typing import List, Union
|
1
|
+
from typing import Dict, List, Optional, Tuple, Union
|
2
|
+
|
3
|
+
import pandas as pd
|
2
4
|
|
3
5
|
from snowflake.ml._internal import telemetry
|
4
|
-
from snowflake.ml._internal.utils import sql_identifier
|
6
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
5
7
|
from snowflake.ml.model._client.model import model_version_impl
|
6
8
|
from snowflake.ml.model._client.ops import model_ops
|
7
9
|
|
@@ -37,10 +39,12 @@ class Model:
|
|
37
39
|
|
38
40
|
@property
|
39
41
|
def name(self) -> str:
|
42
|
+
"""Return the name of the model that can be used to refer to it in SQL."""
|
40
43
|
return self._model_name.identifier()
|
41
44
|
|
42
45
|
@property
|
43
46
|
def fully_qualified_name(self) -> str:
|
47
|
+
"""Return the fully qualified name of the model that can be used to refer to it in SQL."""
|
44
48
|
return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
|
45
49
|
|
46
50
|
@property
|
@@ -49,6 +53,24 @@ class Model:
|
|
49
53
|
subproject=_TELEMETRY_SUBPROJECT,
|
50
54
|
)
|
51
55
|
def description(self) -> str:
|
56
|
+
"""The description for the model. This is an alias of `comment`."""
|
57
|
+
return self.comment
|
58
|
+
|
59
|
+
@description.setter
|
60
|
+
@telemetry.send_api_usage_telemetry(
|
61
|
+
project=_TELEMETRY_PROJECT,
|
62
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
63
|
+
)
|
64
|
+
def description(self, description: str) -> None:
|
65
|
+
self.comment = description
|
66
|
+
|
67
|
+
@property
|
68
|
+
@telemetry.send_api_usage_telemetry(
|
69
|
+
project=_TELEMETRY_PROJECT,
|
70
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
71
|
+
)
|
72
|
+
def comment(self) -> str:
|
73
|
+
"""The comment to the model."""
|
52
74
|
statement_params = telemetry.get_statement_params(
|
53
75
|
project=_TELEMETRY_PROJECT,
|
54
76
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -58,18 +80,18 @@ class Model:
|
|
58
80
|
statement_params=statement_params,
|
59
81
|
)
|
60
82
|
|
61
|
-
@
|
83
|
+
@comment.setter
|
62
84
|
@telemetry.send_api_usage_telemetry(
|
63
85
|
project=_TELEMETRY_PROJECT,
|
64
86
|
subproject=_TELEMETRY_SUBPROJECT,
|
65
87
|
)
|
66
|
-
def
|
88
|
+
def comment(self, comment: str) -> None:
|
67
89
|
statement_params = telemetry.get_statement_params(
|
68
90
|
project=_TELEMETRY_PROJECT,
|
69
91
|
subproject=_TELEMETRY_SUBPROJECT,
|
70
92
|
)
|
71
93
|
return self._model_ops.set_comment(
|
72
|
-
comment=
|
94
|
+
comment=comment,
|
73
95
|
model_name=self._model_name,
|
74
96
|
statement_params=statement_params,
|
75
97
|
)
|
@@ -80,12 +102,13 @@ class Model:
|
|
80
102
|
subproject=_TELEMETRY_SUBPROJECT,
|
81
103
|
)
|
82
104
|
def default(self) -> model_version_impl.ModelVersion:
|
105
|
+
"""The default version of the model."""
|
83
106
|
statement_params = telemetry.get_statement_params(
|
84
107
|
project=_TELEMETRY_PROJECT,
|
85
108
|
subproject=_TELEMETRY_SUBPROJECT,
|
86
109
|
class_name=self.__class__.__name__,
|
87
110
|
)
|
88
|
-
default_version_name = self._model_ops.
|
111
|
+
default_version_name = self._model_ops.get_default_version(
|
89
112
|
model_name=self._model_name, statement_params=statement_params
|
90
113
|
)
|
91
114
|
return self.version(default_version_name)
|
@@ -105,7 +128,7 @@ class Model:
|
|
105
128
|
version_name = sql_identifier.SqlIdentifier(version)
|
106
129
|
else:
|
107
130
|
version_name = version._version_name
|
108
|
-
self._model_ops.
|
131
|
+
self._model_ops.set_default_version(
|
109
132
|
model_name=self._model_name, version_name=version_name, statement_params=statement_params
|
110
133
|
)
|
111
134
|
|
@@ -114,13 +137,14 @@ class Model:
|
|
114
137
|
subproject=_TELEMETRY_SUBPROJECT,
|
115
138
|
)
|
116
139
|
def version(self, version_name: str) -> model_version_impl.ModelVersion:
|
117
|
-
"""
|
140
|
+
"""
|
141
|
+
Get a model version object given a version name in the model.
|
118
142
|
|
119
143
|
Args:
|
120
|
-
version_name: The name of version
|
144
|
+
version_name: The name of the version.
|
121
145
|
|
122
146
|
Raises:
|
123
|
-
ValueError:
|
147
|
+
ValueError: When the requested version does not exist.
|
124
148
|
|
125
149
|
Returns:
|
126
150
|
The model version object.
|
@@ -149,11 +173,11 @@ class Model:
|
|
149
173
|
project=_TELEMETRY_PROJECT,
|
150
174
|
subproject=_TELEMETRY_SUBPROJECT,
|
151
175
|
)
|
152
|
-
def
|
153
|
-
"""
|
176
|
+
def versions(self) -> List[model_version_impl.ModelVersion]:
|
177
|
+
"""Get all versions in the model.
|
154
178
|
|
155
179
|
Returns:
|
156
|
-
A
|
180
|
+
A list of ModelVersion objects representing all versions in the model.
|
157
181
|
"""
|
158
182
|
statement_params = telemetry.get_statement_params(
|
159
183
|
project=_TELEMETRY_PROJECT,
|
@@ -172,5 +196,140 @@ class Model:
|
|
172
196
|
for version_name in version_names
|
173
197
|
]
|
174
198
|
|
199
|
+
@telemetry.send_api_usage_telemetry(
|
200
|
+
project=_TELEMETRY_PROJECT,
|
201
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
202
|
+
)
|
203
|
+
def show_versions(self) -> pd.DataFrame:
|
204
|
+
"""Show information about all versions in the model.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
A Pandas DataFrame showing information about all versions in the model.
|
208
|
+
"""
|
209
|
+
statement_params = telemetry.get_statement_params(
|
210
|
+
project=_TELEMETRY_PROJECT,
|
211
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
212
|
+
)
|
213
|
+
rows = self._model_ops.show_models_or_versions(
|
214
|
+
model_name=self._model_name,
|
215
|
+
statement_params=statement_params,
|
216
|
+
)
|
217
|
+
return pd.DataFrame([row.as_dict() for row in rows])
|
218
|
+
|
175
219
|
def delete_version(self, version_name: str) -> None:
|
176
220
|
raise NotImplementedError("Deleting version has not been supported yet.")
|
221
|
+
|
222
|
+
@telemetry.send_api_usage_telemetry(
|
223
|
+
project=_TELEMETRY_PROJECT,
|
224
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
225
|
+
)
|
226
|
+
def show_tags(self) -> Dict[str, str]:
|
227
|
+
"""Get a dictionary showing the tag and its value attached to the model.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
The model version object.
|
231
|
+
"""
|
232
|
+
statement_params = telemetry.get_statement_params(
|
233
|
+
project=_TELEMETRY_PROJECT,
|
234
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
235
|
+
)
|
236
|
+
return self._model_ops.show_tags(model_name=self._model_name, statement_params=statement_params)
|
237
|
+
|
238
|
+
def _parse_tag_name(
|
239
|
+
self,
|
240
|
+
tag_name: str,
|
241
|
+
) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
|
242
|
+
_tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
|
243
|
+
if _tag_db is None:
|
244
|
+
tag_db_id = self._model_ops._model_client._database_name
|
245
|
+
else:
|
246
|
+
tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
|
247
|
+
|
248
|
+
if _tag_schema is None:
|
249
|
+
tag_schema_id = self._model_ops._model_client._schema_name
|
250
|
+
else:
|
251
|
+
tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
|
252
|
+
|
253
|
+
if _tag_name is None:
|
254
|
+
raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
|
255
|
+
|
256
|
+
tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
|
257
|
+
|
258
|
+
return tag_db_id, tag_schema_id, tag_name_id
|
259
|
+
|
260
|
+
@telemetry.send_api_usage_telemetry(
|
261
|
+
project=_TELEMETRY_PROJECT,
|
262
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
263
|
+
)
|
264
|
+
def get_tag(self, tag_name: str) -> Optional[str]:
|
265
|
+
"""Get the value of a tag attached to the model.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
|
269
|
+
the model will be used.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
The tag value as a string if the tag is attached, otherwise None.
|
273
|
+
"""
|
274
|
+
statement_params = telemetry.get_statement_params(
|
275
|
+
project=_TELEMETRY_PROJECT,
|
276
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
277
|
+
)
|
278
|
+
tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
|
279
|
+
return self._model_ops.get_tag_value(
|
280
|
+
model_name=self._model_name,
|
281
|
+
tag_database_name=tag_db_id,
|
282
|
+
tag_schema_name=tag_schema_id,
|
283
|
+
tag_name=tag_name_id,
|
284
|
+
statement_params=statement_params,
|
285
|
+
)
|
286
|
+
|
287
|
+
@telemetry.send_api_usage_telemetry(
|
288
|
+
project=_TELEMETRY_PROJECT,
|
289
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
290
|
+
)
|
291
|
+
def set_tag(self, tag_name: str, tag_value: str) -> None:
|
292
|
+
"""Set the value of a tag, attaching it to the model if not.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
|
296
|
+
the model will be used.
|
297
|
+
tag_value: The value of the tag
|
298
|
+
"""
|
299
|
+
statement_params = telemetry.get_statement_params(
|
300
|
+
project=_TELEMETRY_PROJECT,
|
301
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
302
|
+
)
|
303
|
+
tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
|
304
|
+
self._model_ops.set_tag(
|
305
|
+
model_name=self._model_name,
|
306
|
+
tag_database_name=tag_db_id,
|
307
|
+
tag_schema_name=tag_schema_id,
|
308
|
+
tag_name=tag_name_id,
|
309
|
+
tag_value=tag_value,
|
310
|
+
statement_params=statement_params,
|
311
|
+
)
|
312
|
+
|
313
|
+
@telemetry.send_api_usage_telemetry(
|
314
|
+
project=_TELEMETRY_PROJECT,
|
315
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
316
|
+
)
|
317
|
+
def unset_tag(self, tag_name: str) -> None:
|
318
|
+
"""Unset a tag attached to a model.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
|
322
|
+
the model will be used.
|
323
|
+
"""
|
324
|
+
statement_params = telemetry.get_statement_params(
|
325
|
+
project=_TELEMETRY_PROJECT,
|
326
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
327
|
+
)
|
328
|
+
tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
|
329
|
+
self._model_ops.unset_tag(
|
330
|
+
model_name=self._model_name,
|
331
|
+
tag_database_name=tag_db_id,
|
332
|
+
tag_schema_name=tag_schema_id,
|
333
|
+
tag_name=tag_name_id,
|
334
|
+
statement_params=statement_params,
|
335
|
+
)
|