snowpark-connect 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (476) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +116 -4
  2. snowflake/snowpark_connect/config.py +13 -0
  3. snowflake/snowpark_connect/constants.py +0 -29
  4. snowflake/snowpark_connect/dataframe_container.py +6 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +56 -1
  6. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  7. snowflake/snowpark_connect/expression/literal.py +18 -2
  8. snowflake/snowpark_connect/expression/map_cast.py +5 -8
  9. snowflake/snowpark_connect/expression/map_expression.py +10 -1
  10. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  11. snowflake/snowpark_connect/expression/map_sql_expression.py +23 -1
  12. snowflake/snowpark_connect/expression/map_udf.py +26 -8
  13. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +199 -15
  14. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +44 -16
  15. snowflake/snowpark_connect/expression/map_unresolved_function.py +836 -365
  16. snowflake/snowpark_connect/expression/map_unresolved_star.py +3 -2
  17. snowflake/snowpark_connect/hidden_column.py +39 -0
  18. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  19. snowflake/snowpark_connect/includes/jars/{hadoop-client-api-3.3.4.jar → spark-connect-client-jvm_2.12-3.5.6.jar} +0 -0
  20. snowflake/snowpark_connect/relation/map_column_ops.py +18 -36
  21. snowflake/snowpark_connect/relation/map_extension.py +56 -15
  22. snowflake/snowpark_connect/relation/map_join.py +258 -62
  23. snowflake/snowpark_connect/relation/map_row_ops.py +2 -29
  24. snowflake/snowpark_connect/relation/map_sql.py +88 -11
  25. snowflake/snowpark_connect/relation/map_udtf.py +4 -2
  26. snowflake/snowpark_connect/relation/read/map_read.py +3 -3
  27. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +1 -1
  28. snowflake/snowpark_connect/relation/read/map_read_json.py +8 -1
  29. snowflake/snowpark_connect/relation/read/map_read_table.py +1 -9
  30. snowflake/snowpark_connect/relation/read/reader_config.py +3 -1
  31. snowflake/snowpark_connect/relation/read/utils.py +6 -7
  32. snowflake/snowpark_connect/relation/utils.py +1 -170
  33. snowflake/snowpark_connect/relation/write/map_write.py +62 -53
  34. snowflake/snowpark_connect/resources_initializer.py +29 -1
  35. snowflake/snowpark_connect/server.py +18 -3
  36. snowflake/snowpark_connect/type_mapping.py +29 -25
  37. snowflake/snowpark_connect/typed_column.py +14 -0
  38. snowflake/snowpark_connect/utils/artifacts.py +23 -0
  39. snowflake/snowpark_connect/utils/context.py +6 -1
  40. snowflake/snowpark_connect/utils/scala_udf_utils.py +588 -0
  41. snowflake/snowpark_connect/utils/telemetry.py +6 -17
  42. snowflake/snowpark_connect/utils/udf_helper.py +2 -0
  43. snowflake/snowpark_connect/utils/udf_utils.py +38 -7
  44. snowflake/snowpark_connect/utils/udtf_utils.py +17 -3
  45. snowflake/snowpark_connect/version.py +1 -1
  46. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/METADATA +1 -1
  47. snowpark_connect-0.25.0.dist-info/RECORD +477 -0
  48. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  52. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +0 -16
  53. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +0 -60
  54. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +0 -306
  55. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +0 -16
  56. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +0 -53
  57. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +0 -50
  58. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +0 -43
  59. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +0 -114
  60. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +0 -47
  61. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +0 -43
  62. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +0 -46
  63. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +0 -238
  64. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +0 -194
  65. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +0 -156
  66. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +0 -184
  67. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +0 -78
  68. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +0 -292
  69. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +0 -50
  70. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +0 -152
  71. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +0 -456
  72. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +0 -96
  73. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +0 -186
  74. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +0 -77
  75. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +0 -401
  76. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +0 -528
  77. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +0 -82
  78. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +0 -409
  79. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +0 -55
  80. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +0 -441
  81. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +0 -546
  82. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +0 -71
  83. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +0 -52
  84. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +0 -494
  85. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +0 -85
  86. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +0 -138
  87. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +0 -16
  88. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +0 -151
  89. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +0 -97
  90. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +0 -143
  91. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +0 -551
  92. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +0 -137
  93. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +0 -96
  94. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +0 -142
  95. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +0 -16
  96. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +0 -137
  97. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +0 -561
  98. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +0 -172
  99. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +0 -16
  100. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +0 -353
  101. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +0 -192
  102. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +0 -680
  103. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +0 -206
  104. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +0 -471
  105. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +0 -108
  106. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +0 -16
  107. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +0 -16
  108. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +0 -177
  109. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +0 -575
  110. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +0 -235
  111. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +0 -653
  112. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +0 -463
  113. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +0 -86
  114. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +0 -151
  115. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +0 -139
  116. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +0 -458
  117. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +0 -86
  118. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +0 -202
  119. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +0 -520
  120. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +0 -361
  121. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +0 -16
  122. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +0 -16
  123. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +0 -40
  124. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +0 -42
  125. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +0 -40
  126. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +0 -37
  127. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +0 -60
  128. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +0 -40
  129. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +0 -40
  130. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +0 -90
  131. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +0 -40
  132. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +0 -40
  133. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +0 -40
  134. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +0 -42
  135. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +0 -37
  136. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +0 -16
  137. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +0 -36
  138. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +0 -42
  139. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +0 -47
  140. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +0 -55
  141. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +0 -40
  142. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +0 -47
  143. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +0 -47
  144. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +0 -42
  145. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +0 -43
  146. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +0 -47
  147. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +0 -43
  148. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +0 -47
  149. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +0 -47
  150. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +0 -40
  151. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +0 -226
  152. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +0 -16
  153. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +0 -39
  154. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +0 -55
  155. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +0 -39
  156. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +0 -39
  157. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +0 -39
  158. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +0 -39
  159. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +0 -39
  160. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +0 -43
  161. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +0 -43
  162. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +0 -16
  163. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +0 -40
  164. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +0 -39
  165. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +0 -42
  166. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +0 -42
  167. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +0 -37
  168. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +0 -40
  169. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +0 -42
  170. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +0 -48
  171. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +0 -40
  172. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +0 -16
  173. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +0 -40
  174. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +0 -41
  175. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +0 -67
  176. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +0 -40
  177. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +0 -55
  178. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +0 -40
  179. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +0 -38
  180. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +0 -55
  181. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +0 -39
  182. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +0 -38
  183. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +0 -16
  184. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +0 -40
  185. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +0 -50
  186. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +0 -73
  187. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +0 -39
  188. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +0 -40
  189. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +0 -40
  190. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +0 -40
  191. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +0 -48
  192. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +0 -39
  193. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +0 -16
  194. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +0 -40
  195. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +0 -16
  196. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +0 -45
  197. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +0 -45
  198. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +0 -49
  199. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +0 -37
  200. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +0 -53
  201. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +0 -45
  202. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +0 -16
  203. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +0 -38
  204. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +0 -37
  205. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +0 -37
  206. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +0 -38
  207. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +0 -37
  208. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +0 -40
  209. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +0 -40
  210. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +0 -38
  211. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +0 -40
  212. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +0 -37
  213. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +0 -38
  214. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +0 -38
  215. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +0 -66
  216. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +0 -37
  217. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +0 -37
  218. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +0 -42
  219. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +0 -39
  220. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +0 -49
  221. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +0 -37
  222. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +0 -39
  223. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +0 -49
  224. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +0 -53
  225. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +0 -43
  226. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +0 -49
  227. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +0 -39
  228. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +0 -41
  229. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +0 -39
  230. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +0 -60
  231. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +0 -48
  232. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +0 -39
  233. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +0 -44
  234. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +0 -84
  235. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +0 -37
  236. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +0 -45
  237. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +0 -39
  238. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +0 -39
  239. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +0 -37
  240. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +0 -39
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +0 -39
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +0 -39
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +0 -39
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +0 -43
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +0 -37
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +0 -36
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +0 -37
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +0 -39
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +0 -16
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +0 -107
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +0 -224
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +0 -825
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +0 -562
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +0 -368
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +0 -257
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +0 -260
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +0 -178
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +0 -184
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +0 -497
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +0 -140
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +0 -354
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +0 -219
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +0 -192
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +0 -228
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +0 -16
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +0 -118
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +0 -198
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +0 -181
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +0 -103
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +0 -141
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +0 -109
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +0 -136
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +0 -125
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +0 -217
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +0 -16
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +0 -384
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +0 -598
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +0 -73
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +0 -869
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +0 -487
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +0 -309
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +0 -156
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +0 -149
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +0 -163
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +0 -16
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +0 -311
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +0 -524
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +0 -419
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +0 -144
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +0 -979
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +0 -234
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +0 -206
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +0 -421
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +0 -187
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +0 -397
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +0 -16
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +0 -100
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +0 -2743
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +0 -484
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +0 -276
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +0 -432
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +0 -310
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +0 -257
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +0 -160
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +0 -128
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +0 -16
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +0 -137
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +0 -16
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +0 -170
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +0 -547
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +0 -285
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +0 -106
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +0 -409
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +0 -247
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +0 -16
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +0 -105
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +0 -197
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +0 -137
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +0 -227
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +0 -634
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +0 -88
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +0 -139
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +0 -475
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +0 -265
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +0 -818
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +0 -162
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +0 -780
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +0 -741
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +0 -160
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +0 -453
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +0 -281
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +0 -487
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +0 -109
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +0 -434
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +0 -253
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +0 -152
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +0 -162
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +0 -234
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +0 -1339
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +0 -82
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +0 -124
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +0 -638
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +0 -200
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +0 -1355
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +0 -655
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +0 -113
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +0 -118
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +0 -192
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +0 -346
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +0 -495
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +0 -263
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +0 -59
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +0 -85
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +0 -364
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +0 -362
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +0 -46
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +0 -123
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +0 -581
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +0 -447
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +0 -301
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +0 -465
  362. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +0 -16
  363. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +0 -83
  364. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +0 -16
  365. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +0 -16
  366. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +0 -16
  367. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +0 -420
  368. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +0 -358
  369. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +0 -16
  370. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +0 -36
  371. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +0 -44
  372. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +0 -116
  373. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +0 -35
  374. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +0 -3612
  375. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +0 -1042
  376. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +0 -2381
  377. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +0 -1060
  378. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +0 -163
  379. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +0 -38
  380. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +0 -48
  381. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +0 -36
  382. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +0 -55
  383. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +0 -36
  384. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +0 -96
  385. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +0 -44
  386. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +0 -36
  387. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +0 -59
  388. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +0 -36
  389. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +0 -59
  390. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +0 -74
  391. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +0 -62
  392. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +0 -58
  393. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +0 -70
  394. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +0 -50
  395. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +0 -68
  396. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +0 -40
  397. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +0 -46
  398. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +0 -44
  399. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +0 -100
  400. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +0 -100
  401. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +0 -163
  402. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +0 -181
  403. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +0 -42
  404. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +0 -16
  405. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +0 -623
  406. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +0 -869
  407. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +0 -342
  408. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +0 -436
  409. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +0 -363
  410. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +0 -592
  411. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +0 -1503
  412. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +0 -392
  413. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +0 -375
  414. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +0 -411
  415. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +0 -16
  416. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +0 -401
  417. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +0 -295
  418. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +0 -106
  419. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +0 -558
  420. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +0 -1346
  421. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +0 -182
  422. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +0 -202
  423. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +0 -503
  424. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +0 -225
  425. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +0 -83
  426. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +0 -201
  427. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +0 -1931
  428. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +0 -256
  429. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +0 -69
  430. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +0 -1349
  431. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +0 -53
  432. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +0 -68
  433. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +0 -283
  434. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +0 -155
  435. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +0 -412
  436. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +0 -1581
  437. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +0 -961
  438. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +0 -165
  439. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +0 -1456
  440. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +0 -1686
  441. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +0 -16
  442. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +0 -184
  443. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +0 -706
  444. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +0 -118
  445. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +0 -160
  446. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +0 -16
  447. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +0 -306
  448. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +0 -196
  449. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +0 -44
  450. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +0 -346
  451. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +0 -89
  452. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +0 -124
  453. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +0 -69
  454. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +0 -167
  455. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +0 -194
  456. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +0 -168
  457. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +0 -939
  458. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +0 -52
  459. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +0 -66
  460. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +0 -368
  461. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +0 -257
  462. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +0 -267
  463. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +0 -153
  464. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +0 -130
  465. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +0 -350
  466. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +0 -97
  467. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +0 -271
  468. snowpark_connect-0.23.0.dist-info/RECORD +0 -893
  469. {snowpark_connect-0.23.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-connect +0 -0
  470. {snowpark_connect-0.23.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-session +0 -0
  471. {snowpark_connect-0.23.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-submit +0 -0
  472. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/WHEEL +0 -0
  473. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/LICENSE-binary +0 -0
  474. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  475. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/NOTICE-binary +0 -0
  476. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.25.0.dist-info}/top_level.txt +0 -0
@@ -1,3612 +0,0 @@
1
- #
2
- # Licensed to the Apache Software Foundation (ASF) under one or more
3
- # contributor license agreements. See the NOTICE file distributed with
4
- # this work for additional information regarding copyright ownership.
5
- # The ASF licenses this file to You under the Apache License, Version 2.0
6
- # (the "License"); you may not use this file except in compliance with
7
- # the License. You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
-
18
- import array
19
- import datetime
20
- import os
21
- import unittest
22
- import random
23
- import shutil
24
- import string
25
- import tempfile
26
- import uuid
27
- from collections import defaultdict
28
-
29
- from pyspark.errors import (
30
- PySparkAttributeError,
31
- PySparkTypeError,
32
- PySparkException,
33
- PySparkValueError,
34
- )
35
- from pyspark.errors.exceptions.base import SessionNotSameException
36
- from pyspark.sql import SparkSession as PySparkSession, Row
37
- from pyspark.sql.types import (
38
- StructType,
39
- StructField,
40
- LongType,
41
- StringType,
42
- IntegerType,
43
- MapType,
44
- ArrayType,
45
- Row,
46
- )
47
-
48
- from pyspark.testing.objects import (
49
- MyObject,
50
- PythonOnlyUDT,
51
- ExamplePoint,
52
- PythonOnlyPoint,
53
- )
54
- from pyspark.testing.sqlutils import SQLTestUtils
55
- from pyspark.testing.connectutils import (
56
- should_test_connect,
57
- ReusedConnectTestCase,
58
- connect_requirement_message,
59
- )
60
- from pyspark.testing.pandasutils import PandasOnSparkTestUtils
61
- from pyspark.errors.exceptions.connect import (
62
- AnalysisException,
63
- ParseException,
64
- SparkConnectException,
65
- )
66
-
67
- if should_test_connect:
68
- import grpc
69
- import pandas as pd
70
- import numpy as np
71
- from pyspark.sql.connect.proto import Expression as ProtoExpression
72
- from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
73
- from pyspark.sql.connect.client import ChannelBuilder
74
- from pyspark.sql.connect.column import Column
75
- from pyspark.sql.connect.readwriter import DataFrameWriterV2
76
- from pyspark.sql.dataframe import DataFrame
77
- from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
78
- from pyspark.sql import functions as SF
79
- from pyspark.sql.connect import functions as CF
80
- from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
81
-
82
-
83
- @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
84
- class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
85
- """Parent test fixture class for all Spark Connect related
86
- test cases."""
87
-
88
- @classmethod
89
- def setUpClass(cls):
90
- super(SparkConnectSQLTestCase, cls).setUpClass()
91
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
92
- # PySpark libraries.
93
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
94
-
95
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
96
- cls.spark = PySparkSession._instantiatedSession
97
- assert cls.spark is not None
98
-
99
- cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
100
- cls.testDataStr = [Row(key=str(i)) for i in range(100)]
101
- cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
102
- cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
103
-
104
- cls.tbl_name = "test_connect_basic_table_1"
105
- cls.tbl_name2 = "test_connect_basic_table_2"
106
- cls.tbl_name3 = "test_connect_basic_table_3"
107
- cls.tbl_name4 = "test_connect_basic_table_4"
108
- cls.tbl_name_empty = "test_connect_basic_table_empty"
109
-
110
- # Cleanup test data
111
- cls.spark_connect_clean_up_test_data()
112
- # Load test data
113
- cls.spark_connect_load_test_data()
114
-
115
- @classmethod
116
- def tearDownClass(cls):
117
- try:
118
- cls.spark_connect_clean_up_test_data()
119
- # Stopping Spark Connect closes the session in JVM at the server.
120
- cls.spark = cls.connect
121
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
122
- finally:
123
- super(SparkConnectSQLTestCase, cls).tearDownClass()
124
-
125
- @classmethod
126
- def spark_connect_load_test_data(cls):
127
- df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
128
- # Since we might create multiple Spark sessions, we need to create global temporary view
129
- # that is specifically maintained in the "global_temp" schema.
130
- df.write.saveAsTable(cls.tbl_name)
131
- df2 = cls.spark.createDataFrame(
132
- [(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]
133
- )
134
- df2.write.saveAsTable(cls.tbl_name2)
135
- df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"])
136
- df3.write.saveAsTable(cls.tbl_name3)
137
- df4 = cls.spark.createDataFrame(
138
- [(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"]
139
- )
140
- df4.write.saveAsTable(cls.tbl_name4)
141
- empty_table_schema = StructType(
142
- [
143
- StructField("firstname", StringType(), True),
144
- StructField("middlename", StringType(), True),
145
- StructField("lastname", StringType(), True),
146
- ]
147
- )
148
- emptyRDD = cls.spark.sparkContext.emptyRDD()
149
- empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
150
- empty_df.write.saveAsTable(cls.tbl_name_empty)
151
-
152
- @classmethod
153
- def spark_connect_clean_up_test_data(cls):
154
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
155
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
156
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name3))
157
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name4))
158
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
159
-
160
-
161
- class SparkConnectBasicTests(SparkConnectSQLTestCase):
162
- def test_df_getattr_behavior(self):
163
- cdf = self.connect.range(10)
164
- sdf = self.spark.range(10)
165
-
166
- sdf._simple_extension = 10
167
- cdf._simple_extension = 10
168
-
169
- self.assertEqual(sdf._simple_extension, cdf._simple_extension)
170
- self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))
171
-
172
- self.assertTrue(hasattr(cdf, "_simple_extension"))
173
- self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))
174
-
175
- def test_df_get_item(self):
176
- # SPARK-41779: test __getitem__
177
-
178
- query = """
179
- SELECT * FROM VALUES
180
- (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
181
- AS tab(a, b, c)
182
- """
183
-
184
- # +-----+----+----+
185
- # | a| b| c|
186
- # +-----+----+----+
187
- # | true| 1|NULL|
188
- # |false|NULL| 2.0|
189
- # | NULL| 3| 3.0|
190
- # +-----+----+----+
191
-
192
- cdf = self.connect.sql(query)
193
- sdf = self.spark.sql(query)
194
-
195
- # filter
196
- self.assert_eq(
197
- cdf[cdf.a].toPandas(),
198
- sdf[sdf.a].toPandas(),
199
- )
200
- self.assert_eq(
201
- cdf[cdf.b.isin(2, 3)].toPandas(),
202
- sdf[sdf.b.isin(2, 3)].toPandas(),
203
- )
204
- self.assert_eq(
205
- cdf[cdf.c > 1.5].toPandas(),
206
- sdf[sdf.c > 1.5].toPandas(),
207
- )
208
-
209
- # select
210
- self.assert_eq(
211
- cdf[[cdf.a, "b", cdf.c]].toPandas(),
212
- sdf[[sdf.a, "b", sdf.c]].toPandas(),
213
- )
214
- self.assert_eq(
215
- cdf[(cdf.a, "b", cdf.c)].toPandas(),
216
- sdf[(sdf.a, "b", sdf.c)].toPandas(),
217
- )
218
-
219
- # select by index
220
- self.assertTrue(isinstance(cdf[0], Column))
221
- self.assertTrue(isinstance(cdf[1], Column))
222
- self.assertTrue(isinstance(cdf[2], Column))
223
-
224
- self.assert_eq(
225
- cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
226
- sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
227
- )
228
-
229
- # check error
230
- with self.assertRaises(PySparkTypeError) as pe:
231
- cdf[1.5]
232
-
233
- self.check_error(
234
- exception=pe.exception,
235
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
236
- message_parameters={
237
- "arg_name": "item",
238
- "arg_type": "float",
239
- },
240
- )
241
-
242
- with self.assertRaises(PySparkTypeError) as pe:
243
- cdf[None]
244
-
245
- self.check_error(
246
- exception=pe.exception,
247
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
248
- message_parameters={
249
- "arg_name": "item",
250
- "arg_type": "NoneType",
251
- },
252
- )
253
-
254
- with self.assertRaises(PySparkTypeError) as pe:
255
- cdf[cdf]
256
-
257
- self.check_error(
258
- exception=pe.exception,
259
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
260
- message_parameters={
261
- "arg_name": "item",
262
- "arg_type": "DataFrame",
263
- },
264
- )
265
-
266
- def test_error_handling(self):
267
- # SPARK-41533 Proper error handling for Spark Connect
268
- df = self.connect.range(10).select("id2")
269
- with self.assertRaises(AnalysisException):
270
- df.collect()
271
-
272
- def test_simple_read(self):
273
- df = self.connect.read.table(self.tbl_name)
274
- data = df.limit(10).toPandas()
275
- # Check that the limit is applied
276
- self.assertEqual(len(data.index), 10)
277
-
278
- def test_json(self):
279
- with tempfile.TemporaryDirectory() as d:
280
- # Write a DataFrame into a JSON file
281
- self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
282
- "overwrite"
283
- ).format("json").save(d)
284
- # Read the JSON file as a DataFrame.
285
- self.assert_eq(self.connect.read.json(d).toPandas(), self.spark.read.json(d).toPandas())
286
-
287
- for schema in [
288
- "age INT, name STRING",
289
- StructType(
290
- [
291
- StructField("age", IntegerType()),
292
- StructField("name", StringType()),
293
- ]
294
- ),
295
- ]:
296
- self.assert_eq(
297
- self.connect.read.json(path=d, schema=schema).toPandas(),
298
- self.spark.read.json(path=d, schema=schema).toPandas(),
299
- )
300
-
301
- self.assert_eq(
302
- self.connect.read.json(path=d, primitivesAsString=True).toPandas(),
303
- self.spark.read.json(path=d, primitivesAsString=True).toPandas(),
304
- )
305
-
306
- def test_parquet(self):
307
- # SPARK-41445: Implement DataFrameReader.parquet
308
- with tempfile.TemporaryDirectory() as d:
309
- # Write a DataFrame into a JSON file
310
- self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
311
- "overwrite"
312
- ).format("parquet").save(d)
313
- # Read the Parquet file as a DataFrame.
314
- self.assert_eq(
315
- self.connect.read.parquet(d).toPandas(), self.spark.read.parquet(d).toPandas()
316
- )
317
-
318
- def test_text(self):
319
- # SPARK-41849: Implement DataFrameReader.text
320
- with tempfile.TemporaryDirectory() as d:
321
- # Write a DataFrame into a text file
322
- self.spark.createDataFrame(
323
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
324
- ).write.mode("overwrite").format("text").save(d)
325
- # Read the text file as a DataFrame.
326
- self.assert_eq(self.connect.read.text(d).toPandas(), self.spark.read.text(d).toPandas())
327
-
328
- def test_csv(self):
329
- # SPARK-42011: Implement DataFrameReader.csv
330
- with tempfile.TemporaryDirectory() as d:
331
- # Write a DataFrame into a text file
332
- self.spark.createDataFrame(
333
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
334
- ).write.mode("overwrite").format("csv").save(d)
335
- # Read the text file as a DataFrame.
336
- self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas())
337
-
338
- def test_multi_paths(self):
339
- # SPARK-42041: DataFrameReader should support list of paths
340
-
341
- with tempfile.TemporaryDirectory() as d:
342
- text_files = []
343
- for i in range(0, 3):
344
- text_file = f"{d}/text-{i}.text"
345
- shutil.copyfile("python/test_support/sql/text-test.txt", text_file)
346
- text_files.append(text_file)
347
-
348
- self.assertEqual(
349
- self.connect.read.text(text_files).collect(),
350
- self.spark.read.text(text_files).collect(),
351
- )
352
-
353
- with tempfile.TemporaryDirectory() as d:
354
- json_files = []
355
- for i in range(0, 5):
356
- json_file = f"{d}/json-{i}.json"
357
- shutil.copyfile("python/test_support/sql/people.json", json_file)
358
- json_files.append(json_file)
359
-
360
- self.assertEqual(
361
- self.connect.read.json(json_files).collect(),
362
- self.spark.read.json(json_files).collect(),
363
- )
364
-
365
- def test_orc(self):
366
- # SPARK-42012: Implement DataFrameReader.orc
367
- with tempfile.TemporaryDirectory() as d:
368
- # Write a DataFrame into a text file
369
- self.spark.createDataFrame(
370
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
371
- ).write.mode("overwrite").format("orc").save(d)
372
- # Read the text file as a DataFrame.
373
- self.assert_eq(self.connect.read.orc(d).toPandas(), self.spark.read.orc(d).toPandas())
374
-
375
- def test_join_condition_column_list_columns(self):
376
- left_connect_df = self.connect.read.table(self.tbl_name)
377
- right_connect_df = self.connect.read.table(self.tbl_name2)
378
- left_spark_df = self.spark.read.table(self.tbl_name)
379
- right_spark_df = self.spark.read.table(self.tbl_name2)
380
- joined_plan = left_connect_df.join(
381
- other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner"
382
- )
383
- joined_plan2 = left_spark_df.join(
384
- other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner"
385
- )
386
- self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas())
387
-
388
- joined_plan3 = left_connect_df.join(
389
- other=right_connect_df,
390
- on=[
391
- left_connect_df.id == right_connect_df.col1,
392
- left_connect_df.name == right_connect_df.col2,
393
- ],
394
- how="inner",
395
- )
396
- joined_plan4 = left_spark_df.join(
397
- other=right_spark_df,
398
- on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2],
399
- how="inner",
400
- )
401
- self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
402
-
403
- def test_join_ambiguous_cols(self):
404
- # SPARK-41812: test join with ambiguous columns
405
- data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
406
- cdf1 = self.connect.createDataFrame(data1)
407
- sdf1 = self.spark.createDataFrame(data1)
408
-
409
- data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
410
- cdf2 = self.connect.createDataFrame(data2)
411
- sdf2 = self.spark.createDataFrame(data2)
412
-
413
- cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
414
- sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
415
-
416
- self.assertEqual(cdf3.schema, sdf3.schema)
417
- self.assertEqual(cdf3.collect(), sdf3.collect())
418
-
419
- cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
420
- sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
421
-
422
- self.assertEqual(cdf4.schema, sdf4.schema)
423
- self.assertEqual(cdf4.collect(), sdf4.collect())
424
-
425
- cdf5 = cdf1.join(
426
- cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
427
- )
428
- sdf5 = sdf1.join(
429
- sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
430
- )
431
-
432
- self.assertEqual(cdf5.schema, sdf5.schema)
433
- self.assertEqual(cdf5.collect(), sdf5.collect())
434
-
435
- cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
436
- sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
437
-
438
- self.assertEqual(cdf6.schema, sdf6.schema)
439
- self.assertEqual(cdf6.collect(), sdf6.collect())
440
-
441
- cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
442
- sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
443
-
444
- self.assertEqual(cdf7.schema, sdf7.schema)
445
- self.assertEqual(cdf7.collect(), sdf7.collect())
446
-
447
- def test_invalid_column(self):
448
- # SPARK-41812: fail df1.select(df2.col)
449
- data1 = [Row(a=1, b=2, c=3)]
450
- cdf1 = self.connect.createDataFrame(data1)
451
-
452
- data2 = [Row(a=2, b=0)]
453
- cdf2 = self.connect.createDataFrame(data2)
454
-
455
- with self.assertRaises(AnalysisException):
456
- cdf1.select(cdf2.a).schema
457
-
458
- with self.assertRaises(AnalysisException):
459
- cdf2.withColumn("x", cdf1.a + 1).schema
460
-
461
- with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
462
- cdf3 = cdf1.select(cdf1.a)
463
- cdf3.select(cdf1.b).schema
464
-
465
- def test_collect(self):
466
- cdf = self.connect.read.table(self.tbl_name)
467
- sdf = self.spark.read.table(self.tbl_name)
468
-
469
- data = cdf.limit(10).collect()
470
- self.assertEqual(len(data), 10)
471
- # Check Row has schema column names.
472
- self.assertTrue("name" in data[0])
473
- self.assertTrue("id" in data[0])
474
-
475
- cdf = cdf.select(
476
- CF.log("id"), CF.log("id"), CF.struct("id", "name"), CF.struct("id", "name")
477
- ).limit(10)
478
- sdf = sdf.select(
479
- SF.log("id"), SF.log("id"), SF.struct("id", "name"), SF.struct("id", "name")
480
- ).limit(10)
481
-
482
- self.assertEqual(
483
- cdf.collect(),
484
- sdf.collect(),
485
- )
486
-
487
- def test_collect_timestamp(self):
488
- query = """
489
- SELECT * FROM VALUES
490
- (TIMESTAMP('2022-12-25 10:30:00'), 1),
491
- (TIMESTAMP('2022-12-25 10:31:00'), 2),
492
- (TIMESTAMP('2022-12-25 10:32:00'), 1),
493
- (TIMESTAMP('2022-12-25 10:33:00'), 2),
494
- (TIMESTAMP('2022-12-26 09:30:00'), 1),
495
- (TIMESTAMP('2022-12-26 09:35:00'), 3)
496
- AS tab(date, val)
497
- """
498
-
499
- cdf = self.connect.sql(query)
500
- sdf = self.spark.sql(query)
501
-
502
- self.assertEqual(cdf.schema, sdf.schema)
503
-
504
- self.assertEqual(cdf.collect(), sdf.collect())
505
-
506
- self.assertEqual(
507
- cdf.select(CF.date_trunc("year", cdf.date).alias("year")).collect(),
508
- sdf.select(SF.date_trunc("year", sdf.date).alias("year")).collect(),
509
- )
510
-
511
- def test_with_columns_renamed(self):
512
- # SPARK-41312: test DataFrame.withColumnsRenamed()
513
- self.assertEqual(
514
- self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
515
- self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
516
- )
517
- self.assertEqual(
518
- self.connect.read.table(self.tbl_name)
519
- .withColumnsRenamed({"id": "id_new", "name": "name_new"})
520
- .schema,
521
- self.spark.read.table(self.tbl_name)
522
- .withColumnsRenamed({"id": "id_new", "name": "name_new"})
523
- .schema,
524
- )
525
-
526
- def test_with_local_data(self):
527
- """SPARK-41114: Test creating a dataframe using local data"""
528
- pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
529
- df = self.connect.createDataFrame(pdf)
530
- rows = df.filter(df.a == CF.lit(3)).collect()
531
- self.assertTrue(len(rows) == 1)
532
- self.assertEqual(rows[0][0], 3)
533
- self.assertEqual(rows[0][1], "c")
534
-
535
- # Check correct behavior for empty DataFrame
536
- pdf = pd.DataFrame({"a": []})
537
- with self.assertRaises(ValueError):
538
- self.connect.createDataFrame(pdf)
539
-
540
- def test_with_local_ndarray(self):
541
- """SPARK-41446: Test creating a dataframe using local list"""
542
- data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
543
-
544
- sdf = self.spark.createDataFrame(data)
545
- cdf = self.connect.createDataFrame(data)
546
- self.assertEqual(sdf.schema, cdf.schema)
547
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
548
-
549
- for schema in [
550
- StructType(
551
- [
552
- StructField("col1", IntegerType(), True),
553
- StructField("col2", IntegerType(), True),
554
- StructField("col3", IntegerType(), True),
555
- StructField("col4", IntegerType(), True),
556
- ]
557
- ),
558
- "struct<col1 int, col2 int, col3 int, col4 int>",
559
- "col1 int, col2 int, col3 int, col4 int",
560
- "col1 int, col2 long, col3 string, col4 long",
561
- "col1 int, col2 string, col3 short, col4 long",
562
- ["a", "b", "c", "d"],
563
- ("x1", "x2", "x3", "x4"),
564
- ]:
565
- with self.subTest(schema=schema):
566
- sdf = self.spark.createDataFrame(data, schema=schema)
567
- cdf = self.connect.createDataFrame(data, schema=schema)
568
-
569
- self.assertEqual(sdf.schema, cdf.schema)
570
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
571
-
572
- with self.assertRaises(PySparkValueError) as pe:
573
- self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
574
-
575
- self.check_error(
576
- exception=pe.exception,
577
- error_class="AXIS_LENGTH_MISMATCH",
578
- message_parameters={"expected_length": "5", "actual_length": "4"},
579
- )
580
-
581
- with self.assertRaises(ParseException):
582
- self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
583
-
584
- with self.assertRaises(PySparkValueError) as pe:
585
- self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
586
-
587
- self.check_error(
588
- exception=pe.exception,
589
- error_class="AXIS_LENGTH_MISMATCH",
590
- message_parameters={"expected_length": "3", "actual_length": "4"},
591
- )
592
-
593
- # test 1 dim ndarray
594
- data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
595
- self.assertEqual(data.ndim, 1)
596
-
597
- sdf = self.spark.createDataFrame(data)
598
- cdf = self.connect.createDataFrame(data)
599
- self.assertEqual(sdf.schema, cdf.schema)
600
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
601
-
602
- def test_with_local_list(self):
603
- """SPARK-41446: Test creating a dataframe using local list"""
604
- data = [[1, 2, 3, 4]]
605
-
606
- sdf = self.spark.createDataFrame(data)
607
- cdf = self.connect.createDataFrame(data)
608
- self.assertEqual(sdf.schema, cdf.schema)
609
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
610
-
611
- for schema in [
612
- "struct<col1 int, col2 int, col3 int, col4 int>",
613
- "col1 int, col2 int, col3 int, col4 int",
614
- "col1 int, col2 long, col3 string, col4 long",
615
- "col1 int, col2 string, col3 short, col4 long",
616
- ["a", "b", "c", "d"],
617
- ("x1", "x2", "x3", "x4"),
618
- ]:
619
- sdf = self.spark.createDataFrame(data, schema=schema)
620
- cdf = self.connect.createDataFrame(data, schema=schema)
621
-
622
- self.assertEqual(sdf.schema, cdf.schema)
623
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
624
-
625
- with self.assertRaises(PySparkValueError) as pe:
626
- self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
627
-
628
- self.check_error(
629
- exception=pe.exception,
630
- error_class="AXIS_LENGTH_MISMATCH",
631
- message_parameters={"expected_length": "5", "actual_length": "4"},
632
- )
633
-
634
- with self.assertRaises(ParseException):
635
- self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
636
-
637
- with self.assertRaisesRegex(
638
- ValueError,
639
- "Length mismatch: Expected axis has 3 elements, new values have 4 elements",
640
- ):
641
- self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
642
-
643
- def test_with_local_rows(self):
644
- # SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries
645
- rows = [
646
- Row(course="dotNET", year=2012, earnings=10000),
647
- Row(course="Java", year=2012, earnings=20000),
648
- Row(course="dotNET", year=2012, earnings=5000),
649
- Row(course="dotNET", year=2013, earnings=48000),
650
- Row(course="Java", year=2013, earnings=30000),
651
- Row(course="Scala", year=2022, earnings=None),
652
- ]
653
- dicts = [row.asDict() for row in rows]
654
-
655
- for data in [rows, dicts]:
656
- sdf = self.spark.createDataFrame(data)
657
- cdf = self.connect.createDataFrame(data)
658
-
659
- self.assertEqual(sdf.schema, cdf.schema)
660
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
661
-
662
- # test with rename
663
- sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
664
- cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
665
-
666
- self.assertEqual(sdf.schema, cdf.schema)
667
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
668
-
669
- def test_streaming_local_relation(self):
670
- threshold_conf = "spark.sql.session.localRelationCacheThreshold"
671
- old_threshold = self.connect.conf.get(threshold_conf)
672
- threshold = 1024 * 1024
673
- self.connect.conf.set(threshold_conf, threshold)
674
- try:
675
- suffix = "abcdef"
676
- letters = string.ascii_lowercase
677
- str = "".join(random.choice(letters) for i in range(threshold)) + suffix
678
- data = [[0, str], [1, str]]
679
- for i in range(0, 2):
680
- cdf = self.connect.createDataFrame(data, ["a", "b"])
681
- self.assert_eq(cdf.count(), len(data))
682
- self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False)
683
- finally:
684
- self.connect.conf.set(threshold_conf, old_threshold)
685
-
686
- def test_with_atom_type(self):
687
- for data in [[(1), (2), (3)], [1, 2, 3]]:
688
- for schema in ["long", "int", "short"]:
689
- sdf = self.spark.createDataFrame(data, schema=schema)
690
- cdf = self.connect.createDataFrame(data, schema=schema)
691
-
692
- self.assertEqual(sdf.schema, cdf.schema)
693
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
694
-
695
- def test_with_none_and_nan(self):
696
- # SPARK-41855: make createDataFrame support None and NaN
697
- # SPARK-41814: test with eqNullSafe
698
- data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
699
- data2 = [Row(id=1, value=np.nan), Row(id=2, value=42.0), Row(id=3, value=None)]
700
- data3 = [
701
- {"id": 1, "value": float("NaN")},
702
- {"id": 2, "value": 42.0},
703
- {"id": 3, "value": None},
704
- ]
705
- data4 = [{"id": 1, "value": np.nan}, {"id": 2, "value": 42.0}, {"id": 3, "value": None}]
706
- data5 = [(1, float("NaN")), (2, 42.0), (3, None)]
707
- data6 = [(1, np.nan), (2, 42.0), (3, None)]
708
- data7 = np.array([[1, float("NaN")], [2, 42.0], [3, None]])
709
- data8 = np.array([[1, np.nan], [2, 42.0], [3, None]])
710
-
711
- # +---+-----+
712
- # | id|value|
713
- # +---+-----+
714
- # | 1| NaN|
715
- # | 2| 42.0|
716
- # | 3| NULL|
717
- # +---+-----+
718
-
719
- for data in [data1, data2, data3, data4, data5, data6, data7, data8]:
720
- if isinstance(data[0], (Row, dict)):
721
- # data1, data2, data3, data4
722
- cdf = self.connect.createDataFrame(data)
723
- sdf = self.spark.createDataFrame(data)
724
- else:
725
- # data5, data6, data7, data8
726
- cdf = self.connect.createDataFrame(data, schema=["id", "value"])
727
- sdf = self.spark.createDataFrame(data, schema=["id", "value"])
728
-
729
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
730
-
731
- self.assert_eq(
732
- cdf.select(
733
- cdf["value"].eqNullSafe(None),
734
- cdf["value"].eqNullSafe(float("NaN")),
735
- cdf["value"].eqNullSafe(42.0),
736
- ).toPandas(),
737
- sdf.select(
738
- sdf["value"].eqNullSafe(None),
739
- sdf["value"].eqNullSafe(float("NaN")),
740
- sdf["value"].eqNullSafe(42.0),
741
- ).toPandas(),
742
- )
743
-
744
- # SPARK-41851: test with nanvl
745
- data = [(1.0, float("nan")), (float("nan"), 2.0)]
746
-
747
- cdf = self.connect.createDataFrame(data, ("a", "b"))
748
- sdf = self.spark.createDataFrame(data, ("a", "b"))
749
-
750
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
751
-
752
- self.assert_eq(
753
- cdf.select(
754
- CF.nanvl("a", "b").alias("r1"), CF.nanvl(cdf.a, cdf.b).alias("r2")
755
- ).toPandas(),
756
- sdf.select(
757
- SF.nanvl("a", "b").alias("r1"), SF.nanvl(sdf.a, sdf.b).alias("r2")
758
- ).toPandas(),
759
- )
760
-
761
- # SPARK-41852: test with pmod
762
- data = [
763
- (1.0, float("nan")),
764
- (float("nan"), 2.0),
765
- (10.0, 3.0),
766
- (float("nan"), float("nan")),
767
- (-3.0, 4.0),
768
- (-10.0, 3.0),
769
- (-5.0, -6.0),
770
- (7.0, -8.0),
771
- (1.0, 2.0),
772
- ]
773
-
774
- cdf = self.connect.createDataFrame(data, ("a", "b"))
775
- sdf = self.spark.createDataFrame(data, ("a", "b"))
776
-
777
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
778
-
779
- self.assert_eq(
780
- cdf.select(CF.pmod("a", "b")).toPandas(),
781
- sdf.select(SF.pmod("a", "b")).toPandas(),
782
- )
783
-
784
- def test_cast_with_ddl(self):
785
- data = [Row(date=datetime.date(2021, 12, 27), add=2)]
786
-
787
- cdf = self.connect.createDataFrame(data, "date date, add integer")
788
- sdf = self.spark.createDataFrame(data, "date date, add integer")
789
-
790
- self.assertEqual(cdf.schema, sdf.schema)
791
-
792
- def test_create_empty_df(self):
793
- for schema in [
794
- "STRING",
795
- "x STRING",
796
- "x STRING, y INTEGER",
797
- StringType(),
798
- StructType(
799
- [
800
- StructField("x", StringType(), True),
801
- StructField("y", IntegerType(), True),
802
- ]
803
- ),
804
- ]:
805
- cdf = self.connect.createDataFrame(data=[], schema=schema)
806
- sdf = self.spark.createDataFrame(data=[], schema=schema)
807
-
808
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
809
-
810
- # check error
811
- with self.assertRaises(PySparkValueError) as pe:
812
- self.connect.createDataFrame(data=[])
813
-
814
- self.check_error(
815
- exception=pe.exception,
816
- error_class="CANNOT_INFER_EMPTY_SCHEMA",
817
- message_parameters={},
818
- )
819
-
820
- def test_create_dataframe_from_arrays(self):
821
- # SPARK-42021: createDataFrame support array.array
822
- data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", [4, 5, 6]))]
823
- data2 = [(array.array("d", [1, 2, 3]), 2, "3")]
824
- data3 = [{"a": 1, "b": array.array("i", [1, 2, 3])}]
825
-
826
- for data in [data1, data2, data3]:
827
- cdf = self.connect.createDataFrame(data)
828
- sdf = self.spark.createDataFrame(data)
829
-
830
- # TODO: the nullability is different, need to fix
831
- # self.assertEqual(cdf.schema, sdf.schema)
832
- self.assertEqual(cdf.collect(), sdf.collect())
833
-
834
- def test_timestampe_create_from_rows(self):
835
- data = [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]
836
-
837
- cdf = self.connect.createDataFrame(data, ["date", "val"])
838
- sdf = self.spark.createDataFrame(data, ["date", "val"])
839
-
840
- self.assertEqual(cdf.schema, sdf.schema)
841
- self.assertEqual(cdf.collect(), sdf.collect())
842
-
843
- def test_create_dataframe_with_coercion(self):
844
- data1 = [[1.33, 1], ["2.1", 1]]
845
- data2 = [[True, 1], ["false", 1]]
846
-
847
- for data in [data1, data2]:
848
- cdf = self.connect.createDataFrame(data, ["a", "b"])
849
- sdf = self.spark.createDataFrame(data, ["a", "b"])
850
-
851
- self.assertEqual(cdf.schema, sdf.schema)
852
- self.assertEqual(cdf.collect(), sdf.collect())
853
-
854
- def test_nested_type_create_from_rows(self):
855
- data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
856
- # root
857
- # |-- a: long (nullable = true)
858
- # |-- b: struct (nullable = true)
859
- # | |-- c: long (nullable = true)
860
- # | |-- d: struct (nullable = true)
861
- # | | |-- e: long (nullable = true)
862
- # | | |-- f: struct (nullable = true)
863
- # | | | |-- g: long (nullable = true)
864
- # | | | |-- h: struct (nullable = true)
865
- # | | | | |-- i: long (nullable = true)
866
-
867
- data2 = [
868
- (
869
- 1,
870
- "a",
871
- Row(
872
- a=1,
873
- b=[1, 2, 3],
874
- c={"a": "b"},
875
- d=Row(x=1, y="y", z=Row(o=1, p=2, q=Row(g=1.5))),
876
- ),
877
- )
878
- ]
879
- # root
880
- # |-- _1: long (nullable = true)
881
- # |-- _2: string (nullable = true)
882
- # |-- _3: struct (nullable = true)
883
- # | |-- a: long (nullable = true)
884
- # | |-- b: array (nullable = true)
885
- # | | |-- element: long (containsNull = true)
886
- # | |-- c: map (nullable = true)
887
- # | | |-- key: string
888
- # | | |-- value: string (valueContainsNull = true)
889
- # | |-- d: struct (nullable = true)
890
- # | | |-- x: long (nullable = true)
891
- # | | |-- y: string (nullable = true)
892
- # | | |-- z: struct (nullable = true)
893
- # | | | |-- o: long (nullable = true)
894
- # | | | |-- p: long (nullable = true)
895
- # | | | |-- q: struct (nullable = true)
896
- # | | | | |-- g: double (nullable = true)
897
-
898
- data3 = [
899
- Row(
900
- a=1,
901
- b=[1, 2, 3],
902
- c={"a": "b"},
903
- d=Row(x=1, y="y", z=Row(1, 2, 3)),
904
- e=list("hello connect"),
905
- )
906
- ]
907
- # root
908
- # |-- a: long (nullable = true)
909
- # |-- b: array (nullable = true)
910
- # | |-- element: long (containsNull = true)
911
- # |-- c: map (nullable = true)
912
- # | |-- key: string
913
- # | |-- value: string (valueContainsNull = true)
914
- # |-- d: struct (nullable = true)
915
- # | |-- x: long (nullable = true)
916
- # | |-- y: string (nullable = true)
917
- # | |-- z: struct (nullable = true)
918
- # | | |-- _1: long (nullable = true)
919
- # | | |-- _2: long (nullable = true)
920
- # | | |-- _3: long (nullable = true)
921
- # |-- e: array (nullable = true)
922
- # | |-- element: string (containsNull = true)
923
-
924
- data4 = [
925
- {
926
- "a": 1,
927
- "b": Row(x=1, y=Row(z=2)),
928
- "c": {"x": -1, "y": 2},
929
- "d": [1, 2, 3, 4, 5],
930
- }
931
- ]
932
- # root
933
- # |-- a: long (nullable = true)
934
- # |-- b: struct (nullable = true)
935
- # | |-- x: long (nullable = true)
936
- # | |-- y: struct (nullable = true)
937
- # | | |-- z: long (nullable = true)
938
- # |-- c: map (nullable = true)
939
- # | |-- key: string
940
- # | |-- value: long (valueContainsNull = true)
941
- # |-- d: array (nullable = true)
942
- # | |-- element: long (containsNull = true)
943
-
944
- data5 = [
945
- {
946
- "a": [Row(x=1, y="2"), Row(x=-1, y="-2")],
947
- "b": [[1, 2, 3], [4, 5], [6]],
948
- "c": {3: {4: {5: 6}}, 7: {8: {9: 0}}},
949
- }
950
- ]
951
- # root
952
- # |-- a: array (nullable = true)
953
- # | |-- element: struct (containsNull = true)
954
- # | | |-- x: long (nullable = true)
955
- # | | |-- y: string (nullable = true)
956
- # |-- b: array (nullable = true)
957
- # | |-- element: array (containsNull = true)
958
- # | | |-- element: long (containsNull = true)
959
- # |-- c: map (nullable = true)
960
- # | |-- key: long
961
- # | |-- value: map (valueContainsNull = true)
962
- # | | |-- key: long
963
- # | | |-- value: map (valueContainsNull = true)
964
- # | | | |-- key: long
965
- # | | | |-- value: long (valueContainsNull = true)
966
-
967
- for data in [data1, data2, data3, data4, data5]:
968
- with self.subTest(data=data):
969
- cdf = self.connect.createDataFrame(data)
970
- sdf = self.spark.createDataFrame(data)
971
-
972
- self.assertEqual(cdf.schema, sdf.schema)
973
- self.assertEqual(cdf.collect(), sdf.collect())
974
-
975
- def test_create_df_from_objects(self):
976
- data = [MyObject(1, "1"), MyObject(2, "2")]
977
-
978
- # +---+-----+
979
- # |key|value|
980
- # +---+-----+
981
- # | 1| 1|
982
- # | 2| 2|
983
- # +---+-----+
984
-
985
- cdf = self.connect.createDataFrame(data)
986
- sdf = self.spark.createDataFrame(data)
987
-
988
- self.assertEqual(cdf.schema, sdf.schema)
989
- self.assertEqual(cdf.collect(), sdf.collect())
990
-
991
- def test_simple_explain_string(self):
992
- df = self.connect.read.table(self.tbl_name).limit(10)
993
- result = df._explain_string()
994
- self.assertGreater(len(result), 0)
995
-
996
- def test_schema(self):
997
- schema = self.connect.read.table(self.tbl_name).schema
998
- self.assertEqual(
999
- StructType(
1000
- [StructField("id", LongType(), True), StructField("name", StringType(), True)]
1001
- ),
1002
- schema,
1003
- )
1004
-
1005
- # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
1006
- query = """
1007
- SELECT * FROM VALUES
1008
- (float(1.0), double(1.0), 1.0, "1", true, NULL),
1009
- (float(2.0), double(2.0), 2.0, "2", false, NULL),
1010
- (float(3.0), double(3.0), NULL, "3", false, NULL)
1011
- AS tab(a, b, c, d, e, f)
1012
- """
1013
- self.assertEqual(
1014
- self.spark.sql(query).schema,
1015
- self.connect.sql(query).schema,
1016
- )
1017
-
1018
- # test TimestampType, DateType
1019
- query = """
1020
- SELECT * FROM VALUES
1021
- (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
1022
- (TIMESTAMP('2019-04-12 15:50:00'), NULL),
1023
- (NULL, DATE('2022-02-22'))
1024
- AS tab(a, b)
1025
- """
1026
- self.assertEqual(
1027
- self.spark.sql(query).schema,
1028
- self.connect.sql(query).schema,
1029
- )
1030
-
1031
- # test DayTimeIntervalType
1032
- query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
1033
- self.assertEqual(
1034
- self.spark.sql(query).schema,
1035
- self.connect.sql(query).schema,
1036
- )
1037
-
1038
- # test MapType
1039
- query = """
1040
- SELECT * FROM VALUES
1041
- (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
1042
- (MAP('x', 'yz'), MAP('x', NULL), NULL),
1043
- (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
1044
- AS tab(a, b, c)
1045
- """
1046
- self.assertEqual(
1047
- self.spark.sql(query).schema,
1048
- self.connect.sql(query).schema,
1049
- )
1050
-
1051
- # test ArrayType
1052
- query = """
1053
- SELECT * FROM VALUES
1054
- (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
1055
- (ARRAY('x', NULL), NULL, ARRAY(1, 3)),
1056
- (NULL, ARRAY(-1, -2, -3), Array())
1057
- AS tab(a, b, c)
1058
- """
1059
- self.assertEqual(
1060
- self.spark.sql(query).schema,
1061
- self.connect.sql(query).schema,
1062
- )
1063
-
1064
- # test StructType
1065
- query = """
1066
- SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
1067
- (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
1068
- (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
1069
- (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
1070
- AS tab(a, b, c, d, e, f, g, h)
1071
- """
1072
- self.assertEqual(
1073
- self.spark.sql(query).schema,
1074
- self.connect.sql(query).schema,
1075
- )
1076
-
1077
- def test_to(self):
1078
- # SPARK-41464: test DataFrame.to()
1079
-
1080
- cdf = self.connect.read.table(self.tbl_name)
1081
- df = self.spark.read.table(self.tbl_name)
1082
-
1083
- def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType):
1084
- cdf_to = cdf.to(schema)
1085
- df_to = df.to(schema)
1086
- self.assertEqual(cdf_to.schema, df_to.schema)
1087
- self.assert_eq(cdf_to.toPandas(), df_to.toPandas())
1088
-
1089
- # The schema has not changed
1090
- schema = StructType(
1091
- [
1092
- StructField("id", IntegerType(), True),
1093
- StructField("name", StringType(), True),
1094
- ]
1095
- )
1096
-
1097
- assert_eq_schema(cdf, df, schema)
1098
-
1099
- # Change schema with struct
1100
- schema2 = StructType([StructField("struct", schema, False)])
1101
-
1102
- cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2)
1103
- df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2)
1104
-
1105
- self.assertEqual(cdf_to.schema, df_to.schema)
1106
-
1107
- # Change the column name
1108
- schema = StructType(
1109
- [
1110
- StructField("col1", IntegerType(), True),
1111
- StructField("col2", StringType(), True),
1112
- ]
1113
- )
1114
-
1115
- assert_eq_schema(cdf, df, schema)
1116
-
1117
- # Change the column data type
1118
- schema = StructType(
1119
- [
1120
- StructField("id", StringType(), True),
1121
- StructField("name", StringType(), True),
1122
- ]
1123
- )
1124
-
1125
- assert_eq_schema(cdf, df, schema)
1126
-
1127
- # Reduce the column quantity and change data type
1128
- schema = StructType(
1129
- [
1130
- StructField("id", LongType(), True),
1131
- ]
1132
- )
1133
-
1134
- assert_eq_schema(cdf, df, schema)
1135
-
1136
- # incompatible field nullability
1137
- schema = StructType([StructField("id", LongType(), False)])
1138
- self.assertRaisesRegex(
1139
- AnalysisException,
1140
- "NULLABLE_COLUMN_OR_FIELD",
1141
- lambda: cdf.to(schema).toPandas(),
1142
- )
1143
-
1144
- # field cannot upcast
1145
- schema = StructType([StructField("name", LongType())])
1146
- self.assertRaisesRegex(
1147
- AnalysisException,
1148
- "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1149
- lambda: cdf.to(schema).toPandas(),
1150
- )
1151
-
1152
- schema = StructType(
1153
- [
1154
- StructField("id", IntegerType(), True),
1155
- StructField("name", IntegerType(), True),
1156
- ]
1157
- )
1158
- self.assertRaisesRegex(
1159
- AnalysisException,
1160
- "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1161
- lambda: cdf.to(schema).toPandas(),
1162
- )
1163
-
1164
- # Test map type and array type
1165
- schema = StructType(
1166
- [
1167
- StructField("id", StringType(), True),
1168
- StructField("my_map", MapType(StringType(), IntegerType(), False), True),
1169
- StructField("my_array", ArrayType(IntegerType(), False), True),
1170
- ]
1171
- )
1172
- cdf = self.connect.read.table(self.tbl_name4)
1173
- df = self.spark.read.table(self.tbl_name4)
1174
-
1175
- assert_eq_schema(cdf, df, schema)
1176
-
1177
- def test_toDF(self):
1178
- # SPARK-41310: test DataFrame.toDF()
1179
- self.assertEqual(
1180
- self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema,
1181
- self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema,
1182
- )
1183
-
1184
- def test_print_schema(self):
1185
- # SPARK-41216: Test print schema
1186
- tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._tree_string()
1187
- # root
1188
- # |-- X: integer (nullable = false)
1189
- # |-- Y: integer (nullable = false)
1190
- expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
1191
- self.assertEqual(tree_str, expected)
1192
-
1193
- def test_is_local(self):
1194
- # SPARK-41216: Test is local
1195
- self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal())
1196
- self.assertFalse(self.connect.read.table(self.tbl_name).isLocal())
1197
-
1198
- def test_is_streaming(self):
1199
- # SPARK-41216: Test is streaming
1200
- self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
1201
- self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
1202
-
1203
- def test_input_files(self):
1204
- # SPARK-41216: Test input files
1205
- tmpPath = tempfile.mkdtemp()
1206
- shutil.rmtree(tmpPath)
1207
- try:
1208
- self.df_text.write.text(tmpPath)
1209
-
1210
- input_files_list1 = (
1211
- self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1212
- )
1213
- input_files_list2 = (
1214
- self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1215
- )
1216
-
1217
- self.assertTrue(len(input_files_list1) > 0)
1218
- self.assertEqual(len(input_files_list1), len(input_files_list2))
1219
- for file_path in input_files_list2:
1220
- self.assertTrue(file_path in input_files_list1)
1221
- finally:
1222
- shutil.rmtree(tmpPath)
1223
-
1224
- def test_limit_offset(self):
1225
- df = self.connect.read.table(self.tbl_name)
1226
- pd = df.limit(10).offset(1).toPandas()
1227
- self.assertEqual(9, len(pd.index))
1228
- pd2 = df.offset(98).limit(10).toPandas()
1229
- self.assertEqual(2, len(pd2.index))
1230
-
1231
- def test_tail(self):
1232
- df = self.connect.read.table(self.tbl_name)
1233
- df2 = self.spark.read.table(self.tbl_name)
1234
- self.assertEqual(df.tail(10), df2.tail(10))
1235
-
1236
- def test_sql(self):
1237
- pdf = self.connect.sql("SELECT 1").toPandas()
1238
- self.assertEqual(1, len(pdf.index))
1239
-
1240
- def test_sql_with_named_args(self):
1241
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1242
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1243
- self.assert_eq(df.toPandas(), df2.toPandas())
1244
-
1245
- def test_namedargs_with_global_limit(self):
1246
- sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val)
1247
- where val = :val"""
1248
- df = self.connect.sql(sqlText, args={"val": 1})
1249
- df2 = self.spark.sql(sqlText, args={"val": 1})
1250
- self.assert_eq(df.toPandas(), df2.toPandas())
1251
-
1252
- def test_sql_with_pos_args(self):
1253
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1254
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1255
- self.assert_eq(df.toPandas(), df2.toPandas())
1256
-
1257
- def test_head(self):
1258
- # SPARK-41002: test `head` API in Python Client
1259
- df = self.connect.read.table(self.tbl_name)
1260
- self.assertIsNotNone(len(df.head()))
1261
- self.assertIsNotNone(len(df.head(1)))
1262
- self.assertIsNotNone(len(df.head(5)))
1263
- df2 = self.connect.read.table(self.tbl_name_empty)
1264
- self.assertIsNone(df2.head())
1265
-
1266
- def test_deduplicate(self):
1267
- # SPARK-41326: test distinct and dropDuplicates.
1268
- df = self.connect.read.table(self.tbl_name)
1269
- df2 = self.spark.read.table(self.tbl_name)
1270
- self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas())
1271
- self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas())
1272
- self.assert_eq(
1273
- df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()
1274
- )
1275
-
1276
- def test_deduplicate_within_watermark_in_batch(self):
1277
- df = self.connect.read.table(self.tbl_name)
1278
- with self.assertRaisesRegex(
1279
- AnalysisException,
1280
- "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
1281
- ):
1282
- df.dropDuplicatesWithinWatermark().toPandas()
1283
-
1284
- def test_first(self):
1285
- # SPARK-41002: test `first` API in Python Client
1286
- df = self.connect.read.table(self.tbl_name)
1287
- self.assertIsNotNone(len(df.first()))
1288
- df2 = self.connect.read.table(self.tbl_name_empty)
1289
- self.assertIsNone(df2.first())
1290
-
1291
- def test_take(self) -> None:
1292
- # SPARK-41002: test `take` API in Python Client
1293
- df = self.connect.read.table(self.tbl_name)
1294
- self.assertEqual(5, len(df.take(5)))
1295
- df2 = self.connect.read.table(self.tbl_name_empty)
1296
- self.assertEqual(0, len(df2.take(5)))
1297
-
1298
- def test_drop(self):
1299
- # SPARK-41169: test drop
1300
- query = """
1301
- SELECT * FROM VALUES
1302
- (false, 1, NULL), (false, NULL, 2), (NULL, 3, 3)
1303
- AS tab(a, b, c)
1304
- """
1305
-
1306
- cdf = self.connect.sql(query)
1307
- sdf = self.spark.sql(query)
1308
- self.assert_eq(
1309
- cdf.drop("a").toPandas(),
1310
- sdf.drop("a").toPandas(),
1311
- )
1312
- self.assert_eq(
1313
- cdf.drop("a", "b").toPandas(),
1314
- sdf.drop("a", "b").toPandas(),
1315
- )
1316
- self.assert_eq(
1317
- cdf.drop("a", "x").toPandas(),
1318
- sdf.drop("a", "x").toPandas(),
1319
- )
1320
- self.assert_eq(
1321
- cdf.drop(cdf.a, "x").toPandas(),
1322
- sdf.drop(sdf.a, "x").toPandas(),
1323
- )
1324
-
1325
- def test_subquery_alias(self) -> None:
1326
- # SPARK-40938: test subquery alias.
1327
- plan_text = (
1328
- self.connect.read.table(self.tbl_name)
1329
- .alias("special_alias")
1330
- ._explain_string(extended=True)
1331
- )
1332
- self.assertTrue("special_alias" in plan_text)
1333
-
1334
- def test_sort(self):
1335
- # SPARK-41332: test sort
1336
- query = """
1337
- SELECT * FROM VALUES
1338
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1339
- AS tab(a, b, c)
1340
- """
1341
- # +-----+----+----+
1342
- # | a| b| c|
1343
- # +-----+----+----+
1344
- # |false| 1|NULL|
1345
- # |false|NULL| 2.0|
1346
- # | NULL| 3| 3.0|
1347
- # +-----+----+----+
1348
-
1349
- cdf = self.connect.sql(query)
1350
- sdf = self.spark.sql(query)
1351
- self.assert_eq(
1352
- cdf.sort("a").toPandas(),
1353
- sdf.sort("a").toPandas(),
1354
- )
1355
- self.assert_eq(
1356
- cdf.sort("c").toPandas(),
1357
- sdf.sort("c").toPandas(),
1358
- )
1359
- self.assert_eq(
1360
- cdf.sort("b").toPandas(),
1361
- sdf.sort("b").toPandas(),
1362
- )
1363
- self.assert_eq(
1364
- cdf.sort(cdf.c, "b").toPandas(),
1365
- sdf.sort(sdf.c, "b").toPandas(),
1366
- )
1367
- self.assert_eq(
1368
- cdf.sort(cdf.c.desc(), "b").toPandas(),
1369
- sdf.sort(sdf.c.desc(), "b").toPandas(),
1370
- )
1371
- self.assert_eq(
1372
- cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(),
1373
- sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(),
1374
- )
1375
-
1376
- def test_range(self):
1377
- self.assert_eq(
1378
- self.connect.range(start=0, end=10).toPandas(),
1379
- self.spark.range(start=0, end=10).toPandas(),
1380
- )
1381
- self.assert_eq(
1382
- self.connect.range(start=0, end=10, step=3).toPandas(),
1383
- self.spark.range(start=0, end=10, step=3).toPandas(),
1384
- )
1385
- self.assert_eq(
1386
- self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1387
- self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1388
- )
1389
- # SPARK-41301
1390
- self.assert_eq(
1391
- self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()
1392
- )
1393
-
1394
- def test_create_global_temp_view(self):
1395
- # SPARK-41127: test global temp view creation.
1396
- with self.tempView("view_1"):
1397
- self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1398
- self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1")
1399
- self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1400
-
1401
- # Test when creating a view which is already exists but
1402
- self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1403
- with self.assertRaises(AnalysisException):
1404
- self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1405
-
1406
- def test_create_session_local_temp_view(self):
1407
- # SPARK-41372: test session local temp view creation.
1408
- with self.tempView("view_local_temp"):
1409
- self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp")
1410
- self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1)
1411
- self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp")
1412
- self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0)
1413
-
1414
- # Test when creating a view which is already exists but
1415
- with self.assertRaises(AnalysisException):
1416
- self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp")
1417
-
1418
- def test_to_pandas(self):
1419
- # SPARK-41005: Test to pandas
1420
- query = """
1421
- SELECT * FROM VALUES
1422
- (false, 1, NULL),
1423
- (false, NULL, float(2.0)),
1424
- (NULL, 3, float(3.0))
1425
- AS tab(a, b, c)
1426
- """
1427
-
1428
- self.assert_eq(
1429
- self.connect.sql(query).toPandas(),
1430
- self.spark.sql(query).toPandas(),
1431
- )
1432
-
1433
- query = """
1434
- SELECT * FROM VALUES
1435
- (1, 1, NULL),
1436
- (2, NULL, float(2.0)),
1437
- (3, 3, float(3.0))
1438
- AS tab(a, b, c)
1439
- """
1440
-
1441
- self.assert_eq(
1442
- self.connect.sql(query).toPandas(),
1443
- self.spark.sql(query).toPandas(),
1444
- )
1445
-
1446
- query = """
1447
- SELECT * FROM VALUES
1448
- (double(1.0), 1, "1"),
1449
- (NULL, NULL, NULL),
1450
- (double(2.0), 3, "3")
1451
- AS tab(a, b, c)
1452
- """
1453
-
1454
- self.assert_eq(
1455
- self.connect.sql(query).toPandas(),
1456
- self.spark.sql(query).toPandas(),
1457
- )
1458
-
1459
- query = """
1460
- SELECT * FROM VALUES
1461
- (float(1.0), double(1.0), 1, "1"),
1462
- (float(2.0), double(2.0), 2, "2"),
1463
- (float(3.0), double(3.0), 3, "3")
1464
- AS tab(a, b, c, d)
1465
- """
1466
-
1467
- self.assert_eq(
1468
- self.connect.sql(query).toPandas(),
1469
- self.spark.sql(query).toPandas(),
1470
- )
1471
-
1472
- def test_create_dataframe_from_pandas_with_ns_timestamp(self):
1473
- """Truncate the timestamps for nanoseconds."""
1474
- from datetime import datetime, timezone, timedelta
1475
- from pandas import Timestamp
1476
- import pandas as pd
1477
-
1478
- pdf = pd.DataFrame(
1479
- {
1480
- "naive": [datetime(2019, 1, 1, 0)],
1481
- "aware": [
1482
- Timestamp(
1483
- year=2019, month=1, day=1, nanosecond=500, tz=timezone(timedelta(hours=-8))
1484
- )
1485
- ],
1486
- }
1487
- )
1488
-
1489
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
1490
- self.assertEqual(
1491
- self.connect.createDataFrame(pdf).collect(),
1492
- self.spark.createDataFrame(pdf).collect(),
1493
- )
1494
-
1495
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
1496
- self.assertEqual(
1497
- self.connect.createDataFrame(pdf).collect(),
1498
- self.spark.createDataFrame(pdf).collect(),
1499
- )
1500
-
1501
- def test_select_expr(self):
1502
- # SPARK-41201: test selectExpr API.
1503
- self.assert_eq(
1504
- self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1505
- self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1506
- )
1507
- self.assert_eq(
1508
- self.connect.read.table(self.tbl_name)
1509
- .selectExpr(["id * 2", "cast(name as long) as name"])
1510
- .toPandas(),
1511
- self.spark.read.table(self.tbl_name)
1512
- .selectExpr(["id * 2", "cast(name as long) as name"])
1513
- .toPandas(),
1514
- )
1515
-
1516
- self.assert_eq(
1517
- self.connect.read.table(self.tbl_name)
1518
- .selectExpr("id * 2", "cast(name as long) as name")
1519
- .toPandas(),
1520
- self.spark.read.table(self.tbl_name)
1521
- .selectExpr("id * 2", "cast(name as long) as name")
1522
- .toPandas(),
1523
- )
1524
-
1525
- def test_select_star(self):
1526
- data = [Row(a=1, b=Row(c=2, d=Row(e=3)))]
1527
-
1528
- # +---+--------+
1529
- # | a| b|
1530
- # +---+--------+
1531
- # | 1|{2, {3}}|
1532
- # +---+--------+
1533
-
1534
- cdf = self.connect.createDataFrame(data=data)
1535
- sdf = self.spark.createDataFrame(data=data)
1536
-
1537
- self.assertEqual(
1538
- cdf.select("*").collect(),
1539
- sdf.select("*").collect(),
1540
- )
1541
- self.assertEqual(
1542
- cdf.select("a", "*").collect(),
1543
- sdf.select("a", "*").collect(),
1544
- )
1545
- self.assertEqual(
1546
- cdf.select("a", "b").collect(),
1547
- sdf.select("a", "b").collect(),
1548
- )
1549
- self.assertEqual(
1550
- cdf.select("a", "b.*").collect(),
1551
- sdf.select("a", "b.*").collect(),
1552
- )
1553
-
1554
- def test_fill_na(self):
1555
- # SPARK-41128: Test fill na
1556
- query = """
1557
- SELECT * FROM VALUES
1558
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1559
- AS tab(a, b, c)
1560
- """
1561
- # +-----+----+----+
1562
- # | a| b| c|
1563
- # +-----+----+----+
1564
- # |false| 1|NULL|
1565
- # |false|NULL| 2.0|
1566
- # | NULL| 3| 3.0|
1567
- # +-----+----+----+
1568
-
1569
- self.assert_eq(
1570
- self.connect.sql(query).fillna(True).toPandas(),
1571
- self.spark.sql(query).fillna(True).toPandas(),
1572
- )
1573
- self.assert_eq(
1574
- self.connect.sql(query).fillna(2).toPandas(),
1575
- self.spark.sql(query).fillna(2).toPandas(),
1576
- )
1577
- self.assert_eq(
1578
- self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(),
1579
- self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(),
1580
- )
1581
- self.assert_eq(
1582
- self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1583
- self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1584
- )
1585
-
1586
- def test_drop_na(self):
1587
- # SPARK-41148: Test drop na
1588
- query = """
1589
- SELECT * FROM VALUES
1590
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1591
- AS tab(a, b, c)
1592
- """
1593
- # +-----+----+----+
1594
- # | a| b| c|
1595
- # +-----+----+----+
1596
- # |false| 1|NULL|
1597
- # |false|NULL| 2.0|
1598
- # | NULL| 3| 3.0|
1599
- # +-----+----+----+
1600
-
1601
- self.assert_eq(
1602
- self.connect.sql(query).dropna().toPandas(),
1603
- self.spark.sql(query).dropna().toPandas(),
1604
- )
1605
- self.assert_eq(
1606
- self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(),
1607
- self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(),
1608
- )
1609
- self.assert_eq(
1610
- self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1611
- self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1612
- )
1613
- self.assert_eq(
1614
- self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1615
- self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1616
- )
1617
-
1618
- def test_replace(self):
1619
- # SPARK-41315: Test replace
1620
- query = """
1621
- SELECT * FROM VALUES
1622
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1623
- AS tab(a, b, c)
1624
- """
1625
- # +-----+----+----+
1626
- # | a| b| c|
1627
- # +-----+----+----+
1628
- # |false| 1|NULL|
1629
- # |false|NULL| 2.0|
1630
- # | NULL| 3| 3.0|
1631
- # +-----+----+----+
1632
-
1633
- self.assert_eq(
1634
- self.connect.sql(query).replace(2, 3).toPandas(),
1635
- self.spark.sql(query).replace(2, 3).toPandas(),
1636
- )
1637
- self.assert_eq(
1638
- self.connect.sql(query).na.replace(False, True).toPandas(),
1639
- self.spark.sql(query).na.replace(False, True).toPandas(),
1640
- )
1641
- self.assert_eq(
1642
- self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1643
- self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1644
- )
1645
- self.assert_eq(
1646
- self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1647
- self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1648
- )
1649
- self.assert_eq(
1650
- self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1651
- self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1652
- )
1653
-
1654
- with self.assertRaises(ValueError) as context:
1655
- self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
1656
- self.assertTrue("Mixed type replacements are not supported" in str(context.exception))
1657
-
1658
- with self.assertRaises(AnalysisException) as context:
1659
- self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas()
1660
- self.assertIn(
1661
- """Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
1662
- )
1663
-
1664
- def test_unpivot(self):
1665
- self.assert_eq(
1666
- self.connect.read.table(self.tbl_name)
1667
- .filter("id > 3")
1668
- .unpivot(["id"], ["name"], "variable", "value")
1669
- .toPandas(),
1670
- self.spark.read.table(self.tbl_name)
1671
- .filter("id > 3")
1672
- .unpivot(["id"], ["name"], "variable", "value")
1673
- .toPandas(),
1674
- )
1675
-
1676
- self.assert_eq(
1677
- self.connect.read.table(self.tbl_name)
1678
- .filter("id > 3")
1679
- .unpivot("id", None, "variable", "value")
1680
- .toPandas(),
1681
- self.spark.read.table(self.tbl_name)
1682
- .filter("id > 3")
1683
- .unpivot("id", None, "variable", "value")
1684
- .toPandas(),
1685
- )
1686
-
1687
- def test_union_by_name(self):
1688
- # SPARK-41832: Test unionByName
1689
- data1 = [(1, 2, 3)]
1690
- data2 = [(6, 2, 5)]
1691
- df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"])
1692
- df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"])
1693
- union_df_connect = df1_connect.unionByName(df2_connect)
1694
-
1695
- df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"])
1696
- df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"])
1697
- union_df_spark = df1_spark.unionByName(df2_spark)
1698
-
1699
- self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1700
-
1701
- df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"])
1702
- union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True)
1703
-
1704
- df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"])
1705
- union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True)
1706
-
1707
- self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1708
-
1709
- def test_random_split(self):
1710
- # SPARK-41440: test randomSplit(weights, seed).
1711
- relations = (
1712
- self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1713
- )
1714
- datasets = (
1715
- self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1716
- )
1717
-
1718
- self.assertTrue(len(relations) == len(datasets))
1719
- i = 0
1720
- while i < len(relations):
1721
- self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
1722
- i += 1
1723
-
1724
- def test_observe(self):
1725
- # SPARK-41527: test DataFrame.observe()
1726
- observation_name = "my_metric"
1727
-
1728
- self.assert_eq(
1729
- self.connect.read.table(self.tbl_name)
1730
- .filter("id > 3")
1731
- .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
1732
- .toPandas(),
1733
- self.spark.read.table(self.tbl_name)
1734
- .filter("id > 3")
1735
- .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
1736
- .toPandas(),
1737
- )
1738
-
1739
- from pyspark.sql.observation import Observation
1740
-
1741
- observation = Observation(observation_name)
1742
-
1743
- cdf = (
1744
- self.connect.read.table(self.tbl_name)
1745
- .filter("id > 3")
1746
- .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
1747
- .toPandas()
1748
- )
1749
- df = (
1750
- self.spark.read.table(self.tbl_name)
1751
- .filter("id > 3")
1752
- .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
1753
- .toPandas()
1754
- )
1755
-
1756
- self.assert_eq(cdf, df)
1757
-
1758
- observed_metrics = cdf.attrs["observed_metrics"]
1759
- self.assert_eq(len(observed_metrics), 1)
1760
- self.assert_eq(observed_metrics[0].name, observation_name)
1761
- self.assert_eq(len(observed_metrics[0].metrics), 3)
1762
- for metric in observed_metrics[0].metrics:
1763
- self.assertIsInstance(metric, ProtoExpression.Literal)
1764
- values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
1765
- self.assert_eq(values, [4, 99, 4944])
1766
-
1767
- with self.assertRaises(PySparkValueError) as pe:
1768
- self.connect.read.table(self.tbl_name).observe(observation_name)
1769
-
1770
- self.check_error(
1771
- exception=pe.exception,
1772
- error_class="CANNOT_BE_EMPTY",
1773
- message_parameters={"item": "exprs"},
1774
- )
1775
-
1776
- with self.assertRaises(PySparkTypeError) as pe:
1777
- self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
1778
-
1779
- self.check_error(
1780
- exception=pe.exception,
1781
- error_class="NOT_LIST_OF_COLUMN",
1782
- message_parameters={"arg_name": "exprs"},
1783
- )
1784
-
1785
- def test_with_columns(self):
1786
- # SPARK-41256: test withColumn(s).
1787
- self.assert_eq(
1788
- self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(),
1789
- self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(),
1790
- )
1791
-
1792
- self.assert_eq(
1793
- self.connect.read.table(self.tbl_name)
1794
- .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)})
1795
- .toPandas(),
1796
- self.spark.read.table(self.tbl_name)
1797
- .withColumns(
1798
- {
1799
- "id": SF.lit(False),
1800
- "col_not_exist": SF.lit(False),
1801
- }
1802
- )
1803
- .toPandas(),
1804
- )
1805
-
1806
- def test_hint(self):
1807
- # SPARK-41349: Test hint
1808
- self.assert_eq(
1809
- self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1810
- self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1811
- )
1812
-
1813
- # Hint with unsupported name will be ignored
1814
- self.assert_eq(
1815
- self.connect.read.table(self.tbl_name).hint("illegal").toPandas(),
1816
- self.spark.read.table(self.tbl_name).hint("illegal").toPandas(),
1817
- )
1818
-
1819
- # Hint with all supported parameter values
1820
- such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
1821
- self.assert_eq(
1822
- self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1823
- self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1824
- )
1825
-
1826
- # Hint with unsupported parameter values
1827
- with self.assertRaises(AnalysisException):
1828
- self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas()
1829
-
1830
- # Hint with unsupported parameter types
1831
- with self.assertRaises(TypeError):
1832
- self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas()
1833
-
1834
- # Hint with unsupported parameter types
1835
- with self.assertRaises(TypeError):
1836
- self.connect.read.table(self.tbl_name).hint(
1837
- "my awesome hint", 1.2345, 2, such_a_nice_list, range(6)
1838
- ).toPandas()
1839
-
1840
- # Hint with wrong combination
1841
- with self.assertRaises(AnalysisException):
1842
- self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas()
1843
-
1844
- def test_join_hint(self):
1845
- cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
1846
- cdf2 = self.connect.createDataFrame(
1847
- [Row(height=80, name="Tom"), Row(height=85, name="Bob")]
1848
- )
1849
-
1850
- self.assertTrue(
1851
- "BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()
1852
- )
1853
- self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string())
1854
- self.assertTrue(
1855
- "ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()
1856
- )
1857
-
1858
- def test_different_spark_session_join_or_union(self):
1859
- df = self.connect.range(10).limit(3)
1860
-
1861
- spark2 = RemoteSparkSession(connection="sc://localhost")
1862
- df2 = spark2.range(10).limit(3)
1863
-
1864
- with self.assertRaises(SessionNotSameException) as e1:
1865
- df.union(df2).collect()
1866
- self.check_error(
1867
- exception=e1.exception,
1868
- error_class="SESSION_NOT_SAME",
1869
- message_parameters={},
1870
- )
1871
-
1872
- with self.assertRaises(SessionNotSameException) as e2:
1873
- df.unionByName(df2).collect()
1874
- self.check_error(
1875
- exception=e2.exception,
1876
- error_class="SESSION_NOT_SAME",
1877
- message_parameters={},
1878
- )
1879
-
1880
- with self.assertRaises(SessionNotSameException) as e3:
1881
- df.join(df2).collect()
1882
- self.check_error(
1883
- exception=e3.exception,
1884
- error_class="SESSION_NOT_SAME",
1885
- message_parameters={},
1886
- )
1887
-
1888
- def test_extended_hint_types(self):
1889
- cdf = self.connect.range(100).toDF("id")
1890
-
1891
- cdf.hint(
1892
- "my awesome hint",
1893
- 1.2345,
1894
- "what",
1895
- ["itworks1", "itworks2", "itworks3"],
1896
- ).show()
1897
-
1898
- with self.assertRaises(PySparkTypeError) as pe:
1899
- cdf.hint(
1900
- "my awesome hint",
1901
- 1.2345,
1902
- "what",
1903
- {"itworks1": "itworks2"},
1904
- ).show()
1905
-
1906
- self.check_error(
1907
- exception=pe.exception,
1908
- error_class="INVALID_ITEM_FOR_CONTAINER",
1909
- message_parameters={
1910
- "arg_name": "parameters",
1911
- "allowed_types": "str, list, float, int",
1912
- "item_type": "dict",
1913
- },
1914
- )
1915
-
1916
- def test_empty_dataset(self):
1917
- # SPARK-41005: Test arrow based collection with empty dataset.
1918
- self.assertTrue(
1919
- self.connect.sql("SELECT 1 AS X LIMIT 0")
1920
- .toPandas()
1921
- .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas())
1922
- )
1923
- pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas()
1924
- self.assertEqual(0, len(pdf)) # empty dataset
1925
- self.assertEqual(1, len(pdf.columns)) # one column
1926
- self.assertEqual("X", pdf.columns[0])
1927
-
1928
- def test_is_empty(self):
1929
- # SPARK-41212: Test is empty
1930
- self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty())
1931
- self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty())
1932
-
1933
- def test_session(self):
1934
- self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession)
1935
-
1936
- def test_show(self):
1937
- # SPARK-41111: Test the show method
1938
- show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string()
1939
- # +---+---+
1940
- # | X| Y|
1941
- # +---+---+
1942
- # | 1| 2|
1943
- # +---+---+
1944
- expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n"
1945
- self.assertEqual(show_str, expected)
1946
-
1947
- def test_describe(self):
1948
- # SPARK-41403: Test the describe method
1949
- self.assert_eq(
1950
- self.connect.read.table(self.tbl_name).describe("id").toPandas(),
1951
- self.spark.read.table(self.tbl_name).describe("id").toPandas(),
1952
- )
1953
- self.assert_eq(
1954
- self.connect.read.table(self.tbl_name).describe("id", "name").toPandas(),
1955
- self.spark.read.table(self.tbl_name).describe("id", "name").toPandas(),
1956
- )
1957
- self.assert_eq(
1958
- self.connect.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1959
- self.spark.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1960
- )
1961
-
1962
- def test_stat_cov(self):
1963
- # SPARK-41067: Test the stat.cov method
1964
- self.assertEqual(
1965
- self.connect.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1966
- self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1967
- )
1968
-
1969
- def test_stat_corr(self):
1970
- # SPARK-41068: Test the stat.corr method
1971
- self.assertEqual(
1972
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1973
- self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1974
- )
1975
-
1976
- self.assertEqual(
1977
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1978
- self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1979
- )
1980
-
1981
- with self.assertRaises(PySparkTypeError) as pe:
1982
- self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson")
1983
-
1984
- self.check_error(
1985
- exception=pe.exception,
1986
- error_class="NOT_STR",
1987
- message_parameters={
1988
- "arg_name": "col1",
1989
- "arg_type": "int",
1990
- },
1991
- )
1992
-
1993
- with self.assertRaises(PySparkTypeError) as pe:
1994
- self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson")
1995
-
1996
- self.check_error(
1997
- exception=pe.exception,
1998
- error_class="NOT_STR",
1999
- message_parameters={
2000
- "arg_name": "col2",
2001
- "arg_type": "int",
2002
- },
2003
- )
2004
- with self.assertRaises(ValueError) as context:
2005
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"),
2006
- self.assertTrue(
2007
- "Currently only the calculation of the Pearson Correlation "
2008
- + "coefficient is supported."
2009
- in str(context.exception)
2010
- )
2011
-
2012
- def test_stat_approx_quantile(self):
2013
- # SPARK-41069: Test the stat.approxQuantile method
2014
- result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2015
- ["col1", "col3"], [0.1, 0.5, 0.9], 0.1
2016
- )
2017
- self.assertEqual(len(result), 2)
2018
- self.assertEqual(len(result[0]), 3)
2019
- self.assertEqual(len(result[1]), 3)
2020
-
2021
- result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2022
- ["col1"], [0.1, 0.5, 0.9], 0.1
2023
- )
2024
- self.assertEqual(len(result), 1)
2025
- self.assertEqual(len(result[0]), 3)
2026
-
2027
- with self.assertRaises(PySparkTypeError) as pe:
2028
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(1, [0.1, 0.5, 0.9], 0.1)
2029
-
2030
- self.check_error(
2031
- exception=pe.exception,
2032
- error_class="NOT_LIST_OR_STR_OR_TUPLE",
2033
- message_parameters={
2034
- "arg_name": "col",
2035
- "arg_type": "int",
2036
- },
2037
- )
2038
-
2039
- with self.assertRaises(PySparkTypeError) as pe:
2040
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(["col1", "col3"], 0.1, 0.1)
2041
-
2042
- self.check_error(
2043
- exception=pe.exception,
2044
- error_class="NOT_LIST_OR_TUPLE",
2045
- message_parameters={
2046
- "arg_name": "probabilities",
2047
- "arg_type": "float",
2048
- },
2049
- )
2050
- with self.assertRaises(PySparkTypeError) as pe:
2051
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2052
- ["col1", "col3"], [-0.1], 0.1
2053
- )
2054
-
2055
- self.check_error(
2056
- exception=pe.exception,
2057
- error_class="NOT_LIST_OF_FLOAT_OR_INT",
2058
- message_parameters={"arg_name": "probabilities", "arg_type": "float"},
2059
- )
2060
- with self.assertRaises(PySparkTypeError) as pe:
2061
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2062
- ["col1", "col3"], [0.1, 0.5, 0.9], "str"
2063
- )
2064
-
2065
- self.check_error(
2066
- exception=pe.exception,
2067
- error_class="NOT_FLOAT_OR_INT",
2068
- message_parameters={
2069
- "arg_name": "relativeError",
2070
- "arg_type": "str",
2071
- },
2072
- )
2073
- with self.assertRaises(PySparkValueError) as pe:
2074
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2075
- ["col1", "col3"], [0.1, 0.5, 0.9], -0.1
2076
- )
2077
-
2078
- self.check_error(
2079
- exception=pe.exception,
2080
- error_class="NEGATIVE_VALUE",
2081
- message_parameters={
2082
- "arg_name": "relativeError",
2083
- "arg_value": "-0.1",
2084
- },
2085
- )
2086
-
2087
- def test_stat_freq_items(self):
2088
- # SPARK-41065: Test the stat.freqItems method
2089
- self.assert_eq(
2090
- self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2091
- self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2092
- )
2093
-
2094
- self.assert_eq(
2095
- self.connect.read.table(self.tbl_name2)
2096
- .stat.freqItems(["col1", "col3"], 0.4)
2097
- .toPandas(),
2098
- self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"], 0.4).toPandas(),
2099
- )
2100
-
2101
- with self.assertRaises(PySparkTypeError) as pe:
2102
- self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
2103
-
2104
- self.check_error(
2105
- exception=pe.exception,
2106
- error_class="NOT_LIST_OR_TUPLE",
2107
- message_parameters={
2108
- "arg_name": "cols",
2109
- "arg_type": "str",
2110
- },
2111
- )
2112
-
2113
- def test_stat_sample_by(self):
2114
- # SPARK-41069: Test stat.sample_by
2115
-
2116
- cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key"))
2117
- sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
2118
-
2119
- self.assert_eq(
2120
- cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2121
- .groupBy("key")
2122
- .agg(CF.count(CF.lit(1)))
2123
- .orderBy("key")
2124
- .toPandas(),
2125
- sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2126
- .groupBy("key")
2127
- .agg(SF.count(SF.lit(1)))
2128
- .orderBy("key")
2129
- .toPandas(),
2130
- )
2131
-
2132
- with self.assertRaises(PySparkTypeError) as pe:
2133
- cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0)
2134
-
2135
- self.check_error(
2136
- exception=pe.exception,
2137
- error_class="DISALLOWED_TYPE_FOR_CONTAINER",
2138
- message_parameters={
2139
- "arg_name": "fractions",
2140
- "arg_type": "dict",
2141
- "allowed_types": "float, int, str",
2142
- "return_type": "NoneType",
2143
- },
2144
- )
2145
-
2146
- with self.assertRaises(SparkConnectException):
2147
- cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show()
2148
-
2149
- def test_repr(self):
2150
- # SPARK-41213: Test the __repr__ method
2151
- query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
2152
- self.assertEqual(
2153
- self.connect.sql(query).__repr__(),
2154
- self.spark.sql(query).__repr__(),
2155
- )
2156
-
2157
- def test_explain_string(self):
2158
- # SPARK-41122: test explain API.
2159
- plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True)
2160
- self.assertTrue("Parsed Logical Plan" in plan_str)
2161
- self.assertTrue("Analyzed Logical Plan" in plan_str)
2162
- self.assertTrue("Optimized Logical Plan" in plan_str)
2163
- self.assertTrue("Physical Plan" in plan_str)
2164
-
2165
- with self.assertRaises(PySparkValueError) as pe:
2166
- self.connect.sql("SELECT 1")._explain_string(mode="unknown")
2167
- self.check_error(
2168
- exception=pe.exception,
2169
- error_class="UNKNOWN_EXPLAIN_MODE",
2170
- message_parameters={"explain_mode": "unknown"},
2171
- )
2172
-
2173
- def test_simple_datasource_read(self) -> None:
2174
- writeDf = self.df_text
2175
- tmpPath = tempfile.mkdtemp()
2176
- shutil.rmtree(tmpPath)
2177
- writeDf.write.text(tmpPath)
2178
-
2179
- for schema in [
2180
- "id STRING",
2181
- StructType([StructField("id", StringType())]),
2182
- ]:
2183
- readDf = self.connect.read.format("text").schema(schema).load(path=tmpPath)
2184
- expectResult = writeDf.collect()
2185
- pandasResult = readDf.toPandas()
2186
- if pandasResult is None:
2187
- self.assertTrue(False, "Empty pandas dataframe")
2188
- else:
2189
- actualResult = pandasResult.values.tolist()
2190
- self.assertEqual(len(expectResult), len(actualResult))
2191
-
2192
- def test_simple_read_without_schema(self) -> None:
2193
- """SPARK-41300: Schema not set when reading CSV."""
2194
- writeDf = self.df_text
2195
- tmpPath = tempfile.mkdtemp()
2196
- shutil.rmtree(tmpPath)
2197
- writeDf.write.csv(tmpPath, header=True)
2198
-
2199
- readDf = self.connect.read.format("csv").option("header", True).load(path=tmpPath)
2200
- expectResult = set(writeDf.collect())
2201
- pandasResult = set(readDf.collect())
2202
- self.assertEqual(expectResult, pandasResult)
2203
-
2204
- def test_count(self) -> None:
2205
- # SPARK-41308: test count() API.
2206
- self.assertEqual(
2207
- self.connect.read.table(self.tbl_name).count(),
2208
- self.spark.read.table(self.tbl_name).count(),
2209
- )
2210
-
2211
- def test_simple_transform(self) -> None:
2212
- """SPARK-41203: Support DF.transform"""
2213
-
2214
- def transform_df(input_df: CDataFrame) -> CDataFrame:
2215
- return input_df.select((CF.col("id") + CF.lit(10)).alias("id"))
2216
-
2217
- df = self.connect.range(1, 100)
2218
- result_left = df.transform(transform_df).collect()
2219
- result_right = self.connect.range(11, 110).collect()
2220
- self.assertEqual(result_right, result_left)
2221
-
2222
- # Check assertion.
2223
- with self.assertRaises(AssertionError):
2224
- df.transform(lambda x: 2) # type: ignore
2225
-
2226
- def test_alias(self) -> None:
2227
- """Testing supported and unsupported alias"""
2228
- col0 = (
2229
- self.connect.range(1, 10)
2230
- .select(CF.col("id").alias("name", metadata={"max": 99}))
2231
- .schema.names[0]
2232
- )
2233
- self.assertEqual("name", col0)
2234
-
2235
- with self.assertRaises(SparkConnectException) as exc:
2236
- self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect()
2237
- self.assertIn("(this, is, not)", str(exc.exception))
2238
-
2239
- def test_column_regexp(self) -> None:
2240
- # SPARK-41438: test dataframe.colRegex()
2241
- ndf = self.connect.read.table(self.tbl_name3)
2242
- df = self.spark.read.table(self.tbl_name3)
2243
-
2244
- self.assert_eq(
2245
- ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(),
2246
- df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(),
2247
- )
2248
-
2249
- def test_repartition(self) -> None:
2250
- # SPARK-41354: test dataframe.repartition(numPartitions)
2251
- self.assert_eq(
2252
- self.connect.read.table(self.tbl_name).repartition(10).toPandas(),
2253
- self.spark.read.table(self.tbl_name).repartition(10).toPandas(),
2254
- )
2255
-
2256
- self.assert_eq(
2257
- self.connect.read.table(self.tbl_name).coalesce(10).toPandas(),
2258
- self.spark.read.table(self.tbl_name).coalesce(10).toPandas(),
2259
- )
2260
-
2261
- def test_repartition_by_expression(self) -> None:
2262
- # SPARK-41354: test dataframe.repartition(expressions)
2263
- self.assert_eq(
2264
- self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2265
- self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2266
- )
2267
-
2268
- self.assert_eq(
2269
- self.connect.read.table(self.tbl_name).repartition("id").toPandas(),
2270
- self.spark.read.table(self.tbl_name).repartition("id").toPandas(),
2271
- )
2272
-
2273
- # repartition with unsupported parameter values
2274
- with self.assertRaises(AnalysisException):
2275
- self.connect.read.table(self.tbl_name).repartition("id+1").toPandas()
2276
-
2277
- def test_repartition_by_range(self) -> None:
2278
- # SPARK-41354: test dataframe.repartitionByRange(expressions)
2279
- cdf = self.connect.read.table(self.tbl_name)
2280
- sdf = self.spark.read.table(self.tbl_name)
2281
-
2282
- self.assert_eq(
2283
- cdf.repartitionByRange(10, "id").toPandas(),
2284
- sdf.repartitionByRange(10, "id").toPandas(),
2285
- )
2286
-
2287
- self.assert_eq(
2288
- cdf.repartitionByRange("id").toPandas(),
2289
- sdf.repartitionByRange("id").toPandas(),
2290
- )
2291
-
2292
- self.assert_eq(
2293
- cdf.repartitionByRange(cdf.id.desc()).toPandas(),
2294
- sdf.repartitionByRange(sdf.id.desc()).toPandas(),
2295
- )
2296
-
2297
- # repartitionByRange with unsupported parameter values
2298
- with self.assertRaises(AnalysisException):
2299
- self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas()
2300
-
2301
- def test_agg_with_two_agg_exprs(self) -> None:
2302
- # SPARK-41230: test dataframe.agg()
2303
- self.assert_eq(
2304
- self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2305
- self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2306
- )
2307
-
2308
- def test_subtract(self):
2309
- # SPARK-41453: test dataframe.subtract()
2310
- ndf1 = self.connect.read.table(self.tbl_name)
2311
- ndf2 = ndf1.filter("id > 3")
2312
- df1 = self.spark.read.table(self.tbl_name)
2313
- df2 = df1.filter("id > 3")
2314
-
2315
- self.assert_eq(
2316
- ndf1.subtract(ndf2).toPandas(),
2317
- df1.subtract(df2).toPandas(),
2318
- )
2319
-
2320
- def test_write_operations(self):
2321
- with tempfile.TemporaryDirectory() as d:
2322
- df = self.connect.range(50)
2323
- df.write.mode("overwrite").format("csv").save(d)
2324
-
2325
- ndf = self.connect.read.schema("id int").load(d, format="csv")
2326
- self.assertEqual(50, len(ndf.collect()))
2327
- cd = ndf.collect()
2328
- self.assertEqual(set(df.collect()), set(cd))
2329
-
2330
- with tempfile.TemporaryDirectory() as d:
2331
- df = self.connect.range(50)
2332
- df.write.mode("overwrite").csv(d, lineSep="|")
2333
-
2334
- ndf = self.connect.read.schema("id int").load(d, format="csv", lineSep="|")
2335
- self.assertEqual(set(df.collect()), set(ndf.collect()))
2336
-
2337
- df = self.connect.range(50)
2338
- df.write.format("parquet").saveAsTable("parquet_test")
2339
-
2340
- ndf = self.connect.read.table("parquet_test")
2341
- self.assertEqual(set(df.collect()), set(ndf.collect()))
2342
-
2343
- def test_writeTo_operations(self):
2344
- # SPARK-42002: Implement DataFrameWriterV2
2345
- import datetime
2346
- from pyspark.sql.connect.functions import col, years, months, days, hours, bucket
2347
-
2348
- df = self.connect.createDataFrame(
2349
- [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
2350
- )
2351
- writer = df.writeTo("table1")
2352
- self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
2353
- self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
2354
- self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
2355
- self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
2356
- self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
2357
- self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
2358
- self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
2359
- self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
2360
- self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
2361
- self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
2362
- self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2)
2363
-
2364
- def test_agg_with_avg(self):
2365
- # SPARK-41325: groupby.avg()
2366
- df = (
2367
- self.connect.range(10)
2368
- .groupBy((CF.col("id") % CF.lit(2)).alias("moded"))
2369
- .avg("id")
2370
- .sort("moded")
2371
- )
2372
- res = df.collect()
2373
- self.assertEqual(2, len(res))
2374
- self.assertEqual(4.0, res[0][1])
2375
- self.assertEqual(5.0, res[1][1])
2376
-
2377
- # Additional GroupBy tests with 3 rows
2378
-
2379
- df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded"))
2380
- df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded"))
2381
- self.assertEqual(
2382
- set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect())
2383
- )
2384
-
2385
- # Dict agg
2386
- measures = {"id": "sum"}
2387
- self.assertEqual(
2388
- set(df_a.agg(measures).select("sum(id)").collect()),
2389
- set(df_b.agg(measures).select("sum(id)").collect()),
2390
- )
2391
-
2392
- def test_column_cannot_be_constructed_from_string(self):
2393
- with self.assertRaises(TypeError):
2394
- Column("col")
2395
-
2396
- def test_crossjoin(self):
2397
- # SPARK-41227: Test CrossJoin
2398
- connect_df = self.connect.read.table(self.tbl_name)
2399
- spark_df = self.spark.read.table(self.tbl_name)
2400
- self.assert_eq(
2401
- set(
2402
- connect_df.select("id")
2403
- .join(other=connect_df.select("name"), how="cross")
2404
- .toPandas()
2405
- ),
2406
- set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()),
2407
- )
2408
- self.assert_eq(
2409
- set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()),
2410
- set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()),
2411
- )
2412
-
2413
- def test_grouped_data(self):
2414
- query = """
2415
- SELECT * FROM VALUES
2416
- ('James', 'Sales', 3000, 2020),
2417
- ('Michael', 'Sales', 4600, 2020),
2418
- ('Robert', 'Sales', 4100, 2020),
2419
- ('Maria', 'Finance', 3000, 2020),
2420
- ('James', 'Sales', 3000, 2019),
2421
- ('Scott', 'Finance', 3300, 2020),
2422
- ('Jen', 'Finance', 3900, 2020),
2423
- ('Jeff', 'Marketing', 3000, 2020),
2424
- ('Kumar', 'Marketing', 2000, 2020),
2425
- ('Saif', 'Sales', 4100, 2020)
2426
- AS T(name, department, salary, year)
2427
- """
2428
-
2429
- # +-------+----------+------+----+
2430
- # | name|department|salary|year|
2431
- # +-------+----------+------+----+
2432
- # | James| Sales| 3000|2020|
2433
- # |Michael| Sales| 4600|2020|
2434
- # | Robert| Sales| 4100|2020|
2435
- # | Maria| Finance| 3000|2020|
2436
- # | James| Sales| 3000|2019|
2437
- # | Scott| Finance| 3300|2020|
2438
- # | Jen| Finance| 3900|2020|
2439
- # | Jeff| Marketing| 3000|2020|
2440
- # | Kumar| Marketing| 2000|2020|
2441
- # | Saif| Sales| 4100|2020|
2442
- # +-------+----------+------+----+
2443
-
2444
- cdf = self.connect.sql(query)
2445
- sdf = self.spark.sql(query)
2446
-
2447
- # test groupby
2448
- self.assert_eq(
2449
- cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(),
2450
- sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(),
2451
- )
2452
- self.assert_eq(
2453
- cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2454
- sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2455
- )
2456
-
2457
- # test rollup
2458
- self.assert_eq(
2459
- cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(),
2460
- sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(),
2461
- )
2462
- self.assert_eq(
2463
- cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2464
- sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2465
- )
2466
-
2467
- # test cube
2468
- self.assert_eq(
2469
- cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(),
2470
- sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(),
2471
- )
2472
- self.assert_eq(
2473
- cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2474
- sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2475
- )
2476
-
2477
- # test pivot
2478
- # pivot with values
2479
- self.assert_eq(
2480
- cdf.groupBy("name")
2481
- .pivot("department", ["Sales", "Marketing"])
2482
- .agg(CF.sum(cdf.salary))
2483
- .toPandas(),
2484
- sdf.groupBy("name")
2485
- .pivot("department", ["Sales", "Marketing"])
2486
- .agg(SF.sum(sdf.salary))
2487
- .toPandas(),
2488
- )
2489
- self.assert_eq(
2490
- cdf.groupBy(cdf.name)
2491
- .pivot("department", ["Sales", "Finance", "Marketing"])
2492
- .agg(CF.sum(cdf.salary))
2493
- .toPandas(),
2494
- sdf.groupBy(sdf.name)
2495
- .pivot("department", ["Sales", "Finance", "Marketing"])
2496
- .agg(SF.sum(sdf.salary))
2497
- .toPandas(),
2498
- )
2499
- self.assert_eq(
2500
- cdf.groupBy(cdf.name)
2501
- .pivot("department", ["Sales", "Finance", "Unknown"])
2502
- .agg(CF.sum(cdf.salary))
2503
- .toPandas(),
2504
- sdf.groupBy(sdf.name)
2505
- .pivot("department", ["Sales", "Finance", "Unknown"])
2506
- .agg(SF.sum(sdf.salary))
2507
- .toPandas(),
2508
- )
2509
-
2510
- # pivot without values
2511
- self.assert_eq(
2512
- cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(),
2513
- sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(),
2514
- )
2515
-
2516
- self.assert_eq(
2517
- cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(),
2518
- sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(),
2519
- )
2520
-
2521
- # check error
2522
- with self.assertRaisesRegex(
2523
- Exception,
2524
- "PIVOT after ROLLUP is not supported",
2525
- ):
2526
- cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary))
2527
-
2528
- with self.assertRaisesRegex(
2529
- Exception,
2530
- "PIVOT after CUBE is not supported",
2531
- ):
2532
- cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary))
2533
-
2534
- with self.assertRaisesRegex(
2535
- Exception,
2536
- "Repeated PIVOT operation is not supported",
2537
- ):
2538
- cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary))
2539
-
2540
- with self.assertRaises(PySparkTypeError) as pe:
2541
- cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary))
2542
-
2543
- self.check_error(
2544
- exception=pe.exception,
2545
- error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR",
2546
- message_parameters={
2547
- "arg_name": "value",
2548
- "arg_type": "bytes",
2549
- },
2550
- )
2551
-
2552
- def test_numeric_aggregation(self):
2553
- # SPARK-41737: test numeric aggregation
2554
- query = """
2555
- SELECT * FROM VALUES
2556
- ('James', 'Sales', 3000, 2020),
2557
- ('Michael', 'Sales', 4600, 2020),
2558
- ('Robert', 'Sales', 4100, 2020),
2559
- ('Maria', 'Finance', 3000, 2020),
2560
- ('James', 'Sales', 3000, 2019),
2561
- ('Scott', 'Finance', 3300, 2020),
2562
- ('Jen', 'Finance', 3900, 2020),
2563
- ('Jeff', 'Marketing', 3000, 2020),
2564
- ('Kumar', 'Marketing', 2000, 2020),
2565
- ('Saif', 'Sales', 4100, 2020)
2566
- AS T(name, department, salary, year)
2567
- """
2568
-
2569
- # +-------+----------+------+----+
2570
- # | name|department|salary|year|
2571
- # +-------+----------+------+----+
2572
- # | James| Sales| 3000|2020|
2573
- # |Michael| Sales| 4600|2020|
2574
- # | Robert| Sales| 4100|2020|
2575
- # | Maria| Finance| 3000|2020|
2576
- # | James| Sales| 3000|2019|
2577
- # | Scott| Finance| 3300|2020|
2578
- # | Jen| Finance| 3900|2020|
2579
- # | Jeff| Marketing| 3000|2020|
2580
- # | Kumar| Marketing| 2000|2020|
2581
- # | Saif| Sales| 4100|2020|
2582
- # +-------+----------+------+----+
2583
-
2584
- cdf = self.connect.sql(query)
2585
- sdf = self.spark.sql(query)
2586
-
2587
- # test groupby
2588
- self.assert_eq(
2589
- cdf.groupBy("name").min().toPandas(),
2590
- sdf.groupBy("name").min().toPandas(),
2591
- )
2592
- self.assert_eq(
2593
- cdf.groupBy("name").min("salary").toPandas(),
2594
- sdf.groupBy("name").min("salary").toPandas(),
2595
- )
2596
- self.assert_eq(
2597
- cdf.groupBy("name").max("salary").toPandas(),
2598
- sdf.groupBy("name").max("salary").toPandas(),
2599
- )
2600
- self.assert_eq(
2601
- cdf.groupBy("name", cdf.department).avg("salary", "year").toPandas(),
2602
- sdf.groupBy("name", sdf.department).avg("salary", "year").toPandas(),
2603
- )
2604
- self.assert_eq(
2605
- cdf.groupBy("name", cdf.department).mean("salary", "year").toPandas(),
2606
- sdf.groupBy("name", sdf.department).mean("salary", "year").toPandas(),
2607
- )
2608
- self.assert_eq(
2609
- cdf.groupBy("name", cdf.department).sum("salary", "year").toPandas(),
2610
- sdf.groupBy("name", sdf.department).sum("salary", "year").toPandas(),
2611
- )
2612
-
2613
- # test rollup
2614
- self.assert_eq(
2615
- cdf.rollup("name").max().toPandas(),
2616
- sdf.rollup("name").max().toPandas(),
2617
- )
2618
- self.assert_eq(
2619
- cdf.rollup("name").min("salary").toPandas(),
2620
- sdf.rollup("name").min("salary").toPandas(),
2621
- )
2622
- self.assert_eq(
2623
- cdf.rollup("name").max("salary").toPandas(),
2624
- sdf.rollup("name").max("salary").toPandas(),
2625
- )
2626
- self.assert_eq(
2627
- cdf.rollup("name", cdf.department).avg("salary", "year").toPandas(),
2628
- sdf.rollup("name", sdf.department).avg("salary", "year").toPandas(),
2629
- )
2630
- self.assert_eq(
2631
- cdf.rollup("name", cdf.department).mean("salary", "year").toPandas(),
2632
- sdf.rollup("name", sdf.department).mean("salary", "year").toPandas(),
2633
- )
2634
- self.assert_eq(
2635
- cdf.rollup("name", cdf.department).sum("salary", "year").toPandas(),
2636
- sdf.rollup("name", sdf.department).sum("salary", "year").toPandas(),
2637
- )
2638
-
2639
- # test cube
2640
- self.assert_eq(
2641
- cdf.cube("name").avg().toPandas(),
2642
- sdf.cube("name").avg().toPandas(),
2643
- )
2644
- self.assert_eq(
2645
- cdf.cube("name").mean().toPandas(),
2646
- sdf.cube("name").mean().toPandas(),
2647
- )
2648
- self.assert_eq(
2649
- cdf.cube("name").min("salary").toPandas(),
2650
- sdf.cube("name").min("salary").toPandas(),
2651
- )
2652
- self.assert_eq(
2653
- cdf.cube("name").max("salary").toPandas(),
2654
- sdf.cube("name").max("salary").toPandas(),
2655
- )
2656
- self.assert_eq(
2657
- cdf.cube("name", cdf.department).avg("salary", "year").toPandas(),
2658
- sdf.cube("name", sdf.department).avg("salary", "year").toPandas(),
2659
- )
2660
- self.assert_eq(
2661
- cdf.cube("name", cdf.department).sum("salary", "year").toPandas(),
2662
- sdf.cube("name", sdf.department).sum("salary", "year").toPandas(),
2663
- )
2664
-
2665
- # test pivot
2666
- # pivot with values
2667
- self.assert_eq(
2668
- cdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2669
- sdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2670
- )
2671
- self.assert_eq(
2672
- cdf.groupBy("name")
2673
- .pivot("department", ["Sales", "Marketing"])
2674
- .min("salary")
2675
- .toPandas(),
2676
- sdf.groupBy("name")
2677
- .pivot("department", ["Sales", "Marketing"])
2678
- .min("salary")
2679
- .toPandas(),
2680
- )
2681
- self.assert_eq(
2682
- cdf.groupBy("name")
2683
- .pivot("department", ["Sales", "Marketing"])
2684
- .max("salary")
2685
- .toPandas(),
2686
- sdf.groupBy("name")
2687
- .pivot("department", ["Sales", "Marketing"])
2688
- .max("salary")
2689
- .toPandas(),
2690
- )
2691
- self.assert_eq(
2692
- cdf.groupBy(cdf.name)
2693
- .pivot("department", ["Sales", "Finance", "Unknown"])
2694
- .avg("salary", "year")
2695
- .toPandas(),
2696
- sdf.groupBy(sdf.name)
2697
- .pivot("department", ["Sales", "Finance", "Unknown"])
2698
- .avg("salary", "year")
2699
- .toPandas(),
2700
- )
2701
- self.assert_eq(
2702
- cdf.groupBy(cdf.name)
2703
- .pivot("department", ["Sales", "Finance", "Unknown"])
2704
- .sum("salary", "year")
2705
- .toPandas(),
2706
- sdf.groupBy(sdf.name)
2707
- .pivot("department", ["Sales", "Finance", "Unknown"])
2708
- .sum("salary", "year")
2709
- .toPandas(),
2710
- )
2711
-
2712
- # pivot without values
2713
- self.assert_eq(
2714
- cdf.groupBy("name").pivot("department").min().toPandas(),
2715
- sdf.groupBy("name").pivot("department").min().toPandas(),
2716
- )
2717
- self.assert_eq(
2718
- cdf.groupBy("name").pivot("department").min("salary").toPandas(),
2719
- sdf.groupBy("name").pivot("department").min("salary").toPandas(),
2720
- )
2721
- self.assert_eq(
2722
- cdf.groupBy("name").pivot("department").max("salary").toPandas(),
2723
- sdf.groupBy("name").pivot("department").max("salary").toPandas(),
2724
- )
2725
- self.assert_eq(
2726
- cdf.groupBy(cdf.name).pivot("department").avg("salary", "year").toPandas(),
2727
- sdf.groupBy(sdf.name).pivot("department").avg("salary", "year").toPandas(),
2728
- )
2729
- self.assert_eq(
2730
- cdf.groupBy(cdf.name).pivot("department").sum("salary", "year").toPandas(),
2731
- sdf.groupBy(sdf.name).pivot("department").sum("salary", "year").toPandas(),
2732
- )
2733
-
2734
- # check error
2735
- with self.assertRaisesRegex(
2736
- TypeError,
2737
- "Numeric aggregation function can only be applied on numeric columns",
2738
- ):
2739
- cdf.groupBy("name").min("department").show()
2740
-
2741
- with self.assertRaisesRegex(
2742
- TypeError,
2743
- "Numeric aggregation function can only be applied on numeric columns",
2744
- ):
2745
- cdf.groupBy("name").max("salary", "department").show()
2746
-
2747
- with self.assertRaisesRegex(
2748
- TypeError,
2749
- "Numeric aggregation function can only be applied on numeric columns",
2750
- ):
2751
- cdf.rollup("name").avg("department").show()
2752
-
2753
- with self.assertRaisesRegex(
2754
- TypeError,
2755
- "Numeric aggregation function can only be applied on numeric columns",
2756
- ):
2757
- cdf.rollup("name").sum("salary", "department").show()
2758
-
2759
- with self.assertRaisesRegex(
2760
- TypeError,
2761
- "Numeric aggregation function can only be applied on numeric columns",
2762
- ):
2763
- cdf.cube("name").min("department").show()
2764
-
2765
- with self.assertRaisesRegex(
2766
- TypeError,
2767
- "Numeric aggregation function can only be applied on numeric columns",
2768
- ):
2769
- cdf.cube("name").max("salary", "department").show()
2770
-
2771
- with self.assertRaisesRegex(
2772
- TypeError,
2773
- "Numeric aggregation function can only be applied on numeric columns",
2774
- ):
2775
- cdf.groupBy("name").pivot("department").avg("department").show()
2776
-
2777
- with self.assertRaisesRegex(
2778
- TypeError,
2779
- "Numeric aggregation function can only be applied on numeric columns",
2780
- ):
2781
- cdf.groupBy("name").pivot("department").sum("salary", "department").show()
2782
-
2783
- def test_with_metadata(self):
2784
- cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"])
2785
- self.assertEqual(cdf.schema["age"].metadata, {})
2786
- self.assertEqual(cdf.schema["name"].metadata, {})
2787
-
2788
- cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5})
2789
- self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5})
2790
-
2791
- cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]})
2792
- self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]})
2793
-
2794
- with self.assertRaises(PySparkTypeError) as pe:
2795
- cdf.withMetadata(columnName="name", metadata=["magic"])
2796
-
2797
- self.check_error(
2798
- exception=pe.exception,
2799
- error_class="NOT_DICT",
2800
- message_parameters={
2801
- "arg_name": "metadata",
2802
- "arg_type": "list",
2803
- },
2804
- )
2805
-
2806
- def test_collect_nested_type(self):
2807
- query = """
2808
- SELECT * FROM VALUES
2809
- (1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
2810
- (2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
2811
- (3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
2812
- AS tab(a, b, c, d, e, f, g, h)
2813
- """
2814
-
2815
- # +---+---+----+----+-----+----+------------+-------------------+
2816
- # | a| b| c| d| e| f| g| h|
2817
- # +---+---+----+----+-----+----+------------+-------------------+
2818
- # | 1| 4| 0| 8| true|true|[1, null, 3]| {1 -> 2, 3 -> 4}|
2819
- # | 2| 5| -1|NULL|false|NULL| [1, 3]|{1 -> null, 3 -> 4}|
2820
- # | 3| 6|NULL| 0|false|NULL| [null]| NULL|
2821
- # +---+---+----+----+-----+----+------------+-------------------+
2822
-
2823
- cdf = self.connect.sql(query)
2824
- sdf = self.spark.sql(query)
2825
-
2826
- # test collect array
2827
- # +--------------+-------------+------------+
2828
- # |array(a, b, c)| array(e, f)| g|
2829
- # +--------------+-------------+------------+
2830
- # | [1, 4, 0]| [true, true]|[1, null, 3]|
2831
- # | [2, 5, -1]|[false, null]| [1, 3]|
2832
- # | [3, 6, null]|[false, null]| [null]|
2833
- # +--------------+-------------+------------+
2834
- self.assertEqual(
2835
- cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"), CF.col("g")).collect(),
2836
- sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"), SF.col("g")).collect(),
2837
- )
2838
-
2839
- # test collect nested array
2840
- # +-----------------------------------+-------------------------+
2841
- # |array(array(a), array(b), array(c))|array(array(e), array(f))|
2842
- # +-----------------------------------+-------------------------+
2843
- # | [[1], [4], [0]]| [[true], [true]]|
2844
- # | [[2], [5], [-1]]| [[false], [null]]|
2845
- # | [[3], [6], [null]]| [[false], [null]]|
2846
- # +-----------------------------------+-------------------------+
2847
- self.assertEqual(
2848
- cdf.select(
2849
- CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
2850
- CF.array(CF.array("e"), CF.array("f")),
2851
- ).collect(),
2852
- sdf.select(
2853
- SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
2854
- SF.array(SF.array("e"), SF.array("f")),
2855
- ).collect(),
2856
- )
2857
-
2858
- # test collect array of struct, map
2859
- # +----------------+---------------------+
2860
- # |array(struct(a))| array(h)|
2861
- # +----------------+---------------------+
2862
- # | [{1}]| [{1 -> 2, 3 -> 4}]|
2863
- # | [{2}]|[{1 -> null, 3 -> 4}]|
2864
- # | [{3}]| [null]|
2865
- # +----------------+---------------------+
2866
- self.assertEqual(
2867
- cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
2868
- sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
2869
- )
2870
-
2871
- # test collect map
2872
- # +-------------------+-------------------+
2873
- # | h| map(a, b, b, c)|
2874
- # +-------------------+-------------------+
2875
- # | {1 -> 2, 3 -> 4}| {1 -> 4, 4 -> 0}|
2876
- # |{1 -> null, 3 -> 4}| {2 -> 5, 5 -> -1}|
2877
- # | NULL|{3 -> 6, 6 -> null}|
2878
- # +-------------------+-------------------+
2879
- self.assertEqual(
2880
- cdf.select(CF.col("h"), CF.create_map("a", "b", "b", "c")).collect(),
2881
- sdf.select(SF.col("h"), SF.create_map("a", "b", "b", "c")).collect(),
2882
- )
2883
-
2884
- # test collect map of struct, array
2885
- # +-------------------+------------------------+
2886
- # | map(a, g)| map(a, struct(b, g))|
2887
- # +-------------------+------------------------+
2888
- # |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
2889
- # | {2 -> [1, 3]}| {2 -> {5, [1, 3]}}|
2890
- # | {3 -> [null]}| {3 -> {6, [null]}}|
2891
- # +-------------------+------------------------+
2892
- self.assertEqual(
2893
- cdf.select(CF.create_map("a", "g"), CF.create_map("a", CF.struct("b", "g"))).collect(),
2894
- sdf.select(SF.create_map("a", "g"), SF.create_map("a", SF.struct("b", "g"))).collect(),
2895
- )
2896
-
2897
- # test collect struct
2898
- # +------------------+--------------------------+
2899
- # |struct(a, b, c, d)| struct(e, f, g)|
2900
- # +------------------+--------------------------+
2901
- # | {1, 4, 0, 8}|{true, true, [1, null, 3]}|
2902
- # | {2, 5, -1, null}| {false, null, [1, 3]}|
2903
- # | {3, 6, null, 0}| {false, null, [null]}|
2904
- # +------------------+--------------------------+
2905
- self.assertEqual(
2906
- cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f", "g")).collect(),
2907
- sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f", "g")).collect(),
2908
- )
2909
-
2910
- # test collect nested struct
2911
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2912
- # |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c, d))| struct(e, f, struct(g))| # noqa
2913
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2914
- # | {1, {1, {0, {8}}}}| {1, 4, {0, 8}}|{true, true, {[1, null, 3]}}| # noqa
2915
- # | {2, {2, {-1, {null}}}}| {2, 5, {-1, null}}| {false, null, {[1, 3]}}| # noqa
2916
- # | {3, {3, {null, {0}}}}| {3, 6, {null, 0}}| {false, null, {[null]}}| # noqa
2917
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2918
- self.assertEqual(
2919
- cdf.select(
2920
- CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
2921
- CF.struct("a", "b", CF.struct("c", "d")),
2922
- CF.struct("e", "f", CF.struct("g")),
2923
- ).collect(),
2924
- sdf.select(
2925
- SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
2926
- SF.struct("a", "b", SF.struct("c", "d")),
2927
- SF.struct("e", "f", SF.struct("g")),
2928
- ).collect(),
2929
- )
2930
-
2931
- # test collect struct containing array, map
2932
- # +--------------------------------------------+
2933
- # | struct(a, struct(a, struct(g, struct(h))))|
2934
- # +--------------------------------------------+
2935
- # |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
2936
- # | {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
2937
- # | {3, {3, {[null], {null}}}}|
2938
- # +--------------------------------------------+
2939
- self.assertEqual(
2940
- cdf.select(
2941
- CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
2942
- ).collect(),
2943
- sdf.select(
2944
- SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
2945
- ).collect(),
2946
- )
2947
-
2948
- def test_simple_udt(self):
2949
- from pyspark.ml.linalg import MatrixUDT, VectorUDT
2950
-
2951
- for schema in [
2952
- StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2953
- StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2954
- StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())),
2955
- StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2956
- StructType().add("key", LongType()).add("vec", VectorUDT()),
2957
- StructType().add("key", LongType()).add("mat", MatrixUDT()),
2958
- ]:
2959
- cdf = self.connect.createDataFrame(data=[], schema=schema)
2960
- sdf = self.spark.createDataFrame(data=[], schema=schema)
2961
-
2962
- self.assertEqual(cdf.schema, sdf.schema)
2963
-
2964
- def test_simple_udt_from_read(self):
2965
- from pyspark.ml.linalg import Matrices, Vectors
2966
-
2967
- with tempfile.TemporaryDirectory() as d:
2968
- path1 = f"{d}/df1.parquet"
2969
- self.spark.createDataFrame(
2970
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2971
- schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2972
- ).write.parquet(path1)
2973
-
2974
- path2 = f"{d}/df2.parquet"
2975
- self.spark.createDataFrame(
2976
- [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
2977
- schema=StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2978
- ).write.parquet(path2)
2979
-
2980
- path3 = f"{d}/df3.parquet"
2981
- self.spark.createDataFrame(
2982
- [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
2983
- schema=StructType()
2984
- .add("key", LongType())
2985
- .add("val", MapType(LongType(), PythonOnlyUDT())),
2986
- ).write.parquet(path3)
2987
-
2988
- path4 = f"{d}/df4.parquet"
2989
- self.spark.createDataFrame(
2990
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2991
- schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2992
- ).write.parquet(path4)
2993
-
2994
- path5 = f"{d}/df5.parquet"
2995
- self.spark.createDataFrame(
2996
- [Row(label=1.0, point=ExamplePoint(1.0, 2.0))]
2997
- ).write.parquet(path5)
2998
-
2999
- path6 = f"{d}/df6.parquet"
3000
- self.spark.createDataFrame(
3001
- [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)],
3002
- ["vec"],
3003
- ).write.parquet(path6)
3004
-
3005
- path7 = f"{d}/df7.parquet"
3006
- self.spark.createDataFrame(
3007
- [
3008
- (Matrices.dense(3, 2, [0, 1, 4, 5, 9, 10]),),
3009
- (Matrices.sparse(1, 1, [0, 1], [0], [2.0]),),
3010
- ],
3011
- ["mat"],
3012
- ).write.parquet(path7)
3013
-
3014
- for path in [path1, path2, path3, path4, path5, path6, path7]:
3015
- self.assertEqual(
3016
- self.connect.read.parquet(path).schema,
3017
- self.spark.read.parquet(path).schema,
3018
- )
3019
-
3020
- def test_version(self):
3021
- self.assertEqual(
3022
- self.connect.version,
3023
- self.spark.version,
3024
- )
3025
-
3026
- def test_same_semantics(self):
3027
- plan = self.connect.sql("SELECT 1")
3028
- other = self.connect.sql("SELECT 1")
3029
- self.assertTrue(plan.sameSemantics(other))
3030
-
3031
- def test_semantic_hash(self):
3032
- plan = self.connect.sql("SELECT 1")
3033
- other = self.connect.sql("SELECT 1")
3034
- self.assertEqual(
3035
- plan.semanticHash(),
3036
- other.semanticHash(),
3037
- )
3038
-
3039
- def test_unsupported_functions(self):
3040
- # SPARK-41225: Disable unsupported functions.
3041
- df = self.connect.read.table(self.tbl_name)
3042
- for f in (
3043
- "rdd",
3044
- "foreach",
3045
- "foreachPartition",
3046
- "checkpoint",
3047
- "localCheckpoint",
3048
- ):
3049
- with self.assertRaises(NotImplementedError):
3050
- getattr(df, f)()
3051
-
3052
- def test_unsupported_session_functions(self):
3053
- # SPARK-41934: Disable unsupported functions.
3054
-
3055
- with self.assertRaises(NotImplementedError):
3056
- RemoteSparkSession.builder.enableHiveSupport()
3057
-
3058
- for f in (
3059
- "newSession",
3060
- "sparkContext",
3061
- ):
3062
- with self.assertRaises(NotImplementedError):
3063
- getattr(self.connect, f)()
3064
-
3065
- def test_sql_with_command(self):
3066
- # SPARK-42705: spark.sql should return values from the command.
3067
- self.assertEqual(
3068
- self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()
3069
- )
3070
-
3071
- def test_schema_has_nullable(self):
3072
- schema_false = StructType().add("id", IntegerType(), False)
3073
- cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
3074
- sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
3075
- self.assertEqual(cdf1.schema, sdf1.schema)
3076
- self.assertEqual(cdf1.collect(), sdf1.collect())
3077
-
3078
- schema_true = StructType().add("id", IntegerType(), True)
3079
- cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
3080
- sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
3081
- self.assertEqual(cdf2.schema, sdf2.schema)
3082
- self.assertEqual(cdf2.collect(), sdf2.collect())
3083
-
3084
- pdf1 = cdf1.toPandas()
3085
- cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
3086
- sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
3087
- self.assertEqual(cdf3.schema, sdf3.schema)
3088
- self.assertEqual(cdf3.collect(), sdf3.collect())
3089
-
3090
- pdf2 = cdf2.toPandas()
3091
- cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
3092
- sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
3093
- self.assertEqual(cdf4.schema, sdf4.schema)
3094
- self.assertEqual(cdf4.collect(), sdf4.collect())
3095
-
3096
- def test_array_has_nullable(self):
3097
- for schemas, data in [
3098
- (
3099
- [StructType().add("arr", ArrayType(IntegerType(), False), True)],
3100
- [Row([1, 2]), Row([3]), Row(None)],
3101
- ),
3102
- (
3103
- [
3104
- StructType().add("arr", ArrayType(IntegerType(), True), True),
3105
- "arr array<integer>",
3106
- ],
3107
- [Row([1, None]), Row([3]), Row(None)],
3108
- ),
3109
- (
3110
- [StructType().add("arr", ArrayType(IntegerType(), False), False)],
3111
- [Row([1, 2]), Row([3])],
3112
- ),
3113
- (
3114
- [
3115
- StructType().add("arr", ArrayType(IntegerType(), True), False),
3116
- "arr array<integer> not null",
3117
- ],
3118
- [Row([1, None]), Row([3])],
3119
- ),
3120
- ]:
3121
- for schema in schemas:
3122
- with self.subTest(schema=schema):
3123
- cdf = self.connect.createDataFrame(data, schema=schema)
3124
- sdf = self.spark.createDataFrame(data, schema=schema)
3125
- self.assertEqual(cdf.schema, sdf.schema)
3126
- self.assertEqual(cdf.collect(), sdf.collect())
3127
-
3128
- def test_map_has_nullable(self):
3129
- for schemas, data in [
3130
- (
3131
- [StructType().add("map", MapType(StringType(), IntegerType(), False), True)],
3132
- [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
3133
- ),
3134
- (
3135
- [
3136
- StructType().add("map", MapType(StringType(), IntegerType(), True), True),
3137
- "map map<string, integer>",
3138
- ],
3139
- [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
3140
- ),
3141
- (
3142
- [StructType().add("map", MapType(StringType(), IntegerType(), False), False)],
3143
- [Row({"a": 1, "b": 2}), Row({"a": 3})],
3144
- ),
3145
- (
3146
- [
3147
- StructType().add("map", MapType(StringType(), IntegerType(), True), False),
3148
- "map map<string, integer> not null",
3149
- ],
3150
- [Row({"a": 1, "b": None}), Row({"a": 3})],
3151
- ),
3152
- ]:
3153
- for schema in schemas:
3154
- with self.subTest(schema=schema):
3155
- cdf = self.connect.createDataFrame(data, schema=schema)
3156
- sdf = self.spark.createDataFrame(data, schema=schema)
3157
- self.assertEqual(cdf.schema, sdf.schema)
3158
- self.assertEqual(cdf.collect(), sdf.collect())
3159
-
3160
- def test_struct_has_nullable(self):
3161
- for schemas, data in [
3162
- (
3163
- [
3164
- StructType().add("struct", StructType().add("i", IntegerType(), False), True),
3165
- "struct struct<i: integer not null>",
3166
- ],
3167
- [Row(Row(1)), Row(Row(2)), Row(None)],
3168
- ),
3169
- (
3170
- [
3171
- StructType().add("struct", StructType().add("i", IntegerType(), True), True),
3172
- "struct struct<i: integer>",
3173
- ],
3174
- [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
3175
- ),
3176
- (
3177
- [
3178
- StructType().add("struct", StructType().add("i", IntegerType(), False), False),
3179
- "struct struct<i: integer not null> not null",
3180
- ],
3181
- [Row(Row(1)), Row(Row(2))],
3182
- ),
3183
- (
3184
- [
3185
- StructType().add("struct", StructType().add("i", IntegerType(), True), False),
3186
- "struct struct<i: integer> not null",
3187
- ],
3188
- [Row(Row(1)), Row(Row(2)), Row(Row(None))],
3189
- ),
3190
- ]:
3191
- for schema in schemas:
3192
- with self.subTest(schema=schema):
3193
- cdf = self.connect.createDataFrame(data, schema=schema)
3194
- sdf = self.spark.createDataFrame(data, schema=schema)
3195
- self.assertEqual(cdf.schema, sdf.schema)
3196
- self.assertEqual(cdf.collect(), sdf.collect())
3197
-
3198
- def test_large_client_data(self):
3199
- # SPARK-42816 support more than 4MB message size.
3200
- # ~200bytes
3201
- cols = ["abcdefghijklmnoprstuvwxyz" for x in range(10)]
3202
- # 100k rows => 20MB
3203
- row_count = 100 * 1000
3204
- rows = [cols] * row_count
3205
- self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
3206
-
3207
- def test_unsupported_jvm_attribute(self):
3208
- # Unsupported jvm attributes for Spark session.
3209
- unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
3210
- spark_session = self.connect
3211
- for attr in unsupported_attrs:
3212
- with self.assertRaises(PySparkAttributeError) as pe:
3213
- getattr(spark_session, attr)
3214
-
3215
- self.check_error(
3216
- exception=pe.exception,
3217
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3218
- message_parameters={"attr_name": attr},
3219
- )
3220
-
3221
- # Unsupported jvm attributes for DataFrame.
3222
- unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"]
3223
- cdf = self.connect.range(10)
3224
- for attr in unsupported_attrs:
3225
- with self.assertRaises(PySparkAttributeError) as pe:
3226
- getattr(cdf, attr)
3227
-
3228
- self.check_error(
3229
- exception=pe.exception,
3230
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3231
- message_parameters={"attr_name": attr},
3232
- )
3233
-
3234
- # Unsupported jvm attributes for Column.
3235
- with self.assertRaises(PySparkAttributeError) as pe:
3236
- getattr(cdf.id, "_jc")
3237
-
3238
- self.check_error(
3239
- exception=pe.exception,
3240
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3241
- message_parameters={"attr_name": "_jc"},
3242
- )
3243
-
3244
- # Unsupported jvm attributes for DataFrameReader.
3245
- with self.assertRaises(PySparkAttributeError) as pe:
3246
- getattr(spark_session.read, "_jreader")
3247
-
3248
- self.check_error(
3249
- exception=pe.exception,
3250
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3251
- message_parameters={"attr_name": "_jreader"},
3252
- )
3253
-
3254
- def test_df_caache(self):
3255
- df = self.connect.range(10)
3256
- df.cache()
3257
- self.assert_eq(10, df.count())
3258
- self.assertTrue(df.is_cached)
3259
-
3260
-
3261
- @unittest.skipIf(
3262
- "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Session creation different from local mode"
3263
- )
3264
- class SparkConnectSessionTests(ReusedConnectTestCase):
3265
- def setUp(self) -> None:
3266
- self.spark = (
3267
- PySparkSession.builder.config(conf=self.conf())
3268
- .appName(self.__class__.__name__)
3269
- .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
3270
- .getOrCreate()
3271
- )
3272
-
3273
- def tearDown(self):
3274
- self.spark.stop()
3275
-
3276
- def _check_no_active_session_error(self, e: PySparkException):
3277
- self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict())
3278
-
3279
- def test_stop_session(self):
3280
- df = self.spark.sql("select 1 as a, 2 as b")
3281
- catalog = self.spark.catalog
3282
- self.spark.stop()
3283
-
3284
- # _execute_and_fetch
3285
- with self.assertRaises(SparkConnectException) as e:
3286
- self.spark.sql("select 1")
3287
- self._check_no_active_session_error(e.exception)
3288
-
3289
- with self.assertRaises(SparkConnectException) as e:
3290
- catalog.tableExists("table")
3291
- self._check_no_active_session_error(e.exception)
3292
-
3293
- # _execute
3294
- with self.assertRaises(SparkConnectException) as e:
3295
- self.spark.udf.register("test_func", lambda x: x + 1)
3296
- self._check_no_active_session_error(e.exception)
3297
-
3298
- # _analyze
3299
- with self.assertRaises(SparkConnectException) as e:
3300
- df._explain_string(extended=True)
3301
- self._check_no_active_session_error(e.exception)
3302
-
3303
- # Config
3304
- with self.assertRaises(SparkConnectException) as e:
3305
- self.spark.conf.get("some.conf")
3306
- self._check_no_active_session_error(e.exception)
3307
-
3308
- def test_error_stack_trace(self):
3309
- with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
3310
- with self.assertRaises(AnalysisException) as e:
3311
- self.spark.sql("select x").collect()
3312
- self.assertTrue("JVM stacktrace" in e.exception.message)
3313
- self.assertTrue(
3314
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3315
- )
3316
-
3317
- with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
3318
- with self.assertRaises(AnalysisException) as e:
3319
- self.spark.sql("select x").collect()
3320
- self.assertFalse("JVM stacktrace" in e.exception.message)
3321
- self.assertFalse(
3322
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3323
- )
3324
-
3325
- # Create a new session with a different stack trace size.
3326
- self.spark.stop()
3327
- spark = (
3328
- PySparkSession.builder.config(conf=self.conf())
3329
- .config("spark.connect.jvmStacktrace.maxSize", 128)
3330
- .remote("local[4]")
3331
- .getOrCreate()
3332
- )
3333
- spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
3334
- with self.assertRaises(AnalysisException) as e:
3335
- spark.sql("select x").collect()
3336
- self.assertTrue("JVM stacktrace" in e.exception.message)
3337
- self.assertFalse(
3338
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3339
- )
3340
- spark.stop()
3341
-
3342
- def test_can_create_multiple_sessions_to_different_remotes(self):
3343
- self.spark.stop()
3344
- self.assertIsNotNone(self.spark._client)
3345
- # Creates a new remote session.
3346
- other = PySparkSession.builder.remote("sc://other.remote:114/").create()
3347
- self.assertNotEquals(self.spark, other)
3348
-
3349
- # Gets currently active session.
3350
- same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
3351
- self.assertEquals(other, same)
3352
- same.stop()
3353
-
3354
- # Make sure the environment is clean.
3355
- self.spark.stop()
3356
- with self.assertRaises(RuntimeError) as e:
3357
- PySparkSession.builder.create()
3358
- self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e))
3359
-
3360
-
3361
- @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
3362
- class SparkConnectSessionWithOptionsTest(unittest.TestCase):
3363
- def setUp(self) -> None:
3364
- self.spark = (
3365
- PySparkSession.builder.config("string", "foo")
3366
- .config("integer", 1)
3367
- .config("boolean", False)
3368
- .appName(self.__class__.__name__)
3369
- .remote("local[4]")
3370
- .getOrCreate()
3371
- )
3372
-
3373
- def tearDown(self):
3374
- self.spark.stop()
3375
-
3376
- def test_config(self):
3377
- # Config
3378
- self.assertEqual(self.spark.conf.get("string"), "foo")
3379
- self.assertEqual(self.spark.conf.get("boolean"), "false")
3380
- self.assertEqual(self.spark.conf.get("integer"), "1")
3381
-
3382
-
3383
- @unittest.skipIf(not should_test_connect, connect_requirement_message)
3384
- class ClientTests(unittest.TestCase):
3385
- def test_retry_error_handling(self):
3386
- # Helper class for wrapping the test.
3387
- class TestError(grpc.RpcError, Exception):
3388
- def __init__(self, code: grpc.StatusCode):
3389
- self._code = code
3390
-
3391
- def code(self):
3392
- return self._code
3393
-
3394
- def stub(retries, w, code):
3395
- w["attempts"] += 1
3396
- if w["attempts"] < retries:
3397
- w["raised"] += 1
3398
- raise TestError(code)
3399
-
3400
- # Check that max_retries 1 is only one retry so two attempts.
3401
- call_wrap = defaultdict(int)
3402
- for attempt in Retrying(
3403
- can_retry=lambda x: True,
3404
- max_retries=1,
3405
- backoff_multiplier=1,
3406
- initial_backoff=1,
3407
- max_backoff=10,
3408
- jitter=0,
3409
- min_jitter_threshold=0,
3410
- ):
3411
- with attempt:
3412
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3413
-
3414
- self.assertEqual(2, call_wrap["attempts"])
3415
- self.assertEqual(1, call_wrap["raised"])
3416
-
3417
- # Check that if we have less than 4 retries all is ok.
3418
- call_wrap = defaultdict(int)
3419
- for attempt in Retrying(
3420
- can_retry=lambda x: True,
3421
- max_retries=4,
3422
- backoff_multiplier=1,
3423
- initial_backoff=1,
3424
- max_backoff=10,
3425
- jitter=0,
3426
- min_jitter_threshold=0,
3427
- ):
3428
- with attempt:
3429
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3430
-
3431
- self.assertTrue(call_wrap["attempts"] < 4)
3432
- self.assertEqual(call_wrap["raised"], 1)
3433
-
3434
- # Exceed the retries.
3435
- call_wrap = defaultdict(int)
3436
- with self.assertRaises(TestError):
3437
- for attempt in Retrying(
3438
- can_retry=lambda x: True,
3439
- max_retries=2,
3440
- max_backoff=50,
3441
- backoff_multiplier=1,
3442
- initial_backoff=50,
3443
- jitter=0,
3444
- min_jitter_threshold=0,
3445
- ):
3446
- with attempt:
3447
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3448
-
3449
- self.assertTrue(call_wrap["attempts"] < 5)
3450
- self.assertEqual(call_wrap["raised"], 3)
3451
-
3452
- # Check that only specific exceptions are retried.
3453
- # Check that if we have less than 4 retries all is ok.
3454
- call_wrap = defaultdict(int)
3455
- for attempt in Retrying(
3456
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3457
- max_retries=4,
3458
- backoff_multiplier=1,
3459
- initial_backoff=1,
3460
- max_backoff=10,
3461
- jitter=0,
3462
- min_jitter_threshold=0,
3463
- ):
3464
- with attempt:
3465
- stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
3466
-
3467
- self.assertTrue(call_wrap["attempts"] < 4)
3468
- self.assertEqual(call_wrap["raised"], 1)
3469
-
3470
- # Exceed the retries.
3471
- call_wrap = defaultdict(int)
3472
- with self.assertRaises(TestError):
3473
- for attempt in Retrying(
3474
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3475
- max_retries=2,
3476
- max_backoff=50,
3477
- backoff_multiplier=1,
3478
- initial_backoff=50,
3479
- jitter=0,
3480
- min_jitter_threshold=0,
3481
- ):
3482
- with attempt:
3483
- stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
3484
-
3485
- self.assertTrue(call_wrap["attempts"] < 4)
3486
- self.assertEqual(call_wrap["raised"], 3)
3487
-
3488
- # Test that another error is always thrown.
3489
- call_wrap = defaultdict(int)
3490
- with self.assertRaises(TestError):
3491
- for attempt in Retrying(
3492
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3493
- max_retries=4,
3494
- backoff_multiplier=1,
3495
- initial_backoff=1,
3496
- max_backoff=10,
3497
- jitter=0,
3498
- min_jitter_threshold=0,
3499
- ):
3500
- with attempt:
3501
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3502
-
3503
- self.assertEqual(call_wrap["attempts"], 1)
3504
- self.assertEqual(call_wrap["raised"], 1)
3505
-
3506
-
3507
- @unittest.skipIf(not should_test_connect, connect_requirement_message)
3508
- class ChannelBuilderTests(unittest.TestCase):
3509
- def test_invalid_connection_strings(self):
3510
- invalid = [
3511
- "scc://host:12",
3512
- "http://host",
3513
- "sc:/host:1234/path",
3514
- "sc://host/path",
3515
- "sc://host/;parm1;param2",
3516
- ]
3517
- for i in invalid:
3518
- self.assertRaises(PySparkValueError, ChannelBuilder, i)
3519
-
3520
- def test_sensible_defaults(self):
3521
- chan = ChannelBuilder("sc://host")
3522
- self.assertFalse(chan.secure, "Default URL is not secure")
3523
-
3524
- chan = ChannelBuilder("sc://host/;token=abcs")
3525
- self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
3526
- self.assertRegex(
3527
- chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$"
3528
- )
3529
- chan = ChannelBuilder("sc://host/;use_ssl=abcs")
3530
- self.assertFalse(chan.secure, "Garbage in, false out")
3531
-
3532
- def test_user_agent(self):
3533
- chan = ChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
3534
- self.assertIn("Agent123 /3.4", chan.userAgent)
3535
-
3536
- def test_user_agent_len(self):
3537
- user_agent = "x" * 2049
3538
- chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3539
- with self.assertRaises(SparkConnectException) as err:
3540
- chan.userAgent
3541
- self.assertRegex(err.exception.message, "'user_agent' parameter should not exceed")
3542
-
3543
- user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
3544
- expected = "ä" * 341
3545
- chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3546
- self.assertIn(expected, chan.userAgent)
3547
-
3548
- def test_valid_channel_creation(self):
3549
- chan = ChannelBuilder("sc://host").toChannel()
3550
- self.assertIsInstance(chan, grpc.Channel)
3551
-
3552
- # Sets up a channel without tokens because ssl is not used.
3553
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel()
3554
- self.assertIsInstance(chan, grpc.Channel)
3555
-
3556
- chan = ChannelBuilder("sc://host/;use_ssl=true").toChannel()
3557
- self.assertIsInstance(chan, grpc.Channel)
3558
-
3559
- def test_channel_properties(self):
3560
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021")
3561
- self.assertEqual("host:15002", chan.endpoint)
3562
- self.assertIn("foo", chan.userAgent.split(" "))
3563
- self.assertEqual(True, chan.secure)
3564
- self.assertEqual("120 21", chan.get("param1"))
3565
-
3566
- def test_metadata(self):
3567
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd")
3568
- md = chan.metadata()
3569
- self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
3570
-
3571
- def test_metadata(self):
3572
- id = str(uuid.uuid4())
3573
- chan = ChannelBuilder(f"sc://host/;session_id={id}")
3574
- self.assertEqual(id, chan.session_id)
3575
-
3576
- chan = ChannelBuilder(f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true")
3577
- md = chan.metadata()
3578
- for kv in md:
3579
- self.assertNotIn(
3580
- kv[0],
3581
- [
3582
- ChannelBuilder.PARAM_SESSION_ID,
3583
- ChannelBuilder.PARAM_TOKEN,
3584
- ChannelBuilder.PARAM_USER_ID,
3585
- ChannelBuilder.PARAM_USER_AGENT,
3586
- ChannelBuilder.PARAM_USE_SSL,
3587
- ],
3588
- "Metadata must not contain fixed params",
3589
- )
3590
-
3591
- with self.assertRaises(ValueError) as ve:
3592
- chan = ChannelBuilder("sc://host/;session_id=abcd")
3593
- SparkConnectClient(chan)
3594
- self.assertIn(
3595
- "Parameter value 'session_id' must be a valid UUID format.", str(ve.exception)
3596
- )
3597
-
3598
- chan = ChannelBuilder("sc://host/")
3599
- self.assertIsNone(chan.session_id)
3600
-
3601
-
3602
- if __name__ == "__main__":
3603
- from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
3604
-
3605
- try:
3606
- import xmlrunner
3607
-
3608
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
3609
- except ImportError:
3610
- testRunner = None
3611
-
3612
- unittest.main(testRunner=testRunner, verbosity=2)