snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +283 -0
- snowflake/ml/feature_store/feature_store.py +160 -100
- snowflake/ml/feature_store/feature_view.py +30 -19
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +2 -30
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +174 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +111 -42
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +8 -4
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
- snowflake/ml/modeling/cluster/birch.py +8 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
- snowflake/ml/modeling/cluster/dbscan.py +8 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
- snowflake/ml/modeling/cluster/k_means.py +8 -1
- snowflake/ml/modeling/cluster/mean_shift.py +8 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
- snowflake/ml/modeling/cluster/optics.py +8 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
- snowflake/ml/modeling/compose/column_transformer.py +8 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
- snowflake/ml/modeling/covariance/oas.py +8 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/pca.py +8 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
- snowflake/ml/modeling/impute/knn_imputer.py +8 -1
- snowflake/ml/modeling/impute/missing_indicator.py +8 -1
- snowflake/ml/modeling/impute/simple_imputer.py +21 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/lars.py +8 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/perceptron.py +8 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ridge.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
- snowflake/ml/modeling/manifold/isomap.py +8 -1
- snowflake/ml/modeling/manifold/mds.py +8 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
- snowflake/ml/modeling/manifold/tsne.py +8 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
- snowflake/ml/modeling/svm/linear_svc.py +8 -1
- snowflake/ml/modeling/svm/linear_svr.py +8 -1
- snowflake/ml/modeling/svm/nu_svc.py +8 -1
- snowflake/ml/modeling/svm/nu_svr.py +8 -1
- snowflake/ml/modeling/svm/svc.py +8 -1
- snowflake/ml/modeling/svm/svr.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
- snowflake/ml/registry/_manager/model_manager.py +95 -8
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
snowflake/ml/fileset/snowfs.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
import collections
|
2
2
|
import logging
|
3
3
|
import re
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
|
6
6
|
import fsspec
|
7
|
-
import packaging.version as pkg_version
|
8
7
|
|
9
8
|
from snowflake import snowpark
|
10
9
|
from snowflake.connector import connection
|
@@ -12,7 +11,7 @@ from snowflake.ml._internal.exceptions import (
|
|
12
11
|
error_codes,
|
13
12
|
exceptions as snowml_exceptions,
|
14
13
|
)
|
15
|
-
from snowflake.ml._internal.utils import identifier
|
14
|
+
from snowflake.ml._internal.utils import identifier
|
16
15
|
from snowflake.ml.fileset import embedded_stage_fs, sfcfs
|
17
16
|
|
18
17
|
PROTOCOL_NAME = "snow"
|
@@ -28,10 +27,6 @@ _SNOWURL_PATTERN = re.compile(
|
|
28
27
|
r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
|
29
28
|
)
|
30
29
|
|
31
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
32
|
-
_BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior
|
33
|
-
_BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior
|
34
|
-
|
35
30
|
|
36
31
|
class SnowFileSystem(sfcfs.SFFileSystem):
|
37
32
|
"""A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations.
|
@@ -54,21 +49,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
54
49
|
) -> None:
|
55
50
|
super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs)
|
56
51
|
|
57
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
58
|
-
if SnowFileSystem._IS_BUGGED_VERSION is None:
|
59
|
-
try:
|
60
|
-
sf_version = snowflake_env.get_current_snowflake_version(self._session)
|
61
|
-
SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX
|
62
|
-
except Exception:
|
63
|
-
SnowFileSystem._IS_BUGGED_VERSION = False
|
64
|
-
|
65
|
-
def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
|
66
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
67
|
-
res: Dict[str, Any] = super().info(path, **kwargs)
|
68
|
-
if res.get("type") == "directory" and not res["name"].endswith("/"):
|
69
|
-
res["name"] += "/"
|
70
|
-
return res
|
71
|
-
|
72
52
|
def _get_stage_fs(
|
73
53
|
self, sf_file_path: _SFFileEntityPath # type: ignore[override]
|
74
54
|
) -> embedded_stage_fs.SFEmbeddedStageFileSystem:
|
@@ -100,11 +80,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
100
80
|
if stage_name.startswith(protocol):
|
101
81
|
stage_name = stage_name[len(protocol) :]
|
102
82
|
abs_path = stage_name + "/" + path
|
103
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
104
|
-
if self._IS_BUGGED_VERSION:
|
105
|
-
match = _SNOWURL_PATTERN.fullmatch(abs_path)
|
106
|
-
assert match is not None
|
107
|
-
abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/"))
|
108
83
|
return abs_path
|
109
84
|
|
110
85
|
@classmethod
|
@@ -143,9 +118,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
143
118
|
version = snowurl_match.group("version")
|
144
119
|
relative_path = snowurl_match.group("relpath") or ""
|
145
120
|
logging.debug(f"Parsed snow URL: {snowurl_match.groups()}")
|
146
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
147
|
-
if cls._IS_BUGGED_VERSION:
|
148
|
-
filepath = filepath.replace(f"{version}/", f"{version}//")
|
149
121
|
return _SFFileEntityPath(
|
150
122
|
domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath
|
151
123
|
)
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -2,13 +2,13 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
6
6
|
|
7
7
|
import fsspec
|
8
8
|
from fsspec.implementations import http as httpfs
|
9
9
|
|
10
10
|
from snowflake import snowpark
|
11
|
-
from snowflake.connector import connection, errorcode
|
11
|
+
from snowflake.connector import connection, errorcode, errors as snowpark_errors
|
12
12
|
from snowflake.ml._internal import telemetry
|
13
13
|
from snowflake.ml._internal.exceptions import (
|
14
14
|
error_codes,
|
@@ -18,6 +18,7 @@ from snowflake.ml._internal.exceptions import (
|
|
18
18
|
)
|
19
19
|
from snowflake.snowpark import exceptions as snowpark_exceptions
|
20
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
21
|
+
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
21
22
|
|
22
23
|
# The default length of how long a presigned url stays active in seconds.
|
23
24
|
# Presigned url here is used to fetch file objects from Snowflake when SFStageFileSystem.open() is called.
|
@@ -167,7 +168,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
167
168
|
try:
|
168
169
|
loc = self.stage_name
|
169
170
|
path = path.lstrip("/")
|
170
|
-
|
171
|
+
async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
|
172
|
+
objects: List[snowpark.Row] = _resolve_async_job(async_job)
|
171
173
|
except snowpark_exceptions.SnowparkClientException as e:
|
172
174
|
if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST):
|
173
175
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -289,9 +291,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
289
291
|
original_exception=e,
|
290
292
|
)
|
291
293
|
|
292
|
-
def _parse_list_result(
|
293
|
-
self, list_result: List[Tuple[str, int, str, str]], search_path: str
|
294
|
-
) -> List[Dict[str, Any]]:
|
294
|
+
def _parse_list_result(self, list_result: List[snowpark.Row], search_path: str) -> List[Dict[str, Any]]:
|
295
295
|
"""Convert the result from LIST query to the expected format of fsspec ls() method.
|
296
296
|
|
297
297
|
Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
|
@@ -312,7 +312,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
312
312
|
"""
|
313
313
|
files: Dict[str, Dict[str, Any]] = {}
|
314
314
|
search_path = search_path.strip("/")
|
315
|
-
for
|
315
|
+
for row in list_result:
|
316
|
+
name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
|
316
317
|
obj_path = self._stage_path_to_relative_path(name)
|
317
318
|
if obj_path == search_path:
|
318
319
|
# If there is a exact match, then the matched object will always be a file object.
|
@@ -408,3 +409,20 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
|
|
408
409
|
# Snowpark writes error code to message instead of populating e.error_code
|
409
410
|
error_code_str = str(error_code)
|
410
411
|
return ex.error_code == error_code_str or error_code_str in ex.message
|
412
|
+
|
413
|
+
|
414
|
+
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
415
|
+
def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
|
416
|
+
# Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
|
417
|
+
try:
|
418
|
+
query_result = cast(List[snowpark.Row], async_job.result("row"))
|
419
|
+
return query_result
|
420
|
+
except snowpark_errors.DatabaseError as e:
|
421
|
+
# HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
|
422
|
+
# assume it's due to FileNotFound
|
423
|
+
if type(e) is snowpark_errors.DatabaseError and "results are unavailable" in str(e):
|
424
|
+
raise snowml_exceptions.SnowflakeMLException(
|
425
|
+
error_code=error_codes.SNOWML_NOT_FOUND,
|
426
|
+
original_exception=fileset_errors.StageNotFoundError("Query failed."),
|
427
|
+
) from e
|
428
|
+
raise
|
@@ -1,9 +1,9 @@
|
|
1
|
-
from typing import Dict, List, Optional,
|
1
|
+
from typing import Dict, List, Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
|
-
from snowflake.ml._internal.utils import
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
7
7
|
from snowflake.ml.model._client.model import model_version_impl
|
8
8
|
from snowflake.ml.model._client.ops import model_ops
|
9
9
|
|
@@ -45,7 +45,7 @@ class Model:
|
|
45
45
|
@property
|
46
46
|
def fully_qualified_name(self) -> str:
|
47
47
|
"""Return the fully qualified name of the model that can be used to refer to it in SQL."""
|
48
|
-
return self._model_ops._model_version_client.
|
48
|
+
return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
|
49
49
|
|
50
50
|
@property
|
51
51
|
@telemetry.send_api_usage_telemetry(
|
@@ -76,6 +76,8 @@ class Model:
|
|
76
76
|
subproject=_TELEMETRY_SUBPROJECT,
|
77
77
|
)
|
78
78
|
return self._model_ops.get_comment(
|
79
|
+
database_name=None,
|
80
|
+
schema_name=None,
|
79
81
|
model_name=self._model_name,
|
80
82
|
statement_params=statement_params,
|
81
83
|
)
|
@@ -92,6 +94,8 @@ class Model:
|
|
92
94
|
)
|
93
95
|
return self._model_ops.set_comment(
|
94
96
|
comment=comment,
|
97
|
+
database_name=None,
|
98
|
+
schema_name=None,
|
95
99
|
model_name=self._model_name,
|
96
100
|
statement_params=statement_params,
|
97
101
|
)
|
@@ -109,7 +113,7 @@ class Model:
|
|
109
113
|
class_name=self.__class__.__name__,
|
110
114
|
)
|
111
115
|
default_version_name = self._model_ops.get_default_version(
|
112
|
-
model_name=self._model_name, statement_params=statement_params
|
116
|
+
database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
|
113
117
|
)
|
114
118
|
return self.version(default_version_name)
|
115
119
|
|
@@ -129,7 +133,11 @@ class Model:
|
|
129
133
|
else:
|
130
134
|
version_name = version._version_name
|
131
135
|
self._model_ops.set_default_version(
|
132
|
-
|
136
|
+
database_name=None,
|
137
|
+
schema_name=None,
|
138
|
+
model_name=self._model_name,
|
139
|
+
version_name=version_name,
|
140
|
+
statement_params=statement_params,
|
133
141
|
)
|
134
142
|
|
135
143
|
@telemetry.send_api_usage_telemetry(
|
@@ -155,6 +163,8 @@ class Model:
|
|
155
163
|
)
|
156
164
|
version_id = sql_identifier.SqlIdentifier(version_name)
|
157
165
|
if self._model_ops.validate_existence(
|
166
|
+
database_name=None,
|
167
|
+
schema_name=None,
|
158
168
|
model_name=self._model_name,
|
159
169
|
version_name=version_id,
|
160
170
|
statement_params=statement_params,
|
@@ -184,6 +194,8 @@ class Model:
|
|
184
194
|
subproject=_TELEMETRY_SUBPROJECT,
|
185
195
|
)
|
186
196
|
version_names = self._model_ops.list_models_or_versions(
|
197
|
+
database_name=None,
|
198
|
+
schema_name=None,
|
187
199
|
model_name=self._model_name,
|
188
200
|
statement_params=statement_params,
|
189
201
|
)
|
@@ -211,6 +223,8 @@ class Model:
|
|
211
223
|
subproject=_TELEMETRY_SUBPROJECT,
|
212
224
|
)
|
213
225
|
rows = self._model_ops.show_models_or_versions(
|
226
|
+
database_name=None,
|
227
|
+
schema_name=None,
|
214
228
|
model_name=self._model_name,
|
215
229
|
statement_params=statement_params,
|
216
230
|
)
|
@@ -231,6 +245,8 @@ class Model:
|
|
231
245
|
subproject=_TELEMETRY_SUBPROJECT,
|
232
246
|
)
|
233
247
|
self._model_ops.delete_model_or_version(
|
248
|
+
database_name=None,
|
249
|
+
schema_name=None,
|
234
250
|
model_name=self._model_name,
|
235
251
|
version_name=sql_identifier.SqlIdentifier(version_name),
|
236
252
|
statement_params=statement_params,
|
@@ -250,29 +266,9 @@ class Model:
|
|
250
266
|
project=_TELEMETRY_PROJECT,
|
251
267
|
subproject=_TELEMETRY_SUBPROJECT,
|
252
268
|
)
|
253
|
-
return self._model_ops.show_tags(
|
254
|
-
|
255
|
-
|
256
|
-
self,
|
257
|
-
tag_name: str,
|
258
|
-
) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
|
259
|
-
_tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
|
260
|
-
if _tag_db is None:
|
261
|
-
tag_db_id = self._model_ops._model_client._database_name
|
262
|
-
else:
|
263
|
-
tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
|
264
|
-
|
265
|
-
if _tag_schema is None:
|
266
|
-
tag_schema_id = self._model_ops._model_client._schema_name
|
267
|
-
else:
|
268
|
-
tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
|
269
|
-
|
270
|
-
if _tag_name is None:
|
271
|
-
raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
|
272
|
-
|
273
|
-
tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
|
274
|
-
|
275
|
-
return tag_db_id, tag_schema_id, tag_name_id
|
269
|
+
return self._model_ops.show_tags(
|
270
|
+
database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
|
271
|
+
)
|
276
272
|
|
277
273
|
@telemetry.send_api_usage_telemetry(
|
278
274
|
project=_TELEMETRY_PROJECT,
|
@@ -292,8 +288,10 @@ class Model:
|
|
292
288
|
project=_TELEMETRY_PROJECT,
|
293
289
|
subproject=_TELEMETRY_SUBPROJECT,
|
294
290
|
)
|
295
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
291
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
296
292
|
return self._model_ops.get_tag_value(
|
293
|
+
database_name=None,
|
294
|
+
schema_name=None,
|
297
295
|
model_name=self._model_name,
|
298
296
|
tag_database_name=tag_db_id,
|
299
297
|
tag_schema_name=tag_schema_id,
|
@@ -317,8 +315,10 @@ class Model:
|
|
317
315
|
project=_TELEMETRY_PROJECT,
|
318
316
|
subproject=_TELEMETRY_SUBPROJECT,
|
319
317
|
)
|
320
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
318
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
321
319
|
self._model_ops.set_tag(
|
320
|
+
database_name=None,
|
321
|
+
schema_name=None,
|
322
322
|
model_name=self._model_name,
|
323
323
|
tag_database_name=tag_db_id,
|
324
324
|
tag_schema_name=tag_schema_id,
|
@@ -342,8 +342,10 @@ class Model:
|
|
342
342
|
project=_TELEMETRY_PROJECT,
|
343
343
|
subproject=_TELEMETRY_SUBPROJECT,
|
344
344
|
)
|
345
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
345
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
346
346
|
self._model_ops.unset_tag(
|
347
|
+
database_name=None,
|
348
|
+
schema_name=None,
|
347
349
|
model_name=self._model_name,
|
348
350
|
tag_database_name=tag_db_id,
|
349
351
|
tag_schema_name=tag_schema_id,
|
@@ -365,15 +367,20 @@ class Model:
|
|
365
367
|
project=_TELEMETRY_PROJECT,
|
366
368
|
subproject=_TELEMETRY_SUBPROJECT,
|
367
369
|
)
|
368
|
-
|
369
|
-
|
370
|
-
new_model_schema = sql_identifier.SqlIdentifier(schema) if schema else None
|
371
|
-
new_model_id = sql_identifier.SqlIdentifier(model)
|
370
|
+
new_db, new_schema, new_model = sql_identifier.parse_fully_qualified_name(model_name)
|
371
|
+
|
372
372
|
self._model_ops.rename(
|
373
|
+
database_name=None,
|
374
|
+
schema_name=None,
|
373
375
|
model_name=self._model_name,
|
374
|
-
new_model_db=
|
375
|
-
new_model_schema=
|
376
|
-
new_model_name=
|
376
|
+
new_model_db=new_db,
|
377
|
+
new_model_schema=new_schema,
|
378
|
+
new_model_name=new_model,
|
377
379
|
statement_params=statement_params,
|
378
380
|
)
|
379
|
-
self.
|
381
|
+
self._model_ops = model_ops.ModelOperator(
|
382
|
+
self._model_ops._session,
|
383
|
+
database_name=new_db or self._model_ops._model_client._database_name,
|
384
|
+
schema_name=new_schema or self._model_ops._model_client._schema_name,
|
385
|
+
)
|
386
|
+
self._model_name = new_model
|
@@ -72,7 +72,7 @@ class ModelVersion:
|
|
72
72
|
@property
|
73
73
|
def fully_qualified_model_name(self) -> str:
|
74
74
|
"""Return the fully qualified name of the model to which the model version belongs."""
|
75
|
-
return self._model_ops._model_version_client.
|
75
|
+
return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
|
76
76
|
|
77
77
|
@property
|
78
78
|
@telemetry.send_api_usage_telemetry(
|
@@ -103,6 +103,8 @@ class ModelVersion:
|
|
103
103
|
subproject=_TELEMETRY_SUBPROJECT,
|
104
104
|
)
|
105
105
|
return self._model_ops.get_comment(
|
106
|
+
database_name=None,
|
107
|
+
schema_name=None,
|
106
108
|
model_name=self._model_name,
|
107
109
|
version_name=self._version_name,
|
108
110
|
statement_params=statement_params,
|
@@ -120,6 +122,8 @@ class ModelVersion:
|
|
120
122
|
)
|
121
123
|
return self._model_ops.set_comment(
|
122
124
|
comment=comment,
|
125
|
+
database_name=None,
|
126
|
+
schema_name=None,
|
123
127
|
model_name=self._model_name,
|
124
128
|
version_name=self._version_name,
|
125
129
|
statement_params=statement_params,
|
@@ -140,7 +144,11 @@ class ModelVersion:
|
|
140
144
|
subproject=_TELEMETRY_SUBPROJECT,
|
141
145
|
)
|
142
146
|
return self._model_ops._metadata_ops.load(
|
143
|
-
|
147
|
+
database_name=None,
|
148
|
+
schema_name=None,
|
149
|
+
model_name=self._model_name,
|
150
|
+
version_name=self._version_name,
|
151
|
+
statement_params=statement_params,
|
144
152
|
)["metrics"]
|
145
153
|
|
146
154
|
@telemetry.send_api_usage_telemetry(
|
@@ -183,6 +191,8 @@ class ModelVersion:
|
|
183
191
|
metrics[metric_name] = value
|
184
192
|
self._model_ops._metadata_ops.save(
|
185
193
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
194
|
+
database_name=None,
|
195
|
+
schema_name=None,
|
186
196
|
model_name=self._model_name,
|
187
197
|
version_name=self._version_name,
|
188
198
|
statement_params=statement_params,
|
@@ -211,6 +221,8 @@ class ModelVersion:
|
|
211
221
|
del metrics[metric_name]
|
212
222
|
self._model_ops._metadata_ops.save(
|
213
223
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
224
|
+
database_name=None,
|
225
|
+
schema_name=None,
|
214
226
|
model_name=self._model_name,
|
215
227
|
version_name=self._version_name,
|
216
228
|
statement_params=statement_params,
|
@@ -222,6 +234,8 @@ class ModelVersion:
|
|
222
234
|
subproject=_TELEMETRY_SUBPROJECT,
|
223
235
|
)
|
224
236
|
return self._model_ops.get_functions(
|
237
|
+
database_name=None,
|
238
|
+
schema_name=None,
|
225
239
|
model_name=self._model_name,
|
226
240
|
version_name=self._version_name,
|
227
241
|
statement_params=statement_params,
|
@@ -309,6 +323,8 @@ class ModelVersion:
|
|
309
323
|
method_function_type=target_function_info["target_method_function_type"],
|
310
324
|
signature=target_function_info["signature"],
|
311
325
|
X=X,
|
326
|
+
database_name=None,
|
327
|
+
schema_name=None,
|
312
328
|
model_name=self._model_name,
|
313
329
|
version_name=self._version_name,
|
314
330
|
strict_input_validation=strict_input_validation,
|
@@ -341,6 +357,8 @@ class ModelVersion:
|
|
341
357
|
subproject=_TELEMETRY_SUBPROJECT,
|
342
358
|
)
|
343
359
|
self._model_ops.download_files(
|
360
|
+
database_name=None,
|
361
|
+
schema_name=None,
|
344
362
|
model_name=self._model_name,
|
345
363
|
version_name=self._version_name,
|
346
364
|
target_path=target_local_path,
|
@@ -380,6 +398,8 @@ class ModelVersion:
|
|
380
398
|
with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
|
381
399
|
ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
|
382
400
|
self._model_ops.download_files(
|
401
|
+
database_name=None,
|
402
|
+
schema_name=None,
|
383
403
|
model_name=self._model_name,
|
384
404
|
version_name=self._version_name,
|
385
405
|
target_path=ws_path_for_validation,
|
@@ -417,6 +437,8 @@ class ModelVersion:
|
|
417
437
|
# We need the folder to be existed.
|
418
438
|
workspace = pathlib.Path(tempfile.mkdtemp())
|
419
439
|
self._model_ops.download_files(
|
440
|
+
database_name=None,
|
441
|
+
schema_name=None,
|
420
442
|
model_name=self._model_name,
|
421
443
|
version_name=self._version_name,
|
422
444
|
target_path=workspace,
|
@@ -61,12 +61,18 @@ class MetadataOperator:
|
|
61
61
|
def _get_current_metadata_dict(
|
62
62
|
self,
|
63
63
|
*,
|
64
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
65
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
64
66
|
model_name: sql_identifier.SqlIdentifier,
|
65
67
|
version_name: sql_identifier.SqlIdentifier,
|
66
68
|
statement_params: Optional[Dict[str, Any]] = None,
|
67
69
|
) -> Dict[str, Any]:
|
68
70
|
version_info_list = self._model_client.show_versions(
|
69
|
-
|
71
|
+
database_name=database_name,
|
72
|
+
schema_name=schema_name,
|
73
|
+
model_name=model_name,
|
74
|
+
version_name=version_name,
|
75
|
+
statement_params=statement_params,
|
70
76
|
)
|
71
77
|
metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
|
72
78
|
if not metadata_str:
|
@@ -79,12 +85,18 @@ class MetadataOperator:
|
|
79
85
|
def load(
|
80
86
|
self,
|
81
87
|
*,
|
88
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
89
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
82
90
|
model_name: sql_identifier.SqlIdentifier,
|
83
91
|
version_name: sql_identifier.SqlIdentifier,
|
84
92
|
statement_params: Optional[Dict[str, Any]] = None,
|
85
93
|
) -> ModelVersionMetadataSchema:
|
86
94
|
metadata_dict = self._get_current_metadata_dict(
|
87
|
-
|
95
|
+
database_name=database_name,
|
96
|
+
schema_name=schema_name,
|
97
|
+
model_name=model_name,
|
98
|
+
version_name=version_name,
|
99
|
+
statement_params=statement_params,
|
88
100
|
)
|
89
101
|
return MetadataOperator._parse(metadata_dict)
|
90
102
|
|
@@ -92,14 +104,25 @@ class MetadataOperator:
|
|
92
104
|
self,
|
93
105
|
metadata: ModelVersionMetadataSchema,
|
94
106
|
*,
|
107
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
108
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
95
109
|
model_name: sql_identifier.SqlIdentifier,
|
96
110
|
version_name: sql_identifier.SqlIdentifier,
|
97
111
|
statement_params: Optional[Dict[str, Any]] = None,
|
98
112
|
) -> None:
|
99
113
|
metadata_dict = self._get_current_metadata_dict(
|
100
|
-
|
114
|
+
database_name=database_name,
|
115
|
+
schema_name=schema_name,
|
116
|
+
model_name=model_name,
|
117
|
+
version_name=version_name,
|
118
|
+
statement_params=statement_params,
|
101
119
|
)
|
102
120
|
metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
|
103
121
|
self._model_version_client.set_metadata(
|
104
|
-
metadata_dict,
|
122
|
+
metadata_dict,
|
123
|
+
database_name=database_name,
|
124
|
+
schema_name=schema_name,
|
125
|
+
model_name=model_name,
|
126
|
+
version_name=version_name,
|
127
|
+
statement_params=statement_params,
|
105
128
|
)
|