snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
+
import os
|
1
2
|
import pathlib
|
2
3
|
import tempfile
|
3
|
-
from
|
4
|
-
from typing import Any, Dict, Generator, List, Optional, Union, cast
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
@@ -19,7 +19,9 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
19
19
|
model_manifest,
|
20
20
|
model_manifest_schema,
|
21
21
|
)
|
22
|
+
from snowflake.ml.model._packager.model_env import model_env
|
22
23
|
from snowflake.ml.model._packager.model_meta import model_meta
|
24
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
23
25
|
from snowflake.ml.model._signatures import snowpark_handler
|
24
26
|
from snowflake.snowpark import dataframe, row, session
|
25
27
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -72,37 +74,57 @@ class ModelOperator:
|
|
72
74
|
and self._model_version_client == __value._model_version_client
|
73
75
|
)
|
74
76
|
|
75
|
-
def prepare_model_stage_path(
|
77
|
+
def prepare_model_stage_path(
|
78
|
+
self,
|
79
|
+
*,
|
80
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
81
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
82
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
83
|
+
) -> str:
|
76
84
|
stage_name = sql_identifier.SqlIdentifier(
|
77
85
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
78
86
|
)
|
79
|
-
self._stage_client.create_tmp_stage(
|
80
|
-
|
87
|
+
self._stage_client.create_tmp_stage(
|
88
|
+
database_name=database_name,
|
89
|
+
schema_name=schema_name,
|
90
|
+
stage_name=stage_name,
|
91
|
+
statement_params=statement_params,
|
92
|
+
)
|
93
|
+
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
81
94
|
|
82
95
|
def create_from_stage(
|
83
96
|
self,
|
84
97
|
composed_model: model_composer.ModelComposer,
|
85
98
|
*,
|
99
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
100
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
86
101
|
model_name: sql_identifier.SqlIdentifier,
|
87
102
|
version_name: sql_identifier.SqlIdentifier,
|
88
103
|
statement_params: Optional[Dict[str, Any]] = None,
|
89
104
|
) -> None:
|
90
105
|
stage_path = str(composed_model.stage_path)
|
91
106
|
if self.validate_existence(
|
107
|
+
database_name=database_name,
|
108
|
+
schema_name=schema_name,
|
92
109
|
model_name=model_name,
|
93
110
|
statement_params=statement_params,
|
94
111
|
):
|
95
112
|
if self.validate_existence(
|
113
|
+
database_name=database_name,
|
114
|
+
schema_name=schema_name,
|
96
115
|
model_name=model_name,
|
97
116
|
version_name=version_name,
|
98
117
|
statement_params=statement_params,
|
99
118
|
):
|
100
119
|
raise ValueError(
|
101
|
-
|
102
|
-
f"
|
120
|
+
"Model "
|
121
|
+
f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
122
|
+
f" version {version_name} already existed."
|
103
123
|
)
|
104
124
|
else:
|
105
125
|
self._model_version_client.add_version_from_stage(
|
126
|
+
database_name=database_name,
|
127
|
+
schema_name=schema_name,
|
106
128
|
stage_path=stage_path,
|
107
129
|
model_name=model_name,
|
108
130
|
version_name=version_name,
|
@@ -110,6 +132,8 @@ class ModelOperator:
|
|
110
132
|
)
|
111
133
|
else:
|
112
134
|
self._model_version_client.create_from_stage(
|
135
|
+
database_name=database_name,
|
136
|
+
schema_name=schema_name,
|
113
137
|
stage_path=stage_path,
|
114
138
|
model_name=model_name,
|
115
139
|
version_name=version_name,
|
@@ -119,17 +143,23 @@ class ModelOperator:
|
|
119
143
|
def show_models_or_versions(
|
120
144
|
self,
|
121
145
|
*,
|
146
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
147
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
122
148
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
123
149
|
statement_params: Optional[Dict[str, Any]] = None,
|
124
150
|
) -> List[row.Row]:
|
125
151
|
if model_name:
|
126
152
|
return self._model_client.show_versions(
|
153
|
+
database_name=database_name,
|
154
|
+
schema_name=schema_name,
|
127
155
|
model_name=model_name,
|
128
156
|
validate_result=False,
|
129
157
|
statement_params=statement_params,
|
130
158
|
)
|
131
159
|
else:
|
132
160
|
return self._model_client.show_models(
|
161
|
+
database_name=database_name,
|
162
|
+
schema_name=schema_name,
|
133
163
|
validate_result=False,
|
134
164
|
statement_params=statement_params,
|
135
165
|
)
|
@@ -137,10 +167,14 @@ class ModelOperator:
|
|
137
167
|
def list_models_or_versions(
|
138
168
|
self,
|
139
169
|
*,
|
170
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
171
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
140
172
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
141
173
|
statement_params: Optional[Dict[str, Any]] = None,
|
142
174
|
) -> List[sql_identifier.SqlIdentifier]:
|
143
175
|
res = self.show_models_or_versions(
|
176
|
+
database_name=database_name,
|
177
|
+
schema_name=schema_name,
|
144
178
|
model_name=model_name,
|
145
179
|
statement_params=statement_params,
|
146
180
|
)
|
@@ -153,12 +187,16 @@ class ModelOperator:
|
|
153
187
|
def validate_existence(
|
154
188
|
self,
|
155
189
|
*,
|
190
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
191
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
156
192
|
model_name: sql_identifier.SqlIdentifier,
|
157
193
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
158
194
|
statement_params: Optional[Dict[str, Any]] = None,
|
159
195
|
) -> bool:
|
160
196
|
if version_name:
|
161
197
|
res = self._model_client.show_versions(
|
198
|
+
database_name=database_name,
|
199
|
+
schema_name=schema_name,
|
162
200
|
model_name=model_name,
|
163
201
|
version_name=version_name,
|
164
202
|
validate_result=False,
|
@@ -166,6 +204,8 @@ class ModelOperator:
|
|
166
204
|
)
|
167
205
|
else:
|
168
206
|
res = self._model_client.show_models(
|
207
|
+
database_name=database_name,
|
208
|
+
schema_name=schema_name,
|
169
209
|
model_name=model_name,
|
170
210
|
validate_result=False,
|
171
211
|
statement_params=statement_params,
|
@@ -175,12 +215,16 @@ class ModelOperator:
|
|
175
215
|
def get_comment(
|
176
216
|
self,
|
177
217
|
*,
|
218
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
219
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
178
220
|
model_name: sql_identifier.SqlIdentifier,
|
179
221
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
180
222
|
statement_params: Optional[Dict[str, Any]] = None,
|
181
223
|
) -> str:
|
182
224
|
if version_name:
|
183
225
|
res = self._model_client.show_versions(
|
226
|
+
database_name=database_name,
|
227
|
+
schema_name=schema_name,
|
184
228
|
model_name=model_name,
|
185
229
|
version_name=version_name,
|
186
230
|
statement_params=statement_params,
|
@@ -188,6 +232,8 @@ class ModelOperator:
|
|
188
232
|
col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
|
189
233
|
else:
|
190
234
|
res = self._model_client.show_models(
|
235
|
+
database_name=database_name,
|
236
|
+
schema_name=schema_name,
|
191
237
|
model_name=model_name,
|
192
238
|
statement_params=statement_params,
|
193
239
|
)
|
@@ -198,6 +244,8 @@ class ModelOperator:
|
|
198
244
|
self,
|
199
245
|
*,
|
200
246
|
comment: str,
|
247
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
248
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
201
249
|
model_name: sql_identifier.SqlIdentifier,
|
202
250
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
203
251
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -205,6 +253,8 @@ class ModelOperator:
|
|
205
253
|
if version_name:
|
206
254
|
self._model_version_client.set_comment(
|
207
255
|
comment=comment,
|
256
|
+
database_name=database_name,
|
257
|
+
schema_name=schema_name,
|
208
258
|
model_name=model_name,
|
209
259
|
version_name=version_name,
|
210
260
|
statement_params=statement_params,
|
@@ -212,6 +262,8 @@ class ModelOperator:
|
|
212
262
|
else:
|
213
263
|
self._model_client.set_comment(
|
214
264
|
comment=comment,
|
265
|
+
database_name=database_name,
|
266
|
+
schema_name=schema_name,
|
215
267
|
model_name=model_name,
|
216
268
|
statement_params=statement_params,
|
217
269
|
)
|
@@ -219,25 +271,42 @@ class ModelOperator:
|
|
219
271
|
def set_default_version(
|
220
272
|
self,
|
221
273
|
*,
|
274
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
275
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
222
276
|
model_name: sql_identifier.SqlIdentifier,
|
223
277
|
version_name: sql_identifier.SqlIdentifier,
|
224
278
|
statement_params: Optional[Dict[str, Any]] = None,
|
225
279
|
) -> None:
|
226
280
|
if not self.validate_existence(
|
227
|
-
|
281
|
+
database_name=database_name,
|
282
|
+
schema_name=schema_name,
|
283
|
+
model_name=model_name,
|
284
|
+
version_name=version_name,
|
285
|
+
statement_params=statement_params,
|
228
286
|
):
|
229
287
|
raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
|
230
288
|
self._model_version_client.set_default_version(
|
231
|
-
|
289
|
+
database_name=database_name,
|
290
|
+
schema_name=schema_name,
|
291
|
+
model_name=model_name,
|
292
|
+
version_name=version_name,
|
293
|
+
statement_params=statement_params,
|
232
294
|
)
|
233
295
|
|
234
296
|
def get_default_version(
|
235
297
|
self,
|
236
298
|
*,
|
299
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
300
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
237
301
|
model_name: sql_identifier.SqlIdentifier,
|
238
302
|
statement_params: Optional[Dict[str, Any]] = None,
|
239
303
|
) -> sql_identifier.SqlIdentifier:
|
240
|
-
res = self._model_client.show_models(
|
304
|
+
res = self._model_client.show_models(
|
305
|
+
database_name=database_name,
|
306
|
+
schema_name=schema_name,
|
307
|
+
model_name=model_name,
|
308
|
+
statement_params=statement_params,
|
309
|
+
)[0]
|
241
310
|
return sql_identifier.SqlIdentifier(
|
242
311
|
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
243
312
|
)
|
@@ -245,14 +314,18 @@ class ModelOperator:
|
|
245
314
|
def get_tag_value(
|
246
315
|
self,
|
247
316
|
*,
|
317
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
318
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
248
319
|
model_name: sql_identifier.SqlIdentifier,
|
249
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
250
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
320
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
321
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
251
322
|
tag_name: sql_identifier.SqlIdentifier,
|
252
323
|
statement_params: Optional[Dict[str, Any]] = None,
|
253
324
|
) -> Optional[str]:
|
254
325
|
r = self._tag_client.get_tag_value(
|
255
|
-
|
326
|
+
database_name=database_name,
|
327
|
+
schema_name=schema_name,
|
328
|
+
model_name=model_name,
|
256
329
|
tag_database_name=tag_database_name,
|
257
330
|
tag_schema_name=tag_schema_name,
|
258
331
|
tag_name=tag_name,
|
@@ -266,11 +339,15 @@ class ModelOperator:
|
|
266
339
|
def show_tags(
|
267
340
|
self,
|
268
341
|
*,
|
342
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
343
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
269
344
|
model_name: sql_identifier.SqlIdentifier,
|
270
345
|
statement_params: Optional[Dict[str, Any]] = None,
|
271
346
|
) -> Dict[str, str]:
|
272
347
|
tags_info = self._tag_client.get_tag_list(
|
273
|
-
|
348
|
+
database_name=database_name,
|
349
|
+
schema_name=schema_name,
|
350
|
+
model_name=model_name,
|
274
351
|
statement_params=statement_params,
|
275
352
|
)
|
276
353
|
res: Dict[str, str] = {
|
@@ -286,14 +363,18 @@ class ModelOperator:
|
|
286
363
|
def set_tag(
|
287
364
|
self,
|
288
365
|
*,
|
366
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
367
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
289
368
|
model_name: sql_identifier.SqlIdentifier,
|
290
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
291
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
369
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
370
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
292
371
|
tag_name: sql_identifier.SqlIdentifier,
|
293
372
|
tag_value: str,
|
294
373
|
statement_params: Optional[Dict[str, Any]] = None,
|
295
374
|
) -> None:
|
296
375
|
self._tag_client.set_tag_on_model(
|
376
|
+
database_name=database_name,
|
377
|
+
schema_name=schema_name,
|
297
378
|
model_name=model_name,
|
298
379
|
tag_database_name=tag_database_name,
|
299
380
|
tag_schema_name=tag_schema_name,
|
@@ -305,13 +386,17 @@ class ModelOperator:
|
|
305
386
|
def unset_tag(
|
306
387
|
self,
|
307
388
|
*,
|
389
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
390
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
308
391
|
model_name: sql_identifier.SqlIdentifier,
|
309
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
310
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
392
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
393
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
311
394
|
tag_name: sql_identifier.SqlIdentifier,
|
312
395
|
statement_params: Optional[Dict[str, Any]] = None,
|
313
396
|
) -> None:
|
314
397
|
self._tag_client.unset_tag_on_model(
|
398
|
+
database_name=database_name,
|
399
|
+
schema_name=schema_name,
|
315
400
|
model_name=model_name,
|
316
401
|
tag_database_name=tag_database_name,
|
317
402
|
tag_schema_name=tag_schema_name,
|
@@ -322,12 +407,16 @@ class ModelOperator:
|
|
322
407
|
def get_model_version_manifest(
|
323
408
|
self,
|
324
409
|
*,
|
410
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
411
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
325
412
|
model_name: sql_identifier.SqlIdentifier,
|
326
413
|
version_name: sql_identifier.SqlIdentifier,
|
327
414
|
statement_params: Optional[Dict[str, Any]] = None,
|
328
415
|
) -> model_manifest_schema.ModelManifestDict:
|
329
416
|
with tempfile.TemporaryDirectory() as tmpdir:
|
330
417
|
self._model_version_client.get_file(
|
418
|
+
database_name=database_name,
|
419
|
+
schema_name=schema_name,
|
331
420
|
model_name=model_name,
|
332
421
|
version_name=version_name,
|
333
422
|
file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
|
@@ -337,16 +426,6 @@ class ModelOperator:
|
|
337
426
|
mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
|
338
427
|
return mm.load()
|
339
428
|
|
340
|
-
@contextmanager
|
341
|
-
def _enable_model_details(
|
342
|
-
self,
|
343
|
-
*,
|
344
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
345
|
-
) -> Generator[None, None, None]:
|
346
|
-
self._model_client.config_model_details(enable=True, statement_params=statement_params)
|
347
|
-
yield
|
348
|
-
self._model_client.config_model_details(enable=False, statement_params=statement_params)
|
349
|
-
|
350
429
|
@staticmethod
|
351
430
|
def _match_model_spec_with_sql_functions(
|
352
431
|
sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
|
@@ -370,68 +449,75 @@ class ModelOperator:
|
|
370
449
|
def get_functions(
|
371
450
|
self,
|
372
451
|
*,
|
452
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
453
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
373
454
|
model_name: sql_identifier.SqlIdentifier,
|
374
455
|
version_name: sql_identifier.SqlIdentifier,
|
375
456
|
statement_params: Optional[Dict[str, Any]] = None,
|
376
457
|
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
)
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
458
|
+
raw_model_spec_res = self._model_client.show_versions(
|
459
|
+
database_name=database_name,
|
460
|
+
schema_name=schema_name,
|
461
|
+
model_name=model_name,
|
462
|
+
version_name=version_name,
|
463
|
+
check_model_details=True,
|
464
|
+
statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True},
|
465
|
+
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
466
|
+
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
467
|
+
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
468
|
+
show_functions_res = self._model_version_client.show_functions(
|
469
|
+
database_name=database_name,
|
470
|
+
schema_name=schema_name,
|
471
|
+
model_name=model_name,
|
472
|
+
version_name=version_name,
|
473
|
+
statement_params=statement_params,
|
474
|
+
)
|
475
|
+
function_names_and_types = []
|
476
|
+
for r in show_functions_res:
|
477
|
+
function_name = sql_identifier.SqlIdentifier(
|
478
|
+
r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
|
390
479
|
)
|
391
|
-
function_names_and_types = []
|
392
|
-
for r in show_functions_res:
|
393
|
-
function_name = sql_identifier.SqlIdentifier(
|
394
|
-
r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
|
395
|
-
)
|
396
480
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
function_names_and_types.append((function_name, function_type))
|
407
|
-
|
408
|
-
signatures = model_spec["signatures"]
|
409
|
-
function_names = [name for name, _ in function_names_and_types]
|
410
|
-
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
411
|
-
function_names, list(signatures.keys())
|
412
|
-
)
|
481
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
482
|
+
try:
|
483
|
+
return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
|
484
|
+
except KeyError:
|
485
|
+
pass
|
486
|
+
else:
|
487
|
+
if "TABLE" in return_type:
|
488
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
413
489
|
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
490
|
+
function_names_and_types.append((function_name, function_type))
|
491
|
+
|
492
|
+
signatures = model_spec["signatures"]
|
493
|
+
function_names = [name for name, _ in function_names_and_types]
|
494
|
+
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
495
|
+
function_names, list(signatures.keys())
|
496
|
+
)
|
497
|
+
|
498
|
+
return [
|
499
|
+
model_manifest_schema.ModelFunctionInfo(
|
500
|
+
name=function_name.identifier(),
|
501
|
+
target_method=function_name_mapping[function_name],
|
502
|
+
target_method_function_type=function_type,
|
503
|
+
signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
|
504
|
+
)
|
505
|
+
for function_name, function_type in function_names_and_types
|
506
|
+
]
|
425
507
|
|
426
508
|
def invoke_method(
|
427
509
|
self,
|
428
510
|
*,
|
429
511
|
method_name: sql_identifier.SqlIdentifier,
|
512
|
+
method_function_type: str,
|
430
513
|
signature: model_signature.ModelSignature,
|
431
514
|
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
515
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
516
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
432
517
|
model_name: sql_identifier.SqlIdentifier,
|
433
518
|
version_name: sql_identifier.SqlIdentifier,
|
434
519
|
strict_input_validation: bool = False,
|
520
|
+
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
435
521
|
statement_params: Optional[Dict[str, str]] = None,
|
436
522
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
437
523
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
@@ -469,15 +555,31 @@ class ModelOperator:
|
|
469
555
|
if output_name in original_cols:
|
470
556
|
original_cols.remove(output_name)
|
471
557
|
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
558
|
+
if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
|
559
|
+
df_res = self._model_version_client.invoke_function_method(
|
560
|
+
method_name=method_name,
|
561
|
+
input_df=s_df,
|
562
|
+
input_args=input_args,
|
563
|
+
returns=returns,
|
564
|
+
database_name=database_name,
|
565
|
+
schema_name=schema_name,
|
566
|
+
model_name=model_name,
|
567
|
+
version_name=version_name,
|
568
|
+
statement_params=statement_params,
|
569
|
+
)
|
570
|
+
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
571
|
+
df_res = self._model_version_client.invoke_table_function_method(
|
572
|
+
method_name=method_name,
|
573
|
+
input_df=s_df,
|
574
|
+
input_args=input_args,
|
575
|
+
partition_column=partition_column,
|
576
|
+
returns=returns,
|
577
|
+
database_name=database_name,
|
578
|
+
schema_name=schema_name,
|
579
|
+
model_name=model_name,
|
580
|
+
version_name=version_name,
|
581
|
+
statement_params=statement_params,
|
582
|
+
)
|
481
583
|
|
482
584
|
if keep_order:
|
483
585
|
df_res = df_res.sort(
|
@@ -486,7 +588,11 @@ class ModelOperator:
|
|
486
588
|
)
|
487
589
|
|
488
590
|
if not output_with_input_features:
|
489
|
-
|
591
|
+
cols_to_drop = original_cols
|
592
|
+
if partition_column is not None:
|
593
|
+
# don't drop partition column
|
594
|
+
cols_to_drop.remove(partition_column.identifier())
|
595
|
+
df_res = df_res.drop(*cols_to_drop)
|
490
596
|
|
491
597
|
# Get final result
|
492
598
|
if not isinstance(X, dataframe.DataFrame):
|
@@ -497,18 +603,97 @@ class ModelOperator:
|
|
497
603
|
def delete_model_or_version(
|
498
604
|
self,
|
499
605
|
*,
|
606
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
607
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
500
608
|
model_name: sql_identifier.SqlIdentifier,
|
501
609
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
502
610
|
statement_params: Optional[Dict[str, Any]] = None,
|
503
611
|
) -> None:
|
504
612
|
if version_name:
|
505
613
|
self._model_version_client.drop_version(
|
614
|
+
database_name=database_name,
|
615
|
+
schema_name=schema_name,
|
506
616
|
model_name=model_name,
|
507
617
|
version_name=version_name,
|
508
618
|
statement_params=statement_params,
|
509
619
|
)
|
510
620
|
else:
|
511
621
|
self._model_client.drop_model(
|
622
|
+
database_name=database_name,
|
623
|
+
schema_name=schema_name,
|
624
|
+
model_name=model_name,
|
625
|
+
statement_params=statement_params,
|
626
|
+
)
|
627
|
+
|
628
|
+
def rename(
|
629
|
+
self,
|
630
|
+
*,
|
631
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
632
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
633
|
+
model_name: sql_identifier.SqlIdentifier,
|
634
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
635
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
636
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
637
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
638
|
+
) -> None:
|
639
|
+
self._model_client.rename(
|
640
|
+
database_name=database_name,
|
641
|
+
schema_name=schema_name,
|
642
|
+
model_name=model_name,
|
643
|
+
new_model_db=new_model_db,
|
644
|
+
new_model_schema=new_model_schema,
|
645
|
+
new_model_name=new_model_name,
|
646
|
+
statement_params=statement_params,
|
647
|
+
)
|
648
|
+
|
649
|
+
# Map indicating in different modes, the path to list and download.
|
650
|
+
# The boolean value indicates if it is a directory,
|
651
|
+
MODEL_FILE_DOWNLOAD_PATTERN = {
|
652
|
+
"minimal": {
|
653
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
654
|
+
/ model_meta.MODEL_METADATA_FILE: False,
|
655
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True,
|
656
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
657
|
+
/ model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True,
|
658
|
+
},
|
659
|
+
"model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True},
|
660
|
+
"full": {pathlib.PurePosixPath(os.curdir): True},
|
661
|
+
}
|
662
|
+
|
663
|
+
def download_files(
|
664
|
+
self,
|
665
|
+
*,
|
666
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
667
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
668
|
+
model_name: sql_identifier.SqlIdentifier,
|
669
|
+
version_name: sql_identifier.SqlIdentifier,
|
670
|
+
target_path: pathlib.Path,
|
671
|
+
mode: Literal["full", "model", "minimal"] = "model",
|
672
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
673
|
+
) -> None:
|
674
|
+
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
675
|
+
list_file_res = self._model_version_client.list_file(
|
676
|
+
database_name=database_name,
|
677
|
+
schema_name=schema_name,
|
512
678
|
model_name=model_name,
|
679
|
+
version_name=version_name,
|
680
|
+
file_path=remote_rel_path,
|
681
|
+
is_dir=is_dir,
|
513
682
|
statement_params=statement_params,
|
514
683
|
)
|
684
|
+
file_list = [
|
685
|
+
pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions/<version_name>/...
|
686
|
+
for row in list_file_res
|
687
|
+
]
|
688
|
+
for stage_file_path in file_list:
|
689
|
+
local_file_dir = target_path / stage_file_path.parent
|
690
|
+
local_file_dir.mkdir(parents=True, exist_ok=True)
|
691
|
+
self._model_version_client.get_file(
|
692
|
+
database_name=database_name,
|
693
|
+
schema_name=schema_name,
|
694
|
+
model_name=model_name,
|
695
|
+
version_name=version_name,
|
696
|
+
file_path=stage_file_path,
|
697
|
+
target_path=local_file_dir,
|
698
|
+
statement_params=statement_params,
|
699
|
+
)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class _BaseSQLClient:
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
session: session.Session,
|
11
|
+
*,
|
12
|
+
database_name: sql_identifier.SqlIdentifier,
|
13
|
+
schema_name: sql_identifier.SqlIdentifier,
|
14
|
+
) -> None:
|
15
|
+
self._session = session
|
16
|
+
self._database_name = database_name
|
17
|
+
self._schema_name = schema_name
|
18
|
+
|
19
|
+
def __eq__(self, __value: object) -> bool:
|
20
|
+
if not isinstance(__value, _BaseSQLClient):
|
21
|
+
return False
|
22
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
23
|
+
|
24
|
+
def fully_qualified_object_name(
|
25
|
+
self,
|
26
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
27
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
28
|
+
object_name: sql_identifier.SqlIdentifier,
|
29
|
+
) -> str:
|
30
|
+
actual_database_name = database_name or self._database_name
|
31
|
+
actual_schema_name = schema_name or self._schema_name
|
32
|
+
return identifier.get_schema_level_object_identifier(
|
33
|
+
actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
|
34
|
+
)
|