snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/callback/keras.py +63 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  7. snowflake/ml/experiment/callback/xgboost.py +5 -1
  8. snowflake/ml/experiment/experiment_tracking.py +89 -4
  9. snowflake/ml/feature_store/feature_store.py +1150 -131
  10. snowflake/ml/feature_store/feature_view.py +122 -0
  11. snowflake/ml/jobs/_utils/__init__.py +0 -0
  12. snowflake/ml/jobs/_utils/constants.py +9 -14
  13. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  14. snowflake/ml/jobs/_utils/payload_utils.py +61 -19
  15. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  16. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  17. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
  19. snowflake/ml/jobs/_utils/spec_utils.py +44 -13
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  21. snowflake/ml/jobs/_utils/types.py +7 -8
  22. snowflake/ml/jobs/job.py +34 -18
  23. snowflake/ml/jobs/manager.py +107 -24
  24. snowflake/ml/model/__init__.py +6 -1
  25. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  26. snowflake/ml/model/_client/model/model_version_impl.py +225 -73
  27. snowflake/ml/model/_client/ops/service_ops.py +128 -174
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
  30. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  33. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  35. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  36. snowflake/ml/model/_signatures/utils.py +4 -2
  37. snowflake/ml/model/inference_engine.py +5 -0
  38. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  39. snowflake/ml/model/openai_signatures.py +57 -0
  40. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  41. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  46. snowflake/ml/modeling/cluster/birch.py +1 -1
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  50. snowflake/ml/modeling/cluster/k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  53. snowflake/ml/modeling/cluster/optics.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  57. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  64. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  65. snowflake/ml/modeling/covariance/oas.py +1 -1
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  69. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/pca.py +1 -1
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  105. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  106. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  107. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  122. snowflake/ml/modeling/linear_model/lars.py +1 -1
  123. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  129. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  151. snowflake/ml/modeling/manifold/isomap.py +1 -1
  152. snowflake/ml/modeling/manifold/mds.py +1 -1
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  154. snowflake/ml/modeling/manifold/tsne.py +1 -1
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  157. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  158. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  159. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  163. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  164. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  165. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  166. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  167. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  168. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  169. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  170. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  171. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  172. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  173. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  174. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  175. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  176. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  178. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  179. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  180. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  181. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  182. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  183. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  184. snowflake/ml/modeling/svm/svc.py +1 -1
  185. snowflake/ml/modeling/svm/svr.py +1 -1
  186. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  189. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  192. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  193. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  194. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  195. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  196. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  197. snowflake/ml/monitoring/model_monitor.py +26 -0
  198. snowflake/ml/registry/_manager/model_manager.py +7 -35
  199. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  200. snowflake/ml/version.py +1 -1
  201. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
  202. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
  203. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  204. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,15 @@ class Model(BaseModel):
10
10
  version: str
11
11
 
12
12
 
13
+ class InferenceEngineSpec(BaseModel):
14
+ inference_engine_name: str
15
+ inference_engine_args: Optional[list[str]] = None
16
+
17
+
13
18
  class ImageBuild(BaseModel):
14
- compute_pool: str
15
- image_repo: str
16
- force_rebuild: bool
19
+ compute_pool: Optional[str] = None
20
+ image_repo: Optional[str] = None
21
+ force_rebuild: Optional[bool] = None
17
22
  external_access_integrations: Optional[list[str]] = None
18
23
 
19
24
 
@@ -27,6 +32,17 @@ class Service(BaseModel):
27
32
  gpu: Optional[str] = None
28
33
  num_workers: Optional[int] = None
29
34
  max_batch_rows: Optional[int] = None
35
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
36
+
37
+
38
+ class Input(BaseModel):
39
+ input_stage_location: str
40
+ input_file_pattern: str
41
+
42
+
43
+ class Output(BaseModel):
44
+ output_stage_location: str
45
+ completion_filename: str
30
46
 
31
47
 
32
48
  class Job(BaseModel):
@@ -37,10 +53,10 @@ class Job(BaseModel):
37
53
  gpu: Optional[str] = None
38
54
  num_workers: Optional[int] = None
39
55
  max_batch_rows: Optional[int] = None
40
- warehouse: str
41
- target_method: str
42
- input_table_name: str
43
- output_table_name: str
56
+ warehouse: Optional[str] = None
57
+ function_name: str
58
+ input: Input
59
+ output: Output
44
60
 
45
61
 
46
62
  class LogModelArgs(BaseModel):
@@ -68,13 +84,13 @@ class ModelLogging(BaseModel):
68
84
 
69
85
  class ModelServiceDeploymentSpec(BaseModel):
70
86
  models: list[Model]
71
- image_build: ImageBuild
87
+ image_build: Optional[ImageBuild] = None
72
88
  service: Service
73
89
  model_loggings: Optional[list[ModelLogging]] = None
74
90
 
75
91
 
76
92
  class ModelJobDeploymentSpec(BaseModel):
77
93
  models: list[Model]
78
- image_build: ImageBuild
94
+ image_build: Optional[ImageBuild] = None
79
95
  job: Job
80
96
  model_loggings: Optional[list[ModelLogging]] = None
@@ -1,17 +1,12 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
- import warnings
5
4
  from types import ModuleType
6
5
  from typing import TYPE_CHECKING, Any, Optional, Union
7
6
  from urllib import parse
8
7
 
9
- from absl import logging
10
- from packaging import requirements
11
-
12
8
  from snowflake import snowpark
13
- from snowflake.ml import version as snowml_version
14
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
9
+ from snowflake.ml._internal import file_utils
15
10
  from snowflake.ml._internal.lineage import lineage_utils
16
11
  from snowflake.ml.data import data_source
17
12
  from snowflake.ml.model import model_signature, type_hints as model_types
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
19
14
  from snowflake.ml.model._packager import model_packager
20
15
  from snowflake.ml.model._packager.model_meta import model_meta
21
16
  from snowflake.snowpark import Session
22
- from snowflake.snowpark._internal import utils as snowpark_utils
23
17
 
24
18
  if TYPE_CHECKING:
25
19
  from snowflake.ml.experiment._experiment_info import ExperimentInfo
@@ -142,73 +136,10 @@ class ModelComposer:
142
136
  experiment_info: Optional["ExperimentInfo"] = None,
143
137
  options: Optional[model_types.ModelSaveOption] = None,
144
138
  ) -> model_meta.ModelMetadata:
145
- # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
146
- conda_dep_dict = env_utils.validate_conda_dependency_string_list(
147
- conda_dependencies if conda_dependencies else []
148
- )
149
-
150
- enable_explainability = None
151
-
152
- if options:
153
- enable_explainability = options.get("enable_explainability", None)
154
-
155
- # skip everything if user said False explicitly
156
- if enable_explainability is None or enable_explainability is True:
157
- is_warehouse_runnable = (
158
- not conda_dep_dict
159
- or all(
160
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
161
- for chan in conda_dep_dict
162
- )
163
- ) and (not pip_requirements)
164
-
165
- only_spcs = (
166
- target_platforms
167
- and len(target_platforms) == 1
168
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
169
- )
170
- if only_spcs or (not is_warehouse_runnable):
171
- # if only SPCS and user asked for explainability we fail
172
- if enable_explainability is True:
173
- raise ValueError(
174
- "`enable_explainability` cannot be set to True when the model is not runnable in WH "
175
- "or the target platforms include SPCS."
176
- )
177
- elif not options: # explicitly set flag to false in these cases if not specified
178
- options = model_types.BaseModelSaveOption()
179
- options["enable_explainability"] = False
180
- elif (
181
- target_platforms
182
- and len(target_platforms) > 1
183
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
184
- ): # if both then only available for WH
185
- if enable_explainability is True:
186
- warnings.warn(
187
- ("Explain function will only be available for model deployed to warehouse."),
188
- category=UserWarning,
189
- stacklevel=2,
190
- )
191
139
 
192
140
  if not options:
193
141
  options = model_types.BaseModelSaveOption()
194
142
 
195
- if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
196
- model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
197
- ]:
198
- snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
199
- self.session,
200
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
201
- python_version=python_version or snowml_env.PYTHON_VERSION,
202
- statement_params=self._statement_params,
203
- ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
204
-
205
- if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
206
- logging.info(
207
- f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
208
- " which is not available in the Snowflake server, embedding local ML library automatically."
209
- )
210
- options["embed_local_ml_library"] = True
211
-
212
143
  model_metadata: model_meta.ModelMetadata = self.packager.save(
213
144
  name=name,
214
145
  model=model,
@@ -1,13 +1,11 @@
1
1
  import collections
2
2
  import logging
3
3
  import pathlib
4
- import warnings
5
4
  from typing import TYPE_CHECKING, Optional, cast
6
5
 
7
6
  import yaml
8
7
 
9
8
  from snowflake.ml._internal import env_utils
10
- from snowflake.ml._internal.exceptions import error_codes, exceptions
11
9
  from snowflake.ml.data import data_source
12
10
  from snowflake.ml.model import type_hints
13
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -55,47 +53,8 @@ class ModelManifest:
55
53
  experiment_info: Optional["ExperimentInfo"] = None,
56
54
  target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
57
55
  ) -> None:
58
- if options is None:
59
- options = {}
60
-
61
- has_pip_requirements = len(model_meta.env.pip_requirements) > 0
62
- only_spcs = (
63
- target_platforms
64
- and len(target_platforms) == 1
65
- and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
66
- )
67
-
68
- if "relax_version" not in options:
69
- if has_pip_requirements or only_spcs:
70
- logger.info(
71
- "Setting `relax_version=False` as this model will run in Snowpark Container Services "
72
- "or in Warehouse with a specified artifact_repository_map where exact version "
73
- " specifications will be honored."
74
- )
75
- relax_version = False
76
- else:
77
- warnings.warn(
78
- (
79
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
80
- " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
81
- " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
82
- ),
83
- category=UserWarning,
84
- stacklevel=2,
85
- )
86
- relax_version = True
87
- options["relax_version"] = relax_version
88
- else:
89
- relax_version = options.get("relax_version", True)
90
- if relax_version and (has_pip_requirements or only_spcs):
91
- raise exceptions.SnowflakeMLException(
92
- error_code=error_codes.INVALID_ARGUMENT,
93
- original_exception=ValueError(
94
- "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
95
- "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
96
- "targeting only Snowpark Container Services."
97
- ),
98
- )
56
+ assert options is not None, "ModelParameterReconciler should have set options with relax_version"
57
+ relax_version = options["relax_version"]
99
58
 
100
59
  runtime_to_use = model_runtime.ModelRuntime(
101
60
  name=self._DEFAULT_RUNTIME_NAME,
@@ -1,6 +1,8 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import time
5
+ import uuid
4
6
  import warnings
5
7
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
6
8
 
@@ -11,7 +13,12 @@ from packaging import version
11
13
  from typing_extensions import TypeGuard, Unpack
12
14
 
13
15
  from snowflake.ml._internal import type_utils
14
- from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
16
+ from snowflake.ml.model import (
17
+ custom_model,
18
+ model_signature,
19
+ openai_signatures,
20
+ type_hints as model_types,
21
+ )
15
22
  from snowflake.ml.model._packager.model_env import model_env
16
23
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
17
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
@@ -151,7 +158,10 @@ class HuggingFacePipelineHandler(
151
158
  assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
152
159
  params = {**model.__dict__, **model.model_kwargs}
153
160
 
154
- inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(task, params=params)
161
+ inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
162
+ task,
163
+ params=params,
164
+ )
155
165
 
156
166
  if not is_sub_model:
157
167
  target_methods = handlers_utils.get_target_methods(
@@ -401,6 +411,34 @@ class HuggingFacePipelineHandler(
401
411
  ),
402
412
  axis=1,
403
413
  ).to_list()
414
+ elif raw_model.task == "text-generation":
415
+ # verify when the target method is __call__ and
416
+ # if the signature is default text-generation signature
417
+ # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
418
+ if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
419
+ wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
420
+
421
+ temp_res = X.apply(
422
+ lambda row: wrapped_model.generate_chat_completion(
423
+ messages=row["messages"],
424
+ max_completion_tokens=row.get("max_completion_tokens", None),
425
+ temperature=row.get("temperature", None),
426
+ stop_strings=row.get("stop", None),
427
+ n=row.get("n", 1),
428
+ stream=row.get("stream", False),
429
+ top_p=row.get("top_p", 1.0),
430
+ frequency_penalty=row.get("frequency_penalty", None),
431
+ presence_penalty=row.get("presence_penalty", None),
432
+ ),
433
+ axis=1,
434
+ ).to_list()
435
+ else:
436
+ if len(signature.inputs) > 1:
437
+ input_data = X.to_dict("records")
438
+ # If it is only expecting one argument, Then it is expecting a list of something.
439
+ else:
440
+ input_data = X[signature.inputs[0].name].to_list()
441
+ temp_res = getattr(raw_model, target_method)(input_data)
404
442
  else:
405
443
  # For others, we could offer the whole dataframe as a list.
406
444
  # Some of them may need some conversion
@@ -527,3 +565,170 @@ class HuggingFacePipelineHandler(
527
565
  hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
528
566
 
529
567
  return hg_pipe_model
568
+
569
+
570
+ class HuggingFaceOpenAICompatibleModel:
571
+ """
572
+ A class to wrap a Hugging Face text generation model and provide an
573
+ OpenAI-compatible chat completion interface.
574
+ """
575
+
576
+ def __init__(self, pipeline: "transformers.Pipeline") -> None:
577
+ """
578
+ Initializes the model and tokenizer.
579
+
580
+ Args:
581
+ pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
582
+ """
583
+
584
+ self.pipeline = pipeline
585
+ self.model = self.pipeline.model
586
+ self.tokenizer = self.pipeline.tokenizer
587
+
588
+ self.model_name = self.pipeline.model.name_or_path
589
+
590
+ def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
591
+ """
592
+ Applies a chat template to a list of messages.
593
+ If the tokenizer has a chat template, it uses that.
594
+ Otherwise, it falls back to a simple concatenation.
595
+
596
+ Args:
597
+ messages (list[dict]): A list of message dictionaries, e.g.,
598
+ [{"role": "user", "content": "Hello!"}, ...]
599
+
600
+ Returns:
601
+ The formatted prompt string ready for model input.
602
+ """
603
+
604
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
605
+ # Use the tokenizer's built-in chat template if available
606
+ # `tokenize=False` means it returns a string, not token IDs
607
+ return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
608
+ messages,
609
+ tokenize=False,
610
+ add_generation_prompt=True,
611
+ )
612
+ else:
613
+ # Fallback to a simple concatenation for models without a specific chat template
614
+ # This is a basic example; real chat models often need specific formatting.
615
+ prompt = ""
616
+ for message in messages:
617
+ role = message.get("role", "user")
618
+ content = message.get("content", "")
619
+ if role == "system":
620
+ prompt += f"System: {content}\n"
621
+ elif role == "user":
622
+ prompt += f"User: {content}\n"
623
+ elif role == "assistant":
624
+ prompt += f"Assistant: {content}\n"
625
+ prompt += "Assistant:" # Indicate that the assistant should respond
626
+ return prompt
627
+
628
+ def generate_chat_completion(
629
+ self,
630
+ messages: list[dict[str, Any]],
631
+ max_completion_tokens: Optional[int] = None,
632
+ stream: Optional[bool] = False,
633
+ stop_strings: Optional[list[str]] = None,
634
+ temperature: Optional[float] = None,
635
+ top_p: Optional[float] = None,
636
+ frequency_penalty: Optional[float] = None,
637
+ presence_penalty: Optional[float] = None,
638
+ n: int = 1,
639
+ ) -> dict[str, Any]:
640
+ """
641
+ Generates a chat completion response in an OpenAI-compatible format.
642
+
643
+ Args:
644
+ messages (list[dict]): A list of message dictionaries, e.g.,
645
+ [{"role": "system", "content": "You are a helpful assistant."},
646
+ {"role": "user", "content": "What is deep learning?"}]
647
+ max_completion_tokens (int): The maximum number of completion tokens to generate.
648
+ stop_strings (list[str]): A list of strings to stop generation.
649
+ temperature (float): The temperature for sampling.
650
+ top_p (float): The top-p value for sampling.
651
+ stream (bool): Whether to stream the generation.
652
+ frequency_penalty (float): The frequency penalty for sampling.
653
+ presence_penalty (float): The presence penalty for sampling.
654
+ n (int): The number of samples to generate.
655
+
656
+ Returns:
657
+ dict: An OpenAI-compatible dictionary representing the chat completion.
658
+ """
659
+ # Apply chat template to convert messages into a single prompt string
660
+
661
+ prompt_text = self._apply_chat_template(messages)
662
+
663
+ # Tokenize the prompt
664
+ inputs = self.tokenizer(
665
+ prompt_text,
666
+ return_tensors="pt",
667
+ padding=True,
668
+ )
669
+ prompt_tokens = inputs.input_ids.shape[1]
670
+
671
+ from transformers import GenerationConfig
672
+
673
+ generation_config = GenerationConfig(
674
+ max_new_tokens=max_completion_tokens,
675
+ temperature=temperature,
676
+ top_p=top_p,
677
+ pad_token_id=self.tokenizer.pad_token_id,
678
+ eos_token_id=self.tokenizer.eos_token_id,
679
+ stop_strings=stop_strings,
680
+ stream=stream,
681
+ repetition_penalty=frequency_penalty,
682
+ diversity_penalty=presence_penalty if n > 1 else None,
683
+ num_return_sequences=n,
684
+ num_beams=max(2, n), # must be >1
685
+ num_beam_groups=max(2, n) if presence_penalty else 1,
686
+ )
687
+
688
+ # Generate text
689
+ output_ids = self.model.generate(
690
+ inputs.input_ids,
691
+ attention_mask=inputs.attention_mask,
692
+ generation_config=generation_config,
693
+ )
694
+
695
+ generated_texts = []
696
+ completion_tokens = 0
697
+ total_tokens = prompt_tokens
698
+ for output_id in output_ids:
699
+ # The output_ids include the input prompt
700
+ # Decode the generated text, excluding the input prompt
701
+ # so we slice to get only new tokens
702
+ generated_tokens = output_id[prompt_tokens:]
703
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
704
+ generated_texts.append(generated_text)
705
+
706
+ # Calculate completion tokens
707
+ completion_tokens += len(generated_tokens)
708
+ total_tokens += len(generated_tokens)
709
+
710
+ choices = []
711
+ for i, generated_text in enumerate(generated_texts):
712
+ choices.append(
713
+ {
714
+ "index": i,
715
+ "message": {"role": "assistant", "content": generated_text},
716
+ "logprobs": None, # Not directly supported in this basic implementation
717
+ "finish_reason": "stop", # Assuming stop for simplicity
718
+ }
719
+ )
720
+
721
+ # Construct OpenAI-compatible response
722
+ response = {
723
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
724
+ "object": "chat.completion",
725
+ "created": int(time.time()),
726
+ "model": self.model_name,
727
+ "choices": choices,
728
+ "usage": {
729
+ "prompt_tokens": prompt_tokens,
730
+ "completion_tokens": completion_tokens,
731
+ "total_tokens": total_tokens,
732
+ },
733
+ }
734
+ return response
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
386
386
  predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
387
387
  try:
388
388
  explainer = shap.Explainer(predictor, transformed_bg_data)
389
- return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
389
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
390
+ np.float64, errors="ignore"
391
+ )
390
392
  except TypeError:
391
393
  if isinstance(data, pd.DataFrame):
392
394
  dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
@@ -14,7 +14,7 @@ REQUIREMENTS = [
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
16
  "platformdirs<5",
17
- "pyarrow",
17
+ "pyarrow<19.0.0",
18
18
  "pydantic>=2.8.2, <3",
19
19
  "pyjwt>=2.0.0, <3",
20
20
  "pytimeparse>=1.1.8,<2",
@@ -22,10 +22,10 @@ REQUIREMENTS = [
22
22
  "requests",
23
23
  "retrying>=1.3.3,<2",
24
24
  "s3fs>=2024.6.1,<2026",
25
- "scikit-learn<1.6",
25
+ "scikit-learn<1.7",
26
26
  "scipy>=1.9,<2",
27
27
  "shap>=0.46.0,<1",
28
- "snowflake-connector-python>=3.15.0,<4",
28
+ "snowflake-connector-python>=3.16.0,<4",
29
29
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
30
30
  "snowflake.core>=1.0.2,<2",
31
31
  "sqlparse>=0.4,<1",
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
84
84
  return json.loads(x)
85
85
 
86
86
  for field in data.schema.fields:
87
- if isinstance(field.datatype, spt.ArrayType):
87
+ if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
88
88
  df_local[identifier.get_unescaped_names(field.name)] = df_local[
89
89
  identifier.get_unescaped_names(field.name)
90
90
  ].map(load_if_not_null)
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
104
104
  return data
105
105
 
106
106
 
107
- def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any]) -> Optional[core.ModelSignature]:
107
+ def huggingface_pipeline_signature_auto_infer(
108
+ task: str,
109
+ params: dict[str, Any],
110
+ ) -> Optional[core.ModelSignature]:
108
111
  # Text
109
112
 
110
113
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
297
300
  )
298
301
  ],
299
302
  )
300
-
301
303
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
302
304
  if task == "text2text-generation":
303
305
  if params.get("return_tensors", False):
@@ -0,0 +1,5 @@
1
+ import enum
2
+
3
+
4
+ class InferenceEngine(enum.Enum):
5
+ VLLM = "vllm"
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
258
258
  # model_version_impl.create_service parameters
259
259
  service_name: str,
260
260
  service_compute_pool: str,
261
- image_repo: str,
261
+ image_repo: Optional[str] = None,
262
262
  image_build_compute_pool: Optional[str] = None,
263
263
  ingress_enabled: bool = False,
264
264
  max_instances: int = 1,
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
282
282
  comment: Comment for the model. Defaults to None.
283
283
  service_name: The name of the service to create.
284
284
  service_compute_pool: The compute pool for the service.
285
- image_repo: The name of the image repository.
285
+ image_repo: The name of the image repository. This can be None, in that case a default hidden image
286
+ repository will be used.
286
287
  image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
287
288
  the service compute pool if None.
288
289
  ingress_enabled: Whether ingress is enabled. Defaults to False.
@@ -356,7 +357,7 @@ class HuggingFacePipelineModel:
356
357
  else sql_identifier.SqlIdentifier(service_compute_pool)
357
358
  ),
358
359
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
359
- image_repo=image_repo,
360
+ image_repo_name=image_repo,
360
361
  ingress_enabled=ingress_enabled,
361
362
  max_instances=max_instances,
362
363
  cpu_requests=cpu_requests,
@@ -0,0 +1,57 @@
1
+ from snowflake.ml.model._signatures import core
2
+
3
+ _OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
4
+ inputs=[
5
+ core.FeatureGroupSpec(
6
+ name="messages",
7
+ specs=[
8
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
9
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
10
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
11
+ core.FeatureSpec(name="title", dtype=core.DataType.STRING),
12
+ ],
13
+ shape=(-1,),
14
+ ),
15
+ core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
16
+ core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
17
+ core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
18
+ core.FeatureSpec(name="n", dtype=core.DataType.INT32),
19
+ core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
20
+ core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
21
+ core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
22
+ core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
23
+ ],
24
+ outputs=[
25
+ core.FeatureSpec(name="id", dtype=core.DataType.STRING),
26
+ core.FeatureSpec(name="object", dtype=core.DataType.STRING),
27
+ core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
28
+ core.FeatureSpec(name="model", dtype=core.DataType.STRING),
29
+ core.FeatureGroupSpec(
30
+ name="choices",
31
+ specs=[
32
+ core.FeatureSpec(name="index", dtype=core.DataType.INT32),
33
+ core.FeatureGroupSpec(
34
+ name="message",
35
+ specs=[
36
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
37
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
38
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
39
+ ],
40
+ ),
41
+ core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
42
+ core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
43
+ ],
44
+ shape=(-1,),
45
+ ),
46
+ core.FeatureGroupSpec(
47
+ name="usage",
48
+ specs=[
49
+ core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
50
+ core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
51
+ core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
52
+ ],
53
+ ),
54
+ ],
55
+ )
56
+
57
+ OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}