snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
4
 
4
5
  import numpy as np
@@ -8,7 +9,11 @@ from typing_extensions import TypeGuard, Unpack
8
9
  from snowflake.ml._internal import type_utils
9
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
10
11
  from snowflake.ml.model._packager.model_env import model_env
11
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
12
+ from snowflake.ml.model._packager.model_handlers import (
13
+ _base,
14
+ _utils as handlers_utils,
15
+ model_objective_utils,
16
+ )
12
17
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
18
  from snowflake.ml.model._packager.model_meta import (
14
19
  model_blob_meta,
@@ -32,22 +37,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
32
37
 
33
38
  MODEL_BLOB_FILE_OR_DIR = "model.bin"
34
39
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
-
36
- @classmethod
37
- def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
38
- import catboost
39
-
40
- if isinstance(model, catboost.CatBoostClassifier):
41
- num_classes = handlers_utils.get_num_classes_if_exists(model)
42
- if num_classes == 2:
43
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
44
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
45
- if isinstance(model, catboost.CatBoostRanker):
46
- return model_meta_schema.ModelObjective.RANKING
47
- if isinstance(model, catboost.CatBoostRegressor):
48
- return model_meta_schema.ModelObjective.REGRESSION
49
- # TODO: Find out model type from the generic Catboost Model
50
- return model_meta_schema.ModelObjective.UNKNOWN
40
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
51
41
 
52
42
  @classmethod
53
43
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
@@ -77,6 +67,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
77
67
  is_sub_model: Optional[bool] = False,
78
68
  **kwargs: Unpack[model_types.CatBoostModelSaveOptions],
79
69
  ) -> None:
70
+ enable_explainability = kwargs.get("enable_explainability", True)
71
+
80
72
  import catboost
81
73
 
82
74
  assert isinstance(model, catboost.CatBoost)
@@ -105,22 +97,34 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
105
97
  sample_input_data=sample_input_data,
106
98
  get_prediction_fn=get_prediction,
107
99
  )
108
- model_objective = cls.get_model_objective(model)
109
- model_meta.model_objective = model_objective
110
- if kwargs.get("enable_explainability", True):
111
- output_type = model_signature.DataType.DOUBLE
112
- if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
113
- output_type = model_signature.DataType.STRING
100
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
101
+ model_meta.task = model_task_and_output.task
102
+ if enable_explainability:
103
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
114
104
  model_meta = handlers_utils.add_explain_method_signature(
115
105
  model_meta=model_meta,
116
106
  explain_method="explain",
117
- target_method="predict",
118
- output_return_type=output_type,
107
+ target_method=explain_target_method,
108
+ output_return_type=model_task_and_output.output_type,
119
109
  )
120
110
  model_meta.function_properties = {
121
111
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
122
112
  }
123
113
 
114
+ background_data = handlers_utils.get_explainability_supported_background(
115
+ sample_input_data, model_meta, explain_target_method
116
+ )
117
+ if background_data is not None:
118
+ handlers_utils.save_background_data(
119
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
120
+ )
121
+ else:
122
+ warnings.warn(
123
+ "sample_input_data should be provided for better explainability results",
124
+ category=UserWarning,
125
+ stacklevel=1,
126
+ )
127
+
124
128
  model_blob_path = os.path.join(model_blobs_dir_path, name)
125
129
  os.makedirs(model_blob_path, exist_ok=True)
126
130
  model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
@@ -143,11 +147,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
143
147
  ],
144
148
  check_local_version=True,
145
149
  )
146
- if kwargs.get("enable_explainability", True):
147
- model_meta.env.include_if_absent(
148
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
149
- check_local_version=True,
150
- )
150
+ if enable_explainability:
151
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
151
152
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
152
153
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
153
154
 
@@ -242,10 +242,10 @@ class HuggingFacePipelineHandler(
242
242
  task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
243
243
  )
244
244
  if framework is None or framework == "pt":
245
- # Since we set default cuda version to be 11.7, to make sure it works with GPU, we need to have a default
246
- # Pytorch version that works with CUDA 11.7 as well. This is required for huggingface pipelines only as
245
+ # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
246
+ # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
247
247
  # users are not required to install pytorch locally if they are using the wrapper.
248
- pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch"))
248
+ pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
249
249
  elif framework == "tf":
250
250
  pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
251
251
  model_meta.env.include_if_absent(
@@ -369,7 +369,9 @@ class HuggingFacePipelineHandler(
369
369
  else:
370
370
  # For others, we could offer the whole dataframe as a list.
371
371
  # Some of them may need some conversion
372
- if isinstance(raw_model, transformers.ConversationalPipeline):
372
+ if hasattr(transformers, "ConversationalPipeline") and isinstance(
373
+ raw_model, transformers.ConversationalPipeline
374
+ ):
373
375
  input_data = [
374
376
  transformers.Conversation(
375
377
  text=conv_data["user_inputs"][0],
@@ -391,27 +393,33 @@ class HuggingFacePipelineHandler(
391
393
  # Making it not aligned with the auto-inferred signature.
392
394
  # If the output is a dict, we could blindly create a list containing that.
393
395
  # Otherwise, creating pandas DataFrame won't succeed.
394
- if isinstance(temp_res, (dict, transformers.Conversation)) or (
395
- # For some pipeline that is expected to generate a list of dict per input
396
- # When it omit outer list, it becomes list of dict instead of list of list of dict.
397
- # We need to distinguish them from those pipelines that designed to output a dict per input
398
- # So we need to check the pipeline type.
399
- isinstance(
400
- raw_model,
401
- (
402
- transformers.FillMaskPipeline,
403
- transformers.QuestionAnsweringPipeline,
404
- ),
396
+ if (
397
+ (hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation))
398
+ or isinstance(temp_res, dict)
399
+ or (
400
+ # For some pipeline that is expected to generate a list of dict per input
401
+ # When it omit outer list, it becomes list of dict instead of list of list of dict.
402
+ # We need to distinguish them from those pipelines that designed to output a dict per input
403
+ # So we need to check the pipeline type.
404
+ isinstance(
405
+ raw_model,
406
+ (
407
+ transformers.FillMaskPipeline,
408
+ transformers.QuestionAnsweringPipeline,
409
+ ),
410
+ )
411
+ and X.shape[0] == 1
412
+ and isinstance(temp_res[0], dict)
405
413
  )
406
- and X.shape[0] == 1
407
- and isinstance(temp_res[0], dict)
408
414
  ):
409
415
  temp_res = [temp_res]
410
416
 
411
417
  if len(temp_res) == 0:
412
418
  return pd.DataFrame()
413
419
 
414
- if isinstance(raw_model, transformers.ConversationalPipeline):
420
+ if hasattr(transformers, "ConversationalPipeline") and isinstance(
421
+ raw_model, transformers.ConversationalPipeline
422
+ ):
415
423
  temp_res = [[conv.generated_responses] for conv in temp_res]
416
424
 
417
425
  # To concat those who outputs a list with one input.
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import (
3
4
  TYPE_CHECKING,
4
5
  Any,
@@ -19,7 +20,11 @@ from typing_extensions import TypeGuard, Unpack
19
20
  from snowflake.ml._internal import type_utils
20
21
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
21
22
  from snowflake.ml.model._packager.model_env import model_env
22
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
23
+ from snowflake.ml.model._packager.model_handlers import (
24
+ _base,
25
+ _utils as handlers_utils,
26
+ model_objective_utils,
27
+ )
23
28
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
24
29
  from snowflake.ml.model._packager.model_meta import (
25
30
  model_blob_meta,
@@ -43,47 +48,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
43
48
 
44
49
  MODEL_BLOB_FILE_OR_DIR = "model.pkl"
45
50
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
- _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
47
- _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
48
- _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
49
- _REGRESSION_OBJECTIVES = [
50
- "regression",
51
- "regression_l1",
52
- "huber",
53
- "fair",
54
- "poisson",
55
- "quantile",
56
- "tweedie",
57
- "mape",
58
- "gamma",
59
- ]
60
-
61
- @classmethod
62
- def get_model_objective(
63
- cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
64
- ) -> model_meta_schema.ModelObjective:
65
- import lightgbm
66
-
67
- # does not account for cross-entropy and custom
68
- if isinstance(model, lightgbm.LGBMClassifier):
69
- num_classes = handlers_utils.get_num_classes_if_exists(model)
70
- if num_classes == 2:
71
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
72
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
73
- if isinstance(model, lightgbm.LGBMRanker):
74
- return model_meta_schema.ModelObjective.RANKING
75
- if isinstance(model, lightgbm.LGBMRegressor):
76
- return model_meta_schema.ModelObjective.REGRESSION
77
- model_objective = model.params["objective"]
78
- if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
79
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
80
- if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
81
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
82
- if model_objective in cls._RANKING_OBJECTIVES:
83
- return model_meta_schema.ModelObjective.RANKING
84
- if model_objective in cls._REGRESSION_OBJECTIVES:
85
- return model_meta_schema.ModelObjective.REGRESSION
86
- return model_meta_schema.ModelObjective.UNKNOWN
51
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
87
52
 
88
53
  @classmethod
89
54
  def can_handle(
@@ -118,6 +83,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
118
83
  is_sub_model: Optional[bool] = False,
119
84
  **kwargs: Unpack[model_types.LGBMModelSaveOptions],
120
85
  ) -> None:
86
+ enable_explainability = kwargs.get("enable_explainability", True)
87
+
121
88
  import lightgbm
122
89
 
123
90
  assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
@@ -146,25 +113,34 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
146
113
  sample_input_data=sample_input_data,
147
114
  get_prediction_fn=get_prediction,
148
115
  )
149
- model_objective = cls.get_model_objective(model)
150
- model_meta.model_objective = model_objective
151
- if kwargs.get("enable_explainability", True):
152
- output_type = model_signature.DataType.DOUBLE
153
- if model_objective in [
154
- model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
155
- model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
156
- ]:
157
- output_type = model_signature.DataType.STRING
116
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
117
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
118
+ if enable_explainability:
119
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
158
120
  model_meta = handlers_utils.add_explain_method_signature(
159
121
  model_meta=model_meta,
160
122
  explain_method="explain",
161
- target_method="predict",
162
- output_return_type=output_type,
123
+ target_method=explain_target_method,
124
+ output_return_type=model_task_and_output.output_type,
163
125
  )
164
126
  model_meta.function_properties = {
165
127
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
166
128
  }
167
129
 
130
+ background_data = handlers_utils.get_explainability_supported_background(
131
+ sample_input_data, model_meta, explain_target_method
132
+ )
133
+ if background_data is not None:
134
+ handlers_utils.save_background_data(
135
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
136
+ )
137
+ else:
138
+ warnings.warn(
139
+ "sample_input_data should be provided for better explainability results",
140
+ category=UserWarning,
141
+ stacklevel=1,
142
+ )
143
+
168
144
  model_blob_path = os.path.join(model_blobs_dir_path, name)
169
145
  os.makedirs(model_blob_path, exist_ok=True)
170
146
 
@@ -189,11 +165,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
189
165
  ],
190
166
  check_local_version=True,
191
167
  )
192
- if kwargs.get("enable_explainability", True):
193
- model_meta.env.include_if_absent(
194
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
195
- check_local_version=True,
196
- )
168
+ if enable_explainability:
169
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
197
170
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
198
171
 
199
172
  return None
@@ -168,11 +168,6 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
168
168
  ) -> "mlflow.pyfunc.PyFuncModel":
169
169
  import mlflow
170
170
 
171
- if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
172
- # We need to redirect the mlruns folder to a writable location in the sandbox.
173
- tmpdir = tempfile.TemporaryDirectory(dir="/tmp")
174
- mlflow.set_tracking_uri(f"file://{tmpdir}")
175
-
176
171
  model_blob_path = os.path.join(model_blobs_dir_path, name)
177
172
  model_blobs_metadata = model_meta.models
178
173
  model_blob_metadata = model_blobs_metadata[name]
@@ -183,6 +178,9 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
183
178
  model_artifact_path = model_blob_options["artifact_path"]
184
179
  model_blob_filename = model_blob_metadata.path
185
180
 
181
+ if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
182
+ return mlflow.pyfunc.load_model(os.path.join(model_blob_path, model_blob_filename, model_artifact_path))
183
+
186
184
  # This is to make sure the loaded model can be saved again.
187
185
  with mlflow.start_run() as run:
188
186
  mlflow.log_artifacts(
@@ -0,0 +1,169 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Union
4
+
5
+ from snowflake.ml._internal import type_utils
6
+ from snowflake.ml.model import model_signature, type_hints
7
+ from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
8
+
9
+ if TYPE_CHECKING:
10
+ import catboost
11
+ import lightgbm
12
+ import sklearn
13
+ import sklearn.pipeline
14
+ import xgboost
15
+
16
+
17
+ @dataclass
18
+ class ModelTaskAndOutputType:
19
+ task: type_hints.Task
20
+ output_type: model_signature.DataType
21
+
22
+
23
+ def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]) -> type_hints.Task:
24
+ from sklearn.base import is_classifier, is_regressor
25
+
26
+ if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
27
+ return type_hints.Task.UNKNOWN
28
+ if is_regressor(model):
29
+ return type_hints.Task.TABULAR_REGRESSION
30
+ if is_classifier(model):
31
+ classes_list = getattr(model, "classes_", [])
32
+ num_classes = getattr(model, "n_classes_", None) or len(classes_list)
33
+ if isinstance(num_classes, int):
34
+ if num_classes > 2:
35
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
36
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
37
+ return type_hints.Task.UNKNOWN
38
+ return type_hints.Task.UNKNOWN
39
+
40
+
41
+ def get_model_task_catboost(model: "catboost.CatBoost") -> type_hints.Task:
42
+ loss_function = None
43
+ if type_utils.LazyType("catboost.CatBoost").isinstance(model):
44
+ loss_function = model.get_all_params()["loss_function"] # type: ignore[attr-defined]
45
+
46
+ if (type_utils.LazyType("catboost.CatBoostClassifier").isinstance(model)) or model._is_classification_objective(
47
+ loss_function
48
+ ):
49
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
50
+ if num_classes == 0:
51
+ return type_hints.Task.UNKNOWN
52
+ if num_classes <= 2:
53
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
54
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
55
+ if (type_utils.LazyType("catboost.CatBoostRanker").isinstance(model)) or model._is_ranking_objective(loss_function):
56
+ return type_hints.Task.TABULAR_RANKING
57
+ if (type_utils.LazyType("catboost.CatBoostRegressor").isinstance(model)) or model._is_regression_objective(
58
+ loss_function
59
+ ):
60
+ return type_hints.Task.TABULAR_REGRESSION
61
+
62
+ return type_hints.Task.UNKNOWN
63
+
64
+
65
+ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.Task:
66
+
67
+ _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
68
+ _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
69
+ _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
70
+ _REGRESSION_OBJECTIVES = [
71
+ "regression",
72
+ "regression_l1",
73
+ "huber",
74
+ "fair",
75
+ "poisson",
76
+ "quantile",
77
+ "tweedie",
78
+ "mape",
79
+ "gamma",
80
+ ]
81
+
82
+ # does not account for cross-entropy and custom
83
+ model_task = ""
84
+ if type_utils.LazyType("lightgbm.Booster").isinstance(model):
85
+ model_task = model.params["objective"] # type: ignore[attr-defined]
86
+ elif hasattr(model, "objective_"):
87
+ model_task = model.objective_
88
+ if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
89
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
90
+ if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
91
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
92
+ if model_task in _RANKING_OBJECTIVES:
93
+ return type_hints.Task.TABULAR_RANKING
94
+ if model_task in _REGRESSION_OBJECTIVES:
95
+ return type_hints.Task.TABULAR_REGRESSION
96
+ return type_hints.Task.UNKNOWN
97
+
98
+
99
+ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.Task:
100
+
101
+ _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
102
+ _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
103
+ _RANKING_OBJECTIVE_PREFIX = ["rank:"]
104
+ _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
105
+
106
+ model_task = ""
107
+ if type_utils.LazyType("xgboost.Booster").isinstance(model):
108
+ model_params = json.loads(model.save_config()) # type: ignore[attr-defined]
109
+ model_task = model_params.get("learner", {}).get("objective", "")
110
+ else:
111
+ if hasattr(model, "get_params"):
112
+ model_task = model.get_params().get("objective", "")
113
+
114
+ if isinstance(model_task, dict):
115
+ model_task = model_task.get("name", "")
116
+ for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
117
+ if classification_objective in model_task:
118
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
119
+ for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
120
+ if classification_objective in model_task:
121
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
122
+ for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
123
+ if ranking_objective in model_task:
124
+ return type_hints.Task.TABULAR_RANKING
125
+ for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
126
+ if regression_objective in model_task:
127
+ return type_hints.Task.TABULAR_REGRESSION
128
+ return type_hints.Task.UNKNOWN
129
+
130
+
131
+ def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
132
+ if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
133
+ model
134
+ ):
135
+ task = get_model_task_xgb(model)
136
+ output_type = model_signature.DataType.DOUBLE
137
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
138
+ output_type = model_signature.DataType.STRING
139
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
140
+
141
+ if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
142
+ "lightgbm.LGBMModel"
143
+ ).isinstance(model):
144
+ task = get_model_task_lightgbm(model)
145
+ output_type = model_signature.DataType.DOUBLE
146
+ if task in [
147
+ type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
148
+ type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
149
+ ]:
150
+ output_type = model_signature.DataType.STRING
151
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
152
+
153
+ if type_utils.LazyType("catboost.CatBoost").isinstance(model):
154
+ task = get_model_task_catboost(model)
155
+ output_type = model_signature.DataType.DOUBLE
156
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
157
+ output_type = model_signature.DataType.STRING
158
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
159
+
160
+ if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
161
+ "sklearn.pipeline.Pipeline"
162
+ ).isinstance(model):
163
+ task = get_task_skl(model)
164
+ output_type = model_signature.DataType.DOUBLE
165
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
166
+ output_type = model_signature.DataType.STRING
167
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
168
+
169
+ raise ValueError(f"Model type {type(model)} is not supported")
@@ -2,7 +2,6 @@ import logging
2
2
  import os
3
3
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
4
 
5
- import cloudpickle
6
5
  import pandas as pd
7
6
  from typing_extensions import TypeGuard, Unpack
8
7
 
@@ -120,9 +119,21 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
120
119
  model_meta.env.include_if_absent(
121
120
  [
122
121
  model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
122
+ model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
123
+ model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
123
124
  ],
124
125
  check_local_version=True,
125
126
  )
127
+ model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
128
+
129
+ @staticmethod
130
+ def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
131
+ if kwargs.get("device", None) is not None:
132
+ return kwargs["device"]
133
+ elif kwargs.get("use_gpu", False):
134
+ return "cuda"
135
+
136
+ return None
126
137
 
127
138
  @classmethod
128
139
  def load_model(
@@ -144,13 +155,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
144
155
  model_blob_filename = model_blob_metadata.path
145
156
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
146
157
 
147
- if os.path.isdir(model_blob_file_or_dir_path): # if the saved model is a directory
148
- model = sentence_transformers.SentenceTransformer(model_blob_file_or_dir_path)
149
- else:
150
- assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
151
- with open(model_blob_file_or_dir_path, "rb") as f:
152
- model = cloudpickle.load(f)
153
- assert isinstance(model, sentence_transformers.SentenceTransformer)
158
+ model = sentence_transformers.SentenceTransformer(
159
+ model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
160
+ )
154
161
  return model
155
162
 
156
163
  @classmethod