snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import pathlib
3
- from typing import List, Optional, TypedDict, Union
3
+ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
@@ -137,8 +137,8 @@ class ModelMethod:
137
137
  )
138
138
 
139
139
  outputs: Union[
140
- List[model_manifest_schema.ModelMethodSignatureField],
141
- List[model_manifest_schema.ModelMethodSignatureFieldWithName],
140
+ list[model_manifest_schema.ModelMethodSignatureField],
141
+ list[model_manifest_schema.ModelMethodSignatureFieldWithName],
142
142
  ]
143
143
  if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
144
144
  outputs = [
@@ -3,10 +3,11 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, Dict, List, Optional
6
+ from typing import DefaultDict, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
10
+ from snowflake.ml import version as snowml_version
10
11
  from snowflake.ml._internal import env as snowml_env, env_utils
11
12
  from snowflake.ml.model._packager.model_meta import model_meta_schema
12
13
 
@@ -19,9 +20,8 @@ _DEFAULT_CONDA_ENV_FILENAME = "conda.yml"
19
20
  _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
20
21
 
21
22
  # The default CUDA version is chosen based on the driver availability in SPCS.
22
- # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
23
- # make sure they are compatible.
24
- DEFAULT_CUDA_VERSION = "11.8"
23
+ # Make sure they are aligned with default CUDA version in inference server.
24
+ DEFAULT_CUDA_VERSION = "12.4"
25
25
 
26
26
 
27
27
  class ModelEnv:
@@ -38,15 +38,16 @@ class ModelEnv:
38
38
  self.prefer_pip: bool = prefer_pip
39
39
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
40
40
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
41
- self.artifact_repository_map: Optional[Dict[str, str]] = None
42
- self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
43
- self._pip_requirements: List[requirements.Requirement] = []
41
+ self.artifact_repository_map: Optional[dict[str, str]] = None
42
+ self.resource_constraint: Optional[dict[str, str]] = None
43
+ self._conda_dependencies: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
44
+ self._pip_requirements: list[requirements.Requirement] = []
44
45
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
45
46
  self._cuda_version: Optional[version.Version] = None
46
- self._snowpark_ml_version: version.Version = version.parse(snowml_env.VERSION)
47
+ self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
47
48
 
48
49
  @property
49
- def conda_dependencies(self) -> List[str]:
50
+ def conda_dependencies(self) -> list[str]:
50
51
  """List of conda channel and dependencies from that to run the model"""
51
52
  return sorted(
52
53
  f"{chan}::{str(req)}" if chan else str(req)
@@ -57,24 +58,24 @@ class ModelEnv:
57
58
  @conda_dependencies.setter
58
59
  def conda_dependencies(
59
60
  self,
60
- conda_dependencies: Optional[List[str]] = None,
61
+ conda_dependencies: Optional[list[str]] = None,
61
62
  ) -> None:
62
63
  self._conda_dependencies = env_utils.validate_conda_dependency_string_list(
63
- conda_dependencies if conda_dependencies else []
64
+ conda_dependencies if conda_dependencies else [], add_local_version_specifier=True
64
65
  )
65
66
 
66
67
  @property
67
- def pip_requirements(self) -> List[str]:
68
+ def pip_requirements(self) -> list[str]:
68
69
  """List of pip Python packages requirements for running the model."""
69
70
  return sorted(list(map(str, self._pip_requirements)))
70
71
 
71
72
  @pip_requirements.setter
72
73
  def pip_requirements(
73
74
  self,
74
- pip_requirements: Optional[List[str]] = None,
75
+ pip_requirements: Optional[list[str]] = None,
75
76
  ) -> None:
76
77
  self._pip_requirements = env_utils.validate_pip_requirement_string_list(
77
- pip_requirements if pip_requirements else []
78
+ pip_requirements if pip_requirements else [], add_local_version_specifier=True
78
79
  )
79
80
 
80
81
  @property
@@ -117,7 +118,7 @@ class ModelEnv:
117
118
 
118
119
  def include_if_absent(
119
120
  self,
120
- pkgs: List[ModelDependency],
121
+ pkgs: list[ModelDependency],
121
122
  check_local_version: bool = False,
122
123
  ) -> None:
123
124
  """Append requirements into model env if absent. Depending on the environment, requirements may be added
@@ -128,7 +129,7 @@ class ModelEnv:
128
129
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
129
130
  """
130
131
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
131
- pip_pkg_reqs: List[str] = []
132
+ pip_pkg_reqs: list[str] = []
132
133
  warnings.warn(
133
134
  (
134
135
  "Dependencies specified from pip requirements."
@@ -145,7 +146,7 @@ class ModelEnv:
145
146
  else:
146
147
  self._include_if_absent_conda(pkgs, check_local_version)
147
148
 
148
- def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
149
+ def _include_if_absent_conda(self, pkgs: list[ModelDependency], check_local_version: bool = False) -> None:
149
150
  """Append requirements into model env conda dependencies if absent.
150
151
 
151
152
  Args:
@@ -190,7 +191,7 @@ class ModelEnv:
190
191
  stacklevel=2,
191
192
  )
192
193
 
193
- def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
194
+ def _include_if_absent_pip(self, pkgs: list[str], check_local_version: bool = False) -> None:
194
195
  """Append pip requirements into model env pip requirements if absent.
195
196
 
196
197
  Args:
@@ -207,7 +208,7 @@ class ModelEnv:
207
208
  except env_utils.DuplicateDependencyError:
208
209
  pass
209
210
 
210
- def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
211
+ def remove_if_present_conda(self, conda_pkgs: list[str]) -> None:
211
212
  """Remove conda requirements from model env if present.
212
213
 
213
214
  Args:
@@ -352,13 +353,14 @@ class ModelEnv:
352
353
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
353
354
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
354
355
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
355
- self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
356
+ self.artifact_repository_map = env_dict.get("artifact_repository_map")
357
+ self.resource_constraint = env_dict.get("resource_constraint")
356
358
 
357
359
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
358
360
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
359
361
 
360
362
  self.python_version = env_dict["python_version"]
361
- self.cuda_version = env_dict.get("cuda_version", None)
363
+ self.cuda_version = env_dict.get("cuda_version")
362
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
363
365
 
364
366
  def save_as_dict(
@@ -381,7 +383,8 @@ class ModelEnv:
381
383
  return {
382
384
  "conda": self.conda_env_rel_path.as_posix(),
383
385
  "pip": self.pip_requirements_rel_path.as_posix(),
384
- "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
386
+ "artifact_repository_map": self.artifact_repository_map or {},
387
+ "resource_constraint": self.resource_constraint or {},
385
388
  "python_version": self.python_version,
386
389
  "cuda_version": self.cuda_version,
387
390
  "snowpark_ml_version": self.snowpark_ml_version,
@@ -389,7 +392,7 @@ class ModelEnv:
389
392
 
390
393
  def validate_with_local_env(
391
394
  self, check_snowpark_ml_version: bool = False
392
- ) -> List[env_utils.IncorrectLocalEnvironmentError]:
395
+ ) -> list[env_utils.IncorrectLocalEnvironmentError]:
393
396
  errors = []
394
397
  try:
395
398
  env_utils.validate_py_runtime_version(str(self._python_version))
@@ -413,10 +416,10 @@ class ModelEnv:
413
416
 
414
417
  if check_snowpark_ml_version:
415
418
  # For Modeling model
416
- if self._snowpark_ml_version.base_version != snowml_env.VERSION:
419
+ if self._snowpark_ml_version.base_version != snowml_version.VERSION:
417
420
  errors.append(
418
421
  env_utils.IncorrectLocalEnvironmentError(
419
- f"The local installed version of Snowpark ML library is {snowml_env.VERSION} "
422
+ f"The local installed version of Snowpark ML library is {snowml_version.VERSION} "
420
423
  f"which differs from required version {self.snowpark_ml_version}."
421
424
  )
422
425
  )
@@ -2,13 +2,13 @@ import functools
2
2
  import importlib
3
3
  import pkgutil
4
4
  from types import ModuleType
5
- from typing import Any, Callable, Dict, Optional, Type, TypeVar, cast
5
+ from typing import Any, Callable, Optional, TypeVar, cast
6
6
 
7
7
  from snowflake.ml.model import type_hints as model_types
8
8
  from snowflake.ml.model._packager.model_handlers import _base
9
9
 
10
10
  _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
11
- _MODEL_HANDLER_REGISTRY: Dict[str, Type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
11
+ _MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
12
12
  _IS_HANDLER_LOADED = False
13
13
 
14
14
 
@@ -54,7 +54,7 @@ def ensure_handlers_registration(fn: F) -> F:
54
54
  @ensure_handlers_registration
55
55
  def find_handler(
56
56
  model: model_types.SupportedModelType,
57
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
57
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
58
58
  for handler in _MODEL_HANDLER_REGISTRY.values():
59
59
  if handler.can_handle(model):
60
60
  return handler
@@ -64,7 +64,7 @@ def find_handler(
64
64
  @ensure_handlers_registration
65
65
  def load_handler(
66
66
  target_model_type: model_types.SupportedModelHandlerType,
67
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
67
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
68
68
  for model_type, handler in _MODEL_HANDLER_REGISTRY.items():
69
69
  if target_model_type == model_type:
70
70
  return handler
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from abc import abstractmethod
3
- from typing import Dict, Generic, Optional, Protocol, Type, final
3
+ from typing import Generic, Optional, Protocol, final
4
4
 
5
5
  import pandas as pd
6
6
  from typing_extensions import TypeGuard, Unpack
@@ -14,7 +14,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
14
14
  HANDLER_TYPE: model_types.SupportedModelHandlerType
15
15
  HANDLER_VERSION: str
16
16
  _MIN_SNOWPARK_ML_VERSION: str
17
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]]
17
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]]
18
18
 
19
19
  @classmethod
20
20
  @abstractmethod
@@ -1,8 +1,9 @@
1
+ import importlib
1
2
  import json
2
3
  import os
3
4
  import pathlib
4
5
  import warnings
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
6
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
6
7
 
7
8
  import numpy as np
8
9
  import numpy.typing as npt
@@ -10,8 +11,10 @@ import pandas as pd
10
11
  from absl import logging
11
12
 
12
13
  import snowflake.snowpark.dataframe as sp_df
14
+ from snowflake.ml._internal import env
13
15
  from snowflake.ml._internal.utils import identifier
14
16
  from snowflake.ml.model import model_signature, type_hints as model_types
17
+ from snowflake.ml.model._packager.model_env import model_env
15
18
  from snowflake.ml.model._packager.model_meta import model_meta
16
19
  from snowflake.ml.model._signatures import (
17
20
  core,
@@ -106,6 +109,35 @@ def get_input_signature(
106
109
  return input_sig
107
110
 
108
111
 
112
+ def add_inferred_explain_method_signature(
113
+ model_meta: model_meta.ModelMetadata,
114
+ explain_method: str,
115
+ target_method: str,
116
+ background_data: model_types.SupportedDataType,
117
+ explain_fn: Callable[[model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
118
+ output_feature_names: Optional[Sequence[str]] = None,
119
+ ) -> model_meta.ModelMetadata:
120
+ inputs = get_input_signature(model_meta, target_method)
121
+ if output_feature_names is None: # If not provided, assume output feature names are the same as input feature names
122
+ output_feature_names = [spec.name for spec in inputs]
123
+
124
+ if model_meta.model_type == "snowml":
125
+ suffixed_output_names = [identifier.concat_names([name, "_explanation"]) for name in output_feature_names]
126
+ else:
127
+ suffixed_output_names = [f"{name}_explanation" for name in output_feature_names]
128
+
129
+ truncated_background_data = get_truncated_sample_data(background_data, 5)
130
+ sig = model_signature.infer_signature(
131
+ input_data=truncated_background_data,
132
+ output_data=explain_fn(truncated_background_data),
133
+ input_feature_names=[spec.name for spec in inputs],
134
+ output_feature_names=suffixed_output_names,
135
+ )
136
+
137
+ model_meta.signatures[explain_method] = sig
138
+ return model_meta
139
+
140
+
109
141
  def add_explain_method_signature(
110
142
  model_meta: model_meta.ModelMetadata,
111
143
  explain_method: str,
@@ -231,10 +263,11 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
231
263
 
232
264
 
233
265
  def get_explain_target_method(
234
- model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
266
+ model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
235
267
  ) -> Optional[str]:
236
- for method in model_metadata.signatures.keys():
237
- if method in target_methods_list:
268
+ """Returns the first target method that is found in the model metadata signatures."""
269
+ for method in target_methods_list:
270
+ if method in model_metadata.signatures.keys():
238
271
  return method
239
272
  return None
240
273
 
@@ -248,7 +281,7 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
248
281
  config_dict = json.load(f)
249
282
 
250
283
  # a. get repository and class_path from configs
251
- auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
284
+ auto_map_configs = cast(dict[str, str], config_dict.get("auto_map", {}))
252
285
  for config_name, config_value in auto_map_configs.items():
253
286
  repository, _, class_path = config_value.rpartition("--")
254
287
 
@@ -261,3 +294,12 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
261
294
 
262
295
  with open(f_path, "w") as f:
263
296
  json.dump(config_dict, f)
297
+
298
+
299
+ def get_default_cuda_version() -> str:
300
+ # Default to the env cuda version when running in ML runtime
301
+ if env.IN_ML_RUNTIME and importlib.util.find_spec("torch") is not None:
302
+ import torch
303
+
304
+ return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
305
+ return model_env.DEFAULT_CUDA_VERSION
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -30,7 +30,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  HANDLER_TYPE = "catboost"
31
31
  HANDLER_VERSION = "2024-03-21"
32
32
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
33
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
34
34
 
35
35
  MODEL_BLOB_FILE_OR_DIR = "model.bin"
36
36
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -147,7 +147,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
147
147
  if enable_explainability:
148
148
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
149
149
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
150
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
150
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
151
151
 
152
152
  return None
153
153
 
@@ -202,7 +202,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
202
202
  def _create_custom_model(
203
203
  raw_model: "catboost.CatBoost",
204
204
  model_meta: model_meta_api.ModelMetadata,
205
- ) -> Type[custom_model.CustomModel]:
205
+ ) -> type[custom_model.CustomModel]:
206
206
  def fn_factory(
207
207
  raw_model: "catboost.CatBoost",
208
208
  signature: model_signature.ModelSignature,
@@ -235,7 +235,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
235
235
 
236
236
  return fn
237
237
 
238
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
238
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
239
239
  for target_method_name, sig in model_meta.signatures.items():
240
240
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
241
241
 
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, cast, final
5
+ from typing import Optional, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
28
28
  HANDLER_TYPE = "custom"
29
29
  HANDLER_VERSION = "2023-12-01"
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
31
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
31
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
33
  @classmethod
34
34
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
@@ -72,7 +72,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
72
72
  predictions_df = target_method(model, sample_input_data)
73
73
  return predictions_df
74
74
 
75
- for func_name in model._get_partitioned_infer_methods():
75
+ for func_name in model._get_partitioned_methods():
76
76
  function_properties = model_meta.function_properties.get(func_name, {})
77
77
  function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
78
78
  model_meta.function_properties[func_name] = function_properties
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
101
  if handler is None:
102
- raise TypeError("Your input type to custom model is not currently supported")
102
+ raise TypeError(
103
+ f"Model {sub_name} in model context is not a supported model type. See "
104
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
105
+ "bring-your-own-model-types for more details."
106
+ )
103
107
  sub_model = handler.cast_model(model_ref.model)
104
108
  handler.save_model(
105
109
  name=sub_name,
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
161
165
  name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
162
166
  for name, rel_path in artifacts_meta.items()
163
167
  }
164
- models: Dict[str, model_types.SupportedModelType] = dict()
168
+ models: dict[str, model_types.SupportedModelType] = dict()
165
169
  for sub_model_name, _ref in context.model_refs.items():
166
170
  model_type = model_meta.models[sub_model_name].model_type
167
171
  handler = model_handler.load_handler(model_type)
@@ -1,18 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- List,
10
- Optional,
11
- Type,
12
- Union,
13
- cast,
14
- final,
15
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
16
5
 
17
6
  import cloudpickle
18
7
  import numpy as np
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
38
27
  import transformers
39
28
 
40
29
 
41
- def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model_env.ModelDependency]:
30
+ def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
42
31
  # Text
43
32
  if task in [
44
33
  "conversational",
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
84
73
  HANDLER_TYPE = "huggingface_pipeline"
85
74
  HANDLER_VERSION = "2023-12-01"
86
75
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
87
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
76
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
88
77
 
89
78
  MODEL_BLOB_FILE_OR_DIR = "model"
90
79
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
250
239
  task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
251
240
  )
252
241
  if framework is None or framework == "pt":
253
- # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
254
- # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
255
- # users are not required to install pytorch locally if they are using the wrapper.
256
242
  pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
257
243
  elif framework == "tf":
258
244
  pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
259
245
  model_meta.env.include_if_absent(
260
246
  pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
261
247
  )
262
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
248
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
263
249
 
264
250
  @staticmethod
265
- def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
266
- device_config: Dict[str, Any] = {}
251
+ def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
252
+ device_config: dict[str, Any] = {}
267
253
  cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
268
254
  gpu_nums = 0
269
255
  if cuda_visible_devices is not None:
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
369
355
  def _create_custom_model(
370
356
  raw_model: "transformers.Pipeline",
371
357
  model_meta: model_meta_api.ModelMetadata,
372
- ) -> Type[custom_model.CustomModel]:
358
+ ) -> type[custom_model.CustomModel]:
373
359
  def fn_factory(
374
360
  raw_model: "transformers.Pipeline",
375
361
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
32
32
  HANDLER_TYPE = "keras"
33
33
  HANDLER_VERSION = "2025-01-01"
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.7.5"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
35
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
37
  MODEL_BLOB_FILE_OR_DIR = "model.keras"
38
38
  CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
146
146
  dependencies,
147
147
  check_local_version=True,
148
148
  )
149
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
149
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
150
150
 
151
151
  @classmethod
152
152
  def load_model(
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
185
185
  def _create_custom_model(
186
186
  raw_model: "keras.Model",
187
187
  model_meta: model_meta_api.ModelMetadata,
188
- ) -> Type[custom_model.CustomModel]:
188
+ ) -> type[custom_model.CustomModel]:
189
189
  def fn_factory(
190
190
  raw_model: "keras.Model",
191
191
  signature: model_signature.ModelSignature,
@@ -1,16 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import (
4
- TYPE_CHECKING,
5
- Any,
6
- Callable,
7
- Dict,
8
- Optional,
9
- Type,
10
- Union,
11
- cast,
12
- final,
13
- )
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
14
4
 
15
5
  import cloudpickle
16
6
  import numpy as np
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
31
  HANDLER_TYPE = "lightgbm"
42
32
  HANDLER_VERSION = "2024-03-19"
43
33
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
44
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
34
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
45
35
 
46
36
  MODEL_BLOB_FILE_OR_DIR = "model.pkl"
47
37
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
215
205
  def _create_custom_model(
216
206
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
217
207
  model_meta: model_meta_api.ModelMetadata,
218
- ) -> Type[custom_model.CustomModel]:
208
+ ) -> type[custom_model.CustomModel]:
219
209
  def fn_factory(
220
210
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
221
211
  signature: model_signature.ModelSignature,
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
250
240
 
251
241
  return fn
252
242
 
253
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
243
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
254
244
  for target_method_name, sig in model_meta.signatures.items():
255
245
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
256
246
 
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import pathlib
3
3
  import tempfile
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
61
61
  HANDLER_TYPE = "mlflow"
62
62
  HANDLER_VERSION = "2023-12-01"
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
64
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
66
  MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
204
204
  def _create_custom_model(
205
205
  raw_model: "mlflow.pyfunc.PyFuncModel",
206
206
  model_meta: model_meta_api.ModelMetadata,
207
- ) -> Type[custom_model.CustomModel]:
207
+ ) -> type[custom_model.CustomModel]:
208
208
  def fn_factory(
209
209
  raw_model: "mlflow.pyfunc.PyFuncModel",
210
210
  signature: model_signature.ModelSignature,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import sys
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import pandas as pd
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
38
38
  HANDLER_TYPE = "pytorch"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
43
  }
44
44
 
@@ -82,6 +82,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
82
82
  enable_explainability = kwargs.get("enable_explainability", False)
83
83
  if enable_explainability:
84
84
  raise NotImplementedError("Explainability is not supported for PyTorch model.")
85
+ multiple_inputs = kwargs.get("multiple_inputs", False)
85
86
 
86
87
  import torch
87
88
 
@@ -94,8 +95,6 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
94
95
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
95
96
  )
96
97
 
97
- multiple_inputs = kwargs.get("multiple_inputs", False)
98
-
99
98
  def get_prediction(
100
99
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
101
100
  ) -> model_types.SupportedLocalDataType:
@@ -151,7 +150,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
151
150
  model_meta.env.include_if_absent(
152
151
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
153
152
  )
154
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
153
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
155
154
 
156
155
  @classmethod
157
156
  def load_model(
@@ -188,7 +187,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
188
187
  def _create_custom_model(
189
188
  raw_model: "torch.nn.Module",
190
189
  model_meta: model_meta_api.ModelMetadata,
191
- ) -> Type[custom_model.CustomModel]:
190
+ ) -> type[custom_model.CustomModel]:
192
191
  multiple_inputs = cast(
193
192
  model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
193
  )["multiple_inputs"]