snowflake-ml-python 1.15.0__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +4 -0
  2. snowflake/ml/_internal/utils/mixins.py +24 -9
  3. snowflake/ml/experiment/experiment_tracking.py +63 -19
  4. snowflake/ml/jobs/_utils/spec_utils.py +49 -11
  5. snowflake/ml/jobs/manager.py +20 -0
  6. snowflake/ml/model/__init__.py +16 -2
  7. snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
  8. snowflake/ml/model/_client/model/model_version_impl.py +5 -0
  9. snowflake/ml/model/_client/ops/service_ops.py +50 -5
  10. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  11. snowflake/ml/model/_client/sql/stage.py +8 -0
  12. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  13. snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
  14. snowflake/ml/model/_packager/model_env/model_env.py +26 -16
  15. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  16. snowflake/ml/model/type_hints.py +13 -0
  17. snowflake/ml/model/volatility.py +34 -0
  18. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  19. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  20. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  21. snowflake/ml/modeling/cluster/birch.py +1 -1
  22. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  23. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  24. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  25. snowflake/ml/modeling/cluster/k_means.py +1 -1
  26. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  27. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  28. snowflake/ml/modeling/cluster/optics.py +1 -1
  29. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  30. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  31. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  32. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  33. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  34. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  35. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  36. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  37. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  38. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  39. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  40. snowflake/ml/modeling/covariance/oas.py +1 -1
  41. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  42. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  43. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  44. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  45. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  46. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  47. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  48. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  49. snowflake/ml/modeling/decomposition/pca.py +1 -1
  50. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  51. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  52. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  53. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  54. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  55. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  56. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  57. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  58. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  59. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  60. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  61. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  62. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  63. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  64. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  65. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  66. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  67. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  68. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  70. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  71. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  72. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  73. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  74. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  75. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  76. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  77. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  78. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  79. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  80. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  81. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  82. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  83. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  84. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  85. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  86. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  87. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  88. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  89. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  90. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  91. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  92. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  93. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  94. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  95. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  96. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  97. snowflake/ml/modeling/linear_model/lars.py +1 -1
  98. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  99. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  100. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  101. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  102. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  103. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  104. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  105. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  107. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  110. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  112. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  113. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  114. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  115. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  117. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  119. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  122. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  123. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  124. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  125. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  126. snowflake/ml/modeling/manifold/isomap.py +1 -1
  127. snowflake/ml/modeling/manifold/mds.py +1 -1
  128. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  129. snowflake/ml/modeling/manifold/tsne.py +1 -1
  130. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  131. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  132. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  133. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  134. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  135. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  136. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  137. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  138. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  139. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  140. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  141. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  142. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  143. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  144. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  145. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  146. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  147. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  148. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  149. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  150. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  151. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  152. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  153. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  154. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  155. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  156. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  157. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  158. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  159. snowflake/ml/modeling/svm/svc.py +1 -1
  160. snowflake/ml/modeling/svm/svr.py +1 -1
  161. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  162. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  163. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  164. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  165. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  166. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  167. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  168. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  169. snowflake/ml/registry/_manager/model_manager.py +1 -0
  170. snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
  171. snowflake/ml/registry/registry.py +15 -0
  172. snowflake/ml/utils/authentication.py +16 -0
  173. snowflake/ml/version.py +1 -1
  174. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +41 -3
  175. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +178 -177
  176. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
  177. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
  178. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import Row
5
6
 
6
7
 
7
8
  class StageSQLClient(_base._BaseSQLClient):
@@ -21,3 +22,10 @@ class StageSQLClient(_base._BaseSQLClient):
21
22
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
22
23
 
23
24
  return fq_stage_name
25
+
26
+ def list_stage(self, stage_name: str) -> list[Row]:
27
+ try:
28
+ list_results = self._session.sql(f"LIST {stage_name}").collect()
29
+ except Exception as e:
30
+ raise RuntimeError(f"Failed to check stage location '{stage_name}': {e}")
31
+ return list_results
@@ -46,6 +46,7 @@ class ModelFunctionMethodDict(TypedDict):
46
46
  handler: Required[str]
47
47
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
48
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
49
+ volatility: NotRequired[str]
49
50
 
50
51
 
51
52
  ModelMethodDict = ModelFunctionMethodDict
@@ -4,6 +4,7 @@ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import sql_identifier
8
9
  from snowflake.ml.model import model_signature, type_hints
9
10
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -12,6 +13,7 @@ from snowflake.ml.model._model_composer.model_method import (
12
13
  function_generator,
13
14
  )
14
15
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
16
+ from snowflake.ml.model.volatility import Volatility
15
17
  from snowflake.snowpark._internal import type_utils
16
18
 
17
19
 
@@ -20,10 +22,12 @@ class ModelMethodOptions(TypedDict):
20
22
 
21
23
  case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL.
22
24
  function_type: One of `ModelMethodFunctionTypes` specifying function type.
25
+ volatility: One of `Volatility` enum values specifying function volatility.
23
26
  """
24
27
 
25
28
  case_sensitive: NotRequired[bool]
26
29
  function_type: NotRequired[str]
30
+ volatility: NotRequired[Volatility]
27
31
 
28
32
 
29
33
  def get_model_method_options_from_options(
@@ -38,10 +42,19 @@ def get_model_method_options_from_options(
38
42
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
39
43
  raise NotImplementedError(f"Function type {function_type} is not supported.")
40
44
 
41
- return ModelMethodOptions(
45
+ default_volatility = options.get("volatility")
46
+ method_volatility = method_option.get("volatility")
47
+ resolved_volatility = method_volatility or default_volatility
48
+
49
+ # Only include volatility if explicitly provided in method options
50
+ result: ModelMethodOptions = ModelMethodOptions(
42
51
  case_sensitive=method_option.get("case_sensitive", False),
43
52
  function_type=function_type,
44
53
  )
54
+ if resolved_volatility:
55
+ result["volatility"] = resolved_volatility
56
+
57
+ return result
45
58
 
46
59
 
47
60
  class ModelMethod:
@@ -94,6 +107,9 @@ class ModelMethod:
94
107
  "function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
95
108
  )
96
109
 
110
+ # Volatility is optional; when not provided, we omit it from the manifest
111
+ self.volatility = self.options.get("volatility")
112
+
97
113
  @staticmethod
98
114
  def _get_method_arg_from_feature(
99
115
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
@@ -148,7 +164,7 @@ class ModelMethod:
148
164
  else:
149
165
  outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")]
150
166
 
151
- return model_manifest_schema.ModelFunctionMethodDict(
167
+ method_dict = model_manifest_schema.ModelFunctionMethodDict(
152
168
  name=self.method_name.resolved(),
153
169
  runtime=self.runtime_name,
154
170
  type=self.function_type,
@@ -158,3 +174,10 @@ class ModelMethod:
158
174
  inputs=input_list,
159
175
  outputs=outputs,
160
176
  )
177
+ should_set_volatility = (
178
+ platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
179
+ )
180
+ if should_set_volatility and self.volatility is not None:
181
+ method_dict["volatility"] = self.volatility.name
182
+
183
+ return method_dict
@@ -145,11 +145,12 @@ class ModelEnv:
145
145
  """
146
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
147
147
  pip_pkg_reqs: list[str] = []
148
- if self.targets_warehouse:
148
+ if self.targets_warehouse and not self.artifact_repository_map:
149
149
  self._warn_once(
150
150
  (
151
151
  "Dependencies specified from pip requirements."
152
152
  " This may prevent model deploying to Snowflake Warehouse."
153
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
153
154
  ),
154
155
  stacklevel=2,
155
156
  )
@@ -177,7 +178,11 @@ class ModelEnv:
177
178
  req_to_add.name = conda_req.name
178
179
  else:
179
180
  req_to_add = conda_req
180
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
181
+ show_warning_message = (
182
+ conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
183
+ and self.targets_warehouse
184
+ and not self.artifact_repository_map
185
+ )
181
186
 
182
187
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
183
188
  if show_warning_message:
@@ -185,6 +190,7 @@ class ModelEnv:
185
190
  (
186
191
  f"Basic dependency {req_to_add.name} specified from pip requirements."
187
192
  " This may prevent model deploying to Snowflake Warehouse."
193
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
188
194
  ),
189
195
  stacklevel=2,
190
196
  )
@@ -318,13 +324,15 @@ class ModelEnv:
318
324
  )
319
325
 
320
326
  if pip_requirements_list and self.targets_warehouse:
321
- self._warn_once(
322
- (
323
- "Found dependencies specified as pip requirements."
324
- " This may prevent model deploying to Snowflake Warehouse."
325
- ),
326
- stacklevel=2,
327
- )
327
+ if not self.artifact_repository_map:
328
+ self._warn_once(
329
+ (
330
+ "Found dependencies specified as pip requirements."
331
+ " This may prevent model deploying to Snowflake Warehouse."
332
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
333
+ ),
334
+ stacklevel=2,
335
+ )
328
336
  for pip_dependency in pip_requirements_list:
329
337
  if any(
330
338
  channel_dependency.name == pip_dependency.name
@@ -343,13 +351,15 @@ class ModelEnv:
343
351
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
344
352
 
345
353
  if pip_requirements_list and self.targets_warehouse:
346
- self._warn_once(
347
- (
348
- "Found dependencies specified as pip requirements."
349
- " This may prevent model deploying to Snowflake Warehouse."
350
- ),
351
- stacklevel=2,
352
- )
354
+ if not self.artifact_repository_map:
355
+ self._warn_once(
356
+ (
357
+ "Found dependencies specified as pip requirements."
358
+ " This may prevent model deploying to Snowflake Warehouse."
359
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
360
+ ),
361
+ stacklevel=2,
362
+ )
353
363
  for pip_dependency in pip_requirements_list:
354
364
  if any(
355
365
  channel_dependency.name == pip_dependency.name
@@ -21,7 +21,7 @@ REQUIREMENTS = [
21
21
  "requests",
22
22
  "retrying>=1.3.3,<2",
23
23
  "s3fs>=2024.6.1,<2026",
24
- "scikit-learn<1.7",
24
+ "scikit-learn<1.8",
25
25
  "scipy>=1.9,<2",
26
26
  "shap>=0.46.0,<1",
27
27
  "snowflake-connector-python>=3.16.0,<4",
@@ -15,6 +15,7 @@ from typing_extensions import NotRequired
15
15
 
16
16
  from snowflake.ml.model.target_platform import TargetPlatform
17
17
  from snowflake.ml.model.task import Task
18
+ from snowflake.ml.model.volatility import Volatility
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  import catboost
@@ -150,6 +151,7 @@ class ModelMethodSaveOptions(TypedDict):
150
151
  case_sensitive: NotRequired[bool]
151
152
  max_batch_size: NotRequired[int]
152
153
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
154
+ volatility: NotRequired[Volatility]
153
155
 
154
156
 
155
157
  class BaseModelSaveOption(TypedDict):
@@ -158,12 +160,23 @@ class BaseModelSaveOption(TypedDict):
158
160
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
159
161
  relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
160
162
  It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
163
+ function_type: Set the method function type globally. To set method function types individually see
164
+ function_type in method_options.
165
+ volatility: Set the volatility for all model methods globally. To set volatility for individual methods
166
+ see volatility in method_options. Defaults are set automatically based on model type: supported
167
+ models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while
168
+ custom models default to VOLATILE. When both global volatility and per-method volatility are specified,
169
+ the per-method volatility takes precedence.
170
+ method_options: Per-method saving options. This dictionary has method names as keys and dictionary
171
+ values with the desired options.
172
+ enable_explainability: Whether to enable explainability features for the model.
161
173
  save_location: Local directory path to save the model and metadata.
162
174
  """
163
175
 
164
176
  embed_local_ml_library: NotRequired[bool]
165
177
  relax_version: NotRequired[bool]
166
178
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
179
+ volatility: NotRequired[Volatility]
167
180
  method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
168
181
  enable_explainability: NotRequired[bool]
169
182
  save_location: NotRequired[str]
@@ -0,0 +1,34 @@
1
+ """Volatility definitions for model functions."""
2
+
3
+ from enum import Enum, auto
4
+
5
+
6
+ class Volatility(Enum):
7
+ """Volatility levels for model functions.
8
+
9
+ Attributes:
10
+ VOLATILE: Function results may change between calls with the same arguments.
11
+ Use this for functions that depend on external data or have non-deterministic behavior.
12
+ IMMUTABLE: Function results are guaranteed to be the same for the same arguments.
13
+ Use this for pure functions that always return the same output for the same input.
14
+ """
15
+
16
+ VOLATILE = auto()
17
+ IMMUTABLE = auto()
18
+
19
+
20
+ DEFAULT_VOLATILITY_BY_MODEL_TYPE = {
21
+ "catboost": Volatility.IMMUTABLE,
22
+ "custom": Volatility.VOLATILE,
23
+ "huggingface_pipeline": Volatility.IMMUTABLE,
24
+ "keras": Volatility.IMMUTABLE,
25
+ "lightgbm": Volatility.IMMUTABLE,
26
+ "mlflow": Volatility.IMMUTABLE,
27
+ "pytorch": Volatility.IMMUTABLE,
28
+ "sentence_transformers": Volatility.IMMUTABLE,
29
+ "sklearn": Volatility.IMMUTABLE,
30
+ "snowml": Volatility.IMMUTABLE,
31
+ "tensorflow": Volatility.IMMUTABLE,
32
+ "torchscript": Volatility.IMMUTABLE,
33
+ "xgboost": Volatility.IMMUTABLE,
34
+ }
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(