snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Optional
5
5
  from urllib.parse import ParseResult
6
6
 
7
7
  from snowflake.ml._internal.utils import (
@@ -34,7 +34,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
34
34
  model_name: sql_identifier.SqlIdentifier,
35
35
  version_name: sql_identifier.SqlIdentifier,
36
36
  stage_path: str,
37
- statement_params: Optional[Dict[str, Any]] = None,
37
+ statement_params: Optional[dict[str, Any]] = None,
38
38
  ) -> None:
39
39
  query_result_checker.SqlResultValidator(
40
40
  self._session,
@@ -56,7 +56,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
56
56
  schema_name: Optional[sql_identifier.SqlIdentifier],
57
57
  model_name: sql_identifier.SqlIdentifier,
58
58
  version_name: sql_identifier.SqlIdentifier,
59
- statement_params: Optional[Dict[str, Any]] = None,
59
+ statement_params: Optional[dict[str, Any]] = None,
60
60
  ) -> None:
61
61
  fq_source_model_name = self.fully_qualified_object_name(
62
62
  source_database_name, source_schema_name, source_model_name
@@ -78,7 +78,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
78
78
  schema_name: Optional[sql_identifier.SqlIdentifier],
79
79
  model_name: sql_identifier.SqlIdentifier,
80
80
  version_name: sql_identifier.SqlIdentifier,
81
- statement_params: Optional[Dict[str, Any]] = None,
81
+ statement_params: Optional[dict[str, Any]] = None,
82
82
  ) -> None:
83
83
  sql = (
84
84
  f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -97,7 +97,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
97
97
  schema_name: Optional[sql_identifier.SqlIdentifier],
98
98
  model_name: sql_identifier.SqlIdentifier,
99
99
  version_name: sql_identifier.SqlIdentifier,
100
- statement_params: Optional[Dict[str, Any]] = None,
100
+ statement_params: Optional[dict[str, Any]] = None,
101
101
  ) -> None:
102
102
  sql = (
103
103
  f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -116,7 +116,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
116
116
  schema_name: Optional[sql_identifier.SqlIdentifier],
117
117
  model_name: sql_identifier.SqlIdentifier,
118
118
  version_name: sql_identifier.SqlIdentifier,
119
- statement_params: Optional[Dict[str, Any]] = None,
119
+ statement_params: Optional[dict[str, Any]] = None,
120
120
  ) -> None:
121
121
  sql = (
122
122
  f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -138,7 +138,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
138
138
  model_name: sql_identifier.SqlIdentifier,
139
139
  version_name: sql_identifier.SqlIdentifier,
140
140
  stage_path: str,
141
- statement_params: Optional[Dict[str, Any]] = None,
141
+ statement_params: Optional[dict[str, Any]] = None,
142
142
  ) -> None:
143
143
  query_result_checker.SqlResultValidator(
144
144
  self._session,
@@ -160,7 +160,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
160
160
  schema_name: Optional[sql_identifier.SqlIdentifier],
161
161
  model_name: sql_identifier.SqlIdentifier,
162
162
  version_name: sql_identifier.SqlIdentifier,
163
- statement_params: Optional[Dict[str, Any]] = None,
163
+ statement_params: Optional[dict[str, Any]] = None,
164
164
  ) -> None:
165
165
  fq_source_model_name = self.fully_qualified_object_name(
166
166
  source_database_name, source_schema_name, source_model_name
@@ -182,7 +182,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
182
182
  schema_name: Optional[sql_identifier.SqlIdentifier],
183
183
  model_name: sql_identifier.SqlIdentifier,
184
184
  version_name: sql_identifier.SqlIdentifier,
185
- statement_params: Optional[Dict[str, Any]] = None,
185
+ statement_params: Optional[dict[str, Any]] = None,
186
186
  ) -> None:
187
187
  query_result_checker.SqlResultValidator(
188
188
  self._session,
@@ -201,7 +201,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
201
201
  model_name: sql_identifier.SqlIdentifier,
202
202
  version_name: sql_identifier.SqlIdentifier,
203
203
  alias_name: sql_identifier.SqlIdentifier,
204
- statement_params: Optional[Dict[str, Any]] = None,
204
+ statement_params: Optional[dict[str, Any]] = None,
205
205
  ) -> None:
206
206
  query_result_checker.SqlResultValidator(
207
207
  self._session,
@@ -219,7 +219,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
219
219
  schema_name: Optional[sql_identifier.SqlIdentifier],
220
220
  model_name: sql_identifier.SqlIdentifier,
221
221
  version_or_alias_name: sql_identifier.SqlIdentifier,
222
- statement_params: Optional[Dict[str, Any]] = None,
222
+ statement_params: Optional[dict[str, Any]] = None,
223
223
  ) -> None:
224
224
  query_result_checker.SqlResultValidator(
225
225
  self._session,
@@ -239,8 +239,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
239
239
  version_name: sql_identifier.SqlIdentifier,
240
240
  file_path: pathlib.PurePosixPath,
241
241
  is_dir: bool = False,
242
- statement_params: Optional[Dict[str, Any]] = None,
243
- ) -> List[row.Row]:
242
+ statement_params: Optional[dict[str, Any]] = None,
243
+ ) -> list[row.Row]:
244
244
  # Workaround for snowURL bug.
245
245
  trailing_slash = "/" if is_dir else ""
246
246
 
@@ -276,7 +276,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
276
276
  version_name: sql_identifier.SqlIdentifier,
277
277
  file_path: pathlib.PurePosixPath,
278
278
  target_path: pathlib.Path,
279
- statement_params: Optional[Dict[str, Any]] = None,
279
+ statement_params: Optional[dict[str, Any]] = None,
280
280
  ) -> pathlib.Path:
281
281
  stage_location = pathlib.PurePosixPath(
282
282
  self.fully_qualified_object_name(database_name, schema_name, model_name),
@@ -310,8 +310,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
310
310
  schema_name: Optional[sql_identifier.SqlIdentifier],
311
311
  model_name: sql_identifier.SqlIdentifier,
312
312
  version_name: sql_identifier.SqlIdentifier,
313
- statement_params: Optional[Dict[str, Any]] = None,
314
- ) -> List[row.Row]:
313
+ statement_params: Optional[dict[str, Any]] = None,
314
+ ) -> list[row.Row]:
315
315
  res = query_result_checker.SqlResultValidator(
316
316
  self._session,
317
317
  (
@@ -331,7 +331,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
331
331
  model_name: sql_identifier.SqlIdentifier,
332
332
  version_name: sql_identifier.SqlIdentifier,
333
333
  comment: str,
334
- statement_params: Optional[Dict[str, Any]] = None,
334
+ statement_params: Optional[dict[str, Any]] = None,
335
335
  ) -> None:
336
336
  query_result_checker.SqlResultValidator(
337
337
  self._session,
@@ -351,9 +351,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
351
351
  version_name: sql_identifier.SqlIdentifier,
352
352
  method_name: sql_identifier.SqlIdentifier,
353
353
  input_df: dataframe.DataFrame,
354
- input_args: List[sql_identifier.SqlIdentifier],
355
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
- statement_params: Optional[Dict[str, Any]] = None,
354
+ input_args: list[sql_identifier.SqlIdentifier],
355
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
+ statement_params: Optional[dict[str, Any]] = None,
357
357
  ) -> dataframe.DataFrame:
358
358
  with_statements = []
359
359
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -433,10 +433,10 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
433
433
  version_name: sql_identifier.SqlIdentifier,
434
434
  method_name: sql_identifier.SqlIdentifier,
435
435
  input_df: dataframe.DataFrame,
436
- input_args: List[sql_identifier.SqlIdentifier],
437
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
436
+ input_args: list[sql_identifier.SqlIdentifier],
437
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
438
438
  partition_column: Optional[sql_identifier.SqlIdentifier],
439
- statement_params: Optional[Dict[str, Any]] = None,
439
+ statement_params: Optional[dict[str, Any]] = None,
440
440
  is_partitioned: bool = True,
441
441
  ) -> dataframe.DataFrame:
442
442
  with_statements = []
@@ -529,13 +529,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
529
529
 
530
530
  def set_metadata(
531
531
  self,
532
- metadata_dict: Dict[str, Any],
532
+ metadata_dict: dict[str, Any],
533
533
  *,
534
534
  database_name: Optional[sql_identifier.SqlIdentifier],
535
535
  schema_name: Optional[sql_identifier.SqlIdentifier],
536
536
  model_name: sql_identifier.SqlIdentifier,
537
537
  version_name: sql_identifier.SqlIdentifier,
538
- statement_params: Optional[Dict[str, Any]] = None,
538
+ statement_params: Optional[dict[str, Any]] = None,
539
539
  ) -> None:
540
540
  json_metadata = json.dumps(metadata_dict)
541
541
  query_result_checker.SqlResultValidator(
@@ -554,7 +554,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
554
554
  schema_name: Optional[sql_identifier.SqlIdentifier],
555
555
  model_name: sql_identifier.SqlIdentifier,
556
556
  version_name: sql_identifier.SqlIdentifier,
557
- statement_params: Optional[Dict[str, Any]] = None,
557
+ statement_params: Optional[dict[str, Any]] = None,
558
558
  ) -> None:
559
559
  query_result_checker.SqlResultValidator(
560
560
  self._session,
@@ -1,10 +1,9 @@
1
1
  import enum
2
2
  import json
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
- from snowflake.ml._internal import platform_capabilities
8
7
  from snowflake.ml._internal.utils import (
9
8
  identifier,
10
9
  query_result_checker,
@@ -47,7 +46,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
47
46
  gpu: Optional[Union[str, int]],
48
47
  force_rebuild: bool,
49
48
  external_access_integration: sql_identifier.SqlIdentifier,
50
- statement_params: Optional[Dict[str, Any]] = None,
49
+ statement_params: Optional[dict[str, Any]] = None,
51
50
  ) -> None:
52
51
  actual_image_repo_database = image_repo_database_name or self._database_name
53
52
  actual_image_repo_schema = image_repo_schema_name or self._schema_name
@@ -76,8 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
76
75
  stage_path: Optional[str] = None,
77
76
  model_deployment_spec_yaml_str: Optional[str] = None,
78
77
  model_deployment_spec_file_rel_path: Optional[str] = None,
79
- statement_params: Optional[Dict[str, Any]] = None,
80
- ) -> Tuple[str, snowpark.AsyncJob]:
78
+ statement_params: Optional[dict[str, Any]] = None,
79
+ ) -> tuple[str, snowpark.AsyncJob]:
81
80
  assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
82
81
  if model_deployment_spec_yaml_str:
83
82
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
@@ -95,9 +94,9 @@ class ServiceSQLClient(_base._BaseSQLClient):
95
94
  service_name: sql_identifier.SqlIdentifier,
96
95
  method_name: sql_identifier.SqlIdentifier,
97
96
  input_df: dataframe.DataFrame,
98
- input_args: List[sql_identifier.SqlIdentifier],
99
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
100
- statement_params: Optional[Dict[str, Any]] = None,
97
+ input_args: list[sql_identifier.SqlIdentifier],
98
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
99
+ statement_params: Optional[dict[str, Any]] = None,
101
100
  ) -> dataframe.DataFrame:
102
101
  with_statements = []
103
102
  actual_database_name = database_name or self._database_name
@@ -133,18 +132,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
133
132
  input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
134
133
  args_sql = f"object_construct_keep_null({input_args_sql})"
135
134
 
136
- if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
137
- fully_qualified_service_name = self.fully_qualified_object_name(
138
- actual_database_name, actual_schema_name, service_name
139
- )
140
- fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
141
- else:
142
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
143
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
144
- actual_database_name.identifier(),
145
- actual_schema_name.identifier(),
146
- function_name,
147
- )
135
+ fully_qualified_service_name = self.fully_qualified_object_name(
136
+ actual_database_name, actual_schema_name, service_name
137
+ )
138
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
148
139
 
149
140
  sql = textwrap.dedent(
150
141
  f"""{with_sql}
@@ -181,7 +172,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
181
172
  service_name: sql_identifier.SqlIdentifier,
182
173
  instance_id: str = "0",
183
174
  container_name: str,
184
- statement_params: Optional[Dict[str, Any]] = None,
175
+ statement_params: Optional[dict[str, Any]] = None,
185
176
  ) -> str:
186
177
  system_func = "SYSTEM$GET_SERVICE_LOGS"
187
178
  rows = (
@@ -206,8 +197,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
206
197
  schema_name: Optional[sql_identifier.SqlIdentifier],
207
198
  service_name: sql_identifier.SqlIdentifier,
208
199
  include_message: bool = False,
209
- statement_params: Optional[Dict[str, Any]] = None,
210
- ) -> Tuple[ServiceStatus, Optional[str]]:
200
+ statement_params: Optional[dict[str, Any]] = None,
201
+ ) -> tuple[ServiceStatus, Optional[str]]:
211
202
  system_func = "SYSTEM$GET_SERVICE_STATUS"
212
203
  rows = (
213
204
  query_result_checker.SqlResultValidator(
@@ -231,7 +222,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
231
222
  database_name: Optional[sql_identifier.SqlIdentifier],
232
223
  schema_name: Optional[sql_identifier.SqlIdentifier],
233
224
  service_name: sql_identifier.SqlIdentifier,
234
- statement_params: Optional[Dict[str, Any]] = None,
225
+ statement_params: Optional[dict[str, Any]] = None,
235
226
  ) -> None:
236
227
  query_result_checker.SqlResultValidator(
237
228
  self._session,
@@ -245,8 +236,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
245
236
  database_name: Optional[sql_identifier.SqlIdentifier],
246
237
  schema_name: Optional[sql_identifier.SqlIdentifier],
247
238
  service_name: sql_identifier.SqlIdentifier,
248
- statement_params: Optional[Dict[str, Any]] = None,
249
- ) -> List[row.Row]:
239
+ statement_params: Optional[dict[str, Any]] = None,
240
+ ) -> list[row.Row]:
250
241
  fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
251
242
  res = (
252
243
  query_result_checker.SqlResultValidator(
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -11,7 +11,7 @@ class StageSQLClient(_base._BaseSQLClient):
11
11
  database_name: Optional[sql_identifier.SqlIdentifier],
12
12
  schema_name: Optional[sql_identifier.SqlIdentifier],
13
13
  stage_name: sql_identifier.SqlIdentifier,
14
- statement_params: Optional[Dict[str, Any]] = None,
14
+ statement_params: Optional[dict[str, Any]] = None,
15
15
  ) -> None:
16
16
  query_result_checker.SqlResultValidator(
17
17
  self._session,
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -16,7 +16,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
16
16
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
17
17
  tag_name: sql_identifier.SqlIdentifier,
18
18
  tag_value: str,
19
- statement_params: Optional[Dict[str, Any]] = None,
19
+ statement_params: Optional[dict[str, Any]] = None,
20
20
  ) -> None:
21
21
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
22
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -35,7 +35,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
35
35
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
36
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
37
37
  tag_name: sql_identifier.SqlIdentifier,
38
- statement_params: Optional[Dict[str, Any]] = None,
38
+ statement_params: Optional[dict[str, Any]] = None,
39
39
  ) -> None:
40
40
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
41
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -54,7 +54,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
54
54
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
55
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
56
56
  tag_name: sql_identifier.SqlIdentifier,
57
- statement_params: Optional[Dict[str, Any]] = None,
57
+ statement_params: Optional[dict[str, Any]] = None,
58
58
  ) -> row.Row:
59
59
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
60
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -75,8 +75,8 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
75
75
  database_name: Optional[sql_identifier.SqlIdentifier],
76
76
  schema_name: Optional[sql_identifier.SqlIdentifier],
77
77
  model_name: sql_identifier.SqlIdentifier,
78
- statement_params: Optional[Dict[str, Any]] = None,
79
- ) -> List[row.Row]:
78
+ statement_params: Optional[dict[str, Any]] = None,
79
+ ) -> list[row.Row]:
80
80
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
81
  actual_database_name = database_name or self._database_name
82
82
  return (
@@ -3,13 +3,14 @@ import tempfile
3
3
  import uuid
4
4
  import warnings
5
5
  from types import ModuleType
6
- from typing import Any, Dict, List, Optional, Union
6
+ from typing import Any, Optional, Union
7
7
  from urllib import parse
8
8
 
9
9
  from absl import logging
10
10
  from packaging import requirements
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import version as snowml_version
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
15
  from snowflake.ml._internal.lineage import lineage_utils
15
16
  from snowflake.ml.data import data_source
@@ -43,7 +44,7 @@ class ModelComposer:
43
44
  session: Session,
44
45
  stage_path: str,
45
46
  *,
46
- statement_params: Optional[Dict[str, Any]] = None,
47
+ statement_params: Optional[dict[str, Any]] = None,
47
48
  save_location: Optional[str] = None,
48
49
  ) -> None:
49
50
  self.session = session
@@ -122,17 +123,18 @@ class ModelComposer:
122
123
  *,
123
124
  name: str,
124
125
  model: model_types.SupportedModelType,
125
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
126
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
126
127
  sample_input_data: Optional[model_types.SupportedDataType] = None,
127
- metadata: Optional[Dict[str, str]] = None,
128
- conda_dependencies: Optional[List[str]] = None,
129
- pip_requirements: Optional[List[str]] = None,
130
- artifact_repository_map: Optional[Dict[str, str]] = None,
131
- target_platforms: Optional[List[model_types.TargetPlatform]] = None,
128
+ metadata: Optional[dict[str, str]] = None,
129
+ conda_dependencies: Optional[list[str]] = None,
130
+ pip_requirements: Optional[list[str]] = None,
131
+ artifact_repository_map: Optional[dict[str, str]] = None,
132
+ resource_constraint: Optional[dict[str, str]] = None,
133
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
132
134
  python_version: Optional[str] = None,
133
- user_files: Optional[Dict[str, List[str]]] = None,
134
- ext_modules: Optional[List[ModuleType]] = None,
135
- code_paths: Optional[List[str]] = None,
135
+ user_files: Optional[dict[str, list[str]]] = None,
136
+ ext_modules: Optional[list[ModuleType]] = None,
137
+ code_paths: Optional[list[str]] = None,
136
138
  task: model_types.Task = model_types.Task.UNKNOWN,
137
139
  options: Optional[model_types.ModelSaveOption] = None,
138
140
  ) -> model_meta.ModelMetadata:
@@ -140,40 +142,63 @@ class ModelComposer:
140
142
  conda_dep_dict = env_utils.validate_conda_dependency_string_list(
141
143
  conda_dependencies if conda_dependencies else []
142
144
  )
143
- is_warehouse_runnable = (
144
- not conda_dep_dict
145
- or all(
146
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
147
- for chan in conda_dep_dict
148
- )
149
- ) and (not pip_requirements)
150
- disable_explainability = (
151
- target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
152
- ) or (not is_warehouse_runnable)
153
-
154
- if disable_explainability and options and options.get("enable_explainability", False):
155
- warnings.warn(
156
- ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
157
- category=UserWarning,
158
- stacklevel=2,
145
+
146
+ enable_explainability = None
147
+
148
+ if options:
149
+ enable_explainability = options.get("enable_explainability", None)
150
+
151
+ # skip everything if user said False explicitly
152
+ if enable_explainability is None or enable_explainability is True:
153
+ is_warehouse_runnable = (
154
+ not conda_dep_dict
155
+ or all(
156
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
157
+ for chan in conda_dep_dict
158
+ )
159
+ ) and (not pip_requirements)
160
+
161
+ only_spcs = (
162
+ target_platforms
163
+ and len(target_platforms) == 1
164
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
159
165
  )
166
+ if only_spcs or (not is_warehouse_runnable):
167
+ # if only SPCS and user asked for explainability we fail
168
+ if enable_explainability is True:
169
+ raise ValueError(
170
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
171
+ "or the target platforms include SPCS."
172
+ )
173
+ elif not options: # explicitly set flag to false in these cases if not specified
174
+ options = model_types.BaseModelSaveOption()
175
+ options["enable_explainability"] = False
176
+ elif (
177
+ target_platforms
178
+ and len(target_platforms) > 1
179
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
180
+ ): # if both then only available for WH
181
+ if enable_explainability is True:
182
+ warnings.warn(
183
+ ("Explain function will only be available for model deployed to warehouse."),
184
+ category=UserWarning,
185
+ stacklevel=2,
186
+ )
160
187
 
161
188
  if not options:
162
189
  options = model_types.BaseModelSaveOption()
163
- if disable_explainability:
164
- options["enable_explainability"] = False
165
190
 
166
191
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
167
192
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
168
193
  self.session,
169
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
194
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
170
195
  python_version=python_version or snowml_env.PYTHON_VERSION,
171
196
  statement_params=self._statement_params,
172
197
  ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
173
198
 
174
199
  if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
175
200
  logging.info(
176
- f"Local snowflake-ml-python library has version {snowml_env.VERSION},"
201
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
177
202
  " which is not available in the Snowflake server, embedding local ML library automatically."
178
203
  )
179
204
  options["embed_local_ml_library"] = True
@@ -187,6 +212,7 @@ class ModelComposer:
187
212
  conda_dependencies=conda_dependencies,
188
213
  pip_requirements=pip_requirements,
189
214
  artifact_repository_map=artifact_repository_map,
215
+ resource_constraint=resource_constraint,
190
216
  target_platforms=target_platforms,
191
217
  python_version=python_version,
192
218
  ext_modules=ext_modules,
@@ -226,7 +252,7 @@ class ModelComposer:
226
252
 
227
253
  def _get_data_sources(
228
254
  self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
229
- ) -> Optional[List[data_source.DataSource]]:
255
+ ) -> Optional[list[data_source.DataSource]]:
230
256
  data_sources = lineage_utils.get_data_sources(model)
231
257
  if not data_sources and sample_input_data is not None:
232
258
  data_sources = lineage_utils.get_data_sources(sample_input_data)
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import Dict, List, Optional, cast
5
+ from typing import Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -45,10 +45,10 @@ class ModelManifest:
45
45
  self,
46
46
  model_meta: model_meta_api.ModelMetadata,
47
47
  model_rel_path: pathlib.PurePosixPath,
48
- user_files: Optional[Dict[str, List[str]]] = None,
48
+ user_files: Optional[dict[str, list[str]]] = None,
49
49
  options: Optional[type_hints.ModelSaveOption] = None,
50
- data_sources: Optional[List[data_source.DataSource]] = None,
51
- target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
50
+ data_sources: Optional[list[data_source.DataSource]] = None,
51
+ target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
52
52
  ) -> None:
53
53
  if options is None:
54
54
  options = {}
@@ -78,12 +78,13 @@ class ModelManifest:
78
78
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
79
79
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
80
80
  logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
+ logger.info(f"resource_constraint: {runtime_to_use.runtime_env.resource_constraint}")
81
82
  runtime_dict = runtime_to_use.save(
82
83
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
84
  )
84
85
 
85
86
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
86
- self.methods: List[model_method.ModelMethod] = []
87
+ self.methods: list[model_method.ModelMethod] = []
87
88
 
88
89
  for target_method in model_meta.signatures.keys():
89
90
  method = model_method.ModelMethod(
@@ -100,7 +101,7 @@ class ModelManifest:
100
101
 
101
102
  self.methods.append(method)
102
103
 
103
- self.user_files: List[model_user_file.ModelUserFile] = []
104
+ self.user_files: list[model_user_file.ModelUserFile] = []
104
105
 
105
106
  if user_files is not None:
106
107
  for subdirectory, paths in user_files.items():
@@ -127,16 +128,19 @@ class ModelManifest:
127
128
  if model_meta.env.artifact_repository_map:
128
129
  dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
129
130
 
131
+ runtime = model_manifest_schema.ModelRuntimeDict(
132
+ language="PYTHON",
133
+ version=runtime_to_use.runtime_env.python_version,
134
+ imports=runtime_dict["imports"],
135
+ dependencies=dependencies,
136
+ )
137
+
138
+ if runtime_dict["resource_constraint"]:
139
+ runtime["resource_constraint"] = runtime_dict["resource_constraint"]
140
+
130
141
  manifest_dict = model_manifest_schema.ModelManifestDict(
131
142
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
132
- runtimes={
133
- self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
134
- language="PYTHON",
135
- version=runtime_to_use.runtime_env.python_version,
136
- imports=runtime_dict["imports"],
137
- dependencies=dependencies,
138
- )
139
- },
143
+ runtimes={self._DEFAULT_RUNTIME_NAME: runtime},
140
144
  methods=[
141
145
  method.save(
142
146
  self.workspace_path,
@@ -178,8 +182,8 @@ class ModelManifest:
178
182
  return res
179
183
 
180
184
  def _extract_lineage_info(
181
- self, data_sources: Optional[List[data_source.DataSource]]
182
- ) -> List[model_manifest_schema.LineageSourceDict]:
185
+ self, data_sources: Optional[list[data_source.DataSource]]
186
+ ) -> list[model_manifest_schema.LineageSourceDict]:
183
187
  result = []
184
188
  if data_sources:
185
189
  for source in data_sources:
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
3
+ from typing import Any, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,14 +20,15 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
- artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
+ artifact_repository_map: NotRequired[Optional[dict[str, str]]]
24
24
 
25
25
 
26
26
  class ModelRuntimeDict(TypedDict):
27
27
  language: Required[Literal["PYTHON"]]
28
28
  version: Required[str]
29
- imports: Required[List[str]]
29
+ imports: Required[list[str]]
30
30
  dependencies: Required[ModelRuntimeDependenciesDict]
31
+ resource_constraint: NotRequired[Optional[dict[str, str]]]
31
32
 
32
33
 
33
34
  class ModelMethodSignatureField(TypedDict):
@@ -43,8 +44,8 @@ class ModelFunctionMethodDict(TypedDict):
43
44
  runtime: Required[str]
44
45
  type: Required[str]
45
46
  handler: Required[str]
46
- inputs: Required[List[ModelMethodSignatureFieldWithName]]
47
- outputs: Required[Union[List[ModelMethodSignatureField], List[ModelMethodSignatureFieldWithName]]]
47
+ inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
+ outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
48
49
 
49
50
 
50
51
  ModelMethodDict = ModelFunctionMethodDict
@@ -71,12 +72,12 @@ class ModelFunctionInfo(TypedDict):
71
72
  class ModelFunctionInfoDict(TypedDict):
72
73
  name: Required[str]
73
74
  target_method: Required[str]
74
- signature: Required[Dict[str, Any]]
75
+ signature: Required[dict[str, Any]]
75
76
 
76
77
 
77
78
  class SnowparkMLDataDict(TypedDict):
78
79
  schema_version: Required[str]
79
- functions: Required[List[ModelFunctionInfoDict]]
80
+ functions: Required[list[ModelFunctionInfoDict]]
80
81
 
81
82
 
82
83
  class LineageSourceTypes(enum.Enum):
@@ -92,9 +93,9 @@ class LineageSourceDict(TypedDict):
92
93
 
93
94
  class ModelManifestDict(TypedDict):
94
95
  manifest_version: Required[str]
95
- runtimes: Required[Dict[str, ModelRuntimeDict]]
96
- methods: Required[List[ModelMethodDict]]
97
- user_data: NotRequired[Dict[str, Any]]
98
- user_files: NotRequired[List[str]]
99
- lineage_sources: NotRequired[List[LineageSourceDict]]
100
- target_platforms: NotRequired[List[str]]
96
+ runtimes: Required[dict[str, ModelRuntimeDict]]
97
+ methods: Required[list[ModelMethodDict]]
98
+ user_data: NotRequired[dict[str, Any]]
99
+ user_files: NotRequired[list[str]]
100
+ lineage_sources: NotRequired[list[LineageSourceDict]]
101
+ target_platforms: NotRequired[list[str]]