snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,11 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModelSQLClient:
|
8
|
+
class ModelSQLClient(_base._BaseSQLClient):
|
12
9
|
MODEL_NAME_COL_NAME = "name"
|
13
10
|
MODEL_COMMENT_COL_NAME = "comment"
|
14
11
|
MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
|
@@ -18,35 +15,18 @@ class ModelSQLClient:
|
|
18
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
19
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
20
17
|
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
session: session.Session,
|
24
|
-
*,
|
25
|
-
database_name: sql_identifier.SqlIdentifier,
|
26
|
-
schema_name: sql_identifier.SqlIdentifier,
|
27
|
-
) -> None:
|
28
|
-
self._session = session
|
29
|
-
self._database_name = database_name
|
30
|
-
self._schema_name = schema_name
|
31
|
-
|
32
|
-
def __eq__(self, __value: object) -> bool:
|
33
|
-
if not isinstance(__value, ModelSQLClient):
|
34
|
-
return False
|
35
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
36
|
-
|
37
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
38
|
-
return identifier.get_schema_level_object_identifier(
|
39
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
40
|
-
)
|
41
|
-
|
42
18
|
def show_models(
|
43
19
|
self,
|
44
20
|
*,
|
21
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
22
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
45
23
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
46
24
|
validate_result: bool = True,
|
47
25
|
statement_params: Optional[Dict[str, Any]] = None,
|
48
26
|
) -> List[row.Row]:
|
49
|
-
|
27
|
+
actual_database_name = database_name or self._database_name
|
28
|
+
actual_schema_name = schema_name or self._schema_name
|
29
|
+
fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
|
50
30
|
like_sql = ""
|
51
31
|
if model_name:
|
52
32
|
like_sql = f" LIKE '{model_name.resolved()}'"
|
@@ -69,6 +49,8 @@ class ModelSQLClient:
|
|
69
49
|
def show_versions(
|
70
50
|
self,
|
71
51
|
*,
|
52
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
72
54
|
model_name: sql_identifier.SqlIdentifier,
|
73
55
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
74
56
|
validate_result: bool = True,
|
@@ -82,7 +64,10 @@ class ModelSQLClient:
|
|
82
64
|
res = (
|
83
65
|
query_result_checker.SqlResultValidator(
|
84
66
|
self._session,
|
85
|
-
|
67
|
+
(
|
68
|
+
f"SHOW VERSIONS{like_sql} IN "
|
69
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
70
|
+
),
|
86
71
|
statement_params=statement_params,
|
87
72
|
)
|
88
73
|
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
@@ -99,43 +84,53 @@ class ModelSQLClient:
|
|
99
84
|
def set_comment(
|
100
85
|
self,
|
101
86
|
*,
|
102
|
-
|
87
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
88
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
103
89
|
model_name: sql_identifier.SqlIdentifier,
|
90
|
+
comment: str,
|
104
91
|
statement_params: Optional[Dict[str, Any]] = None,
|
105
92
|
) -> None:
|
106
93
|
query_result_checker.SqlResultValidator(
|
107
94
|
self._session,
|
108
|
-
|
95
|
+
(
|
96
|
+
f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
97
|
+
f" IS $${comment}$$"
|
98
|
+
),
|
109
99
|
statement_params=statement_params,
|
110
100
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
101
|
|
112
102
|
def drop_model(
|
113
103
|
self,
|
114
104
|
*,
|
105
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
106
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
115
107
|
model_name: sql_identifier.SqlIdentifier,
|
116
108
|
statement_params: Optional[Dict[str, Any]] = None,
|
117
109
|
) -> None:
|
118
110
|
query_result_checker.SqlResultValidator(
|
119
111
|
self._session,
|
120
|
-
f"DROP MODEL {self.
|
112
|
+
f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
|
121
113
|
statement_params=statement_params,
|
122
114
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
123
115
|
|
124
|
-
def
|
116
|
+
def rename(
|
125
117
|
self,
|
126
118
|
*,
|
127
|
-
|
119
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
120
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
121
|
+
model_name: sql_identifier.SqlIdentifier,
|
122
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
123
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
124
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
128
125
|
statement_params: Optional[Dict[str, Any]] = None,
|
129
126
|
) -> None:
|
130
|
-
if
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
statement_params=statement_params,
|
141
|
-
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
127
|
+
# Use registry's database and schema if a non fully qualified new model name is provided.
|
128
|
+
new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
|
129
|
+
query_result_checker.SqlResultValidator(
|
130
|
+
self._session,
|
131
|
+
(
|
132
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
133
|
+
f" RENAME TO {new_fully_qualified_name}"
|
134
|
+
),
|
135
|
+
statement_params=statement_params,
|
136
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
|
|
9
9
|
query_result_checker,
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
|
-
from snowflake.
|
12
|
+
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
13
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
14
15
|
|
15
16
|
|
@@ -20,34 +21,15 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
20
21
|
return f"'{url}'"
|
21
22
|
|
22
23
|
|
23
|
-
class ModelVersionSQLClient:
|
24
|
+
class ModelVersionSQLClient(_base._BaseSQLClient):
|
24
25
|
FUNCTION_NAME_COL_NAME = "name"
|
25
26
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
26
27
|
|
27
|
-
def __init__(
|
28
|
-
self,
|
29
|
-
session: session.Session,
|
30
|
-
*,
|
31
|
-
database_name: sql_identifier.SqlIdentifier,
|
32
|
-
schema_name: sql_identifier.SqlIdentifier,
|
33
|
-
) -> None:
|
34
|
-
self._session = session
|
35
|
-
self._database_name = database_name
|
36
|
-
self._schema_name = schema_name
|
37
|
-
|
38
|
-
def __eq__(self, __value: object) -> bool:
|
39
|
-
if not isinstance(__value, ModelVersionSQLClient):
|
40
|
-
return False
|
41
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
42
|
-
|
43
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
44
|
-
return identifier.get_schema_level_object_identifier(
|
45
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
46
|
-
)
|
47
|
-
|
48
28
|
def create_from_stage(
|
49
29
|
self,
|
50
30
|
*,
|
31
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
32
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
51
33
|
model_name: sql_identifier.SqlIdentifier,
|
52
34
|
version_name: sql_identifier.SqlIdentifier,
|
53
35
|
stage_path: str,
|
@@ -56,8 +38,8 @@ class ModelVersionSQLClient:
|
|
56
38
|
query_result_checker.SqlResultValidator(
|
57
39
|
self._session,
|
58
40
|
(
|
59
|
-
f"CREATE MODEL {self.
|
60
|
-
f" FROM {stage_path}"
|
41
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
42
|
+
f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
|
61
43
|
),
|
62
44
|
statement_params=statement_params,
|
63
45
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -66,6 +48,8 @@ class ModelVersionSQLClient:
|
|
66
48
|
def add_version_from_stage(
|
67
49
|
self,
|
68
50
|
*,
|
51
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
69
53
|
model_name: sql_identifier.SqlIdentifier,
|
70
54
|
version_name: sql_identifier.SqlIdentifier,
|
71
55
|
stage_path: str,
|
@@ -74,8 +58,8 @@ class ModelVersionSQLClient:
|
|
74
58
|
query_result_checker.SqlResultValidator(
|
75
59
|
self._session,
|
76
60
|
(
|
77
|
-
f"ALTER MODEL {self.
|
78
|
-
f" FROM {stage_path}"
|
61
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
62
|
+
f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
|
79
63
|
),
|
80
64
|
statement_params=statement_params,
|
81
65
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -83,6 +67,8 @@ class ModelVersionSQLClient:
|
|
83
67
|
def set_default_version(
|
84
68
|
self,
|
85
69
|
*,
|
70
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
71
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
86
72
|
model_name: sql_identifier.SqlIdentifier,
|
87
73
|
version_name: sql_identifier.SqlIdentifier,
|
88
74
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -90,15 +76,54 @@ class ModelVersionSQLClient:
|
|
90
76
|
query_result_checker.SqlResultValidator(
|
91
77
|
self._session,
|
92
78
|
(
|
93
|
-
f"ALTER MODEL {self.
|
79
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
94
80
|
f"SET DEFAULT_VERSION = {version_name.identifier()}"
|
95
81
|
),
|
96
82
|
statement_params=statement_params,
|
97
83
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
98
84
|
|
85
|
+
def list_file(
|
86
|
+
self,
|
87
|
+
*,
|
88
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
89
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
90
|
+
model_name: sql_identifier.SqlIdentifier,
|
91
|
+
version_name: sql_identifier.SqlIdentifier,
|
92
|
+
file_path: pathlib.PurePosixPath,
|
93
|
+
is_dir: bool = False,
|
94
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
95
|
+
) -> List[row.Row]:
|
96
|
+
# Workaround for snowURL bug.
|
97
|
+
trailing_slash = "/" if is_dir else ""
|
98
|
+
|
99
|
+
stage_location = (
|
100
|
+
pathlib.PurePosixPath(
|
101
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
102
|
+
"versions",
|
103
|
+
version_name.resolved(),
|
104
|
+
file_path,
|
105
|
+
).as_posix()
|
106
|
+
+ trailing_slash
|
107
|
+
)
|
108
|
+
stage_location_url = ParseResult(
|
109
|
+
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
110
|
+
).geturl()
|
111
|
+
|
112
|
+
return (
|
113
|
+
query_result_checker.SqlResultValidator(
|
114
|
+
self._session,
|
115
|
+
f"List {_normalize_url_for_sql(stage_location_url)}",
|
116
|
+
statement_params=statement_params,
|
117
|
+
)
|
118
|
+
.has_column("name", allow_empty=True)
|
119
|
+
.validate()
|
120
|
+
)
|
121
|
+
|
99
122
|
def get_file(
|
100
123
|
self,
|
101
124
|
*,
|
125
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
126
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
102
127
|
model_name: sql_identifier.SqlIdentifier,
|
103
128
|
version_name: sql_identifier.SqlIdentifier,
|
104
129
|
file_path: pathlib.PurePosixPath,
|
@@ -106,7 +131,10 @@ class ModelVersionSQLClient:
|
|
106
131
|
statement_params: Optional[Dict[str, Any]] = None,
|
107
132
|
) -> pathlib.Path:
|
108
133
|
stage_location = pathlib.PurePosixPath(
|
109
|
-
self.
|
134
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
135
|
+
"versions",
|
136
|
+
version_name.resolved(),
|
137
|
+
file_path,
|
110
138
|
).as_posix()
|
111
139
|
stage_location_url = ParseResult(
|
112
140
|
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
@@ -130,6 +158,8 @@ class ModelVersionSQLClient:
|
|
130
158
|
def show_functions(
|
131
159
|
self,
|
132
160
|
*,
|
161
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
162
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
133
163
|
model_name: sql_identifier.SqlIdentifier,
|
134
164
|
version_name: sql_identifier.SqlIdentifier,
|
135
165
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -137,7 +167,7 @@ class ModelVersionSQLClient:
|
|
137
167
|
res = query_result_checker.SqlResultValidator(
|
138
168
|
self._session,
|
139
169
|
(
|
140
|
-
f"SHOW FUNCTIONS IN MODEL {self.
|
170
|
+
f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
141
171
|
f" VERSION {version_name.identifier()}"
|
142
172
|
),
|
143
173
|
statement_params=statement_params,
|
@@ -148,23 +178,27 @@ class ModelVersionSQLClient:
|
|
148
178
|
def set_comment(
|
149
179
|
self,
|
150
180
|
*,
|
151
|
-
|
181
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
182
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
152
183
|
model_name: sql_identifier.SqlIdentifier,
|
153
184
|
version_name: sql_identifier.SqlIdentifier,
|
185
|
+
comment: str,
|
154
186
|
statement_params: Optional[Dict[str, Any]] = None,
|
155
187
|
) -> None:
|
156
188
|
query_result_checker.SqlResultValidator(
|
157
189
|
self._session,
|
158
190
|
(
|
159
|
-
f"ALTER MODEL {self.
|
191
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
160
192
|
f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
|
161
193
|
),
|
162
194
|
statement_params=statement_params,
|
163
195
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
164
196
|
|
165
|
-
def
|
197
|
+
def invoke_function_method(
|
166
198
|
self,
|
167
199
|
*,
|
200
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
201
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
168
202
|
model_name: sql_identifier.SqlIdentifier,
|
169
203
|
version_name: sql_identifier.SqlIdentifier,
|
170
204
|
method_name: sql_identifier.SqlIdentifier,
|
@@ -178,10 +212,12 @@ class ModelVersionSQLClient:
|
|
178
212
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
179
213
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
180
214
|
else:
|
215
|
+
actual_database_name = database_name or self._database_name
|
216
|
+
actual_schema_name = schema_name or self._schema_name
|
181
217
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
182
218
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
183
|
-
|
184
|
-
|
219
|
+
actual_database_name.identifier(),
|
220
|
+
actual_schema_name.identifier(),
|
185
221
|
tmp_table_name,
|
186
222
|
)
|
187
223
|
input_df.write.save_as_table( # type: ignore[call-overload]
|
@@ -196,7 +232,8 @@ class ModelVersionSQLClient:
|
|
196
232
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
197
233
|
with_statements.append(
|
198
234
|
f"{module_version_alias} AS "
|
199
|
-
f"MODEL {self.
|
235
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
236
|
+
f" VERSION {version_name.identifier()}"
|
200
237
|
)
|
201
238
|
|
202
239
|
args_sql_list = []
|
@@ -232,10 +269,93 @@ class ModelVersionSQLClient:
|
|
232
269
|
|
233
270
|
return output_df
|
234
271
|
|
272
|
+
def invoke_table_function_method(
|
273
|
+
self,
|
274
|
+
*,
|
275
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
276
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
277
|
+
model_name: sql_identifier.SqlIdentifier,
|
278
|
+
version_name: sql_identifier.SqlIdentifier,
|
279
|
+
method_name: sql_identifier.SqlIdentifier,
|
280
|
+
input_df: dataframe.DataFrame,
|
281
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
282
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
283
|
+
partition_column: Optional[sql_identifier.SqlIdentifier],
|
284
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
285
|
+
) -> dataframe.DataFrame:
|
286
|
+
with_statements = []
|
287
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
288
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
289
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
290
|
+
else:
|
291
|
+
actual_database_name = database_name or self._database_name
|
292
|
+
actual_schema_name = schema_name or self._schema_name
|
293
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
294
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
295
|
+
actual_database_name.identifier(),
|
296
|
+
actual_schema_name.identifier(),
|
297
|
+
tmp_table_name,
|
298
|
+
)
|
299
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
300
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
301
|
+
mode="errorifexists",
|
302
|
+
table_type="temporary",
|
303
|
+
statement_params=statement_params,
|
304
|
+
)
|
305
|
+
|
306
|
+
module_version_alias = "MODEL_VERSION_ALIAS"
|
307
|
+
with_statements.append(
|
308
|
+
f"{module_version_alias} AS "
|
309
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
310
|
+
f" VERSION {version_name.identifier()}"
|
311
|
+
)
|
312
|
+
|
313
|
+
partition_by = partition_column.identifier() if partition_column is not None else "1"
|
314
|
+
|
315
|
+
args_sql_list = []
|
316
|
+
for input_arg_value in input_args:
|
317
|
+
args_sql_list.append(input_arg_value)
|
318
|
+
|
319
|
+
args_sql = ", ".join(args_sql_list)
|
320
|
+
|
321
|
+
sql = textwrap.dedent(
|
322
|
+
f"""WITH {','.join(with_statements)}
|
323
|
+
SELECT *,
|
324
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
325
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
326
|
+
OVER (PARTITION BY {partition_by}))"""
|
327
|
+
)
|
328
|
+
|
329
|
+
output_df = self._session.sql(sql)
|
330
|
+
|
331
|
+
# Prepare the output
|
332
|
+
output_cols = []
|
333
|
+
output_names = []
|
334
|
+
|
335
|
+
for output_name, output_type, output_col_name in returns:
|
336
|
+
output_cols.append(F.col(output_name).astype(output_type))
|
337
|
+
output_names.append(output_col_name)
|
338
|
+
|
339
|
+
if partition_column is not None:
|
340
|
+
output_cols.append(F.col(partition_column.identifier()))
|
341
|
+
output_names.append(partition_column)
|
342
|
+
|
343
|
+
output_df = output_df.with_columns(
|
344
|
+
col_names=output_names,
|
345
|
+
values=output_cols,
|
346
|
+
)
|
347
|
+
|
348
|
+
if statement_params:
|
349
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
350
|
+
|
351
|
+
return output_df
|
352
|
+
|
235
353
|
def set_metadata(
|
236
354
|
self,
|
237
355
|
metadata_dict: Dict[str, Any],
|
238
356
|
*,
|
357
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
358
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
239
359
|
model_name: sql_identifier.SqlIdentifier,
|
240
360
|
version_name: sql_identifier.SqlIdentifier,
|
241
361
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -244,8 +364,8 @@ class ModelVersionSQLClient:
|
|
244
364
|
query_result_checker.SqlResultValidator(
|
245
365
|
self._session,
|
246
366
|
(
|
247
|
-
f"ALTER MODEL {self.
|
248
|
-
f" SET METADATA=$${json_metadata}$$"
|
367
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
368
|
+
f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
|
249
369
|
),
|
250
370
|
statement_params=statement_params,
|
251
371
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -253,12 +373,17 @@ class ModelVersionSQLClient:
|
|
253
373
|
def drop_version(
|
254
374
|
self,
|
255
375
|
*,
|
376
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
377
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
256
378
|
model_name: sql_identifier.SqlIdentifier,
|
257
379
|
version_name: sql_identifier.SqlIdentifier,
|
258
380
|
statement_params: Optional[Dict[str, Any]] = None,
|
259
381
|
) -> None:
|
260
382
|
query_result_checker.SqlResultValidator(
|
261
383
|
self._session,
|
262
|
-
|
384
|
+
(
|
385
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
386
|
+
f" DROP VERSION {version_name.identifier()}"
|
387
|
+
),
|
263
388
|
statement_params=statement_params,
|
264
389
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,46 +1,20 @@
|
|
1
1
|
from typing import Any, Dict, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
query_result_checker,
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
9
5
|
|
10
6
|
|
11
|
-
class StageSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, StageSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_stage_name(
|
29
|
-
self,
|
30
|
-
stage_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
|
34
|
-
)
|
35
|
-
|
7
|
+
class StageSQLClient(_base._BaseSQLClient):
|
36
8
|
def create_tmp_stage(
|
37
9
|
self,
|
38
10
|
*,
|
11
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
12
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
39
13
|
stage_name: sql_identifier.SqlIdentifier,
|
40
14
|
statement_params: Optional[Dict[str, Any]] = None,
|
41
15
|
) -> None:
|
42
16
|
query_result_checker.SqlResultValidator(
|
43
17
|
self._session,
|
44
|
-
f"CREATE TEMPORARY STAGE {self.
|
18
|
+
f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
45
19
|
statement_params=statement_params,
|
46
20
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|