snowpark-connect 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (474) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +116 -4
  2. snowflake/snowpark_connect/config.py +13 -0
  3. snowflake/snowpark_connect/constants.py +0 -29
  4. snowflake/snowpark_connect/dataframe_container.py +6 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +56 -1
  6. snowflake/snowpark_connect/expression/literal.py +13 -2
  7. snowflake/snowpark_connect/expression/map_cast.py +5 -8
  8. snowflake/snowpark_connect/expression/map_sql_expression.py +23 -1
  9. snowflake/snowpark_connect/expression/map_udf.py +26 -8
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +199 -15
  11. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +44 -16
  12. snowflake/snowpark_connect/expression/map_unresolved_function.py +825 -353
  13. snowflake/snowpark_connect/expression/map_unresolved_star.py +3 -2
  14. snowflake/snowpark_connect/hidden_column.py +39 -0
  15. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  16. snowflake/snowpark_connect/includes/jars/{hadoop-client-api-3.3.4.jar → spark-connect-client-jvm_2.12-3.5.6.jar} +0 -0
  17. snowflake/snowpark_connect/relation/map_column_ops.py +17 -4
  18. snowflake/snowpark_connect/relation/map_extension.py +52 -11
  19. snowflake/snowpark_connect/relation/map_join.py +258 -62
  20. snowflake/snowpark_connect/relation/map_sql.py +88 -11
  21. snowflake/snowpark_connect/relation/map_udtf.py +4 -2
  22. snowflake/snowpark_connect/relation/read/map_read.py +3 -3
  23. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +1 -1
  24. snowflake/snowpark_connect/relation/read/map_read_json.py +8 -1
  25. snowflake/snowpark_connect/relation/read/map_read_table.py +1 -9
  26. snowflake/snowpark_connect/relation/read/reader_config.py +3 -1
  27. snowflake/snowpark_connect/relation/write/map_write.py +62 -53
  28. snowflake/snowpark_connect/resources_initializer.py +29 -1
  29. snowflake/snowpark_connect/server.py +18 -3
  30. snowflake/snowpark_connect/type_mapping.py +29 -25
  31. snowflake/snowpark_connect/typed_column.py +14 -0
  32. snowflake/snowpark_connect/utils/artifacts.py +23 -0
  33. snowflake/snowpark_connect/utils/context.py +6 -1
  34. snowflake/snowpark_connect/utils/scala_udf_utils.py +588 -0
  35. snowflake/snowpark_connect/utils/telemetry.py +6 -17
  36. snowflake/snowpark_connect/utils/udf_helper.py +2 -0
  37. snowflake/snowpark_connect/utils/udf_utils.py +38 -7
  38. snowflake/snowpark_connect/utils/udtf_utils.py +17 -3
  39. snowflake/snowpark_connect/version.py +1 -1
  40. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/METADATA +1 -1
  41. snowpark_connect-0.25.0.dist-info/RECORD +477 -0
  42. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  46. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +0 -16
  47. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +0 -60
  48. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +0 -306
  49. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +0 -16
  50. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +0 -53
  51. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +0 -50
  52. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +0 -43
  53. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +0 -114
  54. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +0 -47
  55. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +0 -43
  56. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +0 -46
  57. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +0 -238
  58. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +0 -194
  59. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +0 -156
  60. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +0 -184
  61. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +0 -78
  62. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +0 -292
  63. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +0 -50
  64. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +0 -152
  65. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +0 -456
  66. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +0 -96
  67. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +0 -186
  68. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +0 -77
  69. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +0 -401
  70. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +0 -528
  71. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +0 -82
  72. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +0 -409
  73. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +0 -55
  74. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +0 -441
  75. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +0 -546
  76. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +0 -71
  77. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +0 -52
  78. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +0 -494
  79. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +0 -85
  80. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +0 -138
  81. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +0 -16
  82. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +0 -151
  83. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +0 -97
  84. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +0 -143
  85. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +0 -551
  86. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +0 -137
  87. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +0 -96
  88. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +0 -142
  89. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +0 -16
  90. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +0 -137
  91. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +0 -561
  92. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +0 -172
  93. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +0 -16
  94. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +0 -353
  95. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +0 -192
  96. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +0 -680
  97. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +0 -206
  98. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +0 -471
  99. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +0 -108
  100. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +0 -16
  101. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +0 -1281
  102. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +0 -203
  103. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +0 -202
  104. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +0 -16
  105. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +0 -16
  106. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +0 -177
  107. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +0 -575
  108. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +0 -235
  109. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +0 -653
  110. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +0 -463
  111. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +0 -86
  112. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +0 -151
  113. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +0 -139
  114. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +0 -458
  115. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +0 -86
  116. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +0 -202
  117. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +0 -520
  118. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +0 -361
  119. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +0 -16
  120. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +0 -16
  121. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +0 -40
  122. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +0 -42
  123. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +0 -40
  124. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +0 -37
  125. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +0 -60
  126. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +0 -40
  127. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +0 -40
  128. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +0 -90
  129. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +0 -40
  130. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +0 -40
  131. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +0 -40
  132. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +0 -42
  133. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +0 -37
  134. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +0 -16
  135. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +0 -36
  136. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +0 -42
  137. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +0 -47
  138. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +0 -55
  139. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +0 -40
  140. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +0 -47
  141. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +0 -47
  142. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +0 -42
  143. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +0 -43
  144. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +0 -47
  145. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +0 -43
  146. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +0 -47
  147. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +0 -47
  148. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +0 -40
  149. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +0 -226
  150. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +0 -16
  151. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +0 -39
  152. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +0 -55
  153. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +0 -39
  154. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +0 -39
  155. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +0 -39
  156. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +0 -39
  157. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +0 -39
  158. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +0 -43
  159. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +0 -43
  160. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +0 -16
  161. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +0 -40
  162. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +0 -39
  163. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +0 -42
  164. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +0 -42
  165. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +0 -37
  166. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +0 -40
  167. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +0 -42
  168. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +0 -48
  169. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +0 -40
  170. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +0 -16
  171. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +0 -40
  172. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +0 -41
  173. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +0 -67
  174. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +0 -40
  175. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +0 -55
  176. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +0 -40
  177. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +0 -38
  178. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +0 -55
  179. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +0 -39
  180. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +0 -38
  181. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +0 -16
  182. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +0 -40
  183. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +0 -50
  184. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +0 -73
  185. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +0 -39
  186. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +0 -40
  187. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +0 -40
  188. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +0 -40
  189. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +0 -48
  190. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +0 -39
  191. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +0 -16
  192. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +0 -40
  193. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +0 -16
  194. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +0 -45
  195. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +0 -45
  196. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +0 -49
  197. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +0 -37
  198. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +0 -53
  199. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +0 -45
  200. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +0 -16
  201. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +0 -38
  202. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +0 -37
  203. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +0 -37
  204. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +0 -38
  205. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +0 -37
  206. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +0 -40
  207. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +0 -40
  208. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +0 -38
  209. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +0 -40
  210. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +0 -37
  211. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +0 -38
  212. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +0 -38
  213. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +0 -66
  214. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +0 -37
  215. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +0 -37
  216. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +0 -42
  217. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +0 -39
  218. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +0 -49
  219. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +0 -37
  220. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +0 -39
  221. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +0 -49
  222. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +0 -53
  223. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +0 -43
  224. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +0 -49
  225. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +0 -39
  226. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +0 -41
  227. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +0 -39
  228. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +0 -60
  229. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +0 -48
  230. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +0 -39
  231. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +0 -44
  232. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +0 -84
  233. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +0 -37
  234. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +0 -45
  235. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +0 -39
  236. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +0 -39
  237. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +0 -37
  238. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +0 -39
  239. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +0 -39
  240. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +0 -39
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +0 -39
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +0 -43
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +0 -37
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +0 -36
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +0 -37
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +0 -39
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +0 -16
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +0 -107
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +0 -224
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +0 -825
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +0 -562
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +0 -368
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +0 -257
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +0 -260
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +0 -178
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +0 -184
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +0 -497
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +0 -140
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +0 -354
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +0 -219
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +0 -192
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +0 -228
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +0 -16
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +0 -118
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +0 -198
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +0 -181
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +0 -103
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +0 -141
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +0 -109
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +0 -136
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +0 -125
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +0 -217
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +0 -16
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +0 -384
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +0 -598
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +0 -73
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +0 -869
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +0 -487
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +0 -309
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +0 -156
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +0 -149
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +0 -163
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +0 -16
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +0 -311
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +0 -524
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +0 -419
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +0 -144
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +0 -979
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +0 -234
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +0 -206
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +0 -421
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +0 -187
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +0 -397
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +0 -16
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +0 -100
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +0 -2743
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +0 -484
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +0 -276
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +0 -432
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +0 -310
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +0 -257
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +0 -160
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +0 -128
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +0 -16
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +0 -137
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +0 -16
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +0 -170
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +0 -547
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +0 -285
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +0 -106
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +0 -409
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +0 -247
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +0 -16
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +0 -105
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +0 -197
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +0 -137
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +0 -227
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +0 -634
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +0 -88
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +0 -139
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +0 -475
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +0 -265
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +0 -818
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +0 -162
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +0 -780
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +0 -741
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +0 -160
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +0 -453
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +0 -281
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +0 -487
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +0 -109
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +0 -434
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +0 -253
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +0 -152
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +0 -162
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +0 -234
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +0 -1339
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +0 -82
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +0 -124
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +0 -638
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +0 -200
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +0 -1355
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +0 -655
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +0 -113
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +0 -118
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +0 -192
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +0 -346
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +0 -495
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +0 -263
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +0 -59
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +0 -85
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +0 -364
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +0 -362
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +0 -46
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +0 -123
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +0 -581
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +0 -447
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +0 -301
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +0 -465
  360. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +0 -16
  361. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +0 -83
  362. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +0 -16
  363. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +0 -16
  364. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +0 -16
  365. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +0 -420
  366. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +0 -358
  367. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +0 -16
  368. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +0 -36
  369. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +0 -44
  370. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +0 -116
  371. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +0 -35
  372. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +0 -3612
  373. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +0 -1042
  374. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +0 -2381
  375. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +0 -1060
  376. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +0 -163
  377. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +0 -38
  378. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +0 -48
  379. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +0 -36
  380. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +0 -55
  381. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +0 -36
  382. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +0 -96
  383. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +0 -44
  384. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +0 -36
  385. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +0 -59
  386. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +0 -36
  387. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +0 -59
  388. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +0 -74
  389. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +0 -62
  390. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +0 -58
  391. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +0 -70
  392. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +0 -50
  393. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +0 -68
  394. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +0 -40
  395. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +0 -46
  396. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +0 -44
  397. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +0 -100
  398. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +0 -100
  399. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +0 -163
  400. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +0 -181
  401. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +0 -42
  402. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +0 -16
  403. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +0 -623
  404. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +0 -869
  405. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +0 -342
  406. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +0 -436
  407. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +0 -363
  408. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +0 -592
  409. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +0 -1503
  410. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +0 -392
  411. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +0 -375
  412. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +0 -411
  413. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +0 -16
  414. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +0 -401
  415. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +0 -295
  416. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +0 -106
  417. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +0 -558
  418. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +0 -1346
  419. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +0 -182
  420. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +0 -202
  421. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +0 -503
  422. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +0 -225
  423. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +0 -83
  424. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +0 -201
  425. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +0 -1931
  426. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +0 -256
  427. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +0 -69
  428. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +0 -1349
  429. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +0 -53
  430. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +0 -68
  431. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +0 -283
  432. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +0 -155
  433. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +0 -412
  434. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +0 -1581
  435. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +0 -961
  436. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +0 -165
  437. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +0 -1456
  438. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +0 -1686
  439. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +0 -16
  440. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +0 -184
  441. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +0 -706
  442. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +0 -118
  443. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +0 -160
  444. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +0 -16
  445. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +0 -306
  446. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +0 -196
  447. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +0 -44
  448. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +0 -346
  449. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +0 -89
  450. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +0 -124
  451. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +0 -69
  452. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +0 -167
  453. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +0 -194
  454. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +0 -168
  455. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +0 -939
  456. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +0 -52
  457. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +0 -66
  458. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +0 -368
  459. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +0 -257
  460. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +0 -267
  461. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +0 -153
  462. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +0 -130
  463. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +0 -350
  464. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +0 -97
  465. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +0 -271
  466. snowpark_connect-0.24.0.dist-info/RECORD +0 -898
  467. {snowpark_connect-0.24.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-connect +0 -0
  468. {snowpark_connect-0.24.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-session +0 -0
  469. {snowpark_connect-0.24.0.data → snowpark_connect-0.25.0.data}/scripts/snowpark-submit +0 -0
  470. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/WHEEL +0 -0
  471. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/LICENSE-binary +0 -0
  472. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  473. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/licenses/NOTICE-binary +0 -0
  474. {snowpark_connect-0.24.0.dist-info → snowpark_connect-0.25.0.dist-info}/top_level.txt +0 -0
@@ -1,3612 +0,0 @@
1
- #
2
- # Licensed to the Apache Software Foundation (ASF) under one or more
3
- # contributor license agreements. See the NOTICE file distributed with
4
- # this work for additional information regarding copyright ownership.
5
- # The ASF licenses this file to You under the Apache License, Version 2.0
6
- # (the "License"); you may not use this file except in compliance with
7
- # the License. You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
-
18
- import array
19
- import datetime
20
- import os
21
- import unittest
22
- import random
23
- import shutil
24
- import string
25
- import tempfile
26
- import uuid
27
- from collections import defaultdict
28
-
29
- from pyspark.errors import (
30
- PySparkAttributeError,
31
- PySparkTypeError,
32
- PySparkException,
33
- PySparkValueError,
34
- )
35
- from pyspark.errors.exceptions.base import SessionNotSameException
36
- from pyspark.sql import SparkSession as PySparkSession, Row
37
- from pyspark.sql.types import (
38
- StructType,
39
- StructField,
40
- LongType,
41
- StringType,
42
- IntegerType,
43
- MapType,
44
- ArrayType,
45
- Row,
46
- )
47
-
48
- from pyspark.testing.objects import (
49
- MyObject,
50
- PythonOnlyUDT,
51
- ExamplePoint,
52
- PythonOnlyPoint,
53
- )
54
- from pyspark.testing.sqlutils import SQLTestUtils
55
- from pyspark.testing.connectutils import (
56
- should_test_connect,
57
- ReusedConnectTestCase,
58
- connect_requirement_message,
59
- )
60
- from pyspark.testing.pandasutils import PandasOnSparkTestUtils
61
- from pyspark.errors.exceptions.connect import (
62
- AnalysisException,
63
- ParseException,
64
- SparkConnectException,
65
- )
66
-
67
- if should_test_connect:
68
- import grpc
69
- import pandas as pd
70
- import numpy as np
71
- from pyspark.sql.connect.proto import Expression as ProtoExpression
72
- from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
73
- from pyspark.sql.connect.client import ChannelBuilder
74
- from pyspark.sql.connect.column import Column
75
- from pyspark.sql.connect.readwriter import DataFrameWriterV2
76
- from pyspark.sql.dataframe import DataFrame
77
- from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
78
- from pyspark.sql import functions as SF
79
- from pyspark.sql.connect import functions as CF
80
- from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
81
-
82
-
83
- @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
84
- class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
85
- """Parent test fixture class for all Spark Connect related
86
- test cases."""
87
-
88
- @classmethod
89
- def setUpClass(cls):
90
- super(SparkConnectSQLTestCase, cls).setUpClass()
91
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
92
- # PySpark libraries.
93
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
94
-
95
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
96
- cls.spark = PySparkSession._instantiatedSession
97
- assert cls.spark is not None
98
-
99
- cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
100
- cls.testDataStr = [Row(key=str(i)) for i in range(100)]
101
- cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
102
- cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
103
-
104
- cls.tbl_name = "test_connect_basic_table_1"
105
- cls.tbl_name2 = "test_connect_basic_table_2"
106
- cls.tbl_name3 = "test_connect_basic_table_3"
107
- cls.tbl_name4 = "test_connect_basic_table_4"
108
- cls.tbl_name_empty = "test_connect_basic_table_empty"
109
-
110
- # Cleanup test data
111
- cls.spark_connect_clean_up_test_data()
112
- # Load test data
113
- cls.spark_connect_load_test_data()
114
-
115
- @classmethod
116
- def tearDownClass(cls):
117
- try:
118
- cls.spark_connect_clean_up_test_data()
119
- # Stopping Spark Connect closes the session in JVM at the server.
120
- cls.spark = cls.connect
121
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
122
- finally:
123
- super(SparkConnectSQLTestCase, cls).tearDownClass()
124
-
125
- @classmethod
126
- def spark_connect_load_test_data(cls):
127
- df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
128
- # Since we might create multiple Spark sessions, we need to create global temporary view
129
- # that is specifically maintained in the "global_temp" schema.
130
- df.write.saveAsTable(cls.tbl_name)
131
- df2 = cls.spark.createDataFrame(
132
- [(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]
133
- )
134
- df2.write.saveAsTable(cls.tbl_name2)
135
- df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"])
136
- df3.write.saveAsTable(cls.tbl_name3)
137
- df4 = cls.spark.createDataFrame(
138
- [(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"]
139
- )
140
- df4.write.saveAsTable(cls.tbl_name4)
141
- empty_table_schema = StructType(
142
- [
143
- StructField("firstname", StringType(), True),
144
- StructField("middlename", StringType(), True),
145
- StructField("lastname", StringType(), True),
146
- ]
147
- )
148
- emptyRDD = cls.spark.sparkContext.emptyRDD()
149
- empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
150
- empty_df.write.saveAsTable(cls.tbl_name_empty)
151
-
152
- @classmethod
153
- def spark_connect_clean_up_test_data(cls):
154
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
155
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
156
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name3))
157
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name4))
158
- cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
159
-
160
-
161
- class SparkConnectBasicTests(SparkConnectSQLTestCase):
162
- def test_df_getattr_behavior(self):
163
- cdf = self.connect.range(10)
164
- sdf = self.spark.range(10)
165
-
166
- sdf._simple_extension = 10
167
- cdf._simple_extension = 10
168
-
169
- self.assertEqual(sdf._simple_extension, cdf._simple_extension)
170
- self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))
171
-
172
- self.assertTrue(hasattr(cdf, "_simple_extension"))
173
- self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))
174
-
175
- def test_df_get_item(self):
176
- # SPARK-41779: test __getitem__
177
-
178
- query = """
179
- SELECT * FROM VALUES
180
- (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
181
- AS tab(a, b, c)
182
- """
183
-
184
- # +-----+----+----+
185
- # | a| b| c|
186
- # +-----+----+----+
187
- # | true| 1|NULL|
188
- # |false|NULL| 2.0|
189
- # | NULL| 3| 3.0|
190
- # +-----+----+----+
191
-
192
- cdf = self.connect.sql(query)
193
- sdf = self.spark.sql(query)
194
-
195
- # filter
196
- self.assert_eq(
197
- cdf[cdf.a].toPandas(),
198
- sdf[sdf.a].toPandas(),
199
- )
200
- self.assert_eq(
201
- cdf[cdf.b.isin(2, 3)].toPandas(),
202
- sdf[sdf.b.isin(2, 3)].toPandas(),
203
- )
204
- self.assert_eq(
205
- cdf[cdf.c > 1.5].toPandas(),
206
- sdf[sdf.c > 1.5].toPandas(),
207
- )
208
-
209
- # select
210
- self.assert_eq(
211
- cdf[[cdf.a, "b", cdf.c]].toPandas(),
212
- sdf[[sdf.a, "b", sdf.c]].toPandas(),
213
- )
214
- self.assert_eq(
215
- cdf[(cdf.a, "b", cdf.c)].toPandas(),
216
- sdf[(sdf.a, "b", sdf.c)].toPandas(),
217
- )
218
-
219
- # select by index
220
- self.assertTrue(isinstance(cdf[0], Column))
221
- self.assertTrue(isinstance(cdf[1], Column))
222
- self.assertTrue(isinstance(cdf[2], Column))
223
-
224
- self.assert_eq(
225
- cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
226
- sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
227
- )
228
-
229
- # check error
230
- with self.assertRaises(PySparkTypeError) as pe:
231
- cdf[1.5]
232
-
233
- self.check_error(
234
- exception=pe.exception,
235
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
236
- message_parameters={
237
- "arg_name": "item",
238
- "arg_type": "float",
239
- },
240
- )
241
-
242
- with self.assertRaises(PySparkTypeError) as pe:
243
- cdf[None]
244
-
245
- self.check_error(
246
- exception=pe.exception,
247
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
248
- message_parameters={
249
- "arg_name": "item",
250
- "arg_type": "NoneType",
251
- },
252
- )
253
-
254
- with self.assertRaises(PySparkTypeError) as pe:
255
- cdf[cdf]
256
-
257
- self.check_error(
258
- exception=pe.exception,
259
- error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
260
- message_parameters={
261
- "arg_name": "item",
262
- "arg_type": "DataFrame",
263
- },
264
- )
265
-
266
- def test_error_handling(self):
267
- # SPARK-41533 Proper error handling for Spark Connect
268
- df = self.connect.range(10).select("id2")
269
- with self.assertRaises(AnalysisException):
270
- df.collect()
271
-
272
- def test_simple_read(self):
273
- df = self.connect.read.table(self.tbl_name)
274
- data = df.limit(10).toPandas()
275
- # Check that the limit is applied
276
- self.assertEqual(len(data.index), 10)
277
-
278
- def test_json(self):
279
- with tempfile.TemporaryDirectory() as d:
280
- # Write a DataFrame into a JSON file
281
- self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
282
- "overwrite"
283
- ).format("json").save(d)
284
- # Read the JSON file as a DataFrame.
285
- self.assert_eq(self.connect.read.json(d).toPandas(), self.spark.read.json(d).toPandas())
286
-
287
- for schema in [
288
- "age INT, name STRING",
289
- StructType(
290
- [
291
- StructField("age", IntegerType()),
292
- StructField("name", StringType()),
293
- ]
294
- ),
295
- ]:
296
- self.assert_eq(
297
- self.connect.read.json(path=d, schema=schema).toPandas(),
298
- self.spark.read.json(path=d, schema=schema).toPandas(),
299
- )
300
-
301
- self.assert_eq(
302
- self.connect.read.json(path=d, primitivesAsString=True).toPandas(),
303
- self.spark.read.json(path=d, primitivesAsString=True).toPandas(),
304
- )
305
-
306
- def test_parquet(self):
307
- # SPARK-41445: Implement DataFrameReader.parquet
308
- with tempfile.TemporaryDirectory() as d:
309
- # Write a DataFrame into a JSON file
310
- self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
311
- "overwrite"
312
- ).format("parquet").save(d)
313
- # Read the Parquet file as a DataFrame.
314
- self.assert_eq(
315
- self.connect.read.parquet(d).toPandas(), self.spark.read.parquet(d).toPandas()
316
- )
317
-
318
- def test_text(self):
319
- # SPARK-41849: Implement DataFrameReader.text
320
- with tempfile.TemporaryDirectory() as d:
321
- # Write a DataFrame into a text file
322
- self.spark.createDataFrame(
323
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
324
- ).write.mode("overwrite").format("text").save(d)
325
- # Read the text file as a DataFrame.
326
- self.assert_eq(self.connect.read.text(d).toPandas(), self.spark.read.text(d).toPandas())
327
-
328
- def test_csv(self):
329
- # SPARK-42011: Implement DataFrameReader.csv
330
- with tempfile.TemporaryDirectory() as d:
331
- # Write a DataFrame into a text file
332
- self.spark.createDataFrame(
333
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
334
- ).write.mode("overwrite").format("csv").save(d)
335
- # Read the text file as a DataFrame.
336
- self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas())
337
-
338
- def test_multi_paths(self):
339
- # SPARK-42041: DataFrameReader should support list of paths
340
-
341
- with tempfile.TemporaryDirectory() as d:
342
- text_files = []
343
- for i in range(0, 3):
344
- text_file = f"{d}/text-{i}.text"
345
- shutil.copyfile("python/test_support/sql/text-test.txt", text_file)
346
- text_files.append(text_file)
347
-
348
- self.assertEqual(
349
- self.connect.read.text(text_files).collect(),
350
- self.spark.read.text(text_files).collect(),
351
- )
352
-
353
- with tempfile.TemporaryDirectory() as d:
354
- json_files = []
355
- for i in range(0, 5):
356
- json_file = f"{d}/json-{i}.json"
357
- shutil.copyfile("python/test_support/sql/people.json", json_file)
358
- json_files.append(json_file)
359
-
360
- self.assertEqual(
361
- self.connect.read.json(json_files).collect(),
362
- self.spark.read.json(json_files).collect(),
363
- )
364
-
365
- def test_orc(self):
366
- # SPARK-42012: Implement DataFrameReader.orc
367
- with tempfile.TemporaryDirectory() as d:
368
- # Write a DataFrame into a text file
369
- self.spark.createDataFrame(
370
- [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
371
- ).write.mode("overwrite").format("orc").save(d)
372
- # Read the text file as a DataFrame.
373
- self.assert_eq(self.connect.read.orc(d).toPandas(), self.spark.read.orc(d).toPandas())
374
-
375
- def test_join_condition_column_list_columns(self):
376
- left_connect_df = self.connect.read.table(self.tbl_name)
377
- right_connect_df = self.connect.read.table(self.tbl_name2)
378
- left_spark_df = self.spark.read.table(self.tbl_name)
379
- right_spark_df = self.spark.read.table(self.tbl_name2)
380
- joined_plan = left_connect_df.join(
381
- other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner"
382
- )
383
- joined_plan2 = left_spark_df.join(
384
- other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner"
385
- )
386
- self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas())
387
-
388
- joined_plan3 = left_connect_df.join(
389
- other=right_connect_df,
390
- on=[
391
- left_connect_df.id == right_connect_df.col1,
392
- left_connect_df.name == right_connect_df.col2,
393
- ],
394
- how="inner",
395
- )
396
- joined_plan4 = left_spark_df.join(
397
- other=right_spark_df,
398
- on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2],
399
- how="inner",
400
- )
401
- self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
402
-
403
- def test_join_ambiguous_cols(self):
404
- # SPARK-41812: test join with ambiguous columns
405
- data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
406
- cdf1 = self.connect.createDataFrame(data1)
407
- sdf1 = self.spark.createDataFrame(data1)
408
-
409
- data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
410
- cdf2 = self.connect.createDataFrame(data2)
411
- sdf2 = self.spark.createDataFrame(data2)
412
-
413
- cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
414
- sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
415
-
416
- self.assertEqual(cdf3.schema, sdf3.schema)
417
- self.assertEqual(cdf3.collect(), sdf3.collect())
418
-
419
- cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
420
- sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
421
-
422
- self.assertEqual(cdf4.schema, sdf4.schema)
423
- self.assertEqual(cdf4.collect(), sdf4.collect())
424
-
425
- cdf5 = cdf1.join(
426
- cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
427
- )
428
- sdf5 = sdf1.join(
429
- sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
430
- )
431
-
432
- self.assertEqual(cdf5.schema, sdf5.schema)
433
- self.assertEqual(cdf5.collect(), sdf5.collect())
434
-
435
- cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
436
- sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
437
-
438
- self.assertEqual(cdf6.schema, sdf6.schema)
439
- self.assertEqual(cdf6.collect(), sdf6.collect())
440
-
441
- cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
442
- sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
443
-
444
- self.assertEqual(cdf7.schema, sdf7.schema)
445
- self.assertEqual(cdf7.collect(), sdf7.collect())
446
-
447
- def test_invalid_column(self):
448
- # SPARK-41812: fail df1.select(df2.col)
449
- data1 = [Row(a=1, b=2, c=3)]
450
- cdf1 = self.connect.createDataFrame(data1)
451
-
452
- data2 = [Row(a=2, b=0)]
453
- cdf2 = self.connect.createDataFrame(data2)
454
-
455
- with self.assertRaises(AnalysisException):
456
- cdf1.select(cdf2.a).schema
457
-
458
- with self.assertRaises(AnalysisException):
459
- cdf2.withColumn("x", cdf1.a + 1).schema
460
-
461
- with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
462
- cdf3 = cdf1.select(cdf1.a)
463
- cdf3.select(cdf1.b).schema
464
-
465
- def test_collect(self):
466
- cdf = self.connect.read.table(self.tbl_name)
467
- sdf = self.spark.read.table(self.tbl_name)
468
-
469
- data = cdf.limit(10).collect()
470
- self.assertEqual(len(data), 10)
471
- # Check Row has schema column names.
472
- self.assertTrue("name" in data[0])
473
- self.assertTrue("id" in data[0])
474
-
475
- cdf = cdf.select(
476
- CF.log("id"), CF.log("id"), CF.struct("id", "name"), CF.struct("id", "name")
477
- ).limit(10)
478
- sdf = sdf.select(
479
- SF.log("id"), SF.log("id"), SF.struct("id", "name"), SF.struct("id", "name")
480
- ).limit(10)
481
-
482
- self.assertEqual(
483
- cdf.collect(),
484
- sdf.collect(),
485
- )
486
-
487
- def test_collect_timestamp(self):
488
- query = """
489
- SELECT * FROM VALUES
490
- (TIMESTAMP('2022-12-25 10:30:00'), 1),
491
- (TIMESTAMP('2022-12-25 10:31:00'), 2),
492
- (TIMESTAMP('2022-12-25 10:32:00'), 1),
493
- (TIMESTAMP('2022-12-25 10:33:00'), 2),
494
- (TIMESTAMP('2022-12-26 09:30:00'), 1),
495
- (TIMESTAMP('2022-12-26 09:35:00'), 3)
496
- AS tab(date, val)
497
- """
498
-
499
- cdf = self.connect.sql(query)
500
- sdf = self.spark.sql(query)
501
-
502
- self.assertEqual(cdf.schema, sdf.schema)
503
-
504
- self.assertEqual(cdf.collect(), sdf.collect())
505
-
506
- self.assertEqual(
507
- cdf.select(CF.date_trunc("year", cdf.date).alias("year")).collect(),
508
- sdf.select(SF.date_trunc("year", sdf.date).alias("year")).collect(),
509
- )
510
-
511
- def test_with_columns_renamed(self):
512
- # SPARK-41312: test DataFrame.withColumnsRenamed()
513
- self.assertEqual(
514
- self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
515
- self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
516
- )
517
- self.assertEqual(
518
- self.connect.read.table(self.tbl_name)
519
- .withColumnsRenamed({"id": "id_new", "name": "name_new"})
520
- .schema,
521
- self.spark.read.table(self.tbl_name)
522
- .withColumnsRenamed({"id": "id_new", "name": "name_new"})
523
- .schema,
524
- )
525
-
526
- def test_with_local_data(self):
527
- """SPARK-41114: Test creating a dataframe using local data"""
528
- pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
529
- df = self.connect.createDataFrame(pdf)
530
- rows = df.filter(df.a == CF.lit(3)).collect()
531
- self.assertTrue(len(rows) == 1)
532
- self.assertEqual(rows[0][0], 3)
533
- self.assertEqual(rows[0][1], "c")
534
-
535
- # Check correct behavior for empty DataFrame
536
- pdf = pd.DataFrame({"a": []})
537
- with self.assertRaises(ValueError):
538
- self.connect.createDataFrame(pdf)
539
-
540
- def test_with_local_ndarray(self):
541
- """SPARK-41446: Test creating a dataframe using local list"""
542
- data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
543
-
544
- sdf = self.spark.createDataFrame(data)
545
- cdf = self.connect.createDataFrame(data)
546
- self.assertEqual(sdf.schema, cdf.schema)
547
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
548
-
549
- for schema in [
550
- StructType(
551
- [
552
- StructField("col1", IntegerType(), True),
553
- StructField("col2", IntegerType(), True),
554
- StructField("col3", IntegerType(), True),
555
- StructField("col4", IntegerType(), True),
556
- ]
557
- ),
558
- "struct<col1 int, col2 int, col3 int, col4 int>",
559
- "col1 int, col2 int, col3 int, col4 int",
560
- "col1 int, col2 long, col3 string, col4 long",
561
- "col1 int, col2 string, col3 short, col4 long",
562
- ["a", "b", "c", "d"],
563
- ("x1", "x2", "x3", "x4"),
564
- ]:
565
- with self.subTest(schema=schema):
566
- sdf = self.spark.createDataFrame(data, schema=schema)
567
- cdf = self.connect.createDataFrame(data, schema=schema)
568
-
569
- self.assertEqual(sdf.schema, cdf.schema)
570
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
571
-
572
- with self.assertRaises(PySparkValueError) as pe:
573
- self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
574
-
575
- self.check_error(
576
- exception=pe.exception,
577
- error_class="AXIS_LENGTH_MISMATCH",
578
- message_parameters={"expected_length": "5", "actual_length": "4"},
579
- )
580
-
581
- with self.assertRaises(ParseException):
582
- self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
583
-
584
- with self.assertRaises(PySparkValueError) as pe:
585
- self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
586
-
587
- self.check_error(
588
- exception=pe.exception,
589
- error_class="AXIS_LENGTH_MISMATCH",
590
- message_parameters={"expected_length": "3", "actual_length": "4"},
591
- )
592
-
593
- # test 1 dim ndarray
594
- data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
595
- self.assertEqual(data.ndim, 1)
596
-
597
- sdf = self.spark.createDataFrame(data)
598
- cdf = self.connect.createDataFrame(data)
599
- self.assertEqual(sdf.schema, cdf.schema)
600
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
601
-
602
- def test_with_local_list(self):
603
- """SPARK-41446: Test creating a dataframe using local list"""
604
- data = [[1, 2, 3, 4]]
605
-
606
- sdf = self.spark.createDataFrame(data)
607
- cdf = self.connect.createDataFrame(data)
608
- self.assertEqual(sdf.schema, cdf.schema)
609
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
610
-
611
- for schema in [
612
- "struct<col1 int, col2 int, col3 int, col4 int>",
613
- "col1 int, col2 int, col3 int, col4 int",
614
- "col1 int, col2 long, col3 string, col4 long",
615
- "col1 int, col2 string, col3 short, col4 long",
616
- ["a", "b", "c", "d"],
617
- ("x1", "x2", "x3", "x4"),
618
- ]:
619
- sdf = self.spark.createDataFrame(data, schema=schema)
620
- cdf = self.connect.createDataFrame(data, schema=schema)
621
-
622
- self.assertEqual(sdf.schema, cdf.schema)
623
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
624
-
625
- with self.assertRaises(PySparkValueError) as pe:
626
- self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
627
-
628
- self.check_error(
629
- exception=pe.exception,
630
- error_class="AXIS_LENGTH_MISMATCH",
631
- message_parameters={"expected_length": "5", "actual_length": "4"},
632
- )
633
-
634
- with self.assertRaises(ParseException):
635
- self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
636
-
637
- with self.assertRaisesRegex(
638
- ValueError,
639
- "Length mismatch: Expected axis has 3 elements, new values have 4 elements",
640
- ):
641
- self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
642
-
643
- def test_with_local_rows(self):
644
- # SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries
645
- rows = [
646
- Row(course="dotNET", year=2012, earnings=10000),
647
- Row(course="Java", year=2012, earnings=20000),
648
- Row(course="dotNET", year=2012, earnings=5000),
649
- Row(course="dotNET", year=2013, earnings=48000),
650
- Row(course="Java", year=2013, earnings=30000),
651
- Row(course="Scala", year=2022, earnings=None),
652
- ]
653
- dicts = [row.asDict() for row in rows]
654
-
655
- for data in [rows, dicts]:
656
- sdf = self.spark.createDataFrame(data)
657
- cdf = self.connect.createDataFrame(data)
658
-
659
- self.assertEqual(sdf.schema, cdf.schema)
660
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
661
-
662
- # test with rename
663
- sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
664
- cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
665
-
666
- self.assertEqual(sdf.schema, cdf.schema)
667
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
668
-
669
- def test_streaming_local_relation(self):
670
- threshold_conf = "spark.sql.session.localRelationCacheThreshold"
671
- old_threshold = self.connect.conf.get(threshold_conf)
672
- threshold = 1024 * 1024
673
- self.connect.conf.set(threshold_conf, threshold)
674
- try:
675
- suffix = "abcdef"
676
- letters = string.ascii_lowercase
677
- str = "".join(random.choice(letters) for i in range(threshold)) + suffix
678
- data = [[0, str], [1, str]]
679
- for i in range(0, 2):
680
- cdf = self.connect.createDataFrame(data, ["a", "b"])
681
- self.assert_eq(cdf.count(), len(data))
682
- self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False)
683
- finally:
684
- self.connect.conf.set(threshold_conf, old_threshold)
685
-
686
- def test_with_atom_type(self):
687
- for data in [[(1), (2), (3)], [1, 2, 3]]:
688
- for schema in ["long", "int", "short"]:
689
- sdf = self.spark.createDataFrame(data, schema=schema)
690
- cdf = self.connect.createDataFrame(data, schema=schema)
691
-
692
- self.assertEqual(sdf.schema, cdf.schema)
693
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
694
-
695
- def test_with_none_and_nan(self):
696
- # SPARK-41855: make createDataFrame support None and NaN
697
- # SPARK-41814: test with eqNullSafe
698
- data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
699
- data2 = [Row(id=1, value=np.nan), Row(id=2, value=42.0), Row(id=3, value=None)]
700
- data3 = [
701
- {"id": 1, "value": float("NaN")},
702
- {"id": 2, "value": 42.0},
703
- {"id": 3, "value": None},
704
- ]
705
- data4 = [{"id": 1, "value": np.nan}, {"id": 2, "value": 42.0}, {"id": 3, "value": None}]
706
- data5 = [(1, float("NaN")), (2, 42.0), (3, None)]
707
- data6 = [(1, np.nan), (2, 42.0), (3, None)]
708
- data7 = np.array([[1, float("NaN")], [2, 42.0], [3, None]])
709
- data8 = np.array([[1, np.nan], [2, 42.0], [3, None]])
710
-
711
- # +---+-----+
712
- # | id|value|
713
- # +---+-----+
714
- # | 1| NaN|
715
- # | 2| 42.0|
716
- # | 3| NULL|
717
- # +---+-----+
718
-
719
- for data in [data1, data2, data3, data4, data5, data6, data7, data8]:
720
- if isinstance(data[0], (Row, dict)):
721
- # data1, data2, data3, data4
722
- cdf = self.connect.createDataFrame(data)
723
- sdf = self.spark.createDataFrame(data)
724
- else:
725
- # data5, data6, data7, data8
726
- cdf = self.connect.createDataFrame(data, schema=["id", "value"])
727
- sdf = self.spark.createDataFrame(data, schema=["id", "value"])
728
-
729
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
730
-
731
- self.assert_eq(
732
- cdf.select(
733
- cdf["value"].eqNullSafe(None),
734
- cdf["value"].eqNullSafe(float("NaN")),
735
- cdf["value"].eqNullSafe(42.0),
736
- ).toPandas(),
737
- sdf.select(
738
- sdf["value"].eqNullSafe(None),
739
- sdf["value"].eqNullSafe(float("NaN")),
740
- sdf["value"].eqNullSafe(42.0),
741
- ).toPandas(),
742
- )
743
-
744
- # SPARK-41851: test with nanvl
745
- data = [(1.0, float("nan")), (float("nan"), 2.0)]
746
-
747
- cdf = self.connect.createDataFrame(data, ("a", "b"))
748
- sdf = self.spark.createDataFrame(data, ("a", "b"))
749
-
750
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
751
-
752
- self.assert_eq(
753
- cdf.select(
754
- CF.nanvl("a", "b").alias("r1"), CF.nanvl(cdf.a, cdf.b).alias("r2")
755
- ).toPandas(),
756
- sdf.select(
757
- SF.nanvl("a", "b").alias("r1"), SF.nanvl(sdf.a, sdf.b).alias("r2")
758
- ).toPandas(),
759
- )
760
-
761
- # SPARK-41852: test with pmod
762
- data = [
763
- (1.0, float("nan")),
764
- (float("nan"), 2.0),
765
- (10.0, 3.0),
766
- (float("nan"), float("nan")),
767
- (-3.0, 4.0),
768
- (-10.0, 3.0),
769
- (-5.0, -6.0),
770
- (7.0, -8.0),
771
- (1.0, 2.0),
772
- ]
773
-
774
- cdf = self.connect.createDataFrame(data, ("a", "b"))
775
- sdf = self.spark.createDataFrame(data, ("a", "b"))
776
-
777
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
778
-
779
- self.assert_eq(
780
- cdf.select(CF.pmod("a", "b")).toPandas(),
781
- sdf.select(SF.pmod("a", "b")).toPandas(),
782
- )
783
-
784
- def test_cast_with_ddl(self):
785
- data = [Row(date=datetime.date(2021, 12, 27), add=2)]
786
-
787
- cdf = self.connect.createDataFrame(data, "date date, add integer")
788
- sdf = self.spark.createDataFrame(data, "date date, add integer")
789
-
790
- self.assertEqual(cdf.schema, sdf.schema)
791
-
792
- def test_create_empty_df(self):
793
- for schema in [
794
- "STRING",
795
- "x STRING",
796
- "x STRING, y INTEGER",
797
- StringType(),
798
- StructType(
799
- [
800
- StructField("x", StringType(), True),
801
- StructField("y", IntegerType(), True),
802
- ]
803
- ),
804
- ]:
805
- cdf = self.connect.createDataFrame(data=[], schema=schema)
806
- sdf = self.spark.createDataFrame(data=[], schema=schema)
807
-
808
- self.assert_eq(cdf.toPandas(), sdf.toPandas())
809
-
810
- # check error
811
- with self.assertRaises(PySparkValueError) as pe:
812
- self.connect.createDataFrame(data=[])
813
-
814
- self.check_error(
815
- exception=pe.exception,
816
- error_class="CANNOT_INFER_EMPTY_SCHEMA",
817
- message_parameters={},
818
- )
819
-
820
- def test_create_dataframe_from_arrays(self):
821
- # SPARK-42021: createDataFrame support array.array
822
- data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", [4, 5, 6]))]
823
- data2 = [(array.array("d", [1, 2, 3]), 2, "3")]
824
- data3 = [{"a": 1, "b": array.array("i", [1, 2, 3])}]
825
-
826
- for data in [data1, data2, data3]:
827
- cdf = self.connect.createDataFrame(data)
828
- sdf = self.spark.createDataFrame(data)
829
-
830
- # TODO: the nullability is different, need to fix
831
- # self.assertEqual(cdf.schema, sdf.schema)
832
- self.assertEqual(cdf.collect(), sdf.collect())
833
-
834
- def test_timestampe_create_from_rows(self):
835
- data = [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]
836
-
837
- cdf = self.connect.createDataFrame(data, ["date", "val"])
838
- sdf = self.spark.createDataFrame(data, ["date", "val"])
839
-
840
- self.assertEqual(cdf.schema, sdf.schema)
841
- self.assertEqual(cdf.collect(), sdf.collect())
842
-
843
- def test_create_dataframe_with_coercion(self):
844
- data1 = [[1.33, 1], ["2.1", 1]]
845
- data2 = [[True, 1], ["false", 1]]
846
-
847
- for data in [data1, data2]:
848
- cdf = self.connect.createDataFrame(data, ["a", "b"])
849
- sdf = self.spark.createDataFrame(data, ["a", "b"])
850
-
851
- self.assertEqual(cdf.schema, sdf.schema)
852
- self.assertEqual(cdf.collect(), sdf.collect())
853
-
854
- def test_nested_type_create_from_rows(self):
855
- data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
856
- # root
857
- # |-- a: long (nullable = true)
858
- # |-- b: struct (nullable = true)
859
- # | |-- c: long (nullable = true)
860
- # | |-- d: struct (nullable = true)
861
- # | | |-- e: long (nullable = true)
862
- # | | |-- f: struct (nullable = true)
863
- # | | | |-- g: long (nullable = true)
864
- # | | | |-- h: struct (nullable = true)
865
- # | | | | |-- i: long (nullable = true)
866
-
867
- data2 = [
868
- (
869
- 1,
870
- "a",
871
- Row(
872
- a=1,
873
- b=[1, 2, 3],
874
- c={"a": "b"},
875
- d=Row(x=1, y="y", z=Row(o=1, p=2, q=Row(g=1.5))),
876
- ),
877
- )
878
- ]
879
- # root
880
- # |-- _1: long (nullable = true)
881
- # |-- _2: string (nullable = true)
882
- # |-- _3: struct (nullable = true)
883
- # | |-- a: long (nullable = true)
884
- # | |-- b: array (nullable = true)
885
- # | | |-- element: long (containsNull = true)
886
- # | |-- c: map (nullable = true)
887
- # | | |-- key: string
888
- # | | |-- value: string (valueContainsNull = true)
889
- # | |-- d: struct (nullable = true)
890
- # | | |-- x: long (nullable = true)
891
- # | | |-- y: string (nullable = true)
892
- # | | |-- z: struct (nullable = true)
893
- # | | | |-- o: long (nullable = true)
894
- # | | | |-- p: long (nullable = true)
895
- # | | | |-- q: struct (nullable = true)
896
- # | | | | |-- g: double (nullable = true)
897
-
898
- data3 = [
899
- Row(
900
- a=1,
901
- b=[1, 2, 3],
902
- c={"a": "b"},
903
- d=Row(x=1, y="y", z=Row(1, 2, 3)),
904
- e=list("hello connect"),
905
- )
906
- ]
907
- # root
908
- # |-- a: long (nullable = true)
909
- # |-- b: array (nullable = true)
910
- # | |-- element: long (containsNull = true)
911
- # |-- c: map (nullable = true)
912
- # | |-- key: string
913
- # | |-- value: string (valueContainsNull = true)
914
- # |-- d: struct (nullable = true)
915
- # | |-- x: long (nullable = true)
916
- # | |-- y: string (nullable = true)
917
- # | |-- z: struct (nullable = true)
918
- # | | |-- _1: long (nullable = true)
919
- # | | |-- _2: long (nullable = true)
920
- # | | |-- _3: long (nullable = true)
921
- # |-- e: array (nullable = true)
922
- # | |-- element: string (containsNull = true)
923
-
924
- data4 = [
925
- {
926
- "a": 1,
927
- "b": Row(x=1, y=Row(z=2)),
928
- "c": {"x": -1, "y": 2},
929
- "d": [1, 2, 3, 4, 5],
930
- }
931
- ]
932
- # root
933
- # |-- a: long (nullable = true)
934
- # |-- b: struct (nullable = true)
935
- # | |-- x: long (nullable = true)
936
- # | |-- y: struct (nullable = true)
937
- # | | |-- z: long (nullable = true)
938
- # |-- c: map (nullable = true)
939
- # | |-- key: string
940
- # | |-- value: long (valueContainsNull = true)
941
- # |-- d: array (nullable = true)
942
- # | |-- element: long (containsNull = true)
943
-
944
- data5 = [
945
- {
946
- "a": [Row(x=1, y="2"), Row(x=-1, y="-2")],
947
- "b": [[1, 2, 3], [4, 5], [6]],
948
- "c": {3: {4: {5: 6}}, 7: {8: {9: 0}}},
949
- }
950
- ]
951
- # root
952
- # |-- a: array (nullable = true)
953
- # | |-- element: struct (containsNull = true)
954
- # | | |-- x: long (nullable = true)
955
- # | | |-- y: string (nullable = true)
956
- # |-- b: array (nullable = true)
957
- # | |-- element: array (containsNull = true)
958
- # | | |-- element: long (containsNull = true)
959
- # |-- c: map (nullable = true)
960
- # | |-- key: long
961
- # | |-- value: map (valueContainsNull = true)
962
- # | | |-- key: long
963
- # | | |-- value: map (valueContainsNull = true)
964
- # | | | |-- key: long
965
- # | | | |-- value: long (valueContainsNull = true)
966
-
967
- for data in [data1, data2, data3, data4, data5]:
968
- with self.subTest(data=data):
969
- cdf = self.connect.createDataFrame(data)
970
- sdf = self.spark.createDataFrame(data)
971
-
972
- self.assertEqual(cdf.schema, sdf.schema)
973
- self.assertEqual(cdf.collect(), sdf.collect())
974
-
975
- def test_create_df_from_objects(self):
976
- data = [MyObject(1, "1"), MyObject(2, "2")]
977
-
978
- # +---+-----+
979
- # |key|value|
980
- # +---+-----+
981
- # | 1| 1|
982
- # | 2| 2|
983
- # +---+-----+
984
-
985
- cdf = self.connect.createDataFrame(data)
986
- sdf = self.spark.createDataFrame(data)
987
-
988
- self.assertEqual(cdf.schema, sdf.schema)
989
- self.assertEqual(cdf.collect(), sdf.collect())
990
-
991
- def test_simple_explain_string(self):
992
- df = self.connect.read.table(self.tbl_name).limit(10)
993
- result = df._explain_string()
994
- self.assertGreater(len(result), 0)
995
-
996
- def test_schema(self):
997
- schema = self.connect.read.table(self.tbl_name).schema
998
- self.assertEqual(
999
- StructType(
1000
- [StructField("id", LongType(), True), StructField("name", StringType(), True)]
1001
- ),
1002
- schema,
1003
- )
1004
-
1005
- # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
1006
- query = """
1007
- SELECT * FROM VALUES
1008
- (float(1.0), double(1.0), 1.0, "1", true, NULL),
1009
- (float(2.0), double(2.0), 2.0, "2", false, NULL),
1010
- (float(3.0), double(3.0), NULL, "3", false, NULL)
1011
- AS tab(a, b, c, d, e, f)
1012
- """
1013
- self.assertEqual(
1014
- self.spark.sql(query).schema,
1015
- self.connect.sql(query).schema,
1016
- )
1017
-
1018
- # test TimestampType, DateType
1019
- query = """
1020
- SELECT * FROM VALUES
1021
- (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
1022
- (TIMESTAMP('2019-04-12 15:50:00'), NULL),
1023
- (NULL, DATE('2022-02-22'))
1024
- AS tab(a, b)
1025
- """
1026
- self.assertEqual(
1027
- self.spark.sql(query).schema,
1028
- self.connect.sql(query).schema,
1029
- )
1030
-
1031
- # test DayTimeIntervalType
1032
- query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
1033
- self.assertEqual(
1034
- self.spark.sql(query).schema,
1035
- self.connect.sql(query).schema,
1036
- )
1037
-
1038
- # test MapType
1039
- query = """
1040
- SELECT * FROM VALUES
1041
- (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
1042
- (MAP('x', 'yz'), MAP('x', NULL), NULL),
1043
- (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
1044
- AS tab(a, b, c)
1045
- """
1046
- self.assertEqual(
1047
- self.spark.sql(query).schema,
1048
- self.connect.sql(query).schema,
1049
- )
1050
-
1051
- # test ArrayType
1052
- query = """
1053
- SELECT * FROM VALUES
1054
- (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
1055
- (ARRAY('x', NULL), NULL, ARRAY(1, 3)),
1056
- (NULL, ARRAY(-1, -2, -3), Array())
1057
- AS tab(a, b, c)
1058
- """
1059
- self.assertEqual(
1060
- self.spark.sql(query).schema,
1061
- self.connect.sql(query).schema,
1062
- )
1063
-
1064
- # test StructType
1065
- query = """
1066
- SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
1067
- (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
1068
- (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
1069
- (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
1070
- AS tab(a, b, c, d, e, f, g, h)
1071
- """
1072
- self.assertEqual(
1073
- self.spark.sql(query).schema,
1074
- self.connect.sql(query).schema,
1075
- )
1076
-
1077
- def test_to(self):
1078
- # SPARK-41464: test DataFrame.to()
1079
-
1080
- cdf = self.connect.read.table(self.tbl_name)
1081
- df = self.spark.read.table(self.tbl_name)
1082
-
1083
- def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType):
1084
- cdf_to = cdf.to(schema)
1085
- df_to = df.to(schema)
1086
- self.assertEqual(cdf_to.schema, df_to.schema)
1087
- self.assert_eq(cdf_to.toPandas(), df_to.toPandas())
1088
-
1089
- # The schema has not changed
1090
- schema = StructType(
1091
- [
1092
- StructField("id", IntegerType(), True),
1093
- StructField("name", StringType(), True),
1094
- ]
1095
- )
1096
-
1097
- assert_eq_schema(cdf, df, schema)
1098
-
1099
- # Change schema with struct
1100
- schema2 = StructType([StructField("struct", schema, False)])
1101
-
1102
- cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2)
1103
- df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2)
1104
-
1105
- self.assertEqual(cdf_to.schema, df_to.schema)
1106
-
1107
- # Change the column name
1108
- schema = StructType(
1109
- [
1110
- StructField("col1", IntegerType(), True),
1111
- StructField("col2", StringType(), True),
1112
- ]
1113
- )
1114
-
1115
- assert_eq_schema(cdf, df, schema)
1116
-
1117
- # Change the column data type
1118
- schema = StructType(
1119
- [
1120
- StructField("id", StringType(), True),
1121
- StructField("name", StringType(), True),
1122
- ]
1123
- )
1124
-
1125
- assert_eq_schema(cdf, df, schema)
1126
-
1127
- # Reduce the column quantity and change data type
1128
- schema = StructType(
1129
- [
1130
- StructField("id", LongType(), True),
1131
- ]
1132
- )
1133
-
1134
- assert_eq_schema(cdf, df, schema)
1135
-
1136
- # incompatible field nullability
1137
- schema = StructType([StructField("id", LongType(), False)])
1138
- self.assertRaisesRegex(
1139
- AnalysisException,
1140
- "NULLABLE_COLUMN_OR_FIELD",
1141
- lambda: cdf.to(schema).toPandas(),
1142
- )
1143
-
1144
- # field cannot upcast
1145
- schema = StructType([StructField("name", LongType())])
1146
- self.assertRaisesRegex(
1147
- AnalysisException,
1148
- "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1149
- lambda: cdf.to(schema).toPandas(),
1150
- )
1151
-
1152
- schema = StructType(
1153
- [
1154
- StructField("id", IntegerType(), True),
1155
- StructField("name", IntegerType(), True),
1156
- ]
1157
- )
1158
- self.assertRaisesRegex(
1159
- AnalysisException,
1160
- "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1161
- lambda: cdf.to(schema).toPandas(),
1162
- )
1163
-
1164
- # Test map type and array type
1165
- schema = StructType(
1166
- [
1167
- StructField("id", StringType(), True),
1168
- StructField("my_map", MapType(StringType(), IntegerType(), False), True),
1169
- StructField("my_array", ArrayType(IntegerType(), False), True),
1170
- ]
1171
- )
1172
- cdf = self.connect.read.table(self.tbl_name4)
1173
- df = self.spark.read.table(self.tbl_name4)
1174
-
1175
- assert_eq_schema(cdf, df, schema)
1176
-
1177
- def test_toDF(self):
1178
- # SPARK-41310: test DataFrame.toDF()
1179
- self.assertEqual(
1180
- self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema,
1181
- self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema,
1182
- )
1183
-
1184
- def test_print_schema(self):
1185
- # SPARK-41216: Test print schema
1186
- tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._tree_string()
1187
- # root
1188
- # |-- X: integer (nullable = false)
1189
- # |-- Y: integer (nullable = false)
1190
- expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
1191
- self.assertEqual(tree_str, expected)
1192
-
1193
- def test_is_local(self):
1194
- # SPARK-41216: Test is local
1195
- self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal())
1196
- self.assertFalse(self.connect.read.table(self.tbl_name).isLocal())
1197
-
1198
- def test_is_streaming(self):
1199
- # SPARK-41216: Test is streaming
1200
- self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
1201
- self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
1202
-
1203
- def test_input_files(self):
1204
- # SPARK-41216: Test input files
1205
- tmpPath = tempfile.mkdtemp()
1206
- shutil.rmtree(tmpPath)
1207
- try:
1208
- self.df_text.write.text(tmpPath)
1209
-
1210
- input_files_list1 = (
1211
- self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1212
- )
1213
- input_files_list2 = (
1214
- self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1215
- )
1216
-
1217
- self.assertTrue(len(input_files_list1) > 0)
1218
- self.assertEqual(len(input_files_list1), len(input_files_list2))
1219
- for file_path in input_files_list2:
1220
- self.assertTrue(file_path in input_files_list1)
1221
- finally:
1222
- shutil.rmtree(tmpPath)
1223
-
1224
- def test_limit_offset(self):
1225
- df = self.connect.read.table(self.tbl_name)
1226
- pd = df.limit(10).offset(1).toPandas()
1227
- self.assertEqual(9, len(pd.index))
1228
- pd2 = df.offset(98).limit(10).toPandas()
1229
- self.assertEqual(2, len(pd2.index))
1230
-
1231
- def test_tail(self):
1232
- df = self.connect.read.table(self.tbl_name)
1233
- df2 = self.spark.read.table(self.tbl_name)
1234
- self.assertEqual(df.tail(10), df2.tail(10))
1235
-
1236
- def test_sql(self):
1237
- pdf = self.connect.sql("SELECT 1").toPandas()
1238
- self.assertEqual(1, len(pdf.index))
1239
-
1240
- def test_sql_with_named_args(self):
1241
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1242
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1243
- self.assert_eq(df.toPandas(), df2.toPandas())
1244
-
1245
- def test_namedargs_with_global_limit(self):
1246
- sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val)
1247
- where val = :val"""
1248
- df = self.connect.sql(sqlText, args={"val": 1})
1249
- df2 = self.spark.sql(sqlText, args={"val": 1})
1250
- self.assert_eq(df.toPandas(), df2.toPandas())
1251
-
1252
- def test_sql_with_pos_args(self):
1253
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1254
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1255
- self.assert_eq(df.toPandas(), df2.toPandas())
1256
-
1257
- def test_head(self):
1258
- # SPARK-41002: test `head` API in Python Client
1259
- df = self.connect.read.table(self.tbl_name)
1260
- self.assertIsNotNone(len(df.head()))
1261
- self.assertIsNotNone(len(df.head(1)))
1262
- self.assertIsNotNone(len(df.head(5)))
1263
- df2 = self.connect.read.table(self.tbl_name_empty)
1264
- self.assertIsNone(df2.head())
1265
-
1266
- def test_deduplicate(self):
1267
- # SPARK-41326: test distinct and dropDuplicates.
1268
- df = self.connect.read.table(self.tbl_name)
1269
- df2 = self.spark.read.table(self.tbl_name)
1270
- self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas())
1271
- self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas())
1272
- self.assert_eq(
1273
- df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()
1274
- )
1275
-
1276
- def test_deduplicate_within_watermark_in_batch(self):
1277
- df = self.connect.read.table(self.tbl_name)
1278
- with self.assertRaisesRegex(
1279
- AnalysisException,
1280
- "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
1281
- ):
1282
- df.dropDuplicatesWithinWatermark().toPandas()
1283
-
1284
- def test_first(self):
1285
- # SPARK-41002: test `first` API in Python Client
1286
- df = self.connect.read.table(self.tbl_name)
1287
- self.assertIsNotNone(len(df.first()))
1288
- df2 = self.connect.read.table(self.tbl_name_empty)
1289
- self.assertIsNone(df2.first())
1290
-
1291
- def test_take(self) -> None:
1292
- # SPARK-41002: test `take` API in Python Client
1293
- df = self.connect.read.table(self.tbl_name)
1294
- self.assertEqual(5, len(df.take(5)))
1295
- df2 = self.connect.read.table(self.tbl_name_empty)
1296
- self.assertEqual(0, len(df2.take(5)))
1297
-
1298
- def test_drop(self):
1299
- # SPARK-41169: test drop
1300
- query = """
1301
- SELECT * FROM VALUES
1302
- (false, 1, NULL), (false, NULL, 2), (NULL, 3, 3)
1303
- AS tab(a, b, c)
1304
- """
1305
-
1306
- cdf = self.connect.sql(query)
1307
- sdf = self.spark.sql(query)
1308
- self.assert_eq(
1309
- cdf.drop("a").toPandas(),
1310
- sdf.drop("a").toPandas(),
1311
- )
1312
- self.assert_eq(
1313
- cdf.drop("a", "b").toPandas(),
1314
- sdf.drop("a", "b").toPandas(),
1315
- )
1316
- self.assert_eq(
1317
- cdf.drop("a", "x").toPandas(),
1318
- sdf.drop("a", "x").toPandas(),
1319
- )
1320
- self.assert_eq(
1321
- cdf.drop(cdf.a, "x").toPandas(),
1322
- sdf.drop(sdf.a, "x").toPandas(),
1323
- )
1324
-
1325
- def test_subquery_alias(self) -> None:
1326
- # SPARK-40938: test subquery alias.
1327
- plan_text = (
1328
- self.connect.read.table(self.tbl_name)
1329
- .alias("special_alias")
1330
- ._explain_string(extended=True)
1331
- )
1332
- self.assertTrue("special_alias" in plan_text)
1333
-
1334
- def test_sort(self):
1335
- # SPARK-41332: test sort
1336
- query = """
1337
- SELECT * FROM VALUES
1338
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1339
- AS tab(a, b, c)
1340
- """
1341
- # +-----+----+----+
1342
- # | a| b| c|
1343
- # +-----+----+----+
1344
- # |false| 1|NULL|
1345
- # |false|NULL| 2.0|
1346
- # | NULL| 3| 3.0|
1347
- # +-----+----+----+
1348
-
1349
- cdf = self.connect.sql(query)
1350
- sdf = self.spark.sql(query)
1351
- self.assert_eq(
1352
- cdf.sort("a").toPandas(),
1353
- sdf.sort("a").toPandas(),
1354
- )
1355
- self.assert_eq(
1356
- cdf.sort("c").toPandas(),
1357
- sdf.sort("c").toPandas(),
1358
- )
1359
- self.assert_eq(
1360
- cdf.sort("b").toPandas(),
1361
- sdf.sort("b").toPandas(),
1362
- )
1363
- self.assert_eq(
1364
- cdf.sort(cdf.c, "b").toPandas(),
1365
- sdf.sort(sdf.c, "b").toPandas(),
1366
- )
1367
- self.assert_eq(
1368
- cdf.sort(cdf.c.desc(), "b").toPandas(),
1369
- sdf.sort(sdf.c.desc(), "b").toPandas(),
1370
- )
1371
- self.assert_eq(
1372
- cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(),
1373
- sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(),
1374
- )
1375
-
1376
- def test_range(self):
1377
- self.assert_eq(
1378
- self.connect.range(start=0, end=10).toPandas(),
1379
- self.spark.range(start=0, end=10).toPandas(),
1380
- )
1381
- self.assert_eq(
1382
- self.connect.range(start=0, end=10, step=3).toPandas(),
1383
- self.spark.range(start=0, end=10, step=3).toPandas(),
1384
- )
1385
- self.assert_eq(
1386
- self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1387
- self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1388
- )
1389
- # SPARK-41301
1390
- self.assert_eq(
1391
- self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()
1392
- )
1393
-
1394
- def test_create_global_temp_view(self):
1395
- # SPARK-41127: test global temp view creation.
1396
- with self.tempView("view_1"):
1397
- self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1398
- self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1")
1399
- self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1400
-
1401
- # Test when creating a view which is already exists but
1402
- self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1403
- with self.assertRaises(AnalysisException):
1404
- self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1405
-
1406
- def test_create_session_local_temp_view(self):
1407
- # SPARK-41372: test session local temp view creation.
1408
- with self.tempView("view_local_temp"):
1409
- self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp")
1410
- self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1)
1411
- self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp")
1412
- self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0)
1413
-
1414
- # Test when creating a view which is already exists but
1415
- with self.assertRaises(AnalysisException):
1416
- self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp")
1417
-
1418
- def test_to_pandas(self):
1419
- # SPARK-41005: Test to pandas
1420
- query = """
1421
- SELECT * FROM VALUES
1422
- (false, 1, NULL),
1423
- (false, NULL, float(2.0)),
1424
- (NULL, 3, float(3.0))
1425
- AS tab(a, b, c)
1426
- """
1427
-
1428
- self.assert_eq(
1429
- self.connect.sql(query).toPandas(),
1430
- self.spark.sql(query).toPandas(),
1431
- )
1432
-
1433
- query = """
1434
- SELECT * FROM VALUES
1435
- (1, 1, NULL),
1436
- (2, NULL, float(2.0)),
1437
- (3, 3, float(3.0))
1438
- AS tab(a, b, c)
1439
- """
1440
-
1441
- self.assert_eq(
1442
- self.connect.sql(query).toPandas(),
1443
- self.spark.sql(query).toPandas(),
1444
- )
1445
-
1446
- query = """
1447
- SELECT * FROM VALUES
1448
- (double(1.0), 1, "1"),
1449
- (NULL, NULL, NULL),
1450
- (double(2.0), 3, "3")
1451
- AS tab(a, b, c)
1452
- """
1453
-
1454
- self.assert_eq(
1455
- self.connect.sql(query).toPandas(),
1456
- self.spark.sql(query).toPandas(),
1457
- )
1458
-
1459
- query = """
1460
- SELECT * FROM VALUES
1461
- (float(1.0), double(1.0), 1, "1"),
1462
- (float(2.0), double(2.0), 2, "2"),
1463
- (float(3.0), double(3.0), 3, "3")
1464
- AS tab(a, b, c, d)
1465
- """
1466
-
1467
- self.assert_eq(
1468
- self.connect.sql(query).toPandas(),
1469
- self.spark.sql(query).toPandas(),
1470
- )
1471
-
1472
- def test_create_dataframe_from_pandas_with_ns_timestamp(self):
1473
- """Truncate the timestamps for nanoseconds."""
1474
- from datetime import datetime, timezone, timedelta
1475
- from pandas import Timestamp
1476
- import pandas as pd
1477
-
1478
- pdf = pd.DataFrame(
1479
- {
1480
- "naive": [datetime(2019, 1, 1, 0)],
1481
- "aware": [
1482
- Timestamp(
1483
- year=2019, month=1, day=1, nanosecond=500, tz=timezone(timedelta(hours=-8))
1484
- )
1485
- ],
1486
- }
1487
- )
1488
-
1489
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
1490
- self.assertEqual(
1491
- self.connect.createDataFrame(pdf).collect(),
1492
- self.spark.createDataFrame(pdf).collect(),
1493
- )
1494
-
1495
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
1496
- self.assertEqual(
1497
- self.connect.createDataFrame(pdf).collect(),
1498
- self.spark.createDataFrame(pdf).collect(),
1499
- )
1500
-
1501
- def test_select_expr(self):
1502
- # SPARK-41201: test selectExpr API.
1503
- self.assert_eq(
1504
- self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1505
- self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1506
- )
1507
- self.assert_eq(
1508
- self.connect.read.table(self.tbl_name)
1509
- .selectExpr(["id * 2", "cast(name as long) as name"])
1510
- .toPandas(),
1511
- self.spark.read.table(self.tbl_name)
1512
- .selectExpr(["id * 2", "cast(name as long) as name"])
1513
- .toPandas(),
1514
- )
1515
-
1516
- self.assert_eq(
1517
- self.connect.read.table(self.tbl_name)
1518
- .selectExpr("id * 2", "cast(name as long) as name")
1519
- .toPandas(),
1520
- self.spark.read.table(self.tbl_name)
1521
- .selectExpr("id * 2", "cast(name as long) as name")
1522
- .toPandas(),
1523
- )
1524
-
1525
- def test_select_star(self):
1526
- data = [Row(a=1, b=Row(c=2, d=Row(e=3)))]
1527
-
1528
- # +---+--------+
1529
- # | a| b|
1530
- # +---+--------+
1531
- # | 1|{2, {3}}|
1532
- # +---+--------+
1533
-
1534
- cdf = self.connect.createDataFrame(data=data)
1535
- sdf = self.spark.createDataFrame(data=data)
1536
-
1537
- self.assertEqual(
1538
- cdf.select("*").collect(),
1539
- sdf.select("*").collect(),
1540
- )
1541
- self.assertEqual(
1542
- cdf.select("a", "*").collect(),
1543
- sdf.select("a", "*").collect(),
1544
- )
1545
- self.assertEqual(
1546
- cdf.select("a", "b").collect(),
1547
- sdf.select("a", "b").collect(),
1548
- )
1549
- self.assertEqual(
1550
- cdf.select("a", "b.*").collect(),
1551
- sdf.select("a", "b.*").collect(),
1552
- )
1553
-
1554
- def test_fill_na(self):
1555
- # SPARK-41128: Test fill na
1556
- query = """
1557
- SELECT * FROM VALUES
1558
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1559
- AS tab(a, b, c)
1560
- """
1561
- # +-----+----+----+
1562
- # | a| b| c|
1563
- # +-----+----+----+
1564
- # |false| 1|NULL|
1565
- # |false|NULL| 2.0|
1566
- # | NULL| 3| 3.0|
1567
- # +-----+----+----+
1568
-
1569
- self.assert_eq(
1570
- self.connect.sql(query).fillna(True).toPandas(),
1571
- self.spark.sql(query).fillna(True).toPandas(),
1572
- )
1573
- self.assert_eq(
1574
- self.connect.sql(query).fillna(2).toPandas(),
1575
- self.spark.sql(query).fillna(2).toPandas(),
1576
- )
1577
- self.assert_eq(
1578
- self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(),
1579
- self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(),
1580
- )
1581
- self.assert_eq(
1582
- self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1583
- self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1584
- )
1585
-
1586
- def test_drop_na(self):
1587
- # SPARK-41148: Test drop na
1588
- query = """
1589
- SELECT * FROM VALUES
1590
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1591
- AS tab(a, b, c)
1592
- """
1593
- # +-----+----+----+
1594
- # | a| b| c|
1595
- # +-----+----+----+
1596
- # |false| 1|NULL|
1597
- # |false|NULL| 2.0|
1598
- # | NULL| 3| 3.0|
1599
- # +-----+----+----+
1600
-
1601
- self.assert_eq(
1602
- self.connect.sql(query).dropna().toPandas(),
1603
- self.spark.sql(query).dropna().toPandas(),
1604
- )
1605
- self.assert_eq(
1606
- self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(),
1607
- self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(),
1608
- )
1609
- self.assert_eq(
1610
- self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1611
- self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1612
- )
1613
- self.assert_eq(
1614
- self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1615
- self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1616
- )
1617
-
1618
- def test_replace(self):
1619
- # SPARK-41315: Test replace
1620
- query = """
1621
- SELECT * FROM VALUES
1622
- (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1623
- AS tab(a, b, c)
1624
- """
1625
- # +-----+----+----+
1626
- # | a| b| c|
1627
- # +-----+----+----+
1628
- # |false| 1|NULL|
1629
- # |false|NULL| 2.0|
1630
- # | NULL| 3| 3.0|
1631
- # +-----+----+----+
1632
-
1633
- self.assert_eq(
1634
- self.connect.sql(query).replace(2, 3).toPandas(),
1635
- self.spark.sql(query).replace(2, 3).toPandas(),
1636
- )
1637
- self.assert_eq(
1638
- self.connect.sql(query).na.replace(False, True).toPandas(),
1639
- self.spark.sql(query).na.replace(False, True).toPandas(),
1640
- )
1641
- self.assert_eq(
1642
- self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1643
- self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1644
- )
1645
- self.assert_eq(
1646
- self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1647
- self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1648
- )
1649
- self.assert_eq(
1650
- self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1651
- self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1652
- )
1653
-
1654
- with self.assertRaises(ValueError) as context:
1655
- self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
1656
- self.assertTrue("Mixed type replacements are not supported" in str(context.exception))
1657
-
1658
- with self.assertRaises(AnalysisException) as context:
1659
- self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas()
1660
- self.assertIn(
1661
- """Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
1662
- )
1663
-
1664
- def test_unpivot(self):
1665
- self.assert_eq(
1666
- self.connect.read.table(self.tbl_name)
1667
- .filter("id > 3")
1668
- .unpivot(["id"], ["name"], "variable", "value")
1669
- .toPandas(),
1670
- self.spark.read.table(self.tbl_name)
1671
- .filter("id > 3")
1672
- .unpivot(["id"], ["name"], "variable", "value")
1673
- .toPandas(),
1674
- )
1675
-
1676
- self.assert_eq(
1677
- self.connect.read.table(self.tbl_name)
1678
- .filter("id > 3")
1679
- .unpivot("id", None, "variable", "value")
1680
- .toPandas(),
1681
- self.spark.read.table(self.tbl_name)
1682
- .filter("id > 3")
1683
- .unpivot("id", None, "variable", "value")
1684
- .toPandas(),
1685
- )
1686
-
1687
- def test_union_by_name(self):
1688
- # SPARK-41832: Test unionByName
1689
- data1 = [(1, 2, 3)]
1690
- data2 = [(6, 2, 5)]
1691
- df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"])
1692
- df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"])
1693
- union_df_connect = df1_connect.unionByName(df2_connect)
1694
-
1695
- df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"])
1696
- df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"])
1697
- union_df_spark = df1_spark.unionByName(df2_spark)
1698
-
1699
- self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1700
-
1701
- df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"])
1702
- union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True)
1703
-
1704
- df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"])
1705
- union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True)
1706
-
1707
- self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1708
-
1709
- def test_random_split(self):
1710
- # SPARK-41440: test randomSplit(weights, seed).
1711
- relations = (
1712
- self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1713
- )
1714
- datasets = (
1715
- self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1716
- )
1717
-
1718
- self.assertTrue(len(relations) == len(datasets))
1719
- i = 0
1720
- while i < len(relations):
1721
- self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
1722
- i += 1
1723
-
1724
- def test_observe(self):
1725
- # SPARK-41527: test DataFrame.observe()
1726
- observation_name = "my_metric"
1727
-
1728
- self.assert_eq(
1729
- self.connect.read.table(self.tbl_name)
1730
- .filter("id > 3")
1731
- .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
1732
- .toPandas(),
1733
- self.spark.read.table(self.tbl_name)
1734
- .filter("id > 3")
1735
- .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
1736
- .toPandas(),
1737
- )
1738
-
1739
- from pyspark.sql.observation import Observation
1740
-
1741
- observation = Observation(observation_name)
1742
-
1743
- cdf = (
1744
- self.connect.read.table(self.tbl_name)
1745
- .filter("id > 3")
1746
- .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
1747
- .toPandas()
1748
- )
1749
- df = (
1750
- self.spark.read.table(self.tbl_name)
1751
- .filter("id > 3")
1752
- .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
1753
- .toPandas()
1754
- )
1755
-
1756
- self.assert_eq(cdf, df)
1757
-
1758
- observed_metrics = cdf.attrs["observed_metrics"]
1759
- self.assert_eq(len(observed_metrics), 1)
1760
- self.assert_eq(observed_metrics[0].name, observation_name)
1761
- self.assert_eq(len(observed_metrics[0].metrics), 3)
1762
- for metric in observed_metrics[0].metrics:
1763
- self.assertIsInstance(metric, ProtoExpression.Literal)
1764
- values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
1765
- self.assert_eq(values, [4, 99, 4944])
1766
-
1767
- with self.assertRaises(PySparkValueError) as pe:
1768
- self.connect.read.table(self.tbl_name).observe(observation_name)
1769
-
1770
- self.check_error(
1771
- exception=pe.exception,
1772
- error_class="CANNOT_BE_EMPTY",
1773
- message_parameters={"item": "exprs"},
1774
- )
1775
-
1776
- with self.assertRaises(PySparkTypeError) as pe:
1777
- self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
1778
-
1779
- self.check_error(
1780
- exception=pe.exception,
1781
- error_class="NOT_LIST_OF_COLUMN",
1782
- message_parameters={"arg_name": "exprs"},
1783
- )
1784
-
1785
- def test_with_columns(self):
1786
- # SPARK-41256: test withColumn(s).
1787
- self.assert_eq(
1788
- self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(),
1789
- self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(),
1790
- )
1791
-
1792
- self.assert_eq(
1793
- self.connect.read.table(self.tbl_name)
1794
- .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)})
1795
- .toPandas(),
1796
- self.spark.read.table(self.tbl_name)
1797
- .withColumns(
1798
- {
1799
- "id": SF.lit(False),
1800
- "col_not_exist": SF.lit(False),
1801
- }
1802
- )
1803
- .toPandas(),
1804
- )
1805
-
1806
- def test_hint(self):
1807
- # SPARK-41349: Test hint
1808
- self.assert_eq(
1809
- self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1810
- self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1811
- )
1812
-
1813
- # Hint with unsupported name will be ignored
1814
- self.assert_eq(
1815
- self.connect.read.table(self.tbl_name).hint("illegal").toPandas(),
1816
- self.spark.read.table(self.tbl_name).hint("illegal").toPandas(),
1817
- )
1818
-
1819
- # Hint with all supported parameter values
1820
- such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
1821
- self.assert_eq(
1822
- self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1823
- self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1824
- )
1825
-
1826
- # Hint with unsupported parameter values
1827
- with self.assertRaises(AnalysisException):
1828
- self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas()
1829
-
1830
- # Hint with unsupported parameter types
1831
- with self.assertRaises(TypeError):
1832
- self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas()
1833
-
1834
- # Hint with unsupported parameter types
1835
- with self.assertRaises(TypeError):
1836
- self.connect.read.table(self.tbl_name).hint(
1837
- "my awesome hint", 1.2345, 2, such_a_nice_list, range(6)
1838
- ).toPandas()
1839
-
1840
- # Hint with wrong combination
1841
- with self.assertRaises(AnalysisException):
1842
- self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas()
1843
-
1844
- def test_join_hint(self):
1845
- cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
1846
- cdf2 = self.connect.createDataFrame(
1847
- [Row(height=80, name="Tom"), Row(height=85, name="Bob")]
1848
- )
1849
-
1850
- self.assertTrue(
1851
- "BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()
1852
- )
1853
- self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string())
1854
- self.assertTrue(
1855
- "ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()
1856
- )
1857
-
1858
- def test_different_spark_session_join_or_union(self):
1859
- df = self.connect.range(10).limit(3)
1860
-
1861
- spark2 = RemoteSparkSession(connection="sc://localhost")
1862
- df2 = spark2.range(10).limit(3)
1863
-
1864
- with self.assertRaises(SessionNotSameException) as e1:
1865
- df.union(df2).collect()
1866
- self.check_error(
1867
- exception=e1.exception,
1868
- error_class="SESSION_NOT_SAME",
1869
- message_parameters={},
1870
- )
1871
-
1872
- with self.assertRaises(SessionNotSameException) as e2:
1873
- df.unionByName(df2).collect()
1874
- self.check_error(
1875
- exception=e2.exception,
1876
- error_class="SESSION_NOT_SAME",
1877
- message_parameters={},
1878
- )
1879
-
1880
- with self.assertRaises(SessionNotSameException) as e3:
1881
- df.join(df2).collect()
1882
- self.check_error(
1883
- exception=e3.exception,
1884
- error_class="SESSION_NOT_SAME",
1885
- message_parameters={},
1886
- )
1887
-
1888
- def test_extended_hint_types(self):
1889
- cdf = self.connect.range(100).toDF("id")
1890
-
1891
- cdf.hint(
1892
- "my awesome hint",
1893
- 1.2345,
1894
- "what",
1895
- ["itworks1", "itworks2", "itworks3"],
1896
- ).show()
1897
-
1898
- with self.assertRaises(PySparkTypeError) as pe:
1899
- cdf.hint(
1900
- "my awesome hint",
1901
- 1.2345,
1902
- "what",
1903
- {"itworks1": "itworks2"},
1904
- ).show()
1905
-
1906
- self.check_error(
1907
- exception=pe.exception,
1908
- error_class="INVALID_ITEM_FOR_CONTAINER",
1909
- message_parameters={
1910
- "arg_name": "parameters",
1911
- "allowed_types": "str, list, float, int",
1912
- "item_type": "dict",
1913
- },
1914
- )
1915
-
1916
- def test_empty_dataset(self):
1917
- # SPARK-41005: Test arrow based collection with empty dataset.
1918
- self.assertTrue(
1919
- self.connect.sql("SELECT 1 AS X LIMIT 0")
1920
- .toPandas()
1921
- .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas())
1922
- )
1923
- pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas()
1924
- self.assertEqual(0, len(pdf)) # empty dataset
1925
- self.assertEqual(1, len(pdf.columns)) # one column
1926
- self.assertEqual("X", pdf.columns[0])
1927
-
1928
- def test_is_empty(self):
1929
- # SPARK-41212: Test is empty
1930
- self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty())
1931
- self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty())
1932
-
1933
- def test_session(self):
1934
- self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession)
1935
-
1936
- def test_show(self):
1937
- # SPARK-41111: Test the show method
1938
- show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string()
1939
- # +---+---+
1940
- # | X| Y|
1941
- # +---+---+
1942
- # | 1| 2|
1943
- # +---+---+
1944
- expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n"
1945
- self.assertEqual(show_str, expected)
1946
-
1947
- def test_describe(self):
1948
- # SPARK-41403: Test the describe method
1949
- self.assert_eq(
1950
- self.connect.read.table(self.tbl_name).describe("id").toPandas(),
1951
- self.spark.read.table(self.tbl_name).describe("id").toPandas(),
1952
- )
1953
- self.assert_eq(
1954
- self.connect.read.table(self.tbl_name).describe("id", "name").toPandas(),
1955
- self.spark.read.table(self.tbl_name).describe("id", "name").toPandas(),
1956
- )
1957
- self.assert_eq(
1958
- self.connect.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1959
- self.spark.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1960
- )
1961
-
1962
- def test_stat_cov(self):
1963
- # SPARK-41067: Test the stat.cov method
1964
- self.assertEqual(
1965
- self.connect.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1966
- self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1967
- )
1968
-
1969
- def test_stat_corr(self):
1970
- # SPARK-41068: Test the stat.corr method
1971
- self.assertEqual(
1972
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1973
- self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1974
- )
1975
-
1976
- self.assertEqual(
1977
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1978
- self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1979
- )
1980
-
1981
- with self.assertRaises(PySparkTypeError) as pe:
1982
- self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson")
1983
-
1984
- self.check_error(
1985
- exception=pe.exception,
1986
- error_class="NOT_STR",
1987
- message_parameters={
1988
- "arg_name": "col1",
1989
- "arg_type": "int",
1990
- },
1991
- )
1992
-
1993
- with self.assertRaises(PySparkTypeError) as pe:
1994
- self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson")
1995
-
1996
- self.check_error(
1997
- exception=pe.exception,
1998
- error_class="NOT_STR",
1999
- message_parameters={
2000
- "arg_name": "col2",
2001
- "arg_type": "int",
2002
- },
2003
- )
2004
- with self.assertRaises(ValueError) as context:
2005
- self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"),
2006
- self.assertTrue(
2007
- "Currently only the calculation of the Pearson Correlation "
2008
- + "coefficient is supported."
2009
- in str(context.exception)
2010
- )
2011
-
2012
- def test_stat_approx_quantile(self):
2013
- # SPARK-41069: Test the stat.approxQuantile method
2014
- result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2015
- ["col1", "col3"], [0.1, 0.5, 0.9], 0.1
2016
- )
2017
- self.assertEqual(len(result), 2)
2018
- self.assertEqual(len(result[0]), 3)
2019
- self.assertEqual(len(result[1]), 3)
2020
-
2021
- result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2022
- ["col1"], [0.1, 0.5, 0.9], 0.1
2023
- )
2024
- self.assertEqual(len(result), 1)
2025
- self.assertEqual(len(result[0]), 3)
2026
-
2027
- with self.assertRaises(PySparkTypeError) as pe:
2028
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(1, [0.1, 0.5, 0.9], 0.1)
2029
-
2030
- self.check_error(
2031
- exception=pe.exception,
2032
- error_class="NOT_LIST_OR_STR_OR_TUPLE",
2033
- message_parameters={
2034
- "arg_name": "col",
2035
- "arg_type": "int",
2036
- },
2037
- )
2038
-
2039
- with self.assertRaises(PySparkTypeError) as pe:
2040
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(["col1", "col3"], 0.1, 0.1)
2041
-
2042
- self.check_error(
2043
- exception=pe.exception,
2044
- error_class="NOT_LIST_OR_TUPLE",
2045
- message_parameters={
2046
- "arg_name": "probabilities",
2047
- "arg_type": "float",
2048
- },
2049
- )
2050
- with self.assertRaises(PySparkTypeError) as pe:
2051
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2052
- ["col1", "col3"], [-0.1], 0.1
2053
- )
2054
-
2055
- self.check_error(
2056
- exception=pe.exception,
2057
- error_class="NOT_LIST_OF_FLOAT_OR_INT",
2058
- message_parameters={"arg_name": "probabilities", "arg_type": "float"},
2059
- )
2060
- with self.assertRaises(PySparkTypeError) as pe:
2061
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2062
- ["col1", "col3"], [0.1, 0.5, 0.9], "str"
2063
- )
2064
-
2065
- self.check_error(
2066
- exception=pe.exception,
2067
- error_class="NOT_FLOAT_OR_INT",
2068
- message_parameters={
2069
- "arg_name": "relativeError",
2070
- "arg_type": "str",
2071
- },
2072
- )
2073
- with self.assertRaises(PySparkValueError) as pe:
2074
- self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2075
- ["col1", "col3"], [0.1, 0.5, 0.9], -0.1
2076
- )
2077
-
2078
- self.check_error(
2079
- exception=pe.exception,
2080
- error_class="NEGATIVE_VALUE",
2081
- message_parameters={
2082
- "arg_name": "relativeError",
2083
- "arg_value": "-0.1",
2084
- },
2085
- )
2086
-
2087
- def test_stat_freq_items(self):
2088
- # SPARK-41065: Test the stat.freqItems method
2089
- self.assert_eq(
2090
- self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2091
- self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2092
- )
2093
-
2094
- self.assert_eq(
2095
- self.connect.read.table(self.tbl_name2)
2096
- .stat.freqItems(["col1", "col3"], 0.4)
2097
- .toPandas(),
2098
- self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"], 0.4).toPandas(),
2099
- )
2100
-
2101
- with self.assertRaises(PySparkTypeError) as pe:
2102
- self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
2103
-
2104
- self.check_error(
2105
- exception=pe.exception,
2106
- error_class="NOT_LIST_OR_TUPLE",
2107
- message_parameters={
2108
- "arg_name": "cols",
2109
- "arg_type": "str",
2110
- },
2111
- )
2112
-
2113
- def test_stat_sample_by(self):
2114
- # SPARK-41069: Test stat.sample_by
2115
-
2116
- cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key"))
2117
- sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
2118
-
2119
- self.assert_eq(
2120
- cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2121
- .groupBy("key")
2122
- .agg(CF.count(CF.lit(1)))
2123
- .orderBy("key")
2124
- .toPandas(),
2125
- sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2126
- .groupBy("key")
2127
- .agg(SF.count(SF.lit(1)))
2128
- .orderBy("key")
2129
- .toPandas(),
2130
- )
2131
-
2132
- with self.assertRaises(PySparkTypeError) as pe:
2133
- cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0)
2134
-
2135
- self.check_error(
2136
- exception=pe.exception,
2137
- error_class="DISALLOWED_TYPE_FOR_CONTAINER",
2138
- message_parameters={
2139
- "arg_name": "fractions",
2140
- "arg_type": "dict",
2141
- "allowed_types": "float, int, str",
2142
- "return_type": "NoneType",
2143
- },
2144
- )
2145
-
2146
- with self.assertRaises(SparkConnectException):
2147
- cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show()
2148
-
2149
- def test_repr(self):
2150
- # SPARK-41213: Test the __repr__ method
2151
- query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
2152
- self.assertEqual(
2153
- self.connect.sql(query).__repr__(),
2154
- self.spark.sql(query).__repr__(),
2155
- )
2156
-
2157
- def test_explain_string(self):
2158
- # SPARK-41122: test explain API.
2159
- plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True)
2160
- self.assertTrue("Parsed Logical Plan" in plan_str)
2161
- self.assertTrue("Analyzed Logical Plan" in plan_str)
2162
- self.assertTrue("Optimized Logical Plan" in plan_str)
2163
- self.assertTrue("Physical Plan" in plan_str)
2164
-
2165
- with self.assertRaises(PySparkValueError) as pe:
2166
- self.connect.sql("SELECT 1")._explain_string(mode="unknown")
2167
- self.check_error(
2168
- exception=pe.exception,
2169
- error_class="UNKNOWN_EXPLAIN_MODE",
2170
- message_parameters={"explain_mode": "unknown"},
2171
- )
2172
-
2173
- def test_simple_datasource_read(self) -> None:
2174
- writeDf = self.df_text
2175
- tmpPath = tempfile.mkdtemp()
2176
- shutil.rmtree(tmpPath)
2177
- writeDf.write.text(tmpPath)
2178
-
2179
- for schema in [
2180
- "id STRING",
2181
- StructType([StructField("id", StringType())]),
2182
- ]:
2183
- readDf = self.connect.read.format("text").schema(schema).load(path=tmpPath)
2184
- expectResult = writeDf.collect()
2185
- pandasResult = readDf.toPandas()
2186
- if pandasResult is None:
2187
- self.assertTrue(False, "Empty pandas dataframe")
2188
- else:
2189
- actualResult = pandasResult.values.tolist()
2190
- self.assertEqual(len(expectResult), len(actualResult))
2191
-
2192
- def test_simple_read_without_schema(self) -> None:
2193
- """SPARK-41300: Schema not set when reading CSV."""
2194
- writeDf = self.df_text
2195
- tmpPath = tempfile.mkdtemp()
2196
- shutil.rmtree(tmpPath)
2197
- writeDf.write.csv(tmpPath, header=True)
2198
-
2199
- readDf = self.connect.read.format("csv").option("header", True).load(path=tmpPath)
2200
- expectResult = set(writeDf.collect())
2201
- pandasResult = set(readDf.collect())
2202
- self.assertEqual(expectResult, pandasResult)
2203
-
2204
- def test_count(self) -> None:
2205
- # SPARK-41308: test count() API.
2206
- self.assertEqual(
2207
- self.connect.read.table(self.tbl_name).count(),
2208
- self.spark.read.table(self.tbl_name).count(),
2209
- )
2210
-
2211
- def test_simple_transform(self) -> None:
2212
- """SPARK-41203: Support DF.transform"""
2213
-
2214
- def transform_df(input_df: CDataFrame) -> CDataFrame:
2215
- return input_df.select((CF.col("id") + CF.lit(10)).alias("id"))
2216
-
2217
- df = self.connect.range(1, 100)
2218
- result_left = df.transform(transform_df).collect()
2219
- result_right = self.connect.range(11, 110).collect()
2220
- self.assertEqual(result_right, result_left)
2221
-
2222
- # Check assertion.
2223
- with self.assertRaises(AssertionError):
2224
- df.transform(lambda x: 2) # type: ignore
2225
-
2226
- def test_alias(self) -> None:
2227
- """Testing supported and unsupported alias"""
2228
- col0 = (
2229
- self.connect.range(1, 10)
2230
- .select(CF.col("id").alias("name", metadata={"max": 99}))
2231
- .schema.names[0]
2232
- )
2233
- self.assertEqual("name", col0)
2234
-
2235
- with self.assertRaises(SparkConnectException) as exc:
2236
- self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect()
2237
- self.assertIn("(this, is, not)", str(exc.exception))
2238
-
2239
- def test_column_regexp(self) -> None:
2240
- # SPARK-41438: test dataframe.colRegex()
2241
- ndf = self.connect.read.table(self.tbl_name3)
2242
- df = self.spark.read.table(self.tbl_name3)
2243
-
2244
- self.assert_eq(
2245
- ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(),
2246
- df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(),
2247
- )
2248
-
2249
- def test_repartition(self) -> None:
2250
- # SPARK-41354: test dataframe.repartition(numPartitions)
2251
- self.assert_eq(
2252
- self.connect.read.table(self.tbl_name).repartition(10).toPandas(),
2253
- self.spark.read.table(self.tbl_name).repartition(10).toPandas(),
2254
- )
2255
-
2256
- self.assert_eq(
2257
- self.connect.read.table(self.tbl_name).coalesce(10).toPandas(),
2258
- self.spark.read.table(self.tbl_name).coalesce(10).toPandas(),
2259
- )
2260
-
2261
- def test_repartition_by_expression(self) -> None:
2262
- # SPARK-41354: test dataframe.repartition(expressions)
2263
- self.assert_eq(
2264
- self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2265
- self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2266
- )
2267
-
2268
- self.assert_eq(
2269
- self.connect.read.table(self.tbl_name).repartition("id").toPandas(),
2270
- self.spark.read.table(self.tbl_name).repartition("id").toPandas(),
2271
- )
2272
-
2273
- # repartition with unsupported parameter values
2274
- with self.assertRaises(AnalysisException):
2275
- self.connect.read.table(self.tbl_name).repartition("id+1").toPandas()
2276
-
2277
- def test_repartition_by_range(self) -> None:
2278
- # SPARK-41354: test dataframe.repartitionByRange(expressions)
2279
- cdf = self.connect.read.table(self.tbl_name)
2280
- sdf = self.spark.read.table(self.tbl_name)
2281
-
2282
- self.assert_eq(
2283
- cdf.repartitionByRange(10, "id").toPandas(),
2284
- sdf.repartitionByRange(10, "id").toPandas(),
2285
- )
2286
-
2287
- self.assert_eq(
2288
- cdf.repartitionByRange("id").toPandas(),
2289
- sdf.repartitionByRange("id").toPandas(),
2290
- )
2291
-
2292
- self.assert_eq(
2293
- cdf.repartitionByRange(cdf.id.desc()).toPandas(),
2294
- sdf.repartitionByRange(sdf.id.desc()).toPandas(),
2295
- )
2296
-
2297
- # repartitionByRange with unsupported parameter values
2298
- with self.assertRaises(AnalysisException):
2299
- self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas()
2300
-
2301
- def test_agg_with_two_agg_exprs(self) -> None:
2302
- # SPARK-41230: test dataframe.agg()
2303
- self.assert_eq(
2304
- self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2305
- self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2306
- )
2307
-
2308
- def test_subtract(self):
2309
- # SPARK-41453: test dataframe.subtract()
2310
- ndf1 = self.connect.read.table(self.tbl_name)
2311
- ndf2 = ndf1.filter("id > 3")
2312
- df1 = self.spark.read.table(self.tbl_name)
2313
- df2 = df1.filter("id > 3")
2314
-
2315
- self.assert_eq(
2316
- ndf1.subtract(ndf2).toPandas(),
2317
- df1.subtract(df2).toPandas(),
2318
- )
2319
-
2320
- def test_write_operations(self):
2321
- with tempfile.TemporaryDirectory() as d:
2322
- df = self.connect.range(50)
2323
- df.write.mode("overwrite").format("csv").save(d)
2324
-
2325
- ndf = self.connect.read.schema("id int").load(d, format="csv")
2326
- self.assertEqual(50, len(ndf.collect()))
2327
- cd = ndf.collect()
2328
- self.assertEqual(set(df.collect()), set(cd))
2329
-
2330
- with tempfile.TemporaryDirectory() as d:
2331
- df = self.connect.range(50)
2332
- df.write.mode("overwrite").csv(d, lineSep="|")
2333
-
2334
- ndf = self.connect.read.schema("id int").load(d, format="csv", lineSep="|")
2335
- self.assertEqual(set(df.collect()), set(ndf.collect()))
2336
-
2337
- df = self.connect.range(50)
2338
- df.write.format("parquet").saveAsTable("parquet_test")
2339
-
2340
- ndf = self.connect.read.table("parquet_test")
2341
- self.assertEqual(set(df.collect()), set(ndf.collect()))
2342
-
2343
- def test_writeTo_operations(self):
2344
- # SPARK-42002: Implement DataFrameWriterV2
2345
- import datetime
2346
- from pyspark.sql.connect.functions import col, years, months, days, hours, bucket
2347
-
2348
- df = self.connect.createDataFrame(
2349
- [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
2350
- )
2351
- writer = df.writeTo("table1")
2352
- self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
2353
- self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
2354
- self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
2355
- self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
2356
- self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
2357
- self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
2358
- self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
2359
- self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
2360
- self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
2361
- self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
2362
- self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2)
2363
-
2364
- def test_agg_with_avg(self):
2365
- # SPARK-41325: groupby.avg()
2366
- df = (
2367
- self.connect.range(10)
2368
- .groupBy((CF.col("id") % CF.lit(2)).alias("moded"))
2369
- .avg("id")
2370
- .sort("moded")
2371
- )
2372
- res = df.collect()
2373
- self.assertEqual(2, len(res))
2374
- self.assertEqual(4.0, res[0][1])
2375
- self.assertEqual(5.0, res[1][1])
2376
-
2377
- # Additional GroupBy tests with 3 rows
2378
-
2379
- df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded"))
2380
- df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded"))
2381
- self.assertEqual(
2382
- set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect())
2383
- )
2384
-
2385
- # Dict agg
2386
- measures = {"id": "sum"}
2387
- self.assertEqual(
2388
- set(df_a.agg(measures).select("sum(id)").collect()),
2389
- set(df_b.agg(measures).select("sum(id)").collect()),
2390
- )
2391
-
2392
- def test_column_cannot_be_constructed_from_string(self):
2393
- with self.assertRaises(TypeError):
2394
- Column("col")
2395
-
2396
- def test_crossjoin(self):
2397
- # SPARK-41227: Test CrossJoin
2398
- connect_df = self.connect.read.table(self.tbl_name)
2399
- spark_df = self.spark.read.table(self.tbl_name)
2400
- self.assert_eq(
2401
- set(
2402
- connect_df.select("id")
2403
- .join(other=connect_df.select("name"), how="cross")
2404
- .toPandas()
2405
- ),
2406
- set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()),
2407
- )
2408
- self.assert_eq(
2409
- set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()),
2410
- set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()),
2411
- )
2412
-
2413
- def test_grouped_data(self):
2414
- query = """
2415
- SELECT * FROM VALUES
2416
- ('James', 'Sales', 3000, 2020),
2417
- ('Michael', 'Sales', 4600, 2020),
2418
- ('Robert', 'Sales', 4100, 2020),
2419
- ('Maria', 'Finance', 3000, 2020),
2420
- ('James', 'Sales', 3000, 2019),
2421
- ('Scott', 'Finance', 3300, 2020),
2422
- ('Jen', 'Finance', 3900, 2020),
2423
- ('Jeff', 'Marketing', 3000, 2020),
2424
- ('Kumar', 'Marketing', 2000, 2020),
2425
- ('Saif', 'Sales', 4100, 2020)
2426
- AS T(name, department, salary, year)
2427
- """
2428
-
2429
- # +-------+----------+------+----+
2430
- # | name|department|salary|year|
2431
- # +-------+----------+------+----+
2432
- # | James| Sales| 3000|2020|
2433
- # |Michael| Sales| 4600|2020|
2434
- # | Robert| Sales| 4100|2020|
2435
- # | Maria| Finance| 3000|2020|
2436
- # | James| Sales| 3000|2019|
2437
- # | Scott| Finance| 3300|2020|
2438
- # | Jen| Finance| 3900|2020|
2439
- # | Jeff| Marketing| 3000|2020|
2440
- # | Kumar| Marketing| 2000|2020|
2441
- # | Saif| Sales| 4100|2020|
2442
- # +-------+----------+------+----+
2443
-
2444
- cdf = self.connect.sql(query)
2445
- sdf = self.spark.sql(query)
2446
-
2447
- # test groupby
2448
- self.assert_eq(
2449
- cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(),
2450
- sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(),
2451
- )
2452
- self.assert_eq(
2453
- cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2454
- sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2455
- )
2456
-
2457
- # test rollup
2458
- self.assert_eq(
2459
- cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(),
2460
- sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(),
2461
- )
2462
- self.assert_eq(
2463
- cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2464
- sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2465
- )
2466
-
2467
- # test cube
2468
- self.assert_eq(
2469
- cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(),
2470
- sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(),
2471
- )
2472
- self.assert_eq(
2473
- cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2474
- sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2475
- )
2476
-
2477
- # test pivot
2478
- # pivot with values
2479
- self.assert_eq(
2480
- cdf.groupBy("name")
2481
- .pivot("department", ["Sales", "Marketing"])
2482
- .agg(CF.sum(cdf.salary))
2483
- .toPandas(),
2484
- sdf.groupBy("name")
2485
- .pivot("department", ["Sales", "Marketing"])
2486
- .agg(SF.sum(sdf.salary))
2487
- .toPandas(),
2488
- )
2489
- self.assert_eq(
2490
- cdf.groupBy(cdf.name)
2491
- .pivot("department", ["Sales", "Finance", "Marketing"])
2492
- .agg(CF.sum(cdf.salary))
2493
- .toPandas(),
2494
- sdf.groupBy(sdf.name)
2495
- .pivot("department", ["Sales", "Finance", "Marketing"])
2496
- .agg(SF.sum(sdf.salary))
2497
- .toPandas(),
2498
- )
2499
- self.assert_eq(
2500
- cdf.groupBy(cdf.name)
2501
- .pivot("department", ["Sales", "Finance", "Unknown"])
2502
- .agg(CF.sum(cdf.salary))
2503
- .toPandas(),
2504
- sdf.groupBy(sdf.name)
2505
- .pivot("department", ["Sales", "Finance", "Unknown"])
2506
- .agg(SF.sum(sdf.salary))
2507
- .toPandas(),
2508
- )
2509
-
2510
- # pivot without values
2511
- self.assert_eq(
2512
- cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(),
2513
- sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(),
2514
- )
2515
-
2516
- self.assert_eq(
2517
- cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(),
2518
- sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(),
2519
- )
2520
-
2521
- # check error
2522
- with self.assertRaisesRegex(
2523
- Exception,
2524
- "PIVOT after ROLLUP is not supported",
2525
- ):
2526
- cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary))
2527
-
2528
- with self.assertRaisesRegex(
2529
- Exception,
2530
- "PIVOT after CUBE is not supported",
2531
- ):
2532
- cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary))
2533
-
2534
- with self.assertRaisesRegex(
2535
- Exception,
2536
- "Repeated PIVOT operation is not supported",
2537
- ):
2538
- cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary))
2539
-
2540
- with self.assertRaises(PySparkTypeError) as pe:
2541
- cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary))
2542
-
2543
- self.check_error(
2544
- exception=pe.exception,
2545
- error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR",
2546
- message_parameters={
2547
- "arg_name": "value",
2548
- "arg_type": "bytes",
2549
- },
2550
- )
2551
-
2552
- def test_numeric_aggregation(self):
2553
- # SPARK-41737: test numeric aggregation
2554
- query = """
2555
- SELECT * FROM VALUES
2556
- ('James', 'Sales', 3000, 2020),
2557
- ('Michael', 'Sales', 4600, 2020),
2558
- ('Robert', 'Sales', 4100, 2020),
2559
- ('Maria', 'Finance', 3000, 2020),
2560
- ('James', 'Sales', 3000, 2019),
2561
- ('Scott', 'Finance', 3300, 2020),
2562
- ('Jen', 'Finance', 3900, 2020),
2563
- ('Jeff', 'Marketing', 3000, 2020),
2564
- ('Kumar', 'Marketing', 2000, 2020),
2565
- ('Saif', 'Sales', 4100, 2020)
2566
- AS T(name, department, salary, year)
2567
- """
2568
-
2569
- # +-------+----------+------+----+
2570
- # | name|department|salary|year|
2571
- # +-------+----------+------+----+
2572
- # | James| Sales| 3000|2020|
2573
- # |Michael| Sales| 4600|2020|
2574
- # | Robert| Sales| 4100|2020|
2575
- # | Maria| Finance| 3000|2020|
2576
- # | James| Sales| 3000|2019|
2577
- # | Scott| Finance| 3300|2020|
2578
- # | Jen| Finance| 3900|2020|
2579
- # | Jeff| Marketing| 3000|2020|
2580
- # | Kumar| Marketing| 2000|2020|
2581
- # | Saif| Sales| 4100|2020|
2582
- # +-------+----------+------+----+
2583
-
2584
- cdf = self.connect.sql(query)
2585
- sdf = self.spark.sql(query)
2586
-
2587
- # test groupby
2588
- self.assert_eq(
2589
- cdf.groupBy("name").min().toPandas(),
2590
- sdf.groupBy("name").min().toPandas(),
2591
- )
2592
- self.assert_eq(
2593
- cdf.groupBy("name").min("salary").toPandas(),
2594
- sdf.groupBy("name").min("salary").toPandas(),
2595
- )
2596
- self.assert_eq(
2597
- cdf.groupBy("name").max("salary").toPandas(),
2598
- sdf.groupBy("name").max("salary").toPandas(),
2599
- )
2600
- self.assert_eq(
2601
- cdf.groupBy("name", cdf.department).avg("salary", "year").toPandas(),
2602
- sdf.groupBy("name", sdf.department).avg("salary", "year").toPandas(),
2603
- )
2604
- self.assert_eq(
2605
- cdf.groupBy("name", cdf.department).mean("salary", "year").toPandas(),
2606
- sdf.groupBy("name", sdf.department).mean("salary", "year").toPandas(),
2607
- )
2608
- self.assert_eq(
2609
- cdf.groupBy("name", cdf.department).sum("salary", "year").toPandas(),
2610
- sdf.groupBy("name", sdf.department).sum("salary", "year").toPandas(),
2611
- )
2612
-
2613
- # test rollup
2614
- self.assert_eq(
2615
- cdf.rollup("name").max().toPandas(),
2616
- sdf.rollup("name").max().toPandas(),
2617
- )
2618
- self.assert_eq(
2619
- cdf.rollup("name").min("salary").toPandas(),
2620
- sdf.rollup("name").min("salary").toPandas(),
2621
- )
2622
- self.assert_eq(
2623
- cdf.rollup("name").max("salary").toPandas(),
2624
- sdf.rollup("name").max("salary").toPandas(),
2625
- )
2626
- self.assert_eq(
2627
- cdf.rollup("name", cdf.department).avg("salary", "year").toPandas(),
2628
- sdf.rollup("name", sdf.department).avg("salary", "year").toPandas(),
2629
- )
2630
- self.assert_eq(
2631
- cdf.rollup("name", cdf.department).mean("salary", "year").toPandas(),
2632
- sdf.rollup("name", sdf.department).mean("salary", "year").toPandas(),
2633
- )
2634
- self.assert_eq(
2635
- cdf.rollup("name", cdf.department).sum("salary", "year").toPandas(),
2636
- sdf.rollup("name", sdf.department).sum("salary", "year").toPandas(),
2637
- )
2638
-
2639
- # test cube
2640
- self.assert_eq(
2641
- cdf.cube("name").avg().toPandas(),
2642
- sdf.cube("name").avg().toPandas(),
2643
- )
2644
- self.assert_eq(
2645
- cdf.cube("name").mean().toPandas(),
2646
- sdf.cube("name").mean().toPandas(),
2647
- )
2648
- self.assert_eq(
2649
- cdf.cube("name").min("salary").toPandas(),
2650
- sdf.cube("name").min("salary").toPandas(),
2651
- )
2652
- self.assert_eq(
2653
- cdf.cube("name").max("salary").toPandas(),
2654
- sdf.cube("name").max("salary").toPandas(),
2655
- )
2656
- self.assert_eq(
2657
- cdf.cube("name", cdf.department).avg("salary", "year").toPandas(),
2658
- sdf.cube("name", sdf.department).avg("salary", "year").toPandas(),
2659
- )
2660
- self.assert_eq(
2661
- cdf.cube("name", cdf.department).sum("salary", "year").toPandas(),
2662
- sdf.cube("name", sdf.department).sum("salary", "year").toPandas(),
2663
- )
2664
-
2665
- # test pivot
2666
- # pivot with values
2667
- self.assert_eq(
2668
- cdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2669
- sdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2670
- )
2671
- self.assert_eq(
2672
- cdf.groupBy("name")
2673
- .pivot("department", ["Sales", "Marketing"])
2674
- .min("salary")
2675
- .toPandas(),
2676
- sdf.groupBy("name")
2677
- .pivot("department", ["Sales", "Marketing"])
2678
- .min("salary")
2679
- .toPandas(),
2680
- )
2681
- self.assert_eq(
2682
- cdf.groupBy("name")
2683
- .pivot("department", ["Sales", "Marketing"])
2684
- .max("salary")
2685
- .toPandas(),
2686
- sdf.groupBy("name")
2687
- .pivot("department", ["Sales", "Marketing"])
2688
- .max("salary")
2689
- .toPandas(),
2690
- )
2691
- self.assert_eq(
2692
- cdf.groupBy(cdf.name)
2693
- .pivot("department", ["Sales", "Finance", "Unknown"])
2694
- .avg("salary", "year")
2695
- .toPandas(),
2696
- sdf.groupBy(sdf.name)
2697
- .pivot("department", ["Sales", "Finance", "Unknown"])
2698
- .avg("salary", "year")
2699
- .toPandas(),
2700
- )
2701
- self.assert_eq(
2702
- cdf.groupBy(cdf.name)
2703
- .pivot("department", ["Sales", "Finance", "Unknown"])
2704
- .sum("salary", "year")
2705
- .toPandas(),
2706
- sdf.groupBy(sdf.name)
2707
- .pivot("department", ["Sales", "Finance", "Unknown"])
2708
- .sum("salary", "year")
2709
- .toPandas(),
2710
- )
2711
-
2712
- # pivot without values
2713
- self.assert_eq(
2714
- cdf.groupBy("name").pivot("department").min().toPandas(),
2715
- sdf.groupBy("name").pivot("department").min().toPandas(),
2716
- )
2717
- self.assert_eq(
2718
- cdf.groupBy("name").pivot("department").min("salary").toPandas(),
2719
- sdf.groupBy("name").pivot("department").min("salary").toPandas(),
2720
- )
2721
- self.assert_eq(
2722
- cdf.groupBy("name").pivot("department").max("salary").toPandas(),
2723
- sdf.groupBy("name").pivot("department").max("salary").toPandas(),
2724
- )
2725
- self.assert_eq(
2726
- cdf.groupBy(cdf.name).pivot("department").avg("salary", "year").toPandas(),
2727
- sdf.groupBy(sdf.name).pivot("department").avg("salary", "year").toPandas(),
2728
- )
2729
- self.assert_eq(
2730
- cdf.groupBy(cdf.name).pivot("department").sum("salary", "year").toPandas(),
2731
- sdf.groupBy(sdf.name).pivot("department").sum("salary", "year").toPandas(),
2732
- )
2733
-
2734
- # check error
2735
- with self.assertRaisesRegex(
2736
- TypeError,
2737
- "Numeric aggregation function can only be applied on numeric columns",
2738
- ):
2739
- cdf.groupBy("name").min("department").show()
2740
-
2741
- with self.assertRaisesRegex(
2742
- TypeError,
2743
- "Numeric aggregation function can only be applied on numeric columns",
2744
- ):
2745
- cdf.groupBy("name").max("salary", "department").show()
2746
-
2747
- with self.assertRaisesRegex(
2748
- TypeError,
2749
- "Numeric aggregation function can only be applied on numeric columns",
2750
- ):
2751
- cdf.rollup("name").avg("department").show()
2752
-
2753
- with self.assertRaisesRegex(
2754
- TypeError,
2755
- "Numeric aggregation function can only be applied on numeric columns",
2756
- ):
2757
- cdf.rollup("name").sum("salary", "department").show()
2758
-
2759
- with self.assertRaisesRegex(
2760
- TypeError,
2761
- "Numeric aggregation function can only be applied on numeric columns",
2762
- ):
2763
- cdf.cube("name").min("department").show()
2764
-
2765
- with self.assertRaisesRegex(
2766
- TypeError,
2767
- "Numeric aggregation function can only be applied on numeric columns",
2768
- ):
2769
- cdf.cube("name").max("salary", "department").show()
2770
-
2771
- with self.assertRaisesRegex(
2772
- TypeError,
2773
- "Numeric aggregation function can only be applied on numeric columns",
2774
- ):
2775
- cdf.groupBy("name").pivot("department").avg("department").show()
2776
-
2777
- with self.assertRaisesRegex(
2778
- TypeError,
2779
- "Numeric aggregation function can only be applied on numeric columns",
2780
- ):
2781
- cdf.groupBy("name").pivot("department").sum("salary", "department").show()
2782
-
2783
- def test_with_metadata(self):
2784
- cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"])
2785
- self.assertEqual(cdf.schema["age"].metadata, {})
2786
- self.assertEqual(cdf.schema["name"].metadata, {})
2787
-
2788
- cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5})
2789
- self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5})
2790
-
2791
- cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]})
2792
- self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]})
2793
-
2794
- with self.assertRaises(PySparkTypeError) as pe:
2795
- cdf.withMetadata(columnName="name", metadata=["magic"])
2796
-
2797
- self.check_error(
2798
- exception=pe.exception,
2799
- error_class="NOT_DICT",
2800
- message_parameters={
2801
- "arg_name": "metadata",
2802
- "arg_type": "list",
2803
- },
2804
- )
2805
-
2806
- def test_collect_nested_type(self):
2807
- query = """
2808
- SELECT * FROM VALUES
2809
- (1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
2810
- (2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
2811
- (3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
2812
- AS tab(a, b, c, d, e, f, g, h)
2813
- """
2814
-
2815
- # +---+---+----+----+-----+----+------------+-------------------+
2816
- # | a| b| c| d| e| f| g| h|
2817
- # +---+---+----+----+-----+----+------------+-------------------+
2818
- # | 1| 4| 0| 8| true|true|[1, null, 3]| {1 -> 2, 3 -> 4}|
2819
- # | 2| 5| -1|NULL|false|NULL| [1, 3]|{1 -> null, 3 -> 4}|
2820
- # | 3| 6|NULL| 0|false|NULL| [null]| NULL|
2821
- # +---+---+----+----+-----+----+------------+-------------------+
2822
-
2823
- cdf = self.connect.sql(query)
2824
- sdf = self.spark.sql(query)
2825
-
2826
- # test collect array
2827
- # +--------------+-------------+------------+
2828
- # |array(a, b, c)| array(e, f)| g|
2829
- # +--------------+-------------+------------+
2830
- # | [1, 4, 0]| [true, true]|[1, null, 3]|
2831
- # | [2, 5, -1]|[false, null]| [1, 3]|
2832
- # | [3, 6, null]|[false, null]| [null]|
2833
- # +--------------+-------------+------------+
2834
- self.assertEqual(
2835
- cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"), CF.col("g")).collect(),
2836
- sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"), SF.col("g")).collect(),
2837
- )
2838
-
2839
- # test collect nested array
2840
- # +-----------------------------------+-------------------------+
2841
- # |array(array(a), array(b), array(c))|array(array(e), array(f))|
2842
- # +-----------------------------------+-------------------------+
2843
- # | [[1], [4], [0]]| [[true], [true]]|
2844
- # | [[2], [5], [-1]]| [[false], [null]]|
2845
- # | [[3], [6], [null]]| [[false], [null]]|
2846
- # +-----------------------------------+-------------------------+
2847
- self.assertEqual(
2848
- cdf.select(
2849
- CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
2850
- CF.array(CF.array("e"), CF.array("f")),
2851
- ).collect(),
2852
- sdf.select(
2853
- SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
2854
- SF.array(SF.array("e"), SF.array("f")),
2855
- ).collect(),
2856
- )
2857
-
2858
- # test collect array of struct, map
2859
- # +----------------+---------------------+
2860
- # |array(struct(a))| array(h)|
2861
- # +----------------+---------------------+
2862
- # | [{1}]| [{1 -> 2, 3 -> 4}]|
2863
- # | [{2}]|[{1 -> null, 3 -> 4}]|
2864
- # | [{3}]| [null]|
2865
- # +----------------+---------------------+
2866
- self.assertEqual(
2867
- cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
2868
- sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
2869
- )
2870
-
2871
- # test collect map
2872
- # +-------------------+-------------------+
2873
- # | h| map(a, b, b, c)|
2874
- # +-------------------+-------------------+
2875
- # | {1 -> 2, 3 -> 4}| {1 -> 4, 4 -> 0}|
2876
- # |{1 -> null, 3 -> 4}| {2 -> 5, 5 -> -1}|
2877
- # | NULL|{3 -> 6, 6 -> null}|
2878
- # +-------------------+-------------------+
2879
- self.assertEqual(
2880
- cdf.select(CF.col("h"), CF.create_map("a", "b", "b", "c")).collect(),
2881
- sdf.select(SF.col("h"), SF.create_map("a", "b", "b", "c")).collect(),
2882
- )
2883
-
2884
- # test collect map of struct, array
2885
- # +-------------------+------------------------+
2886
- # | map(a, g)| map(a, struct(b, g))|
2887
- # +-------------------+------------------------+
2888
- # |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
2889
- # | {2 -> [1, 3]}| {2 -> {5, [1, 3]}}|
2890
- # | {3 -> [null]}| {3 -> {6, [null]}}|
2891
- # +-------------------+------------------------+
2892
- self.assertEqual(
2893
- cdf.select(CF.create_map("a", "g"), CF.create_map("a", CF.struct("b", "g"))).collect(),
2894
- sdf.select(SF.create_map("a", "g"), SF.create_map("a", SF.struct("b", "g"))).collect(),
2895
- )
2896
-
2897
- # test collect struct
2898
- # +------------------+--------------------------+
2899
- # |struct(a, b, c, d)| struct(e, f, g)|
2900
- # +------------------+--------------------------+
2901
- # | {1, 4, 0, 8}|{true, true, [1, null, 3]}|
2902
- # | {2, 5, -1, null}| {false, null, [1, 3]}|
2903
- # | {3, 6, null, 0}| {false, null, [null]}|
2904
- # +------------------+--------------------------+
2905
- self.assertEqual(
2906
- cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f", "g")).collect(),
2907
- sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f", "g")).collect(),
2908
- )
2909
-
2910
- # test collect nested struct
2911
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2912
- # |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c, d))| struct(e, f, struct(g))| # noqa
2913
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2914
- # | {1, {1, {0, {8}}}}| {1, 4, {0, 8}}|{true, true, {[1, null, 3]}}| # noqa
2915
- # | {2, {2, {-1, {null}}}}| {2, 5, {-1, null}}| {false, null, {[1, 3]}}| # noqa
2916
- # | {3, {3, {null, {0}}}}| {3, 6, {null, 0}}| {false, null, {[null]}}| # noqa
2917
- # +------------------------------------------+--------------------------+----------------------------+ # noqa
2918
- self.assertEqual(
2919
- cdf.select(
2920
- CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
2921
- CF.struct("a", "b", CF.struct("c", "d")),
2922
- CF.struct("e", "f", CF.struct("g")),
2923
- ).collect(),
2924
- sdf.select(
2925
- SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
2926
- SF.struct("a", "b", SF.struct("c", "d")),
2927
- SF.struct("e", "f", SF.struct("g")),
2928
- ).collect(),
2929
- )
2930
-
2931
- # test collect struct containing array, map
2932
- # +--------------------------------------------+
2933
- # | struct(a, struct(a, struct(g, struct(h))))|
2934
- # +--------------------------------------------+
2935
- # |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
2936
- # | {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
2937
- # | {3, {3, {[null], {null}}}}|
2938
- # +--------------------------------------------+
2939
- self.assertEqual(
2940
- cdf.select(
2941
- CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
2942
- ).collect(),
2943
- sdf.select(
2944
- SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
2945
- ).collect(),
2946
- )
2947
-
2948
- def test_simple_udt(self):
2949
- from pyspark.ml.linalg import MatrixUDT, VectorUDT
2950
-
2951
- for schema in [
2952
- StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2953
- StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2954
- StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())),
2955
- StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2956
- StructType().add("key", LongType()).add("vec", VectorUDT()),
2957
- StructType().add("key", LongType()).add("mat", MatrixUDT()),
2958
- ]:
2959
- cdf = self.connect.createDataFrame(data=[], schema=schema)
2960
- sdf = self.spark.createDataFrame(data=[], schema=schema)
2961
-
2962
- self.assertEqual(cdf.schema, sdf.schema)
2963
-
2964
- def test_simple_udt_from_read(self):
2965
- from pyspark.ml.linalg import Matrices, Vectors
2966
-
2967
- with tempfile.TemporaryDirectory() as d:
2968
- path1 = f"{d}/df1.parquet"
2969
- self.spark.createDataFrame(
2970
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2971
- schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2972
- ).write.parquet(path1)
2973
-
2974
- path2 = f"{d}/df2.parquet"
2975
- self.spark.createDataFrame(
2976
- [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
2977
- schema=StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2978
- ).write.parquet(path2)
2979
-
2980
- path3 = f"{d}/df3.parquet"
2981
- self.spark.createDataFrame(
2982
- [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
2983
- schema=StructType()
2984
- .add("key", LongType())
2985
- .add("val", MapType(LongType(), PythonOnlyUDT())),
2986
- ).write.parquet(path3)
2987
-
2988
- path4 = f"{d}/df4.parquet"
2989
- self.spark.createDataFrame(
2990
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2991
- schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2992
- ).write.parquet(path4)
2993
-
2994
- path5 = f"{d}/df5.parquet"
2995
- self.spark.createDataFrame(
2996
- [Row(label=1.0, point=ExamplePoint(1.0, 2.0))]
2997
- ).write.parquet(path5)
2998
-
2999
- path6 = f"{d}/df6.parquet"
3000
- self.spark.createDataFrame(
3001
- [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)],
3002
- ["vec"],
3003
- ).write.parquet(path6)
3004
-
3005
- path7 = f"{d}/df7.parquet"
3006
- self.spark.createDataFrame(
3007
- [
3008
- (Matrices.dense(3, 2, [0, 1, 4, 5, 9, 10]),),
3009
- (Matrices.sparse(1, 1, [0, 1], [0], [2.0]),),
3010
- ],
3011
- ["mat"],
3012
- ).write.parquet(path7)
3013
-
3014
- for path in [path1, path2, path3, path4, path5, path6, path7]:
3015
- self.assertEqual(
3016
- self.connect.read.parquet(path).schema,
3017
- self.spark.read.parquet(path).schema,
3018
- )
3019
-
3020
- def test_version(self):
3021
- self.assertEqual(
3022
- self.connect.version,
3023
- self.spark.version,
3024
- )
3025
-
3026
- def test_same_semantics(self):
3027
- plan = self.connect.sql("SELECT 1")
3028
- other = self.connect.sql("SELECT 1")
3029
- self.assertTrue(plan.sameSemantics(other))
3030
-
3031
- def test_semantic_hash(self):
3032
- plan = self.connect.sql("SELECT 1")
3033
- other = self.connect.sql("SELECT 1")
3034
- self.assertEqual(
3035
- plan.semanticHash(),
3036
- other.semanticHash(),
3037
- )
3038
-
3039
- def test_unsupported_functions(self):
3040
- # SPARK-41225: Disable unsupported functions.
3041
- df = self.connect.read.table(self.tbl_name)
3042
- for f in (
3043
- "rdd",
3044
- "foreach",
3045
- "foreachPartition",
3046
- "checkpoint",
3047
- "localCheckpoint",
3048
- ):
3049
- with self.assertRaises(NotImplementedError):
3050
- getattr(df, f)()
3051
-
3052
- def test_unsupported_session_functions(self):
3053
- # SPARK-41934: Disable unsupported functions.
3054
-
3055
- with self.assertRaises(NotImplementedError):
3056
- RemoteSparkSession.builder.enableHiveSupport()
3057
-
3058
- for f in (
3059
- "newSession",
3060
- "sparkContext",
3061
- ):
3062
- with self.assertRaises(NotImplementedError):
3063
- getattr(self.connect, f)()
3064
-
3065
- def test_sql_with_command(self):
3066
- # SPARK-42705: spark.sql should return values from the command.
3067
- self.assertEqual(
3068
- self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()
3069
- )
3070
-
3071
- def test_schema_has_nullable(self):
3072
- schema_false = StructType().add("id", IntegerType(), False)
3073
- cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
3074
- sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
3075
- self.assertEqual(cdf1.schema, sdf1.schema)
3076
- self.assertEqual(cdf1.collect(), sdf1.collect())
3077
-
3078
- schema_true = StructType().add("id", IntegerType(), True)
3079
- cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
3080
- sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
3081
- self.assertEqual(cdf2.schema, sdf2.schema)
3082
- self.assertEqual(cdf2.collect(), sdf2.collect())
3083
-
3084
- pdf1 = cdf1.toPandas()
3085
- cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
3086
- sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
3087
- self.assertEqual(cdf3.schema, sdf3.schema)
3088
- self.assertEqual(cdf3.collect(), sdf3.collect())
3089
-
3090
- pdf2 = cdf2.toPandas()
3091
- cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
3092
- sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
3093
- self.assertEqual(cdf4.schema, sdf4.schema)
3094
- self.assertEqual(cdf4.collect(), sdf4.collect())
3095
-
3096
- def test_array_has_nullable(self):
3097
- for schemas, data in [
3098
- (
3099
- [StructType().add("arr", ArrayType(IntegerType(), False), True)],
3100
- [Row([1, 2]), Row([3]), Row(None)],
3101
- ),
3102
- (
3103
- [
3104
- StructType().add("arr", ArrayType(IntegerType(), True), True),
3105
- "arr array<integer>",
3106
- ],
3107
- [Row([1, None]), Row([3]), Row(None)],
3108
- ),
3109
- (
3110
- [StructType().add("arr", ArrayType(IntegerType(), False), False)],
3111
- [Row([1, 2]), Row([3])],
3112
- ),
3113
- (
3114
- [
3115
- StructType().add("arr", ArrayType(IntegerType(), True), False),
3116
- "arr array<integer> not null",
3117
- ],
3118
- [Row([1, None]), Row([3])],
3119
- ),
3120
- ]:
3121
- for schema in schemas:
3122
- with self.subTest(schema=schema):
3123
- cdf = self.connect.createDataFrame(data, schema=schema)
3124
- sdf = self.spark.createDataFrame(data, schema=schema)
3125
- self.assertEqual(cdf.schema, sdf.schema)
3126
- self.assertEqual(cdf.collect(), sdf.collect())
3127
-
3128
- def test_map_has_nullable(self):
3129
- for schemas, data in [
3130
- (
3131
- [StructType().add("map", MapType(StringType(), IntegerType(), False), True)],
3132
- [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
3133
- ),
3134
- (
3135
- [
3136
- StructType().add("map", MapType(StringType(), IntegerType(), True), True),
3137
- "map map<string, integer>",
3138
- ],
3139
- [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
3140
- ),
3141
- (
3142
- [StructType().add("map", MapType(StringType(), IntegerType(), False), False)],
3143
- [Row({"a": 1, "b": 2}), Row({"a": 3})],
3144
- ),
3145
- (
3146
- [
3147
- StructType().add("map", MapType(StringType(), IntegerType(), True), False),
3148
- "map map<string, integer> not null",
3149
- ],
3150
- [Row({"a": 1, "b": None}), Row({"a": 3})],
3151
- ),
3152
- ]:
3153
- for schema in schemas:
3154
- with self.subTest(schema=schema):
3155
- cdf = self.connect.createDataFrame(data, schema=schema)
3156
- sdf = self.spark.createDataFrame(data, schema=schema)
3157
- self.assertEqual(cdf.schema, sdf.schema)
3158
- self.assertEqual(cdf.collect(), sdf.collect())
3159
-
3160
- def test_struct_has_nullable(self):
3161
- for schemas, data in [
3162
- (
3163
- [
3164
- StructType().add("struct", StructType().add("i", IntegerType(), False), True),
3165
- "struct struct<i: integer not null>",
3166
- ],
3167
- [Row(Row(1)), Row(Row(2)), Row(None)],
3168
- ),
3169
- (
3170
- [
3171
- StructType().add("struct", StructType().add("i", IntegerType(), True), True),
3172
- "struct struct<i: integer>",
3173
- ],
3174
- [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
3175
- ),
3176
- (
3177
- [
3178
- StructType().add("struct", StructType().add("i", IntegerType(), False), False),
3179
- "struct struct<i: integer not null> not null",
3180
- ],
3181
- [Row(Row(1)), Row(Row(2))],
3182
- ),
3183
- (
3184
- [
3185
- StructType().add("struct", StructType().add("i", IntegerType(), True), False),
3186
- "struct struct<i: integer> not null",
3187
- ],
3188
- [Row(Row(1)), Row(Row(2)), Row(Row(None))],
3189
- ),
3190
- ]:
3191
- for schema in schemas:
3192
- with self.subTest(schema=schema):
3193
- cdf = self.connect.createDataFrame(data, schema=schema)
3194
- sdf = self.spark.createDataFrame(data, schema=schema)
3195
- self.assertEqual(cdf.schema, sdf.schema)
3196
- self.assertEqual(cdf.collect(), sdf.collect())
3197
-
3198
- def test_large_client_data(self):
3199
- # SPARK-42816 support more than 4MB message size.
3200
- # ~200bytes
3201
- cols = ["abcdefghijklmnoprstuvwxyz" for x in range(10)]
3202
- # 100k rows => 20MB
3203
- row_count = 100 * 1000
3204
- rows = [cols] * row_count
3205
- self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
3206
-
3207
- def test_unsupported_jvm_attribute(self):
3208
- # Unsupported jvm attributes for Spark session.
3209
- unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
3210
- spark_session = self.connect
3211
- for attr in unsupported_attrs:
3212
- with self.assertRaises(PySparkAttributeError) as pe:
3213
- getattr(spark_session, attr)
3214
-
3215
- self.check_error(
3216
- exception=pe.exception,
3217
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3218
- message_parameters={"attr_name": attr},
3219
- )
3220
-
3221
- # Unsupported jvm attributes for DataFrame.
3222
- unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"]
3223
- cdf = self.connect.range(10)
3224
- for attr in unsupported_attrs:
3225
- with self.assertRaises(PySparkAttributeError) as pe:
3226
- getattr(cdf, attr)
3227
-
3228
- self.check_error(
3229
- exception=pe.exception,
3230
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3231
- message_parameters={"attr_name": attr},
3232
- )
3233
-
3234
- # Unsupported jvm attributes for Column.
3235
- with self.assertRaises(PySparkAttributeError) as pe:
3236
- getattr(cdf.id, "_jc")
3237
-
3238
- self.check_error(
3239
- exception=pe.exception,
3240
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3241
- message_parameters={"attr_name": "_jc"},
3242
- )
3243
-
3244
- # Unsupported jvm attributes for DataFrameReader.
3245
- with self.assertRaises(PySparkAttributeError) as pe:
3246
- getattr(spark_session.read, "_jreader")
3247
-
3248
- self.check_error(
3249
- exception=pe.exception,
3250
- error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3251
- message_parameters={"attr_name": "_jreader"},
3252
- )
3253
-
3254
- def test_df_caache(self):
3255
- df = self.connect.range(10)
3256
- df.cache()
3257
- self.assert_eq(10, df.count())
3258
- self.assertTrue(df.is_cached)
3259
-
3260
-
3261
- @unittest.skipIf(
3262
- "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Session creation different from local mode"
3263
- )
3264
- class SparkConnectSessionTests(ReusedConnectTestCase):
3265
- def setUp(self) -> None:
3266
- self.spark = (
3267
- PySparkSession.builder.config(conf=self.conf())
3268
- .appName(self.__class__.__name__)
3269
- .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
3270
- .getOrCreate()
3271
- )
3272
-
3273
- def tearDown(self):
3274
- self.spark.stop()
3275
-
3276
- def _check_no_active_session_error(self, e: PySparkException):
3277
- self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict())
3278
-
3279
- def test_stop_session(self):
3280
- df = self.spark.sql("select 1 as a, 2 as b")
3281
- catalog = self.spark.catalog
3282
- self.spark.stop()
3283
-
3284
- # _execute_and_fetch
3285
- with self.assertRaises(SparkConnectException) as e:
3286
- self.spark.sql("select 1")
3287
- self._check_no_active_session_error(e.exception)
3288
-
3289
- with self.assertRaises(SparkConnectException) as e:
3290
- catalog.tableExists("table")
3291
- self._check_no_active_session_error(e.exception)
3292
-
3293
- # _execute
3294
- with self.assertRaises(SparkConnectException) as e:
3295
- self.spark.udf.register("test_func", lambda x: x + 1)
3296
- self._check_no_active_session_error(e.exception)
3297
-
3298
- # _analyze
3299
- with self.assertRaises(SparkConnectException) as e:
3300
- df._explain_string(extended=True)
3301
- self._check_no_active_session_error(e.exception)
3302
-
3303
- # Config
3304
- with self.assertRaises(SparkConnectException) as e:
3305
- self.spark.conf.get("some.conf")
3306
- self._check_no_active_session_error(e.exception)
3307
-
3308
- def test_error_stack_trace(self):
3309
- with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
3310
- with self.assertRaises(AnalysisException) as e:
3311
- self.spark.sql("select x").collect()
3312
- self.assertTrue("JVM stacktrace" in e.exception.message)
3313
- self.assertTrue(
3314
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3315
- )
3316
-
3317
- with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
3318
- with self.assertRaises(AnalysisException) as e:
3319
- self.spark.sql("select x").collect()
3320
- self.assertFalse("JVM stacktrace" in e.exception.message)
3321
- self.assertFalse(
3322
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3323
- )
3324
-
3325
- # Create a new session with a different stack trace size.
3326
- self.spark.stop()
3327
- spark = (
3328
- PySparkSession.builder.config(conf=self.conf())
3329
- .config("spark.connect.jvmStacktrace.maxSize", 128)
3330
- .remote("local[4]")
3331
- .getOrCreate()
3332
- )
3333
- spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
3334
- with self.assertRaises(AnalysisException) as e:
3335
- spark.sql("select x").collect()
3336
- self.assertTrue("JVM stacktrace" in e.exception.message)
3337
- self.assertFalse(
3338
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3339
- )
3340
- spark.stop()
3341
-
3342
- def test_can_create_multiple_sessions_to_different_remotes(self):
3343
- self.spark.stop()
3344
- self.assertIsNotNone(self.spark._client)
3345
- # Creates a new remote session.
3346
- other = PySparkSession.builder.remote("sc://other.remote:114/").create()
3347
- self.assertNotEquals(self.spark, other)
3348
-
3349
- # Gets currently active session.
3350
- same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
3351
- self.assertEquals(other, same)
3352
- same.stop()
3353
-
3354
- # Make sure the environment is clean.
3355
- self.spark.stop()
3356
- with self.assertRaises(RuntimeError) as e:
3357
- PySparkSession.builder.create()
3358
- self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e))
3359
-
3360
-
3361
- @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
3362
- class SparkConnectSessionWithOptionsTest(unittest.TestCase):
3363
- def setUp(self) -> None:
3364
- self.spark = (
3365
- PySparkSession.builder.config("string", "foo")
3366
- .config("integer", 1)
3367
- .config("boolean", False)
3368
- .appName(self.__class__.__name__)
3369
- .remote("local[4]")
3370
- .getOrCreate()
3371
- )
3372
-
3373
- def tearDown(self):
3374
- self.spark.stop()
3375
-
3376
- def test_config(self):
3377
- # Config
3378
- self.assertEqual(self.spark.conf.get("string"), "foo")
3379
- self.assertEqual(self.spark.conf.get("boolean"), "false")
3380
- self.assertEqual(self.spark.conf.get("integer"), "1")
3381
-
3382
-
3383
- @unittest.skipIf(not should_test_connect, connect_requirement_message)
3384
- class ClientTests(unittest.TestCase):
3385
- def test_retry_error_handling(self):
3386
- # Helper class for wrapping the test.
3387
- class TestError(grpc.RpcError, Exception):
3388
- def __init__(self, code: grpc.StatusCode):
3389
- self._code = code
3390
-
3391
- def code(self):
3392
- return self._code
3393
-
3394
- def stub(retries, w, code):
3395
- w["attempts"] += 1
3396
- if w["attempts"] < retries:
3397
- w["raised"] += 1
3398
- raise TestError(code)
3399
-
3400
- # Check that max_retries 1 is only one retry so two attempts.
3401
- call_wrap = defaultdict(int)
3402
- for attempt in Retrying(
3403
- can_retry=lambda x: True,
3404
- max_retries=1,
3405
- backoff_multiplier=1,
3406
- initial_backoff=1,
3407
- max_backoff=10,
3408
- jitter=0,
3409
- min_jitter_threshold=0,
3410
- ):
3411
- with attempt:
3412
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3413
-
3414
- self.assertEqual(2, call_wrap["attempts"])
3415
- self.assertEqual(1, call_wrap["raised"])
3416
-
3417
- # Check that if we have less than 4 retries all is ok.
3418
- call_wrap = defaultdict(int)
3419
- for attempt in Retrying(
3420
- can_retry=lambda x: True,
3421
- max_retries=4,
3422
- backoff_multiplier=1,
3423
- initial_backoff=1,
3424
- max_backoff=10,
3425
- jitter=0,
3426
- min_jitter_threshold=0,
3427
- ):
3428
- with attempt:
3429
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3430
-
3431
- self.assertTrue(call_wrap["attempts"] < 4)
3432
- self.assertEqual(call_wrap["raised"], 1)
3433
-
3434
- # Exceed the retries.
3435
- call_wrap = defaultdict(int)
3436
- with self.assertRaises(TestError):
3437
- for attempt in Retrying(
3438
- can_retry=lambda x: True,
3439
- max_retries=2,
3440
- max_backoff=50,
3441
- backoff_multiplier=1,
3442
- initial_backoff=50,
3443
- jitter=0,
3444
- min_jitter_threshold=0,
3445
- ):
3446
- with attempt:
3447
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3448
-
3449
- self.assertTrue(call_wrap["attempts"] < 5)
3450
- self.assertEqual(call_wrap["raised"], 3)
3451
-
3452
- # Check that only specific exceptions are retried.
3453
- # Check that if we have less than 4 retries all is ok.
3454
- call_wrap = defaultdict(int)
3455
- for attempt in Retrying(
3456
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3457
- max_retries=4,
3458
- backoff_multiplier=1,
3459
- initial_backoff=1,
3460
- max_backoff=10,
3461
- jitter=0,
3462
- min_jitter_threshold=0,
3463
- ):
3464
- with attempt:
3465
- stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
3466
-
3467
- self.assertTrue(call_wrap["attempts"] < 4)
3468
- self.assertEqual(call_wrap["raised"], 1)
3469
-
3470
- # Exceed the retries.
3471
- call_wrap = defaultdict(int)
3472
- with self.assertRaises(TestError):
3473
- for attempt in Retrying(
3474
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3475
- max_retries=2,
3476
- max_backoff=50,
3477
- backoff_multiplier=1,
3478
- initial_backoff=50,
3479
- jitter=0,
3480
- min_jitter_threshold=0,
3481
- ):
3482
- with attempt:
3483
- stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
3484
-
3485
- self.assertTrue(call_wrap["attempts"] < 4)
3486
- self.assertEqual(call_wrap["raised"], 3)
3487
-
3488
- # Test that another error is always thrown.
3489
- call_wrap = defaultdict(int)
3490
- with self.assertRaises(TestError):
3491
- for attempt in Retrying(
3492
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3493
- max_retries=4,
3494
- backoff_multiplier=1,
3495
- initial_backoff=1,
3496
- max_backoff=10,
3497
- jitter=0,
3498
- min_jitter_threshold=0,
3499
- ):
3500
- with attempt:
3501
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3502
-
3503
- self.assertEqual(call_wrap["attempts"], 1)
3504
- self.assertEqual(call_wrap["raised"], 1)
3505
-
3506
-
3507
- @unittest.skipIf(not should_test_connect, connect_requirement_message)
3508
- class ChannelBuilderTests(unittest.TestCase):
3509
- def test_invalid_connection_strings(self):
3510
- invalid = [
3511
- "scc://host:12",
3512
- "http://host",
3513
- "sc:/host:1234/path",
3514
- "sc://host/path",
3515
- "sc://host/;parm1;param2",
3516
- ]
3517
- for i in invalid:
3518
- self.assertRaises(PySparkValueError, ChannelBuilder, i)
3519
-
3520
- def test_sensible_defaults(self):
3521
- chan = ChannelBuilder("sc://host")
3522
- self.assertFalse(chan.secure, "Default URL is not secure")
3523
-
3524
- chan = ChannelBuilder("sc://host/;token=abcs")
3525
- self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
3526
- self.assertRegex(
3527
- chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$"
3528
- )
3529
- chan = ChannelBuilder("sc://host/;use_ssl=abcs")
3530
- self.assertFalse(chan.secure, "Garbage in, false out")
3531
-
3532
- def test_user_agent(self):
3533
- chan = ChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
3534
- self.assertIn("Agent123 /3.4", chan.userAgent)
3535
-
3536
- def test_user_agent_len(self):
3537
- user_agent = "x" * 2049
3538
- chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3539
- with self.assertRaises(SparkConnectException) as err:
3540
- chan.userAgent
3541
- self.assertRegex(err.exception.message, "'user_agent' parameter should not exceed")
3542
-
3543
- user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
3544
- expected = "ä" * 341
3545
- chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3546
- self.assertIn(expected, chan.userAgent)
3547
-
3548
- def test_valid_channel_creation(self):
3549
- chan = ChannelBuilder("sc://host").toChannel()
3550
- self.assertIsInstance(chan, grpc.Channel)
3551
-
3552
- # Sets up a channel without tokens because ssl is not used.
3553
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel()
3554
- self.assertIsInstance(chan, grpc.Channel)
3555
-
3556
- chan = ChannelBuilder("sc://host/;use_ssl=true").toChannel()
3557
- self.assertIsInstance(chan, grpc.Channel)
3558
-
3559
- def test_channel_properties(self):
3560
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021")
3561
- self.assertEqual("host:15002", chan.endpoint)
3562
- self.assertIn("foo", chan.userAgent.split(" "))
3563
- self.assertEqual(True, chan.secure)
3564
- self.assertEqual("120 21", chan.get("param1"))
3565
-
3566
- def test_metadata(self):
3567
- chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd")
3568
- md = chan.metadata()
3569
- self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
3570
-
3571
- def test_metadata(self):
3572
- id = str(uuid.uuid4())
3573
- chan = ChannelBuilder(f"sc://host/;session_id={id}")
3574
- self.assertEqual(id, chan.session_id)
3575
-
3576
- chan = ChannelBuilder(f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true")
3577
- md = chan.metadata()
3578
- for kv in md:
3579
- self.assertNotIn(
3580
- kv[0],
3581
- [
3582
- ChannelBuilder.PARAM_SESSION_ID,
3583
- ChannelBuilder.PARAM_TOKEN,
3584
- ChannelBuilder.PARAM_USER_ID,
3585
- ChannelBuilder.PARAM_USER_AGENT,
3586
- ChannelBuilder.PARAM_USE_SSL,
3587
- ],
3588
- "Metadata must not contain fixed params",
3589
- )
3590
-
3591
- with self.assertRaises(ValueError) as ve:
3592
- chan = ChannelBuilder("sc://host/;session_id=abcd")
3593
- SparkConnectClient(chan)
3594
- self.assertIn(
3595
- "Parameter value 'session_id' must be a valid UUID format.", str(ve.exception)
3596
- )
3597
-
3598
- chan = ChannelBuilder("sc://host/")
3599
- self.assertIsNone(chan.session_id)
3600
-
3601
-
3602
- if __name__ == "__main__":
3603
- from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
3604
-
3605
- try:
3606
- import xmlrunner
3607
-
3608
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
3609
- except ImportError:
3610
- testRunner = None
3611
-
3612
- unittest.main(testRunner=testRunner, verbosity=2)