snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +13 -7
- snowflake/ml/_internal/utils/connection_params.py +5 -3
- snowflake/ml/_internal/utils/jwt_generator.py +3 -2
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
- snowflake/ml/experiment/_entities/__init__.py +2 -1
- snowflake/ml/experiment/_entities/run.py +0 -15
- snowflake/ml/experiment/_entities/run_metadata.py +3 -51
- snowflake/ml/experiment/experiment_tracking.py +71 -27
- snowflake/ml/jobs/_utils/spec_utils.py +49 -11
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/__init__.py +12 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
- snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
- snowflake/ml/model/_client/model/model_version_impl.py +30 -62
- snowflake/ml/model/_client/ops/service_ops.py +68 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/service.py +29 -2
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
- snowflake/ml/model/_packager/model_env/model_env.py +26 -16
- snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
- snowflake/ml/model/_packager/model_packager.py +4 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_signatures/utils.py +0 -21
- snowflake/ml/model/models/huggingface_pipeline.py +56 -21
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +2 -1
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/utils/connection_params.py +5 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import dataclasses
|
|
2
3
|
import enum
|
|
3
4
|
import logging
|
|
4
5
|
import textwrap
|
|
5
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Any, Generator, Optional
|
|
6
7
|
|
|
7
8
|
from snowflake import snowpark
|
|
8
9
|
from snowflake.ml._internal.utils import (
|
|
@@ -17,6 +18,11 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
|
17
18
|
|
|
18
19
|
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
21
|
+
# Using this token instead of '?' to avoid escaping issues
|
|
22
|
+
# After quotes are escaped, we replace this token with '|| ? ||'
|
|
23
|
+
QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
|
|
24
|
+
QMARK_PARAMETER_TOKEN = "'|| ? ||'"
|
|
25
|
+
|
|
20
26
|
|
|
21
27
|
class ServiceStatus(enum.Enum):
|
|
22
28
|
PENDING = "PENDING"
|
|
@@ -70,12 +76,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
70
76
|
CONTAINER_STATUS = "status"
|
|
71
77
|
MESSAGE = "message"
|
|
72
78
|
|
|
79
|
+
@contextlib.contextmanager
|
|
80
|
+
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
81
|
+
"""Context manager that temporarily changes paramstyle to qmark and restores original value on exit."""
|
|
82
|
+
if not hasattr(self._session, "_options"):
|
|
83
|
+
yield
|
|
84
|
+
else:
|
|
85
|
+
original_paramstyle = self._session._options["paramstyle"]
|
|
86
|
+
try:
|
|
87
|
+
self._session._options["paramstyle"] = "qmark"
|
|
88
|
+
yield
|
|
89
|
+
finally:
|
|
90
|
+
self._session._options["paramstyle"] = original_paramstyle
|
|
91
|
+
|
|
73
92
|
def deploy_model(
|
|
74
93
|
self,
|
|
75
94
|
*,
|
|
76
95
|
stage_path: Optional[str] = None,
|
|
77
96
|
model_deployment_spec_yaml_str: Optional[str] = None,
|
|
78
97
|
model_deployment_spec_file_rel_path: Optional[str] = None,
|
|
98
|
+
query_params: Optional[list[Any]] = None,
|
|
79
99
|
statement_params: Optional[dict[str, Any]] = None,
|
|
80
100
|
) -> tuple[str, snowpark.AsyncJob]:
|
|
81
101
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
|
@@ -83,11 +103,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
83
103
|
model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
|
|
84
104
|
model_deployment_spec_yaml_str
|
|
85
105
|
) # type: ignore[no-untyped-call]
|
|
106
|
+
model_deployment_spec_yaml_str = model_deployment_spec_yaml_str.replace( # type: ignore[union-attr]
|
|
107
|
+
QMARK_RESERVED_TOKEN, QMARK_PARAMETER_TOKEN
|
|
108
|
+
)
|
|
86
109
|
logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
|
|
87
110
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
|
88
111
|
else:
|
|
89
112
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
|
90
|
-
|
|
113
|
+
with self._qmark_paramstyle():
|
|
114
|
+
async_job = self._session.sql(
|
|
115
|
+
sql_str,
|
|
116
|
+
params=query_params if query_params else None,
|
|
117
|
+
).collect(block=False, statement_params=statement_params)
|
|
91
118
|
assert isinstance(async_job, snowpark.AsyncJob)
|
|
92
119
|
return async_job.query_id, async_job
|
|
93
120
|
|
|
@@ -2,6 +2,7 @@ from typing import Any, Optional
|
|
|
2
2
|
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
|
5
|
+
from snowflake.snowpark import Row
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class StageSQLClient(_base._BaseSQLClient):
|
|
@@ -21,3 +22,10 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
|
21
22
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
22
23
|
|
|
23
24
|
return fq_stage_name
|
|
25
|
+
|
|
26
|
+
def list_stage(self, stage_name: str) -> list[Row]:
|
|
27
|
+
try:
|
|
28
|
+
list_results = self._session.sql(f"LIST {stage_name}").collect()
|
|
29
|
+
except Exception as e:
|
|
30
|
+
raise RuntimeError(f"Failed to check stage location '{stage_name}': {e}")
|
|
31
|
+
return list_results
|
|
@@ -46,6 +46,7 @@ class ModelFunctionMethodDict(TypedDict):
|
|
|
46
46
|
handler: Required[str]
|
|
47
47
|
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
|
48
48
|
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
|
49
|
+
volatility: NotRequired[str]
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
ModelMethodDict = ModelFunctionMethodDict
|
|
@@ -4,6 +4,7 @@ from typing import Optional, TypedDict, Union
|
|
|
4
4
|
|
|
5
5
|
from typing_extensions import NotRequired
|
|
6
6
|
|
|
7
|
+
from snowflake.ml._internal import platform_capabilities
|
|
7
8
|
from snowflake.ml._internal.utils import sql_identifier
|
|
8
9
|
from snowflake.ml.model import model_signature, type_hints
|
|
9
10
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -12,6 +13,7 @@ from snowflake.ml.model._model_composer.model_method import (
|
|
|
12
13
|
function_generator,
|
|
13
14
|
)
|
|
14
15
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
|
16
|
+
from snowflake.ml.model.volatility import Volatility
|
|
15
17
|
from snowflake.snowpark._internal import type_utils
|
|
16
18
|
|
|
17
19
|
|
|
@@ -20,10 +22,12 @@ class ModelMethodOptions(TypedDict):
|
|
|
20
22
|
|
|
21
23
|
case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL.
|
|
22
24
|
function_type: One of `ModelMethodFunctionTypes` specifying function type.
|
|
25
|
+
volatility: One of `Volatility` enum values specifying function volatility.
|
|
23
26
|
"""
|
|
24
27
|
|
|
25
28
|
case_sensitive: NotRequired[bool]
|
|
26
29
|
function_type: NotRequired[str]
|
|
30
|
+
volatility: NotRequired[Volatility]
|
|
27
31
|
|
|
28
32
|
|
|
29
33
|
def get_model_method_options_from_options(
|
|
@@ -38,10 +42,19 @@ def get_model_method_options_from_options(
|
|
|
38
42
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
|
39
43
|
raise NotImplementedError(f"Function type {function_type} is not supported.")
|
|
40
44
|
|
|
41
|
-
|
|
45
|
+
default_volatility = options.get("volatility")
|
|
46
|
+
method_volatility = method_option.get("volatility")
|
|
47
|
+
resolved_volatility = method_volatility or default_volatility
|
|
48
|
+
|
|
49
|
+
# Only include volatility if explicitly provided in method options
|
|
50
|
+
result: ModelMethodOptions = ModelMethodOptions(
|
|
42
51
|
case_sensitive=method_option.get("case_sensitive", False),
|
|
43
52
|
function_type=function_type,
|
|
44
53
|
)
|
|
54
|
+
if resolved_volatility:
|
|
55
|
+
result["volatility"] = resolved_volatility
|
|
56
|
+
|
|
57
|
+
return result
|
|
45
58
|
|
|
46
59
|
|
|
47
60
|
class ModelMethod:
|
|
@@ -94,6 +107,9 @@ class ModelMethod:
|
|
|
94
107
|
"function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
95
108
|
)
|
|
96
109
|
|
|
110
|
+
# Volatility is optional; when not provided, we omit it from the manifest
|
|
111
|
+
self.volatility = self.options.get("volatility")
|
|
112
|
+
|
|
97
113
|
@staticmethod
|
|
98
114
|
def _get_method_arg_from_feature(
|
|
99
115
|
feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
|
|
@@ -148,7 +164,7 @@ class ModelMethod:
|
|
|
148
164
|
else:
|
|
149
165
|
outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")]
|
|
150
166
|
|
|
151
|
-
|
|
167
|
+
method_dict = model_manifest_schema.ModelFunctionMethodDict(
|
|
152
168
|
name=self.method_name.resolved(),
|
|
153
169
|
runtime=self.runtime_name,
|
|
154
170
|
type=self.function_type,
|
|
@@ -158,3 +174,10 @@ class ModelMethod:
|
|
|
158
174
|
inputs=input_list,
|
|
159
175
|
outputs=outputs,
|
|
160
176
|
)
|
|
177
|
+
should_set_volatility = (
|
|
178
|
+
platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
|
|
179
|
+
)
|
|
180
|
+
if should_set_volatility and self.volatility is not None:
|
|
181
|
+
method_dict["volatility"] = self.volatility.name
|
|
182
|
+
|
|
183
|
+
return method_dict
|
|
@@ -145,11 +145,12 @@ class ModelEnv:
|
|
|
145
145
|
"""
|
|
146
146
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
|
147
147
|
pip_pkg_reqs: list[str] = []
|
|
148
|
-
if self.targets_warehouse:
|
|
148
|
+
if self.targets_warehouse and not self.artifact_repository_map:
|
|
149
149
|
self._warn_once(
|
|
150
150
|
(
|
|
151
151
|
"Dependencies specified from pip requirements."
|
|
152
152
|
" This may prevent model deploying to Snowflake Warehouse."
|
|
153
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
153
154
|
),
|
|
154
155
|
stacklevel=2,
|
|
155
156
|
)
|
|
@@ -177,7 +178,11 @@ class ModelEnv:
|
|
|
177
178
|
req_to_add.name = conda_req.name
|
|
178
179
|
else:
|
|
179
180
|
req_to_add = conda_req
|
|
180
|
-
show_warning_message =
|
|
181
|
+
show_warning_message = (
|
|
182
|
+
conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
|
|
183
|
+
and self.targets_warehouse
|
|
184
|
+
and not self.artifact_repository_map
|
|
185
|
+
)
|
|
181
186
|
|
|
182
187
|
if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
|
|
183
188
|
if show_warning_message:
|
|
@@ -185,6 +190,7 @@ class ModelEnv:
|
|
|
185
190
|
(
|
|
186
191
|
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
|
187
192
|
" This may prevent model deploying to Snowflake Warehouse."
|
|
193
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
188
194
|
),
|
|
189
195
|
stacklevel=2,
|
|
190
196
|
)
|
|
@@ -318,13 +324,15 @@ class ModelEnv:
|
|
|
318
324
|
)
|
|
319
325
|
|
|
320
326
|
if pip_requirements_list and self.targets_warehouse:
|
|
321
|
-
self.
|
|
322
|
-
(
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
327
|
+
if not self.artifact_repository_map:
|
|
328
|
+
self._warn_once(
|
|
329
|
+
(
|
|
330
|
+
"Found dependencies specified as pip requirements."
|
|
331
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
|
332
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
333
|
+
),
|
|
334
|
+
stacklevel=2,
|
|
335
|
+
)
|
|
328
336
|
for pip_dependency in pip_requirements_list:
|
|
329
337
|
if any(
|
|
330
338
|
channel_dependency.name == pip_dependency.name
|
|
@@ -343,13 +351,15 @@ class ModelEnv:
|
|
|
343
351
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
|
344
352
|
|
|
345
353
|
if pip_requirements_list and self.targets_warehouse:
|
|
346
|
-
self.
|
|
347
|
-
(
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
354
|
+
if not self.artifact_repository_map:
|
|
355
|
+
self._warn_once(
|
|
356
|
+
(
|
|
357
|
+
"Found dependencies specified as pip requirements."
|
|
358
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
|
359
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
360
|
+
),
|
|
361
|
+
stacklevel=2,
|
|
362
|
+
)
|
|
353
363
|
for pip_dependency in pip_requirements_list:
|
|
354
364
|
if any(
|
|
355
365
|
channel_dependency.name == pip_dependency.name
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import json
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import pathlib
|
|
5
6
|
import warnings
|
|
@@ -8,7 +9,6 @@ from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
|
10
11
|
import pandas as pd
|
|
11
|
-
from absl import logging
|
|
12
12
|
|
|
13
13
|
import snowflake.snowpark.dataframe as sp_df
|
|
14
14
|
from snowflake.ml._internal import env
|
|
@@ -23,6 +23,8 @@ from snowflake.ml.model._signatures import (
|
|
|
23
23
|
)
|
|
24
24
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
|
25
25
|
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
26
28
|
EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
|
|
27
29
|
|
|
28
30
|
|
|
@@ -257,7 +259,7 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
|
|
|
257
259
|
)
|
|
258
260
|
return inferred_model_task
|
|
259
261
|
elif inferred_model_task != model_types.Task.UNKNOWN:
|
|
260
|
-
|
|
262
|
+
logger.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
|
|
261
263
|
return inferred_model_task
|
|
262
264
|
return passed_model_task
|
|
263
265
|
|
|
@@ -43,7 +43,6 @@ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message
|
|
|
43
43
|
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
|
44
44
|
# Text
|
|
45
45
|
if task in [
|
|
46
|
-
"conversational",
|
|
47
46
|
"fill-mask",
|
|
48
47
|
"ner",
|
|
49
48
|
"token-classification",
|
|
@@ -521,6 +520,7 @@ class HuggingFacePipelineHandler(
|
|
|
521
520
|
input_data = X[signature.inputs[0].name].to_list()
|
|
522
521
|
temp_res = getattr(raw_model, target_method)(input_data)
|
|
523
522
|
else:
|
|
523
|
+
# TODO: remove conversational pipeline code
|
|
524
524
|
# For others, we could offer the whole dataframe as a list.
|
|
525
525
|
# Some of them may need some conversion
|
|
526
526
|
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
|
@@ -759,11 +759,13 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
759
759
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
760
760
|
stop_strings=stop_strings,
|
|
761
761
|
stream=stream,
|
|
762
|
-
repetition_penalty=frequency_penalty,
|
|
763
|
-
diversity_penalty=presence_penalty if n > 1 else None,
|
|
764
762
|
num_return_sequences=n,
|
|
765
|
-
num_beams=max(
|
|
766
|
-
|
|
763
|
+
num_beams=max(1, n), # must be >1
|
|
764
|
+
repetition_penalty=frequency_penalty,
|
|
765
|
+
# TODO: Handle diversity_penalty and num_beam_groups
|
|
766
|
+
# not all models support them making it hard to support any huggingface model
|
|
767
|
+
# diversity_penalty=presence_penalty if n > 1 else None,
|
|
768
|
+
# num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
767
769
|
do_sample=False,
|
|
768
770
|
)
|
|
769
771
|
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
from types import ModuleType
|
|
3
4
|
from typing import Optional
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
|
-
|
|
7
6
|
from snowflake.ml._internal.exceptions import (
|
|
8
7
|
error_codes,
|
|
9
8
|
exceptions as snowml_exceptions,
|
|
@@ -12,6 +11,8 @@ from snowflake.ml.model import custom_model, model_signature, type_hints as mode
|
|
|
12
11
|
from snowflake.ml.model._packager import model_handler
|
|
13
12
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
14
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
15
16
|
|
|
16
17
|
class ModelPackager:
|
|
17
18
|
"""Top-level class to save/load and manage a Snowflake Native formatted model.
|
|
@@ -96,7 +97,7 @@ class ModelPackager:
|
|
|
96
97
|
**options,
|
|
97
98
|
)
|
|
98
99
|
if signatures is None:
|
|
99
|
-
|
|
100
|
+
logger.info(f"Model signatures are auto inferred as:\n\n{meta.signatures}")
|
|
100
101
|
|
|
101
102
|
self.model = model
|
|
102
103
|
self.meta = meta
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# Generate by running 'bazel run --config=pre_build //bazel/requirements:sync_requirements'
|
|
3
3
|
|
|
4
4
|
REQUIREMENTS = [
|
|
5
|
-
"absl-py>=0.15,<2",
|
|
6
5
|
"aiohttp!=4.0.0a0, !=4.0.0a1",
|
|
7
6
|
"anyio>=3.5.0,<5",
|
|
8
7
|
"cachetools>=3.1.1,<6",
|
|
@@ -22,7 +21,7 @@ REQUIREMENTS = [
|
|
|
22
21
|
"requests",
|
|
23
22
|
"retrying>=1.3.3,<2",
|
|
24
23
|
"s3fs>=2024.6.1,<2026",
|
|
25
|
-
"scikit-learn<1.
|
|
24
|
+
"scikit-learn<1.8",
|
|
26
25
|
"scipy>=1.9,<2",
|
|
27
26
|
"shap>=0.46.0,<1",
|
|
28
27
|
"snowflake-connector-python>=3.16.0,<4",
|
|
@@ -110,27 +110,6 @@ def huggingface_pipeline_signature_auto_infer(
|
|
|
110
110
|
) -> Optional[core.ModelSignature]:
|
|
111
111
|
# Text
|
|
112
112
|
|
|
113
|
-
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
|
114
|
-
# Needs to convert to conversation object.
|
|
115
|
-
if task == "conversational":
|
|
116
|
-
warnings.warn(
|
|
117
|
-
(
|
|
118
|
-
"Conversational pipeline is removed from transformers since 4.42.0. "
|
|
119
|
-
"Support will be removed from snowflake-ml-python soon."
|
|
120
|
-
),
|
|
121
|
-
category=DeprecationWarning,
|
|
122
|
-
stacklevel=1,
|
|
123
|
-
)
|
|
124
|
-
return core.ModelSignature(
|
|
125
|
-
inputs=[
|
|
126
|
-
core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
|
|
127
|
-
core.FeatureSpec(name="generated_responses", dtype=core.DataType.STRING, shape=(-1,)),
|
|
128
|
-
],
|
|
129
|
-
outputs=[
|
|
130
|
-
core.FeatureSpec(name="generated_responses", dtype=core.DataType.STRING, shape=(-1,)),
|
|
131
|
-
],
|
|
132
|
-
)
|
|
133
|
-
|
|
134
113
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TokenClassificationPipeline
|
|
135
114
|
if task == "fill-mask":
|
|
136
115
|
return core.ModelSignature(
|
|
@@ -8,6 +8,7 @@ from snowflake import snowpark
|
|
|
8
8
|
from snowflake.ml._internal import telemetry
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
|
11
|
+
from snowflake.ml.model._client.model import inference_engine_utils
|
|
11
12
|
from snowflake.ml.model._client.ops import service_ops
|
|
12
13
|
from snowflake.snowpark import async_job, session
|
|
13
14
|
|
|
@@ -77,6 +78,15 @@ class HuggingFacePipelineModel:
|
|
|
77
78
|
framework = kwargs.get("framework", None)
|
|
78
79
|
feature_extractor = kwargs.get("feature_extractor", None)
|
|
79
80
|
|
|
81
|
+
_can_download_snapshot = False
|
|
82
|
+
if download_snapshot:
|
|
83
|
+
try:
|
|
84
|
+
import huggingface_hub as hf_hub
|
|
85
|
+
|
|
86
|
+
_can_download_snapshot = True
|
|
87
|
+
except ImportError:
|
|
88
|
+
pass
|
|
89
|
+
|
|
80
90
|
# ==== Start pipeline logic from transformers ====
|
|
81
91
|
if model_kwargs is None:
|
|
82
92
|
model_kwargs = {}
|
|
@@ -141,22 +151,23 @@ class HuggingFacePipelineModel:
|
|
|
141
151
|
# Instantiate config if needed
|
|
142
152
|
config_obj = None
|
|
143
153
|
|
|
144
|
-
if
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
154
|
+
if not _can_download_snapshot:
|
|
155
|
+
if isinstance(config, str):
|
|
156
|
+
config_obj = transformers.AutoConfig.from_pretrained(
|
|
157
|
+
config, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
158
|
+
)
|
|
159
|
+
hub_kwargs["_commit_hash"] = config_obj._commit_hash
|
|
160
|
+
elif config is None and isinstance(model, str):
|
|
161
|
+
config_obj = transformers.AutoConfig.from_pretrained(
|
|
162
|
+
model, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
163
|
+
)
|
|
164
|
+
hub_kwargs["_commit_hash"] = config_obj._commit_hash
|
|
165
|
+
# We only support string as config argument.
|
|
166
|
+
elif config is not None and not isinstance(config, str):
|
|
167
|
+
raise RuntimeError(
|
|
168
|
+
"Impossible to use non-string config as input for HuggingFacePipelineModel. "
|
|
169
|
+
"Use transformers.Pipeline object if required."
|
|
170
|
+
)
|
|
160
171
|
|
|
161
172
|
# ==== Start pipeline logic (Task) from transformers ====
|
|
162
173
|
|
|
@@ -208,7 +219,7 @@ class HuggingFacePipelineModel:
|
|
|
208
219
|
"Using a pipeline without specifying a model name and revision in production is not recommended.",
|
|
209
220
|
stacklevel=2,
|
|
210
221
|
)
|
|
211
|
-
if config is None and isinstance(model, str):
|
|
222
|
+
if not _can_download_snapshot and config is None and isinstance(model, str):
|
|
212
223
|
config_obj = transformers.AutoConfig.from_pretrained(
|
|
213
224
|
model, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
|
214
225
|
)
|
|
@@ -228,11 +239,10 @@ class HuggingFacePipelineModel:
|
|
|
228
239
|
)
|
|
229
240
|
|
|
230
241
|
repo_snapshot_dir: Optional[str] = None
|
|
231
|
-
if
|
|
242
|
+
if _can_download_snapshot:
|
|
232
243
|
try:
|
|
233
|
-
from huggingface_hub import snapshot_download
|
|
234
244
|
|
|
235
|
-
repo_snapshot_dir = snapshot_download(
|
|
245
|
+
repo_snapshot_dir = hf_hub.snapshot_download(
|
|
236
246
|
repo_id=model,
|
|
237
247
|
revision=revision,
|
|
238
248
|
token=token,
|
|
@@ -268,7 +278,7 @@ class HuggingFacePipelineModel:
|
|
|
268
278
|
],
|
|
269
279
|
)
|
|
270
280
|
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
271
|
-
def
|
|
281
|
+
def log_model_and_create_service(
|
|
272
282
|
self,
|
|
273
283
|
*,
|
|
274
284
|
session: session.Session,
|
|
@@ -293,6 +303,7 @@ class HuggingFacePipelineModel:
|
|
|
293
303
|
force_rebuild: bool = False,
|
|
294
304
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
295
305
|
block: bool = True,
|
|
306
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
296
307
|
) -> Union[str, async_job.AsyncJob]:
|
|
297
308
|
"""Logs a Hugging Face model and creates a service in Snowflake.
|
|
298
309
|
|
|
@@ -319,6 +330,10 @@ class HuggingFacePipelineModel:
|
|
|
319
330
|
force_rebuild: Whether to force rebuild the image. Defaults to False.
|
|
320
331
|
build_external_access_integrations: External access integrations for building the image. Defaults to None.
|
|
321
332
|
block: Whether to block the operation. Defaults to True.
|
|
333
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
334
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
335
|
+
`inference_engine` is the name of the inference engine to use.
|
|
336
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
322
337
|
|
|
323
338
|
Raises:
|
|
324
339
|
ValueError: if database and schema name is not provided and session doesn't have a
|
|
@@ -360,6 +375,24 @@ class HuggingFacePipelineModel:
|
|
|
360
375
|
)
|
|
361
376
|
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
|
362
377
|
|
|
378
|
+
# Check if model is HuggingFace text-generation before doing inference engine checks
|
|
379
|
+
inference_engine_args = None
|
|
380
|
+
if experimental_options:
|
|
381
|
+
if self.task != "text-generation":
|
|
382
|
+
raise ValueError(
|
|
383
|
+
"Currently, InferenceEngine using experimental_options is only supported for "
|
|
384
|
+
"HuggingFace text-generation models."
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
|
|
388
|
+
|
|
389
|
+
# Enrich inference engine args if inference engine is specified
|
|
390
|
+
if inference_engine_args is not None:
|
|
391
|
+
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
392
|
+
inference_engine_args,
|
|
393
|
+
gpu_requests,
|
|
394
|
+
)
|
|
395
|
+
|
|
363
396
|
from snowflake.ml.model import event_handler
|
|
364
397
|
from snowflake.snowpark import exceptions
|
|
365
398
|
|
|
@@ -412,6 +445,8 @@ class HuggingFacePipelineModel:
|
|
|
412
445
|
# TODO: remove warehouse in the next release
|
|
413
446
|
warehouse=session.get_current_warehouse(),
|
|
414
447
|
),
|
|
448
|
+
# inference engine
|
|
449
|
+
inference_engine_args=inference_engine_args,
|
|
415
450
|
)
|
|
416
451
|
status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
|
|
417
452
|
return result
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing_extensions import NotRequired
|
|
|
15
15
|
|
|
16
16
|
from snowflake.ml.model.target_platform import TargetPlatform
|
|
17
17
|
from snowflake.ml.model.task import Task
|
|
18
|
+
from snowflake.ml.model.volatility import Volatility
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
import catboost
|
|
@@ -150,6 +151,7 @@ class ModelMethodSaveOptions(TypedDict):
|
|
|
150
151
|
case_sensitive: NotRequired[bool]
|
|
151
152
|
max_batch_size: NotRequired[int]
|
|
152
153
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
|
154
|
+
volatility: NotRequired[Volatility]
|
|
153
155
|
|
|
154
156
|
|
|
155
157
|
class BaseModelSaveOption(TypedDict):
|
|
@@ -158,12 +160,23 @@ class BaseModelSaveOption(TypedDict):
|
|
|
158
160
|
embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
|
|
159
161
|
relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
|
|
160
162
|
It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
|
163
|
+
function_type: Set the method function type globally. To set method function types individually see
|
|
164
|
+
function_type in method_options.
|
|
165
|
+
volatility: Set the volatility for all model methods globally. To set volatility for individual methods
|
|
166
|
+
see volatility in method_options. Defaults are set automatically based on model type: supported
|
|
167
|
+
models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while
|
|
168
|
+
custom models default to VOLATILE. When both global volatility and per-method volatility are specified,
|
|
169
|
+
the per-method volatility takes precedence.
|
|
170
|
+
method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
|
171
|
+
values with the desired options.
|
|
172
|
+
enable_explainability: Whether to enable explainability features for the model.
|
|
161
173
|
save_location: Local directory path to save the model and metadata.
|
|
162
174
|
"""
|
|
163
175
|
|
|
164
176
|
embed_local_ml_library: NotRequired[bool]
|
|
165
177
|
relax_version: NotRequired[bool]
|
|
166
178
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
|
179
|
+
volatility: NotRequired[Volatility]
|
|
167
180
|
method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
|
|
168
181
|
enable_explainability: NotRequired[bool]
|
|
169
182
|
save_location: NotRequired[str]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Volatility definitions for model functions."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Volatility(Enum):
|
|
7
|
+
"""Volatility levels for model functions.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
VOLATILE: Function results may change between calls with the same arguments.
|
|
11
|
+
Use this for functions that depend on external data or have non-deterministic behavior.
|
|
12
|
+
IMMUTABLE: Function results are guaranteed to be the same for the same arguments.
|
|
13
|
+
Use this for pure functions that always return the same output for the same input.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
VOLATILE = auto()
|
|
17
|
+
IMMUTABLE = auto()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DEFAULT_VOLATILITY_BY_MODEL_TYPE = {
|
|
21
|
+
"catboost": Volatility.IMMUTABLE,
|
|
22
|
+
"custom": Volatility.VOLATILE,
|
|
23
|
+
"huggingface_pipeline": Volatility.IMMUTABLE,
|
|
24
|
+
"keras": Volatility.IMMUTABLE,
|
|
25
|
+
"lightgbm": Volatility.IMMUTABLE,
|
|
26
|
+
"mlflow": Volatility.IMMUTABLE,
|
|
27
|
+
"pytorch": Volatility.IMMUTABLE,
|
|
28
|
+
"sentence_transformers": Volatility.IMMUTABLE,
|
|
29
|
+
"sklearn": Volatility.IMMUTABLE,
|
|
30
|
+
"snowml": Volatility.IMMUTABLE,
|
|
31
|
+
"tensorflow": Volatility.IMMUTABLE,
|
|
32
|
+
"torchscript": Volatility.IMMUTABLE,
|
|
33
|
+
"xgboost": Volatility.IMMUTABLE,
|
|
34
|
+
}
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|