snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Any, List, Optional, Tuple, Union, overload
2
+ from typing import Any, Optional, Union, overload
3
3
 
4
4
  from snowflake.snowpark._internal.analyzer import analyzer_utils
5
5
 
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
12
12
  _SF_SCHEMA_LEVEL_OBJECT = (
13
13
  rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
14
14
  )
15
- _SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>.*)"
15
+ _SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
16
16
  _SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
17
17
  _SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
18
18
 
@@ -112,7 +112,7 @@ def get_inferred_name(name: str) -> str:
112
112
  return escaped_id
113
113
 
114
114
 
115
- def concat_names(names: List[str]) -> str:
115
+ def concat_names(names: list[str]) -> str:
116
116
  """Concatenates `names` to form one valid id.
117
117
 
118
118
 
@@ -142,7 +142,7 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
142
142
 
143
143
  def parse_schema_level_object_identifier(
144
144
  object_name: str,
145
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
145
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
146
146
  """Parse a string which starts with schema level object.
147
147
 
148
148
  Args:
@@ -172,7 +172,7 @@ def parse_schema_level_object_identifier(
172
172
 
173
173
  def parse_snowflake_stage_path(
174
174
  path: str,
175
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
175
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
176
176
  """Parse a string which represents a snowflake stage path.
177
177
 
178
178
  Args:
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
197
197
  res.group("db"),
198
198
  res.group("schema"),
199
199
  res.group("object"),
200
- res.group("path"),
200
+ res.group("path") or "",
201
201
  )
202
202
 
203
203
 
@@ -260,11 +260,11 @@ def get_unescaped_names(ids: str) -> str:
260
260
 
261
261
 
262
262
  @overload
263
- def get_unescaped_names(ids: List[str]) -> List[str]:
263
+ def get_unescaped_names(ids: list[str]) -> list[str]:
264
264
  ...
265
265
 
266
266
 
267
- def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
267
+ def get_unescaped_names(ids: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
268
268
  """Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
269
269
  response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
270
270
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -308,11 +308,11 @@ def get_inferred_names(names: str) -> str:
308
308
 
309
309
 
310
310
  @overload
311
- def get_inferred_names(names: List[str]) -> List[str]:
311
+ def get_inferred_names(names: list[str]) -> list[str]:
312
312
  ...
313
313
 
314
314
 
315
- def get_inferred_names(names: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
315
+ def get_inferred_names(names: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
316
316
  """Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
317
317
  in case of column name contains special characters, and maintains case-sensitivity
318
318
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -1,5 +1,5 @@
1
1
  import importlib
2
- from typing import Any, Tuple
2
+ from typing import Any
3
3
 
4
4
 
5
5
  class MissingOptionalDependency:
@@ -46,7 +46,7 @@ def import_with_fallbacks(*targets: str) -> Any:
46
46
  raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
47
47
 
48
48
 
49
- def import_or_get_dummy(target: str) -> Tuple[Any, bool]:
49
+ def import_or_get_dummy(target: str) -> tuple[Any, bool]:
50
50
  """Try to import the the given target or return a dummy object.
51
51
 
52
52
  If the import target (package/module/symbol) is available, the target will be returned. If it is not available,
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  from contextlib import contextmanager
3
3
  from timeit import default_timer
4
- from typing import Any, Callable, Dict, Generator, Iterable, List, Optional
4
+ from typing import Any, Callable, Generator, Iterable, Optional
5
5
 
6
6
  import snowflake.snowpark.functions as F
7
7
  from snowflake import snowpark
@@ -17,17 +17,17 @@ def timer() -> Generator[Callable[[], float], None, None]:
17
17
  yield lambda: elapser()
18
18
 
19
19
 
20
- def _flatten(L: Iterable[List[Any]]) -> List[Any]:
20
+ def _flatten(L: Iterable[list[Any]]) -> list[Any]:
21
21
  return [val for sublist in L for val in sublist]
22
22
 
23
23
 
24
24
  def map_dataframe_by_column(
25
25
  df: snowpark.DataFrame,
26
- cols: List[str],
27
- map_func: Callable[[snowpark.DataFrame, List[str]], snowpark.DataFrame],
26
+ cols: list[str],
27
+ map_func: Callable[[snowpark.DataFrame, list[str]], snowpark.DataFrame],
28
28
  partition_size: int,
29
- statement_params: Optional[Dict[str, Any]] = None,
30
- ) -> List[List[Any]]:
29
+ statement_params: Optional[dict[str, Any]] = None,
30
+ ) -> list[list[Any]]:
31
31
  """Applies the `map_func` to the input DataFrame by parallelizing it over subsets of the column.
32
32
 
33
33
  Because the return results are materialized as Python lists *in memory*, this method should
@@ -84,7 +84,7 @@ def map_dataframe_by_column(
84
84
  unioned_df = mapped_df if unioned_df is None else unioned_df.union_all(mapped_df)
85
85
 
86
86
  # Store results in a list of size |n_partitions| x |n_rows| x |n_output_cols|
87
- all_results: List[List[List[Any]]] = [[] for _ in range(n_partitions - 1)]
87
+ all_results: list[list[list[Any]]] = [[] for _ in range(n_partitions - 1)]
88
88
 
89
89
  # Collect the results of the first n-1 partitions, removing the partition_id column
90
90
  unioned_result = unioned_df.collect(statement_params=statement_params) if unioned_df is not None else []
@@ -1,6 +1,6 @@
1
1
  import sys
2
2
  import warnings
3
- from typing import Dict, List, Optional, Tuple, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  from packaging.version import Version
6
6
 
@@ -8,7 +8,7 @@ from snowflake.ml._internal import telemetry
8
8
  from snowflake.snowpark import AsyncJob, Row, Session
9
9
  from snowflake.snowpark._internal import utils as snowpark_utils
10
10
 
11
- cache: Dict[str, Optional[str]] = {}
11
+ cache: dict[str, Optional[str]] = {}
12
12
 
13
13
  _PROJECT = "ModelDevelopment"
14
14
  _SUBPROJECT = "utils"
@@ -23,8 +23,8 @@ def is_relaxed() -> bool:
23
23
 
24
24
 
25
25
  def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
26
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
27
- ) -> List[str]:
26
+ pkg_versions: list[str], session: Session, subproject: Optional[str] = None
27
+ ) -> list[str]:
28
28
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
29
29
  return pkg_versions
30
30
  else:
@@ -32,9 +32,9 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
32
32
 
33
33
 
34
34
  def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
35
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
36
- ) -> List[str]:
37
- pkg_version_async_job_list: List[Tuple[str, AsyncJob]] = []
35
+ pkg_versions: list[str], session: Session, subproject: Optional[str] = None
36
+ ) -> list[str]:
37
+ pkg_version_async_job_list: list[tuple[str, AsyncJob]] = []
38
38
  for pkg_version in pkg_versions:
39
39
  if pkg_version not in cache:
40
40
  # Execute pkg version queries asynchronously.
@@ -64,7 +64,7 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
64
64
 
65
65
  def _query_pkg_version_supported_in_snowflake_conda_channel(
66
66
  pkg_version: str, session: Session, block: bool, subproject: Optional[str] = None
67
- ) -> Union[AsyncJob, List[Row]]:
67
+ ) -> Union[AsyncJob, list[Row]]:
68
68
  tokens = pkg_version.split("==")
69
69
  if len(tokens) != 2:
70
70
  raise RuntimeError(
@@ -102,9 +102,9 @@ def _query_pkg_version_supported_in_snowflake_conda_channel(
102
102
  return pkg_version_list_or_async_job
103
103
 
104
104
 
105
- def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
106
- pkg_version_conda_list: List[str] = []
107
- pkg_version_warning_list: List[List[str]] = []
105
+ def _get_conda_packages_and_emit_warnings(pkg_versions: list[str]) -> list[str]:
106
+ pkg_version_conda_list: list[str] = []
107
+ pkg_version_warning_list: list[list[str]] = []
108
108
  for pkg_version in pkg_versions:
109
109
  try:
110
110
  conda_pkg_version = cache[pkg_version]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations # for return self methods
2
2
 
3
3
  from functools import partial
4
- from typing import Any, Callable, Dict, List, Optional
4
+ from typing import Any, Callable, Optional
5
5
 
6
6
  from snowflake import connector, snowpark
7
7
  from snowflake.ml._internal.utils import formatting
@@ -123,7 +123,7 @@ def cell_value_by_column_matcher(
123
123
  return True
124
124
 
125
125
 
126
- _DEFAULT_MATCHERS: List[Callable[[List[snowpark.Row], Optional[str]], bool]] = [
126
+ _DEFAULT_MATCHERS: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = [
127
127
  partial(result_dimension_matcher, 1, 1),
128
128
  partial(column_name_matcher, "status"),
129
129
  ]
@@ -252,12 +252,12 @@ class SqlResultValidator(ResultValidator):
252
252
  """
253
253
 
254
254
  def __init__(
255
- self, session: snowpark.Session, query: str, statement_params: Optional[Dict[str, Any]] = None
255
+ self, session: snowpark.Session, query: str, statement_params: Optional[dict[str, Any]] = None
256
256
  ) -> None:
257
257
  self._session: snowpark.Session = session
258
258
  self._query: str = query
259
259
  self._success_matchers: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = []
260
- self._statement_params: Optional[Dict[str, Any]] = statement_params
260
+ self._statement_params: Optional[dict[str, Any]] = statement_params
261
261
 
262
262
  def _get_result(self) -> list[snowpark.Row]:
263
263
  """Collect the result of the given SQL query."""
@@ -1,15 +1,15 @@
1
1
  import enum
2
- from typing import Any, Dict, Optional, TypedDict, cast
2
+ from typing import Any, Optional, TypedDict, cast
3
3
 
4
4
  from packaging import version
5
5
  from typing_extensions import NotRequired, Required
6
6
 
7
7
  from snowflake.ml._internal.utils import query_result_checker
8
- from snowflake.snowpark import session
8
+ from snowflake.snowpark import exceptions as sp_exceptions, session
9
9
 
10
10
 
11
11
  def get_current_snowflake_version(
12
- sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
12
+ sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
13
13
  ) -> version.Version:
14
14
  """Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
15
15
  "7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
@@ -60,8 +60,8 @@ class SnowflakeRegion(TypedDict):
60
60
 
61
61
 
62
62
  def get_regions(
63
- sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
64
- ) -> Dict[str, SnowflakeRegion]:
63
+ sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
64
+ ) -> dict[str, SnowflakeRegion]:
65
65
  res = (
66
66
  query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
67
67
  .has_column("snowflake_region")
@@ -93,7 +93,7 @@ def get_regions(
93
93
  return res_dict
94
94
 
95
95
 
96
- def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
96
+ def get_current_region_id(sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None) -> str:
97
97
  res = (
98
98
  query_result_checker.SqlResultValidator(
99
99
  sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
@@ -103,3 +103,25 @@ def get_current_region_id(sess: session.Session, *, statement_params: Optional[D
103
103
  )
104
104
 
105
105
  return cast(str, res.CURRENT_REGION)
106
+
107
+
108
+ def get_current_cloud(
109
+ sess: session.Session,
110
+ default: Optional[SnowflakeCloudType] = None,
111
+ *,
112
+ statement_params: Optional[dict[str, Any]] = None,
113
+ ) -> SnowflakeCloudType:
114
+ region_id = get_current_region_id(sess, statement_params=statement_params)
115
+ try:
116
+ region = get_regions(sess, statement_params=statement_params)[region_id]
117
+ return region["cloud"]
118
+ except sp_exceptions.SnowparkSQLException:
119
+ # SHOW REGIONS not available, try to infer cloud from region name
120
+ region_name = region_id.split(".", 1)[-1] # Drop region group if any, e.g. PUBLIC
121
+ cloud_name_maybe = region_name.split("_", 1)[0] # Extract cloud name, e.g. AWS_US_WEST -> AWS
122
+ try:
123
+ return SnowflakeCloudType.from_value(cloud_name_maybe)
124
+ except ValueError:
125
+ if default:
126
+ return default
127
+ raise
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  import warnings
3
- from typing import List, Optional
3
+ from typing import Optional
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal.utils import sql_identifier
7
7
  from snowflake.snowpark import functions, types
8
8
 
9
9
 
10
- def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
10
+ def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[list[str]] = None) -> snowpark.DataFrame:
11
11
  """Cast columns in the dataframe to types that are compatible with tensor.
12
12
 
13
13
  It assists FileSet.make() in performing implicit data casting.
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from snowflake.ml._internal.utils import identifier
4
4
 
@@ -77,13 +77,13 @@ class SqlIdentifier(str):
77
77
  return super().__hash__()
78
78
 
79
79
 
80
- def to_sql_identifiers(list_of_str: List[str], *, case_sensitive: bool = False) -> List[SqlIdentifier]:
80
+ def to_sql_identifiers(list_of_str: list[str], *, case_sensitive: bool = False) -> list[SqlIdentifier]:
81
81
  return [SqlIdentifier(val, case_sensitive=case_sensitive) for val in list_of_str]
82
82
 
83
83
 
84
84
  def parse_fully_qualified_name(
85
85
  name: str,
86
- ) -> Tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
86
+ ) -> tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
87
87
  db, schema, object = identifier.parse_schema_level_object_identifier(name)
88
88
 
89
89
  assert name is not None, f"Unable parse the input name `{name}` as fully qualified."
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
@@ -24,8 +24,8 @@ def create_single_table(
24
24
  database_name: str,
25
25
  schema_name: str,
26
26
  table_name: str,
27
- table_schema: List[Tuple[str, str]],
28
- statement_params: Optional[Dict[str, Any]] = None,
27
+ table_schema: list[tuple[str, str]],
28
+ statement_params: Optional[dict[str, Any]] = None,
29
29
  ) -> str:
30
30
  """Creates a single table for registry and returns the fully qualified name of the table.
31
31
 
@@ -55,7 +55,7 @@ def create_single_table(
55
55
  return fully_qualified_table_name
56
56
 
57
57
 
58
- def insert_table_entry(session: snowpark.Session, table: str, columns: Dict[str, Any]) -> List[snowpark.Row]:
58
+ def insert_table_entry(session: snowpark.Session, table: str, columns: dict[str, Any]) -> list[snowpark.Row]:
59
59
  """Insert an entry into an internal Model Registry table.
60
60
 
61
61
  Args:
@@ -99,9 +99,9 @@ def validate_table_exist(session: snowpark.Session, table: str, qualified_schema
99
99
  return len(tables) == 1
100
100
 
101
101
 
102
- def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> Dict[str, str]:
102
+ def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> dict[str, str]:
103
103
  result = session.sql(f"DESC TABLE {qualified_schema_name}.{table_name}").collect()
104
- schema_dict: Dict[str, str] = {}
104
+ schema_dict: dict[str, str] = {}
105
105
  for row in result:
106
106
  schema_dict[row["name"]] = row["type"]
107
107
  return schema_dict
@@ -112,13 +112,13 @@ def get_table_schema_types(
112
112
  database: str,
113
113
  schema: str,
114
114
  table_name: str,
115
- ) -> Dict[str, types.DataType]:
115
+ ) -> dict[str, types.DataType]:
116
116
  fully_qualified_table_name = identifier.get_schema_level_object_identifier(
117
117
  db=database, schema=schema, object_name=table_name
118
118
  )
119
- struct_fields: List[types.StructField] = session.table(fully_qualified_table_name).schema.fields
119
+ struct_fields: list[types.StructField] = session.table(fully_qualified_table_name).schema.fields
120
120
 
121
- schema_dict: Dict[str, types.DataType] = {}
121
+ schema_dict: dict[str, types.DataType] = {}
122
122
  for field in struct_fields:
123
123
  schema_dict[field.name] = field.datatype
124
124
  return schema_dict
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Dict, Iterator, List, Optional, Sequence, Union
5
+ from typing import Any, Deque, Iterator, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -71,7 +71,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
71
71
  return cls(session, sources)
72
72
 
73
73
  @property
74
- def data_sources(self) -> List[data_source.DataSource]:
74
+ def data_sources(self) -> list[data_source.DataSource]:
75
75
  return self._data_sources
76
76
 
77
77
  def to_batches(
@@ -79,7 +79,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
79
79
  batch_size: int,
80
80
  shuffle: bool = True,
81
81
  drop_last_batch: bool = True,
82
- ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
82
+ ) -> Iterator[dict[str, npt.NDArray[Any]]]:
83
83
  """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
84
84
 
85
85
  As we are generating batches with the exactly same length, the last few rows in each file might get left as they
@@ -120,7 +120,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
120
120
 
121
121
  def _get_dataset(self, shuffle: bool) -> pds.Dataset:
122
122
  format = self._format
123
- sources: List[Any] = []
123
+ sources: list[Any] = []
124
124
  source_format = None
125
125
  for source in self._data_sources:
126
126
  if isinstance(source, str):
@@ -155,7 +155,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
155
155
  pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
156
156
  return pa_dataset
157
157
 
158
- def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
158
+ def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
159
159
  """Generate new batches from the existing record batch buffer."""
160
160
  cnt_rbs_num_rows = 0
161
161
  candidates = []
@@ -180,7 +180,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
180
180
  return _record_batch_to_arrays(res)
181
181
 
182
182
 
183
- def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
183
+ def _merge_record_batches(record_batches: list[pa.RecordBatch]) -> pa.RecordBatch:
184
184
  """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
185
185
  if not record_batches:
186
186
  return _EMPTY_RECORD_BATCH
@@ -192,7 +192,7 @@ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatc
192
192
  return batches[0]
193
193
 
194
194
 
195
- def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
195
+ def _record_batch_to_arrays(rb: pa.RecordBatch) -> dict[str, npt.NDArray[Any]]:
196
196
  """Transform the record batch to a (string, numpy array) dict."""
197
197
  batch_dict = {}
198
198
  for column, column_schema in zip(rb, rb.schema):
@@ -1,28 +1,13 @@
1
1
  import os
2
- from typing import (
3
- TYPE_CHECKING,
4
- Any,
5
- Dict,
6
- Generator,
7
- List,
8
- Optional,
9
- Sequence,
10
- Type,
11
- TypeVar,
12
- cast,
13
- )
2
+ from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
14
3
 
15
4
  import numpy.typing as npt
16
5
  from typing_extensions import deprecated
17
6
 
18
7
  from snowflake import snowpark
19
- from snowflake.ml._internal import telemetry
8
+ from snowflake.ml._internal import env, telemetry
20
9
  from snowflake.ml.data import data_ingestor, data_source
21
10
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
22
- from snowflake.ml.modeling._internal.constants import (
23
- IN_ML_RUNTIME_ENV_VAR,
24
- USE_OPTIMIZED_DATA_INGESTOR,
25
- )
26
11
  from snowflake.snowpark import context as sf_context
27
12
 
28
13
  if TYPE_CHECKING:
@@ -43,7 +28,7 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
43
28
  class DataConnector:
44
29
  """Snowflake data reader which provides application integration connectors"""
45
30
 
46
- DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
31
+ DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
47
32
 
48
33
  def __init__(
49
34
  self,
@@ -54,27 +39,22 @@ class DataConnector:
54
39
  self._kwargs = kwargs
55
40
 
56
41
  @classmethod
57
- @snowpark._internal.utils.private_preview(version="1.6.0")
58
42
  def from_dataframe(
59
- cls: Type[DataConnectorType],
43
+ cls: type[DataConnectorType],
60
44
  df: snowpark.DataFrame,
61
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
45
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
62
46
  **kwargs: Any,
63
47
  ) -> DataConnectorType:
64
48
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
65
49
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
66
- return cast(
67
- DataConnectorType,
68
- cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
69
- )
50
+ return cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs)
70
51
 
71
52
  @classmethod
72
- @snowpark._internal.utils.private_preview(version="1.7.3")
73
53
  def from_sql(
74
- cls: Type[DataConnectorType],
54
+ cls: type[DataConnectorType],
75
55
  query: str,
76
56
  session: Optional[snowpark.Session] = None,
77
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
57
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
78
58
  **kwargs: Any,
79
59
  ) -> DataConnectorType:
80
60
  session = session or sf_context.get_active_session()
@@ -83,9 +63,9 @@ class DataConnector:
83
63
 
84
64
  @classmethod
85
65
  def from_dataset(
86
- cls: Type[DataConnectorType],
66
+ cls: type[DataConnectorType],
87
67
  ds: "dataset.Dataset",
88
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
68
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
89
69
  **kwargs: Any,
90
70
  ) -> DataConnectorType:
91
71
  dsv = ds.selected_version
@@ -102,10 +82,10 @@ class DataConnector:
102
82
  func_params_to_log=["sources", "ingestor_class"],
103
83
  )
104
84
  def from_sources(
105
- cls: Type[DataConnectorType],
85
+ cls: type[DataConnectorType],
106
86
  session: snowpark.Session,
107
87
  sources: Sequence[data_source.DataSource],
108
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
88
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
109
89
  **kwargs: Any,
110
90
  ) -> DataConnectorType:
111
91
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
@@ -113,7 +93,7 @@ class DataConnector:
113
93
  return cls(ingestor, **kwargs)
114
94
 
115
95
  @property
116
- def data_sources(self) -> List[data_source.DataSource]:
96
+ def data_sources(self) -> list[data_source.DataSource]:
117
97
  return self._ingestor.data_sources
118
98
 
119
99
  @telemetry.send_api_usage_telemetry(
@@ -139,7 +119,7 @@ class DataConnector:
139
119
  """
140
120
  import tensorflow as tf
141
121
 
142
- def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
122
+ def generator() -> Generator[dict[str, npt.NDArray[Any]], None, None]:
143
123
  yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
144
124
 
145
125
  # Derive TensorFlow signature
@@ -269,11 +249,10 @@ class DataConnector:
269
249
 
270
250
  # Switch to use Runtime's Data Ingester if running in ML runtime
271
251
  # Fail silently if the data ingester is not found
272
- if os.getenv(IN_ML_RUNTIME_ENV_VAR) and os.getenv(USE_OPTIMIZED_DATA_INGESTOR):
252
+ if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
273
253
  try:
274
254
  from runtime_external_entities import get_ingester_class
275
255
 
276
256
  DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
277
257
  except ImportError:
278
258
  """Runtime Default Ingester not found, ignore"""
279
- pass
@@ -1,15 +1,4 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- Any,
4
- Dict,
5
- Iterator,
6
- List,
7
- Optional,
8
- Protocol,
9
- Sequence,
10
- Type,
11
- TypeVar,
12
- )
1
+ from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, Sequence, TypeVar
13
2
 
14
3
  from numpy import typing as npt
15
4
 
@@ -26,12 +15,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
26
15
  class DataIngestor(Protocol):
27
16
  @classmethod
28
17
  def from_sources(
29
- cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
18
+ cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
30
19
  ) -> DataIngestorType:
31
20
  raise NotImplementedError
32
21
 
33
22
  @property
34
- def data_sources(self) -> List[data_source.DataSource]:
23
+ def data_sources(self) -> list[data_source.DataSource]:
35
24
  raise NotImplementedError
36
25
 
37
26
  def to_batches(
@@ -39,7 +28,7 @@ class DataIngestor(Protocol):
39
28
  batch_size: int,
40
29
  shuffle: bool = True,
41
30
  drop_last_batch: bool = True,
42
- ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
31
+ ) -> Iterator[dict[str, npt.NDArray[Any]]]:
43
32
  raise NotImplementedError
44
33
 
45
34
  def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
@@ -1,5 +1,5 @@
1
1
  import dataclasses
2
- from typing import List, Optional, Union
2
+ from typing import Optional, Union
3
3
 
4
4
 
5
5
  @dataclasses.dataclass(frozen=True)
@@ -17,7 +17,7 @@ class DatasetInfo:
17
17
  fully_qualified_name: str
18
18
  version: str
19
19
  url: Optional[str] = None
20
- exclude_cols: Optional[List[str]] = None
20
+ exclude_cols: Optional[list[str]] = None
21
21
 
22
22
 
23
23
  DataSource = Union[DataFrameInfo, DatasetInfo, str]
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import fsspec
4
4
  import pyarrow as pa
@@ -33,7 +33,7 @@ def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFr
33
33
 
34
34
  def get_dataframe_result_batches(
35
35
  session: snowpark.Session, df_info: data_source.DataFrameInfo
36
- ) -> List[result_batch.ResultBatch]:
36
+ ) -> list[result_batch.ResultBatch]:
37
37
  """Retrieve the ResultBatches for a given query"""
38
38
  cursor = _get_dataframe_cursor(session, df_info)
39
39
  batches = cursor.get_result_batches()
@@ -63,7 +63,7 @@ def get_dataset_filesystem(
63
63
 
64
64
  def get_dataset_files(
65
65
  session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
66
- ) -> List[str]:
66
+ ) -> list[str]:
67
67
  """Get the list of files in a given Dataset"""
68
68
  if filesystem is None:
69
69
  filesystem = get_dataset_filesystem(session, ds_info)