snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +13 -7
  2. snowflake/ml/_internal/utils/connection_params.py +5 -3
  3. snowflake/ml/_internal/utils/jwt_generator.py +3 -2
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
  7. snowflake/ml/experiment/_entities/__init__.py +2 -1
  8. snowflake/ml/experiment/_entities/run.py +0 -15
  9. snowflake/ml/experiment/_entities/run_metadata.py +3 -51
  10. snowflake/ml/experiment/experiment_tracking.py +71 -27
  11. snowflake/ml/jobs/_utils/spec_utils.py +49 -11
  12. snowflake/ml/jobs/manager.py +20 -0
  13. snowflake/ml/model/__init__.py +12 -2
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
  15. snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +30 -62
  17. snowflake/ml/model/_client/ops/service_ops.py +68 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  19. snowflake/ml/model/_client/sql/service.py +29 -2
  20. snowflake/ml/model/_client/sql/stage.py +8 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
  23. snowflake/ml/model/_packager/model_env/model_env.py +26 -16
  24. snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
  26. snowflake/ml/model/_packager/model_packager.py +4 -3
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  28. snowflake/ml/model/_signatures/utils.py +0 -21
  29. snowflake/ml/model/models/huggingface_pipeline.py +56 -21
  30. snowflake/ml/model/type_hints.py +13 -0
  31. snowflake/ml/model/volatility.py +34 -0
  32. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  33. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  34. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  35. snowflake/ml/modeling/cluster/birch.py +1 -1
  36. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  37. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  38. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  39. snowflake/ml/modeling/cluster/k_means.py +1 -1
  40. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  41. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  42. snowflake/ml/modeling/cluster/optics.py +1 -1
  43. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  44. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  45. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  46. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  47. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  48. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  49. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  51. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  52. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  53. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  54. snowflake/ml/modeling/covariance/oas.py +1 -1
  55. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  56. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  57. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  58. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  59. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  60. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  62. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  63. snowflake/ml/modeling/decomposition/pca.py +1 -1
  64. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  65. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  66. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  67. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  79. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  84. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  85. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  86. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  87. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  88. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  89. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  90. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  91. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  94. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  95. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  96. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  107. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/lars.py +1 -1
  112. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  113. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  118. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  128. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  131. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  140. snowflake/ml/modeling/manifold/isomap.py +1 -1
  141. snowflake/ml/modeling/manifold/mds.py +1 -1
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  143. snowflake/ml/modeling/manifold/tsne.py +1 -1
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  156. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  166. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  167. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  168. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  169. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  170. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  171. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  172. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  173. snowflake/ml/modeling/svm/svc.py +1 -1
  174. snowflake/ml/modeling/svm/svr.py +1 -1
  175. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  176. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  177. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  178. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  179. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  180. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  182. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  183. snowflake/ml/registry/_manager/model_manager.py +2 -1
  184. snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
  185. snowflake/ml/registry/registry.py +15 -0
  186. snowflake/ml/utils/authentication.py +16 -0
  187. snowflake/ml/utils/connection_params.py +5 -3
  188. snowflake/ml/version.py +1 -1
  189. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
  190. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
  191. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
  192. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
  193. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
+ import contextlib
1
2
  import dataclasses
2
3
  import enum
3
4
  import logging
4
5
  import textwrap
5
- from typing import Any, Optional
6
+ from typing import Any, Generator, Optional
6
7
 
7
8
  from snowflake import snowpark
8
9
  from snowflake.ml._internal.utils import (
@@ -17,6 +18,11 @@ from snowflake.snowpark._internal import utils as snowpark_utils
17
18
 
18
19
  logger = logging.getLogger(__name__)
19
20
 
21
+ # Using this token instead of '?' to avoid escaping issues
22
+ # After quotes are escaped, we replace this token with '|| ? ||'
23
+ QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
24
+ QMARK_PARAMETER_TOKEN = "'|| ? ||'"
25
+
20
26
 
21
27
  class ServiceStatus(enum.Enum):
22
28
  PENDING = "PENDING"
@@ -70,12 +76,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
70
76
  CONTAINER_STATUS = "status"
71
77
  MESSAGE = "message"
72
78
 
79
+ @contextlib.contextmanager
80
+ def _qmark_paramstyle(self) -> Generator[None, None, None]:
81
+ """Context manager that temporarily changes paramstyle to qmark and restores original value on exit."""
82
+ if not hasattr(self._session, "_options"):
83
+ yield
84
+ else:
85
+ original_paramstyle = self._session._options["paramstyle"]
86
+ try:
87
+ self._session._options["paramstyle"] = "qmark"
88
+ yield
89
+ finally:
90
+ self._session._options["paramstyle"] = original_paramstyle
91
+
73
92
  def deploy_model(
74
93
  self,
75
94
  *,
76
95
  stage_path: Optional[str] = None,
77
96
  model_deployment_spec_yaml_str: Optional[str] = None,
78
97
  model_deployment_spec_file_rel_path: Optional[str] = None,
98
+ query_params: Optional[list[Any]] = None,
79
99
  statement_params: Optional[dict[str, Any]] = None,
80
100
  ) -> tuple[str, snowpark.AsyncJob]:
81
101
  assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
@@ -83,11 +103,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
83
103
  model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
84
104
  model_deployment_spec_yaml_str
85
105
  ) # type: ignore[no-untyped-call]
106
+ model_deployment_spec_yaml_str = model_deployment_spec_yaml_str.replace( # type: ignore[union-attr]
107
+ QMARK_RESERVED_TOKEN, QMARK_PARAMETER_TOKEN
108
+ )
86
109
  logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
87
110
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
88
111
  else:
89
112
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
90
- async_job = self._session.sql(sql_str).collect(block=False, statement_params=statement_params)
113
+ with self._qmark_paramstyle():
114
+ async_job = self._session.sql(
115
+ sql_str,
116
+ params=query_params if query_params else None,
117
+ ).collect(block=False, statement_params=statement_params)
91
118
  assert isinstance(async_job, snowpark.AsyncJob)
92
119
  return async_job.query_id, async_job
93
120
 
@@ -2,6 +2,7 @@ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import Row
5
6
 
6
7
 
7
8
  class StageSQLClient(_base._BaseSQLClient):
@@ -21,3 +22,10 @@ class StageSQLClient(_base._BaseSQLClient):
21
22
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
22
23
 
23
24
  return fq_stage_name
25
+
26
+ def list_stage(self, stage_name: str) -> list[Row]:
27
+ try:
28
+ list_results = self._session.sql(f"LIST {stage_name}").collect()
29
+ except Exception as e:
30
+ raise RuntimeError(f"Failed to check stage location '{stage_name}': {e}")
31
+ return list_results
@@ -46,6 +46,7 @@ class ModelFunctionMethodDict(TypedDict):
46
46
  handler: Required[str]
47
47
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
48
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
49
+ volatility: NotRequired[str]
49
50
 
50
51
 
51
52
  ModelMethodDict = ModelFunctionMethodDict
@@ -4,6 +4,7 @@ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import sql_identifier
8
9
  from snowflake.ml.model import model_signature, type_hints
9
10
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -12,6 +13,7 @@ from snowflake.ml.model._model_composer.model_method import (
12
13
  function_generator,
13
14
  )
14
15
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
16
+ from snowflake.ml.model.volatility import Volatility
15
17
  from snowflake.snowpark._internal import type_utils
16
18
 
17
19
 
@@ -20,10 +22,12 @@ class ModelMethodOptions(TypedDict):
20
22
 
21
23
  case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL.
22
24
  function_type: One of `ModelMethodFunctionTypes` specifying function type.
25
+ volatility: One of `Volatility` enum values specifying function volatility.
23
26
  """
24
27
 
25
28
  case_sensitive: NotRequired[bool]
26
29
  function_type: NotRequired[str]
30
+ volatility: NotRequired[Volatility]
27
31
 
28
32
 
29
33
  def get_model_method_options_from_options(
@@ -38,10 +42,19 @@ def get_model_method_options_from_options(
38
42
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
39
43
  raise NotImplementedError(f"Function type {function_type} is not supported.")
40
44
 
41
- return ModelMethodOptions(
45
+ default_volatility = options.get("volatility")
46
+ method_volatility = method_option.get("volatility")
47
+ resolved_volatility = method_volatility or default_volatility
48
+
49
+ # Only include volatility if explicitly provided in method options
50
+ result: ModelMethodOptions = ModelMethodOptions(
42
51
  case_sensitive=method_option.get("case_sensitive", False),
43
52
  function_type=function_type,
44
53
  )
54
+ if resolved_volatility:
55
+ result["volatility"] = resolved_volatility
56
+
57
+ return result
45
58
 
46
59
 
47
60
  class ModelMethod:
@@ -94,6 +107,9 @@ class ModelMethod:
94
107
  "function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
95
108
  )
96
109
 
110
+ # Volatility is optional; when not provided, we omit it from the manifest
111
+ self.volatility = self.options.get("volatility")
112
+
97
113
  @staticmethod
98
114
  def _get_method_arg_from_feature(
99
115
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
@@ -148,7 +164,7 @@ class ModelMethod:
148
164
  else:
149
165
  outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")]
150
166
 
151
- return model_manifest_schema.ModelFunctionMethodDict(
167
+ method_dict = model_manifest_schema.ModelFunctionMethodDict(
152
168
  name=self.method_name.resolved(),
153
169
  runtime=self.runtime_name,
154
170
  type=self.function_type,
@@ -158,3 +174,10 @@ class ModelMethod:
158
174
  inputs=input_list,
159
175
  outputs=outputs,
160
176
  )
177
+ should_set_volatility = (
178
+ platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
179
+ )
180
+ if should_set_volatility and self.volatility is not None:
181
+ method_dict["volatility"] = self.volatility.name
182
+
183
+ return method_dict
@@ -145,11 +145,12 @@ class ModelEnv:
145
145
  """
146
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
147
147
  pip_pkg_reqs: list[str] = []
148
- if self.targets_warehouse:
148
+ if self.targets_warehouse and not self.artifact_repository_map:
149
149
  self._warn_once(
150
150
  (
151
151
  "Dependencies specified from pip requirements."
152
152
  " This may prevent model deploying to Snowflake Warehouse."
153
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
153
154
  ),
154
155
  stacklevel=2,
155
156
  )
@@ -177,7 +178,11 @@ class ModelEnv:
177
178
  req_to_add.name = conda_req.name
178
179
  else:
179
180
  req_to_add = conda_req
180
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
181
+ show_warning_message = (
182
+ conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
183
+ and self.targets_warehouse
184
+ and not self.artifact_repository_map
185
+ )
181
186
 
182
187
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
183
188
  if show_warning_message:
@@ -185,6 +190,7 @@ class ModelEnv:
185
190
  (
186
191
  f"Basic dependency {req_to_add.name} specified from pip requirements."
187
192
  " This may prevent model deploying to Snowflake Warehouse."
193
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
188
194
  ),
189
195
  stacklevel=2,
190
196
  )
@@ -318,13 +324,15 @@ class ModelEnv:
318
324
  )
319
325
 
320
326
  if pip_requirements_list and self.targets_warehouse:
321
- self._warn_once(
322
- (
323
- "Found dependencies specified as pip requirements."
324
- " This may prevent model deploying to Snowflake Warehouse."
325
- ),
326
- stacklevel=2,
327
- )
327
+ if not self.artifact_repository_map:
328
+ self._warn_once(
329
+ (
330
+ "Found dependencies specified as pip requirements."
331
+ " This may prevent model deploying to Snowflake Warehouse."
332
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
333
+ ),
334
+ stacklevel=2,
335
+ )
328
336
  for pip_dependency in pip_requirements_list:
329
337
  if any(
330
338
  channel_dependency.name == pip_dependency.name
@@ -343,13 +351,15 @@ class ModelEnv:
343
351
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
344
352
 
345
353
  if pip_requirements_list and self.targets_warehouse:
346
- self._warn_once(
347
- (
348
- "Found dependencies specified as pip requirements."
349
- " This may prevent model deploying to Snowflake Warehouse."
350
- ),
351
- stacklevel=2,
352
- )
354
+ if not self.artifact_repository_map:
355
+ self._warn_once(
356
+ (
357
+ "Found dependencies specified as pip requirements."
358
+ " This may prevent model deploying to Snowflake Warehouse."
359
+ " Use 'artifact_repository_map' to deploy the model to Warehouse."
360
+ ),
361
+ stacklevel=2,
362
+ )
353
363
  for pip_dependency in pip_requirements_list:
354
364
  if any(
355
365
  channel_dependency.name == pip_dependency.name
@@ -1,5 +1,6 @@
1
1
  import importlib
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import pathlib
5
6
  import warnings
@@ -8,7 +9,6 @@ from typing import Any, Callable, Iterable, Optional, Sequence, cast
8
9
  import numpy as np
9
10
  import numpy.typing as npt
10
11
  import pandas as pd
11
- from absl import logging
12
12
 
13
13
  import snowflake.snowpark.dataframe as sp_df
14
14
  from snowflake.ml._internal import env
@@ -23,6 +23,8 @@ from snowflake.ml.model._signatures import (
23
23
  )
24
24
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
25
25
 
26
+ logger = logging.getLogger(__name__)
27
+
26
28
  EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
27
29
 
28
30
 
@@ -257,7 +259,7 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
257
259
  )
258
260
  return inferred_model_task
259
261
  elif inferred_model_task != model_types.Task.UNKNOWN:
260
- logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
262
+ logger.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
261
263
  return inferred_model_task
262
264
  return passed_model_task
263
265
 
@@ -43,7 +43,6 @@ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message
43
43
  def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
44
44
  # Text
45
45
  if task in [
46
- "conversational",
47
46
  "fill-mask",
48
47
  "ner",
49
48
  "token-classification",
@@ -521,6 +520,7 @@ class HuggingFacePipelineHandler(
521
520
  input_data = X[signature.inputs[0].name].to_list()
522
521
  temp_res = getattr(raw_model, target_method)(input_data)
523
522
  else:
523
+ # TODO: remove conversational pipeline code
524
524
  # For others, we could offer the whole dataframe as a list.
525
525
  # Some of them may need some conversion
526
526
  if hasattr(transformers, "ConversationalPipeline") and isinstance(
@@ -759,11 +759,13 @@ class HuggingFaceOpenAICompatibleModel:
759
759
  eos_token_id=self.tokenizer.eos_token_id,
760
760
  stop_strings=stop_strings,
761
761
  stream=stream,
762
- repetition_penalty=frequency_penalty,
763
- diversity_penalty=presence_penalty if n > 1 else None,
764
762
  num_return_sequences=n,
765
- num_beams=max(2, n), # must be >1
766
- num_beam_groups=max(2, n) if presence_penalty else 1,
763
+ num_beams=max(1, n), # must be >1
764
+ repetition_penalty=frequency_penalty,
765
+ # TODO: Handle diversity_penalty and num_beam_groups
766
+ # not all models support them making it hard to support any huggingface model
767
+ # diversity_penalty=presence_penalty if n > 1 else None,
768
+ # num_beam_groups=max(2, n) if presence_penalty else 1,
767
769
  do_sample=False,
768
770
  )
769
771
 
@@ -1,9 +1,8 @@
1
+ import logging
1
2
  import os
2
3
  from types import ModuleType
3
4
  from typing import Optional
4
5
 
5
- from absl import logging
6
-
7
6
  from snowflake.ml._internal.exceptions import (
8
7
  error_codes,
9
8
  exceptions as snowml_exceptions,
@@ -12,6 +11,8 @@ from snowflake.ml.model import custom_model, model_signature, type_hints as mode
12
11
  from snowflake.ml.model._packager import model_handler
13
12
  from snowflake.ml.model._packager.model_meta import model_meta
14
13
 
14
+ logger = logging.getLogger(__name__)
15
+
15
16
 
16
17
  class ModelPackager:
17
18
  """Top-level class to save/load and manage a Snowflake Native formatted model.
@@ -96,7 +97,7 @@ class ModelPackager:
96
97
  **options,
97
98
  )
98
99
  if signatures is None:
99
- logging.info(f"Model signatures are auto inferred as:\n\n{meta.signatures}")
100
+ logger.info(f"Model signatures are auto inferred as:\n\n{meta.signatures}")
100
101
 
101
102
  self.model = model
102
103
  self.meta = meta
@@ -2,7 +2,6 @@
2
2
  # Generate by running 'bazel run --config=pre_build //bazel/requirements:sync_requirements'
3
3
 
4
4
  REQUIREMENTS = [
5
- "absl-py>=0.15,<2",
6
5
  "aiohttp!=4.0.0a0, !=4.0.0a1",
7
6
  "anyio>=3.5.0,<5",
8
7
  "cachetools>=3.1.1,<6",
@@ -22,7 +21,7 @@ REQUIREMENTS = [
22
21
  "requests",
23
22
  "retrying>=1.3.3,<2",
24
23
  "s3fs>=2024.6.1,<2026",
25
- "scikit-learn<1.7",
24
+ "scikit-learn<1.8",
26
25
  "scipy>=1.9,<2",
27
26
  "shap>=0.46.0,<1",
28
27
  "snowflake-connector-python>=3.16.0,<4",
@@ -110,27 +110,6 @@ def huggingface_pipeline_signature_auto_infer(
110
110
  ) -> Optional[core.ModelSignature]:
111
111
  # Text
112
112
 
113
- # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
114
- # Needs to convert to conversation object.
115
- if task == "conversational":
116
- warnings.warn(
117
- (
118
- "Conversational pipeline is removed from transformers since 4.42.0. "
119
- "Support will be removed from snowflake-ml-python soon."
120
- ),
121
- category=DeprecationWarning,
122
- stacklevel=1,
123
- )
124
- return core.ModelSignature(
125
- inputs=[
126
- core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
127
- core.FeatureSpec(name="generated_responses", dtype=core.DataType.STRING, shape=(-1,)),
128
- ],
129
- outputs=[
130
- core.FeatureSpec(name="generated_responses", dtype=core.DataType.STRING, shape=(-1,)),
131
- ],
132
- )
133
-
134
113
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TokenClassificationPipeline
135
114
  if task == "fill-mask":
136
115
  return core.ModelSignature(
@@ -8,6 +8,7 @@ from snowflake import snowpark
8
8
  from snowflake.ml._internal import telemetry
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.model._client.model import inference_engine_utils
11
12
  from snowflake.ml.model._client.ops import service_ops
12
13
  from snowflake.snowpark import async_job, session
13
14
 
@@ -77,6 +78,15 @@ class HuggingFacePipelineModel:
77
78
  framework = kwargs.get("framework", None)
78
79
  feature_extractor = kwargs.get("feature_extractor", None)
79
80
 
81
+ _can_download_snapshot = False
82
+ if download_snapshot:
83
+ try:
84
+ import huggingface_hub as hf_hub
85
+
86
+ _can_download_snapshot = True
87
+ except ImportError:
88
+ pass
89
+
80
90
  # ==== Start pipeline logic from transformers ====
81
91
  if model_kwargs is None:
82
92
  model_kwargs = {}
@@ -141,22 +151,23 @@ class HuggingFacePipelineModel:
141
151
  # Instantiate config if needed
142
152
  config_obj = None
143
153
 
144
- if isinstance(config, str):
145
- config_obj = transformers.AutoConfig.from_pretrained(
146
- config, _from_pipeline=task, **hub_kwargs, **model_kwargs
147
- )
148
- hub_kwargs["_commit_hash"] = config_obj._commit_hash
149
- elif config is None and isinstance(model, str):
150
- config_obj = transformers.AutoConfig.from_pretrained(
151
- model, _from_pipeline=task, **hub_kwargs, **model_kwargs
152
- )
153
- hub_kwargs["_commit_hash"] = config_obj._commit_hash
154
- # We only support string as config argument.
155
- elif config is not None and not isinstance(config, str):
156
- raise RuntimeError(
157
- "Impossible to use non-string config as input for HuggingFacePipelineModel. Use transformers.Pipeline"
158
- " object if required."
159
- )
154
+ if not _can_download_snapshot:
155
+ if isinstance(config, str):
156
+ config_obj = transformers.AutoConfig.from_pretrained(
157
+ config, _from_pipeline=task, **hub_kwargs, **model_kwargs
158
+ )
159
+ hub_kwargs["_commit_hash"] = config_obj._commit_hash
160
+ elif config is None and isinstance(model, str):
161
+ config_obj = transformers.AutoConfig.from_pretrained(
162
+ model, _from_pipeline=task, **hub_kwargs, **model_kwargs
163
+ )
164
+ hub_kwargs["_commit_hash"] = config_obj._commit_hash
165
+ # We only support string as config argument.
166
+ elif config is not None and not isinstance(config, str):
167
+ raise RuntimeError(
168
+ "Impossible to use non-string config as input for HuggingFacePipelineModel. "
169
+ "Use transformers.Pipeline object if required."
170
+ )
160
171
 
161
172
  # ==== Start pipeline logic (Task) from transformers ====
162
173
 
@@ -208,7 +219,7 @@ class HuggingFacePipelineModel:
208
219
  "Using a pipeline without specifying a model name and revision in production is not recommended.",
209
220
  stacklevel=2,
210
221
  )
211
- if config is None and isinstance(model, str):
222
+ if not _can_download_snapshot and config is None and isinstance(model, str):
212
223
  config_obj = transformers.AutoConfig.from_pretrained(
213
224
  model, _from_pipeline=task, **hub_kwargs, **model_kwargs
214
225
  )
@@ -228,11 +239,10 @@ class HuggingFacePipelineModel:
228
239
  )
229
240
 
230
241
  repo_snapshot_dir: Optional[str] = None
231
- if download_snapshot:
242
+ if _can_download_snapshot:
232
243
  try:
233
- from huggingface_hub import snapshot_download
234
244
 
235
- repo_snapshot_dir = snapshot_download(
245
+ repo_snapshot_dir = hf_hub.snapshot_download(
236
246
  repo_id=model,
237
247
  revision=revision,
238
248
  token=token,
@@ -268,7 +278,7 @@ class HuggingFacePipelineModel:
268
278
  ],
269
279
  )
270
280
  @snowpark._internal.utils.private_preview(version="1.9.1")
271
- def create_service(
281
+ def log_model_and_create_service(
272
282
  self,
273
283
  *,
274
284
  session: session.Session,
@@ -293,6 +303,7 @@ class HuggingFacePipelineModel:
293
303
  force_rebuild: bool = False,
294
304
  build_external_access_integrations: Optional[list[str]] = None,
295
305
  block: bool = True,
306
+ experimental_options: Optional[dict[str, Any]] = None,
296
307
  ) -> Union[str, async_job.AsyncJob]:
297
308
  """Logs a Hugging Face model and creates a service in Snowflake.
298
309
 
@@ -319,6 +330,10 @@ class HuggingFacePipelineModel:
319
330
  force_rebuild: Whether to force rebuild the image. Defaults to False.
320
331
  build_external_access_integrations: External access integrations for building the image. Defaults to None.
321
332
  block: Whether to block the operation. Defaults to True.
333
+ experimental_options: Experimental options for the service creation with custom inference engine.
334
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
335
+ `inference_engine` is the name of the inference engine to use.
336
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
322
337
 
323
338
  Raises:
324
339
  ValueError: if database and schema name is not provided and session doesn't have a
@@ -360,6 +375,24 @@ class HuggingFacePipelineModel:
360
375
  )
361
376
  logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
362
377
 
378
+ # Check if model is HuggingFace text-generation before doing inference engine checks
379
+ inference_engine_args = None
380
+ if experimental_options:
381
+ if self.task != "text-generation":
382
+ raise ValueError(
383
+ "Currently, InferenceEngine using experimental_options is only supported for "
384
+ "HuggingFace text-generation models."
385
+ )
386
+
387
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
388
+
389
+ # Enrich inference engine args if inference engine is specified
390
+ if inference_engine_args is not None:
391
+ inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
392
+ inference_engine_args,
393
+ gpu_requests,
394
+ )
395
+
363
396
  from snowflake.ml.model import event_handler
364
397
  from snowflake.snowpark import exceptions
365
398
 
@@ -412,6 +445,8 @@ class HuggingFacePipelineModel:
412
445
  # TODO: remove warehouse in the next release
413
446
  warehouse=session.get_current_warehouse(),
414
447
  ),
448
+ # inference engine
449
+ inference_engine_args=inference_engine_args,
415
450
  )
416
451
  status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
417
452
  return result
@@ -15,6 +15,7 @@ from typing_extensions import NotRequired
15
15
 
16
16
  from snowflake.ml.model.target_platform import TargetPlatform
17
17
  from snowflake.ml.model.task import Task
18
+ from snowflake.ml.model.volatility import Volatility
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  import catboost
@@ -150,6 +151,7 @@ class ModelMethodSaveOptions(TypedDict):
150
151
  case_sensitive: NotRequired[bool]
151
152
  max_batch_size: NotRequired[int]
152
153
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
154
+ volatility: NotRequired[Volatility]
153
155
 
154
156
 
155
157
  class BaseModelSaveOption(TypedDict):
@@ -158,12 +160,23 @@ class BaseModelSaveOption(TypedDict):
158
160
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
159
161
  relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
160
162
  It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
163
+ function_type: Set the method function type globally. To set method function types individually see
164
+ function_type in method_options.
165
+ volatility: Set the volatility for all model methods globally. To set volatility for individual methods
166
+ see volatility in method_options. Defaults are set automatically based on model type: supported
167
+ models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while
168
+ custom models default to VOLATILE. When both global volatility and per-method volatility are specified,
169
+ the per-method volatility takes precedence.
170
+ method_options: Per-method saving options. This dictionary has method names as keys and dictionary
171
+ values with the desired options.
172
+ enable_explainability: Whether to enable explainability features for the model.
161
173
  save_location: Local directory path to save the model and metadata.
162
174
  """
163
175
 
164
176
  embed_local_ml_library: NotRequired[bool]
165
177
  relax_version: NotRequired[bool]
166
178
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
179
+ volatility: NotRequired[Volatility]
167
180
  method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
168
181
  enable_explainability: NotRequired[bool]
169
182
  save_location: NotRequired[str]
@@ -0,0 +1,34 @@
1
+ """Volatility definitions for model functions."""
2
+
3
+ from enum import Enum, auto
4
+
5
+
6
+ class Volatility(Enum):
7
+ """Volatility levels for model functions.
8
+
9
+ Attributes:
10
+ VOLATILE: Function results may change between calls with the same arguments.
11
+ Use this for functions that depend on external data or have non-deterministic behavior.
12
+ IMMUTABLE: Function results are guaranteed to be the same for the same arguments.
13
+ Use this for pure functions that always return the same output for the same input.
14
+ """
15
+
16
+ VOLATILE = auto()
17
+ IMMUTABLE = auto()
18
+
19
+
20
+ DEFAULT_VOLATILITY_BY_MODEL_TYPE = {
21
+ "catboost": Volatility.IMMUTABLE,
22
+ "custom": Volatility.VOLATILE,
23
+ "huggingface_pipeline": Volatility.IMMUTABLE,
24
+ "keras": Volatility.IMMUTABLE,
25
+ "lightgbm": Volatility.IMMUTABLE,
26
+ "mlflow": Volatility.IMMUTABLE,
27
+ "pytorch": Volatility.IMMUTABLE,
28
+ "sentence_transformers": Volatility.IMMUTABLE,
29
+ "sklearn": Volatility.IMMUTABLE,
30
+ "snowml": Volatility.IMMUTABLE,
31
+ "tensorflow": Volatility.IMMUTABLE,
32
+ "torchscript": Volatility.IMMUTABLE,
33
+ "xgboost": Volatility.IMMUTABLE,
34
+ }
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(