snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/telemetry.py +3 -1
  3. snowflake/ml/_internal/utils/service_logger.py +26 -1
  4. snowflake/ml/experiment/_client/artifact.py +76 -0
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  6. snowflake/ml/experiment/experiment_tracking.py +113 -6
  7. snowflake/ml/feature_store/feature_store.py +1150 -131
  8. snowflake/ml/feature_store/feature_view.py +122 -0
  9. snowflake/ml/jobs/_utils/constants.py +8 -16
  10. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +19 -5
  12. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  13. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
  14. snowflake/ml/jobs/_utils/spec_utils.py +4 -6
  15. snowflake/ml/jobs/_utils/types.py +2 -1
  16. snowflake/ml/jobs/job.py +38 -19
  17. snowflake/ml/jobs/manager.py +136 -19
  18. snowflake/ml/model/__init__.py +6 -1
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +62 -65
  21. snowflake/ml/model/_client/ops/model_ops.py +42 -9
  22. snowflake/ml/model/_client/ops/service_ops.py +75 -154
  23. snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
  24. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
  25. snowflake/ml/model/_client/sql/service.py +4 -0
  26. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
  27. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
  29. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  30. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  31. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  32. snowflake/ml/model/_signatures/utils.py +4 -2
  33. snowflake/ml/model/models/huggingface_pipeline.py +23 -0
  34. snowflake/ml/model/openai_signatures.py +57 -0
  35. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  37. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  38. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  39. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  40. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  41. snowflake/ml/modeling/cluster/birch.py +1 -1
  42. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  43. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  44. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  45. snowflake/ml/modeling/cluster/k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  47. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/optics.py +1 -1
  49. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  50. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  51. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  52. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  53. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  54. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  55. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  56. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  57. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  58. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  59. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  60. snowflake/ml/modeling/covariance/oas.py +1 -1
  61. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  62. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  63. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  64. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  65. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  66. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  67. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  72. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  73. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  74. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  79. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  80. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  81. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  84. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  85. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  88. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  90. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  91. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  92. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  93. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  96. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  97. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  98. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  99. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  100. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  101. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  102. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  108. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  109. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  110. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  112. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  113. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  114. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  117. snowflake/ml/modeling/linear_model/lars.py +1 -1
  118. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  119. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  120. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  124. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  125. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  126. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  127. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  128. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  131. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  132. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  135. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  136. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  138. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  139. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  141. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  143. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  144. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  145. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  146. snowflake/ml/modeling/manifold/isomap.py +1 -1
  147. snowflake/ml/modeling/manifold/mds.py +1 -1
  148. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  149. snowflake/ml/modeling/manifold/tsne.py +1 -1
  150. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  151. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  152. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  153. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  154. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  155. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  156. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  157. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  160. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  161. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  162. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  163. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  164. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  165. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  166. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  167. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  168. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  169. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  170. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  171. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  179. snowflake/ml/modeling/svm/svc.py +1 -1
  180. snowflake/ml/modeling/svm/svr.py +1 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  189. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  190. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  191. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  192. snowflake/ml/monitoring/model_monitor.py +26 -0
  193. snowflake/ml/version.py +1 -1
  194. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
  195. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
  196. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
  198. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import shutil
5
+ import time
6
+ import uuid
4
7
  import warnings
5
8
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
6
9
 
@@ -11,7 +14,12 @@ from packaging import version
11
14
  from typing_extensions import TypeGuard, Unpack
12
15
 
13
16
  from snowflake.ml._internal import type_utils
14
- from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
17
+ from snowflake.ml.model import (
18
+ custom_model,
19
+ model_signature,
20
+ openai_signatures,
21
+ type_hints as model_types,
22
+ )
15
23
  from snowflake.ml.model._packager.model_env import model_env
16
24
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
17
25
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
@@ -81,6 +89,7 @@ class HuggingFacePipelineHandler(
81
89
  _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
82
90
 
83
91
  MODEL_BLOB_FILE_OR_DIR = "model"
92
+ MODEL_PICKLE_FILE = "snowml_huggingface_pipeline.pkl"
84
93
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
85
94
  DEFAULT_TARGET_METHODS = ["__call__"]
86
95
  IS_AUTO_SIGNATURE = True
@@ -151,7 +160,10 @@ class HuggingFacePipelineHandler(
151
160
  assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
152
161
  params = {**model.__dict__, **model.model_kwargs}
153
162
 
154
- inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(task, params=params)
163
+ inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
164
+ task,
165
+ params=params,
166
+ )
155
167
 
156
168
  if not is_sub_model:
157
169
  target_methods = handlers_utils.get_target_methods(
@@ -189,6 +201,7 @@ class HuggingFacePipelineHandler(
189
201
  model_blob_path = os.path.join(model_blobs_dir_path, name)
190
202
  os.makedirs(model_blob_path, exist_ok=True)
191
203
 
204
+ is_repo_downloaded = False
192
205
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
193
206
  save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
194
207
  model.save_pretrained( # type:ignore[attr-defined]
@@ -214,11 +227,22 @@ class HuggingFacePipelineHandler(
214
227
  ) as f:
215
228
  cloudpickle.dump(pipeline_params, f)
216
229
  else:
230
+ model_blob_file_or_dir = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
231
+ model_blob_pickle_file = os.path.join(model_blob_file_or_dir, cls.MODEL_PICKLE_FILE)
232
+ os.makedirs(model_blob_file_or_dir, exist_ok=True)
217
233
  with open(
218
- os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
234
+ model_blob_pickle_file,
219
235
  "wb",
220
236
  ) as f:
221
237
  cloudpickle.dump(model, f)
238
+ if model.repo_snapshot_dir:
239
+ logger.info("model's repo_snapshot_dir is available, copying snapshot")
240
+ shutil.copytree(
241
+ model.repo_snapshot_dir,
242
+ model_blob_file_or_dir,
243
+ dirs_exist_ok=True,
244
+ )
245
+ is_repo_downloaded = True
222
246
 
223
247
  base_meta = model_blob_meta.ModelBlobMeta(
224
248
  name=name,
@@ -226,13 +250,12 @@ class HuggingFacePipelineHandler(
226
250
  handler_version=cls.HANDLER_VERSION,
227
251
  path=cls.MODEL_BLOB_FILE_OR_DIR,
228
252
  options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
229
- {
230
- "task": task,
231
- "batch_size": batch_size if batch_size is not None else 1,
232
- "has_tokenizer": has_tokenizer,
233
- "has_feature_extractor": has_feature_extractor,
234
- "has_image_preprocessor": has_image_preprocessor,
235
- }
253
+ task=task,
254
+ batch_size=batch_size if batch_size is not None else 1,
255
+ has_tokenizer=has_tokenizer,
256
+ has_feature_extractor=has_feature_extractor,
257
+ has_image_preprocessor=has_image_preprocessor,
258
+ is_repo_downloaded=is_repo_downloaded,
236
259
  ),
237
260
  )
238
261
  model_meta.models[name] = base_meta
@@ -276,6 +299,27 @@ class HuggingFacePipelineHandler(
276
299
 
277
300
  return device_config
278
301
 
302
+ @staticmethod
303
+ def _load_pickle_model(
304
+ pickle_file: str,
305
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
306
+ ) -> huggingface_pipeline.HuggingFacePipelineModel:
307
+ with open(pickle_file, "rb") as f:
308
+ m = cloudpickle.load(f)
309
+ assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
310
+ torch_dtype: Optional[str] = None
311
+ device_config = None
312
+ if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
313
+ device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
314
+ m.__dict__.update(device_config)
315
+
316
+ if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
317
+ torch_dtype = "auto"
318
+ m.__dict__.update(torch_dtype=torch_dtype)
319
+ else:
320
+ m.__dict__.update(torch_dtype=None)
321
+ return m
322
+
279
323
  @classmethod
280
324
  def load_model(
281
325
  cls,
@@ -300,7 +344,13 @@ class HuggingFacePipelineHandler(
300
344
  raise ValueError("Missing field `batch_size` in model blob metadata for type `huggingface_pipeline`")
301
345
 
302
346
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
303
- if os.path.isdir(model_blob_file_or_dir_path):
347
+ is_repo_downloaded = model_blob_options.get("is_repo_downloaded", False)
348
+
349
+ def _create_pipeline_from_dir(
350
+ model_blob_file_or_dir_path: str,
351
+ model_blob_options: model_meta_schema.HuggingFacePipelineModelBlobOptions,
352
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
353
+ ) -> "transformers.Pipeline":
304
354
  import transformers
305
355
 
306
356
  additional_pipeline_params = {}
@@ -320,7 +370,7 @@ class HuggingFacePipelineHandler(
320
370
  ) as f:
321
371
  pipeline_params = cloudpickle.load(f)
322
372
 
323
- device_config = cls._get_device_config(**kwargs)
373
+ device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
324
374
 
325
375
  m = transformers.pipeline(
326
376
  model_blob_options["task"],
@@ -349,18 +399,59 @@ class HuggingFacePipelineHandler(
349
399
  m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
350
400
 
351
401
  m.__dict__.update(pipeline_params)
402
+ return m
352
403
 
404
+ def _create_pipeline_from_model(
405
+ model_blob_file_or_dir_path: str,
406
+ m: huggingface_pipeline.HuggingFacePipelineModel,
407
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
408
+ ) -> "transformers.Pipeline":
409
+ import transformers
410
+
411
+ return transformers.pipeline(
412
+ m.task,
413
+ model=model_blob_file_or_dir_path,
414
+ trust_remote_code=m.trust_remote_code,
415
+ torch_dtype=getattr(m, "torch_dtype", None),
416
+ revision=m.revision,
417
+ # pass device or device_map when creating the pipeline
418
+ **HuggingFacePipelineHandler._get_device_config(**kwargs),
419
+ # pass other model_kwargs to transformers.pipeline.from_pretrained method
420
+ **m.model_kwargs,
421
+ )
422
+
423
+ if os.path.isdir(model_blob_file_or_dir_path) and not is_repo_downloaded:
424
+ # the logged model is a transformers.Pipeline object
425
+ # weights of the model are saved in the directory
426
+ return _create_pipeline_from_dir(model_blob_file_or_dir_path, model_blob_options, **kwargs)
353
427
  else:
354
- assert os.path.isfile(model_blob_file_or_dir_path)
355
- with open(model_blob_file_or_dir_path, "rb") as f:
356
- m = cloudpickle.load(f)
357
- assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
358
- if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
359
- m.__dict__.update(cls._get_device_config(**kwargs))
360
-
361
- if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
362
- m.__dict__.update(torch_dtype="auto")
363
- return m
428
+ # case 1: LEGACY logging, repo snapshot is not logged
429
+ if os.path.isfile(model_blob_file_or_dir_path):
430
+ # LEGACY logging that had model as a pickle file in the model blob directory
431
+ # the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
432
+ # the model_blob_file_or_dir_path is the pickle file that holds
433
+ # the huggingface_pipeline.HuggingFacePipelineModel object
434
+ # the snapshot of the repo is not logged
435
+ return cls._load_pickle_model(model_blob_file_or_dir_path)
436
+ else:
437
+ assert os.path.isdir(model_blob_file_or_dir_path)
438
+ # the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
439
+ # the pickle_file holds the huggingface_pipeline.HuggingFacePipelineModel object
440
+ pickle_file = os.path.join(model_blob_file_or_dir_path, cls.MODEL_PICKLE_FILE)
441
+ m = cls._load_pickle_model(pickle_file)
442
+
443
+ # case 2: logging without the snapshot of the repo
444
+ if not is_repo_downloaded:
445
+ # we return the huggingface_pipeline.HuggingFacePipelineModel object
446
+ return m
447
+ # case 3: logging with the snapshot of the repo
448
+ else:
449
+ # the model_blob_file_or_dir_path is the directory that holds
450
+ # weights of the model from `huggingface_hub.snapshot_download`
451
+ # the huggingface_pipeline.HuggingFacePipelineModel object is logged
452
+ # with a snapshot of the repo, we create a transformers.Pipeline object
453
+ # by reading the snapshot directory
454
+ return _create_pipeline_from_model(model_blob_file_or_dir_path, m, **kwargs)
364
455
 
365
456
  @classmethod
366
457
  def convert_as_custom_model(
@@ -401,6 +492,34 @@ class HuggingFacePipelineHandler(
401
492
  ),
402
493
  axis=1,
403
494
  ).to_list()
495
+ elif raw_model.task == "text-generation":
496
+ # verify when the target method is __call__ and
497
+ # if the signature is default text-generation signature
498
+ # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
499
+ if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
500
+ wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
501
+
502
+ temp_res = X.apply(
503
+ lambda row: wrapped_model.generate_chat_completion(
504
+ messages=row["messages"],
505
+ max_completion_tokens=row.get("max_completion_tokens", None),
506
+ temperature=row.get("temperature", None),
507
+ stop_strings=row.get("stop", None),
508
+ n=row.get("n", 1),
509
+ stream=row.get("stream", False),
510
+ top_p=row.get("top_p", 1.0),
511
+ frequency_penalty=row.get("frequency_penalty", None),
512
+ presence_penalty=row.get("presence_penalty", None),
513
+ ),
514
+ axis=1,
515
+ ).to_list()
516
+ else:
517
+ if len(signature.inputs) > 1:
518
+ input_data = X.to_dict("records")
519
+ # If it is only expecting one argument, Then it is expecting a list of something.
520
+ else:
521
+ input_data = X[signature.inputs[0].name].to_list()
522
+ temp_res = getattr(raw_model, target_method)(input_data)
404
523
  else:
405
524
  # For others, we could offer the whole dataframe as a list.
406
525
  # Some of them may need some conversion
@@ -527,3 +646,171 @@ class HuggingFacePipelineHandler(
527
646
  hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
528
647
 
529
648
  return hg_pipe_model
649
+
650
+
651
+ class HuggingFaceOpenAICompatibleModel:
652
+ """
653
+ A class to wrap a Hugging Face text generation model and provide an
654
+ OpenAI-compatible chat completion interface.
655
+ """
656
+
657
+ def __init__(self, pipeline: "transformers.Pipeline") -> None:
658
+ """
659
+ Initializes the model and tokenizer.
660
+
661
+ Args:
662
+ pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
663
+ """
664
+
665
+ self.pipeline = pipeline
666
+ self.model = self.pipeline.model
667
+ self.tokenizer = self.pipeline.tokenizer
668
+
669
+ self.model_name = self.pipeline.model.name_or_path
670
+
671
+ def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
672
+ """
673
+ Applies a chat template to a list of messages.
674
+ If the tokenizer has a chat template, it uses that.
675
+ Otherwise, it falls back to a simple concatenation.
676
+
677
+ Args:
678
+ messages (list[dict]): A list of message dictionaries, e.g.,
679
+ [{"role": "user", "content": "Hello!"}, ...]
680
+
681
+ Returns:
682
+ The formatted prompt string ready for model input.
683
+ """
684
+
685
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
686
+ # Use the tokenizer's built-in chat template if available
687
+ # `tokenize=False` means it returns a string, not token IDs
688
+ return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
689
+ messages,
690
+ tokenize=False,
691
+ add_generation_prompt=True,
692
+ )
693
+ else:
694
+ # Fallback to a simple concatenation for models without a specific chat template
695
+ # This is a basic example; real chat models often need specific formatting.
696
+ prompt = ""
697
+ for message in messages:
698
+ role = message.get("role", "user")
699
+ content = message.get("content", "")
700
+ if role == "system":
701
+ prompt += f"System: {content}\n"
702
+ elif role == "user":
703
+ prompt += f"User: {content}\n"
704
+ elif role == "assistant":
705
+ prompt += f"Assistant: {content}\n"
706
+ prompt += "Assistant:" # Indicate that the assistant should respond
707
+ return prompt
708
+
709
+ def generate_chat_completion(
710
+ self,
711
+ messages: list[dict[str, Any]],
712
+ max_completion_tokens: Optional[int] = None,
713
+ stream: Optional[bool] = False,
714
+ stop_strings: Optional[list[str]] = None,
715
+ temperature: Optional[float] = None,
716
+ top_p: Optional[float] = None,
717
+ frequency_penalty: Optional[float] = None,
718
+ presence_penalty: Optional[float] = None,
719
+ n: int = 1,
720
+ ) -> dict[str, Any]:
721
+ """
722
+ Generates a chat completion response in an OpenAI-compatible format.
723
+
724
+ Args:
725
+ messages (list[dict]): A list of message dictionaries, e.g.,
726
+ [{"role": "system", "content": "You are a helpful assistant."},
727
+ {"role": "user", "content": "What is deep learning?"}]
728
+ max_completion_tokens (int): The maximum number of completion tokens to generate.
729
+ stop_strings (list[str]): A list of strings to stop generation.
730
+ temperature (float): The temperature for sampling.
731
+ top_p (float): The top-p value for sampling.
732
+ stream (bool): Whether to stream the generation.
733
+ frequency_penalty (float): The frequency penalty for sampling.
734
+ presence_penalty (float): The presence penalty for sampling.
735
+ n (int): The number of samples to generate.
736
+
737
+ Returns:
738
+ dict: An OpenAI-compatible dictionary representing the chat completion.
739
+ """
740
+ # Apply chat template to convert messages into a single prompt string
741
+
742
+ prompt_text = self._apply_chat_template(messages)
743
+
744
+ # Tokenize the prompt
745
+ inputs = self.tokenizer(
746
+ prompt_text,
747
+ return_tensors="pt",
748
+ padding=True,
749
+ ).to(self.model.device)
750
+ prompt_tokens = inputs.input_ids.shape[1]
751
+
752
+ from transformers import GenerationConfig
753
+
754
+ generation_config = GenerationConfig(
755
+ max_new_tokens=max_completion_tokens,
756
+ temperature=temperature,
757
+ top_p=top_p,
758
+ pad_token_id=self.tokenizer.pad_token_id,
759
+ eos_token_id=self.tokenizer.eos_token_id,
760
+ stop_strings=stop_strings,
761
+ stream=stream,
762
+ repetition_penalty=frequency_penalty,
763
+ diversity_penalty=presence_penalty if n > 1 else None,
764
+ num_return_sequences=n,
765
+ num_beams=max(2, n), # must be >1
766
+ num_beam_groups=max(2, n) if presence_penalty else 1,
767
+ do_sample=False,
768
+ )
769
+
770
+ # Generate text
771
+ output_ids = self.model.generate(
772
+ inputs.input_ids,
773
+ attention_mask=inputs.attention_mask,
774
+ generation_config=generation_config,
775
+ )
776
+
777
+ generated_texts = []
778
+ completion_tokens = 0
779
+ total_tokens = prompt_tokens
780
+ for output_id in output_ids:
781
+ # The output_ids include the input prompt
782
+ # Decode the generated text, excluding the input prompt
783
+ # so we slice to get only new tokens
784
+ generated_tokens = output_id[prompt_tokens:]
785
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
786
+ generated_texts.append(generated_text)
787
+
788
+ # Calculate completion tokens
789
+ completion_tokens += len(generated_tokens)
790
+ total_tokens += len(generated_tokens)
791
+
792
+ choices = []
793
+ for i, generated_text in enumerate(generated_texts):
794
+ choices.append(
795
+ {
796
+ "index": i,
797
+ "message": {"role": "assistant", "content": generated_text},
798
+ "logprobs": None, # Not directly supported in this basic implementation
799
+ "finish_reason": "stop", # Assuming stop for simplicity
800
+ }
801
+ )
802
+
803
+ # Construct OpenAI-compatible response
804
+ response = {
805
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
806
+ "object": "chat.completion",
807
+ "created": int(time.time()),
808
+ "model": self.model_name,
809
+ "choices": choices,
810
+ "usage": {
811
+ "prompt_tokens": prompt_tokens,
812
+ "completion_tokens": completion_tokens,
813
+ "total_tokens": total_tokens,
814
+ },
815
+ }
816
+ return response
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
386
386
  predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
387
387
  try:
388
388
  explainer = shap.Explainer(predictor, transformed_bg_data)
389
- return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
389
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
390
+ np.float64, errors="ignore"
391
+ )
390
392
  except TypeError:
391
393
  if isinstance(data, pd.DataFrame):
392
394
  dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
@@ -229,6 +229,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
229
229
  enable_categorical = False
230
230
  for col, d_type in X.dtypes.items():
231
231
  if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
232
+ if pd.CategoricalDtype.is_dtype(d_type):
233
+ enable_categorical = True
234
+ elif isinstance(d_type, pd.StringDtype):
235
+ X[col] = X[col].astype("category")
236
+ enable_categorical = True
232
237
  continue
233
238
  if not np.issubdtype(d_type, np.number):
234
239
  # categorical columns are converted to numpy's str dtype
@@ -51,6 +51,7 @@ class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
51
51
  has_tokenizer: NotRequired[bool]
52
52
  has_feature_extractor: NotRequired[bool]
53
53
  has_image_preprocessor: NotRequired[bool]
54
+ is_repo_downloaded: NotRequired[Optional[bool]]
54
55
 
55
56
 
56
57
  class LightGBMModelBlobOptions(BaseModelBlobOptions):
@@ -14,7 +14,7 @@ REQUIREMENTS = [
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
16
  "platformdirs<5",
17
- "pyarrow",
17
+ "pyarrow<19.0.0",
18
18
  "pydantic>=2.8.2, <3",
19
19
  "pyjwt>=2.0.0, <3",
20
20
  "pytimeparse>=1.1.8,<2",
@@ -22,10 +22,10 @@ REQUIREMENTS = [
22
22
  "requests",
23
23
  "retrying>=1.3.3,<2",
24
24
  "s3fs>=2024.6.1,<2026",
25
- "scikit-learn<1.6",
25
+ "scikit-learn<1.7",
26
26
  "scipy>=1.9,<2",
27
27
  "shap>=0.46.0,<1",
28
- "snowflake-connector-python>=3.15.0,<4",
28
+ "snowflake-connector-python>=3.16.0,<4",
29
29
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
30
30
  "snowflake.core>=1.0.2,<2",
31
31
  "sqlparse>=0.4,<1",
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
84
84
  return json.loads(x)
85
85
 
86
86
  for field in data.schema.fields:
87
- if isinstance(field.datatype, spt.ArrayType):
87
+ if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
88
88
  df_local[identifier.get_unescaped_names(field.name)] = df_local[
89
89
  identifier.get_unescaped_names(field.name)
90
90
  ].map(load_if_not_null)
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
104
104
  return data
105
105
 
106
106
 
107
- def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any]) -> Optional[core.ModelSignature]:
107
+ def huggingface_pipeline_signature_auto_infer(
108
+ task: str,
109
+ params: dict[str, Any],
110
+ ) -> Optional[core.ModelSignature]:
108
111
  # Text
109
112
 
110
113
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
297
300
  )
298
301
  ],
299
302
  )
300
-
301
303
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
302
304
  if task == "text2text-generation":
303
305
  if params.get("return_tensors", False):
@@ -28,6 +28,10 @@ class HuggingFacePipelineModel:
28
28
  token: Optional[str] = None,
29
29
  trust_remote_code: Optional[bool] = None,
30
30
  model_kwargs: Optional[dict[str, Any]] = None,
31
+ download_snapshot: bool = True,
32
+ # repo snapshot download args
33
+ allow_patterns: Optional[Union[list[str], str]] = None,
34
+ ignore_patterns: Optional[Union[list[str], str]] = None,
31
35
  **kwargs: Any,
32
36
  ) -> None:
33
37
  """
@@ -52,6 +56,9 @@ class HuggingFacePipelineModel:
52
56
  Defaults to None.
53
57
  model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,`.
54
58
  Defaults to None.
59
+ download_snapshot: Whether to download the HuggingFace repository. Defaults to True.
60
+ allow_patterns: If provided, only files matching at least one pattern are downloaded.
61
+ ignore_patterns: If provided, files matching any of the patterns are not downloaded.
55
62
  kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for
56
63
  the corresponding pipeline class for possible values).
57
64
 
@@ -220,6 +227,21 @@ class HuggingFacePipelineModel:
220
227
  stacklevel=2,
221
228
  )
222
229
 
230
+ repo_snapshot_dir: Optional[str] = None
231
+ if download_snapshot:
232
+ try:
233
+ from huggingface_hub import snapshot_download
234
+
235
+ repo_snapshot_dir = snapshot_download(
236
+ repo_id=model,
237
+ revision=revision,
238
+ token=token,
239
+ allow_patterns=allow_patterns,
240
+ ignore_patterns=ignore_patterns,
241
+ )
242
+ except ImportError:
243
+ logger.info("huggingface_hub package is not installed, skipping snapshot download")
244
+
223
245
  # ==== End pipeline logic from transformers ====
224
246
 
225
247
  self.task = normalized_task
@@ -229,6 +251,7 @@ class HuggingFacePipelineModel:
229
251
  self.trust_remote_code = trust_remote_code
230
252
  self.model_kwargs = model_kwargs
231
253
  self.tokenizer = tokenizer
254
+ self.repo_snapshot_dir = repo_snapshot_dir
232
255
  self.__dict__.update(kwargs)
233
256
 
234
257
  @telemetry.send_api_usage_telemetry(
@@ -0,0 +1,57 @@
1
+ from snowflake.ml.model._signatures import core
2
+
3
+ _OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
4
+ inputs=[
5
+ core.FeatureGroupSpec(
6
+ name="messages",
7
+ specs=[
8
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
9
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
10
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
11
+ core.FeatureSpec(name="title", dtype=core.DataType.STRING),
12
+ ],
13
+ shape=(-1,),
14
+ ),
15
+ core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
16
+ core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
17
+ core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
18
+ core.FeatureSpec(name="n", dtype=core.DataType.INT32),
19
+ core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
20
+ core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
21
+ core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
22
+ core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
23
+ ],
24
+ outputs=[
25
+ core.FeatureSpec(name="id", dtype=core.DataType.STRING),
26
+ core.FeatureSpec(name="object", dtype=core.DataType.STRING),
27
+ core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
28
+ core.FeatureSpec(name="model", dtype=core.DataType.STRING),
29
+ core.FeatureGroupSpec(
30
+ name="choices",
31
+ specs=[
32
+ core.FeatureSpec(name="index", dtype=core.DataType.INT32),
33
+ core.FeatureGroupSpec(
34
+ name="message",
35
+ specs=[
36
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
37
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
38
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
39
+ ],
40
+ ),
41
+ core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
42
+ core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
43
+ ],
44
+ shape=(-1,),
45
+ ),
46
+ core.FeatureGroupSpec(
47
+ name="usage",
48
+ specs=[
49
+ core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
50
+ core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
51
+ core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
52
+ ],
53
+ ),
54
+ ],
55
+ )
56
+
57
+ OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}