snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import os
5
5
  import posixpath
6
6
  import sys
7
7
  import uuid
8
- from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Optional, Union
9
9
 
10
10
  import cloudpickle as cp
11
11
  import numpy as np
@@ -50,11 +50,11 @@ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}
50
50
  def construct_cv_results(
51
51
  estimator: Union[GridSearchCV, RandomizedSearchCV],
52
52
  n_split: int,
53
- param_grid: List[Dict[str, Any]],
54
- cv_results_raw_hex: List[Row],
53
+ param_grid: list[dict[str, Any]],
54
+ cv_results_raw_hex: list[Row],
55
55
  cross_validator_indices_length: int,
56
56
  parameter_grid_length: int,
57
- ) -> Tuple[bool, Dict[str, Any]]:
57
+ ) -> tuple[bool, dict[str, Any]]:
58
58
  """Construct the cross validation result from the UDF. Because we accelerate the process
59
59
  by the number of cross validation number, and the combination of parameter grids.
60
60
  Therefore, we need to stick them back together instead of returning the raw result
@@ -158,11 +158,11 @@ def construct_cv_results(
158
158
  def construct_cv_results_memory_efficient_version(
159
159
  estimator: Union[GridSearchCV, RandomizedSearchCV],
160
160
  n_split: int,
161
- param_grid: List[Dict[str, Any]],
162
- cv_results_raw_hex: List[Row],
161
+ param_grid: list[dict[str, Any]],
162
+ cv_results_raw_hex: list[Row],
163
163
  cross_validator_indices_length: int,
164
164
  parameter_grid_length: int,
165
- ) -> Tuple[Any, Dict[str, Any]]:
165
+ ) -> tuple[Any, dict[str, Any]]:
166
166
  """Construct the cross validation result from the UDF.
167
167
  The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
168
168
  This function need to decode the string and then call _format_result to stick them back together
@@ -210,7 +210,7 @@ def construct_cv_results_memory_efficient_version(
210
210
  # because original SearchCV is ranked by parameter first and cv second,
211
211
  # to make the memory efficient, we implemented by fitting on cv first and parameter second
212
212
  # when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
213
- def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
213
+ def generate_the_order_by_parameter_index(all_combination_length: int) -> list[int]:
214
214
  pattern = []
215
215
  for i in range(all_combination_length):
216
216
  if i % parameter_grid_length == 0:
@@ -221,7 +221,7 @@ def construct_cv_results_memory_efficient_version(
221
221
  pattern.append(j)
222
222
  return pattern
223
223
 
224
- def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
224
+ def rerank_array(original_array: list[Any], pattern: list[int]) -> list[Any]:
225
225
  reranked_array = []
226
226
  for index in pattern:
227
227
  reranked_array.append(original_array[index])
@@ -251,8 +251,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
251
251
  estimator: object,
252
252
  dataset: DataFrame,
253
253
  session: Session,
254
- input_cols: List[str],
255
- label_cols: Optional[List[str]],
254
+ input_cols: list[str],
255
+ label_cols: Optional[list[str]],
256
256
  sample_weight_col: Optional[str],
257
257
  autogenerated: bool = False,
258
258
  subproject: str = "",
@@ -289,10 +289,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
289
289
  dataset: DataFrame,
290
290
  session: Session,
291
291
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
292
- dependencies: List[str],
293
- udf_imports: List[str],
294
- input_cols: List[str],
295
- label_cols: Optional[List[str]],
292
+ dependencies: list[str],
293
+ udf_imports: list[str],
294
+ input_cols: list[str],
295
+ label_cols: Optional[list[str]],
296
296
  sample_weight_col: Optional[str],
297
297
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
298
298
  from itertools import product
@@ -382,10 +382,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
382
382
  )
383
383
  def _distributed_search(
384
384
  session: Session,
385
- imports: List[str],
385
+ imports: list[str],
386
386
  stage_estimator_file_name: str,
387
- input_cols: List[str],
388
- label_cols: Optional[List[str]],
387
+ input_cols: list[str],
388
+ label_cols: Optional[list[str]],
389
389
  ) -> str:
390
390
  import os
391
391
  import time
@@ -455,12 +455,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
455
455
  assert estimator is not None
456
456
 
457
457
  @cachetools.cached(cache={})
458
- def _load_data_into_udf() -> Tuple[
459
- Dict[str, pd.DataFrame],
458
+ def _load_data_into_udf() -> tuple[
459
+ dict[str, pd.DataFrame],
460
460
  Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
461
461
  pd.DataFrame,
462
462
  int,
463
- List[Dict[str, Any]],
463
+ list[dict[str, Any]],
464
464
  ]:
465
465
  import pyarrow.parquet as pq
466
466
 
@@ -512,7 +512,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
512
512
  self.data_length = data_length
513
513
  self.params_to_evaluate = params_to_evaluate
514
514
 
515
- def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]:
515
+ def process(self, params_idx: int, cv_idx: int) -> Iterator[tuple[str]]:
516
516
  # Assign parameter to GridSearchCV
517
517
  if hasattr(estimator, "param_grid"):
518
518
  self.estimator.param_grid = self.params_to_evaluate[params_idx]
@@ -699,10 +699,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
699
699
  dataset: DataFrame,
700
700
  session: Session,
701
701
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
702
- dependencies: List[str],
703
- udf_imports: List[str],
704
- input_cols: List[str],
705
- label_cols: Optional[List[str]],
702
+ dependencies: list[str],
703
+ udf_imports: list[str],
704
+ input_cols: list[str],
705
+ label_cols: Optional[list[str]],
706
706
  sample_weight_col: Optional[str],
707
707
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
708
708
  from itertools import product
@@ -727,7 +727,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
727
727
  # Create a temp file and dump the estimator to that file.
728
728
  estimator_file_name = temp_file_utils.get_temp_file_path()
729
729
  params_to_evaluate = list(param_grid)
730
- CONSTANTS: Dict[str, Any] = dict()
730
+ CONSTANTS: dict[str, Any] = dict()
731
731
  CONSTANTS["dataset_snowpark_cols"] = dataset.columns
732
732
  CONSTANTS["n_candidates"] = len(params_to_evaluate)
733
733
  CONSTANTS["_N_JOBS"] = estimator.n_jobs
@@ -791,10 +791,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
791
791
  )
792
792
  def _distributed_search(
793
793
  session: Session,
794
- imports: List[str],
794
+ imports: list[str],
795
795
  stage_estimator_file_name: str,
796
- input_cols: List[str],
797
- label_cols: Optional[List[str]],
796
+ input_cols: list[str],
797
+ label_cols: Optional[list[str]],
798
798
  ) -> str:
799
799
  import os
800
800
  import time
@@ -3,7 +3,7 @@ import inspect
3
3
  import os
4
4
  import posixpath
5
5
  import sys
6
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Optional
7
7
  from uuid import uuid4
8
8
 
9
9
  import cloudpickle as cp
@@ -73,10 +73,10 @@ class SnowparkTransformHandlers:
73
73
  def batch_inference(
74
74
  self,
75
75
  inference_method: str,
76
- input_cols: List[str],
77
- expected_output_cols: List[str],
76
+ input_cols: list[str],
77
+ expected_output_cols: list[str],
78
78
  session: Session,
79
- dependencies: List[str],
79
+ dependencies: list[str],
80
80
  drop_input_cols: Optional[bool] = False,
81
81
  expected_output_cols_type: Optional[str] = "",
82
82
  *args: Any,
@@ -229,11 +229,11 @@ class SnowparkTransformHandlers:
229
229
 
230
230
  def score(
231
231
  self,
232
- input_cols: List[str],
233
- label_cols: List[str],
232
+ input_cols: list[str],
233
+ label_cols: list[str],
234
234
  session: Session,
235
- dependencies: List[str],
236
- score_sproc_imports: List[str],
235
+ dependencies: list[str],
236
+ score_sproc_imports: list[str],
237
237
  sample_weight_col: Optional[str] = None,
238
238
  *args: Any,
239
239
  **kwargs: Any,
@@ -308,12 +308,12 @@ class SnowparkTransformHandlers:
308
308
  )
309
309
  def score_wrapper_sproc(
310
310
  session: Session,
311
- sql_queries: List[str],
311
+ sql_queries: list[str],
312
312
  stage_score_file_name: str,
313
- input_cols: List[str],
314
- label_cols: List[str],
313
+ input_cols: list[str],
314
+ label_cols: list[str],
315
315
  sample_weight_col: Optional[str],
316
- score_statement_params: Dict[str, str],
316
+ score_statement_params: dict[str, str],
317
317
  ) -> float:
318
318
  import inspect
319
319
  import os
@@ -382,7 +382,7 @@ class SnowparkTransformHandlers:
382
382
 
383
383
  return score
384
384
 
385
- def _get_validated_snowpark_dependencies(self, session: Session, dependencies: List[str]) -> List[str]:
385
+ def _get_validated_snowpark_dependencies(self, session: Session, dependencies: list[str]) -> list[str]:
386
386
  """A helper function to validate dependencies and return the available packages that exists
387
387
  in the snowflake anaconda channel
388
388
 
@@ -2,7 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, Callable, Optional, Union
6
6
 
7
7
  import cloudpickle as cp
8
8
  import pandas as pd
@@ -55,8 +55,8 @@ class SnowparkModelTrainer:
55
55
  estimator: object,
56
56
  dataset: DataFrame,
57
57
  session: Session,
58
- input_cols: List[str],
59
- label_cols: Optional[List[str]],
58
+ input_cols: list[str],
59
+ label_cols: Optional[list[str]],
60
60
  sample_weight_col: Optional[str],
61
61
  autogenerated: bool = False,
62
62
  subproject: str = "",
@@ -84,7 +84,7 @@ class SnowparkModelTrainer:
84
84
  self._subproject = subproject
85
85
  self._class_name = estimator.__class__.__name__
86
86
 
87
- def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
87
+ def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: dict[str, str]) -> object:
88
88
  """
89
89
  Downloads the serialized model from a stage location and unpickles it.
90
90
 
@@ -112,7 +112,7 @@ class SnowparkModelTrainer:
112
112
  def _build_fit_wrapper_sproc(
113
113
  self,
114
114
  model_spec: ModelSpecifications,
115
- ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
115
+ ) -> Callable[[Any, list[str], str, list[str], list[str], Optional[str], dict[str, str]], str]:
116
116
  """
117
117
  Constructs and returns a python stored procedure function to be used for training model.
118
118
 
@@ -129,12 +129,12 @@ class SnowparkModelTrainer:
129
129
 
130
130
  def fit_wrapper_function(
131
131
  session: Session,
132
- sql_queries: List[str],
132
+ sql_queries: list[str],
133
133
  temp_stage_name: str,
134
- input_cols: List[str],
135
- label_cols: List[str],
134
+ input_cols: list[str],
135
+ label_cols: list[str],
136
136
  sample_weight_col: Optional[str],
137
- statement_params: Dict[str, str],
137
+ statement_params: dict[str, str],
138
138
  ) -> str:
139
139
  import inspect
140
140
  import os
@@ -218,7 +218,7 @@ class SnowparkModelTrainer:
218
218
 
219
219
  return fit_wrapper_function
220
220
 
221
- def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
221
+ def _get_fit_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
222
222
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
223
223
 
224
224
  fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -243,7 +243,7 @@ class SnowparkModelTrainer:
243
243
  def _build_fit_predict_wrapper_sproc(
244
244
  self,
245
245
  model_spec: ModelSpecifications,
246
- ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
246
+ ) -> Callable[[Session, list[str], str, list[str], dict[str, str], bool, list[str], str], str]:
247
247
  """
248
248
  Constructs and returns a python stored procedure function to be used for training model.
249
249
 
@@ -258,12 +258,12 @@ class SnowparkModelTrainer:
258
258
 
259
259
  def fit_predict_wrapper_function(
260
260
  session: Session,
261
- sql_queries: List[str],
261
+ sql_queries: list[str],
262
262
  temp_stage_name: str,
263
- input_cols: List[str],
264
- statement_params: Dict[str, str],
263
+ input_cols: list[str],
264
+ statement_params: dict[str, str],
265
265
  drop_input_cols: bool,
266
- expected_output_cols_list: List[str],
266
+ expected_output_cols_list: list[str],
267
267
  fit_predict_result_name: str,
268
268
  ) -> str:
269
269
  import os
@@ -346,14 +346,14 @@ class SnowparkModelTrainer:
346
346
  ) -> Callable[
347
347
  [
348
348
  Session,
349
- List[str],
349
+ list[str],
350
350
  str,
351
- List[str],
352
- Optional[List[str]],
351
+ list[str],
352
+ Optional[list[str]],
353
353
  Optional[str],
354
- Dict[str, str],
354
+ dict[str, str],
355
355
  bool,
356
- List[str],
356
+ list[str],
357
357
  str,
358
358
  ],
359
359
  str,
@@ -372,14 +372,14 @@ class SnowparkModelTrainer:
372
372
 
373
373
  def fit_transform_wrapper_function(
374
374
  session: Session,
375
- sql_queries: List[str],
375
+ sql_queries: list[str],
376
376
  temp_stage_name: str,
377
- input_cols: List[str],
378
- label_cols: Optional[List[str]],
377
+ input_cols: list[str],
378
+ label_cols: Optional[list[str]],
379
379
  sample_weight_col: Optional[str],
380
- statement_params: Dict[str, str],
380
+ statement_params: dict[str, str],
381
381
  drop_input_cols: bool,
382
- expected_output_cols_list: List[str],
382
+ expected_output_cols_list: list[str],
383
383
  fit_transform_result_name: str,
384
384
  ) -> str:
385
385
  import os
@@ -473,7 +473,7 @@ class SnowparkModelTrainer:
473
473
 
474
474
  return fit_transform_wrapper_function
475
475
 
476
- def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
476
+ def _get_fit_predict_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
477
477
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
478
478
 
479
479
  fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -495,7 +495,7 @@ class SnowparkModelTrainer:
495
495
 
496
496
  return fit_predict_wrapper_sproc
497
497
 
498
- def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
498
+ def _get_fit_transform_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
499
499
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
500
500
 
501
501
  fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -586,10 +586,10 @@ class SnowparkModelTrainer:
586
586
 
587
587
  def train_fit_predict(
588
588
  self,
589
- expected_output_cols_list: List[str],
589
+ expected_output_cols_list: list[str],
590
590
  drop_input_cols: Optional[bool] = False,
591
591
  example_output_pd_df: Optional[pd.DataFrame] = None,
592
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
592
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
593
593
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
594
594
  This API is different from fit itself because it would also provide the predict
595
595
  output.
@@ -682,9 +682,9 @@ class SnowparkModelTrainer:
682
682
 
683
683
  def train_fit_transform(
684
684
  self,
685
- expected_output_cols_list: List[str],
685
+ expected_output_cols_list: list[str],
686
686
  drop_input_cols: Optional[bool] = False,
687
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
687
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
688
688
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
689
689
  This API is different from fit itself because it would also provide the transform
690
690
  output.
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import os
3
3
  import tempfile
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Optional
5
5
 
6
6
  import cloudpickle as cp
7
7
  import pandas as pd
@@ -41,13 +41,13 @@ _PROJECT = "ModelDevelopment"
41
41
 
42
42
 
43
43
  def get_data_iterator(
44
- file_paths: List[str],
44
+ file_paths: list[str],
45
45
  batch_size: int,
46
- input_cols: List[str],
47
- label_cols: List[str],
46
+ input_cols: list[str],
47
+ label_cols: list[str],
48
48
  sample_weight_col: Optional[str] = None,
49
49
  ) -> Any:
50
- from typing import List, Optional
50
+ from typing import Optional
51
51
 
52
52
  import xgboost
53
53
 
@@ -60,10 +60,10 @@ def get_data_iterator(
60
60
 
61
61
  def __init__(
62
62
  self,
63
- file_paths: List[str],
63
+ file_paths: list[str],
64
64
  batch_size: int,
65
- input_cols: List[str],
66
- label_cols: List[str],
65
+ input_cols: list[str],
66
+ label_cols: list[str],
67
67
  sample_weight_col: Optional[str] = None,
68
68
  ) -> None:
69
69
  """
@@ -151,10 +151,10 @@ def get_data_iterator(
151
151
 
152
152
  def train_xgboost_model(
153
153
  estimator: object,
154
- file_paths: List[str],
154
+ file_paths: list[str],
155
155
  batch_size: int,
156
- input_cols: List[str],
157
- label_cols: List[str],
156
+ input_cols: list[str],
157
+ label_cols: list[str],
158
158
  sample_weight_col: Optional[str] = None,
159
159
  ) -> object:
160
160
  """
@@ -247,8 +247,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
247
247
  estimator: object,
248
248
  dataset: DataFrame,
249
249
  session: Session,
250
- input_cols: List[str],
251
- label_cols: Optional[List[str]],
250
+ input_cols: list[str],
251
+ label_cols: Optional[list[str]],
252
252
  sample_weight_col: Optional[str],
253
253
  autogenerated: bool = False,
254
254
  subproject: str = "",
@@ -285,8 +285,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
285
285
  self,
286
286
  model_spec: ModelSpecifications,
287
287
  session: Session,
288
- statement_params: Dict[str, str],
289
- import_file_paths: List[str],
288
+ statement_params: dict[str, str],
289
+ import_file_paths: list[str],
290
290
  ) -> Any:
291
291
  fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
292
292
 
@@ -308,10 +308,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
308
308
  session: Session,
309
309
  dataset_stage_name: str,
310
310
  batch_size: int,
311
- input_cols: List[str],
312
- label_cols: List[str],
311
+ input_cols: list[str],
312
+ label_cols: list[str],
313
313
  sample_weight_col: Optional[str],
314
- statement_params: Dict[str, str],
314
+ statement_params: dict[str, str],
315
315
  ) -> str:
316
316
  import os
317
317
  import sys
@@ -365,7 +365,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
365
365
 
366
366
  return fit_wrapper_sproc
367
367
 
368
- def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]:
368
+ def _write_training_data_to_stage(self, dataset_stage_name: str) -> list[str]:
369
369
  """
370
370
  Materializes the training to the specified stage and returns the list of stage file paths.
371
371
 
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Optional, Protocol, TypedDict, Union
1
+ from typing import Any, Optional, Protocol, TypedDict, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -29,9 +29,9 @@ class LocalModelTransformHandlers(Protocol):
29
29
  def batch_inference(
30
30
  self,
31
31
  inference_method: str,
32
- input_cols: List[str],
33
- expected_output_cols: List[str],
34
- snowpark_input_cols: Optional[List[str]],
32
+ input_cols: list[str],
33
+ expected_output_cols: list[str],
34
+ snowpark_input_cols: Optional[list[str]],
35
35
  drop_input_cols: Optional[bool] = False,
36
36
  *args: Any,
37
37
  **kwargs: Any,
@@ -57,8 +57,8 @@ class LocalModelTransformHandlers(Protocol):
57
57
 
58
58
  def score(
59
59
  self,
60
- input_cols: List[str],
61
- label_cols: List[str],
60
+ input_cols: list[str],
61
+ label_cols: list[str],
62
62
  sample_weight_col: Optional[str],
63
63
  *args: Any,
64
64
  **kwargs: Any,
@@ -105,10 +105,10 @@ class RemoteModelTransformHandlers(Protocol):
105
105
  def batch_inference(
106
106
  self,
107
107
  inference_method: str,
108
- input_cols: List[str],
109
- expected_output_cols: List[str],
108
+ input_cols: list[str],
109
+ expected_output_cols: list[str],
110
110
  session: snowpark.Session,
111
- dependencies: List[str],
111
+ dependencies: list[str],
112
112
  drop_input_cols: Optional[bool] = False,
113
113
  expected_output_cols_type: Optional[str] = "",
114
114
  *args: Any,
@@ -137,11 +137,11 @@ class RemoteModelTransformHandlers(Protocol):
137
137
 
138
138
  def score(
139
139
  self,
140
- input_cols: List[str],
141
- label_cols: List[str],
140
+ input_cols: list[str],
141
+ label_cols: list[str],
142
142
  session: snowpark.Session,
143
- dependencies: List[str],
144
- score_sproc_imports: List[str],
143
+ dependencies: list[str],
144
+ score_sproc_imports: list[str],
145
145
  sample_weight_col: Optional[str] = None,
146
146
  *args: Any,
147
147
  **kwargs: Any,
@@ -173,10 +173,10 @@ ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransform
173
173
  class BatchInferenceKwargsTypedDict(TypedDict, total=False):
174
174
  """A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""
175
175
 
176
- snowpark_input_cols: Optional[List[str]]
176
+ snowpark_input_cols: Optional[list[str]]
177
177
  drop_input_cols: Optional[bool]
178
178
  session: snowpark.Session
179
- dependencies: List[str]
179
+ dependencies: list[str]
180
180
  expected_output_cols_type: str
181
181
  n_neighbors: Optional[int]
182
182
  return_distance: bool
@@ -186,5 +186,5 @@ class ScoreKwargsTypedDict(TypedDict, total=False):
186
186
  """A typed dict specifying all possible optional keyword args accepted by score() methods."""
187
187
 
188
188
  session: snowpark.Session
189
- dependencies: List[str]
190
- score_sproc_imports: List[str]
189
+ dependencies: list[str]
190
+ score_sproc_imports: list[str]
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class CalibratedClassifierCV(BaseTransformer):
64
72
  r"""Probability calibration with isotonic regression or logistic regression
65
73
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class AffinityPropagation(BaseTransformer):
64
72
  r"""Perform Affinity Propagation Clustering of data
65
73
  For more details on this class, see [sklearn.cluster.AffinityPropagation]
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class AgglomerativeClustering(BaseTransformer):
64
72
  r"""Agglomerative Clustering
65
73
  For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class Birch(BaseTransformer):
64
72
  r"""Implements the BIRCH clustering algorithm
65
73
  For more details on this class, see [sklearn.cluster.Birch]