snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
- def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
27
+ def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
28
28
  if list(sigs.keys()) != ["encode"]:
29
29
  raise ValueError("target_methods can only be ['encode']")
30
30
 
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
48
48
  HANDLER_TYPE = "sentence_transformers"
49
49
  HANDLER_VERSION = "2024-03-15"
50
50
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
51
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
51
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
52
52
 
53
53
  MODEL_BLOB_FILE_OR_DIR = "model"
54
54
  DEFAULT_TARGET_METHODS = ["encode"]
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
166
166
  ],
167
167
  check_local_version=True,
168
168
  )
169
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
169
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
170
170
 
171
171
  @staticmethod
172
172
  def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
224
224
  def _create_custom_model(
225
225
  raw_model: "sentence_transformers.SentenceTransformer",
226
226
  model_meta: model_meta_api.ModelMetadata,
227
- ) -> Type[custom_model.CustomModel]:
227
+ ) -> type[custom_model.CustomModel]:
228
228
  batch_size = cast(
229
229
  model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
230
230
  ).get("batch_size", None)
@@ -1,13 +1,13 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  from typing_extensions import TypeGuard, Unpack
9
9
 
10
- from snowflake.ml._internal import type_utils
10
+ from snowflake.ml._internal import env, type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
13
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
19
19
  )
20
20
  from snowflake.ml.model._packager.model_task import model_task_utils
21
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
- from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
23
22
 
24
23
  if TYPE_CHECKING:
25
24
  import sklearn.base
@@ -39,6 +38,35 @@ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "s
39
38
  return model
40
39
 
41
40
 
41
+ def _apply_transforms_up_to_last_step(
42
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
43
+ data: model_types.SupportedDataType,
44
+ input_feature_names: Optional[list[str]] = None,
45
+ ) -> pd.DataFrame:
46
+ """Apply all transformations in the sklearn pipeline model up to the last step."""
47
+ transformed_data = data
48
+ output_features_names = input_feature_names
49
+
50
+ if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
51
+ for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
52
+ if not hasattr(step, "transform"):
53
+ raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
54
+ transformed_data = step.transform(transformed_data)
55
+ if output_features_names is None:
56
+ continue
57
+ elif hasattr(step, "get_feature_names_out"):
58
+ output_features_names = step.get_feature_names_out(output_features_names)
59
+ else:
60
+ raise ValueError(
61
+ f"Step '{step_name}' in the pipeline does not have a 'get_feature_names_out' method. "
62
+ "Feature names cannot be propagated."
63
+ )
64
+ if type_utils.LazyType("scipy.sparse.csr_matrix").isinstance(transformed_data):
65
+ # Convert to dense array if it's a sparse matrix
66
+ transformed_data = transformed_data.toarray() # type: ignore[attr-defined]
67
+ return pd.DataFrame(transformed_data, columns=output_features_names)
68
+
69
+
42
70
  @final
43
71
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
44
72
  """Handler for scikit-learn based model.
@@ -49,7 +77,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
49
77
  HANDLER_TYPE = "sklearn"
50
78
  HANDLER_VERSION = "2023-12-01"
51
79
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
52
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
80
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
53
81
 
54
82
  DEFAULT_TARGET_METHODS = [
55
83
  "predict",
@@ -59,7 +87,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
59
87
  "decision_function",
60
88
  "score_samples",
61
89
  ]
62
- EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
90
+
91
+ # Prioritize predict_proba as it gives multi-class probabilities
92
+ EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
63
93
 
64
94
  @classmethod
65
95
  def can_handle(
@@ -113,7 +143,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
113
143
  raise ValueError("Sample input data is required to enable explainability.")
114
144
 
115
145
  # If this is a pipeline and we are in the container runtime, check for distributed estimator.
116
- if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
146
+ if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
117
147
  model = _unpack_container_runtime_pipeline(model)
118
148
 
119
149
  if not is_sub_model:
@@ -161,17 +191,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
161
191
  stacklevel=1,
162
192
  )
163
193
  enable_explainability = False
164
- elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
194
+ elif model_meta.task == model_types.Task.UNKNOWN:
195
+ enable_explainability = False
196
+ elif explain_target_method is None:
165
197
  enable_explainability = False
166
198
  else:
167
199
  enable_explainability = True
168
200
  if enable_explainability:
169
- model_meta = handlers_utils.add_explain_method_signature(
170
- model_meta=model_meta,
171
- explain_method="explain",
172
- target_method=explain_target_method,
173
- output_return_type=model_task_and_output_type.output_type,
201
+ explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
202
+
203
+ input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
204
+ transformed_background_data = _apply_transforms_up_to_last_step(
205
+ model=model,
206
+ data=background_data,
207
+ input_feature_names=[spec.name for spec in input_signature],
174
208
  )
209
+
210
+ try:
211
+ model_meta = handlers_utils.add_inferred_explain_method_signature(
212
+ model_meta=model_meta,
213
+ explain_method="explain",
214
+ target_method=explain_target_method,
215
+ background_data=background_data,
216
+ explain_fn=cls._build_explain_fn(model, background_data, input_signature),
217
+ output_feature_names=transformed_background_data.columns,
218
+ )
219
+ except ValueError:
220
+ if kwargs.get("enable_explainability", None):
221
+ # user explicitly enabled explainability, so we should raise the error
222
+ raise ValueError(
223
+ "Explainability for this model is not supported. Please set `enable_explainability=False`"
224
+ )
225
+
175
226
  handlers_utils.save_background_data(
176
227
  model_blobs_dir_path,
177
228
  cls.EXPLAIN_ARTIFACTS_DIR,
@@ -223,11 +274,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
223
274
  )
224
275
 
225
276
  if enable_explainability:
226
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
277
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
227
278
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
228
279
 
229
280
  model_meta.env.include_if_absent(
230
- [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
281
+ [
282
+ model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
283
+ ],
231
284
  check_local_version=True,
232
285
  )
233
286
 
@@ -265,7 +318,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
265
318
  def _create_custom_model(
266
319
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
267
320
  model_meta: model_meta_api.ModelMetadata,
268
- ) -> Type[custom_model.CustomModel]:
321
+ ) -> type[custom_model.CustomModel]:
269
322
  def fn_factory(
270
323
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
271
324
  signature: model_signature.ModelSignature,
@@ -287,37 +340,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
287
340
 
288
341
  @custom_model.inference_api
289
342
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
290
- import shap
291
-
292
- try:
293
- explainer = shap.Explainer(raw_model, background_data)
294
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
295
- except TypeError:
296
- try:
297
- dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
298
-
299
- if isinstance(X, pd.DataFrame):
300
- X = X.astype(dtype_map, copy=False)
301
- if hasattr(raw_model, "predict_proba"):
302
- if isinstance(X, np.ndarray):
303
- explanations = shap.Explainer(
304
- raw_model.predict_proba, background_data.values # type: ignore[union-attr]
305
- )(X).values
306
- else:
307
- explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
308
- elif hasattr(raw_model, "predict"):
309
- if isinstance(X, np.ndarray):
310
- explanations = shap.Explainer(
311
- raw_model.predict, background_data.values # type: ignore[union-attr]
312
- )(X).values
313
- else:
314
- explanations = shap.Explainer(raw_model.predict, background_data)(X).values
315
- else:
316
- raise ValueError("Missing any supported target method to explain.")
317
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
318
- except TypeError as e:
319
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
320
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
343
+ fn = cls._build_explain_fn(raw_model, background_data, signature.inputs)
344
+ return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
321
345
 
322
346
  if target_method == "explain":
323
347
  return explain_fn
@@ -340,3 +364,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
340
364
  skl_model = _SKLModel(custom_model.ModelContext())
341
365
 
342
366
  return skl_model
367
+
368
+ @classmethod
369
+ def _build_explain_fn(
370
+ cls,
371
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
372
+ background_data: model_types.SupportedDataType,
373
+ input_specs: Sequence[model_signature.BaseFeatureSpec],
374
+ ) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
375
+ import shap
376
+ import sklearn.pipeline
377
+
378
+ transformed_bg_data = _apply_transforms_up_to_last_step(model, background_data)
379
+
380
+ def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
381
+ transformed_data = _apply_transforms_up_to_last_step(model, data)
382
+ predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
383
+ try:
384
+ explainer = shap.Explainer(predictor, transformed_bg_data)
385
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
386
+ except TypeError:
387
+ if isinstance(data, pd.DataFrame):
388
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
389
+ transformed_data = _apply_transforms_up_to_last_step(model, data.astype(dtype_map))
390
+ for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
391
+ if not hasattr(predictor, explain_target_method):
392
+ continue
393
+ explain_target_method_fn = getattr(predictor, explain_target_method)
394
+ explanations = shap.Explainer(explain_target_method_fn, transformed_bg_data.values)(
395
+ transformed_data.to_numpy()
396
+ ).values
397
+ return handlers_utils.convert_explanations_to_2D_df(model, explanations)
398
+ raise ValueError("Missing any supported target method to explain.")
399
+
400
+ return explain_fn
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
36
36
  HANDLER_TYPE = "snowml"
37
37
  HANDLER_VERSION = "2023-12-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
40
40
 
41
41
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
42
42
  EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
264
264
  def _create_custom_model(
265
265
  raw_model: "BaseEstimator",
266
266
  model_meta: model_meta_api.ModelMetadata,
267
- ) -> Type[custom_model.CustomModel]:
267
+ ) -> type[custom_model.CustomModel]:
268
268
  def fn_factory(
269
269
  raw_model: "BaseEstimator",
270
270
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from packaging import version
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
38
38
  HANDLER_TYPE = "tensorflow"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
43
  "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
44
  }
@@ -88,6 +88,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
88
88
  import tensorflow
89
89
 
90
90
  assert isinstance(model, tensorflow.Module)
91
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
92
 
92
93
  is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
93
94
  is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
@@ -112,8 +113,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
112
113
  default_target_methods=default_target_methods,
113
114
  )
114
115
 
115
- multiple_inputs = kwargs.get("multiple_inputs", False)
116
-
117
116
  if is_keras_model and len(target_methods) > 1:
118
117
  raise ValueError("Keras model can only have one target method.")
119
118
 
@@ -188,7 +187,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
188
187
  dependencies,
189
188
  check_local_version=True,
190
189
  )
191
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
190
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
192
191
 
193
192
  @classmethod
194
193
  def load_model(
@@ -198,7 +197,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
198
197
  model_blobs_dir_path: str,
199
198
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
200
199
  ) -> "tensorflow.Module":
201
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
202
200
  import tensorflow
203
201
 
204
202
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -209,7 +207,12 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
209
207
  load_path = os.path.join(model_blob_path, model_blob_filename)
210
208
  save_format = model_blob_options.get("save_format", "keras_tf")
211
209
  if save_format == "keras_tf":
212
- m = tensorflow.keras.models.load_model(load_path)
210
+ if version.parse(tensorflow.keras.__version__) >= version.parse("3.0.0"):
211
+ import tf_keras
212
+
213
+ m = tf_keras.models.load_model(load_path)
214
+ else:
215
+ m = tensorflow.keras.models.load_model(load_path)
213
216
  else:
214
217
  m = tensorflow.saved_model.load(load_path)
215
218
 
@@ -230,7 +233,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
230
233
  def _create_custom_model(
231
234
  raw_model: "tensorflow.Module",
232
235
  model_meta: model_meta_api.ModelMetadata,
233
- ) -> Type[custom_model.CustomModel]:
236
+ ) -> type[custom_model.CustomModel]:
234
237
  multiple_inputs = cast(
235
238
  model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
239
  )["multiple_inputs"]
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from typing_extensions import TypeGuard, Unpack
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
36
36
  HANDLER_TYPE = "torchscript"
37
37
  HANDLER_VERSION = "2025-03-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
40
40
  "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
41
  }
42
42
 
@@ -76,6 +76,8 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
76
76
  if enable_explainability:
77
77
  raise NotImplementedError("Explainability is not supported for Torch Script model.")
78
78
 
79
+ multiple_inputs = kwargs.get("multiple_inputs", False)
80
+
79
81
  import torch
80
82
 
81
83
  assert isinstance(model, torch.jit.ScriptModule)
@@ -87,8 +89,6 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
87
89
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
88
90
  )
89
91
 
90
- multiple_inputs = kwargs.get("multiple_inputs", False)
91
-
92
92
  def get_prediction(
93
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
94
94
  ) -> model_types.SupportedLocalDataType:
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
141
141
  model_meta.env.include_if_absent(
142
142
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
143
143
  )
144
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
144
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
145
145
 
146
146
  @classmethod
147
147
  def load_model(
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
181
181
  def _create_custom_model(
182
182
  raw_model: "torch.jit.ScriptModule",
183
183
  model_meta: model_meta_api.ModelMetadata,
184
- ) -> Type[custom_model.CustomModel]:
184
+ ) -> type[custom_model.CustomModel]:
185
185
  def fn_factory(
186
186
  raw_model: "torch.jit.ScriptModule",
187
187
  signature: model_signature.ModelSignature,
@@ -1,17 +1,7 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- Optional,
10
- Type,
11
- Union,
12
- cast,
13
- final,
14
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
15
5
 
16
6
  import numpy as np
17
7
  import pandas as pd
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
44
34
  HANDLER_TYPE = "xgboost"
45
35
  HANDLER_VERSION = "2023-12-01"
46
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
47
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
48
38
 
49
39
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
50
40
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -154,7 +144,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
154
144
  model_type=cls.HANDLER_TYPE,
155
145
  handler_version=cls.HANDLER_VERSION,
156
146
  path=cls.MODEL_BLOB_FILE_OR_DIR,
157
- options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
147
+ options=model_meta_schema.XgboostModelBlobOptions(
148
+ {
149
+ "xgb_estimator_type": model.__class__.__name__,
150
+ "enable_categorical": getattr(model, "enable_categorical", False),
151
+ }
152
+ ),
158
153
  )
159
154
  model_meta.models[name] = base_meta
160
155
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -162,11 +157,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
162
157
  model_meta.env.include_if_absent(
163
158
  [
164
159
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
165
- ],
166
- check_local_version=True,
167
- )
168
- model_meta.env.include_if_absent(
169
- [
170
160
  model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
171
161
  ],
172
162
  check_local_version=True,
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
175
165
  if enable_explainability:
176
166
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
177
167
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
178
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
168
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
179
169
 
180
170
  @classmethod
181
171
  def load_model(
@@ -200,6 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
200
190
  raise ValueError("Type of XGB estimator is illegal.")
201
191
  m = getattr(xgboost, xgb_estimator_type)()
202
192
  m.load_model(os.path.join(model_blob_path, model_blob_filename))
193
+ m.enable_categorical = model_blob_options.get("enable_categorical", False)
203
194
 
204
195
  if kwargs.get("use_gpu", False):
205
196
  assert type(kwargs.get("use_gpu", False)) == bool
@@ -227,7 +218,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
227
218
  def _create_custom_model(
228
219
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
229
220
  model_meta: model_meta_api.ModelMetadata,
230
- ) -> Type[custom_model.CustomModel]:
221
+ ) -> type[custom_model.CustomModel]:
231
222
  def fn_factory(
232
223
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
233
224
  signature: model_signature.ModelSignature,
@@ -235,8 +226,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
235
226
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
236
227
  @custom_model.inference_api
237
228
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
229
+ enable_categorical = False
230
+ for col, d_type in X.dtypes.items():
231
+ if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
232
+ continue
233
+ if not np.issubdtype(d_type, np.number):
234
+ # categorical columns are converted to numpy's str dtype
235
+ X[col] = X[col].astype("category")
236
+ enable_categorical = True
238
237
  if isinstance(raw_model, xgboost.Booster):
239
- X = xgboost.DMatrix(X)
238
+ X = xgboost.DMatrix(X, enable_categorical=enable_categorical)
240
239
 
241
240
  res = getattr(raw_model, target_method)(X)
242
241
 
@@ -261,7 +260,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
261
260
  return explain_fn
262
261
  return fn
263
262
 
264
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
263
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
265
264
  for target_method_name, sig in model_meta.signatures.items():
266
265
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
267
266
 
@@ -1,4 +1,4 @@
1
- from typing import Dict, cast
1
+ from typing import cast
2
2
 
3
3
  from typing_extensions import Unpack
4
4
 
@@ -25,7 +25,7 @@ class ModelBlobMeta:
25
25
  self.handler_version = kwargs["handler_version"]
26
26
  self.function_properties = kwargs.get("function_properties", {})
27
27
 
28
- self.artifacts: Dict[str, str] = {}
28
+ self.artifacts: dict[str, str] = {}
29
29
  artifacts = kwargs.get("artifacts", None)
30
30
  if artifacts:
31
31
  self.artifacts = artifacts