snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +101 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/platform_capabilities.py +87 -0
  12. snowflake/ml/_internal/utils/identifier.py +4 -2
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/dataset/dataset.py +0 -1
  19. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  20. snowflake/ml/fileset/fileset.py +24 -18
  21. snowflake/ml/jobs/__init__.py +21 -0
  22. snowflake/ml/jobs/_utils/constants.py +51 -0
  23. snowflake/ml/jobs/_utils/payload_utils.py +352 -0
  24. snowflake/ml/jobs/_utils/spec_utils.py +298 -0
  25. snowflake/ml/jobs/_utils/types.py +39 -0
  26. snowflake/ml/jobs/decorators.py +91 -0
  27. snowflake/ml/jobs/job.py +113 -0
  28. snowflake/ml/jobs/manager.py +298 -0
  29. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  30. snowflake/ml/model/_client/ops/model_ops.py +13 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  32. snowflake/ml/model/_client/sql/model_version.py +11 -0
  33. snowflake/ml/model/_client/sql/service.py +13 -6
  34. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  35. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  37. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  38. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  40. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  41. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  42. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  43. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  44. snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
  45. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  46. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  49. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  50. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  51. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  52. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  53. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  54. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  56. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  57. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  58. snowflake/ml/model/_signatures/base_handler.py +1 -2
  59. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  60. snowflake/ml/model/_signatures/numpy_handler.py +6 -7
  61. snowflake/ml/model/_signatures/pandas_handler.py +3 -3
  62. snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
  63. snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
  64. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  65. snowflake/ml/model/model_signature.py +17 -4
  66. snowflake/ml/model/type_hints.py +1 -0
  67. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  68. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  69. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  70. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  71. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  72. snowflake/ml/modeling/cluster/birch.py +6 -3
  73. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  74. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  75. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  76. snowflake/ml/modeling/cluster/k_means.py +6 -3
  77. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  78. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  79. snowflake/ml/modeling/cluster/optics.py +6 -3
  80. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  81. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  82. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  83. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  84. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  85. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  86. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  87. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  88. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  89. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  90. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  91. snowflake/ml/modeling/covariance/oas.py +6 -3
  92. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  93. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  94. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  95. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  96. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  97. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  98. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  99. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  100. snowflake/ml/modeling/decomposition/pca.py +6 -3
  101. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  102. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  103. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  104. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  105. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  106. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  107. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  108. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  109. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  110. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  111. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  112. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  113. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  114. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  115. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  116. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  117. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  118. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  119. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  120. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  121. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  122. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  123. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  124. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  125. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  126. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  127. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  128. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  129. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  130. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  131. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  132. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  133. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  134. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  135. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  136. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  137. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  138. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  139. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  140. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  141. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  142. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  143. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  144. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  145. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  146. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  148. snowflake/ml/modeling/linear_model/lars.py +6 -3
  149. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  151. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  152. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  153. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  154. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  155. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  156. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  157. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  158. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  159. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  160. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  161. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  162. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  163. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  164. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  165. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  166. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  167. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  168. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  169. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  170. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  171. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  172. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  173. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  174. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  175. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  176. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  177. snowflake/ml/modeling/manifold/isomap.py +6 -3
  178. snowflake/ml/modeling/manifold/mds.py +6 -3
  179. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  180. snowflake/ml/modeling/manifold/tsne.py +6 -3
  181. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  182. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  183. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  184. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  185. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  186. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  187. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  188. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  189. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  190. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  191. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  192. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  193. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  194. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  195. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  196. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  197. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  198. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  199. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  200. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  201. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  202. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  203. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  204. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  205. snowflake/ml/modeling/pipeline/pipeline.py +16 -178
  206. snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  209. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  210. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  211. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  212. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  213. snowflake/ml/modeling/svm/svc.py +6 -3
  214. snowflake/ml/modeling/svm/svr.py +6 -3
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
  223. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  224. snowflake/ml/registry/_manager/model_manager.py +70 -33
  225. snowflake/ml/registry/registry.py +41 -22
  226. snowflake/ml/version.py +1 -1
  227. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
  228. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
  229. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
  230. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  231. snowflake/ml/fileset/parquet_parser.py +0 -170
  232. snowflake/ml/fileset/tf_dataset.py +0 -88
  233. snowflake/ml/fileset/torch_datapipe.py +0 -57
  234. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  235. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  236. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
  237. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -88,6 +88,7 @@ class ModelComposer:
88
88
  pip_requirements: Optional[List[str]] = None,
89
89
  target_platforms: Optional[List[model_types.TargetPlatform]] = None,
90
90
  python_version: Optional[str] = None,
91
+ user_files: Optional[Dict[str, List[str]]] = None,
91
92
  ext_modules: Optional[List[ModuleType]] = None,
92
93
  code_paths: Optional[List[str]] = None,
93
94
  task: model_types.Task = model_types.Task.UNKNOWN,
@@ -97,9 +98,12 @@ class ModelComposer:
97
98
  options = model_types.BaseModelSaveOption()
98
99
 
99
100
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
100
- snowml_matched_versions = env_utils.get_matched_package_versions_in_snowflake_conda_channel(
101
- req=requirements.Requirement(f"snowflake-ml-python=={snowml_env.VERSION}")
102
- )
101
+ snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
102
+ self.session,
103
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
104
+ python_version=python_version or snowml_env.PYTHON_VERSION,
105
+ statement_params=self._statement_params,
106
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
103
107
 
104
108
  if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
105
109
  logging.info(
@@ -131,6 +135,7 @@ class ModelComposer:
131
135
  model_meta=self.packager.meta,
132
136
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
133
137
  options=options,
138
+ user_files=user_files,
134
139
  data_sources=self._get_data_sources(model, sample_input_data),
135
140
  target_platforms=target_platforms,
136
141
  )
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import List, Optional, cast
5
+ from typing import Dict, List, Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -11,9 +11,11 @@ from snowflake.ml.data import data_source
11
11
  from snowflake.ml.model import type_hints
12
12
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
13
13
  from snowflake.ml.model._model_composer.model_method import (
14
+ constants,
14
15
  function_generator,
15
16
  model_method,
16
17
  )
18
+ from snowflake.ml.model._model_composer.model_user_file import model_user_file
17
19
  from snowflake.ml.model._packager.model_meta import (
18
20
  model_meta as model_meta_api,
19
21
  model_meta_schema,
@@ -30,9 +32,11 @@ class ModelManifest:
30
32
  workspace_path: A local path where model related files should be dumped to.
31
33
  runtimes: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
32
34
  methods: A list of ModelMethod objects managing the method we registered to the MODEL object.
35
+ user_files: A list of ModelUserFile objects managing extra files uploaded to the workspace.
33
36
  """
34
37
 
35
38
  MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
39
+ _ENABLE_USER_FILES = False
36
40
  _DEFAULT_RUNTIME_NAME = "python_runtime"
37
41
 
38
42
  def __init__(self, workspace_path: pathlib.Path) -> None:
@@ -42,6 +46,7 @@ class ModelManifest:
42
46
  self,
43
47
  model_meta: model_meta_api.ModelMetadata,
44
48
  model_rel_path: pathlib.PurePosixPath,
49
+ user_files: Optional[Dict[str, List[str]]] = None,
45
50
  options: Optional[type_hints.ModelSaveOption] = None,
46
51
  data_sources: Optional[List[data_source.DataSource]] = None,
47
52
  target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
@@ -79,6 +84,7 @@ class ModelManifest:
79
84
 
80
85
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
81
86
  self.methods: List[model_method.ModelMethod] = []
87
+
82
88
  for target_method in model_meta.signatures.keys():
83
89
  method = model_method.ModelMethod(
84
90
  model_meta=model_meta,
@@ -88,11 +94,21 @@ class ModelManifest:
88
94
  is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
89
95
  model_meta_schema.FunctionProperties.PARTITIONED.value, False
90
96
  ),
97
+ wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
91
98
  options=model_method.get_model_method_options_from_options(options, target_method),
92
99
  )
93
100
 
94
101
  self.methods.append(method)
95
102
 
103
+ self.user_files: List[model_user_file.ModelUserFile] = []
104
+
105
+ if user_files is not None:
106
+ for subdirectory, paths in user_files.items():
107
+ for path in paths:
108
+ self.user_files.append(
109
+ model_user_file.ModelUserFile(pathlib.PurePosixPath(subdirectory), pathlib.Path(path))
110
+ )
111
+
96
112
  method_name_counter = collections.Counter([method.method_name for method in self.methods])
97
113
  dup_method_names = [k for k, v in method_name_counter.items() if v > 1]
98
114
  if dup_method_names:
@@ -129,6 +145,9 @@ class ModelManifest:
129
145
  ],
130
146
  )
131
147
 
148
+ if self._ENABLE_USER_FILES:
149
+ manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
150
+
132
151
  lineage_sources = self._extract_lineage_info(data_sources)
133
152
  if lineage_sources:
134
153
  manifest_dict["lineage_sources"] = lineage_sources
@@ -94,5 +94,6 @@ class ModelManifestDict(TypedDict):
94
94
  runtimes: Required[Dict[str, ModelRuntimeDict]]
95
95
  methods: Required[List[ModelMethodDict]]
96
96
  user_data: NotRequired[Dict[str, Any]]
97
+ user_files: NotRequired[List[str]]
97
98
  lineage_sources: NotRequired[List[LineageSourceDict]]
98
99
  target_platforms: NotRequired[List[str]]
@@ -0,0 +1 @@
1
+ SNOWPARK_UDF_INPUT_COL_LIMIT = 500
@@ -43,6 +43,7 @@ class FunctionGenerator:
43
43
  target_method: str,
44
44
  function_type: str,
45
45
  is_partitioned_function: bool = False,
46
+ wide_input: bool = False,
46
47
  options: Optional[FunctionGenerateOptions] = None,
47
48
  ) -> None:
48
49
  import importlib_resources
@@ -70,6 +71,7 @@ class FunctionGenerator:
70
71
  model_dir_name=self.model_dir_rel_path.name,
71
72
  target_method=target_method,
72
73
  max_batch_size=options.get("max_batch_size", None),
74
+ wide_input=wide_input,
73
75
  function_name=FunctionGenerator.FUNCTION_NAME,
74
76
  )
75
77
  with open(function_file_path, "w", encoding="utf-8") as f:
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
44
 
45
45
  # Actual function
46
- @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
46
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
47
  def {function_name}(df: pd.DataFrame) -> dict:
48
48
  df.columns = input_cols
49
49
  input_df = df.astype(dtype=dtype_map)
@@ -48,7 +48,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
48
48
 
49
49
  # Actual table function
50
50
  class {function_name}:
51
- @vectorized(input=pd.DataFrame)
51
+ @vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
52
52
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
53
  df.columns = input_cols
54
54
  input_df = df.astype(dtype=dtype_map)
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
44
  # Actual table function
45
45
  class {function_name}:
46
- @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
46
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
47
  def process(self, df: pd.DataFrame) -> pd.DataFrame:
48
48
  df.columns = input_cols
49
49
  input_df = df.astype(dtype=dtype_map)
@@ -7,7 +7,10 @@ from typing_extensions import NotRequired
7
7
  from snowflake.ml._internal.utils import sql_identifier
8
8
  from snowflake.ml.model import model_signature, type_hints
9
9
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
- from snowflake.ml.model._model_composer.model_method import function_generator
10
+ from snowflake.ml.model._model_composer.model_method import (
11
+ constants,
12
+ function_generator,
13
+ )
11
14
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
12
15
  from snowflake.snowpark._internal import type_utils
13
16
 
@@ -64,6 +67,7 @@ class ModelMethod:
64
67
  runtime_name: str,
65
68
  function_generator: function_generator.FunctionGenerator,
66
69
  is_partitioned_function: bool = False,
70
+ wide_input: bool = False,
67
71
  options: Optional[ModelMethodOptions] = None,
68
72
  ) -> None:
69
73
  self.model_meta = model_meta
@@ -71,6 +75,7 @@ class ModelMethod:
71
75
  self.function_generator = function_generator
72
76
  self.is_partitioned_function = is_partitioned_function
73
77
  self.runtime_name = runtime_name
78
+ self.wide_input = wide_input
74
79
  self.options = options or {}
75
80
  try:
76
81
  self.method_name = sql_identifier.SqlIdentifier(
@@ -114,12 +119,15 @@ class ModelMethod:
114
119
  self.target_method,
115
120
  self.function_type,
116
121
  self.is_partitioned_function,
122
+ self.wide_input,
117
123
  options=options,
118
124
  )
119
125
  input_list = [
120
126
  ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
121
127
  for ft in self.model_meta.signatures[self.target_method].inputs
122
128
  ]
129
+ if len(input_list) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT:
130
+ input_list = [{"name": "INPUT", "type": "OBJECT"}]
123
131
  input_name_counter = collections.Counter([input_info["name"] for input_info in input_list])
124
132
  dup_input_names = [k for k, v in input_name_counter.items() if v > 1]
125
133
  if dup_input_names:
@@ -0,0 +1,27 @@
1
+ import os
2
+ import pathlib
3
+
4
+ from snowflake.ml._internal import file_utils
5
+
6
+
7
+ class ModelUserFile:
8
+ """Class representing a user provided file.
9
+
10
+ Attributes:
11
+ subdirectory_name: A local path where model related files should be dumped to.
12
+ local_path: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
13
+ """
14
+
15
+ USER_FILES_DIR_REL_PATH = "user_files"
16
+
17
+ def __init__(self, subdirectory_name: pathlib.PurePosixPath, local_path: pathlib.Path) -> None:
18
+ self.subdirectory_name = subdirectory_name
19
+ self.local_path = local_path
20
+
21
+ def save(self, workspace_path: pathlib.Path) -> str:
22
+ user_files_path = workspace_path / ModelUserFile.USER_FILES_DIR_REL_PATH / self.subdirectory_name
23
+ user_files_path.mkdir(parents=True, exist_ok=True)
24
+
25
+ # copy the file to the workspace
26
+ file_utils.copy_file_or_tree(str(self.local_path), str(user_files_path))
27
+ return os.path.join(self.subdirectory_name, self.local_path.name)
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import os
3
+ import pathlib
3
4
  import warnings
4
- from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
5
6
 
6
7
  import numpy as np
7
8
  import numpy.typing as npt
@@ -37,8 +38,10 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
37
38
  return callable(getattr(model, method_name, None))
38
39
 
39
40
 
40
- def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
41
- trunc_sample_input = model_signature._truncate_data(sample_input_data)
41
+ def get_truncated_sample_data(
42
+ sample_input_data: model_types.SupportedDataType, length: int = 100
43
+ ) -> model_types.SupportedLocalDataType:
44
+ trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
42
45
  local_sample_input: model_types.SupportedLocalDataType = None
43
46
  if isinstance(sample_input_data, SnowparkDataFrame):
44
47
  # Added because of Any from missing stubs.
@@ -77,7 +80,14 @@ def validate_signature(
77
80
  local_sample_input = get_truncated_sample_data(sample_input_data)
78
81
  for target_method in target_methods:
79
82
  predictions_df = get_prediction_fn(target_method, local_sample_input)
80
- sig = model_signature.infer_signature(local_sample_input, predictions_df)
83
+ sig = model_signature.infer_signature(
84
+ sample_input_data,
85
+ predictions_df,
86
+ input_feature_names=None,
87
+ output_feature_names=None,
88
+ input_data_limit=100,
89
+ output_data_limit=100,
90
+ )
81
91
  model_meta.signatures[target_method] = sig
82
92
 
83
93
  return model_meta
@@ -118,7 +128,7 @@ def get_explainability_supported_background(
118
128
  meta: model_meta.ModelMetadata,
119
129
  explain_target_method: Optional[str],
120
130
  ) -> pd.DataFrame:
121
- if sample_input_data is None:
131
+ if sample_input_data is None or explain_target_method is None:
122
132
  return None
123
133
 
124
134
  if isinstance(sample_input_data, pd.DataFrame):
@@ -223,3 +233,27 @@ def get_explain_target_method(
223
233
  if method in target_methods_list:
224
234
  return method
225
235
  return None
236
+
237
+
238
+ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
239
+ import huggingface_hub
240
+
241
+ for f_path in pathlib.Path(local_model_path).iterdir():
242
+ if f_path.name in ["config.json", "tokenizer_config.json"]:
243
+ with open(f_path) as f:
244
+ config_dict = json.load(f)
245
+
246
+ # a. get repository and class_path from configs
247
+ auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
248
+ for config_name, config_value in auto_map_configs.items():
249
+ repository, _, class_path = config_value.rpartition("--")
250
+
251
+ # b. download required configs from hf hub
252
+ if repository:
253
+ huggingface_hub.snapshot_download(repo_id=repository, local_dir=local_model_path)
254
+
255
+ # c. update config files
256
+ config_dict["auto_map"][config_name] = class_path
257
+
258
+ with open(f_path, "w") as f:
259
+ json.dump(config_dict, f)
@@ -94,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
94
94
  sample_input_data=sample_input_data,
95
95
  get_prediction_fn=get_prediction,
96
96
  )
97
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
98
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
97
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
98
+ model_meta.task = model_task_and_output.task
99
99
  if enable_explainability:
100
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
101
101
  model_meta = handlers_utils.add_explain_method_signature(
@@ -227,7 +227,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
227
227
  import shap
228
228
 
229
229
  explainer = shap.TreeExplainer(raw_model)
230
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
230
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
231
231
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
232
232
 
233
233
  if target_method == "explain":
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
66
66
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
67
67
 
68
68
  if inspect.iscoroutinefunction(target_method):
69
- with anyio.start_blocking_portal() as portal:
69
+ with anyio.from_thread.start_blocking_portal() as portal:
70
70
  predictions_df = portal.call(target_method, model, sample_input_data)
71
71
  else:
72
72
  predictions_df = target_method(model, sample_input_data)
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
98
98
  if model.context.model_refs:
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
- assert handler is not None
102
101
  if handler is None:
103
102
  raise TypeError("Your input type to custom model is not currently supported")
104
103
  sub_model = handler.cast_model(model_ref.model)
@@ -195,8 +195,12 @@ class HuggingFacePipelineHandler(
195
195
  os.makedirs(model_blob_path, exist_ok=True)
196
196
 
197
197
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
198
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
198
199
  model.save_pretrained( # type:ignore[attr-defined]
199
- os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
200
+ save_path
201
+ )
202
+ handlers_utils.save_transformers_config_with_auto_map(
203
+ save_path,
200
204
  )
201
205
  pipeline_params = {
202
206
  "_batch_size": model._batch_size, # type:ignore[attr-defined]
@@ -319,6 +323,7 @@ class HuggingFacePipelineHandler(
319
323
  model_blob_options["task"],
320
324
  model=model_blob_file_or_dir_path,
321
325
  trust_remote_code=True,
326
+ torch_dtype="auto",
322
327
  **device_config,
323
328
  )
324
329
 
@@ -110,8 +110,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
110
110
  sample_input_data=sample_input_data,
111
111
  get_prediction_fn=get_prediction,
112
112
  )
113
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
114
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
113
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
114
+ model_meta.task = model_task_and_output.task
115
115
  if enable_explainability:
116
116
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
117
117
  model_meta = handlers_utils.add_explain_method_signature(
@@ -240,7 +240,9 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
240
240
  import shap
241
241
 
242
242
  explainer = shap.TreeExplainer(raw_model)
243
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
243
+ df = handlers_utils.convert_explanations_to_2D_df(
244
+ raw_model, explainer.shap_values(X, from_call=True)
245
+ )
244
246
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
245
247
 
246
248
  if target_method == "explain":
@@ -14,8 +14,8 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
14
  from snowflake.ml.model._packager.model_meta import (
15
15
  model_blob_meta,
16
16
  model_meta as model_meta_api,
17
+ model_meta_schema,
17
18
  )
18
- from snowflake.ml.model._signatures import utils as model_signature_utils
19
19
  from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  if TYPE_CHECKING:
@@ -24,6 +24,25 @@ if TYPE_CHECKING:
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
+ def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
28
+ if list(sigs.keys()) != ["encode"]:
29
+ raise ValueError("target_methods can only be ['encode']")
30
+
31
+ if len(sigs["encode"].inputs) != 1:
32
+ raise ValueError("SentenceTransformer can only accept 1 input column")
33
+
34
+ if len(sigs["encode"].outputs) != 1:
35
+ raise ValueError("SentenceTransformer can only return 1 output column")
36
+
37
+ assert isinstance(sigs["encode"].inputs[0], model_signature.FeatureSpec)
38
+
39
+ if sigs["encode"].inputs[0]._shape is not None:
40
+ raise ValueError("SentenceTransformer does not support input shape")
41
+
42
+ if sigs["encode"].inputs[0]._dtype != model_signature.DataType.STRING:
43
+ raise ValueError("SentenceTransformer only accepts string input")
44
+
45
+
27
46
  @final
28
47
  class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.SentenceTransformer"]):
29
48
  HANDLER_TYPE = "sentence_transformers"
@@ -68,6 +87,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
68
87
  if enable_explainability:
69
88
  raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
70
89
 
90
+ batch_size = kwargs.get("batch_size", 32)
91
+ if not isinstance(batch_size, int) or batch_size <= 0:
92
+ raise ValueError("batch_size must be a positive integer")
93
+
71
94
  # Validate target methods and signature (if possible)
72
95
  if not is_sub_model:
73
96
  target_methods = handlers_utils.get_target_methods(
@@ -75,12 +98,23 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
75
98
  target_methods=kwargs.pop("target_methods", None),
76
99
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
77
100
  )
78
- assert target_methods == ["encode"], "target_methods can only be ['encode']"
101
+ if target_methods != ["encode"]:
102
+ raise ValueError("target_methods can only be ['encode']")
79
103
 
80
104
  def get_prediction(
81
105
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
82
106
  ) -> model_types.SupportedLocalDataType:
83
- return _sentence_transformer_encode(model, sample_input_data)
107
+ if not isinstance(sample_input_data, pd.DataFrame):
108
+ sample_input_data = model_signature._convert_local_data_to_df(data=sample_input_data)
109
+
110
+ if sample_input_data.shape[1] != 1:
111
+ raise ValueError(
112
+ "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
113
+ )
114
+ X_list = sample_input_data.iloc[:, 0].tolist()
115
+
116
+ assert callable(getattr(model, "encode", None))
117
+ return pd.DataFrame({0: model.encode(X_list, batch_size=batch_size).tolist()})
84
118
 
85
119
  if model_meta.signatures:
86
120
  handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
@@ -102,10 +136,16 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
102
136
  get_prediction_fn=get_prediction,
103
137
  )
104
138
 
139
+ _validate_sentence_transformers_signatures(model_meta.signatures)
140
+
105
141
  # save model
106
142
  model_blob_path = os.path.join(model_blobs_dir_path, name)
107
143
  os.makedirs(model_blob_path, exist_ok=True)
108
- model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
144
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
145
+ model.save(save_path)
146
+ handlers_utils.save_transformers_config_with_auto_map(
147
+ save_path,
148
+ )
109
149
 
110
150
  # save model metadata
111
151
  base_meta = model_blob_meta.ModelBlobMeta(
@@ -113,6 +153,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
113
153
  model_type=cls.HANDLER_TYPE,
114
154
  handler_version=cls.HANDLER_VERSION,
115
155
  path=cls.MODEL_BLOB_FILE_OR_DIR,
156
+ options=model_meta_schema.SentenceTransformersModelBlobOptions(batch_size=batch_size),
116
157
  )
117
158
  model_meta.models[name] = base_meta
118
159
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -149,6 +190,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
149
190
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
150
191
  # We need to redirect the same folders to a writable location in the sandbox.
151
192
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
193
+ os.environ["HF_HOME"] = "/tmp"
152
194
 
153
195
  model_blob_path = os.path.join(model_blobs_dir_path, name)
154
196
  model_blobs_metadata = model_meta.models
@@ -183,6 +225,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
183
225
  raw_model: "sentence_transformers.SentenceTransformer",
184
226
  model_meta: model_meta_api.ModelMetadata,
185
227
  ) -> Type[custom_model.CustomModel]:
228
+ batch_size = cast(
229
+ model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
230
+ ).get("batch_size", None)
231
+
186
232
  def get_prediction(
187
233
  raw_model: "sentence_transformers.SentenceTransformer",
188
234
  signature: model_signature.ModelSignature,
@@ -190,8 +236,11 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
190
236
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
191
237
  @custom_model.inference_api
192
238
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
193
- predictions_df = _sentence_transformer_encode(raw_model, X)
194
- return model_signature_utils.rename_pandas_df(predictions_df, signature.outputs)
239
+ X_list = X.iloc[:, 0].tolist()
240
+
241
+ return pd.DataFrame(
242
+ {signature.outputs[0].name: raw_model.encode(X_list, batch_size=batch_size).tolist()}
243
+ )
195
244
 
196
245
  return fn
197
246
 
@@ -217,17 +266,3 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
217
266
  predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
218
267
  assert callable(predict_method)
219
268
  return sentence_transformers_SentenceTransformer_model
220
-
221
-
222
- def _sentence_transformer_encode(
223
- model: "sentence_transformers.SentenceTransformer", X: model_types.SupportedLocalDataType
224
- ) -> model_types.SupportedLocalDataType:
225
-
226
- if not isinstance(X, pd.DataFrame):
227
- X = model_signature._convert_local_data_to_df(X)
228
-
229
- assert X.shape[1] == 1, "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
230
- X_list = X.iloc[:, 0].tolist()
231
-
232
- assert callable(getattr(model, "encode", None))
233
- return pd.DataFrame({0: model.encode(X_list, batch_size=X.shape[0]).tolist()})
@@ -152,8 +152,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
152
152
  sample_input_data, model_meta, explain_target_method
153
153
  )
154
154
 
155
- model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
156
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
155
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
156
+ model_meta.task = model_task_and_output_type.task
157
157
 
158
158
  # if users did not ask then we enable if we have background data
159
159
  if enable_explainability is None:
@@ -164,11 +164,17 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
164
164
  stacklevel=1,
165
165
  )
166
166
  enable_explainability = False
167
- elif model_meta.task == model_types.Task.UNKNOWN:
167
+ elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
168
168
  enable_explainability = False
169
169
  else:
170
170
  enable_explainability = True
171
171
  if enable_explainability:
172
+ model_meta = handlers_utils.add_explain_method_signature(
173
+ model_meta=model_meta,
174
+ explain_method="explain",
175
+ target_method=explain_target_method,
176
+ output_return_type=model_task_and_output_type.output_type,
177
+ )
172
178
  handlers_utils.save_background_data(
173
179
  model_blobs_dir_path,
174
180
  cls.EXPLAIN_ARTIFACTS_DIR,
@@ -177,13 +183,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
177
183
  background_data,
178
184
  )
179
185
 
180
- model_meta = handlers_utils.add_explain_method_signature(
181
- model_meta=model_meta,
182
- explain_method="explain",
183
- target_method=explain_target_method,
184
- output_return_type=model_task_and_output_type.output_type,
185
- )
186
-
187
186
  model_blob_path = os.path.join(model_blobs_dir_path, name)
188
187
  os.makedirs(model_blob_path, exist_ok=True)
189
188
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f: