snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/identifier.py +78 -72
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
- snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
- snowflake/ml/modeling/cluster/birch.py +106 -135
- snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
- snowflake/ml/modeling/cluster/dbscan.py +106 -135
- snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
- snowflake/ml/modeling/cluster/k_means.py +105 -135
- snowflake/ml/modeling/cluster/mean_shift.py +106 -135
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
- snowflake/ml/modeling/cluster/optics.py +106 -135
- snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
- snowflake/ml/modeling/compose/column_transformer.py +106 -135
- snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
- snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
- snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
- snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
- snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
- snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
- snowflake/ml/modeling/covariance/oas.py +99 -128
- snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
- snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
- snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
- snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
- snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/pca.py +106 -135
- snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
- snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
- snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
- snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
- snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
- snowflake/ml/modeling/framework/base.py +83 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
- snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
- snowflake/ml/modeling/impute/knn_imputer.py +106 -135
- snowflake/ml/modeling/impute/missing_indicator.py +106 -135
- snowflake/ml/modeling/impute/simple_imputer.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
- snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
- snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/lars.py +108 -135
- snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
- snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
- snowflake/ml/modeling/linear_model/perceptron.py +107 -135
- snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ridge.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
- snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
- snowflake/ml/modeling/manifold/isomap.py +106 -135
- snowflake/ml/modeling/manifold/mds.py +106 -135
- snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
- snowflake/ml/modeling/manifold/tsne.py +106 -135
- snowflake/ml/modeling/metrics/classification.py +196 -55
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
- snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
- snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
- snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
- snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
- snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
- snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
- snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
- snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
- snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
- snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
- snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
- snowflake/ml/modeling/svm/linear_svc.py +108 -135
- snowflake/ml/modeling/svm/linear_svr.py +108 -135
- snowflake/ml/modeling/svm/nu_svc.py +108 -135
- snowflake/ml/modeling/svm/nu_svr.py +108 -135
- snowflake/ml/modeling/svm/svc.py +108 -135
- snowflake/ml/modeling/svm/svr.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
- snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import TypedDict
|
2
|
+
|
3
|
+
from typing_extensions import Required
|
4
|
+
|
5
|
+
from snowflake.ml.model import model_signature
|
6
|
+
|
7
|
+
|
8
|
+
class ModelMethodInfo(TypedDict):
|
9
|
+
"""Method information.
|
10
|
+
|
11
|
+
Attributes:
|
12
|
+
name: Name of the method to be called via SQL.
|
13
|
+
target_method: actual target method name to be called.
|
14
|
+
signature: The signature of the model method.
|
15
|
+
"""
|
16
|
+
|
17
|
+
name: Required[str]
|
18
|
+
target_method: Required[str]
|
19
|
+
signature: Required[model_signature.ModelSignature]
|
@@ -0,0 +1,291 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
from snowflake.ml._internal import telemetry
|
7
|
+
from snowflake.ml._internal.utils import sql_identifier
|
8
|
+
from snowflake.ml.model import model_signature
|
9
|
+
from snowflake.ml.model._client.model import model_method_info
|
10
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
11
|
+
from snowflake.snowpark import dataframe
|
12
|
+
|
13
|
+
_TELEMETRY_PROJECT = "MLOps"
|
14
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
15
|
+
|
16
|
+
|
17
|
+
class ModelVersion:
|
18
|
+
"""Model Version Object representing a specific version of the model that could be run."""
|
19
|
+
|
20
|
+
_model_ops: model_ops.ModelOperator
|
21
|
+
_model_name: sql_identifier.SqlIdentifier
|
22
|
+
_version_name: sql_identifier.SqlIdentifier
|
23
|
+
|
24
|
+
def __init__(self) -> None:
|
25
|
+
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def _ref(
|
29
|
+
cls,
|
30
|
+
model_ops: model_ops.ModelOperator,
|
31
|
+
*,
|
32
|
+
model_name: sql_identifier.SqlIdentifier,
|
33
|
+
version_name: sql_identifier.SqlIdentifier,
|
34
|
+
) -> "ModelVersion":
|
35
|
+
self: "ModelVersion" = object.__new__(cls)
|
36
|
+
self._model_ops = model_ops
|
37
|
+
self._model_name = model_name
|
38
|
+
self._version_name = version_name
|
39
|
+
return self
|
40
|
+
|
41
|
+
def __eq__(self, __value: object) -> bool:
|
42
|
+
if not isinstance(__value, ModelVersion):
|
43
|
+
return False
|
44
|
+
return (
|
45
|
+
self._model_ops == __value._model_ops
|
46
|
+
and self._model_name == __value._model_name
|
47
|
+
and self._version_name == __value._version_name
|
48
|
+
)
|
49
|
+
|
50
|
+
@property
|
51
|
+
def model_name(self) -> str:
|
52
|
+
return self._model_name.identifier()
|
53
|
+
|
54
|
+
@property
|
55
|
+
def version_name(self) -> str:
|
56
|
+
return self._version_name.identifier()
|
57
|
+
|
58
|
+
@property
|
59
|
+
def fully_qualified_model_name(self) -> str:
|
60
|
+
return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
|
61
|
+
|
62
|
+
@property
|
63
|
+
@telemetry.send_api_usage_telemetry(
|
64
|
+
project=_TELEMETRY_PROJECT,
|
65
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
66
|
+
)
|
67
|
+
def description(self) -> str:
|
68
|
+
statement_params = telemetry.get_statement_params(
|
69
|
+
project=_TELEMETRY_PROJECT,
|
70
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
71
|
+
)
|
72
|
+
return self._model_ops.get_comment(
|
73
|
+
model_name=self._model_name,
|
74
|
+
version_name=self._version_name,
|
75
|
+
statement_params=statement_params,
|
76
|
+
)
|
77
|
+
|
78
|
+
@description.setter
|
79
|
+
@telemetry.send_api_usage_telemetry(
|
80
|
+
project=_TELEMETRY_PROJECT,
|
81
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
82
|
+
)
|
83
|
+
def description(self, description: str) -> None:
|
84
|
+
statement_params = telemetry.get_statement_params(
|
85
|
+
project=_TELEMETRY_PROJECT,
|
86
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
87
|
+
)
|
88
|
+
return self._model_ops.set_comment(
|
89
|
+
comment=description,
|
90
|
+
model_name=self._model_name,
|
91
|
+
version_name=self._version_name,
|
92
|
+
statement_params=statement_params,
|
93
|
+
)
|
94
|
+
|
95
|
+
@telemetry.send_api_usage_telemetry(
|
96
|
+
project=_TELEMETRY_PROJECT,
|
97
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
98
|
+
)
|
99
|
+
def list_metrics(self) -> Dict[str, Any]:
|
100
|
+
"""Show all metrics logged with the model version.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
A dictionary showing the metrics
|
104
|
+
"""
|
105
|
+
statement_params = telemetry.get_statement_params(
|
106
|
+
project=_TELEMETRY_PROJECT,
|
107
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
108
|
+
)
|
109
|
+
return self._model_ops._metadata_ops.load(
|
110
|
+
model_name=self._model_name, version_name=self._version_name, statement_params=statement_params
|
111
|
+
)["metrics"]
|
112
|
+
|
113
|
+
@telemetry.send_api_usage_telemetry(
|
114
|
+
project=_TELEMETRY_PROJECT,
|
115
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
116
|
+
)
|
117
|
+
def get_metric(self, metric_name: str) -> Any:
|
118
|
+
"""Get the value of a specific metric.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
metric_name: The name of the metric
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
KeyError: Raised when the requested metric name does not exist.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
The value of the metric.
|
128
|
+
"""
|
129
|
+
metrics = self.list_metrics()
|
130
|
+
if metric_name not in metrics:
|
131
|
+
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
132
|
+
return metrics[metric_name]
|
133
|
+
|
134
|
+
@telemetry.send_api_usage_telemetry(
|
135
|
+
project=_TELEMETRY_PROJECT,
|
136
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
137
|
+
)
|
138
|
+
def set_metric(self, metric_name: str, value: Any) -> None:
|
139
|
+
"""Set the value of a specific metric name
|
140
|
+
|
141
|
+
Args:
|
142
|
+
metric_name: The name of the metric
|
143
|
+
value: The value of the metric.
|
144
|
+
"""
|
145
|
+
statement_params = telemetry.get_statement_params(
|
146
|
+
project=_TELEMETRY_PROJECT,
|
147
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
148
|
+
)
|
149
|
+
metrics = self.list_metrics()
|
150
|
+
metrics[metric_name] = value
|
151
|
+
self._model_ops._metadata_ops.save(
|
152
|
+
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
153
|
+
model_name=self._model_name,
|
154
|
+
version_name=self._version_name,
|
155
|
+
statement_params=statement_params,
|
156
|
+
)
|
157
|
+
|
158
|
+
@telemetry.send_api_usage_telemetry(
|
159
|
+
project=_TELEMETRY_PROJECT,
|
160
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
161
|
+
)
|
162
|
+
def delete_metric(self, metric_name: str) -> None:
|
163
|
+
"""Delete a metric from metric storage.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
metric_name: The name of the metric to be deleted.
|
167
|
+
|
168
|
+
Raises:
|
169
|
+
KeyError: Raised when the requested metric name does not exist.
|
170
|
+
"""
|
171
|
+
statement_params = telemetry.get_statement_params(
|
172
|
+
project=_TELEMETRY_PROJECT,
|
173
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
174
|
+
)
|
175
|
+
metrics = self.list_metrics()
|
176
|
+
if metric_name not in metrics:
|
177
|
+
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
178
|
+
del metrics[metric_name]
|
179
|
+
self._model_ops._metadata_ops.save(
|
180
|
+
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
181
|
+
model_name=self._model_name,
|
182
|
+
version_name=self._version_name,
|
183
|
+
statement_params=statement_params,
|
184
|
+
)
|
185
|
+
|
186
|
+
@telemetry.send_api_usage_telemetry(
|
187
|
+
project=_TELEMETRY_PROJECT,
|
188
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
189
|
+
)
|
190
|
+
def list_methods(self) -> List[model_method_info.ModelMethodInfo]:
|
191
|
+
"""List all method information in a model version that is callable.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
A list of ModelMethodInfo object containing the following information:
|
195
|
+
- name: The name of the method to be called (both in SQL and in Python SDK).
|
196
|
+
- target_method: The original method name in the logged Python object.
|
197
|
+
- Signature: Python signature of the original method.
|
198
|
+
"""
|
199
|
+
statement_params = telemetry.get_statement_params(
|
200
|
+
project=_TELEMETRY_PROJECT,
|
201
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
202
|
+
)
|
203
|
+
# TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data.
|
204
|
+
manifest = self._model_ops.get_model_version_manifest(
|
205
|
+
model_name=self._model_name,
|
206
|
+
version_name=self._version_name,
|
207
|
+
statement_params=statement_params,
|
208
|
+
)
|
209
|
+
model_meta = self._model_ops.get_model_version_native_packing_meta(
|
210
|
+
model_name=self._model_name,
|
211
|
+
version_name=self._version_name,
|
212
|
+
statement_params=statement_params,
|
213
|
+
)
|
214
|
+
return_methods_info: List[model_method_info.ModelMethodInfo] = []
|
215
|
+
for method in manifest["methods"]:
|
216
|
+
# Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier.
|
217
|
+
method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier()
|
218
|
+
# Method's handler is `functions.<target_method>.infer`
|
219
|
+
assert re.match(
|
220
|
+
r"^functions\.([^\d\W]\w*)\.infer$", method["handler"]
|
221
|
+
), f"Get unexpected handler name {method['handler']}"
|
222
|
+
target_method = method["handler"].split(".")[1]
|
223
|
+
signature_dict = model_meta["signatures"][target_method]
|
224
|
+
method_info = model_method_info.ModelMethodInfo(
|
225
|
+
name=method_name,
|
226
|
+
target_method=target_method,
|
227
|
+
signature=model_signature.ModelSignature.from_dict(signature_dict),
|
228
|
+
)
|
229
|
+
return_methods_info.append(method_info)
|
230
|
+
|
231
|
+
return return_methods_info
|
232
|
+
|
233
|
+
@telemetry.send_api_usage_telemetry(
|
234
|
+
project=_TELEMETRY_PROJECT,
|
235
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
236
|
+
)
|
237
|
+
def run(
|
238
|
+
self,
|
239
|
+
X: Union[pd.DataFrame, dataframe.DataFrame],
|
240
|
+
*,
|
241
|
+
method_name: Optional[str] = None,
|
242
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
243
|
+
"""Invoke a method in a model version object
|
244
|
+
|
245
|
+
Args:
|
246
|
+
X: The input data. Could be pandas DataFrame or Snowpark DataFrame
|
247
|
+
method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None.
|
248
|
+
It can only be None if there is only 1 method.
|
249
|
+
|
250
|
+
Raises:
|
251
|
+
ValueError: No method with the corresponding name is available.
|
252
|
+
ValueError: There are more than 1 target methods available in the model but no method name specified.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
The prediction data.
|
256
|
+
"""
|
257
|
+
statement_params = telemetry.get_statement_params(
|
258
|
+
project=_TELEMETRY_PROJECT,
|
259
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
260
|
+
)
|
261
|
+
|
262
|
+
methods: List[model_method_info.ModelMethodInfo] = self.list_methods()
|
263
|
+
if method_name:
|
264
|
+
req_method_name = sql_identifier.SqlIdentifier(method_name).identifier()
|
265
|
+
find_method: Callable[[model_method_info.ModelMethodInfo], bool] = (
|
266
|
+
lambda method: method["name"] == req_method_name
|
267
|
+
)
|
268
|
+
target_method_info = next(
|
269
|
+
filter(find_method, methods),
|
270
|
+
None,
|
271
|
+
)
|
272
|
+
if target_method_info is None:
|
273
|
+
raise ValueError(
|
274
|
+
f"There is no method with name {method_name} available in the model"
|
275
|
+
f" {self.fully_qualified_model_name} version {self.version_name}"
|
276
|
+
)
|
277
|
+
elif len(methods) != 1:
|
278
|
+
raise ValueError(
|
279
|
+
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
280
|
+
f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
target_method_info = methods[0]
|
284
|
+
return self._model_ops.invoke_method(
|
285
|
+
method_name=sql_identifier.SqlIdentifier(target_method_info["name"]),
|
286
|
+
signature=target_method_info["signature"],
|
287
|
+
X=X,
|
288
|
+
model_name=self._model_name,
|
289
|
+
version_name=self._version_name,
|
290
|
+
statement_params=statement_params,
|
291
|
+
)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional, TypedDict
|
3
|
+
|
4
|
+
from typing_extensions import NotRequired
|
5
|
+
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
7
|
+
from snowflake.ml.model._client.sql import (
|
8
|
+
model as model_sql,
|
9
|
+
model_version as model_version_sql,
|
10
|
+
)
|
11
|
+
from snowflake.snowpark import session
|
12
|
+
|
13
|
+
MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
|
14
|
+
|
15
|
+
|
16
|
+
class ModelVersionMetadataSchema(TypedDict):
|
17
|
+
metrics: NotRequired[Dict[str, Any]]
|
18
|
+
|
19
|
+
|
20
|
+
class MetadataOperator:
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
session: session.Session,
|
24
|
+
*,
|
25
|
+
database_name: sql_identifier.SqlIdentifier,
|
26
|
+
schema_name: sql_identifier.SqlIdentifier,
|
27
|
+
) -> None:
|
28
|
+
self._model_client = model_sql.ModelSQLClient(
|
29
|
+
session,
|
30
|
+
database_name=database_name,
|
31
|
+
schema_name=schema_name,
|
32
|
+
)
|
33
|
+
self._model_version_client = model_version_sql.ModelVersionSQLClient(
|
34
|
+
session,
|
35
|
+
database_name=database_name,
|
36
|
+
schema_name=schema_name,
|
37
|
+
)
|
38
|
+
|
39
|
+
def __eq__(self, __value: object) -> bool:
|
40
|
+
if not isinstance(__value, MetadataOperator):
|
41
|
+
return False
|
42
|
+
return (
|
43
|
+
self._model_client == __value._model_client and self._model_version_client == __value._model_version_client
|
44
|
+
)
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
|
48
|
+
loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
|
49
|
+
if loaded_metadata_schema_version is None:
|
50
|
+
return ModelVersionMetadataSchema(metrics={})
|
51
|
+
elif (
|
52
|
+
not isinstance(loaded_metadata_schema_version, str)
|
53
|
+
or loaded_metadata_schema_version != MODEL_VERSION_METADATA_SCHEMA_VERSION
|
54
|
+
):
|
55
|
+
raise ValueError(f"Unsupported model metadata schema version {loaded_metadata_schema_version} confronted.")
|
56
|
+
loaded_metrics = metadata_dict.get("metrics", {})
|
57
|
+
if not isinstance(loaded_metrics, dict):
|
58
|
+
raise ValueError(f"Metrics in the metadata is expected to be a dictionary, getting {loaded_metrics}")
|
59
|
+
return ModelVersionMetadataSchema(metrics=loaded_metrics)
|
60
|
+
|
61
|
+
def _get_current_metadata_dict(
|
62
|
+
self,
|
63
|
+
*,
|
64
|
+
model_name: sql_identifier.SqlIdentifier,
|
65
|
+
version_name: sql_identifier.SqlIdentifier,
|
66
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
67
|
+
) -> Dict[str, Any]:
|
68
|
+
version_info_list = self._model_client.show_versions(
|
69
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
70
|
+
)
|
71
|
+
assert len(version_info_list) == 1
|
72
|
+
version_info = version_info_list[0]
|
73
|
+
metadata_str = version_info.metadata
|
74
|
+
if not metadata_str:
|
75
|
+
return {}
|
76
|
+
res = json.loads(metadata_str)
|
77
|
+
if not isinstance(res, dict):
|
78
|
+
raise ValueError(f"Metadata is expected to be a dictionary, getting {res}")
|
79
|
+
return res
|
80
|
+
|
81
|
+
def load(
|
82
|
+
self,
|
83
|
+
*,
|
84
|
+
model_name: sql_identifier.SqlIdentifier,
|
85
|
+
version_name: sql_identifier.SqlIdentifier,
|
86
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
87
|
+
) -> ModelVersionMetadataSchema:
|
88
|
+
metadata_dict = self._get_current_metadata_dict(
|
89
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
90
|
+
)
|
91
|
+
return MetadataOperator._parse(metadata_dict)
|
92
|
+
|
93
|
+
def save(
|
94
|
+
self,
|
95
|
+
metadata: ModelVersionMetadataSchema,
|
96
|
+
*,
|
97
|
+
model_name: sql_identifier.SqlIdentifier,
|
98
|
+
version_name: sql_identifier.SqlIdentifier,
|
99
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
100
|
+
) -> None:
|
101
|
+
metadata_dict = self._get_current_metadata_dict(
|
102
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
103
|
+
)
|
104
|
+
metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
|
105
|
+
self._model_version_client.set_metadata(
|
106
|
+
metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params
|
107
|
+
)
|