snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -4,7 +4,11 @@ import textwrap
|
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
5
5
|
from urllib.parse import ParseResult
|
6
6
|
|
7
|
-
from snowflake.ml._internal.utils import
|
7
|
+
from snowflake.ml._internal.utils import (
|
8
|
+
identifier,
|
9
|
+
query_result_checker,
|
10
|
+
sql_identifier,
|
11
|
+
)
|
8
12
|
from snowflake.snowpark import dataframe, functions as F, session, types as spt
|
9
13
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
10
14
|
|
@@ -46,11 +50,14 @@ class ModelVersionSQLClient:
|
|
46
50
|
stage_path: str,
|
47
51
|
statement_params: Optional[Dict[str, Any]] = None,
|
48
52
|
) -> None:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
53
|
+
query_result_checker.SqlResultValidator(
|
54
|
+
self._session,
|
55
|
+
(
|
56
|
+
f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
|
57
|
+
f" FROM {stage_path}"
|
58
|
+
),
|
59
|
+
statement_params=statement_params,
|
60
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
54
61
|
|
55
62
|
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
56
63
|
def add_version_from_stage(
|
@@ -61,11 +68,14 @@ class ModelVersionSQLClient:
|
|
61
68
|
stage_path: str,
|
62
69
|
statement_params: Optional[Dict[str, Any]] = None,
|
63
70
|
) -> None:
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
71
|
+
query_result_checker.SqlResultValidator(
|
72
|
+
self._session,
|
73
|
+
(
|
74
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
|
75
|
+
f" FROM {stage_path}"
|
76
|
+
),
|
77
|
+
statement_params=statement_params,
|
78
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
69
79
|
|
70
80
|
def set_default_version(
|
71
81
|
self,
|
@@ -74,24 +84,14 @@ class ModelVersionSQLClient:
|
|
74
84
|
version_name: sql_identifier.SqlIdentifier,
|
75
85
|
statement_params: Optional[Dict[str, Any]] = None,
|
76
86
|
) -> None:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
model_name: sql_identifier.SqlIdentifier,
|
86
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
87
|
-
) -> str:
|
88
|
-
# TODO: Replace SHOW with DESC when available.
|
89
|
-
default_version: str = (
|
90
|
-
self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}")
|
91
|
-
.filter('"is_default_version" = TRUE')[['"name"']]
|
92
|
-
.collect(statement_params=statement_params)[0][0]
|
93
|
-
)
|
94
|
-
return default_version
|
87
|
+
query_result_checker.SqlResultValidator(
|
88
|
+
self._session,
|
89
|
+
(
|
90
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
|
91
|
+
f"SET DEFAULT_VERSION = {version_name.identifier()}"
|
92
|
+
),
|
93
|
+
statement_params=statement_params,
|
94
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
95
95
|
|
96
96
|
def get_file(
|
97
97
|
self,
|
@@ -108,14 +108,14 @@ class ModelVersionSQLClient:
|
|
108
108
|
stage_location_url = ParseResult(
|
109
109
|
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
110
110
|
).geturl()
|
111
|
-
local_location = target_path.
|
112
|
-
local_location_url =
|
113
|
-
scheme="file", netloc="", path=local_location, params="", query="", fragment=""
|
114
|
-
).geturl()
|
111
|
+
local_location = target_path.resolve().as_posix()
|
112
|
+
local_location_url = f"file://{local_location}"
|
115
113
|
|
116
|
-
|
117
|
-
|
118
|
-
|
114
|
+
query_result_checker.SqlResultValidator(
|
115
|
+
self._session,
|
116
|
+
f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}",
|
117
|
+
statement_params=statement_params,
|
118
|
+
).has_dimensions(expected_rows=1).validate()
|
119
119
|
return target_path / file_path.name
|
120
120
|
|
121
121
|
def set_comment(
|
@@ -126,11 +126,14 @@ class ModelVersionSQLClient:
|
|
126
126
|
version_name: sql_identifier.SqlIdentifier,
|
127
127
|
statement_params: Optional[Dict[str, Any]] = None,
|
128
128
|
) -> None:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
129
|
+
query_result_checker.SqlResultValidator(
|
130
|
+
self._session,
|
131
|
+
(
|
132
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
|
133
|
+
f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
|
134
|
+
),
|
135
|
+
statement_params=statement_params,
|
136
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
134
137
|
|
135
138
|
def invoke_method(
|
136
139
|
self,
|
@@ -143,24 +146,29 @@ class ModelVersionSQLClient:
|
|
143
146
|
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
144
147
|
statement_params: Optional[Dict[str, Any]] = None,
|
145
148
|
) -> dataframe.DataFrame:
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
149
|
+
with_statements = []
|
150
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
151
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
152
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
153
|
+
else:
|
154
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
155
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
156
|
+
self._database_name.identifier(),
|
157
|
+
self._schema_name.identifier(),
|
158
|
+
tmp_table_name,
|
159
|
+
)
|
160
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
161
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
162
|
+
mode="errorifexists",
|
163
|
+
table_type="temporary",
|
164
|
+
statement_params=statement_params,
|
165
|
+
)
|
158
166
|
|
159
167
|
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
160
168
|
|
161
169
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
162
|
-
|
163
|
-
f"
|
170
|
+
with_statements.append(
|
171
|
+
f"{module_version_alias} AS "
|
164
172
|
f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
|
165
173
|
)
|
166
174
|
|
@@ -171,7 +179,7 @@ class ModelVersionSQLClient:
|
|
171
179
|
args_sql = ", ".join(args_sql_list)
|
172
180
|
|
173
181
|
sql = textwrap.dedent(
|
174
|
-
f"""{
|
182
|
+
f"""WITH {','.join(with_statements)}
|
175
183
|
SELECT *,
|
176
184
|
{module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
177
185
|
FROM {INTERMEDIATE_TABLE_NAME}"""
|
@@ -206,8 +214,11 @@ class ModelVersionSQLClient:
|
|
206
214
|
statement_params: Optional[Dict[str, Any]] = None,
|
207
215
|
) -> None:
|
208
216
|
json_metadata = json.dumps(metadata_dict)
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
217
|
+
query_result_checker.SqlResultValidator(
|
218
|
+
self._session,
|
219
|
+
(
|
220
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
|
221
|
+
f" SET METADATA=$${json_metadata}$$"
|
222
|
+
),
|
223
|
+
statement_params=statement_params,
|
224
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,6 +1,10 @@
|
|
1
1
|
from typing import Any, Dict, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
3
|
+
from snowflake.ml._internal.utils import (
|
4
|
+
identifier,
|
5
|
+
query_result_checker,
|
6
|
+
sql_identifier,
|
7
|
+
)
|
4
8
|
from snowflake.snowpark import session
|
5
9
|
|
6
10
|
|
@@ -35,6 +39,8 @@ class StageSQLClient:
|
|
35
39
|
stage_name: sql_identifier.SqlIdentifier,
|
36
40
|
statement_params: Optional[Dict[str, Any]] = None,
|
37
41
|
) -> None:
|
38
|
-
|
39
|
-
|
40
|
-
|
42
|
+
query_result_checker.SqlResultValidator(
|
43
|
+
self._session,
|
44
|
+
f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}",
|
45
|
+
statement_params=statement_params,
|
46
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import (
|
4
|
+
identifier,
|
5
|
+
query_result_checker,
|
6
|
+
sql_identifier,
|
7
|
+
)
|
8
|
+
from snowflake.snowpark import row, session
|
9
|
+
|
10
|
+
|
11
|
+
class ModuleTagSQLClient:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
session: session.Session,
|
15
|
+
*,
|
16
|
+
database_name: sql_identifier.SqlIdentifier,
|
17
|
+
schema_name: sql_identifier.SqlIdentifier,
|
18
|
+
) -> None:
|
19
|
+
self._session = session
|
20
|
+
self._database_name = database_name
|
21
|
+
self._schema_name = schema_name
|
22
|
+
|
23
|
+
def __eq__(self, __value: object) -> bool:
|
24
|
+
if not isinstance(__value, ModuleTagSQLClient):
|
25
|
+
return False
|
26
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
+
|
28
|
+
def fully_qualified_module_name(
|
29
|
+
self,
|
30
|
+
module_name: sql_identifier.SqlIdentifier,
|
31
|
+
) -> str:
|
32
|
+
return identifier.get_schema_level_object_identifier(
|
33
|
+
self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
|
34
|
+
)
|
35
|
+
|
36
|
+
def set_tag_on_model(
|
37
|
+
self,
|
38
|
+
model_name: sql_identifier.SqlIdentifier,
|
39
|
+
*,
|
40
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
41
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
42
|
+
tag_name: sql_identifier.SqlIdentifier,
|
43
|
+
tag_value: str,
|
44
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
45
|
+
) -> None:
|
46
|
+
fq_model_name = self.fully_qualified_module_name(model_name)
|
47
|
+
fq_tag_name = identifier.get_schema_level_object_identifier(
|
48
|
+
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
49
|
+
)
|
50
|
+
query_result_checker.SqlResultValidator(
|
51
|
+
self._session,
|
52
|
+
f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
|
53
|
+
statement_params=statement_params,
|
54
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
55
|
+
|
56
|
+
def unset_tag_on_model(
|
57
|
+
self,
|
58
|
+
model_name: sql_identifier.SqlIdentifier,
|
59
|
+
*,
|
60
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
61
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
62
|
+
tag_name: sql_identifier.SqlIdentifier,
|
63
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
64
|
+
) -> None:
|
65
|
+
fq_model_name = self.fully_qualified_module_name(model_name)
|
66
|
+
fq_tag_name = identifier.get_schema_level_object_identifier(
|
67
|
+
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
68
|
+
)
|
69
|
+
query_result_checker.SqlResultValidator(
|
70
|
+
self._session,
|
71
|
+
f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
|
72
|
+
statement_params=statement_params,
|
73
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
74
|
+
|
75
|
+
def get_tag_value(
|
76
|
+
self,
|
77
|
+
module_name: sql_identifier.SqlIdentifier,
|
78
|
+
*,
|
79
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
80
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
81
|
+
tag_name: sql_identifier.SqlIdentifier,
|
82
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
83
|
+
) -> row.Row:
|
84
|
+
fq_module_name = self.fully_qualified_module_name(module_name)
|
85
|
+
fq_tag_name = identifier.get_schema_level_object_identifier(
|
86
|
+
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
87
|
+
)
|
88
|
+
return (
|
89
|
+
query_result_checker.SqlResultValidator(
|
90
|
+
self._session,
|
91
|
+
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE",
|
92
|
+
statement_params=statement_params,
|
93
|
+
)
|
94
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
95
|
+
.has_column("TAG_VALUE")
|
96
|
+
.validate()[0]
|
97
|
+
)
|
98
|
+
|
99
|
+
def get_tag_list(
|
100
|
+
self,
|
101
|
+
module_name: sql_identifier.SqlIdentifier,
|
102
|
+
*,
|
103
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
104
|
+
) -> List[row.Row]:
|
105
|
+
fq_module_name = self.fully_qualified_module_name(module_name)
|
106
|
+
return (
|
107
|
+
query_result_checker.SqlResultValidator(
|
108
|
+
self._session,
|
109
|
+
f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
|
110
|
+
FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""",
|
111
|
+
statement_params=statement_params,
|
112
|
+
)
|
113
|
+
.has_column("TAG_DATABASE", allow_empty=True)
|
114
|
+
.has_column("TAG_SCHEMA", allow_empty=True)
|
115
|
+
.has_column("TAG_NAME", allow_empty=True)
|
116
|
+
.has_column("TAG_VALUE", allow_empty=True)
|
117
|
+
.validate()
|
118
|
+
)
|
@@ -9,11 +9,11 @@ from enum import Enum
|
|
9
9
|
from typing import List
|
10
10
|
|
11
11
|
from snowflake import snowpark
|
12
|
+
from snowflake.ml._internal.container_services.image_registry import credential
|
12
13
|
from snowflake.ml._internal.exceptions import (
|
13
14
|
error_codes,
|
14
15
|
exceptions as snowml_exceptions,
|
15
16
|
)
|
16
|
-
from snowflake.ml._internal.utils import spcs_image_registry
|
17
17
|
from snowflake.ml.model._deploy_client.image_builds import base_image_builder
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
@@ -106,7 +106,7 @@ class ClientImageBuilder(base_image_builder.ImageBuilder):
|
|
106
106
|
self._run_docker_commands(commands)
|
107
107
|
|
108
108
|
self.validate_docker_client_env()
|
109
|
-
with
|
109
|
+
with credential.generate_image_registry_credential(
|
110
110
|
self.session
|
111
111
|
) as registry_cred, tempfile.TemporaryDirectory() as docker_config_dir:
|
112
112
|
try:
|
@@ -2,7 +2,6 @@ import os
|
|
2
2
|
import posixpath
|
3
3
|
import shutil
|
4
4
|
import string
|
5
|
-
from abc import ABC
|
6
5
|
from typing import Optional
|
7
6
|
|
8
7
|
import importlib_resources
|
@@ -15,7 +14,7 @@ from snowflake.ml.model._packager.model_meta import model_meta
|
|
15
14
|
from snowflake.snowpark import FileOperation, Session
|
16
15
|
|
17
16
|
|
18
|
-
class DockerContext
|
17
|
+
class DockerContext:
|
19
18
|
"""
|
20
19
|
Constructs the Docker context directory required for image building.
|
21
20
|
"""
|
@@ -53,12 +52,13 @@ class DockerContext(ABC):
|
|
53
52
|
|
54
53
|
def _copy_entrypoint_script_to_docker_context(self) -> None:
|
55
54
|
"""Copy gunicorn_run.sh entrypoint to docker context directory."""
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
55
|
+
script_path = importlib_resources.files(image_builds).joinpath( # type: ignore[no-untyped-call]
|
56
|
+
constants.ENTRYPOINT_SCRIPT
|
57
|
+
)
|
58
|
+
target_path = os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT)
|
59
|
+
|
60
|
+
with open(script_path, encoding="utf-8") as source_file, file_utils.open_file(target_path, "w") as target_file:
|
61
|
+
target_file.write(source_file.read())
|
62
62
|
|
63
63
|
def _copy_model_env_dependency_to_docker_context(self) -> None:
|
64
64
|
"""
|
@@ -105,6 +105,8 @@ def _run_setup() -> None:
|
|
105
105
|
|
106
106
|
# TODO (Server-side Model Rollout):
|
107
107
|
# Keep try block only
|
108
|
+
# SPCS spec will convert all environment variables as strings.
|
109
|
+
use_gpu = os.environ.get("SNOWML_USE_GPU", "False").lower() == "true"
|
108
110
|
try:
|
109
111
|
from snowflake.ml.model._packager import model_packager
|
110
112
|
|
@@ -112,9 +114,7 @@ def _run_setup() -> None:
|
|
112
114
|
pk.load(
|
113
115
|
as_custom_model=True,
|
114
116
|
meta_only=False,
|
115
|
-
options=model_types.ModelLoadOption(
|
116
|
-
{"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))}
|
117
|
-
),
|
117
|
+
options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
|
118
118
|
)
|
119
119
|
_LOADED_MODEL = pk.model
|
120
120
|
_LOADED_META = pk.meta
|
@@ -132,9 +132,7 @@ def _run_setup() -> None:
|
|
132
132
|
_LOADED_MODEL, meta_LOADED_META = model_api._load(
|
133
133
|
local_dir_path=extracted_dir,
|
134
134
|
as_custom_model=True,
|
135
|
-
options=model_types.ModelLoadOption(
|
136
|
-
{"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))}
|
137
|
-
),
|
135
|
+
options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
|
138
136
|
)
|
139
137
|
_MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
|
140
138
|
logger.info("Successfully loaded model into memory")
|
@@ -7,6 +7,9 @@ import importlib_resources
|
|
7
7
|
|
8
8
|
from snowflake import snowpark
|
9
9
|
from snowflake.ml._internal import file_utils
|
10
|
+
from snowflake.ml._internal.container_services.image_registry import (
|
11
|
+
registry_client as image_registry_client,
|
12
|
+
)
|
10
13
|
from snowflake.ml._internal.exceptions import (
|
11
14
|
error_codes,
|
12
15
|
exceptions as snowml_exceptions,
|
@@ -14,11 +17,7 @@ from snowflake.ml._internal.exceptions import (
|
|
14
17
|
from snowflake.ml._internal.utils import identifier
|
15
18
|
from snowflake.ml.model._deploy_client import image_builds
|
16
19
|
from snowflake.ml.model._deploy_client.image_builds import base_image_builder
|
17
|
-
from snowflake.ml.model._deploy_client.utils import
|
18
|
-
constants,
|
19
|
-
image_registry_client,
|
20
|
-
snowservice_client,
|
21
|
-
)
|
20
|
+
from snowflake.ml.model._deploy_client.utils import constants, snowservice_client
|
22
21
|
|
23
22
|
logger = logging.getLogger(__name__)
|
24
23
|
|
@@ -117,7 +116,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
117
116
|
|
118
117
|
kaniko_shell_file = os.path.join(self.context_dir, constants.KANIKO_SHELL_SCRIPT_NAME)
|
119
118
|
|
120
|
-
with
|
119
|
+
with file_utils.open_file(kaniko_shell_file, "w+") as script_file:
|
121
120
|
normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
|
122
121
|
params = {
|
123
122
|
# Remove @ in the beginning, append "/" to denote root directory.
|
@@ -175,7 +174,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
175
174
|
os.path.dirname(self.context_dir), f"{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}.yaml"
|
176
175
|
)
|
177
176
|
|
178
|
-
with
|
177
|
+
with file_utils.open_file(spec_file_path, "w+") as spec_file:
|
179
178
|
assert self.artifact_stage_location.startswith("@")
|
180
179
|
normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
|
181
180
|
(db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path)
|
@@ -14,6 +14,9 @@ from packaging import requirements
|
|
14
14
|
from typing_extensions import Unpack
|
15
15
|
|
16
16
|
from snowflake.ml._internal import env_utils, file_utils
|
17
|
+
from snowflake.ml._internal.container_services.image_registry import (
|
18
|
+
registry_client as image_registry_client,
|
19
|
+
)
|
17
20
|
from snowflake.ml._internal.exceptions import (
|
18
21
|
error_codes,
|
19
22
|
exceptions as snowml_exceptions,
|
@@ -32,11 +35,7 @@ from snowflake.ml.model._deploy_client.image_builds import (
|
|
32
35
|
server_image_builder,
|
33
36
|
)
|
34
37
|
from snowflake.ml.model._deploy_client.snowservice import deploy_options, instance_types
|
35
|
-
from snowflake.ml.model._deploy_client.utils import
|
36
|
-
constants,
|
37
|
-
image_registry_client,
|
38
|
-
snowservice_client,
|
39
|
-
)
|
38
|
+
from snowflake.ml.model._deploy_client.utils import constants, snowservice_client
|
40
39
|
from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
|
41
40
|
from snowflake.snowpark import Session
|
42
41
|
|
@@ -1,2 +1,10 @@
|
|
1
1
|
# Snowpark Container Service GPU instance type and corresponding GPU counts.
|
2
|
-
INSTANCE_TYPE_TO_GPU_COUNT = {
|
2
|
+
INSTANCE_TYPE_TO_GPU_COUNT = {
|
3
|
+
"GPU_3": 1,
|
4
|
+
"GPU_5": 1,
|
5
|
+
"GPU_7": 4,
|
6
|
+
"GPU_10": 8,
|
7
|
+
"GPU_NV_S": 1,
|
8
|
+
"GPU_NV_M": 4,
|
9
|
+
"GPU_NV_L": 8,
|
10
|
+
}
|
@@ -2,6 +2,7 @@ import copy
|
|
2
2
|
import logging
|
3
3
|
import posixpath
|
4
4
|
import tempfile
|
5
|
+
import textwrap
|
5
6
|
from types import ModuleType
|
6
7
|
from typing import IO, List, Optional, Tuple, TypedDict, Union
|
7
8
|
|
@@ -154,7 +155,7 @@ def _get_model_final_packages(
|
|
154
155
|
Returns:
|
155
156
|
List of final packages string that is accepted by Snowpark register UDF call.
|
156
157
|
"""
|
157
|
-
|
158
|
+
|
158
159
|
if (
|
159
160
|
any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys())
|
160
161
|
or meta.env.pip_requirements
|
@@ -173,21 +174,29 @@ def _get_model_final_packages(
|
|
173
174
|
else:
|
174
175
|
required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
|
175
176
|
|
176
|
-
|
177
|
+
package_availability_dict = env_utils.get_matched_package_versions_in_information_schema(
|
177
178
|
session, required_packages, python_version=meta.env.python_version
|
178
179
|
)
|
179
|
-
|
180
|
-
|
180
|
+
no_version_available_packages = [
|
181
|
+
req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1
|
182
|
+
]
|
183
|
+
unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict]
|
184
|
+
if no_version_available_packages or unavailable_packages:
|
181
185
|
relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. "
|
186
|
+
required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages))
|
182
187
|
raise snowml_exceptions.SnowflakeMLException(
|
183
188
|
error_code=error_codes.DEPENDENCY_VERSION_ERROR,
|
184
189
|
original_exception=RuntimeError(
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
190
|
+
textwrap.dedent(
|
191
|
+
f"""
|
192
|
+
The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str}
|
193
|
+
Required packages are: {required_package_str}
|
194
|
+
Required Python version is: {meta.env.python_version}
|
195
|
+
Packages that are not available are: {unavailable_packages}
|
196
|
+
Packages that cannot meet your requirements are: {no_version_available_packages}
|
197
|
+
Package availability information of those you requested is: {package_availability_dict}
|
198
|
+
"""
|
199
|
+
),
|
191
200
|
),
|
192
201
|
)
|
193
|
-
return
|
202
|
+
return list(sorted(map(str, required_packages)))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import pathlib
|
3
|
-
from typing import List, Optional, cast
|
3
|
+
from typing import Any, Dict, List, Optional, cast
|
4
4
|
|
5
5
|
import yaml
|
6
6
|
|
@@ -83,7 +83,11 @@ class ModelManifest:
|
|
83
83
|
],
|
84
84
|
)
|
85
85
|
|
86
|
+
manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
|
87
|
+
|
86
88
|
with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
|
89
|
+
# Anchors are not supported in the server, avoid that.
|
90
|
+
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
87
91
|
yaml.safe_dump(manifest_dict, f)
|
88
92
|
|
89
93
|
def load(self) -> model_manifest_schema.ModelManifestDict:
|
@@ -99,3 +103,43 @@ class ModelManifest:
|
|
99
103
|
res = cast(model_manifest_schema.ModelManifestDict, raw_input)
|
100
104
|
|
101
105
|
return res
|
106
|
+
|
107
|
+
def generate_user_data_with_client_data(self, model_meta: model_meta_api.ModelMetadata) -> Dict[str, Any]:
|
108
|
+
client_data = model_manifest_schema.SnowparkMLDataDict(
|
109
|
+
schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION,
|
110
|
+
functions=[
|
111
|
+
model_manifest_schema.ModelFunctionInfoDict(
|
112
|
+
name=method.method_name.identifier(),
|
113
|
+
target_method=method.target_method,
|
114
|
+
signature=model_meta.signatures[method.target_method].to_dict(),
|
115
|
+
)
|
116
|
+
for method in self.methods
|
117
|
+
],
|
118
|
+
)
|
119
|
+
return {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: client_data}
|
120
|
+
|
121
|
+
@staticmethod
|
122
|
+
def parse_client_data_from_user_data(raw_user_data: Dict[str, Any]) -> model_manifest_schema.SnowparkMLDataDict:
|
123
|
+
raw_client_data = raw_user_data.get(model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME, {})
|
124
|
+
if not isinstance(raw_client_data, dict) or "schema_version" not in raw_client_data:
|
125
|
+
raise ValueError(f"Ill-formatted client data {raw_client_data} in user data found.")
|
126
|
+
loaded_client_data_schema_version = raw_client_data["schema_version"]
|
127
|
+
if (
|
128
|
+
not isinstance(loaded_client_data_schema_version, str)
|
129
|
+
or loaded_client_data_schema_version != model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION
|
130
|
+
):
|
131
|
+
raise ValueError(f"Unsupported client data schema version {loaded_client_data_schema_version} confronted.")
|
132
|
+
|
133
|
+
return_functions_info: List[model_manifest_schema.ModelFunctionInfoDict] = []
|
134
|
+
loaded_functions_info = raw_client_data.get("functions", [])
|
135
|
+
for func in loaded_functions_info:
|
136
|
+
fi = model_manifest_schema.ModelFunctionInfoDict(
|
137
|
+
name=func["name"],
|
138
|
+
target_method=func["target_method"],
|
139
|
+
signature=func["signature"],
|
140
|
+
)
|
141
|
+
return_functions_info.append(fi)
|
142
|
+
|
143
|
+
return model_manifest_schema.SnowparkMLDataDict(
|
144
|
+
schema_version=loaded_client_data_schema_version, functions=return_functions_info
|
145
|
+
)
|