snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/platform_capabilities.py +4 -0
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/experiment/experiment_tracking.py +63 -19
  6. snowflake/ml/jobs/__init__.py +4 -0
  7. snowflake/ml/jobs/_interop/__init__.py +0 -0
  8. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  9. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  10. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  11. snowflake/ml/jobs/_interop/legacy.py +225 -0
  12. snowflake/ml/jobs/_interop/protocols.py +471 -0
  13. snowflake/ml/jobs/_interop/results.py +51 -0
  14. snowflake/ml/jobs/_interop/utils.py +144 -0
  15. snowflake/ml/jobs/_utils/constants.py +4 -1
  16. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  17. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  19. snowflake/ml/jobs/_utils/spec_utils.py +50 -11
  20. snowflake/ml/jobs/_utils/types.py +10 -0
  21. snowflake/ml/jobs/job.py +168 -36
  22. snowflake/ml/jobs/manager.py +54 -36
  23. snowflake/ml/model/__init__.py +16 -2
  24. snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
  25. snowflake/ml/model/_client/model/model_version_impl.py +44 -7
  26. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  27. snowflake/ml/model/_client/ops/service_ops.py +50 -5
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  29. snowflake/ml/model/_client/sql/model_version.py +3 -1
  30. snowflake/ml/model/_client/sql/stage.py +8 -0
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
  33. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  34. snowflake/ml/model/_packager/model_env/model_env.py +48 -21
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  38. snowflake/ml/model/type_hints.py +13 -0
  39. snowflake/ml/model/volatility.py +34 -0
  40. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  41. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  42. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  43. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  44. snowflake/ml/modeling/cluster/birch.py +1 -1
  45. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  47. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  48. snowflake/ml/modeling/cluster/k_means.py +1 -1
  49. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  50. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/optics.py +1 -1
  52. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  53. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  55. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  56. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  57. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  58. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  59. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  60. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  61. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  62. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  63. snowflake/ml/modeling/covariance/oas.py +1 -1
  64. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  65. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  66. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  67. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  68. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  71. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/pca.py +1 -1
  73. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  75. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  76. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  77. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  78. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  79. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  88. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  90. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  91. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  92. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  93. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  99. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  100. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  101. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  102. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  103. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  104. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  105. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  111. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  112. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  113. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  114. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  116. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  119. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  120. snowflake/ml/modeling/linear_model/lars.py +1 -1
  121. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  127. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  128. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  129. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  131. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  135. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  138. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  141. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  144. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  145. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  147. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  148. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  149. snowflake/ml/modeling/manifold/isomap.py +1 -1
  150. snowflake/ml/modeling/manifold/mds.py +1 -1
  151. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  152. snowflake/ml/modeling/manifold/tsne.py +1 -1
  153. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  154. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  155. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  156. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  157. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  163. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  164. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  165. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  166. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  167. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  168. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  169. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  170. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  171. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  172. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  173. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  174. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  175. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  177. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  179. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  181. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  182. snowflake/ml/modeling/svm/svc.py +1 -1
  183. snowflake/ml/modeling/svm/svr.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  185. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  192. snowflake/ml/registry/_manager/model_manager.py +1 -0
  193. snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
  194. snowflake/ml/registry/registry.py +15 -0
  195. snowflake/ml/utils/authentication.py +16 -0
  196. snowflake/ml/version.py +1 -1
  197. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
  198. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
  199. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  200. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  201. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
@@ -952,6 +952,7 @@ class ModelOperator:
952
952
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
953
953
  statement_params: Optional[dict[str, str]] = None,
954
954
  is_partitioned: Optional[bool] = None,
955
+ explain_case_sensitive: bool = False,
955
956
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
956
957
  ...
957
958
 
@@ -967,6 +968,7 @@ class ModelOperator:
967
968
  service_name: sql_identifier.SqlIdentifier,
968
969
  strict_input_validation: bool = False,
969
970
  statement_params: Optional[dict[str, str]] = None,
971
+ explain_case_sensitive: bool = False,
970
972
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
971
973
  ...
972
974
 
@@ -986,6 +988,7 @@ class ModelOperator:
986
988
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
987
989
  statement_params: Optional[dict[str, str]] = None,
988
990
  is_partitioned: Optional[bool] = None,
991
+ explain_case_sensitive: bool = False,
989
992
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
990
993
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
991
994
 
@@ -1068,6 +1071,7 @@ class ModelOperator:
1068
1071
  version_name=version_name,
1069
1072
  statement_params=statement_params,
1070
1073
  is_partitioned=is_partitioned or False,
1074
+ explain_case_sensitive=explain_case_sensitive,
1071
1075
  )
1072
1076
 
1073
1077
  if keep_order:
@@ -7,6 +7,7 @@ import re
7
7
  import tempfile
8
8
  import threading
9
9
  import time
10
+ import warnings
10
11
  from typing import Any, Optional, Union, cast
11
12
 
12
13
  from snowflake import snowpark
@@ -14,6 +15,7 @@ from snowflake.ml import jobs
14
15
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
15
16
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
16
17
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
+ from snowflake.ml.model._client.model import batch_inference_specs
17
19
  from snowflake.ml.model._client.service import model_deployment_spec
18
20
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
19
21
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -155,17 +157,17 @@ class ServiceOperator:
155
157
  database_name=database_name,
156
158
  schema_name=schema_name,
157
159
  )
160
+ self._stage_client = stage_sql.StageSQLClient(
161
+ session,
162
+ database_name=database_name,
163
+ schema_name=schema_name,
164
+ )
158
165
  self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
159
166
  if self._use_inlined_deployment_spec:
160
167
  self._workspace = None
161
168
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
162
169
  else:
163
170
  self._workspace = tempfile.TemporaryDirectory()
164
- self._stage_client = stage_sql.StageSQLClient(
165
- session,
166
- database_name=database_name,
167
- schema_name=schema_name,
168
- )
169
171
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
170
172
  workspace_path=pathlib.Path(self._workspace.name)
171
173
  )
@@ -651,6 +653,47 @@ class ServiceOperator:
651
653
  else:
652
654
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
653
655
 
656
+ def _enforce_save_mode(self, output_mode: batch_inference_specs.SaveMode, output_stage_location: str) -> None:
657
+ """Enforce the save mode for the output stage location.
658
+
659
+ Args:
660
+ output_mode: The output mode
661
+ output_stage_location: The output stage location to check/clean.
662
+
663
+ Raises:
664
+ FileExistsError: When ERROR mode is specified and files exist in the output location.
665
+ RuntimeError: When operations fail (checking files or removing files).
666
+ ValueError: When an invalid SaveMode is specified.
667
+ """
668
+ list_results = self._stage_client.list_stage(output_stage_location)
669
+
670
+ if output_mode == batch_inference_specs.SaveMode.ERROR:
671
+ if len(list_results) > 0:
672
+ raise FileExistsError(
673
+ f"Output stage location '{output_stage_location}' is not empty. "
674
+ f"Found {len(list_results)} existing files. When using ERROR mode, the output location "
675
+ f"must be empty. Please clear the existing files or use OVERWRITE mode."
676
+ )
677
+ elif output_mode == batch_inference_specs.SaveMode.OVERWRITE:
678
+ if len(list_results) > 0:
679
+ warnings.warn(
680
+ f"Output stage location '{output_stage_location}' is not empty. "
681
+ f"Found {len(list_results)} existing files. OVERWRITE mode will remove all existing files "
682
+ f"in the output location before running the batch inference job.",
683
+ stacklevel=2,
684
+ )
685
+ try:
686
+ self._session.sql(f"REMOVE {output_stage_location}").collect()
687
+ except Exception as e:
688
+ raise RuntimeError(
689
+ f"OVERWRITE was specified. However, failed to remove existing files in output stage "
690
+ f"{output_stage_location}: {e}. Please clear up the existing files manually and retry "
691
+ f"the operation."
692
+ )
693
+ else:
694
+ valid_modes = list(batch_inference_specs.SaveMode)
695
+ raise ValueError(f"Invalid SaveMode: {output_mode}. Must be one of {valid_modes}")
696
+
654
697
  def _stream_service_logs(
655
698
  self,
656
699
  async_job: snowpark.AsyncJob,
@@ -927,6 +970,7 @@ class ServiceOperator:
927
970
  max_batch_rows: Optional[int],
928
971
  cpu_requests: Optional[str],
929
972
  memory_requests: Optional[str],
973
+ gpu_requests: Optional[str],
930
974
  replicas: Optional[int],
931
975
  statement_params: Optional[dict[str, Any]] = None,
932
976
  ) -> jobs.MLJob[Any]:
@@ -961,6 +1005,7 @@ class ServiceOperator:
961
1005
  warehouse=warehouse,
962
1006
  cpu=cpu_requests,
963
1007
  memory=memory_requests,
1008
+ gpu=gpu_requests,
964
1009
  replicas=replicas,
965
1010
  )
966
1011
 
@@ -204,7 +204,7 @@ class ModelDeploymentSpec:
204
204
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
205
205
  cpu: Optional[str] = None,
206
206
  memory: Optional[str] = None,
207
- gpu: Optional[Union[str, int]] = None,
207
+ gpu: Optional[str] = None,
208
208
  num_workers: Optional[int] = None,
209
209
  max_batch_rows: Optional[int] = None,
210
210
  replicas: Optional[int] = None,
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
438
438
  partition_column: Optional[sql_identifier.SqlIdentifier],
439
439
  statement_params: Optional[dict[str, Any]] = None,
440
440
  is_partitioned: bool = True,
441
+ explain_case_sensitive: bool = False,
441
442
  ) -> dataframe.DataFrame:
442
443
  with_statements = []
443
444
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
505
506
  cols_to_drop = []
506
507
 
507
508
  for output_name, output_type, output_col_name in returns:
508
- output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
509
+ case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
510
+ output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
509
511
  if output_identifier != output_col_name:
510
512
  cols_to_drop.append(output_identifier)
511
513
  output_cols.append(F.col(output_identifier).astype(output_type))
@@ -2,6 +2,7 @@ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import Row
5
6
 
6
7
 
7
8
  class StageSQLClient(_base._BaseSQLClient):
@@ -21,3 +22,10 @@ class StageSQLClient(_base._BaseSQLClient):
21
22
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
22
23
 
23
24
  return fq_stage_name
25
+
26
+ def list_stage(self, stage_name: str) -> list[Row]:
27
+ try:
28
+ list_results = self._session.sql(f"LIST {stage_name}").collect()
29
+ except Exception as e:
30
+ raise RuntimeError(f"Failed to check stage location '{stage_name}': {e}")
31
+ return list_results
@@ -46,6 +46,7 @@ class ModelFunctionMethodDict(TypedDict):
46
46
  handler: Required[str]
47
47
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
48
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
49
+ volatility: NotRequired[str]
49
50
 
50
51
 
51
52
  ModelMethodDict = ModelFunctionMethodDict
@@ -4,14 +4,17 @@ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import sql_identifier
8
9
  from snowflake.ml.model import model_signature, type_hints
9
10
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
11
  from snowflake.ml.model._model_composer.model_method import (
11
12
  constants,
12
13
  function_generator,
14
+ utils,
13
15
  )
14
16
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
17
+ from snowflake.ml.model.volatility import Volatility
15
18
  from snowflake.snowpark._internal import type_utils
16
19
 
17
20
 
@@ -20,28 +23,43 @@ class ModelMethodOptions(TypedDict):
20
23
 
21
24
  case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL.
22
25
  function_type: One of `ModelMethodFunctionTypes` specifying function type.
26
+ volatility: One of `Volatility` enum values specifying function volatility.
23
27
  """
24
28
 
25
29
  case_sensitive: NotRequired[bool]
26
30
  function_type: NotRequired[str]
31
+ volatility: NotRequired[Volatility]
27
32
 
28
33
 
29
34
  def get_model_method_options_from_options(
30
35
  options: type_hints.ModelSaveOption, target_method: str
31
36
  ) -> ModelMethodOptions:
32
37
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
38
+ method_option = options.get("method_options", {}).get(target_method, {})
39
+ case_sensitive = method_option.get("case_sensitive", False)
33
40
  if target_method == "explain":
34
41
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
35
- method_option = options.get("method_options", {}).get(target_method, {})
42
+ case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
43
+ options.get("method_options", {}), target_method
44
+ )
36
45
  global_function_type = options.get("function_type", default_function_type)
37
46
  function_type = method_option.get("function_type", global_function_type)
38
47
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
39
48
  raise NotImplementedError(f"Function type {function_type} is not supported.")
40
49
 
41
- return ModelMethodOptions(
42
- case_sensitive=method_option.get("case_sensitive", False),
50
+ default_volatility = options.get("volatility")
51
+ method_volatility = method_option.get("volatility")
52
+ resolved_volatility = method_volatility or default_volatility
53
+
54
+ # Only include volatility if explicitly provided in method options
55
+ result: ModelMethodOptions = ModelMethodOptions(
56
+ case_sensitive=case_sensitive,
43
57
  function_type=function_type,
44
58
  )
59
+ if resolved_volatility:
60
+ result["volatility"] = resolved_volatility
61
+
62
+ return result
45
63
 
46
64
 
47
65
  class ModelMethod:
@@ -94,6 +112,9 @@ class ModelMethod:
94
112
  "function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
95
113
  )
96
114
 
115
+ # Volatility is optional; when not provided, we omit it from the manifest
116
+ self.volatility = self.options.get("volatility")
117
+
97
118
  @staticmethod
98
119
  def _get_method_arg_from_feature(
99
120
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
@@ -148,7 +169,7 @@ class ModelMethod:
148
169
  else:
149
170
  outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")]
150
171
 
151
- return model_manifest_schema.ModelFunctionMethodDict(
172
+ method_dict = model_manifest_schema.ModelFunctionMethodDict(
152
173
  name=self.method_name.resolved(),
153
174
  runtime=self.runtime_name,
154
175
  type=self.function_type,
@@ -158,3 +179,10 @@ class ModelMethod:
158
179
  inputs=input_list,
159
180
  outputs=outputs,
160
181
  )
182
+ should_set_volatility = (
183
+ platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
184
+ )
185
+ if should_set_volatility and self.volatility is not None:
186
+ method_dict["volatility"] = self.volatility.name
187
+
188
+ return method_dict
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Mapping, Optional
4
+
5
+
6
+ def determine_explain_case_sensitive_from_method_options(
7
+ method_options: Mapping[str, Optional[Mapping[str, Any]]],
8
+ target_method: str,
9
+ ) -> bool:
10
+ """Determine explain method case sensitivity from related predict methods.
11
+
12
+ Args:
13
+ method_options: Mapping from method name to its options. Each option may
14
+ contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
15
+ target_method: The target method name being resolved (e.g., an ``explain_*``
16
+ method).
17
+
18
+ Returns:
19
+ True if the explain method should be treated as case sensitive; otherwise False.
20
+ """
21
+ if "explain" not in target_method:
22
+ return False
23
+ predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
24
+ for src_method in predict_priority_methods:
25
+ src_opts = method_options.get(src_method)
26
+ if src_opts is not None:
27
+ return bool(src_opts.get("case_sensitive", False))
28
+ return False
@@ -145,11 +145,12 @@ class ModelEnv:
145
145
  """
146
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
147
147
  pip_pkg_reqs: list[str] = []
148
- if self.targets_warehouse:
148
+ if self.targets_warehouse and not self.artifact_repository_map:
149
149
  self._warn_once(
150
150
  (
151
151
  "Dependencies specified from pip requirements."
152
152
  " This may prevent model deploying to Snowflake Warehouse."
153
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
153
154
  ),
154
155
  stacklevel=2,
155
156
  )
@@ -177,7 +178,11 @@ class ModelEnv:
177
178
  req_to_add.name = conda_req.name
178
179
  else:
179
180
  req_to_add = conda_req
180
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
181
+ show_warning_message = (
182
+ conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
183
+ and self.targets_warehouse
184
+ and not self.artifact_repository_map
185
+ )
181
186
 
182
187
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
183
188
  if show_warning_message:
@@ -185,6 +190,7 @@ class ModelEnv:
185
190
  (
186
191
  f"Basic dependency {req_to_add.name} specified from pip requirements."
187
192
  " This may prevent model deploying to Snowflake Warehouse."
193
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
188
194
  ),
189
195
  stacklevel=2,
190
196
  )
@@ -234,14 +240,31 @@ class ModelEnv:
234
240
  self._conda_dependencies[channel].remove(spec)
235
241
 
236
242
  def generate_env_for_cuda(self) -> None:
243
+
244
+ # Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
237
245
  xgboost_spec = env_utils.find_dep_spec(
238
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
246
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
239
247
  )
240
248
  if xgboost_spec:
241
- self.include_if_absent(
242
- [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
243
- check_local_version=False,
244
- )
249
+ # Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
250
+ pinned_major: Optional[int] = None
251
+ for spec in xgboost_spec.specifier:
252
+ if spec.operator in ("==", "===", ">", ">="):
253
+ try:
254
+ pinned_major = version.parse(spec.version).major
255
+ except version.InvalidVersion:
256
+ pinned_major = None
257
+ break
258
+
259
+ if pinned_major is not None and pinned_major < 3:
260
+ xgboost_spec = env_utils.find_dep_spec(
261
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
262
+ )
263
+ if xgboost_spec:
264
+ self.include_if_absent(
265
+ [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
266
+ check_local_version=False,
267
+ )
245
268
 
246
269
  tf_spec = env_utils.find_dep_spec(
247
270
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
@@ -318,13 +341,15 @@ class ModelEnv:
318
341
  )
319
342
 
320
343
  if pip_requirements_list and self.targets_warehouse:
321
- self._warn_once(
322
- (
323
- "Found dependencies specified as pip requirements."
324
- " This may prevent model deploying to Snowflake Warehouse."
325
- ),
326
- stacklevel=2,
327
- )
344
+ if not self.artifact_repository_map:
345
+ self._warn_once(
346
+ (
347
+ "Found dependencies specified as pip requirements."
348
+ " This may prevent model deploying to Snowflake Warehouse."
349
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
350
+ ),
351
+ stacklevel=2,
352
+ )
328
353
  for pip_dependency in pip_requirements_list:
329
354
  if any(
330
355
  channel_dependency.name == pip_dependency.name
@@ -343,13 +368,15 @@ class ModelEnv:
343
368
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
344
369
 
345
370
  if pip_requirements_list and self.targets_warehouse:
346
- self._warn_once(
347
- (
348
- "Found dependencies specified as pip requirements."
349
- " This may prevent model deploying to Snowflake Warehouse."
350
- ),
351
- stacklevel=2,
352
- )
371
+ if not self.artifact_repository_map:
372
+ self._warn_once(
373
+ (
374
+ "Found dependencies specified as pip requirements."
375
+ " This may prevent model deploying to Snowflake Warehouse."
376
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
377
+ ),
378
+ stacklevel=2,
379
+ )
353
380
  for pip_dependency in pip_requirements_list:
354
381
  if any(
355
382
  channel_dependency.name == pip_dependency.name
@@ -116,6 +116,8 @@ def create_model_metadata(
116
116
  if embed_local_ml_library:
117
117
  env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}"
118
118
 
119
+ # Persist full method_options
120
+ method_options: dict[str, dict[str, Any]] = kwargs.pop("method_options", {})
119
121
  model_meta = ModelMetadata(
120
122
  name=name,
121
123
  env=env,
@@ -124,6 +126,7 @@ def create_model_metadata(
124
126
  signatures=signatures,
125
127
  function_properties=function_properties,
126
128
  task=task,
129
+ method_options=method_options,
127
130
  )
128
131
 
129
132
  code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
@@ -256,6 +259,7 @@ class ModelMetadata:
256
259
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
257
260
  task: model_types.Task = model_types.Task.UNKNOWN,
258
261
  explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
262
+ method_options: Optional[dict[str, dict[str, Any]]] = None,
259
263
  ) -> None:
260
264
  self.name = name
261
265
  self.signatures: dict[str, model_signature.ModelSignature] = dict()
@@ -283,6 +287,7 @@ class ModelMetadata:
283
287
 
284
288
  self.task: model_types.Task = task
285
289
  self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
290
+ self.method_options: dict[str, dict[str, Any]] = method_options or {}
286
291
 
287
292
  @property
288
293
  def min_snowpark_ml_version(self) -> str:
@@ -342,6 +347,7 @@ class ModelMetadata:
342
347
  else None
343
348
  ),
344
349
  "function_properties": self.function_properties,
350
+ "method_options": self.method_options,
345
351
  }
346
352
  )
347
353
  with open(model_yaml_path, "w", encoding="utf-8") as out:
@@ -381,6 +387,7 @@ class ModelMetadata:
381
387
  task=loaded_meta.get("task", model_types.Task.UNKNOWN.value),
382
388
  explainability=loaded_meta.get("explainability", None),
383
389
  function_properties=loaded_meta.get("function_properties", {}),
390
+ method_options=loaded_meta.get("method_options", {}),
384
391
  )
385
392
 
386
393
  @classmethod
@@ -436,4 +443,5 @@ class ModelMetadata:
436
443
  task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)),
437
444
  explain_algorithm=explanation_algorithm,
438
445
  function_properties=model_dict.get("function_properties", {}),
446
+ method_options=model_dict.get("method_options", {}),
439
447
  )
@@ -125,6 +125,7 @@ class ModelMetadataDict(TypedDict):
125
125
  task: Required[str]
126
126
  explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
127
127
  function_properties: NotRequired[dict[str, dict[str, Any]]]
128
+ method_options: NotRequired[dict[str, dict[str, Any]]]
128
129
 
129
130
 
130
131
  class ModelExplainAlgorithm(Enum):
@@ -21,14 +21,14 @@ REQUIREMENTS = [
21
21
  "requests",
22
22
  "retrying>=1.3.3,<2",
23
23
  "s3fs>=2024.6.1,<2026",
24
- "scikit-learn<1.7",
24
+ "scikit-learn<1.8",
25
25
  "scipy>=1.9,<2",
26
26
  "shap>=0.46.0,<1",
27
- "snowflake-connector-python>=3.16.0,<4",
27
+ "snowflake-connector-python>=3.17.0,<4",
28
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
29
29
  "snowflake.core>=1.0.2,<2",
30
30
  "sqlparse>=0.4,<1",
31
31
  "tqdm<5",
32
32
  "typing-extensions>=4.1.0,<5",
33
- "xgboost>=1.7.3,<3",
33
+ "xgboost<4",
34
34
  ]
@@ -15,6 +15,7 @@ from typing_extensions import NotRequired
15
15
 
16
16
  from snowflake.ml.model.target_platform import TargetPlatform
17
17
  from snowflake.ml.model.task import Task
18
+ from snowflake.ml.model.volatility import Volatility
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  import catboost
@@ -150,6 +151,7 @@ class ModelMethodSaveOptions(TypedDict):
150
151
  case_sensitive: NotRequired[bool]
151
152
  max_batch_size: NotRequired[int]
152
153
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
154
+ volatility: NotRequired[Volatility]
153
155
 
154
156
 
155
157
  class BaseModelSaveOption(TypedDict):
@@ -158,12 +160,23 @@ class BaseModelSaveOption(TypedDict):
158
160
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
159
161
  relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
160
162
  It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
163
+ function_type: Set the method function type globally. To set method function types individually see
164
+ function_type in method_options.
165
+ volatility: Set the volatility for all model methods globally. To set volatility for individual methods
166
+ see volatility in method_options. Defaults are set automatically based on model type: supported
167
+ models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while
168
+ custom models default to VOLATILE. When both global volatility and per-method volatility are specified,
169
+ the per-method volatility takes precedence.
170
+ method_options: Per-method saving options. This dictionary has method names as keys and dictionary
171
+ values with the desired options.
172
+ enable_explainability: Whether to enable explainability features for the model.
161
173
  save_location: Local directory path to save the model and metadata.
162
174
  """
163
175
 
164
176
  embed_local_ml_library: NotRequired[bool]
165
177
  relax_version: NotRequired[bool]
166
178
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
179
+ volatility: NotRequired[Volatility]
167
180
  method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
168
181
  enable_explainability: NotRequired[bool]
169
182
  save_location: NotRequired[str]
@@ -0,0 +1,34 @@
1
+ """Volatility definitions for model functions."""
2
+
3
+ from enum import Enum, auto
4
+
5
+
6
+ class Volatility(Enum):
7
+ """Volatility levels for model functions.
8
+
9
+ Attributes:
10
+ VOLATILE: Function results may change between calls with the same arguments.
11
+ Use this for functions that depend on external data or have non-deterministic behavior.
12
+ IMMUTABLE: Function results are guaranteed to be the same for the same arguments.
13
+ Use this for pure functions that always return the same output for the same input.
14
+ """
15
+
16
+ VOLATILE = auto()
17
+ IMMUTABLE = auto()
18
+
19
+
20
+ DEFAULT_VOLATILITY_BY_MODEL_TYPE = {
21
+ "catboost": Volatility.IMMUTABLE,
22
+ "custom": Volatility.VOLATILE,
23
+ "huggingface_pipeline": Volatility.IMMUTABLE,
24
+ "keras": Volatility.IMMUTABLE,
25
+ "lightgbm": Volatility.IMMUTABLE,
26
+ "mlflow": Volatility.IMMUTABLE,
27
+ "pytorch": Volatility.IMMUTABLE,
28
+ "sentence_transformers": Volatility.IMMUTABLE,
29
+ "sklearn": Volatility.IMMUTABLE,
30
+ "snowml": Volatility.IMMUTABLE,
31
+ "tensorflow": Volatility.IMMUTABLE,
32
+ "torchscript": Volatility.IMMUTABLE,
33
+ "xgboost": Volatility.IMMUTABLE,
34
+ }
@@ -93,7 +93,7 @@ def get_data_iterator(
93
93
  cache_dir_name = tempfile.mkdtemp()
94
94
  super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache"))
95
95
 
96
- def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def]
96
+ def next(self, batch_consumer_fn) -> bool | int: # type: ignore[no-untyped-def]
97
97
  """Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn.
98
98
  This function is called by XGBoost during the construction of ``DMatrix``
99
99
 
@@ -101,7 +101,7 @@ def get_data_iterator(
101
101
  batch_consumer_fn: batch consumer function
102
102
 
103
103
  Returns:
104
- 0 if there is no more data, else 1.
104
+ False/0 if there is no more data, else True/1.
105
105
  """
106
106
  while (self._df is None) or (self._df.shape[0] < self._batch_size):
107
107
  # Read files and append data to temp df until batch size is reached.
@@ -117,7 +117,7 @@ def get_data_iterator(
117
117
 
118
118
  if (self._df is None) or (self._df.shape[0] == 0):
119
119
  # No more data
120
- return 0
120
+ return False
121
121
 
122
122
  # Slice the temp df and save the remainder in the temp df
123
123
  batch_end_index = min(self._batch_size, self._df.shape[0])
@@ -133,8 +133,8 @@ def get_data_iterator(
133
133
  func_args["weight"] = batch_df[self._sample_weight_col].squeeze()
134
134
 
135
135
  batch_consumer_fn(**func_args)
136
- # Return 1 to let XGBoost know we haven't seen all the files yet.
137
- return 1
136
+ # Return True to let XGBoost know we haven't seen all the files yet.
137
+ return True
138
138
 
139
139
  def reset(self) -> None:
140
140
  """Reset the iterator to its beginning"""
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(