snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +283 -0
- snowflake/ml/feature_store/feature_store.py +160 -100
- snowflake/ml/feature_store/feature_view.py +30 -19
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +2 -30
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +174 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +111 -42
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +8 -4
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
- snowflake/ml/modeling/cluster/birch.py +8 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
- snowflake/ml/modeling/cluster/dbscan.py +8 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
- snowflake/ml/modeling/cluster/k_means.py +8 -1
- snowflake/ml/modeling/cluster/mean_shift.py +8 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
- snowflake/ml/modeling/cluster/optics.py +8 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
- snowflake/ml/modeling/compose/column_transformer.py +8 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
- snowflake/ml/modeling/covariance/oas.py +8 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/pca.py +8 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
- snowflake/ml/modeling/impute/knn_imputer.py +8 -1
- snowflake/ml/modeling/impute/missing_indicator.py +8 -1
- snowflake/ml/modeling/impute/simple_imputer.py +21 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/lars.py +8 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/perceptron.py +8 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ridge.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
- snowflake/ml/modeling/manifold/isomap.py +8 -1
- snowflake/ml/modeling/manifold/mds.py +8 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
- snowflake/ml/modeling/manifold/tsne.py +8 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
- snowflake/ml/modeling/svm/linear_svc.py +8 -1
- snowflake/ml/modeling/svm/linear_svr.py +8 -1
- snowflake/ml/modeling/svm/nu_svc.py +8 -1
- snowflake/ml/modeling/svm/nu_svr.py +8 -1
- snowflake/ml/modeling/svm/svc.py +8 -1
- snowflake/ml/modeling/svm/svr.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
- snowflake/ml/registry/_manager/model_manager.py +95 -8
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
|
|
9
9
|
query_result_checker,
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
|
-
from snowflake.
|
12
|
+
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
13
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
14
15
|
|
15
16
|
|
@@ -20,44 +21,51 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
20
21
|
return f"'{url}'"
|
21
22
|
|
22
23
|
|
23
|
-
class ModelVersionSQLClient:
|
24
|
+
class ModelVersionSQLClient(_base._BaseSQLClient):
|
24
25
|
FUNCTION_NAME_COL_NAME = "name"
|
25
26
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
26
27
|
|
27
|
-
def
|
28
|
+
def create_from_stage(
|
28
29
|
self,
|
29
|
-
session: session.Session,
|
30
30
|
*,
|
31
|
-
database_name: sql_identifier.SqlIdentifier,
|
32
|
-
schema_name: sql_identifier.SqlIdentifier,
|
31
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
32
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
33
|
+
model_name: sql_identifier.SqlIdentifier,
|
34
|
+
version_name: sql_identifier.SqlIdentifier,
|
35
|
+
stage_path: str,
|
36
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
33
37
|
) -> None:
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
44
|
-
return identifier.get_schema_level_object_identifier(
|
45
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
46
|
-
)
|
38
|
+
query_result_checker.SqlResultValidator(
|
39
|
+
self._session,
|
40
|
+
(
|
41
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
42
|
+
f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
|
43
|
+
),
|
44
|
+
statement_params=statement_params,
|
45
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
47
46
|
|
48
|
-
def
|
47
|
+
def create_from_model_version(
|
49
48
|
self,
|
50
49
|
*,
|
50
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
51
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
53
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
54
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
51
56
|
model_name: sql_identifier.SqlIdentifier,
|
52
57
|
version_name: sql_identifier.SqlIdentifier,
|
53
|
-
stage_path: str,
|
54
58
|
statement_params: Optional[Dict[str, Any]] = None,
|
55
59
|
) -> None:
|
60
|
+
fq_source_model_name = self.fully_qualified_object_name(
|
61
|
+
source_database_name, source_schema_name, source_model_name
|
62
|
+
)
|
63
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
56
64
|
query_result_checker.SqlResultValidator(
|
57
65
|
self._session,
|
58
66
|
(
|
59
|
-
f"CREATE MODEL {
|
60
|
-
f"
|
67
|
+
f"CREATE MODEL {fq_model_name} WITH VERSION {version_name} FROM MODEL {fq_source_model_name}"
|
68
|
+
f" VERSION {source_version_name}"
|
61
69
|
),
|
62
70
|
statement_params=statement_params,
|
63
71
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -66,6 +74,8 @@ class ModelVersionSQLClient:
|
|
66
74
|
def add_version_from_stage(
|
67
75
|
self,
|
68
76
|
*,
|
77
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
78
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
69
79
|
model_name: sql_identifier.SqlIdentifier,
|
70
80
|
version_name: sql_identifier.SqlIdentifier,
|
71
81
|
stage_path: str,
|
@@ -74,8 +84,34 @@ class ModelVersionSQLClient:
|
|
74
84
|
query_result_checker.SqlResultValidator(
|
75
85
|
self._session,
|
76
86
|
(
|
77
|
-
f"ALTER MODEL {self.
|
78
|
-
f" FROM {stage_path}"
|
87
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
88
|
+
f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
|
89
|
+
),
|
90
|
+
statement_params=statement_params,
|
91
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
92
|
+
|
93
|
+
def add_version_from_model_version(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
97
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
99
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
100
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
101
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
102
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
|
+
version_name: sql_identifier.SqlIdentifier,
|
104
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
105
|
+
) -> None:
|
106
|
+
fq_source_model_name = self.fully_qualified_object_name(
|
107
|
+
source_database_name, source_schema_name, source_model_name
|
108
|
+
)
|
109
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
110
|
+
query_result_checker.SqlResultValidator(
|
111
|
+
self._session,
|
112
|
+
(
|
113
|
+
f"ALTER MODEL {fq_model_name} ADD VERSION {version_name} FROM MODEL {fq_source_model_name}"
|
114
|
+
f" VERSION {source_version_name}"
|
79
115
|
),
|
80
116
|
statement_params=statement_params,
|
81
117
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -83,6 +119,8 @@ class ModelVersionSQLClient:
|
|
83
119
|
def set_default_version(
|
84
120
|
self,
|
85
121
|
*,
|
122
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
123
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
86
124
|
model_name: sql_identifier.SqlIdentifier,
|
87
125
|
version_name: sql_identifier.SqlIdentifier,
|
88
126
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -90,7 +128,7 @@ class ModelVersionSQLClient:
|
|
90
128
|
query_result_checker.SqlResultValidator(
|
91
129
|
self._session,
|
92
130
|
(
|
93
|
-
f"ALTER MODEL {self.
|
131
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
94
132
|
f"SET DEFAULT_VERSION = {version_name.identifier()}"
|
95
133
|
),
|
96
134
|
statement_params=statement_params,
|
@@ -99,6 +137,8 @@ class ModelVersionSQLClient:
|
|
99
137
|
def list_file(
|
100
138
|
self,
|
101
139
|
*,
|
140
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
141
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
102
142
|
model_name: sql_identifier.SqlIdentifier,
|
103
143
|
version_name: sql_identifier.SqlIdentifier,
|
104
144
|
file_path: pathlib.PurePosixPath,
|
@@ -110,7 +150,10 @@ class ModelVersionSQLClient:
|
|
110
150
|
|
111
151
|
stage_location = (
|
112
152
|
pathlib.PurePosixPath(
|
113
|
-
self.
|
153
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
154
|
+
"versions",
|
155
|
+
version_name.resolved(),
|
156
|
+
file_path,
|
114
157
|
).as_posix()
|
115
158
|
+ trailing_slash
|
116
159
|
)
|
@@ -124,13 +167,15 @@ class ModelVersionSQLClient:
|
|
124
167
|
f"List {_normalize_url_for_sql(stage_location_url)}",
|
125
168
|
statement_params=statement_params,
|
126
169
|
)
|
127
|
-
.has_column("name")
|
170
|
+
.has_column("name", allow_empty=True)
|
128
171
|
.validate()
|
129
172
|
)
|
130
173
|
|
131
174
|
def get_file(
|
132
175
|
self,
|
133
176
|
*,
|
177
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
178
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
134
179
|
model_name: sql_identifier.SqlIdentifier,
|
135
180
|
version_name: sql_identifier.SqlIdentifier,
|
136
181
|
file_path: pathlib.PurePosixPath,
|
@@ -138,7 +183,10 @@ class ModelVersionSQLClient:
|
|
138
183
|
statement_params: Optional[Dict[str, Any]] = None,
|
139
184
|
) -> pathlib.Path:
|
140
185
|
stage_location = pathlib.PurePosixPath(
|
141
|
-
self.
|
186
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
187
|
+
"versions",
|
188
|
+
version_name.resolved(),
|
189
|
+
file_path,
|
142
190
|
).as_posix()
|
143
191
|
stage_location_url = ParseResult(
|
144
192
|
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
@@ -149,7 +197,7 @@ class ModelVersionSQLClient:
|
|
149
197
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
150
198
|
options = {"parallel": 10}
|
151
199
|
cursor = self._session._conn._cursor
|
152
|
-
cursor._download(stage_location_url, str(target_path), options) # type: ignore[attr
|
200
|
+
cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
|
153
201
|
cursor.fetchall()
|
154
202
|
else:
|
155
203
|
query_result_checker.SqlResultValidator(
|
@@ -162,6 +210,8 @@ class ModelVersionSQLClient:
|
|
162
210
|
def show_functions(
|
163
211
|
self,
|
164
212
|
*,
|
213
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
214
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
165
215
|
model_name: sql_identifier.SqlIdentifier,
|
166
216
|
version_name: sql_identifier.SqlIdentifier,
|
167
217
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -169,7 +219,7 @@ class ModelVersionSQLClient:
|
|
169
219
|
res = query_result_checker.SqlResultValidator(
|
170
220
|
self._session,
|
171
221
|
(
|
172
|
-
f"SHOW FUNCTIONS IN MODEL {self.
|
222
|
+
f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
173
223
|
f" VERSION {version_name.identifier()}"
|
174
224
|
),
|
175
225
|
statement_params=statement_params,
|
@@ -180,15 +230,17 @@ class ModelVersionSQLClient:
|
|
180
230
|
def set_comment(
|
181
231
|
self,
|
182
232
|
*,
|
183
|
-
|
233
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
234
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
184
235
|
model_name: sql_identifier.SqlIdentifier,
|
185
236
|
version_name: sql_identifier.SqlIdentifier,
|
237
|
+
comment: str,
|
186
238
|
statement_params: Optional[Dict[str, Any]] = None,
|
187
239
|
) -> None:
|
188
240
|
query_result_checker.SqlResultValidator(
|
189
241
|
self._session,
|
190
242
|
(
|
191
|
-
f"ALTER MODEL {self.
|
243
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
192
244
|
f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
|
193
245
|
),
|
194
246
|
statement_params=statement_params,
|
@@ -197,6 +249,8 @@ class ModelVersionSQLClient:
|
|
197
249
|
def invoke_function_method(
|
198
250
|
self,
|
199
251
|
*,
|
252
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
253
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
200
254
|
model_name: sql_identifier.SqlIdentifier,
|
201
255
|
version_name: sql_identifier.SqlIdentifier,
|
202
256
|
method_name: sql_identifier.SqlIdentifier,
|
@@ -210,10 +264,12 @@ class ModelVersionSQLClient:
|
|
210
264
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
211
265
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
212
266
|
else:
|
267
|
+
actual_database_name = database_name or self._database_name
|
268
|
+
actual_schema_name = schema_name or self._schema_name
|
213
269
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
214
270
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
215
|
-
|
216
|
-
|
271
|
+
actual_database_name.identifier(),
|
272
|
+
actual_schema_name.identifier(),
|
217
273
|
tmp_table_name,
|
218
274
|
)
|
219
275
|
input_df.write.save_as_table( # type: ignore[call-overload]
|
@@ -228,7 +284,8 @@ class ModelVersionSQLClient:
|
|
228
284
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
229
285
|
with_statements.append(
|
230
286
|
f"{module_version_alias} AS "
|
231
|
-
f"MODEL {self.
|
287
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
288
|
+
f" VERSION {version_name.identifier()}"
|
232
289
|
)
|
233
290
|
|
234
291
|
args_sql_list = []
|
@@ -267,6 +324,8 @@ class ModelVersionSQLClient:
|
|
267
324
|
def invoke_table_function_method(
|
268
325
|
self,
|
269
326
|
*,
|
327
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
328
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
270
329
|
model_name: sql_identifier.SqlIdentifier,
|
271
330
|
version_name: sql_identifier.SqlIdentifier,
|
272
331
|
method_name: sql_identifier.SqlIdentifier,
|
@@ -281,10 +340,12 @@ class ModelVersionSQLClient:
|
|
281
340
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
282
341
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
283
342
|
else:
|
343
|
+
actual_database_name = database_name or self._database_name
|
344
|
+
actual_schema_name = schema_name or self._schema_name
|
284
345
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
285
346
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
286
|
-
|
287
|
-
|
347
|
+
actual_database_name.identifier(),
|
348
|
+
actual_schema_name.identifier(),
|
288
349
|
tmp_table_name,
|
289
350
|
)
|
290
351
|
input_df.write.save_as_table( # type: ignore[call-overload]
|
@@ -297,7 +358,8 @@ class ModelVersionSQLClient:
|
|
297
358
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
298
359
|
with_statements.append(
|
299
360
|
f"{module_version_alias} AS "
|
300
|
-
f"MODEL {self.
|
361
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
362
|
+
f" VERSION {version_name.identifier()}"
|
301
363
|
)
|
302
364
|
|
303
365
|
partition_by = partition_column.identifier() if partition_column is not None else "1"
|
@@ -344,6 +406,8 @@ class ModelVersionSQLClient:
|
|
344
406
|
self,
|
345
407
|
metadata_dict: Dict[str, Any],
|
346
408
|
*,
|
409
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
410
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
347
411
|
model_name: sql_identifier.SqlIdentifier,
|
348
412
|
version_name: sql_identifier.SqlIdentifier,
|
349
413
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -352,8 +416,8 @@ class ModelVersionSQLClient:
|
|
352
416
|
query_result_checker.SqlResultValidator(
|
353
417
|
self._session,
|
354
418
|
(
|
355
|
-
f"ALTER MODEL {self.
|
356
|
-
f" SET METADATA=$${json_metadata}$$"
|
419
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
420
|
+
f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
|
357
421
|
),
|
358
422
|
statement_params=statement_params,
|
359
423
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -361,12 +425,17 @@ class ModelVersionSQLClient:
|
|
361
425
|
def drop_version(
|
362
426
|
self,
|
363
427
|
*,
|
428
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
429
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
364
430
|
model_name: sql_identifier.SqlIdentifier,
|
365
431
|
version_name: sql_identifier.SqlIdentifier,
|
366
432
|
statement_params: Optional[Dict[str, Any]] = None,
|
367
433
|
) -> None:
|
368
434
|
query_result_checker.SqlResultValidator(
|
369
435
|
self._session,
|
370
|
-
|
436
|
+
(
|
437
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
438
|
+
f" DROP VERSION {version_name.identifier()}"
|
439
|
+
),
|
371
440
|
statement_params=statement_params,
|
372
441
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,46 +1,20 @@
|
|
1
1
|
from typing import Any, Dict, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
query_result_checker,
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
9
5
|
|
10
6
|
|
11
|
-
class StageSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, StageSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_stage_name(
|
29
|
-
self,
|
30
|
-
stage_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
|
34
|
-
)
|
35
|
-
|
7
|
+
class StageSQLClient(_base._BaseSQLClient):
|
36
8
|
def create_tmp_stage(
|
37
9
|
self,
|
38
10
|
*,
|
11
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
12
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
39
13
|
stage_name: sql_identifier.SqlIdentifier,
|
40
14
|
statement_params: Optional[Dict[str, Any]] = None,
|
41
15
|
) -> None:
|
42
16
|
query_result_checker.SqlResultValidator(
|
43
17
|
self._session,
|
44
|
-
f"CREATE TEMPORARY STAGE {self.
|
18
|
+
f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
45
19
|
statement_params=statement_params,
|
46
20
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,52 +1,25 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModuleTagSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, ModuleTagSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_module_name(
|
29
|
-
self,
|
30
|
-
module_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
|
34
|
-
)
|
35
|
-
|
8
|
+
class ModuleTagSQLClient(_base._BaseSQLClient):
|
36
9
|
def set_tag_on_model(
|
37
10
|
self,
|
38
|
-
model_name: sql_identifier.SqlIdentifier,
|
39
11
|
*,
|
40
|
-
|
41
|
-
|
12
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
13
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
14
|
+
model_name: sql_identifier.SqlIdentifier,
|
15
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
16
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
42
17
|
tag_name: sql_identifier.SqlIdentifier,
|
43
18
|
tag_value: str,
|
44
19
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
20
|
) -> None:
|
46
|
-
fq_model_name = self.
|
47
|
-
fq_tag_name =
|
48
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
49
|
-
)
|
21
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
22
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
50
23
|
query_result_checker.SqlResultValidator(
|
51
24
|
self._session,
|
52
25
|
f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
|
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
|
|
55
28
|
|
56
29
|
def unset_tag_on_model(
|
57
30
|
self,
|
58
|
-
model_name: sql_identifier.SqlIdentifier,
|
59
31
|
*,
|
60
|
-
|
61
|
-
|
32
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
33
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
34
|
+
model_name: sql_identifier.SqlIdentifier,
|
35
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
36
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
62
37
|
tag_name: sql_identifier.SqlIdentifier,
|
63
38
|
statement_params: Optional[Dict[str, Any]] = None,
|
64
39
|
) -> None:
|
65
|
-
fq_model_name = self.
|
66
|
-
fq_tag_name =
|
67
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
68
|
-
)
|
40
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
41
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
69
42
|
query_result_checker.SqlResultValidator(
|
70
43
|
self._session,
|
71
44
|
f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
|
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
|
|
74
47
|
|
75
48
|
def get_tag_value(
|
76
49
|
self,
|
77
|
-
module_name: sql_identifier.SqlIdentifier,
|
78
50
|
*,
|
79
|
-
|
80
|
-
|
51
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
model_name: sql_identifier.SqlIdentifier,
|
54
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
81
56
|
tag_name: sql_identifier.SqlIdentifier,
|
82
57
|
statement_params: Optional[Dict[str, Any]] = None,
|
83
58
|
) -> row.Row:
|
84
|
-
|
85
|
-
fq_tag_name =
|
86
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
87
|
-
)
|
59
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
60
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
88
61
|
return (
|
89
62
|
query_result_checker.SqlResultValidator(
|
90
63
|
self._session,
|
91
|
-
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${
|
64
|
+
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
|
92
65
|
statement_params=statement_params,
|
93
66
|
)
|
94
67
|
.has_dimensions(expected_rows=1, expected_cols=1)
|
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
|
|
98
71
|
|
99
72
|
def get_tag_list(
|
100
73
|
self,
|
101
|
-
module_name: sql_identifier.SqlIdentifier,
|
102
74
|
*,
|
75
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
76
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
77
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
78
|
statement_params: Optional[Dict[str, Any]] = None,
|
104
79
|
) -> List[row.Row]:
|
105
|
-
|
80
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
81
|
+
actual_database_name = database_name or self._database_name
|
106
82
|
return (
|
107
83
|
query_result_checker.SqlResultValidator(
|
108
84
|
self._session,
|
109
85
|
f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
|
110
|
-
FROM TABLE({
|
86
|
+
FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
|
111
87
|
statement_params=statement_params,
|
112
88
|
)
|
113
89
|
.has_column("TAG_DATABASE", allow_empty=True)
|
@@ -11,7 +11,7 @@ from packaging import requirements
|
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
13
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
-
from snowflake.ml._internal.lineage import data_source
|
14
|
+
from snowflake.ml._internal.lineage import data_source, lineage_utils
|
15
15
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
16
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
17
|
from snowflake.ml.model._packager import model_packager
|
@@ -136,7 +136,7 @@ class ModelComposer:
|
|
136
136
|
model_meta=self.packager.meta,
|
137
137
|
model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
|
138
138
|
options=options,
|
139
|
-
data_sources=self._get_data_sources(model),
|
139
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
140
140
|
)
|
141
141
|
|
142
142
|
file_utils.upload_directory_to_stage(
|
@@ -179,8 +179,12 @@ class ModelComposer:
|
|
179
179
|
mp.load(meta_only=meta_only, options=options)
|
180
180
|
return mp
|
181
181
|
|
182
|
-
def _get_data_sources(
|
183
|
-
|
182
|
+
def _get_data_sources(
|
183
|
+
self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
|
184
|
+
) -> Optional[List[data_source.DataSource]]:
|
185
|
+
data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
|
186
|
+
if not data_sources and sample_input_data is not None:
|
187
|
+
data_sources = getattr(sample_input_data, lineage_utils.DATA_SOURCES_ATTR, None)
|
184
188
|
if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
|
185
189
|
return data_sources
|
186
190
|
return None
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import pathlib
|
2
3
|
import tempfile
|
3
4
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
5
|
|
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
|
|
45
46
|
if not os.path.exists(conda_env_file_path):
|
46
47
|
raise ValueError("Cannot load MLFlow model dependencies.")
|
47
48
|
|
48
|
-
env.load_from_conda_file(conda_env_file_path)
|
49
|
+
env.load_from_conda_file(pathlib.Path(conda_env_file_path))
|
49
50
|
|
50
51
|
return env
|
51
52
|
|
@@ -281,9 +281,7 @@ class ModelMetadata:
|
|
281
281
|
"cpu": model_runtime.ModelRuntime("cpu", self.env),
|
282
282
|
}
|
283
283
|
if self.env.cuda_version:
|
284
|
-
runtimes.update(
|
285
|
-
{"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True, server_availability_source="conda")}
|
286
|
-
)
|
284
|
+
runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)})
|
287
285
|
return runtimes
|
288
286
|
|
289
287
|
def save(self, model_dir_path: str) -> None:
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
import pathlib
|
3
3
|
import warnings
|
4
|
-
from typing import List,
|
4
|
+
from typing import List, Optional
|
5
5
|
|
6
6
|
from packaging import requirements
|
7
7
|
|
8
|
-
from snowflake.ml._internal import
|
8
|
+
from snowflake.ml._internal import env_utils, file_utils
|
9
9
|
from snowflake.ml.model._packager.model_env import model_env
|
10
10
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
11
11
|
from snowflake.ml.model._packager.model_runtime import (
|
@@ -37,7 +37,6 @@ class ModelRuntime:
|
|
37
37
|
env: model_env.ModelEnv,
|
38
38
|
imports: Optional[List[pathlib.PurePosixPath]] = None,
|
39
39
|
is_gpu: bool = False,
|
40
|
-
server_availability_source: Literal["snowflake", "conda"] = "snowflake",
|
41
40
|
loading_from_file: bool = False,
|
42
41
|
) -> None:
|
43
42
|
self.name = name
|
@@ -48,30 +47,7 @@ class ModelRuntime:
|
|
48
47
|
return
|
49
48
|
|
50
49
|
snowml_pkg_spec = f"{env_utils.SNOWPARK_ML_PKG_NAME}=={self.runtime_env.snowpark_ml_version}"
|
51
|
-
|
52
|
-
self.embed_local_ml_library = True
|
53
|
-
else:
|
54
|
-
if server_availability_source == "snowflake":
|
55
|
-
snowml_server_availability = (
|
56
|
-
len(
|
57
|
-
env_utils.get_matched_package_versions_in_information_schema_with_active_session(
|
58
|
-
reqs=[requirements.Requirement(snowml_pkg_spec)],
|
59
|
-
python_version=snowml_env.PYTHON_VERSION,
|
60
|
-
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
61
|
-
)
|
62
|
-
>= 1
|
63
|
-
)
|
64
|
-
else:
|
65
|
-
snowml_server_availability = (
|
66
|
-
len(
|
67
|
-
env_utils.get_matched_package_versions_in_snowflake_conda_channel(
|
68
|
-
req=requirements.Requirement(snowml_pkg_spec),
|
69
|
-
python_version=snowml_env.PYTHON_VERSION,
|
70
|
-
)
|
71
|
-
)
|
72
|
-
>= 1
|
73
|
-
)
|
74
|
-
self.embed_local_ml_library = not snowml_server_availability
|
50
|
+
self.embed_local_ml_library = self.runtime_env._snowpark_ml_version.local
|
75
51
|
|
76
52
|
additional_package = (
|
77
53
|
_SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES if self.embed_local_ml_library else [snowml_pkg_spec]
|