snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Mapping, Optional
1
+ from typing import Any, Mapping, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.utils import (
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
15
15
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
16
16
 
17
17
 
18
- def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
18
+ def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
19
19
  sql_list = ", ".join([f"'{column}'" for column in columns])
20
20
  return f"({sql_list})"
21
21
 
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
60
60
  function_name: str,
61
61
  warehouse_name: sql_identifier.SqlIdentifier,
62
62
  timestamp_column: sql_identifier.SqlIdentifier,
63
- id_columns: List[sql_identifier.SqlIdentifier],
64
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
65
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
66
- actual_score_columns: List[sql_identifier.SqlIdentifier],
67
- actual_class_columns: List[sql_identifier.SqlIdentifier],
63
+ id_columns: list[sql_identifier.SqlIdentifier],
64
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
65
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
66
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
67
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
68
68
  refresh_interval: str,
69
69
  aggregation_window: str,
70
70
  baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
71
71
  baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
72
72
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
73
- statement_params: Optional[Dict[str, Any]] = None,
73
+ statement_params: Optional[dict[str, Any]] = None,
74
74
  ) -> None:
75
75
  baseline_sql = ""
76
76
  if baseline:
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
103
103
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
104
104
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
105
105
  monitor_name: sql_identifier.SqlIdentifier,
106
- statement_params: Optional[Dict[str, Any]] = None,
106
+ statement_params: Optional[dict[str, Any]] = None,
107
107
  ) -> None:
108
108
  search_database_name = database_name or self._database_name
109
109
  search_schema_name = schema_name or self._schema_name
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
116
116
  def show_model_monitors(
117
117
  self,
118
118
  *,
119
- statement_params: Optional[Dict[str, Any]] = None,
120
- ) -> List[snowpark.Row]:
119
+ statement_params: Optional[dict[str, Any]] = None,
120
+ ) -> list[snowpark.Row]:
121
121
  fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
122
122
  return (
123
123
  query_result_checker.SqlResultValidator(
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
135
135
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
136
136
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
137
137
  monitor_name: sql_identifier.SqlIdentifier,
138
- statement_params: Optional[Dict[str, Any]] = None,
138
+ statement_params: Optional[dict[str, Any]] = None,
139
139
  ) -> bool:
140
140
  search_database_name = database_name or self._database_name
141
141
  search_schema_name = schema_name or self._schema_name
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
153
153
  def validate_monitor_warehouse(
154
154
  self,
155
155
  warehouse_name: sql_identifier.SqlIdentifier,
156
- statement_params: Optional[Dict[str, Any]] = None,
156
+ statement_params: Optional[dict[str, Any]] = None,
157
157
  ) -> None:
158
158
  """Validate warehouse provided for monitoring exists.
159
159
 
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
177
177
  *,
178
178
  source_column_schema: Mapping[str, types.DataType],
179
179
  timestamp_column: sql_identifier.SqlIdentifier,
180
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
181
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
182
- actual_score_columns: List[sql_identifier.SqlIdentifier],
183
- actual_class_columns: List[sql_identifier.SqlIdentifier],
184
- id_columns: List[sql_identifier.SqlIdentifier],
180
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
181
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
182
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
183
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
184
+ id_columns: list[sql_identifier.SqlIdentifier],
185
185
  ) -> None:
186
186
  """Ensures all columns exist in the source table.
187
187
 
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
221
221
  source_schema: Optional[sql_identifier.SqlIdentifier],
222
222
  source: sql_identifier.SqlIdentifier,
223
223
  timestamp_column: sql_identifier.SqlIdentifier,
224
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
225
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
226
- actual_score_columns: List[sql_identifier.SqlIdentifier],
227
- actual_class_columns: List[sql_identifier.SqlIdentifier],
228
- id_columns: List[sql_identifier.SqlIdentifier],
224
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
225
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
226
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
227
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
228
+ id_columns: list[sql_identifier.SqlIdentifier],
229
229
  ) -> None:
230
230
  source_database = source_database or self._database_name
231
231
  source_schema = source_schema or self._schema_name
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
250
250
  self,
251
251
  operation: str,
252
252
  monitor_name: sql_identifier.SqlIdentifier,
253
- statement_params: Optional[Dict[str, Any]] = None,
253
+ statement_params: Optional[dict[str, Any]] = None,
254
254
  ) -> None:
255
255
  if operation not in {"SUSPEND", "RESUME"}:
256
256
  raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
263
263
  def suspend_monitor(
264
264
  self,
265
265
  monitor_name: sql_identifier.SqlIdentifier,
266
- statement_params: Optional[Dict[str, Any]] = None,
266
+ statement_params: Optional[dict[str, Any]] = None,
267
267
  ) -> None:
268
268
  self._alter_monitor(
269
269
  operation="SUSPEND",
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
274
274
  def resume_monitor(
275
275
  self,
276
276
  monitor_name: sql_identifier.SqlIdentifier,
277
- statement_params: Optional[Dict[str, Any]] = None,
277
+ statement_params: Optional[dict[str, Any]] = None,
278
278
  ) -> None:
279
279
  self._alter_monitor(
280
280
  operation="RESUME",
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from snowflake import snowpark
5
5
  from snowflake.ml._internal.utils import sql_identifier
@@ -20,7 +20,7 @@ class ModelMonitorManager:
20
20
  database_name: sql_identifier.SqlIdentifier,
21
21
  schema_name: sql_identifier.SqlIdentifier,
22
22
  *,
23
- statement_params: Optional[Dict[str, Any]] = None,
23
+ statement_params: Optional[dict[str, Any]] = None,
24
24
  ) -> None:
25
25
  """
26
26
  Opens a ModelMonitorManager for a given database and schema.
@@ -64,7 +64,7 @@ class ModelMonitorManager:
64
64
  f"Found: {existing_target_methods}."
65
65
  )
66
66
 
67
- def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
67
+ def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
68
68
  return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
69
69
 
70
70
  def add_monitor(
@@ -172,7 +172,7 @@ class ModelMonitorManager:
172
172
  """
173
173
  rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
174
 
175
- def model_match_fn(model_details: Dict[str, str]) -> bool:
175
+ def model_match_fn(model_details: dict[str, str]) -> bool:
176
176
  return (
177
177
  model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
178
  and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
@@ -215,7 +215,7 @@ class ModelMonitorManager:
215
215
  name=monitor_name_id,
216
216
  )
217
217
 
218
- def show_model_monitors(self) -> List[snowpark.Row]:
218
+ def show_model_monitors(self) -> list[snowpark.Row]:
219
219
  """Show all model monitors in the registry.
220
220
 
221
221
  Returns:
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import List, Optional
2
+ from typing import Optional
3
3
 
4
4
  from snowflake.ml.model._client.model import model_version_impl
5
5
 
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
14
14
  timestamp_column: str
15
15
  """Name of column in the source containing timestamp."""
16
16
 
17
- id_columns: List[str]
17
+ id_columns: list[str]
18
18
  """List of columns in the source containing unique identifiers."""
19
19
 
20
- prediction_score_columns: Optional[List[str]] = None
20
+ prediction_score_columns: Optional[list[str]] = None
21
21
  """List of columns in the source containing prediction scores.
22
22
  Can be regression scores for regression models and probability scores for classification models."""
23
23
 
24
- prediction_class_columns: Optional[List[str]] = None
24
+ prediction_class_columns: Optional[list[str]] = None
25
25
  """List of columns in the source containing prediction classes for classification models."""
26
26
 
27
- actual_score_columns: Optional[List[str]] = None
27
+ actual_score_columns: Optional[list[str]] = None
28
28
  """List of columns in the source containing actual scores."""
29
29
 
30
- actual_class_columns: Optional[List[str]] = None
30
+ actual_class_columns: Optional[list[str]] = None
31
31
  """List of columns in the source containing actual classes for classification models."""
32
32
 
33
33
  baseline: Optional[str] = None
@@ -0,0 +1,286 @@
1
+ from typing import Union, cast, overload
2
+
3
+ import altair as alt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import snowflake.snowpark.dataframe as sp_df
8
+ from snowflake import snowpark
9
+ from snowflake.ml.model import model_signature, type_hints
10
+ from snowflake.ml.model._signatures import snowpark_handler
11
+
12
+
13
+ @overload
14
+ def plot_force(
15
+ shap_row: snowpark.Row,
16
+ features_row: snowpark.Row,
17
+ base_value: float = 0.0,
18
+ figsize: tuple[float, float] = (600, 200),
19
+ contribution_threshold: float = 0.05,
20
+ ) -> alt.LayerChart:
21
+ ...
22
+
23
+
24
+ @overload
25
+ def plot_force(
26
+ shap_row: pd.Series,
27
+ features_row: pd.Series,
28
+ base_value: float = 0.0,
29
+ figsize: tuple[float, float] = (600, 200),
30
+ contribution_threshold: float = 0.05,
31
+ ) -> alt.LayerChart:
32
+ ...
33
+
34
+
35
+ def plot_force(
36
+ shap_row: Union[pd.Series, snowpark.Row],
37
+ features_row: Union[pd.Series, snowpark.Row],
38
+ base_value: float = 0.0,
39
+ figsize: tuple[float, float] = (600, 200),
40
+ contribution_threshold: float = 0.05,
41
+ ) -> alt.LayerChart:
42
+ """
43
+ Create a force plot for SHAP values with stacked bars based on influence direction.
44
+
45
+ Args:
46
+ shap_row: pandas Series or snowpark Row containing SHAP values for a specific instance
47
+ features_row: pandas Series or snowpark Row containing the feature values for the same instance
48
+ base_value: base value of the predictions. Defaults to 0, but is usually the model's average prediction
49
+ figsize: tuple of (width, height) for the plot
50
+ contribution_threshold:
51
+ Only features with magnitude greater than contribution_threshold as a percentage of the
52
+ total absolute SHAP values will be plotted. Defaults to 0.05 (5%)
53
+
54
+ Returns:
55
+ Altair chart object
56
+ """
57
+ if isinstance(shap_row, snowpark.Row):
58
+ shap_row = pd.Series(shap_row.as_dict())
59
+ if isinstance(features_row, snowpark.Row):
60
+ features_row = pd.Series(features_row.as_dict())
61
+
62
+ # Create a dataframe for plotting
63
+ positive_label = "Positive"
64
+ negative_label = "Negative"
65
+ plot_df = pd.DataFrame(
66
+ [
67
+ {
68
+ "feature": feature,
69
+ "feature_value": features_row.iloc[index],
70
+ "feature_annotated": f"{feature}: {features_row.iloc[index]}",
71
+ "influence_value": shap_row.iloc[index],
72
+ "bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
73
+ }
74
+ for index, feature in enumerate(features_row.index)
75
+ ]
76
+ )
77
+
78
+ # Calculate cumulative positions for the stacked bars
79
+ shap_sum = np.sum(shap_row)
80
+ current_position_pos = shap_sum
81
+ current_position_neg = shap_sum
82
+ positions = []
83
+
84
+ total_abs_value_sum = np.sum(plot_df["influence_value"].abs())
85
+ max_abs_value = plot_df["influence_value"].abs().max()
86
+ spacing = max_abs_value * 0.07 # Use 2% of max value as spacing between bars
87
+
88
+ # Sort by absolute value to have largest impacts first
89
+ plot_df = plot_df.reindex(plot_df["influence_value"].abs().sort_values(ascending=False).index)
90
+ for _, row in plot_df.iterrows():
91
+ # Skip features with small contributions
92
+ row_influence_value = row["influence_value"]
93
+ if abs(row_influence_value) / total_abs_value_sum < contribution_threshold:
94
+ continue
95
+
96
+ if row_influence_value >= 0:
97
+ start = current_position_pos - spacing
98
+ end = current_position_pos - row_influence_value
99
+ current_position_pos = end
100
+ else:
101
+ start = current_position_neg + spacing
102
+ end = current_position_neg + abs(row_influence_value)
103
+ current_position_neg = end
104
+
105
+ positions.append(
106
+ {
107
+ "start": start,
108
+ "end": end,
109
+ "avg": (start + end) / 2,
110
+ "influence_value": row_influence_value,
111
+ "influence_annotated": f"Influence: {row_influence_value}",
112
+ "feature_value": row["feature_value"],
113
+ "feature_annotated": row["feature_annotated"],
114
+ "bar_direction": row["bar_direction"],
115
+ }
116
+ )
117
+
118
+ position_df = pd.DataFrame(positions)
119
+
120
+ # Create force plot using Altair
121
+ blue_color = "#1f77b4"
122
+ red_color = "#d62728"
123
+ width, height = figsize
124
+ bars: alt.Chart = (
125
+ alt.Chart(position_df)
126
+ .mark_bar(size=10)
127
+ .encode(
128
+ x=alt.X("start:Q", title="Feature Impact"),
129
+ x2=alt.X2("end:Q"),
130
+ color=alt.Color(
131
+ "bar_direction:N",
132
+ scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
133
+ legend=alt.Legend(title="Influence Direction"),
134
+ ),
135
+ tooltip=["influence_value", "feature_value"],
136
+ )
137
+ .properties(title="Feature Influence (SHAP values)", width=width, height=height)
138
+ ).interactive()
139
+
140
+ arrow: alt.Chart = (
141
+ alt.Chart(position_df)
142
+ .mark_point(shape="triangle", filled=True, fillOpacity=1)
143
+ .encode(
144
+ x=alt.X("start:Q"),
145
+ angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
146
+ color=alt.Color(
147
+ "bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
148
+ ),
149
+ size=alt.SizeValue(300),
150
+ tooltip=alt.value(None),
151
+ )
152
+ )
153
+
154
+ # Add a vertical line at the base value
155
+ zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
156
+
157
+ # Add text labels on each bar
158
+ feature_labels = (
159
+ alt.Chart(position_df)
160
+ .mark_text(align="center", baseline="line-bottom", dy=30, fontSize=11)
161
+ .encode(
162
+ x=alt.X("avg:Q"),
163
+ text=alt.Text("feature_annotated:N"), # Display with 2 decimal places
164
+ color=alt.value("grey"), # Label color for positive values
165
+ tooltip=["feature_value"],
166
+ )
167
+ )
168
+
169
+ return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow)
170
+
171
+
172
+ def plot_influence_sensitivity(
173
+ feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float] = (600, 400)
174
+ ) -> alt.Chart:
175
+ """
176
+ Create a SHAP dependence scatter plot for a specific feature.
177
+
178
+ Args:
179
+ feature_values: pandas Series containing the feature values for a specific feature
180
+ shap_values: pandas Series containing the SHAP values for the same feature
181
+ figsize: tuple of (width, height) for the plot
182
+
183
+ Returns:
184
+ Altair chart object
185
+
186
+ """
187
+
188
+ unique_vals = np.sort(np.unique(feature_values.values))
189
+ max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
190
+ points_per_value = len(feature_values.values) / len(unique_vals)
191
+ is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
192
+
193
+ kwargs = (
194
+ {
195
+ "x": alt.X("feature_value:N", title="Feature Value"),
196
+ "color": alt.Color("feature_value:N").legend(None),
197
+ "xOffset": "jitter:Q",
198
+ }
199
+ if is_categorical
200
+ else {"x": alt.X("feature_value:Q", title="Feature Value")}
201
+ )
202
+
203
+ # Create a dataframe for plotting
204
+ plot_df = pd.DataFrame({"feature_value": feature_values, "shap_value": shap_values})
205
+
206
+ width, height = figsize
207
+
208
+ # Create scatter plot
209
+ scatter = (
210
+ alt.Chart(plot_df)
211
+ .transform_calculate(jitter="random()")
212
+ .mark_circle(size=60, opacity=0.7)
213
+ .encode(
214
+ y=alt.Y("shap_value:Q", title="SHAP Value"),
215
+ tooltip=["feature_value", "shap_value"],
216
+ **kwargs,
217
+ )
218
+ .properties(title="SHAP Dependence Scatter Plot", width=width, height=height)
219
+ )
220
+
221
+ return cast(alt.Chart, scatter)
222
+
223
+
224
+ def plot_violin(
225
+ shap_df: type_hints.SupportedDataType,
226
+ feature_df: type_hints.SupportedDataType,
227
+ figsize: tuple[float, float] = (600, 200),
228
+ ) -> alt.Chart:
229
+ """
230
+ Create a violin plot per feature showing the distribution of SHAP values.
231
+
232
+ Args:
233
+ shap_df: 2D array containing SHAP values for multiple features
234
+ feature_df: 2D array containing the corresponding feature values
235
+ figsize: tuple of (width, height) for the plot
236
+
237
+ Returns:
238
+ Altair chart object
239
+ """
240
+
241
+ shap_df_pd = _convert_to_pandas_df(shap_df)
242
+ feature_df_pd = _convert_to_pandas_df(feature_df)
243
+
244
+ # Assert that the input dataframes are 2D
245
+ assert len(shap_df_pd.shape) == 2, f"shap_df must be 2D, but got shape {shap_df_pd.shape}"
246
+ assert len(feature_df_pd.shape) == 2, f"feature_df must be 2D, but got shape {feature_df_pd.shape}"
247
+
248
+ # Prepare data for plotting
249
+ plot_data = pd.DataFrame(
250
+ {
251
+ "feature_name": feature_df_pd.columns.repeat(shap_df_pd.shape[0]),
252
+ "shap_value": shap_df_pd.transpose().values.flatten(),
253
+ }
254
+ )
255
+
256
+ # Order the rows by the absolute sum of SHAP values per feature
257
+ feature_abs_sum = shap_df_pd.abs().sum(axis=0)
258
+ sorted_features = feature_abs_sum.sort_values(ascending=False).index
259
+ column_sort_order = [feature_df_pd.columns[shap_df_pd.columns.get_loc(col)] for col in sorted_features]
260
+
261
+ # Create the violin plot
262
+ width, height = figsize
263
+ violin = (
264
+ alt.Chart(plot_data)
265
+ .transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
266
+ .mark_area(orient="vertical")
267
+ .encode(
268
+ y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
269
+ x=alt.X("shap_value:Q", title="SHAP Value"),
270
+ row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
271
+ color=alt.Color("feature_name:N", legend=None),
272
+ tooltip=["feature_name", "shap_value"],
273
+ )
274
+ .properties(width=width, height=height)
275
+ ).interactive()
276
+
277
+ return cast(alt.Chart, violin)
278
+
279
+
280
+ def _convert_to_pandas_df(
281
+ data: type_hints.SupportedDataType,
282
+ ) -> pd.DataFrame:
283
+ if isinstance(data, sp_df.DataFrame):
284
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(data)
285
+
286
+ return model_signature._convert_local_data_to_df(data)