snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/utils/db_utils.py +50 -0
  13. snowflake/ml/_internal/utils/service_logger.py +63 -0
  14. snowflake/ml/_internal/utils/sql_identifier.py +25 -1
  15. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  16. snowflake/ml/data/ingestor_utils.py +20 -10
  17. snowflake/ml/feature_store/access_manager.py +3 -3
  18. snowflake/ml/feature_store/feature_store.py +19 -2
  19. snowflake/ml/feature_store/feature_view.py +82 -28
  20. snowflake/ml/fileset/stage_fs.py +2 -1
  21. snowflake/ml/lineage/lineage_node.py +7 -2
  22. snowflake/ml/model/__init__.py +1 -2
  23. snowflake/ml/model/_client/model/model_version_impl.py +78 -9
  24. snowflake/ml/model/_client/ops/model_ops.py +89 -7
  25. snowflake/ml/model/_client/ops/service_ops.py +200 -91
  26. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
  27. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  28. snowflake/ml/model/_client/sql/_base.py +5 -0
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +9 -5
  31. snowflake/ml/model/_client/sql/service.py +35 -13
  32. snowflake/ml/model/_model_composer/model_composer.py +11 -41
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
  34. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
  39. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  40. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
  41. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  42. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
  43. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
  44. snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
  45. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
  46. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
  47. snowflake/ml/model/_packager/model_packager.py +14 -10
  48. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  49. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  50. snowflake/ml/model/type_hints.py +11 -152
  51. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  53. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  54. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
  55. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
  56. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
  57. snowflake/ml/modeling/cluster/birch.py +1 -0
  58. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
  59. snowflake/ml/modeling/cluster/dbscan.py +1 -0
  60. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
  61. snowflake/ml/modeling/cluster/k_means.py +1 -0
  62. snowflake/ml/modeling/cluster/mean_shift.py +1 -0
  63. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
  64. snowflake/ml/modeling/cluster/optics.py +1 -0
  65. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
  66. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
  67. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
  68. snowflake/ml/modeling/compose/column_transformer.py +1 -0
  69. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
  70. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
  71. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
  72. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
  73. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
  74. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
  75. snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
  76. snowflake/ml/modeling/covariance/oas.py +1 -0
  77. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
  78. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
  79. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
  80. snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
  81. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
  85. snowflake/ml/modeling/decomposition/pca.py +1 -0
  86. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
  87. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
  88. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
  89. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
  90. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
  91. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
  92. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
  93. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
  94. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
  95. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
  96. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
  97. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
  99. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
  100. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
  101. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
  102. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
  103. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
  104. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
  105. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
  106. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
  107. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
  108. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
  109. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
  110. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
  111. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
  112. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
  113. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
  116. snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
  117. snowflake/ml/modeling/impute/knn_imputer.py +1 -0
  118. snowflake/ml/modeling/impute/missing_indicator.py +1 -0
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
  127. snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
  129. snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
  133. snowflake/ml/modeling/linear_model/lars.py +1 -0
  134. snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
  135. snowflake/ml/modeling/linear_model/lasso.py +1 -0
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
  140. snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
  150. snowflake/ml/modeling/linear_model/perceptron.py +1 -0
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
  153. snowflake/ml/modeling/linear_model/ridge.py +1 -0
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
  162. snowflake/ml/modeling/manifold/isomap.py +1 -0
  163. snowflake/ml/modeling/manifold/mds.py +1 -0
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
  165. snowflake/ml/modeling/manifold/tsne.py +1 -0
  166. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  167. snowflake/ml/modeling/metrics/ranking.py +0 -3
  168. snowflake/ml/modeling/metrics/regression.py +0 -3
  169. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
  170. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
  181. snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
  191. snowflake/ml/modeling/pipeline/pipeline.py +0 -1
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
  195. snowflake/ml/modeling/svm/linear_svc.py +1 -0
  196. snowflake/ml/modeling/svm/linear_svr.py +1 -0
  197. snowflake/ml/modeling/svm/nu_svc.py +1 -0
  198. snowflake/ml/modeling/svm/nu_svr.py +1 -0
  199. snowflake/ml/modeling/svm/svc.py +1 -0
  200. snowflake/ml/modeling/svm/svr.py +1 -0
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
  209. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  210. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  211. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  212. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  213. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  214. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  215. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  216. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  217. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  218. snowflake/ml/registry/_manager/model_manager.py +4 -4
  219. snowflake/ml/registry/registry.py +165 -6
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/METADATA +30 -9
  222. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/RECORD +225 -249
  223. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/WHEEL +1 -1
  224. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  225. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  226. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  227. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  228. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  229. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  230. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  231. snowflake/ml/_internal/utils/uri.py +0 -77
  232. snowflake/ml/model/_api.py +0 -568
  233. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  234. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  235. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  236. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  237. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  238. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  239. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  240. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  241. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  242. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  243. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  244. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  245. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  246. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  247. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  248. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  249. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  250. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  251. snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
  252. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  253. snowflake/ml/model/deploy_platforms.py +0 -6
  254. snowflake/ml/model/models/llm.py +0 -106
  255. snowflake/ml/monitoring/monitor.py +0 -203
  256. snowflake/ml/registry/_initial_schema.py +0 -142
  257. snowflake/ml/registry/_schema.py +0 -82
  258. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  259. snowflake/ml/registry/_schema_version_manager.py +0 -163
  260. snowflake/ml/registry/model_registry.py +0 -2048
  261. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/LICENSE.txt +0 -0
  262. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,11 @@
1
- import glob
2
1
  import pathlib
3
2
  import tempfile
4
3
  import uuid
5
- import zipfile
6
4
  from types import ModuleType
7
5
  from typing import Any, Dict, List, Optional
8
6
 
9
7
  from absl import logging
10
8
  from packaging import requirements
11
- from typing_extensions import deprecated
12
9
 
13
10
  from snowflake import snowpark
14
11
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
@@ -92,7 +89,7 @@ class ModelComposer:
92
89
  python_version: Optional[str] = None,
93
90
  ext_modules: Optional[List[ModuleType]] = None,
94
91
  code_paths: Optional[List[str]] = None,
95
- model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
92
+ task: model_types.Task = model_types.Task.UNKNOWN,
96
93
  options: Optional[model_types.ModelSaveOption] = None,
97
94
  ) -> model_meta.ModelMetadata:
98
95
  if not options:
@@ -121,25 +118,20 @@ class ModelComposer:
121
118
  python_version=python_version,
122
119
  ext_modules=ext_modules,
123
120
  code_paths=code_paths,
124
- model_objective=model_objective,
121
+ task=task,
125
122
  options=options,
126
123
  )
127
124
  assert self.packager.meta is not None
128
125
 
129
- if not options.get("_legacy_save", False):
130
- # Keep both loose files and zipped file.
131
- # TODO(SNOW-726678): Remove once import a directory is possible.
132
- file_utils.copytree(
133
- str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
134
- )
135
- self.manifest.save(
136
- model_meta=self.packager.meta,
137
- model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
138
- options=options,
139
- data_sources=self._get_data_sources(model, sample_input_data),
140
- )
141
- else:
142
- file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
126
+ file_utils.copytree(
127
+ str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
128
+ )
129
+ self.manifest.save(
130
+ model_meta=self.packager.meta,
131
+ model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
132
+ options=options,
133
+ data_sources=self._get_data_sources(model, sample_input_data),
134
+ )
143
135
 
144
136
  file_utils.upload_directory_to_stage(
145
137
  self.session,
@@ -149,28 +141,6 @@ class ModelComposer:
149
141
  )
150
142
  return model_metadata
151
143
 
152
- @deprecated("Only used by PrPr model registry. Use static method version of load instead.")
153
- def legacy_load(
154
- self,
155
- *,
156
- meta_only: bool = False,
157
- options: Optional[model_types.ModelLoadOption] = None,
158
- ) -> None:
159
- file_utils.download_directory_from_stage(
160
- self.session,
161
- stage_path=self.stage_path,
162
- local_path=self.workspace_path,
163
- statement_params=self._statement_params,
164
- )
165
-
166
- # TODO (Server-side Model Rollout): Remove this section.
167
- model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0])
168
- self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path))
169
-
170
- with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
171
- zf.extractall(path=self._packager_workspace_path)
172
- self.packager.load(meta_only=meta_only, options=options)
173
-
174
144
  @staticmethod
175
145
  def load(
176
146
  workspace_path: pathlib.Path,
@@ -1,6 +1,7 @@
1
1
  import collections
2
- import copy
2
+ import logging
3
3
  import pathlib
4
+ import warnings
4
5
  from typing import List, Optional, cast
5
6
 
6
7
  import yaml
@@ -17,6 +18,9 @@ from snowflake.ml.model._packager.model_meta import (
17
18
  model_meta as model_meta_api,
18
19
  model_meta_schema,
19
20
  )
21
+ from snowflake.ml.model._packager.model_runtime import model_runtime
22
+
23
+ logger = logging.getLogger(__name__)
20
24
 
21
25
 
22
26
  class ModelManifest:
@@ -44,9 +48,30 @@ class ModelManifest:
44
48
  if options is None:
45
49
  options = {}
46
50
 
47
- runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
48
- runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
49
- runtime_to_use.imports.append(str(model_rel_path) + "/")
51
+ if "relax_version" not in options:
52
+ warnings.warn(
53
+ (
54
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
55
+ " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
56
+ "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
57
+ ),
58
+ category=UserWarning,
59
+ stacklevel=2,
60
+ )
61
+ relax_version = options.get("relax_version", True)
62
+
63
+ runtime_to_use = model_runtime.ModelRuntime(
64
+ name=self._DEFAULT_RUNTIME_NAME,
65
+ env=model_meta.env,
66
+ imports=[str(model_rel_path) + "/"],
67
+ is_gpu=False,
68
+ is_warehouse=True,
69
+ )
70
+ if relax_version:
71
+ runtime_to_use.runtime_env.relax_version()
72
+ logger.info("Relaxing version constraints for dependencies in the model.")
73
+ logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
74
+ logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
50
75
  runtime_dict = runtime_to_use.save(
51
76
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
52
77
  )
@@ -21,7 +21,7 @@ _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
21
21
  # The default CUDA version is chosen based on the driver availability in SPCS.
22
22
  # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
23
23
  # make sure they are compatible.
24
- DEFAULT_CUDA_VERSION = "11.7"
24
+ DEFAULT_CUDA_VERSION = "11.8"
25
25
 
26
26
 
27
27
  class ModelEnv:
@@ -199,50 +199,16 @@ class ModelEnv:
199
199
  )
200
200
  if xgboost_spec:
201
201
  self.include_if_absent(
202
- [
203
- ModelDependency(
204
- requirement=f"conda-forge::py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost"
205
- )
206
- ],
202
+ [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
207
203
  check_local_version=False,
208
204
  )
209
205
 
210
- pytorch_spec = env_utils.find_dep_spec(
211
- self._conda_dependencies,
212
- self._pip_requirements,
213
- conda_pkg_name="pytorch",
214
- pip_pkg_name="torch",
215
- remove_spec=True,
216
- )
217
- pytorch_cuda_spec = env_utils.find_dep_spec(
218
- self._conda_dependencies,
219
- self._pip_requirements,
220
- conda_pkg_name="pytorch-cuda",
221
- remove_spec=False,
222
- )
223
- if pytorch_cuda_spec and not pytorch_cuda_spec.specifier.contains(self.cuda_version):
224
- raise ValueError(
225
- "The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is"
226
- " conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda"
227
- " dependencies or pip requirements."
228
- )
229
- if pytorch_spec:
230
- self.include_if_absent(
231
- [ModelDependency(requirement=f"pytorch::pytorch{pytorch_spec.specifier}", pip_name="torch")],
232
- check_local_version=False,
233
- )
234
- if not pytorch_cuda_spec:
235
- self.include_if_absent(
236
- [ModelDependency(requirement=f"pytorch::pytorch-cuda=={self.cuda_version}.*", pip_name="torch")],
237
- check_local_version=False,
238
- )
239
-
240
206
  tf_spec = env_utils.find_dep_spec(
241
207
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
242
208
  )
243
209
  if tf_spec:
244
210
  self.include_if_absent(
245
- [ModelDependency(requirement=f"conda-forge::tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
211
+ [ModelDependency(requirement=f"tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
246
212
  check_local_version=False,
247
213
  )
248
214
 
@@ -252,7 +218,7 @@ class ModelEnv:
252
218
  if transformers_spec:
253
219
  self.include_if_absent(
254
220
  [
255
- ModelDependency(requirement="conda-forge::accelerate>=0.22.0", pip_name="accelerate"),
221
+ ModelDependency(requirement="accelerate>=0.22.0", pip_name="accelerate"),
256
222
  ModelDependency(requirement="scipy>=1.9", pip_name="scipy"),
257
223
  ],
258
224
  check_local_version=False,
@@ -1,17 +1,26 @@
1
1
  import json
2
+ import os
2
3
  import warnings
3
- from typing import Any, Callable, Iterable, Optional, Sequence, cast
4
+ from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
8
  import pandas as pd
8
9
  from absl import logging
9
10
 
11
+ import snowflake.snowpark.dataframe as sp_df
12
+ from snowflake.ml._internal.utils import identifier
10
13
  from snowflake.ml.model import model_signature, type_hints as model_types
11
14
  from snowflake.ml.model._packager.model_meta import model_meta
12
- from snowflake.ml.model._signatures import snowpark_handler
15
+ from snowflake.ml.model._signatures import (
16
+ core,
17
+ snowpark_handler,
18
+ utils as model_signature_utils,
19
+ )
13
20
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
14
21
 
22
+ EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
23
+
15
24
 
16
25
  class NumpyEncoder(json.JSONEncoder):
17
26
  def default(self, obj: Any) -> Any:
@@ -28,6 +37,18 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
28
37
  return callable(getattr(model, method_name, None))
29
38
 
30
39
 
40
+ def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
41
+ trunc_sample_input = model_signature._truncate_data(sample_input_data)
42
+ local_sample_input: model_types.SupportedLocalDataType = None
43
+ if isinstance(sample_input_data, SnowparkDataFrame):
44
+ # Added because of Any from missing stubs.
45
+ trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
46
+ local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
47
+ else:
48
+ local_sample_input = trunc_sample_input
49
+ return local_sample_input
50
+
51
+
31
52
  def validate_signature(
32
53
  model: model_types.SupportedRequireSignatureModelType,
33
54
  model_meta: model_meta.ModelMetadata,
@@ -37,19 +58,23 @@ def validate_signature(
37
58
  ) -> model_meta.ModelMetadata:
38
59
  if model_meta.signatures:
39
60
  validate_target_methods(model, list(model_meta.signatures.keys()))
61
+ if sample_input_data is not None:
62
+ local_sample_input = get_truncated_sample_data(sample_input_data)
63
+ for target_method in model_meta.signatures.keys():
64
+
65
+ model_signature_inst = model_meta.signatures.get(target_method)
66
+ if model_signature_inst is not None:
67
+ # strict validation the input signature
68
+ model_signature._convert_and_validate_local_data(
69
+ local_sample_input, model_signature_inst._inputs, True
70
+ )
40
71
  return model_meta
41
72
 
42
73
  # In this case sample_input_data should be available, because of the check in save_model.
43
74
  assert (
44
75
  sample_input_data is not None
45
76
  ), "Model signature and sample input are None at the same time. This should not happen with local model."
46
- trunc_sample_input = model_signature._truncate_data(sample_input_data)
47
- if isinstance(sample_input_data, SnowparkDataFrame):
48
- # Added because of Any from missing stubs.
49
- trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
50
- local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
51
- else:
52
- local_sample_input = trunc_sample_input
77
+ local_sample_input = get_truncated_sample_data(sample_input_data)
53
78
  for target_method in target_methods:
54
79
  predictions_df = get_prediction_fn(target_method, local_sample_input)
55
80
  sig = model_signature.infer_signature(local_sample_input, predictions_df)
@@ -58,24 +83,55 @@ def validate_signature(
58
83
  return model_meta
59
84
 
60
85
 
86
+ def get_input_signature(
87
+ model_meta: model_meta.ModelMetadata, target_method: Optional[str]
88
+ ) -> Sequence[core.BaseFeatureSpec]:
89
+ if target_method is None or target_method not in model_meta.signatures:
90
+ raise ValueError(f"Signature for target method {target_method} is missing or no method to explain.")
91
+ input_sig = model_meta.signatures[target_method].inputs
92
+ return input_sig
93
+
94
+
61
95
  def add_explain_method_signature(
62
96
  model_meta: model_meta.ModelMetadata,
63
97
  explain_method: str,
64
- target_method: str,
98
+ target_method: Optional[str],
65
99
  output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
66
100
  ) -> model_meta.ModelMetadata:
67
- if target_method not in model_meta.signatures:
68
- raise ValueError(f"Signature for target method {target_method} is missing")
69
- inputs = model_meta.signatures[target_method].inputs
101
+ inputs = get_input_signature(model_meta, target_method)
102
+ if model_meta.model_type == "snowml":
103
+ output_feature_names = [identifier.concat_names([spec.name, "_explanation"]) for spec in inputs]
104
+ else:
105
+ output_feature_names = [f"{spec.name}_explanation" for spec in inputs]
70
106
  model_meta.signatures[explain_method] = model_signature.ModelSignature(
71
107
  inputs=inputs,
72
108
  outputs=[
73
- model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
109
+ model_signature.FeatureSpec(dtype=output_return_type, name=output_name)
110
+ for output_name in output_feature_names
74
111
  ],
75
112
  )
76
113
  return model_meta
77
114
 
78
115
 
116
+ def get_explainability_supported_background(
117
+ sample_input_data: Optional[model_types.SupportedDataType],
118
+ meta: model_meta.ModelMetadata,
119
+ explain_target_method: Optional[str],
120
+ ) -> pd.DataFrame:
121
+ if sample_input_data is None:
122
+ return None
123
+
124
+ if isinstance(sample_input_data, pd.DataFrame):
125
+ return sample_input_data
126
+ if isinstance(sample_input_data, sp_df.DataFrame):
127
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
128
+
129
+ df = model_signature._convert_local_data_to_df(sample_input_data)
130
+ input_signature_for_explain = get_input_signature(meta, explain_target_method)
131
+ df_with_named_cols = model_signature_utils.rename_pandas_df(df, input_signature_for_explain)
132
+ return df_with_named_cols
133
+
134
+
79
135
  def get_target_methods(
80
136
  model: model_types.SupportedModelType,
81
137
  target_methods: Optional[Sequence[str]],
@@ -88,6 +144,23 @@ def get_target_methods(
88
144
  return target_methods
89
145
 
90
146
 
147
+ def save_background_data(
148
+ model_blobs_dir_path: str,
149
+ explain_artifact_dir: str,
150
+ bg_data_file_suffix: str,
151
+ model_name: str,
152
+ background_data: pd.DataFrame,
153
+ ) -> None:
154
+ data_blob_path = os.path.join(model_blobs_dir_path, explain_artifact_dir)
155
+ os.makedirs(data_blob_path, exist_ok=True)
156
+ with open(os.path.join(data_blob_path, model_name + bg_data_file_suffix), "wb") as f:
157
+ # saving only the truncated data
158
+ trunc_background_data = background_data.head(
159
+ min(len(background_data.index), EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT)
160
+ )
161
+ trunc_background_data.to_parquet(f)
162
+
163
+
91
164
  def validate_target_methods(model: model_types.SupportedModelType, target_methods: Iterable[str]) -> None:
92
165
  for method_name in target_methods:
93
166
  if not _is_callable(model, method_name):
@@ -123,25 +196,26 @@ def convert_explanations_to_2D_df(
123
196
  return pd.DataFrame(exp_2d)
124
197
 
125
198
 
126
- def validate_model_objective(
127
- passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective
128
- ) -> model_types.ModelObjective:
129
- if (
130
- passed_model_objective != model_types.ModelObjective.UNKNOWN
131
- and inferred_model_objective != model_types.ModelObjective.UNKNOWN
132
- ):
133
- if passed_model_objective != inferred_model_objective:
199
+ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task: model_types.Task) -> model_types.Task:
200
+ if passed_model_task != model_types.Task.UNKNOWN and inferred_model_task != model_types.Task.UNKNOWN:
201
+ if passed_model_task != inferred_model_task:
134
202
  warnings.warn(
135
- f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
136
- f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored",
203
+ f"Inferred Task: {inferred_model_task.name} is used as task for this model "
204
+ f"version and passed argument Task: {passed_model_task.name} is ignored",
137
205
  category=UserWarning,
138
206
  stacklevel=1,
139
207
  )
140
- return inferred_model_objective
141
- elif inferred_model_objective != model_types.ModelObjective.UNKNOWN:
142
- logging.info(
143
- f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
144
- f"version"
145
- )
146
- return inferred_model_objective
147
- return passed_model_objective
208
+ return inferred_model_task
209
+ elif inferred_model_task != model_types.Task.UNKNOWN:
210
+ logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
211
+ return inferred_model_task
212
+ return passed_model_task
213
+
214
+
215
+ def get_explain_target_method(
216
+ model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
217
+ ) -> Optional[str]:
218
+ for method in model_metadata.signatures.keys():
219
+ if method in target_methods_list:
220
+ return method
221
+ return None
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
4
 
4
5
  import numpy as np
@@ -8,7 +9,11 @@ from typing_extensions import TypeGuard, Unpack
8
9
  from snowflake.ml._internal import type_utils
9
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
10
11
  from snowflake.ml.model._packager.model_env import model_env
11
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
12
+ from snowflake.ml.model._packager.model_handlers import (
13
+ _base,
14
+ _utils as handlers_utils,
15
+ model_objective_utils,
16
+ )
12
17
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
18
  from snowflake.ml.model._packager.model_meta import (
14
19
  model_blob_meta,
@@ -32,22 +37,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
32
37
 
33
38
  MODEL_BLOB_FILE_OR_DIR = "model.bin"
34
39
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
-
36
- @classmethod
37
- def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective:
38
- import catboost
39
-
40
- if isinstance(model, catboost.CatBoostClassifier):
41
- num_classes = handlers_utils.get_num_classes_if_exists(model)
42
- if num_classes == 2:
43
- return model_types.ModelObjective.BINARY_CLASSIFICATION
44
- return model_types.ModelObjective.MULTI_CLASSIFICATION
45
- if isinstance(model, catboost.CatBoostRanker):
46
- return model_types.ModelObjective.RANKING
47
- if isinstance(model, catboost.CatBoostRegressor):
48
- return model_types.ModelObjective.REGRESSION
49
- # TODO: Find out model type from the generic Catboost Model
50
- return model_types.ModelObjective.UNKNOWN
40
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
51
41
 
52
42
  @classmethod
53
43
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
@@ -107,25 +97,34 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
107
97
  sample_input_data=sample_input_data,
108
98
  get_prediction_fn=get_prediction,
109
99
  )
110
- inferred_model_objective = cls.get_model_objective_and_output_type(model)
111
- model_meta.model_objective = handlers_utils.validate_model_objective(
112
- model_meta.model_objective, inferred_model_objective
113
- )
114
- model_objective = model_meta.model_objective
100
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
101
+ model_meta.task = model_task_and_output.task
115
102
  if enable_explainability:
116
- output_type = model_signature.DataType.DOUBLE
117
- if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
118
- output_type = model_signature.DataType.STRING
103
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
119
104
  model_meta = handlers_utils.add_explain_method_signature(
120
105
  model_meta=model_meta,
121
106
  explain_method="explain",
122
- target_method="predict",
123
- output_return_type=output_type,
107
+ target_method=explain_target_method,
108
+ output_return_type=model_task_and_output.output_type,
124
109
  )
125
110
  model_meta.function_properties = {
126
111
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
127
112
  }
128
113
 
114
+ background_data = handlers_utils.get_explainability_supported_background(
115
+ sample_input_data, model_meta, explain_target_method
116
+ )
117
+ if background_data is not None:
118
+ handlers_utils.save_background_data(
119
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
120
+ )
121
+ else:
122
+ warnings.warn(
123
+ "sample_input_data should be provided for better explainability results",
124
+ category=UserWarning,
125
+ stacklevel=1,
126
+ )
127
+
129
128
  model_blob_path = os.path.join(model_blobs_dir_path, name)
130
129
  os.makedirs(model_blob_path, exist_ok=True)
131
130
  model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
@@ -242,10 +242,10 @@ class HuggingFacePipelineHandler(
242
242
  task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
243
243
  )
244
244
  if framework is None or framework == "pt":
245
- # Since we set default cuda version to be 11.7, to make sure it works with GPU, we need to have a default
246
- # Pytorch version that works with CUDA 11.7 as well. This is required for huggingface pipelines only as
245
+ # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
246
+ # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
247
247
  # users are not required to install pytorch locally if they are using the wrapper.
248
- pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch"))
248
+ pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
249
249
  elif framework == "tf":
250
250
  pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
251
251
  model_meta.env.include_if_absent(
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import (
3
4
  TYPE_CHECKING,
4
5
  Any,
@@ -47,6 +48,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
47
48
 
48
49
  MODEL_BLOB_FILE_OR_DIR = "model.pkl"
49
50
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
51
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
50
52
 
51
53
  @classmethod
52
54
  def can_handle(
@@ -111,21 +113,34 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
111
113
  sample_input_data=sample_input_data,
112
114
  get_prediction_fn=get_prediction,
113
115
  )
114
- model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
115
- model_meta.model_objective = handlers_utils.validate_model_objective(
116
- model_meta.model_objective, model_objective_and_output.objective
117
- )
116
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
117
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
118
118
  if enable_explainability:
119
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
119
120
  model_meta = handlers_utils.add_explain_method_signature(
120
121
  model_meta=model_meta,
121
122
  explain_method="explain",
122
- target_method="predict",
123
- output_return_type=model_objective_and_output.output_type,
123
+ target_method=explain_target_method,
124
+ output_return_type=model_task_and_output.output_type,
124
125
  )
125
126
  model_meta.function_properties = {
126
127
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
127
128
  }
128
129
 
130
+ background_data = handlers_utils.get_explainability_supported_background(
131
+ sample_input_data, model_meta, explain_target_method
132
+ )
133
+ if background_data is not None:
134
+ handlers_utils.save_background_data(
135
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
136
+ )
137
+ else:
138
+ warnings.warn(
139
+ "sample_input_data should be provided for better explainability results",
140
+ category=UserWarning,
141
+ stacklevel=1,
142
+ )
143
+
129
144
  model_blob_path = os.path.join(model_blobs_dir_path, name)
130
145
  os.makedirs(model_blob_path, exist_ok=True)
131
146
 
@@ -168,11 +168,6 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
168
168
  ) -> "mlflow.pyfunc.PyFuncModel":
169
169
  import mlflow
170
170
 
171
- if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
172
- # We need to redirect the mlruns folder to a writable location in the sandbox.
173
- tmpdir = tempfile.TemporaryDirectory(dir="/tmp")
174
- mlflow.set_tracking_uri(f"file://{tmpdir}")
175
-
176
171
  model_blob_path = os.path.join(model_blobs_dir_path, name)
177
172
  model_blobs_metadata = model_meta.models
178
173
  model_blob_metadata = model_blobs_metadata[name]
@@ -183,6 +178,9 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
183
178
  model_artifact_path = model_blob_options["artifact_path"]
184
179
  model_blob_filename = model_blob_metadata.path
185
180
 
181
+ if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
182
+ return mlflow.pyfunc.load_model(os.path.join(model_blob_path, model_blob_filename, model_artifact_path))
183
+
186
184
  # This is to make sure the loaded model can be saved again.
187
185
  with mlflow.start_run() as run:
188
186
  mlflow.log_artifacts(