snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/identifier.py +78 -72
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
- snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
- snowflake/ml/modeling/cluster/birch.py +106 -135
- snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
- snowflake/ml/modeling/cluster/dbscan.py +106 -135
- snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
- snowflake/ml/modeling/cluster/k_means.py +105 -135
- snowflake/ml/modeling/cluster/mean_shift.py +106 -135
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
- snowflake/ml/modeling/cluster/optics.py +106 -135
- snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
- snowflake/ml/modeling/compose/column_transformer.py +106 -135
- snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
- snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
- snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
- snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
- snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
- snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
- snowflake/ml/modeling/covariance/oas.py +99 -128
- snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
- snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
- snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
- snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
- snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/pca.py +106 -135
- snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
- snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
- snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
- snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
- snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
- snowflake/ml/modeling/framework/base.py +83 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
- snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
- snowflake/ml/modeling/impute/knn_imputer.py +106 -135
- snowflake/ml/modeling/impute/missing_indicator.py +106 -135
- snowflake/ml/modeling/impute/simple_imputer.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
- snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
- snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/lars.py +108 -135
- snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
- snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
- snowflake/ml/modeling/linear_model/perceptron.py +107 -135
- snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ridge.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
- snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
- snowflake/ml/modeling/manifold/isomap.py +106 -135
- snowflake/ml/modeling/manifold/mds.py +106 -135
- snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
- snowflake/ml/modeling/manifold/tsne.py +106 -135
- snowflake/ml/modeling/metrics/classification.py +196 -55
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
- snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
- snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
- snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
- snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
- snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
- snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
- snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
- snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
- snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
- snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
- snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
- snowflake/ml/modeling/svm/linear_svc.py +108 -135
- snowflake/ml/modeling/svm/linear_svr.py +108 -135
- snowflake/ml/modeling/svm/nu_svc.py +108 -135
- snowflake/ml/modeling/svm/nu_svr.py +108 -135
- snowflake/ml/modeling/svm/svc.py +108 -135
- snowflake/ml/modeling/svm/svr.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
- snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,308 @@
|
|
1
|
+
import pathlib
|
2
|
+
import tempfile
|
3
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import sql_identifier
|
8
|
+
from snowflake.ml.model import model_signature, type_hints
|
9
|
+
from snowflake.ml.model._client.ops import metadata_ops
|
10
|
+
from snowflake.ml.model._client.sql import (
|
11
|
+
model as model_sql,
|
12
|
+
model_version as model_version_sql,
|
13
|
+
stage as stage_sql,
|
14
|
+
)
|
15
|
+
from snowflake.ml.model._model_composer import model_composer
|
16
|
+
from snowflake.ml.model._model_composer.model_manifest import (
|
17
|
+
model_manifest,
|
18
|
+
model_manifest_schema,
|
19
|
+
)
|
20
|
+
from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
|
21
|
+
from snowflake.ml.model._signatures import snowpark_handler
|
22
|
+
from snowflake.snowpark import dataframe, session
|
23
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
24
|
+
|
25
|
+
|
26
|
+
class ModelOperator:
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
session: session.Session,
|
30
|
+
*,
|
31
|
+
database_name: sql_identifier.SqlIdentifier,
|
32
|
+
schema_name: sql_identifier.SqlIdentifier,
|
33
|
+
) -> None:
|
34
|
+
# Ideally, we should only keep session object inside the client, however, some components other than client
|
35
|
+
# are requiring session object like ModelComposer and SnowparkDataFrameHandler. We currently cannot refractor
|
36
|
+
# them all but we should try to avoid use the _session object here unless no other choice.
|
37
|
+
self._session = session
|
38
|
+
self._stage_client = stage_sql.StageSQLClient(
|
39
|
+
session,
|
40
|
+
database_name=database_name,
|
41
|
+
schema_name=schema_name,
|
42
|
+
)
|
43
|
+
self._model_client = model_sql.ModelSQLClient(
|
44
|
+
session,
|
45
|
+
database_name=database_name,
|
46
|
+
schema_name=schema_name,
|
47
|
+
)
|
48
|
+
self._model_version_client = model_version_sql.ModelVersionSQLClient(
|
49
|
+
session,
|
50
|
+
database_name=database_name,
|
51
|
+
schema_name=schema_name,
|
52
|
+
)
|
53
|
+
self._metadata_ops = metadata_ops.MetadataOperator(
|
54
|
+
session,
|
55
|
+
database_name=database_name,
|
56
|
+
schema_name=schema_name,
|
57
|
+
)
|
58
|
+
|
59
|
+
def __eq__(self, __value: object) -> bool:
|
60
|
+
if not isinstance(__value, ModelOperator):
|
61
|
+
return False
|
62
|
+
return (
|
63
|
+
self._stage_client == __value._stage_client
|
64
|
+
and self._model_client == __value._model_client
|
65
|
+
and self._model_version_client == __value._model_version_client
|
66
|
+
)
|
67
|
+
|
68
|
+
def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
|
69
|
+
stage_name = sql_identifier.SqlIdentifier(
|
70
|
+
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
71
|
+
)
|
72
|
+
self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params)
|
73
|
+
return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model"
|
74
|
+
|
75
|
+
def create_from_stage(
|
76
|
+
self,
|
77
|
+
composed_model: model_composer.ModelComposer,
|
78
|
+
*,
|
79
|
+
model_name: sql_identifier.SqlIdentifier,
|
80
|
+
version_name: sql_identifier.SqlIdentifier,
|
81
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
82
|
+
) -> None:
|
83
|
+
stage_path = str(composed_model.stage_path)
|
84
|
+
if self.validate_existence(
|
85
|
+
model_name=model_name,
|
86
|
+
statement_params=statement_params,
|
87
|
+
):
|
88
|
+
if self.validate_existence(
|
89
|
+
model_name=model_name,
|
90
|
+
version_name=version_name,
|
91
|
+
statement_params=statement_params,
|
92
|
+
):
|
93
|
+
raise ValueError(
|
94
|
+
f"Model {self._model_version_client.fully_qualified_model_name(model_name)} "
|
95
|
+
f"version {version_name} already existed."
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
self._model_version_client.add_version_from_stage(
|
99
|
+
stage_path=stage_path,
|
100
|
+
model_name=model_name,
|
101
|
+
version_name=version_name,
|
102
|
+
statement_params=statement_params,
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
self._model_version_client.create_from_stage(
|
106
|
+
stage_path=stage_path,
|
107
|
+
model_name=model_name,
|
108
|
+
version_name=version_name,
|
109
|
+
statement_params=statement_params,
|
110
|
+
)
|
111
|
+
|
112
|
+
def list_models_or_versions(
|
113
|
+
self,
|
114
|
+
*,
|
115
|
+
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
116
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
117
|
+
) -> List[sql_identifier.SqlIdentifier]:
|
118
|
+
if model_name:
|
119
|
+
res = self._model_client.show_versions(
|
120
|
+
model_name=model_name,
|
121
|
+
statement_params=statement_params,
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
res = self._model_client.show_models(
|
125
|
+
statement_params=statement_params,
|
126
|
+
)
|
127
|
+
return [sql_identifier.SqlIdentifier(row.name, case_sensitive=True) for row in res]
|
128
|
+
|
129
|
+
def validate_existence(
|
130
|
+
self,
|
131
|
+
*,
|
132
|
+
model_name: sql_identifier.SqlIdentifier,
|
133
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
134
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
135
|
+
) -> bool:
|
136
|
+
if version_name:
|
137
|
+
res = self._model_client.show_versions(
|
138
|
+
model_name=model_name,
|
139
|
+
version_name=version_name,
|
140
|
+
statement_params=statement_params,
|
141
|
+
)
|
142
|
+
else:
|
143
|
+
res = self._model_client.show_models(
|
144
|
+
model_name=model_name,
|
145
|
+
statement_params=statement_params,
|
146
|
+
)
|
147
|
+
return len(res) == 1
|
148
|
+
|
149
|
+
def get_comment(
|
150
|
+
self,
|
151
|
+
*,
|
152
|
+
model_name: sql_identifier.SqlIdentifier,
|
153
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
154
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
155
|
+
) -> str:
|
156
|
+
if version_name:
|
157
|
+
res = self._model_client.show_versions(
|
158
|
+
model_name=model_name,
|
159
|
+
version_name=version_name,
|
160
|
+
statement_params=statement_params,
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
res = self._model_client.show_models(
|
164
|
+
model_name=model_name,
|
165
|
+
statement_params=statement_params,
|
166
|
+
)
|
167
|
+
assert len(res) == 1
|
168
|
+
return cast(str, res[0].comment)
|
169
|
+
|
170
|
+
def set_comment(
|
171
|
+
self,
|
172
|
+
*,
|
173
|
+
comment: str,
|
174
|
+
model_name: sql_identifier.SqlIdentifier,
|
175
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
176
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
177
|
+
) -> None:
|
178
|
+
if version_name:
|
179
|
+
self._model_version_client.set_comment(
|
180
|
+
comment=comment,
|
181
|
+
model_name=model_name,
|
182
|
+
version_name=version_name,
|
183
|
+
statement_params=statement_params,
|
184
|
+
)
|
185
|
+
else:
|
186
|
+
self._model_client.set_comment(
|
187
|
+
comment=comment,
|
188
|
+
model_name=model_name,
|
189
|
+
statement_params=statement_params,
|
190
|
+
)
|
191
|
+
|
192
|
+
def get_model_version_manifest(
|
193
|
+
self,
|
194
|
+
*,
|
195
|
+
model_name: sql_identifier.SqlIdentifier,
|
196
|
+
version_name: sql_identifier.SqlIdentifier,
|
197
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
198
|
+
) -> model_manifest_schema.ModelManifestDict:
|
199
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
200
|
+
self._model_version_client.get_file(
|
201
|
+
model_name=model_name,
|
202
|
+
version_name=version_name,
|
203
|
+
file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
|
204
|
+
target_path=pathlib.Path(tmpdir),
|
205
|
+
statement_params=statement_params,
|
206
|
+
)
|
207
|
+
mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
|
208
|
+
return mm.load()
|
209
|
+
|
210
|
+
def get_model_version_native_packing_meta(
|
211
|
+
self,
|
212
|
+
*,
|
213
|
+
model_name: sql_identifier.SqlIdentifier,
|
214
|
+
version_name: sql_identifier.SqlIdentifier,
|
215
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
216
|
+
) -> model_meta_schema.ModelMetadataDict:
|
217
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
218
|
+
model_meta_file_path = self._model_version_client.get_file(
|
219
|
+
model_name=model_name,
|
220
|
+
version_name=version_name,
|
221
|
+
file_path=pathlib.PurePosixPath(
|
222
|
+
model_composer.ModelComposer.MODEL_DIR_REL_PATH, model_meta.MODEL_METADATA_FILE
|
223
|
+
),
|
224
|
+
target_path=pathlib.Path(tmpdir),
|
225
|
+
statement_params=statement_params,
|
226
|
+
)
|
227
|
+
with open(model_meta_file_path, encoding="utf-8") as f:
|
228
|
+
raw_model_meta = yaml.safe_load(f)
|
229
|
+
return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta)
|
230
|
+
|
231
|
+
def invoke_method(
|
232
|
+
self,
|
233
|
+
*,
|
234
|
+
method_name: sql_identifier.SqlIdentifier,
|
235
|
+
signature: model_signature.ModelSignature,
|
236
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
237
|
+
model_name: sql_identifier.SqlIdentifier,
|
238
|
+
version_name: sql_identifier.SqlIdentifier,
|
239
|
+
statement_params: Optional[Dict[str, str]] = None,
|
240
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
241
|
+
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
242
|
+
|
243
|
+
# Validate and prepare input
|
244
|
+
if not isinstance(X, dataframe.DataFrame):
|
245
|
+
keep_order = True
|
246
|
+
output_with_input_features = False
|
247
|
+
df = model_signature._convert_and_validate_local_data(X, signature.inputs)
|
248
|
+
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, df, keep_order=keep_order)
|
249
|
+
else:
|
250
|
+
keep_order = False
|
251
|
+
output_with_input_features = True
|
252
|
+
identifier_rule = model_signature._validate_snowpark_data(X, signature.inputs)
|
253
|
+
s_df = X
|
254
|
+
|
255
|
+
original_cols = s_df.columns
|
256
|
+
|
257
|
+
# Compose input and output names
|
258
|
+
input_args = []
|
259
|
+
for input_feature in signature.inputs:
|
260
|
+
col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
|
261
|
+
|
262
|
+
input_args.append(col_name)
|
263
|
+
|
264
|
+
returns = []
|
265
|
+
for output_feature in signature.outputs:
|
266
|
+
output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
|
267
|
+
returns.append((output_feature.name, output_feature.as_snowpark_type(), output_name))
|
268
|
+
# Avoid removing output cols when output_with_input_features is False
|
269
|
+
if output_name in original_cols:
|
270
|
+
original_cols.remove(output_name)
|
271
|
+
|
272
|
+
df_res = self._model_version_client.invoke_method(
|
273
|
+
method_name=method_name,
|
274
|
+
input_df=s_df,
|
275
|
+
input_args=input_args,
|
276
|
+
returns=returns,
|
277
|
+
model_name=model_name,
|
278
|
+
version_name=version_name,
|
279
|
+
statement_params=statement_params,
|
280
|
+
)
|
281
|
+
|
282
|
+
if keep_order:
|
283
|
+
df_res = df_res.sort(
|
284
|
+
"_ID",
|
285
|
+
ascending=True,
|
286
|
+
)
|
287
|
+
|
288
|
+
if not output_with_input_features:
|
289
|
+
df_res = df_res.drop(*original_cols)
|
290
|
+
|
291
|
+
# Get final result
|
292
|
+
if not isinstance(X, dataframe.DataFrame):
|
293
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
|
294
|
+
else:
|
295
|
+
return df_res
|
296
|
+
|
297
|
+
def delete_model_or_version(
|
298
|
+
self,
|
299
|
+
*,
|
300
|
+
model_name: sql_identifier.SqlIdentifier,
|
301
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
302
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
303
|
+
) -> None:
|
304
|
+
# TODO: Delete version is not supported yet.
|
305
|
+
self._model_client.drop_model(
|
306
|
+
model_name=model_name,
|
307
|
+
statement_params=statement_params,
|
308
|
+
)
|
@@ -0,0 +1,75 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
|
+
from snowflake.snowpark import row, session
|
5
|
+
|
6
|
+
|
7
|
+
class ModelSQLClient:
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
session: session.Session,
|
11
|
+
*,
|
12
|
+
database_name: sql_identifier.SqlIdentifier,
|
13
|
+
schema_name: sql_identifier.SqlIdentifier,
|
14
|
+
) -> None:
|
15
|
+
self._session = session
|
16
|
+
self._database_name = database_name
|
17
|
+
self._schema_name = schema_name
|
18
|
+
|
19
|
+
def __eq__(self, __value: object) -> bool:
|
20
|
+
if not isinstance(__value, ModelSQLClient):
|
21
|
+
return False
|
22
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
23
|
+
|
24
|
+
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
25
|
+
return identifier.get_schema_level_object_identifier(
|
26
|
+
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
27
|
+
)
|
28
|
+
|
29
|
+
def show_models(
|
30
|
+
self,
|
31
|
+
*,
|
32
|
+
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
33
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
34
|
+
) -> List[row.Row]:
|
35
|
+
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
36
|
+
like_sql = ""
|
37
|
+
if model_name:
|
38
|
+
like_sql = f" LIKE '{model_name.resolved()}'"
|
39
|
+
res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}")
|
40
|
+
|
41
|
+
return res.collect(statement_params=statement_params)
|
42
|
+
|
43
|
+
def show_versions(
|
44
|
+
self,
|
45
|
+
*,
|
46
|
+
model_name: sql_identifier.SqlIdentifier,
|
47
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
48
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
49
|
+
) -> List[row.Row]:
|
50
|
+
like_sql = ""
|
51
|
+
if version_name:
|
52
|
+
like_sql = f" LIKE '{version_name.resolved()}'"
|
53
|
+
res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}")
|
54
|
+
|
55
|
+
return res.collect(statement_params=statement_params)
|
56
|
+
|
57
|
+
def set_comment(
|
58
|
+
self,
|
59
|
+
*,
|
60
|
+
comment: str,
|
61
|
+
model_name: sql_identifier.SqlIdentifier,
|
62
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
63
|
+
) -> None:
|
64
|
+
comment_sql = f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$"
|
65
|
+
self._session.sql(comment_sql).collect(statement_params=statement_params)
|
66
|
+
|
67
|
+
def drop_model(
|
68
|
+
self,
|
69
|
+
*,
|
70
|
+
model_name: sql_identifier.SqlIdentifier,
|
71
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
72
|
+
) -> None:
|
73
|
+
self._session.sql(f"DROP MODEL {self.fully_qualified_model_name(model_name)}").collect(
|
74
|
+
statement_params=statement_params
|
75
|
+
)
|
@@ -0,0 +1,213 @@
|
|
1
|
+
import json
|
2
|
+
import pathlib
|
3
|
+
import textwrap
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
|
+
from urllib.parse import ParseResult
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
8
|
+
from snowflake.snowpark import dataframe, functions as F, session, types as spt
|
9
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
10
|
+
|
11
|
+
|
12
|
+
def _normalize_url_for_sql(url: str) -> str:
|
13
|
+
if url.startswith("'") and url.endswith("'"):
|
14
|
+
url = url[1:-1]
|
15
|
+
url = url.replace("'", "\\'")
|
16
|
+
return f"'{url}'"
|
17
|
+
|
18
|
+
|
19
|
+
class ModelVersionSQLClient:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
session: session.Session,
|
23
|
+
*,
|
24
|
+
database_name: sql_identifier.SqlIdentifier,
|
25
|
+
schema_name: sql_identifier.SqlIdentifier,
|
26
|
+
) -> None:
|
27
|
+
self._session = session
|
28
|
+
self._database_name = database_name
|
29
|
+
self._schema_name = schema_name
|
30
|
+
|
31
|
+
def __eq__(self, __value: object) -> bool:
|
32
|
+
if not isinstance(__value, ModelVersionSQLClient):
|
33
|
+
return False
|
34
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
35
|
+
|
36
|
+
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
37
|
+
return identifier.get_schema_level_object_identifier(
|
38
|
+
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
39
|
+
)
|
40
|
+
|
41
|
+
def create_from_stage(
|
42
|
+
self,
|
43
|
+
*,
|
44
|
+
model_name: sql_identifier.SqlIdentifier,
|
45
|
+
version_name: sql_identifier.SqlIdentifier,
|
46
|
+
stage_path: str,
|
47
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
48
|
+
) -> None:
|
49
|
+
self._version_name = version_name
|
50
|
+
self._session.sql(
|
51
|
+
f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
|
52
|
+
f" FROM {stage_path}"
|
53
|
+
).collect(statement_params=statement_params)
|
54
|
+
|
55
|
+
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
56
|
+
def add_version_from_stage(
|
57
|
+
self,
|
58
|
+
*,
|
59
|
+
model_name: sql_identifier.SqlIdentifier,
|
60
|
+
version_name: sql_identifier.SqlIdentifier,
|
61
|
+
stage_path: str,
|
62
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
63
|
+
) -> None:
|
64
|
+
self._version_name = version_name
|
65
|
+
self._session.sql(
|
66
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
|
67
|
+
f" FROM {stage_path}"
|
68
|
+
).collect(statement_params=statement_params)
|
69
|
+
|
70
|
+
def set_default_version(
|
71
|
+
self,
|
72
|
+
*,
|
73
|
+
model_name: sql_identifier.SqlIdentifier,
|
74
|
+
version_name: sql_identifier.SqlIdentifier,
|
75
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
76
|
+
) -> None:
|
77
|
+
self._session.sql(
|
78
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
|
79
|
+
f"SET DEFAULT_VERSION = {version_name.identifier()}"
|
80
|
+
).collect(statement_params=statement_params)
|
81
|
+
|
82
|
+
def get_default_version(
|
83
|
+
self,
|
84
|
+
*,
|
85
|
+
model_name: sql_identifier.SqlIdentifier,
|
86
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
87
|
+
) -> str:
|
88
|
+
# TODO: Replace SHOW with DESC when available.
|
89
|
+
default_version: str = (
|
90
|
+
self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}")
|
91
|
+
.filter('"is_default_version" = TRUE')[['"name"']]
|
92
|
+
.collect(statement_params=statement_params)[0][0]
|
93
|
+
)
|
94
|
+
return default_version
|
95
|
+
|
96
|
+
def get_file(
|
97
|
+
self,
|
98
|
+
*,
|
99
|
+
model_name: sql_identifier.SqlIdentifier,
|
100
|
+
version_name: sql_identifier.SqlIdentifier,
|
101
|
+
file_path: pathlib.PurePosixPath,
|
102
|
+
target_path: pathlib.Path,
|
103
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
104
|
+
) -> pathlib.Path:
|
105
|
+
stage_location = pathlib.PurePosixPath(
|
106
|
+
self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
|
107
|
+
).as_posix()
|
108
|
+
stage_location_url = ParseResult(
|
109
|
+
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
110
|
+
).geturl()
|
111
|
+
local_location = target_path.absolute().as_posix()
|
112
|
+
local_location_url = ParseResult(
|
113
|
+
scheme="file", netloc="", path=local_location, params="", query="", fragment=""
|
114
|
+
).geturl()
|
115
|
+
|
116
|
+
self._session.sql(
|
117
|
+
f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}"
|
118
|
+
).collect(statement_params=statement_params)
|
119
|
+
return target_path / file_path.name
|
120
|
+
|
121
|
+
def set_comment(
|
122
|
+
self,
|
123
|
+
*,
|
124
|
+
comment: str,
|
125
|
+
model_name: sql_identifier.SqlIdentifier,
|
126
|
+
version_name: sql_identifier.SqlIdentifier,
|
127
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
128
|
+
) -> None:
|
129
|
+
comment_sql = (
|
130
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
|
131
|
+
f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
|
132
|
+
)
|
133
|
+
self._session.sql(comment_sql).collect(statement_params=statement_params)
|
134
|
+
|
135
|
+
def invoke_method(
|
136
|
+
self,
|
137
|
+
*,
|
138
|
+
model_name: sql_identifier.SqlIdentifier,
|
139
|
+
version_name: sql_identifier.SqlIdentifier,
|
140
|
+
method_name: sql_identifier.SqlIdentifier,
|
141
|
+
input_df: dataframe.DataFrame,
|
142
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
143
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
144
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
145
|
+
) -> dataframe.DataFrame:
|
146
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
147
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
148
|
+
self._database_name.identifier(),
|
149
|
+
self._schema_name.identifier(),
|
150
|
+
tmp_table_name,
|
151
|
+
)
|
152
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
153
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
154
|
+
mode="errorifexists",
|
155
|
+
table_type="temporary",
|
156
|
+
statement_params=statement_params,
|
157
|
+
)
|
158
|
+
|
159
|
+
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
160
|
+
|
161
|
+
module_version_alias = "MODEL_VERSION_ALIAS"
|
162
|
+
model_version_alias_sql = (
|
163
|
+
f"WITH {module_version_alias} AS "
|
164
|
+
f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
|
165
|
+
)
|
166
|
+
|
167
|
+
args_sql_list = []
|
168
|
+
for input_arg_value in input_args:
|
169
|
+
args_sql_list.append(input_arg_value)
|
170
|
+
|
171
|
+
args_sql = ", ".join(args_sql_list)
|
172
|
+
|
173
|
+
sql = textwrap.dedent(
|
174
|
+
f"""{model_version_alias_sql}
|
175
|
+
SELECT *,
|
176
|
+
{module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
177
|
+
FROM {INTERMEDIATE_TABLE_NAME}"""
|
178
|
+
)
|
179
|
+
|
180
|
+
output_df = self._session.sql(sql)
|
181
|
+
|
182
|
+
# Prepare the output
|
183
|
+
output_cols = []
|
184
|
+
output_names = []
|
185
|
+
|
186
|
+
for output_name, output_type, output_col_name in returns:
|
187
|
+
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
|
188
|
+
output_names.append(output_col_name)
|
189
|
+
|
190
|
+
output_df = output_df.with_columns(
|
191
|
+
col_names=output_names,
|
192
|
+
values=output_cols,
|
193
|
+
).drop(INTERMEDIATE_OBJ_NAME)
|
194
|
+
|
195
|
+
if statement_params:
|
196
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
197
|
+
|
198
|
+
return output_df
|
199
|
+
|
200
|
+
def set_metadata(
|
201
|
+
self,
|
202
|
+
metadata_dict: Dict[str, Any],
|
203
|
+
*,
|
204
|
+
model_name: sql_identifier.SqlIdentifier,
|
205
|
+
version_name: sql_identifier.SqlIdentifier,
|
206
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
207
|
+
) -> None:
|
208
|
+
json_metadata = json.dumps(metadata_dict)
|
209
|
+
sql = (
|
210
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
|
211
|
+
f" SET METADATA=$${json_metadata}$$"
|
212
|
+
)
|
213
|
+
self._session.sql(sql).collect(statement_params=statement_params)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
from typing import Any, Dict, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class StageSQLClient:
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
session: session.Session,
|
11
|
+
*,
|
12
|
+
database_name: sql_identifier.SqlIdentifier,
|
13
|
+
schema_name: sql_identifier.SqlIdentifier,
|
14
|
+
) -> None:
|
15
|
+
self._session = session
|
16
|
+
self._database_name = database_name
|
17
|
+
self._schema_name = schema_name
|
18
|
+
|
19
|
+
def __eq__(self, __value: object) -> bool:
|
20
|
+
if not isinstance(__value, StageSQLClient):
|
21
|
+
return False
|
22
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
23
|
+
|
24
|
+
def fully_qualified_stage_name(
|
25
|
+
self,
|
26
|
+
stage_name: sql_identifier.SqlIdentifier,
|
27
|
+
) -> str:
|
28
|
+
return identifier.get_schema_level_object_identifier(
|
29
|
+
self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
|
30
|
+
)
|
31
|
+
|
32
|
+
def create_tmp_stage(
|
33
|
+
self,
|
34
|
+
*,
|
35
|
+
stage_name: sql_identifier.SqlIdentifier,
|
36
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
37
|
+
) -> None:
|
38
|
+
self._session.sql(f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}").collect(
|
39
|
+
statement_params=statement_params
|
40
|
+
)
|
@@ -4,7 +4,6 @@ import posixpath
|
|
4
4
|
from string import Template
|
5
5
|
|
6
6
|
import importlib_resources
|
7
|
-
import yaml
|
8
7
|
|
9
8
|
from snowflake import snowpark
|
10
9
|
from snowflake.ml._internal import file_utils
|
@@ -180,7 +179,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
180
179
|
assert self.artifact_stage_location.startswith("@")
|
181
180
|
normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
|
182
181
|
(db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path)
|
183
|
-
content = Template(spec_template).
|
182
|
+
content = Template(spec_template).safe_substitute(
|
184
183
|
{
|
185
184
|
"base_image": base_image,
|
186
185
|
"container_name": constants.KANIKO_CONTAINER_NAME,
|
@@ -188,10 +187,10 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
188
187
|
# Remove @ in the beginning, append "/" to denote root directory.
|
189
188
|
"script_path": "/"
|
190
189
|
+ posixpath.normpath(identifier.remove_prefix(kaniko_shell_script_stage_location, "@")),
|
190
|
+
"mounted_token_path": constants.SPCS_MOUNTED_TOKEN_PATH,
|
191
191
|
}
|
192
192
|
)
|
193
|
-
|
194
|
-
yaml.dump(content_dict, spec_file)
|
193
|
+
spec_file.write(content)
|
195
194
|
spec_file.seek(0)
|
196
195
|
logger.debug(f"Kaniko job spec file: \n\n {spec_file.read()}")
|
197
196
|
|