snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
from types import ModuleType
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
from absl.logging import logging
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import sql_identifier
|
8
|
+
from snowflake.ml.model import model_signature, type_hints as model_types
|
9
|
+
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
10
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
11
|
+
from snowflake.ml.model._model_composer import model_composer
|
12
|
+
from snowflake.snowpark import session
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class ModelManager:
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
session: session.Session,
|
21
|
+
*,
|
22
|
+
database_name: sql_identifier.SqlIdentifier,
|
23
|
+
schema_name: sql_identifier.SqlIdentifier,
|
24
|
+
) -> None:
|
25
|
+
self._database_name = database_name
|
26
|
+
self._schema_name = schema_name
|
27
|
+
self._model_ops = model_ops.ModelOperator(
|
28
|
+
session, database_name=self._database_name, schema_name=self._schema_name
|
29
|
+
)
|
30
|
+
|
31
|
+
def log_model(
|
32
|
+
self,
|
33
|
+
model: model_types.SupportedModelType,
|
34
|
+
*,
|
35
|
+
model_name: str,
|
36
|
+
version_name: str,
|
37
|
+
comment: Optional[str] = None,
|
38
|
+
metrics: Optional[Dict[str, Any]] = None,
|
39
|
+
conda_dependencies: Optional[List[str]] = None,
|
40
|
+
pip_requirements: Optional[List[str]] = None,
|
41
|
+
python_version: Optional[str] = None,
|
42
|
+
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
43
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
44
|
+
code_paths: Optional[List[str]] = None,
|
45
|
+
ext_modules: Optional[List[ModuleType]] = None,
|
46
|
+
options: Optional[model_types.ModelSaveOption] = None,
|
47
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
48
|
+
) -> model_version_impl.ModelVersion:
|
49
|
+
model_name_id = sql_identifier.SqlIdentifier(model_name)
|
50
|
+
|
51
|
+
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
52
|
+
|
53
|
+
if self._model_ops.validate_existence(
|
54
|
+
model_name=model_name_id, statement_params=statement_params
|
55
|
+
) and self._model_ops.validate_existence(
|
56
|
+
model_name=model_name_id, version_name=version_name_id, statement_params=statement_params
|
57
|
+
):
|
58
|
+
raise ValueError(f"Model {model_name} version {version_name} already existed.")
|
59
|
+
|
60
|
+
stage_path = self._model_ops.prepare_model_stage_path(
|
61
|
+
statement_params=statement_params,
|
62
|
+
)
|
63
|
+
|
64
|
+
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
65
|
+
|
66
|
+
mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path)
|
67
|
+
mc.save(
|
68
|
+
name=model_name_id.resolved(),
|
69
|
+
model=model,
|
70
|
+
signatures=signatures,
|
71
|
+
sample_input=sample_input_data,
|
72
|
+
conda_dependencies=conda_dependencies,
|
73
|
+
pip_requirements=pip_requirements,
|
74
|
+
python_version=python_version,
|
75
|
+
code_paths=code_paths,
|
76
|
+
ext_modules=ext_modules,
|
77
|
+
options=options,
|
78
|
+
)
|
79
|
+
|
80
|
+
logger.info("Start creating MODEL object for you in the Snowflake.")
|
81
|
+
|
82
|
+
self._model_ops.create_from_stage(
|
83
|
+
composed_model=mc,
|
84
|
+
model_name=model_name_id,
|
85
|
+
version_name=version_name_id,
|
86
|
+
statement_params=statement_params,
|
87
|
+
)
|
88
|
+
|
89
|
+
mv = model_version_impl.ModelVersion._ref(
|
90
|
+
self._model_ops,
|
91
|
+
model_name=model_name_id,
|
92
|
+
version_name=version_name_id,
|
93
|
+
)
|
94
|
+
|
95
|
+
if comment:
|
96
|
+
mv.comment = comment
|
97
|
+
|
98
|
+
if metrics:
|
99
|
+
self._model_ops._metadata_ops.save(
|
100
|
+
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
101
|
+
model_name=model_name_id,
|
102
|
+
version_name=version_name_id,
|
103
|
+
statement_params=statement_params,
|
104
|
+
)
|
105
|
+
|
106
|
+
return mv
|
107
|
+
|
108
|
+
def get_model(
|
109
|
+
self,
|
110
|
+
model_name: str,
|
111
|
+
*,
|
112
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
113
|
+
) -> model_impl.Model:
|
114
|
+
model_name_id = sql_identifier.SqlIdentifier(model_name)
|
115
|
+
if self._model_ops.validate_existence(
|
116
|
+
model_name=model_name_id,
|
117
|
+
statement_params=statement_params,
|
118
|
+
):
|
119
|
+
return model_impl.Model._ref(
|
120
|
+
self._model_ops,
|
121
|
+
model_name=model_name_id,
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
raise ValueError(f"Unable to find model {model_name}")
|
125
|
+
|
126
|
+
def models(
|
127
|
+
self,
|
128
|
+
*,
|
129
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
130
|
+
) -> List[model_impl.Model]:
|
131
|
+
model_names = self._model_ops.list_models_or_versions(
|
132
|
+
statement_params=statement_params,
|
133
|
+
)
|
134
|
+
return [
|
135
|
+
model_impl.Model._ref(
|
136
|
+
self._model_ops,
|
137
|
+
model_name=model_name,
|
138
|
+
)
|
139
|
+
for model_name in model_names
|
140
|
+
]
|
141
|
+
|
142
|
+
def show_models(
|
143
|
+
self,
|
144
|
+
*,
|
145
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
146
|
+
) -> pd.DataFrame:
|
147
|
+
rows = self._model_ops.show_models_or_versions(
|
148
|
+
statement_params=statement_params,
|
149
|
+
)
|
150
|
+
return pd.DataFrame([row.as_dict() for row in rows])
|
151
|
+
|
152
|
+
def delete_model(
|
153
|
+
self,
|
154
|
+
model_name: str,
|
155
|
+
*,
|
156
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
157
|
+
) -> None:
|
158
|
+
model_name_id = sql_identifier.SqlIdentifier(model_name)
|
159
|
+
|
160
|
+
self._model_ops.delete_model_or_version(
|
161
|
+
model_name=model_name_id,
|
162
|
+
statement_params=statement_params,
|
163
|
+
)
|
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
import sys
|
4
4
|
import textwrap
|
5
5
|
import types
|
6
|
+
import warnings
|
6
7
|
from typing import (
|
7
8
|
TYPE_CHECKING,
|
8
9
|
Any,
|
@@ -305,6 +306,17 @@ class ModelRegistry:
|
|
305
306
|
schema_name: Desired name of the schema used by this model registry inside the database.
|
306
307
|
create_if_not_exists: create model registry if it's not exists already.
|
307
308
|
"""
|
309
|
+
|
310
|
+
warnings.warn(
|
311
|
+
"""
|
312
|
+
The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.
|
313
|
+
It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,
|
314
|
+
except when specifically required. The old model registry will be removed once all its primary functionalities are
|
315
|
+
fully integrated into the new registry.
|
316
|
+
""",
|
317
|
+
DeprecationWarning,
|
318
|
+
stacklevel=2,
|
319
|
+
)
|
308
320
|
if create_if_not_exists:
|
309
321
|
create_model_registry(session=session, database_name=database_name, schema_name=schema_name)
|
310
322
|
|
@@ -1,12 +1,17 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Dict, List, Optional
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import pandas as pd
|
3
5
|
|
4
6
|
from snowflake.ml._internal import telemetry
|
5
7
|
from snowflake.ml._internal.utils import sql_identifier
|
6
|
-
from snowflake.ml.model import
|
7
|
-
|
8
|
-
|
9
|
-
|
8
|
+
from snowflake.ml.model import (
|
9
|
+
Model,
|
10
|
+
ModelVersion,
|
11
|
+
model_signature,
|
12
|
+
type_hints as model_types,
|
13
|
+
)
|
14
|
+
from snowflake.ml.registry._manager import model_manager
|
10
15
|
from snowflake.snowpark import session
|
11
16
|
|
12
17
|
_TELEMETRY_PROJECT = "MLOps"
|
@@ -21,6 +26,18 @@ class Registry:
|
|
21
26
|
database_name: Optional[str] = None,
|
22
27
|
schema_name: Optional[str] = None,
|
23
28
|
) -> None:
|
29
|
+
"""Opens a registry within a pre-created Snowflake schema.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
session: The Snowpark Session to connect with Snowflake.
|
33
|
+
database_name: The name of the database. If None, the current database of the session
|
34
|
+
will be used. Defaults to None.
|
35
|
+
schema_name: The name of the schema. If None, the current schema of the session
|
36
|
+
will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: When there is no specified or active database in the session.
|
40
|
+
"""
|
24
41
|
if database_name:
|
25
42
|
self._database_name = sql_identifier.SqlIdentifier(database_name)
|
26
43
|
else:
|
@@ -42,12 +59,13 @@ class Registry:
|
|
42
59
|
else sql_identifier.SqlIdentifier("PUBLIC")
|
43
60
|
)
|
44
61
|
|
45
|
-
self.
|
62
|
+
self._model_manager = model_manager.ModelManager(
|
46
63
|
session, database_name=self._database_name, schema_name=self._schema_name
|
47
64
|
)
|
48
65
|
|
49
66
|
@property
|
50
67
|
def location(self) -> str:
|
68
|
+
"""Get the location (database.schema) of the registry."""
|
51
69
|
return ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
52
70
|
|
53
71
|
@telemetry.send_api_usage_telemetry(
|
@@ -60,6 +78,8 @@ class Registry:
|
|
60
78
|
*,
|
61
79
|
model_name: str,
|
62
80
|
version_name: str,
|
81
|
+
comment: Optional[str] = None,
|
82
|
+
metrics: Optional[Dict[str, Any]] = None,
|
63
83
|
conda_dependencies: Optional[List[str]] = None,
|
64
84
|
pip_requirements: Optional[List[str]] = None,
|
65
85
|
python_version: Optional[str] = None,
|
@@ -68,148 +88,138 @@ class Registry:
|
|
68
88
|
code_paths: Optional[List[str]] = None,
|
69
89
|
ext_modules: Optional[List[ModuleType]] = None,
|
70
90
|
options: Optional[model_types.ModelSaveOption] = None,
|
71
|
-
) ->
|
72
|
-
"""
|
91
|
+
) -> ModelVersion:
|
92
|
+
"""
|
93
|
+
Log a model with various parameters and metadata.
|
73
94
|
|
74
95
|
Args:
|
75
|
-
model: Model
|
76
|
-
|
77
|
-
|
78
|
-
|
96
|
+
model: Model object of supported types such as Scikit-learn, XGBoost, Snowpark ML,
|
97
|
+
PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
|
98
|
+
Peft-finetuned LLM, or Custom Model.
|
99
|
+
model_name: Name to identify the model.
|
100
|
+
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
101
|
+
comment: Comment associated with the model version. Defaults to None.
|
102
|
+
metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
|
103
|
+
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
79
104
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
80
|
-
infer the signature. If not None,
|
81
|
-
sample_input_data: Sample input data to infer
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
105
|
+
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
106
|
+
sample_input_data: Sample input data to infer model signatures from. Defaults to None.
|
107
|
+
conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
|
108
|
+
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
109
|
+
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
110
|
+
pip_requirements: List of Pip package specifications. Defaults to None.
|
111
|
+
python_version: Python version in which the model is run. Defaults to None.
|
112
|
+
code_paths: List of directories containing code to import. Defaults to None.
|
113
|
+
ext_modules: List of external modules to pickle with the model object.
|
114
|
+
Only supported when logging the following types of model:
|
115
|
+
Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
|
116
|
+
options (Dict[str, Any], optional): Additional model saving options.
|
117
|
+
|
118
|
+
Model Saving Options include:
|
119
|
+
|
120
|
+
- embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
|
121
|
+
Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
|
122
|
+
Channel. Otherwise, defaults to False
|
123
|
+
- relax_version: Whether or not relax the version constraints of the dependencies.
|
124
|
+
It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
|
125
|
+
- method_options: Per-method saving options including:
|
126
|
+
- case_sensitive: Indicates whether the method and its signature should be case sensitive.
|
127
|
+
This means when you refer the method in the SQL, you need to double quote it.
|
128
|
+
This will be helpful if you need case to tell apart your methods or features, or you have
|
129
|
+
non-alphabetic characters in your method or feature name. Defaults to False.
|
130
|
+
- max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
|
131
|
+
Defaults to None, determined automatically by Snowflake.
|
93
132
|
|
94
133
|
Returns:
|
95
|
-
|
134
|
+
ModelVersion: ModelVersion object corresponding to the model just logged.
|
96
135
|
"""
|
97
136
|
|
98
137
|
statement_params = telemetry.get_statement_params(
|
99
138
|
project=_TELEMETRY_PROJECT,
|
100
139
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
101
140
|
)
|
102
|
-
|
103
|
-
|
104
|
-
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
105
|
-
|
106
|
-
stage_path = self._model_ops.prepare_model_stage_path(
|
107
|
-
statement_params=statement_params,
|
108
|
-
)
|
109
|
-
|
110
|
-
mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path)
|
111
|
-
mc.save(
|
112
|
-
name=model_name_id.resolved(),
|
141
|
+
return self._model_manager.log_model(
|
113
142
|
model=model,
|
114
|
-
|
115
|
-
|
143
|
+
model_name=model_name,
|
144
|
+
version_name=version_name,
|
145
|
+
comment=comment,
|
146
|
+
metrics=metrics,
|
116
147
|
conda_dependencies=conda_dependencies,
|
117
148
|
pip_requirements=pip_requirements,
|
118
149
|
python_version=python_version,
|
150
|
+
signatures=signatures,
|
151
|
+
sample_input_data=sample_input_data,
|
119
152
|
code_paths=code_paths,
|
120
153
|
ext_modules=ext_modules,
|
121
154
|
options=options,
|
122
|
-
)
|
123
|
-
self._model_ops.create_from_stage(
|
124
|
-
composed_model=mc,
|
125
|
-
model_name=model_name_id,
|
126
|
-
version_name=version_name_id,
|
127
155
|
statement_params=statement_params,
|
128
156
|
)
|
129
157
|
|
130
|
-
return model_version_impl.ModelVersion._ref(
|
131
|
-
self._model_ops,
|
132
|
-
model_name=model_name_id,
|
133
|
-
version_name=version_name_id,
|
134
|
-
)
|
135
|
-
|
136
158
|
@telemetry.send_api_usage_telemetry(
|
137
159
|
project=_TELEMETRY_PROJECT,
|
138
160
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
139
161
|
)
|
140
|
-
def get_model(self, model_name: str) ->
|
141
|
-
"""Get the model object.
|
162
|
+
def get_model(self, model_name: str) -> Model:
|
163
|
+
"""Get the model object by its name.
|
142
164
|
|
143
165
|
Args:
|
144
|
-
model_name: The model
|
145
|
-
|
146
|
-
Raises:
|
147
|
-
ValueError: Raised when the model requested does not exist.
|
166
|
+
model_name: The name of the model.
|
148
167
|
|
149
168
|
Returns:
|
150
|
-
The model object.
|
169
|
+
The corresponding model object.
|
151
170
|
"""
|
152
|
-
model_name_id = sql_identifier.SqlIdentifier(model_name)
|
153
|
-
|
154
171
|
statement_params = telemetry.get_statement_params(
|
155
172
|
project=_TELEMETRY_PROJECT,
|
156
173
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
157
174
|
)
|
158
|
-
|
159
|
-
model_name=model_name_id,
|
160
|
-
statement_params=statement_params,
|
161
|
-
):
|
162
|
-
return model_impl.Model._ref(
|
163
|
-
self._model_ops,
|
164
|
-
model_name=model_name_id,
|
165
|
-
)
|
166
|
-
else:
|
167
|
-
raise ValueError(f"Unable to find model {model_name}")
|
175
|
+
return self._model_manager.get_model(model_name=model_name, statement_params=statement_params)
|
168
176
|
|
169
177
|
@telemetry.send_api_usage_telemetry(
|
170
178
|
project=_TELEMETRY_PROJECT,
|
171
179
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
172
180
|
)
|
173
|
-
def
|
174
|
-
"""
|
181
|
+
def models(self) -> List[Model]:
|
182
|
+
"""Get all models in the schema where the registry is opened.
|
175
183
|
|
176
184
|
Returns:
|
177
|
-
A
|
185
|
+
A list of Model objects representing all models in the opened registry.
|
178
186
|
"""
|
179
187
|
statement_params = telemetry.get_statement_params(
|
180
188
|
project=_TELEMETRY_PROJECT,
|
181
189
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
182
190
|
)
|
183
|
-
|
184
|
-
|
191
|
+
return self._model_manager.models(statement_params=statement_params)
|
192
|
+
|
193
|
+
@telemetry.send_api_usage_telemetry(
|
194
|
+
project=_TELEMETRY_PROJECT,
|
195
|
+
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
196
|
+
)
|
197
|
+
def show_models(self) -> pd.DataFrame:
|
198
|
+
"""Show information of all models in the schema where the registry is opened.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
A Pandas DataFrame containing information of all models in the schema.
|
202
|
+
"""
|
203
|
+
statement_params = telemetry.get_statement_params(
|
204
|
+
project=_TELEMETRY_PROJECT,
|
205
|
+
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
185
206
|
)
|
186
|
-
return
|
187
|
-
model_impl.Model._ref(
|
188
|
-
self._model_ops,
|
189
|
-
model_name=model_name,
|
190
|
-
)
|
191
|
-
for model_name in model_names
|
192
|
-
]
|
207
|
+
return self._model_manager.show_models(statement_params=statement_params)
|
193
208
|
|
194
209
|
@telemetry.send_api_usage_telemetry(
|
195
210
|
project=_TELEMETRY_PROJECT,
|
196
211
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
197
212
|
)
|
198
213
|
def delete_model(self, model_name: str) -> None:
|
199
|
-
"""
|
214
|
+
"""
|
215
|
+
Delete the model by its name.
|
200
216
|
|
201
217
|
Args:
|
202
|
-
model_name: The
|
203
|
-
If not, use database name and schema name of the registry.
|
218
|
+
model_name: The name of the model to be deleted.
|
204
219
|
"""
|
205
|
-
model_name_id = sql_identifier.SqlIdentifier(model_name)
|
206
|
-
|
207
220
|
statement_params = telemetry.get_statement_params(
|
208
221
|
project=_TELEMETRY_PROJECT,
|
209
222
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
210
223
|
)
|
211
224
|
|
212
|
-
self.
|
213
|
-
model_name=model_name_id,
|
214
|
-
statement_params=statement_params,
|
215
|
-
)
|
225
|
+
self._model_manager.delete_model(model_name=model_name, statement_params=statement_params)
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.1
|
1
|
+
VERSION="1.2.1"
|