snowflake-ml-python 1.11.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/experiment_tracking.py +89 -4
  6. snowflake/ml/feature_store/feature_store.py +1150 -131
  7. snowflake/ml/feature_store/feature_view.py +122 -0
  8. snowflake/ml/jobs/_utils/constants.py +8 -16
  9. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  10. snowflake/ml/jobs/_utils/payload_utils.py +19 -5
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +12 -4
  13. snowflake/ml/jobs/_utils/spec_utils.py +4 -6
  14. snowflake/ml/jobs/_utils/types.py +2 -1
  15. snowflake/ml/jobs/job.py +33 -17
  16. snowflake/ml/jobs/manager.py +107 -12
  17. snowflake/ml/model/__init__.py +6 -1
  18. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  19. snowflake/ml/model/_client/model/model_version_impl.py +61 -65
  20. snowflake/ml/model/_client/ops/service_ops.py +73 -154
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +20 -37
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +14 -4
  23. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  24. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  26. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  27. snowflake/ml/model/_signatures/utils.py +4 -2
  28. snowflake/ml/model/openai_signatures.py +57 -0
  29. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  30. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  32. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  33. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  34. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  35. snowflake/ml/modeling/cluster/birch.py +1 -1
  36. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  37. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  38. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  39. snowflake/ml/modeling/cluster/k_means.py +1 -1
  40. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  41. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  42. snowflake/ml/modeling/cluster/optics.py +1 -1
  43. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  44. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  45. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  46. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  47. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  48. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  49. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  51. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  52. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  53. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  54. snowflake/ml/modeling/covariance/oas.py +1 -1
  55. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  56. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  57. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  58. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  59. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  60. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  62. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  63. snowflake/ml/modeling/decomposition/pca.py +1 -1
  64. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  65. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  66. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  67. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  79. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  84. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  85. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  86. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  87. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  88. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  89. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  90. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  91. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  94. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  95. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  96. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  107. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/lars.py +1 -1
  112. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  113. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  118. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  128. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  131. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  140. snowflake/ml/modeling/manifold/isomap.py +1 -1
  141. snowflake/ml/modeling/manifold/mds.py +1 -1
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  143. snowflake/ml/modeling/manifold/tsne.py +1 -1
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  156. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  166. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  167. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  168. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  169. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  170. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  171. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  172. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  173. snowflake/ml/modeling/svm/svc.py +1 -1
  174. snowflake/ml/modeling/svm/svr.py +1 -1
  175. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  176. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  177. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  178. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  179. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  180. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  182. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  183. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  184. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  185. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  186. snowflake/ml/monitoring/model_monitor.py +26 -0
  187. snowflake/ml/version.py +1 -1
  188. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +66 -5
  189. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +192 -188
  190. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  191. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  192. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -194,16 +194,14 @@ class ModelDeploymentSpec:
194
194
  self,
195
195
  job_name: sql_identifier.SqlIdentifier,
196
196
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
197
+ function_name: str,
198
+ input_stage_location: str,
199
+ output_stage_location: str,
200
+ completion_filename: str,
201
+ input_file_pattern: str,
197
202
  warehouse: sql_identifier.SqlIdentifier,
198
- target_method: str,
199
- input_table_name: sql_identifier.SqlIdentifier,
200
- output_table_name: sql_identifier.SqlIdentifier,
201
203
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
202
204
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
203
- input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
204
- input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
205
- output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
206
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
207
205
  cpu: Optional[str] = None,
208
206
  memory: Optional[str] = None,
209
207
  gpu: Optional[Union[str, int]] = None,
@@ -215,16 +213,14 @@ class ModelDeploymentSpec:
215
213
  Args:
216
214
  job_name: Name of the job.
217
215
  inference_compute_pool_name: Compute pool for inference.
216
+ warehouse: Warehouse for the job.
217
+ function_name: Function name.
218
+ input_stage_location: Stage location for input data.
219
+ output_stage_location: Stage location for output data.
218
220
  job_database_name: Database name for the job.
219
221
  job_schema_name: Schema name for the job.
220
- warehouse: Warehouse for the job.
221
- target_method: Target method for inference.
222
- input_table_name: Input table name.
223
- output_table_name: Output table name.
224
- input_table_database_name: Database for input table.
225
- input_table_schema_name: Schema for input table.
226
- output_table_database_name: Database for output table.
227
- output_table_schema_name: Schema for output table.
222
+ input_file_pattern: Pattern for input files (optional).
223
+ completion_filename: Name of completion file (default: "completion.txt").
228
224
  cpu: CPU requirement.
229
225
  memory: Memory requirement.
230
226
  gpu: GPU requirement.
@@ -242,41 +238,28 @@ class ModelDeploymentSpec:
242
238
 
243
239
  saved_job_database = job_database_name or self.database
244
240
  saved_job_schema = job_schema_name or self.schema
245
- input_table_database_name = input_table_database_name or self.database
246
- input_table_schema_name = input_table_schema_name or self.schema
247
- output_table_database_name = output_table_database_name or self.database
248
- output_table_schema_name = output_table_schema_name or self.schema
249
241
 
250
242
  assert saved_job_database is not None
251
243
  assert saved_job_schema is not None
252
- assert input_table_database_name is not None
253
- assert input_table_schema_name is not None
254
- assert output_table_database_name is not None
255
- assert output_table_schema_name is not None
256
244
 
257
245
  fq_job_name = identifier.get_schema_level_object_identifier(
258
246
  saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
259
247
  )
260
- fq_input_table_name = identifier.get_schema_level_object_identifier(
261
- input_table_database_name.identifier(),
262
- input_table_schema_name.identifier(),
263
- input_table_name.identifier(),
264
- )
265
- fq_output_table_name = identifier.get_schema_level_object_identifier(
266
- output_table_database_name.identifier(),
267
- output_table_schema_name.identifier(),
268
- output_table_name.identifier(),
269
- )
270
248
 
271
249
  self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
272
250
 
273
251
  self._job = model_deployment_spec_schema.Job(
274
252
  name=fq_job_name,
275
253
  compute_pool=inference_compute_pool_name.identifier(),
276
- warehouse=warehouse.identifier(),
277
- target_method=target_method,
278
- input_table_name=fq_input_table_name,
279
- output_table_name=fq_output_table_name,
254
+ warehouse=warehouse.identifier() if warehouse else None,
255
+ function_name=function_name,
256
+ input=model_deployment_spec_schema.Input(
257
+ input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
258
+ ),
259
+ output=model_deployment_spec_schema.Output(
260
+ output_stage_location=output_stage_location,
261
+ completion_filename=completion_filename,
262
+ ),
280
263
  **self._inference_spec,
281
264
  )
282
265
  return self
@@ -35,6 +35,16 @@ class Service(BaseModel):
35
35
  inference_engine_spec: Optional[InferenceEngineSpec] = None
36
36
 
37
37
 
38
+ class Input(BaseModel):
39
+ input_stage_location: str
40
+ input_file_pattern: str
41
+
42
+
43
+ class Output(BaseModel):
44
+ output_stage_location: str
45
+ completion_filename: str
46
+
47
+
38
48
  class Job(BaseModel):
39
49
  name: str
40
50
  compute_pool: str
@@ -43,10 +53,10 @@ class Job(BaseModel):
43
53
  gpu: Optional[str] = None
44
54
  num_workers: Optional[int] = None
45
55
  max_batch_rows: Optional[int] = None
46
- warehouse: str
47
- target_method: str
48
- input_table_name: str
49
- output_table_name: str
56
+ warehouse: Optional[str] = None
57
+ function_name: str
58
+ input: Input
59
+ output: Output
50
60
 
51
61
 
52
62
  class LogModelArgs(BaseModel):
@@ -1,6 +1,8 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import time
5
+ import uuid
4
6
  import warnings
5
7
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
6
8
 
@@ -11,7 +13,12 @@ from packaging import version
11
13
  from typing_extensions import TypeGuard, Unpack
12
14
 
13
15
  from snowflake.ml._internal import type_utils
14
- from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
16
+ from snowflake.ml.model import (
17
+ custom_model,
18
+ model_signature,
19
+ openai_signatures,
20
+ type_hints as model_types,
21
+ )
15
22
  from snowflake.ml.model._packager.model_env import model_env
16
23
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
17
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
@@ -151,7 +158,10 @@ class HuggingFacePipelineHandler(
151
158
  assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
152
159
  params = {**model.__dict__, **model.model_kwargs}
153
160
 
154
- inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(task, params=params)
161
+ inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
162
+ task,
163
+ params=params,
164
+ )
155
165
 
156
166
  if not is_sub_model:
157
167
  target_methods = handlers_utils.get_target_methods(
@@ -401,6 +411,34 @@ class HuggingFacePipelineHandler(
401
411
  ),
402
412
  axis=1,
403
413
  ).to_list()
414
+ elif raw_model.task == "text-generation":
415
+ # verify when the target method is __call__ and
416
+ # if the signature is default text-generation signature
417
+ # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
418
+ if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
419
+ wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
420
+
421
+ temp_res = X.apply(
422
+ lambda row: wrapped_model.generate_chat_completion(
423
+ messages=row["messages"],
424
+ max_completion_tokens=row.get("max_completion_tokens", None),
425
+ temperature=row.get("temperature", None),
426
+ stop_strings=row.get("stop", None),
427
+ n=row.get("n", 1),
428
+ stream=row.get("stream", False),
429
+ top_p=row.get("top_p", 1.0),
430
+ frequency_penalty=row.get("frequency_penalty", None),
431
+ presence_penalty=row.get("presence_penalty", None),
432
+ ),
433
+ axis=1,
434
+ ).to_list()
435
+ else:
436
+ if len(signature.inputs) > 1:
437
+ input_data = X.to_dict("records")
438
+ # If it is only expecting one argument, Then it is expecting a list of something.
439
+ else:
440
+ input_data = X[signature.inputs[0].name].to_list()
441
+ temp_res = getattr(raw_model, target_method)(input_data)
404
442
  else:
405
443
  # For others, we could offer the whole dataframe as a list.
406
444
  # Some of them may need some conversion
@@ -527,3 +565,170 @@ class HuggingFacePipelineHandler(
527
565
  hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
528
566
 
529
567
  return hg_pipe_model
568
+
569
+
570
+ class HuggingFaceOpenAICompatibleModel:
571
+ """
572
+ A class to wrap a Hugging Face text generation model and provide an
573
+ OpenAI-compatible chat completion interface.
574
+ """
575
+
576
+ def __init__(self, pipeline: "transformers.Pipeline") -> None:
577
+ """
578
+ Initializes the model and tokenizer.
579
+
580
+ Args:
581
+ pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
582
+ """
583
+
584
+ self.pipeline = pipeline
585
+ self.model = self.pipeline.model
586
+ self.tokenizer = self.pipeline.tokenizer
587
+
588
+ self.model_name = self.pipeline.model.name_or_path
589
+
590
+ def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
591
+ """
592
+ Applies a chat template to a list of messages.
593
+ If the tokenizer has a chat template, it uses that.
594
+ Otherwise, it falls back to a simple concatenation.
595
+
596
+ Args:
597
+ messages (list[dict]): A list of message dictionaries, e.g.,
598
+ [{"role": "user", "content": "Hello!"}, ...]
599
+
600
+ Returns:
601
+ The formatted prompt string ready for model input.
602
+ """
603
+
604
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
605
+ # Use the tokenizer's built-in chat template if available
606
+ # `tokenize=False` means it returns a string, not token IDs
607
+ return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
608
+ messages,
609
+ tokenize=False,
610
+ add_generation_prompt=True,
611
+ )
612
+ else:
613
+ # Fallback to a simple concatenation for models without a specific chat template
614
+ # This is a basic example; real chat models often need specific formatting.
615
+ prompt = ""
616
+ for message in messages:
617
+ role = message.get("role", "user")
618
+ content = message.get("content", "")
619
+ if role == "system":
620
+ prompt += f"System: {content}\n"
621
+ elif role == "user":
622
+ prompt += f"User: {content}\n"
623
+ elif role == "assistant":
624
+ prompt += f"Assistant: {content}\n"
625
+ prompt += "Assistant:" # Indicate that the assistant should respond
626
+ return prompt
627
+
628
+ def generate_chat_completion(
629
+ self,
630
+ messages: list[dict[str, Any]],
631
+ max_completion_tokens: Optional[int] = None,
632
+ stream: Optional[bool] = False,
633
+ stop_strings: Optional[list[str]] = None,
634
+ temperature: Optional[float] = None,
635
+ top_p: Optional[float] = None,
636
+ frequency_penalty: Optional[float] = None,
637
+ presence_penalty: Optional[float] = None,
638
+ n: int = 1,
639
+ ) -> dict[str, Any]:
640
+ """
641
+ Generates a chat completion response in an OpenAI-compatible format.
642
+
643
+ Args:
644
+ messages (list[dict]): A list of message dictionaries, e.g.,
645
+ [{"role": "system", "content": "You are a helpful assistant."},
646
+ {"role": "user", "content": "What is deep learning?"}]
647
+ max_completion_tokens (int): The maximum number of completion tokens to generate.
648
+ stop_strings (list[str]): A list of strings to stop generation.
649
+ temperature (float): The temperature for sampling.
650
+ top_p (float): The top-p value for sampling.
651
+ stream (bool): Whether to stream the generation.
652
+ frequency_penalty (float): The frequency penalty for sampling.
653
+ presence_penalty (float): The presence penalty for sampling.
654
+ n (int): The number of samples to generate.
655
+
656
+ Returns:
657
+ dict: An OpenAI-compatible dictionary representing the chat completion.
658
+ """
659
+ # Apply chat template to convert messages into a single prompt string
660
+
661
+ prompt_text = self._apply_chat_template(messages)
662
+
663
+ # Tokenize the prompt
664
+ inputs = self.tokenizer(
665
+ prompt_text,
666
+ return_tensors="pt",
667
+ padding=True,
668
+ )
669
+ prompt_tokens = inputs.input_ids.shape[1]
670
+
671
+ from transformers import GenerationConfig
672
+
673
+ generation_config = GenerationConfig(
674
+ max_new_tokens=max_completion_tokens,
675
+ temperature=temperature,
676
+ top_p=top_p,
677
+ pad_token_id=self.tokenizer.pad_token_id,
678
+ eos_token_id=self.tokenizer.eos_token_id,
679
+ stop_strings=stop_strings,
680
+ stream=stream,
681
+ repetition_penalty=frequency_penalty,
682
+ diversity_penalty=presence_penalty if n > 1 else None,
683
+ num_return_sequences=n,
684
+ num_beams=max(2, n), # must be >1
685
+ num_beam_groups=max(2, n) if presence_penalty else 1,
686
+ )
687
+
688
+ # Generate text
689
+ output_ids = self.model.generate(
690
+ inputs.input_ids,
691
+ attention_mask=inputs.attention_mask,
692
+ generation_config=generation_config,
693
+ )
694
+
695
+ generated_texts = []
696
+ completion_tokens = 0
697
+ total_tokens = prompt_tokens
698
+ for output_id in output_ids:
699
+ # The output_ids include the input prompt
700
+ # Decode the generated text, excluding the input prompt
701
+ # so we slice to get only new tokens
702
+ generated_tokens = output_id[prompt_tokens:]
703
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
704
+ generated_texts.append(generated_text)
705
+
706
+ # Calculate completion tokens
707
+ completion_tokens += len(generated_tokens)
708
+ total_tokens += len(generated_tokens)
709
+
710
+ choices = []
711
+ for i, generated_text in enumerate(generated_texts):
712
+ choices.append(
713
+ {
714
+ "index": i,
715
+ "message": {"role": "assistant", "content": generated_text},
716
+ "logprobs": None, # Not directly supported in this basic implementation
717
+ "finish_reason": "stop", # Assuming stop for simplicity
718
+ }
719
+ )
720
+
721
+ # Construct OpenAI-compatible response
722
+ response = {
723
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
724
+ "object": "chat.completion",
725
+ "created": int(time.time()),
726
+ "model": self.model_name,
727
+ "choices": choices,
728
+ "usage": {
729
+ "prompt_tokens": prompt_tokens,
730
+ "completion_tokens": completion_tokens,
731
+ "total_tokens": total_tokens,
732
+ },
733
+ }
734
+ return response
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
386
386
  predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
387
387
  try:
388
388
  explainer = shap.Explainer(predictor, transformed_bg_data)
389
- return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
389
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
390
+ np.float64, errors="ignore"
391
+ )
390
392
  except TypeError:
391
393
  if isinstance(data, pd.DataFrame):
392
394
  dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
@@ -14,7 +14,7 @@ REQUIREMENTS = [
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
16
  "platformdirs<5",
17
- "pyarrow",
17
+ "pyarrow<19.0.0",
18
18
  "pydantic>=2.8.2, <3",
19
19
  "pyjwt>=2.0.0, <3",
20
20
  "pytimeparse>=1.1.8,<2",
@@ -22,10 +22,10 @@ REQUIREMENTS = [
22
22
  "requests",
23
23
  "retrying>=1.3.3,<2",
24
24
  "s3fs>=2024.6.1,<2026",
25
- "scikit-learn<1.6",
25
+ "scikit-learn<1.7",
26
26
  "scipy>=1.9,<2",
27
27
  "shap>=0.46.0,<1",
28
- "snowflake-connector-python>=3.15.0,<4",
28
+ "snowflake-connector-python>=3.16.0,<4",
29
29
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
30
30
  "snowflake.core>=1.0.2,<2",
31
31
  "sqlparse>=0.4,<1",
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
84
84
  return json.loads(x)
85
85
 
86
86
  for field in data.schema.fields:
87
- if isinstance(field.datatype, spt.ArrayType):
87
+ if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
88
88
  df_local[identifier.get_unescaped_names(field.name)] = df_local[
89
89
  identifier.get_unescaped_names(field.name)
90
90
  ].map(load_if_not_null)
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
104
104
  return data
105
105
 
106
106
 
107
- def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any]) -> Optional[core.ModelSignature]:
107
+ def huggingface_pipeline_signature_auto_infer(
108
+ task: str,
109
+ params: dict[str, Any],
110
+ ) -> Optional[core.ModelSignature]:
108
111
  # Text
109
112
 
110
113
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
297
300
  )
298
301
  ],
299
302
  )
300
-
301
303
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
302
304
  if task == "text2text-generation":
303
305
  if params.get("return_tensors", False):
@@ -0,0 +1,57 @@
1
+ from snowflake.ml.model._signatures import core
2
+
3
+ _OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
4
+ inputs=[
5
+ core.FeatureGroupSpec(
6
+ name="messages",
7
+ specs=[
8
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
9
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
10
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
11
+ core.FeatureSpec(name="title", dtype=core.DataType.STRING),
12
+ ],
13
+ shape=(-1,),
14
+ ),
15
+ core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
16
+ core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
17
+ core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
18
+ core.FeatureSpec(name="n", dtype=core.DataType.INT32),
19
+ core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
20
+ core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
21
+ core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
22
+ core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
23
+ ],
24
+ outputs=[
25
+ core.FeatureSpec(name="id", dtype=core.DataType.STRING),
26
+ core.FeatureSpec(name="object", dtype=core.DataType.STRING),
27
+ core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
28
+ core.FeatureSpec(name="model", dtype=core.DataType.STRING),
29
+ core.FeatureGroupSpec(
30
+ name="choices",
31
+ specs=[
32
+ core.FeatureSpec(name="index", dtype=core.DataType.INT32),
33
+ core.FeatureGroupSpec(
34
+ name="message",
35
+ specs=[
36
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
37
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
38
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
39
+ ],
40
+ ),
41
+ core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
42
+ core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
43
+ ],
44
+ shape=(-1,),
45
+ ),
46
+ core.FeatureGroupSpec(
47
+ name="usage",
48
+ specs=[
49
+ core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
50
+ core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
51
+ core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
52
+ ],
53
+ ),
54
+ ],
55
+ )
56
+
57
+ OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
@@ -42,6 +42,26 @@ def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -
42
42
  error_code=error_codes.DEPENDENCY_VERSION_ERROR,
43
43
  original_exception=RuntimeError(f"Arg {k} is not supported by current version of SKLearn/XGBoost."),
44
44
  )
45
+ elif v[0] == v[1] and v[0] != signature.parameters[k].default:
46
+ # If default value (pulled at autogen time) is not the same as the installed library's default value,
47
+ # we need to validate the parameter value against the parameter constraints.
48
+ # If the parameter value is invalid, we drop it.
49
+ try:
50
+ from sklearn.utils._param_validation import (
51
+ InvalidParameterError,
52
+ validate_parameter_constraints,
53
+ )
54
+
55
+ try:
56
+ validate_parameter_constraints(
57
+ klass._parameter_constraints, # type: ignore[attr-defined]
58
+ {k: v[0]},
59
+ klass.__name__,
60
+ )
61
+ except InvalidParameterError:
62
+ continue # Let the underlying estimator fill in the default value.
63
+ except (ImportError, AttributeError, TypeError):
64
+ result[k] = v[0] # Try to use the value as is.
45
65
  else:
46
66
  result[k] = v[0]
47
67
  return result
@@ -199,7 +219,12 @@ def handle_inference_result(
199
219
  transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload]
200
220
 
201
221
  if len(transformed_numpy_array.shape) == 1:
202
- transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
222
+ # Within a vectorized UDF, a single-row batch often yields a 1D array of length n_components.
223
+ # That must be reshaped to (1, n_components) to keep the number of rows aligned with the input batch.
224
+ if len(output_cols) > 1:
225
+ transformed_numpy_array = np.reshape(transformed_numpy_array, (1, -1))
226
+ else:
227
+ transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
203
228
 
204
229
  shape = transformed_numpy_array.shape
205
230
  if len(shape) > 1:
@@ -292,3 +317,20 @@ def should_include_sample_weight(estimator: object, method_name: str) -> bool:
292
317
  return True
293
318
 
294
319
  return False
320
+
321
+
322
+ def is_multi_task_estimator(estimator: object) -> bool:
323
+ """
324
+ Check if the estimator is a multi-task estimator that requires 2D targets.
325
+
326
+ Args:
327
+ estimator: The estimator to check
328
+
329
+ Returns:
330
+ True if the estimator is a multi-task estimator, False otherwise
331
+ """
332
+ # List of known multi-task estimators that require 2D targets
333
+ multi_task_estimators = {"MultiTaskElasticNet", "MultiTaskElasticNetCV", "MultiTaskLasso", "MultiTaskLassoCV"}
334
+
335
+ estimator_name = estimator.__class__.__name__
336
+ return estimator_name in multi_task_estimators
@@ -3,7 +3,10 @@ from typing import Optional
3
3
 
4
4
  import pandas as pd
5
5
 
6
- from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
6
+ from snowflake.ml.modeling._internal.estimator_utils import (
7
+ handle_inference_result,
8
+ is_multi_task_estimator,
9
+ )
7
10
 
8
11
 
9
12
  class PandasModelTrainer:
@@ -48,7 +51,11 @@ class PandasModelTrainer:
48
51
 
49
52
  if self.label_cols:
50
53
  label_arg_name = "Y" if "Y" in params else "y"
51
- args[label_arg_name] = self.dataset[self.label_cols].squeeze()
54
+ # For multi-task estimators, avoid squeezing to maintain 2D shape
55
+ if is_multi_task_estimator(self.estimator):
56
+ args[label_arg_name] = self.dataset[self.label_cols]
57
+ else:
58
+ args[label_arg_name] = self.dataset[self.label_cols].squeeze()
52
59
 
53
60
  if self.sample_weight_col is not None and "sample_weight" in params:
54
61
  args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
@@ -115,7 +122,11 @@ class PandasModelTrainer:
115
122
  args = {"X": self.dataset[self.input_cols]}
116
123
  if self.label_cols:
117
124
  label_arg_name = "Y" if "Y" in params else "y"
118
- args[label_arg_name] = self.dataset[self.label_cols].squeeze()
125
+ # For multi-task estimators, avoid squeezing to maintain 2D shape
126
+ if is_multi_task_estimator(self.estimator):
127
+ args[label_arg_name] = self.dataset[self.label_cols]
128
+ else:
129
+ args[label_arg_name] = self.dataset[self.label_cols].squeeze()
119
130
 
120
131
  if self.sample_weight_col is not None and "sample_weight" in params:
121
132
  args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
@@ -22,6 +22,7 @@ from snowflake.ml._internal.utils import (
22
22
  from snowflake.ml.modeling._internal import estimator_utils
23
23
  from snowflake.ml.modeling._internal.estimator_utils import (
24
24
  handle_inference_result,
25
+ is_multi_task_estimator,
25
26
  should_include_sample_weight,
26
27
  )
27
28
  from snowflake.ml.modeling._internal.model_specifications import (
@@ -178,7 +179,11 @@ class SnowparkModelTrainer:
178
179
  args = {"X": df[input_cols]}
179
180
  if label_cols:
180
181
  label_arg_name = "Y" if "Y" in params else "y"
181
- args[label_arg_name] = df[label_cols].squeeze()
182
+ # For multi-task estimators, avoid squeezing to maintain 2D shape
183
+ if is_multi_task_estimator(estimator):
184
+ args[label_arg_name] = df[label_cols]
185
+ else:
186
+ args[label_arg_name] = df[label_cols].squeeze()
182
187
 
183
188
  # Sample weight is not included in search estimators parameters, check the underlying estimator.
184
189
  if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
@@ -416,7 +421,11 @@ class SnowparkModelTrainer:
416
421
  args = {"X": df[input_cols]}
417
422
  if label_cols:
418
423
  label_arg_name = "Y" if "Y" in params else "y"
419
- args[label_arg_name] = df[label_cols].squeeze()
424
+ # For multi-task estimators, avoid squeezing to maintain 2D shape
425
+ if is_multi_task_estimator(estimator):
426
+ args[label_arg_name] = df[label_cols]
427
+ else:
428
+ args[label_arg_name] = df[label_cols].squeeze()
420
429
 
421
430
  if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
422
431
  args["sample_weight"] = df[sample_weight_col].squeeze()
@@ -734,12 +743,14 @@ class SnowparkModelTrainer:
734
743
  # Create a temp table in advance to store the output
735
744
  # This would allow us to use the same table outside the stored procedure
736
745
  df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params)
737
- df_one_line[
738
- expected_output_cols_list[0]
739
- ] = "[0]" # Add one column as the output_col; this is a dummy value to represent the OBJECT type
746
+ # Pre-create ALL expected output columns so subsequent writes can target the same schema.
747
+ # Use a simple dummy string value to represent OBJECT-typed payloads.
748
+ for out_col in expected_output_cols_list:
749
+ df_one_line[out_col] = "[0]"
740
750
  if drop_input_cols:
751
+ # When input columns are dropped, the table should only contain the output columns.
741
752
  self.session.write_pandas(
742
- df_one_line[expected_output_cols_list[0]],
753
+ df_one_line[expected_output_cols_list],
743
754
  fit_transform_result_name,
744
755
  auto_create_table=True,
745
756
  table_type="temp",
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
- SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
64
64
  # Modeling library estimators require a smaller sklearn version range.
65
65
  if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
66
  raise Exception(