snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3612 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import array
19
+ import datetime
20
+ import os
21
+ import unittest
22
+ import random
23
+ import shutil
24
+ import string
25
+ import tempfile
26
+ import uuid
27
+ from collections import defaultdict
28
+
29
+ from pyspark.errors import (
30
+ PySparkAttributeError,
31
+ PySparkTypeError,
32
+ PySparkException,
33
+ PySparkValueError,
34
+ )
35
+ from pyspark.errors.exceptions.base import SessionNotSameException
36
+ from pyspark.sql import SparkSession as PySparkSession, Row
37
+ from pyspark.sql.types import (
38
+ StructType,
39
+ StructField,
40
+ LongType,
41
+ StringType,
42
+ IntegerType,
43
+ MapType,
44
+ ArrayType,
45
+ Row,
46
+ )
47
+
48
+ from pyspark.testing.objects import (
49
+ MyObject,
50
+ PythonOnlyUDT,
51
+ ExamplePoint,
52
+ PythonOnlyPoint,
53
+ )
54
+ from pyspark.testing.sqlutils import SQLTestUtils
55
+ from pyspark.testing.connectutils import (
56
+ should_test_connect,
57
+ ReusedConnectTestCase,
58
+ connect_requirement_message,
59
+ )
60
+ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
61
+ from pyspark.errors.exceptions.connect import (
62
+ AnalysisException,
63
+ ParseException,
64
+ SparkConnectException,
65
+ )
66
+
67
+ if should_test_connect:
68
+ import grpc
69
+ import pandas as pd
70
+ import numpy as np
71
+ from pyspark.sql.connect.proto import Expression as ProtoExpression
72
+ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
73
+ from pyspark.sql.connect.client import ChannelBuilder
74
+ from pyspark.sql.connect.column import Column
75
+ from pyspark.sql.connect.readwriter import DataFrameWriterV2
76
+ from pyspark.sql.dataframe import DataFrame
77
+ from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
78
+ from pyspark.sql import functions as SF
79
+ from pyspark.sql.connect import functions as CF
80
+ from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
81
+
82
+
83
+ @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
84
+ class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
85
+ """Parent test fixture class for all Spark Connect related
86
+ test cases."""
87
+
88
+ @classmethod
89
+ def setUpClass(cls):
90
+ super(SparkConnectSQLTestCase, cls).setUpClass()
91
+ # Disable the shared namespace so pyspark.sql.functions, etc point the regular
92
+ # PySpark libraries.
93
+ os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
94
+
95
+ cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
96
+ cls.spark = PySparkSession._instantiatedSession
97
+ assert cls.spark is not None
98
+
99
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
100
+ cls.testDataStr = [Row(key=str(i)) for i in range(100)]
101
+ cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
102
+ cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
103
+
104
+ cls.tbl_name = "test_connect_basic_table_1"
105
+ cls.tbl_name2 = "test_connect_basic_table_2"
106
+ cls.tbl_name3 = "test_connect_basic_table_3"
107
+ cls.tbl_name4 = "test_connect_basic_table_4"
108
+ cls.tbl_name_empty = "test_connect_basic_table_empty"
109
+
110
+ # Cleanup test data
111
+ cls.spark_connect_clean_up_test_data()
112
+ # Load test data
113
+ cls.spark_connect_load_test_data()
114
+
115
+ @classmethod
116
+ def tearDownClass(cls):
117
+ try:
118
+ cls.spark_connect_clean_up_test_data()
119
+ # Stopping Spark Connect closes the session in JVM at the server.
120
+ cls.spark = cls.connect
121
+ del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
122
+ finally:
123
+ super(SparkConnectSQLTestCase, cls).tearDownClass()
124
+
125
+ @classmethod
126
+ def spark_connect_load_test_data(cls):
127
+ df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
128
+ # Since we might create multiple Spark sessions, we need to create global temporary view
129
+ # that is specifically maintained in the "global_temp" schema.
130
+ df.write.saveAsTable(cls.tbl_name)
131
+ df2 = cls.spark.createDataFrame(
132
+ [(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]
133
+ )
134
+ df2.write.saveAsTable(cls.tbl_name2)
135
+ df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"])
136
+ df3.write.saveAsTable(cls.tbl_name3)
137
+ df4 = cls.spark.createDataFrame(
138
+ [(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"]
139
+ )
140
+ df4.write.saveAsTable(cls.tbl_name4)
141
+ empty_table_schema = StructType(
142
+ [
143
+ StructField("firstname", StringType(), True),
144
+ StructField("middlename", StringType(), True),
145
+ StructField("lastname", StringType(), True),
146
+ ]
147
+ )
148
+ emptyRDD = cls.spark.sparkContext.emptyRDD()
149
+ empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
150
+ empty_df.write.saveAsTable(cls.tbl_name_empty)
151
+
152
+ @classmethod
153
+ def spark_connect_clean_up_test_data(cls):
154
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
155
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
156
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name3))
157
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name4))
158
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
159
+
160
+
161
+ class SparkConnectBasicTests(SparkConnectSQLTestCase):
162
+ def test_df_getattr_behavior(self):
163
+ cdf = self.connect.range(10)
164
+ sdf = self.spark.range(10)
165
+
166
+ sdf._simple_extension = 10
167
+ cdf._simple_extension = 10
168
+
169
+ self.assertEqual(sdf._simple_extension, cdf._simple_extension)
170
+ self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))
171
+
172
+ self.assertTrue(hasattr(cdf, "_simple_extension"))
173
+ self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))
174
+
175
+ def test_df_get_item(self):
176
+ # SPARK-41779: test __getitem__
177
+
178
+ query = """
179
+ SELECT * FROM VALUES
180
+ (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
181
+ AS tab(a, b, c)
182
+ """
183
+
184
+ # +-----+----+----+
185
+ # | a| b| c|
186
+ # +-----+----+----+
187
+ # | true| 1|NULL|
188
+ # |false|NULL| 2.0|
189
+ # | NULL| 3| 3.0|
190
+ # +-----+----+----+
191
+
192
+ cdf = self.connect.sql(query)
193
+ sdf = self.spark.sql(query)
194
+
195
+ # filter
196
+ self.assert_eq(
197
+ cdf[cdf.a].toPandas(),
198
+ sdf[sdf.a].toPandas(),
199
+ )
200
+ self.assert_eq(
201
+ cdf[cdf.b.isin(2, 3)].toPandas(),
202
+ sdf[sdf.b.isin(2, 3)].toPandas(),
203
+ )
204
+ self.assert_eq(
205
+ cdf[cdf.c > 1.5].toPandas(),
206
+ sdf[sdf.c > 1.5].toPandas(),
207
+ )
208
+
209
+ # select
210
+ self.assert_eq(
211
+ cdf[[cdf.a, "b", cdf.c]].toPandas(),
212
+ sdf[[sdf.a, "b", sdf.c]].toPandas(),
213
+ )
214
+ self.assert_eq(
215
+ cdf[(cdf.a, "b", cdf.c)].toPandas(),
216
+ sdf[(sdf.a, "b", sdf.c)].toPandas(),
217
+ )
218
+
219
+ # select by index
220
+ self.assertTrue(isinstance(cdf[0], Column))
221
+ self.assertTrue(isinstance(cdf[1], Column))
222
+ self.assertTrue(isinstance(cdf[2], Column))
223
+
224
+ self.assert_eq(
225
+ cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
226
+ sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
227
+ )
228
+
229
+ # check error
230
+ with self.assertRaises(PySparkTypeError) as pe:
231
+ cdf[1.5]
232
+
233
+ self.check_error(
234
+ exception=pe.exception,
235
+ error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
236
+ message_parameters={
237
+ "arg_name": "item",
238
+ "arg_type": "float",
239
+ },
240
+ )
241
+
242
+ with self.assertRaises(PySparkTypeError) as pe:
243
+ cdf[None]
244
+
245
+ self.check_error(
246
+ exception=pe.exception,
247
+ error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
248
+ message_parameters={
249
+ "arg_name": "item",
250
+ "arg_type": "NoneType",
251
+ },
252
+ )
253
+
254
+ with self.assertRaises(PySparkTypeError) as pe:
255
+ cdf[cdf]
256
+
257
+ self.check_error(
258
+ exception=pe.exception,
259
+ error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
260
+ message_parameters={
261
+ "arg_name": "item",
262
+ "arg_type": "DataFrame",
263
+ },
264
+ )
265
+
266
+ def test_error_handling(self):
267
+ # SPARK-41533 Proper error handling for Spark Connect
268
+ df = self.connect.range(10).select("id2")
269
+ with self.assertRaises(AnalysisException):
270
+ df.collect()
271
+
272
+ def test_simple_read(self):
273
+ df = self.connect.read.table(self.tbl_name)
274
+ data = df.limit(10).toPandas()
275
+ # Check that the limit is applied
276
+ self.assertEqual(len(data.index), 10)
277
+
278
+ def test_json(self):
279
+ with tempfile.TemporaryDirectory() as d:
280
+ # Write a DataFrame into a JSON file
281
+ self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
282
+ "overwrite"
283
+ ).format("json").save(d)
284
+ # Read the JSON file as a DataFrame.
285
+ self.assert_eq(self.connect.read.json(d).toPandas(), self.spark.read.json(d).toPandas())
286
+
287
+ for schema in [
288
+ "age INT, name STRING",
289
+ StructType(
290
+ [
291
+ StructField("age", IntegerType()),
292
+ StructField("name", StringType()),
293
+ ]
294
+ ),
295
+ ]:
296
+ self.assert_eq(
297
+ self.connect.read.json(path=d, schema=schema).toPandas(),
298
+ self.spark.read.json(path=d, schema=schema).toPandas(),
299
+ )
300
+
301
+ self.assert_eq(
302
+ self.connect.read.json(path=d, primitivesAsString=True).toPandas(),
303
+ self.spark.read.json(path=d, primitivesAsString=True).toPandas(),
304
+ )
305
+
306
+ def test_parquet(self):
307
+ # SPARK-41445: Implement DataFrameReader.parquet
308
+ with tempfile.TemporaryDirectory() as d:
309
+ # Write a DataFrame into a JSON file
310
+ self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
311
+ "overwrite"
312
+ ).format("parquet").save(d)
313
+ # Read the Parquet file as a DataFrame.
314
+ self.assert_eq(
315
+ self.connect.read.parquet(d).toPandas(), self.spark.read.parquet(d).toPandas()
316
+ )
317
+
318
+ def test_text(self):
319
+ # SPARK-41849: Implement DataFrameReader.text
320
+ with tempfile.TemporaryDirectory() as d:
321
+ # Write a DataFrame into a text file
322
+ self.spark.createDataFrame(
323
+ [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
324
+ ).write.mode("overwrite").format("text").save(d)
325
+ # Read the text file as a DataFrame.
326
+ self.assert_eq(self.connect.read.text(d).toPandas(), self.spark.read.text(d).toPandas())
327
+
328
+ def test_csv(self):
329
+ # SPARK-42011: Implement DataFrameReader.csv
330
+ with tempfile.TemporaryDirectory() as d:
331
+ # Write a DataFrame into a text file
332
+ self.spark.createDataFrame(
333
+ [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
334
+ ).write.mode("overwrite").format("csv").save(d)
335
+ # Read the text file as a DataFrame.
336
+ self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas())
337
+
338
+ def test_multi_paths(self):
339
+ # SPARK-42041: DataFrameReader should support list of paths
340
+
341
+ with tempfile.TemporaryDirectory() as d:
342
+ text_files = []
343
+ for i in range(0, 3):
344
+ text_file = f"{d}/text-{i}.text"
345
+ shutil.copyfile("python/test_support/sql/text-test.txt", text_file)
346
+ text_files.append(text_file)
347
+
348
+ self.assertEqual(
349
+ self.connect.read.text(text_files).collect(),
350
+ self.spark.read.text(text_files).collect(),
351
+ )
352
+
353
+ with tempfile.TemporaryDirectory() as d:
354
+ json_files = []
355
+ for i in range(0, 5):
356
+ json_file = f"{d}/json-{i}.json"
357
+ shutil.copyfile("python/test_support/sql/people.json", json_file)
358
+ json_files.append(json_file)
359
+
360
+ self.assertEqual(
361
+ self.connect.read.json(json_files).collect(),
362
+ self.spark.read.json(json_files).collect(),
363
+ )
364
+
365
+ def test_orc(self):
366
+ # SPARK-42012: Implement DataFrameReader.orc
367
+ with tempfile.TemporaryDirectory() as d:
368
+ # Write a DataFrame into a text file
369
+ self.spark.createDataFrame(
370
+ [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
371
+ ).write.mode("overwrite").format("orc").save(d)
372
+ # Read the text file as a DataFrame.
373
+ self.assert_eq(self.connect.read.orc(d).toPandas(), self.spark.read.orc(d).toPandas())
374
+
375
+ def test_join_condition_column_list_columns(self):
376
+ left_connect_df = self.connect.read.table(self.tbl_name)
377
+ right_connect_df = self.connect.read.table(self.tbl_name2)
378
+ left_spark_df = self.spark.read.table(self.tbl_name)
379
+ right_spark_df = self.spark.read.table(self.tbl_name2)
380
+ joined_plan = left_connect_df.join(
381
+ other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner"
382
+ )
383
+ joined_plan2 = left_spark_df.join(
384
+ other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner"
385
+ )
386
+ self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas())
387
+
388
+ joined_plan3 = left_connect_df.join(
389
+ other=right_connect_df,
390
+ on=[
391
+ left_connect_df.id == right_connect_df.col1,
392
+ left_connect_df.name == right_connect_df.col2,
393
+ ],
394
+ how="inner",
395
+ )
396
+ joined_plan4 = left_spark_df.join(
397
+ other=right_spark_df,
398
+ on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2],
399
+ how="inner",
400
+ )
401
+ self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
402
+
403
+ def test_join_ambiguous_cols(self):
404
+ # SPARK-41812: test join with ambiguous columns
405
+ data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
406
+ cdf1 = self.connect.createDataFrame(data1)
407
+ sdf1 = self.spark.createDataFrame(data1)
408
+
409
+ data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
410
+ cdf2 = self.connect.createDataFrame(data2)
411
+ sdf2 = self.spark.createDataFrame(data2)
412
+
413
+ cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
414
+ sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
415
+
416
+ self.assertEqual(cdf3.schema, sdf3.schema)
417
+ self.assertEqual(cdf3.collect(), sdf3.collect())
418
+
419
+ cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
420
+ sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
421
+
422
+ self.assertEqual(cdf4.schema, sdf4.schema)
423
+ self.assertEqual(cdf4.collect(), sdf4.collect())
424
+
425
+ cdf5 = cdf1.join(
426
+ cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
427
+ )
428
+ sdf5 = sdf1.join(
429
+ sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
430
+ )
431
+
432
+ self.assertEqual(cdf5.schema, sdf5.schema)
433
+ self.assertEqual(cdf5.collect(), sdf5.collect())
434
+
435
+ cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
436
+ sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
437
+
438
+ self.assertEqual(cdf6.schema, sdf6.schema)
439
+ self.assertEqual(cdf6.collect(), sdf6.collect())
440
+
441
+ cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
442
+ sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
443
+
444
+ self.assertEqual(cdf7.schema, sdf7.schema)
445
+ self.assertEqual(cdf7.collect(), sdf7.collect())
446
+
447
+ def test_invalid_column(self):
448
+ # SPARK-41812: fail df1.select(df2.col)
449
+ data1 = [Row(a=1, b=2, c=3)]
450
+ cdf1 = self.connect.createDataFrame(data1)
451
+
452
+ data2 = [Row(a=2, b=0)]
453
+ cdf2 = self.connect.createDataFrame(data2)
454
+
455
+ with self.assertRaises(AnalysisException):
456
+ cdf1.select(cdf2.a).schema
457
+
458
+ with self.assertRaises(AnalysisException):
459
+ cdf2.withColumn("x", cdf1.a + 1).schema
460
+
461
+ with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
462
+ cdf3 = cdf1.select(cdf1.a)
463
+ cdf3.select(cdf1.b).schema
464
+
465
+ def test_collect(self):
466
+ cdf = self.connect.read.table(self.tbl_name)
467
+ sdf = self.spark.read.table(self.tbl_name)
468
+
469
+ data = cdf.limit(10).collect()
470
+ self.assertEqual(len(data), 10)
471
+ # Check Row has schema column names.
472
+ self.assertTrue("name" in data[0])
473
+ self.assertTrue("id" in data[0])
474
+
475
+ cdf = cdf.select(
476
+ CF.log("id"), CF.log("id"), CF.struct("id", "name"), CF.struct("id", "name")
477
+ ).limit(10)
478
+ sdf = sdf.select(
479
+ SF.log("id"), SF.log("id"), SF.struct("id", "name"), SF.struct("id", "name")
480
+ ).limit(10)
481
+
482
+ self.assertEqual(
483
+ cdf.collect(),
484
+ sdf.collect(),
485
+ )
486
+
487
+ def test_collect_timestamp(self):
488
+ query = """
489
+ SELECT * FROM VALUES
490
+ (TIMESTAMP('2022-12-25 10:30:00'), 1),
491
+ (TIMESTAMP('2022-12-25 10:31:00'), 2),
492
+ (TIMESTAMP('2022-12-25 10:32:00'), 1),
493
+ (TIMESTAMP('2022-12-25 10:33:00'), 2),
494
+ (TIMESTAMP('2022-12-26 09:30:00'), 1),
495
+ (TIMESTAMP('2022-12-26 09:35:00'), 3)
496
+ AS tab(date, val)
497
+ """
498
+
499
+ cdf = self.connect.sql(query)
500
+ sdf = self.spark.sql(query)
501
+
502
+ self.assertEqual(cdf.schema, sdf.schema)
503
+
504
+ self.assertEqual(cdf.collect(), sdf.collect())
505
+
506
+ self.assertEqual(
507
+ cdf.select(CF.date_trunc("year", cdf.date).alias("year")).collect(),
508
+ sdf.select(SF.date_trunc("year", sdf.date).alias("year")).collect(),
509
+ )
510
+
511
+ def test_with_columns_renamed(self):
512
+ # SPARK-41312: test DataFrame.withColumnsRenamed()
513
+ self.assertEqual(
514
+ self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
515
+ self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
516
+ )
517
+ self.assertEqual(
518
+ self.connect.read.table(self.tbl_name)
519
+ .withColumnsRenamed({"id": "id_new", "name": "name_new"})
520
+ .schema,
521
+ self.spark.read.table(self.tbl_name)
522
+ .withColumnsRenamed({"id": "id_new", "name": "name_new"})
523
+ .schema,
524
+ )
525
+
526
+ def test_with_local_data(self):
527
+ """SPARK-41114: Test creating a dataframe using local data"""
528
+ pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
529
+ df = self.connect.createDataFrame(pdf)
530
+ rows = df.filter(df.a == CF.lit(3)).collect()
531
+ self.assertTrue(len(rows) == 1)
532
+ self.assertEqual(rows[0][0], 3)
533
+ self.assertEqual(rows[0][1], "c")
534
+
535
+ # Check correct behavior for empty DataFrame
536
+ pdf = pd.DataFrame({"a": []})
537
+ with self.assertRaises(ValueError):
538
+ self.connect.createDataFrame(pdf)
539
+
540
+ def test_with_local_ndarray(self):
541
+ """SPARK-41446: Test creating a dataframe using local list"""
542
+ data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
543
+
544
+ sdf = self.spark.createDataFrame(data)
545
+ cdf = self.connect.createDataFrame(data)
546
+ self.assertEqual(sdf.schema, cdf.schema)
547
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
548
+
549
+ for schema in [
550
+ StructType(
551
+ [
552
+ StructField("col1", IntegerType(), True),
553
+ StructField("col2", IntegerType(), True),
554
+ StructField("col3", IntegerType(), True),
555
+ StructField("col4", IntegerType(), True),
556
+ ]
557
+ ),
558
+ "struct<col1 int, col2 int, col3 int, col4 int>",
559
+ "col1 int, col2 int, col3 int, col4 int",
560
+ "col1 int, col2 long, col3 string, col4 long",
561
+ "col1 int, col2 string, col3 short, col4 long",
562
+ ["a", "b", "c", "d"],
563
+ ("x1", "x2", "x3", "x4"),
564
+ ]:
565
+ with self.subTest(schema=schema):
566
+ sdf = self.spark.createDataFrame(data, schema=schema)
567
+ cdf = self.connect.createDataFrame(data, schema=schema)
568
+
569
+ self.assertEqual(sdf.schema, cdf.schema)
570
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
571
+
572
+ with self.assertRaises(PySparkValueError) as pe:
573
+ self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
574
+
575
+ self.check_error(
576
+ exception=pe.exception,
577
+ error_class="AXIS_LENGTH_MISMATCH",
578
+ message_parameters={"expected_length": "5", "actual_length": "4"},
579
+ )
580
+
581
+ with self.assertRaises(ParseException):
582
+ self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
583
+
584
+ with self.assertRaises(PySparkValueError) as pe:
585
+ self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
586
+
587
+ self.check_error(
588
+ exception=pe.exception,
589
+ error_class="AXIS_LENGTH_MISMATCH",
590
+ message_parameters={"expected_length": "3", "actual_length": "4"},
591
+ )
592
+
593
+ # test 1 dim ndarray
594
+ data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
595
+ self.assertEqual(data.ndim, 1)
596
+
597
+ sdf = self.spark.createDataFrame(data)
598
+ cdf = self.connect.createDataFrame(data)
599
+ self.assertEqual(sdf.schema, cdf.schema)
600
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
601
+
602
+ def test_with_local_list(self):
603
+ """SPARK-41446: Test creating a dataframe using local list"""
604
+ data = [[1, 2, 3, 4]]
605
+
606
+ sdf = self.spark.createDataFrame(data)
607
+ cdf = self.connect.createDataFrame(data)
608
+ self.assertEqual(sdf.schema, cdf.schema)
609
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
610
+
611
+ for schema in [
612
+ "struct<col1 int, col2 int, col3 int, col4 int>",
613
+ "col1 int, col2 int, col3 int, col4 int",
614
+ "col1 int, col2 long, col3 string, col4 long",
615
+ "col1 int, col2 string, col3 short, col4 long",
616
+ ["a", "b", "c", "d"],
617
+ ("x1", "x2", "x3", "x4"),
618
+ ]:
619
+ sdf = self.spark.createDataFrame(data, schema=schema)
620
+ cdf = self.connect.createDataFrame(data, schema=schema)
621
+
622
+ self.assertEqual(sdf.schema, cdf.schema)
623
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
624
+
625
+ with self.assertRaises(PySparkValueError) as pe:
626
+ self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
627
+
628
+ self.check_error(
629
+ exception=pe.exception,
630
+ error_class="AXIS_LENGTH_MISMATCH",
631
+ message_parameters={"expected_length": "5", "actual_length": "4"},
632
+ )
633
+
634
+ with self.assertRaises(ParseException):
635
+ self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
636
+
637
+ with self.assertRaisesRegex(
638
+ ValueError,
639
+ "Length mismatch: Expected axis has 3 elements, new values have 4 elements",
640
+ ):
641
+ self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
642
+
643
+ def test_with_local_rows(self):
644
+ # SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries
645
+ rows = [
646
+ Row(course="dotNET", year=2012, earnings=10000),
647
+ Row(course="Java", year=2012, earnings=20000),
648
+ Row(course="dotNET", year=2012, earnings=5000),
649
+ Row(course="dotNET", year=2013, earnings=48000),
650
+ Row(course="Java", year=2013, earnings=30000),
651
+ Row(course="Scala", year=2022, earnings=None),
652
+ ]
653
+ dicts = [row.asDict() for row in rows]
654
+
655
+ for data in [rows, dicts]:
656
+ sdf = self.spark.createDataFrame(data)
657
+ cdf = self.connect.createDataFrame(data)
658
+
659
+ self.assertEqual(sdf.schema, cdf.schema)
660
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
661
+
662
+ # test with rename
663
+ sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
664
+ cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
665
+
666
+ self.assertEqual(sdf.schema, cdf.schema)
667
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
668
+
669
+ def test_streaming_local_relation(self):
670
+ threshold_conf = "spark.sql.session.localRelationCacheThreshold"
671
+ old_threshold = self.connect.conf.get(threshold_conf)
672
+ threshold = 1024 * 1024
673
+ self.connect.conf.set(threshold_conf, threshold)
674
+ try:
675
+ suffix = "abcdef"
676
+ letters = string.ascii_lowercase
677
+ str = "".join(random.choice(letters) for i in range(threshold)) + suffix
678
+ data = [[0, str], [1, str]]
679
+ for i in range(0, 2):
680
+ cdf = self.connect.createDataFrame(data, ["a", "b"])
681
+ self.assert_eq(cdf.count(), len(data))
682
+ self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False)
683
+ finally:
684
+ self.connect.conf.set(threshold_conf, old_threshold)
685
+
686
+ def test_with_atom_type(self):
687
+ for data in [[(1), (2), (3)], [1, 2, 3]]:
688
+ for schema in ["long", "int", "short"]:
689
+ sdf = self.spark.createDataFrame(data, schema=schema)
690
+ cdf = self.connect.createDataFrame(data, schema=schema)
691
+
692
+ self.assertEqual(sdf.schema, cdf.schema)
693
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
694
+
695
+ def test_with_none_and_nan(self):
696
+ # SPARK-41855: make createDataFrame support None and NaN
697
+ # SPARK-41814: test with eqNullSafe
698
+ data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
699
+ data2 = [Row(id=1, value=np.nan), Row(id=2, value=42.0), Row(id=3, value=None)]
700
+ data3 = [
701
+ {"id": 1, "value": float("NaN")},
702
+ {"id": 2, "value": 42.0},
703
+ {"id": 3, "value": None},
704
+ ]
705
+ data4 = [{"id": 1, "value": np.nan}, {"id": 2, "value": 42.0}, {"id": 3, "value": None}]
706
+ data5 = [(1, float("NaN")), (2, 42.0), (3, None)]
707
+ data6 = [(1, np.nan), (2, 42.0), (3, None)]
708
+ data7 = np.array([[1, float("NaN")], [2, 42.0], [3, None]])
709
+ data8 = np.array([[1, np.nan], [2, 42.0], [3, None]])
710
+
711
+ # +---+-----+
712
+ # | id|value|
713
+ # +---+-----+
714
+ # | 1| NaN|
715
+ # | 2| 42.0|
716
+ # | 3| NULL|
717
+ # +---+-----+
718
+
719
+ for data in [data1, data2, data3, data4, data5, data6, data7, data8]:
720
+ if isinstance(data[0], (Row, dict)):
721
+ # data1, data2, data3, data4
722
+ cdf = self.connect.createDataFrame(data)
723
+ sdf = self.spark.createDataFrame(data)
724
+ else:
725
+ # data5, data6, data7, data8
726
+ cdf = self.connect.createDataFrame(data, schema=["id", "value"])
727
+ sdf = self.spark.createDataFrame(data, schema=["id", "value"])
728
+
729
+ self.assert_eq(cdf.toPandas(), sdf.toPandas())
730
+
731
+ self.assert_eq(
732
+ cdf.select(
733
+ cdf["value"].eqNullSafe(None),
734
+ cdf["value"].eqNullSafe(float("NaN")),
735
+ cdf["value"].eqNullSafe(42.0),
736
+ ).toPandas(),
737
+ sdf.select(
738
+ sdf["value"].eqNullSafe(None),
739
+ sdf["value"].eqNullSafe(float("NaN")),
740
+ sdf["value"].eqNullSafe(42.0),
741
+ ).toPandas(),
742
+ )
743
+
744
+ # SPARK-41851: test with nanvl
745
+ data = [(1.0, float("nan")), (float("nan"), 2.0)]
746
+
747
+ cdf = self.connect.createDataFrame(data, ("a", "b"))
748
+ sdf = self.spark.createDataFrame(data, ("a", "b"))
749
+
750
+ self.assert_eq(cdf.toPandas(), sdf.toPandas())
751
+
752
+ self.assert_eq(
753
+ cdf.select(
754
+ CF.nanvl("a", "b").alias("r1"), CF.nanvl(cdf.a, cdf.b).alias("r2")
755
+ ).toPandas(),
756
+ sdf.select(
757
+ SF.nanvl("a", "b").alias("r1"), SF.nanvl(sdf.a, sdf.b).alias("r2")
758
+ ).toPandas(),
759
+ )
760
+
761
+ # SPARK-41852: test with pmod
762
+ data = [
763
+ (1.0, float("nan")),
764
+ (float("nan"), 2.0),
765
+ (10.0, 3.0),
766
+ (float("nan"), float("nan")),
767
+ (-3.0, 4.0),
768
+ (-10.0, 3.0),
769
+ (-5.0, -6.0),
770
+ (7.0, -8.0),
771
+ (1.0, 2.0),
772
+ ]
773
+
774
+ cdf = self.connect.createDataFrame(data, ("a", "b"))
775
+ sdf = self.spark.createDataFrame(data, ("a", "b"))
776
+
777
+ self.assert_eq(cdf.toPandas(), sdf.toPandas())
778
+
779
+ self.assert_eq(
780
+ cdf.select(CF.pmod("a", "b")).toPandas(),
781
+ sdf.select(SF.pmod("a", "b")).toPandas(),
782
+ )
783
+
784
+ def test_cast_with_ddl(self):
785
+ data = [Row(date=datetime.date(2021, 12, 27), add=2)]
786
+
787
+ cdf = self.connect.createDataFrame(data, "date date, add integer")
788
+ sdf = self.spark.createDataFrame(data, "date date, add integer")
789
+
790
+ self.assertEqual(cdf.schema, sdf.schema)
791
+
792
+ def test_create_empty_df(self):
793
+ for schema in [
794
+ "STRING",
795
+ "x STRING",
796
+ "x STRING, y INTEGER",
797
+ StringType(),
798
+ StructType(
799
+ [
800
+ StructField("x", StringType(), True),
801
+ StructField("y", IntegerType(), True),
802
+ ]
803
+ ),
804
+ ]:
805
+ cdf = self.connect.createDataFrame(data=[], schema=schema)
806
+ sdf = self.spark.createDataFrame(data=[], schema=schema)
807
+
808
+ self.assert_eq(cdf.toPandas(), sdf.toPandas())
809
+
810
+ # check error
811
+ with self.assertRaises(PySparkValueError) as pe:
812
+ self.connect.createDataFrame(data=[])
813
+
814
+ self.check_error(
815
+ exception=pe.exception,
816
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
817
+ message_parameters={},
818
+ )
819
+
820
+ def test_create_dataframe_from_arrays(self):
821
+ # SPARK-42021: createDataFrame support array.array
822
+ data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", [4, 5, 6]))]
823
+ data2 = [(array.array("d", [1, 2, 3]), 2, "3")]
824
+ data3 = [{"a": 1, "b": array.array("i", [1, 2, 3])}]
825
+
826
+ for data in [data1, data2, data3]:
827
+ cdf = self.connect.createDataFrame(data)
828
+ sdf = self.spark.createDataFrame(data)
829
+
830
+ # TODO: the nullability is different, need to fix
831
+ # self.assertEqual(cdf.schema, sdf.schema)
832
+ self.assertEqual(cdf.collect(), sdf.collect())
833
+
834
+ def test_timestampe_create_from_rows(self):
835
+ data = [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]
836
+
837
+ cdf = self.connect.createDataFrame(data, ["date", "val"])
838
+ sdf = self.spark.createDataFrame(data, ["date", "val"])
839
+
840
+ self.assertEqual(cdf.schema, sdf.schema)
841
+ self.assertEqual(cdf.collect(), sdf.collect())
842
+
843
+ def test_create_dataframe_with_coercion(self):
844
+ data1 = [[1.33, 1], ["2.1", 1]]
845
+ data2 = [[True, 1], ["false", 1]]
846
+
847
+ for data in [data1, data2]:
848
+ cdf = self.connect.createDataFrame(data, ["a", "b"])
849
+ sdf = self.spark.createDataFrame(data, ["a", "b"])
850
+
851
+ self.assertEqual(cdf.schema, sdf.schema)
852
+ self.assertEqual(cdf.collect(), sdf.collect())
853
+
854
+ def test_nested_type_create_from_rows(self):
855
+ data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
856
+ # root
857
+ # |-- a: long (nullable = true)
858
+ # |-- b: struct (nullable = true)
859
+ # | |-- c: long (nullable = true)
860
+ # | |-- d: struct (nullable = true)
861
+ # | | |-- e: long (nullable = true)
862
+ # | | |-- f: struct (nullable = true)
863
+ # | | | |-- g: long (nullable = true)
864
+ # | | | |-- h: struct (nullable = true)
865
+ # | | | | |-- i: long (nullable = true)
866
+
867
+ data2 = [
868
+ (
869
+ 1,
870
+ "a",
871
+ Row(
872
+ a=1,
873
+ b=[1, 2, 3],
874
+ c={"a": "b"},
875
+ d=Row(x=1, y="y", z=Row(o=1, p=2, q=Row(g=1.5))),
876
+ ),
877
+ )
878
+ ]
879
+ # root
880
+ # |-- _1: long (nullable = true)
881
+ # |-- _2: string (nullable = true)
882
+ # |-- _3: struct (nullable = true)
883
+ # | |-- a: long (nullable = true)
884
+ # | |-- b: array (nullable = true)
885
+ # | | |-- element: long (containsNull = true)
886
+ # | |-- c: map (nullable = true)
887
+ # | | |-- key: string
888
+ # | | |-- value: string (valueContainsNull = true)
889
+ # | |-- d: struct (nullable = true)
890
+ # | | |-- x: long (nullable = true)
891
+ # | | |-- y: string (nullable = true)
892
+ # | | |-- z: struct (nullable = true)
893
+ # | | | |-- o: long (nullable = true)
894
+ # | | | |-- p: long (nullable = true)
895
+ # | | | |-- q: struct (nullable = true)
896
+ # | | | | |-- g: double (nullable = true)
897
+
898
+ data3 = [
899
+ Row(
900
+ a=1,
901
+ b=[1, 2, 3],
902
+ c={"a": "b"},
903
+ d=Row(x=1, y="y", z=Row(1, 2, 3)),
904
+ e=list("hello connect"),
905
+ )
906
+ ]
907
+ # root
908
+ # |-- a: long (nullable = true)
909
+ # |-- b: array (nullable = true)
910
+ # | |-- element: long (containsNull = true)
911
+ # |-- c: map (nullable = true)
912
+ # | |-- key: string
913
+ # | |-- value: string (valueContainsNull = true)
914
+ # |-- d: struct (nullable = true)
915
+ # | |-- x: long (nullable = true)
916
+ # | |-- y: string (nullable = true)
917
+ # | |-- z: struct (nullable = true)
918
+ # | | |-- _1: long (nullable = true)
919
+ # | | |-- _2: long (nullable = true)
920
+ # | | |-- _3: long (nullable = true)
921
+ # |-- e: array (nullable = true)
922
+ # | |-- element: string (containsNull = true)
923
+
924
+ data4 = [
925
+ {
926
+ "a": 1,
927
+ "b": Row(x=1, y=Row(z=2)),
928
+ "c": {"x": -1, "y": 2},
929
+ "d": [1, 2, 3, 4, 5],
930
+ }
931
+ ]
932
+ # root
933
+ # |-- a: long (nullable = true)
934
+ # |-- b: struct (nullable = true)
935
+ # | |-- x: long (nullable = true)
936
+ # | |-- y: struct (nullable = true)
937
+ # | | |-- z: long (nullable = true)
938
+ # |-- c: map (nullable = true)
939
+ # | |-- key: string
940
+ # | |-- value: long (valueContainsNull = true)
941
+ # |-- d: array (nullable = true)
942
+ # | |-- element: long (containsNull = true)
943
+
944
+ data5 = [
945
+ {
946
+ "a": [Row(x=1, y="2"), Row(x=-1, y="-2")],
947
+ "b": [[1, 2, 3], [4, 5], [6]],
948
+ "c": {3: {4: {5: 6}}, 7: {8: {9: 0}}},
949
+ }
950
+ ]
951
+ # root
952
+ # |-- a: array (nullable = true)
953
+ # | |-- element: struct (containsNull = true)
954
+ # | | |-- x: long (nullable = true)
955
+ # | | |-- y: string (nullable = true)
956
+ # |-- b: array (nullable = true)
957
+ # | |-- element: array (containsNull = true)
958
+ # | | |-- element: long (containsNull = true)
959
+ # |-- c: map (nullable = true)
960
+ # | |-- key: long
961
+ # | |-- value: map (valueContainsNull = true)
962
+ # | | |-- key: long
963
+ # | | |-- value: map (valueContainsNull = true)
964
+ # | | | |-- key: long
965
+ # | | | |-- value: long (valueContainsNull = true)
966
+
967
+ for data in [data1, data2, data3, data4, data5]:
968
+ with self.subTest(data=data):
969
+ cdf = self.connect.createDataFrame(data)
970
+ sdf = self.spark.createDataFrame(data)
971
+
972
+ self.assertEqual(cdf.schema, sdf.schema)
973
+ self.assertEqual(cdf.collect(), sdf.collect())
974
+
975
+ def test_create_df_from_objects(self):
976
+ data = [MyObject(1, "1"), MyObject(2, "2")]
977
+
978
+ # +---+-----+
979
+ # |key|value|
980
+ # +---+-----+
981
+ # | 1| 1|
982
+ # | 2| 2|
983
+ # +---+-----+
984
+
985
+ cdf = self.connect.createDataFrame(data)
986
+ sdf = self.spark.createDataFrame(data)
987
+
988
+ self.assertEqual(cdf.schema, sdf.schema)
989
+ self.assertEqual(cdf.collect(), sdf.collect())
990
+
991
+ def test_simple_explain_string(self):
992
+ df = self.connect.read.table(self.tbl_name).limit(10)
993
+ result = df._explain_string()
994
+ self.assertGreater(len(result), 0)
995
+
996
+ def test_schema(self):
997
+ schema = self.connect.read.table(self.tbl_name).schema
998
+ self.assertEqual(
999
+ StructType(
1000
+ [StructField("id", LongType(), True), StructField("name", StringType(), True)]
1001
+ ),
1002
+ schema,
1003
+ )
1004
+
1005
+ # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
1006
+ query = """
1007
+ SELECT * FROM VALUES
1008
+ (float(1.0), double(1.0), 1.0, "1", true, NULL),
1009
+ (float(2.0), double(2.0), 2.0, "2", false, NULL),
1010
+ (float(3.0), double(3.0), NULL, "3", false, NULL)
1011
+ AS tab(a, b, c, d, e, f)
1012
+ """
1013
+ self.assertEqual(
1014
+ self.spark.sql(query).schema,
1015
+ self.connect.sql(query).schema,
1016
+ )
1017
+
1018
+ # test TimestampType, DateType
1019
+ query = """
1020
+ SELECT * FROM VALUES
1021
+ (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
1022
+ (TIMESTAMP('2019-04-12 15:50:00'), NULL),
1023
+ (NULL, DATE('2022-02-22'))
1024
+ AS tab(a, b)
1025
+ """
1026
+ self.assertEqual(
1027
+ self.spark.sql(query).schema,
1028
+ self.connect.sql(query).schema,
1029
+ )
1030
+
1031
+ # test DayTimeIntervalType
1032
+ query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
1033
+ self.assertEqual(
1034
+ self.spark.sql(query).schema,
1035
+ self.connect.sql(query).schema,
1036
+ )
1037
+
1038
+ # test MapType
1039
+ query = """
1040
+ SELECT * FROM VALUES
1041
+ (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
1042
+ (MAP('x', 'yz'), MAP('x', NULL), NULL),
1043
+ (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
1044
+ AS tab(a, b, c)
1045
+ """
1046
+ self.assertEqual(
1047
+ self.spark.sql(query).schema,
1048
+ self.connect.sql(query).schema,
1049
+ )
1050
+
1051
+ # test ArrayType
1052
+ query = """
1053
+ SELECT * FROM VALUES
1054
+ (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
1055
+ (ARRAY('x', NULL), NULL, ARRAY(1, 3)),
1056
+ (NULL, ARRAY(-1, -2, -3), Array())
1057
+ AS tab(a, b, c)
1058
+ """
1059
+ self.assertEqual(
1060
+ self.spark.sql(query).schema,
1061
+ self.connect.sql(query).schema,
1062
+ )
1063
+
1064
+ # test StructType
1065
+ query = """
1066
+ SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
1067
+ (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
1068
+ (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
1069
+ (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
1070
+ AS tab(a, b, c, d, e, f, g, h)
1071
+ """
1072
+ self.assertEqual(
1073
+ self.spark.sql(query).schema,
1074
+ self.connect.sql(query).schema,
1075
+ )
1076
+
1077
+ def test_to(self):
1078
+ # SPARK-41464: test DataFrame.to()
1079
+
1080
+ cdf = self.connect.read.table(self.tbl_name)
1081
+ df = self.spark.read.table(self.tbl_name)
1082
+
1083
+ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType):
1084
+ cdf_to = cdf.to(schema)
1085
+ df_to = df.to(schema)
1086
+ self.assertEqual(cdf_to.schema, df_to.schema)
1087
+ self.assert_eq(cdf_to.toPandas(), df_to.toPandas())
1088
+
1089
+ # The schema has not changed
1090
+ schema = StructType(
1091
+ [
1092
+ StructField("id", IntegerType(), True),
1093
+ StructField("name", StringType(), True),
1094
+ ]
1095
+ )
1096
+
1097
+ assert_eq_schema(cdf, df, schema)
1098
+
1099
+ # Change schema with struct
1100
+ schema2 = StructType([StructField("struct", schema, False)])
1101
+
1102
+ cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2)
1103
+ df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2)
1104
+
1105
+ self.assertEqual(cdf_to.schema, df_to.schema)
1106
+
1107
+ # Change the column name
1108
+ schema = StructType(
1109
+ [
1110
+ StructField("col1", IntegerType(), True),
1111
+ StructField("col2", StringType(), True),
1112
+ ]
1113
+ )
1114
+
1115
+ assert_eq_schema(cdf, df, schema)
1116
+
1117
+ # Change the column data type
1118
+ schema = StructType(
1119
+ [
1120
+ StructField("id", StringType(), True),
1121
+ StructField("name", StringType(), True),
1122
+ ]
1123
+ )
1124
+
1125
+ assert_eq_schema(cdf, df, schema)
1126
+
1127
+ # Reduce the column quantity and change data type
1128
+ schema = StructType(
1129
+ [
1130
+ StructField("id", LongType(), True),
1131
+ ]
1132
+ )
1133
+
1134
+ assert_eq_schema(cdf, df, schema)
1135
+
1136
+ # incompatible field nullability
1137
+ schema = StructType([StructField("id", LongType(), False)])
1138
+ self.assertRaisesRegex(
1139
+ AnalysisException,
1140
+ "NULLABLE_COLUMN_OR_FIELD",
1141
+ lambda: cdf.to(schema).toPandas(),
1142
+ )
1143
+
1144
+ # field cannot upcast
1145
+ schema = StructType([StructField("name", LongType())])
1146
+ self.assertRaisesRegex(
1147
+ AnalysisException,
1148
+ "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1149
+ lambda: cdf.to(schema).toPandas(),
1150
+ )
1151
+
1152
+ schema = StructType(
1153
+ [
1154
+ StructField("id", IntegerType(), True),
1155
+ StructField("name", IntegerType(), True),
1156
+ ]
1157
+ )
1158
+ self.assertRaisesRegex(
1159
+ AnalysisException,
1160
+ "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
1161
+ lambda: cdf.to(schema).toPandas(),
1162
+ )
1163
+
1164
+ # Test map type and array type
1165
+ schema = StructType(
1166
+ [
1167
+ StructField("id", StringType(), True),
1168
+ StructField("my_map", MapType(StringType(), IntegerType(), False), True),
1169
+ StructField("my_array", ArrayType(IntegerType(), False), True),
1170
+ ]
1171
+ )
1172
+ cdf = self.connect.read.table(self.tbl_name4)
1173
+ df = self.spark.read.table(self.tbl_name4)
1174
+
1175
+ assert_eq_schema(cdf, df, schema)
1176
+
1177
+ def test_toDF(self):
1178
+ # SPARK-41310: test DataFrame.toDF()
1179
+ self.assertEqual(
1180
+ self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema,
1181
+ self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema,
1182
+ )
1183
+
1184
+ def test_print_schema(self):
1185
+ # SPARK-41216: Test print schema
1186
+ tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._tree_string()
1187
+ # root
1188
+ # |-- X: integer (nullable = false)
1189
+ # |-- Y: integer (nullable = false)
1190
+ expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
1191
+ self.assertEqual(tree_str, expected)
1192
+
1193
+ def test_is_local(self):
1194
+ # SPARK-41216: Test is local
1195
+ self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal())
1196
+ self.assertFalse(self.connect.read.table(self.tbl_name).isLocal())
1197
+
1198
+ def test_is_streaming(self):
1199
+ # SPARK-41216: Test is streaming
1200
+ self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
1201
+ self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
1202
+
1203
+ def test_input_files(self):
1204
+ # SPARK-41216: Test input files
1205
+ tmpPath = tempfile.mkdtemp()
1206
+ shutil.rmtree(tmpPath)
1207
+ try:
1208
+ self.df_text.write.text(tmpPath)
1209
+
1210
+ input_files_list1 = (
1211
+ self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1212
+ )
1213
+ input_files_list2 = (
1214
+ self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
1215
+ )
1216
+
1217
+ self.assertTrue(len(input_files_list1) > 0)
1218
+ self.assertEqual(len(input_files_list1), len(input_files_list2))
1219
+ for file_path in input_files_list2:
1220
+ self.assertTrue(file_path in input_files_list1)
1221
+ finally:
1222
+ shutil.rmtree(tmpPath)
1223
+
1224
+ def test_limit_offset(self):
1225
+ df = self.connect.read.table(self.tbl_name)
1226
+ pd = df.limit(10).offset(1).toPandas()
1227
+ self.assertEqual(9, len(pd.index))
1228
+ pd2 = df.offset(98).limit(10).toPandas()
1229
+ self.assertEqual(2, len(pd2.index))
1230
+
1231
+ def test_tail(self):
1232
+ df = self.connect.read.table(self.tbl_name)
1233
+ df2 = self.spark.read.table(self.tbl_name)
1234
+ self.assertEqual(df.tail(10), df2.tail(10))
1235
+
1236
+ def test_sql(self):
1237
+ pdf = self.connect.sql("SELECT 1").toPandas()
1238
+ self.assertEqual(1, len(pdf.index))
1239
+
1240
+ def test_sql_with_named_args(self):
1241
+ df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1242
+ df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7})
1243
+ self.assert_eq(df.toPandas(), df2.toPandas())
1244
+
1245
+ def test_namedargs_with_global_limit(self):
1246
+ sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val)
1247
+ where val = :val"""
1248
+ df = self.connect.sql(sqlText, args={"val": 1})
1249
+ df2 = self.spark.sql(sqlText, args={"val": 1})
1250
+ self.assert_eq(df.toPandas(), df2.toPandas())
1251
+
1252
+ def test_sql_with_pos_args(self):
1253
+ df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1254
+ df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
1255
+ self.assert_eq(df.toPandas(), df2.toPandas())
1256
+
1257
+ def test_head(self):
1258
+ # SPARK-41002: test `head` API in Python Client
1259
+ df = self.connect.read.table(self.tbl_name)
1260
+ self.assertIsNotNone(len(df.head()))
1261
+ self.assertIsNotNone(len(df.head(1)))
1262
+ self.assertIsNotNone(len(df.head(5)))
1263
+ df2 = self.connect.read.table(self.tbl_name_empty)
1264
+ self.assertIsNone(df2.head())
1265
+
1266
+ def test_deduplicate(self):
1267
+ # SPARK-41326: test distinct and dropDuplicates.
1268
+ df = self.connect.read.table(self.tbl_name)
1269
+ df2 = self.spark.read.table(self.tbl_name)
1270
+ self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas())
1271
+ self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas())
1272
+ self.assert_eq(
1273
+ df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()
1274
+ )
1275
+
1276
+ def test_deduplicate_within_watermark_in_batch(self):
1277
+ df = self.connect.read.table(self.tbl_name)
1278
+ with self.assertRaisesRegex(
1279
+ AnalysisException,
1280
+ "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
1281
+ ):
1282
+ df.dropDuplicatesWithinWatermark().toPandas()
1283
+
1284
+ def test_first(self):
1285
+ # SPARK-41002: test `first` API in Python Client
1286
+ df = self.connect.read.table(self.tbl_name)
1287
+ self.assertIsNotNone(len(df.first()))
1288
+ df2 = self.connect.read.table(self.tbl_name_empty)
1289
+ self.assertIsNone(df2.first())
1290
+
1291
+ def test_take(self) -> None:
1292
+ # SPARK-41002: test `take` API in Python Client
1293
+ df = self.connect.read.table(self.tbl_name)
1294
+ self.assertEqual(5, len(df.take(5)))
1295
+ df2 = self.connect.read.table(self.tbl_name_empty)
1296
+ self.assertEqual(0, len(df2.take(5)))
1297
+
1298
+ def test_drop(self):
1299
+ # SPARK-41169: test drop
1300
+ query = """
1301
+ SELECT * FROM VALUES
1302
+ (false, 1, NULL), (false, NULL, 2), (NULL, 3, 3)
1303
+ AS tab(a, b, c)
1304
+ """
1305
+
1306
+ cdf = self.connect.sql(query)
1307
+ sdf = self.spark.sql(query)
1308
+ self.assert_eq(
1309
+ cdf.drop("a").toPandas(),
1310
+ sdf.drop("a").toPandas(),
1311
+ )
1312
+ self.assert_eq(
1313
+ cdf.drop("a", "b").toPandas(),
1314
+ sdf.drop("a", "b").toPandas(),
1315
+ )
1316
+ self.assert_eq(
1317
+ cdf.drop("a", "x").toPandas(),
1318
+ sdf.drop("a", "x").toPandas(),
1319
+ )
1320
+ self.assert_eq(
1321
+ cdf.drop(cdf.a, "x").toPandas(),
1322
+ sdf.drop(sdf.a, "x").toPandas(),
1323
+ )
1324
+
1325
+ def test_subquery_alias(self) -> None:
1326
+ # SPARK-40938: test subquery alias.
1327
+ plan_text = (
1328
+ self.connect.read.table(self.tbl_name)
1329
+ .alias("special_alias")
1330
+ ._explain_string(extended=True)
1331
+ )
1332
+ self.assertTrue("special_alias" in plan_text)
1333
+
1334
+ def test_sort(self):
1335
+ # SPARK-41332: test sort
1336
+ query = """
1337
+ SELECT * FROM VALUES
1338
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1339
+ AS tab(a, b, c)
1340
+ """
1341
+ # +-----+----+----+
1342
+ # | a| b| c|
1343
+ # +-----+----+----+
1344
+ # |false| 1|NULL|
1345
+ # |false|NULL| 2.0|
1346
+ # | NULL| 3| 3.0|
1347
+ # +-----+----+----+
1348
+
1349
+ cdf = self.connect.sql(query)
1350
+ sdf = self.spark.sql(query)
1351
+ self.assert_eq(
1352
+ cdf.sort("a").toPandas(),
1353
+ sdf.sort("a").toPandas(),
1354
+ )
1355
+ self.assert_eq(
1356
+ cdf.sort("c").toPandas(),
1357
+ sdf.sort("c").toPandas(),
1358
+ )
1359
+ self.assert_eq(
1360
+ cdf.sort("b").toPandas(),
1361
+ sdf.sort("b").toPandas(),
1362
+ )
1363
+ self.assert_eq(
1364
+ cdf.sort(cdf.c, "b").toPandas(),
1365
+ sdf.sort(sdf.c, "b").toPandas(),
1366
+ )
1367
+ self.assert_eq(
1368
+ cdf.sort(cdf.c.desc(), "b").toPandas(),
1369
+ sdf.sort(sdf.c.desc(), "b").toPandas(),
1370
+ )
1371
+ self.assert_eq(
1372
+ cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(),
1373
+ sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(),
1374
+ )
1375
+
1376
+ def test_range(self):
1377
+ self.assert_eq(
1378
+ self.connect.range(start=0, end=10).toPandas(),
1379
+ self.spark.range(start=0, end=10).toPandas(),
1380
+ )
1381
+ self.assert_eq(
1382
+ self.connect.range(start=0, end=10, step=3).toPandas(),
1383
+ self.spark.range(start=0, end=10, step=3).toPandas(),
1384
+ )
1385
+ self.assert_eq(
1386
+ self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1387
+ self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
1388
+ )
1389
+ # SPARK-41301
1390
+ self.assert_eq(
1391
+ self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()
1392
+ )
1393
+
1394
+ def test_create_global_temp_view(self):
1395
+ # SPARK-41127: test global temp view creation.
1396
+ with self.tempView("view_1"):
1397
+ self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1398
+ self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1")
1399
+ self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1400
+
1401
+ # Test when creating a view which is already exists but
1402
+ self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
1403
+ with self.assertRaises(AnalysisException):
1404
+ self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
1405
+
1406
+ def test_create_session_local_temp_view(self):
1407
+ # SPARK-41372: test session local temp view creation.
1408
+ with self.tempView("view_local_temp"):
1409
+ self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp")
1410
+ self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1)
1411
+ self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp")
1412
+ self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0)
1413
+
1414
+ # Test when creating a view which is already exists but
1415
+ with self.assertRaises(AnalysisException):
1416
+ self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp")
1417
+
1418
+ def test_to_pandas(self):
1419
+ # SPARK-41005: Test to pandas
1420
+ query = """
1421
+ SELECT * FROM VALUES
1422
+ (false, 1, NULL),
1423
+ (false, NULL, float(2.0)),
1424
+ (NULL, 3, float(3.0))
1425
+ AS tab(a, b, c)
1426
+ """
1427
+
1428
+ self.assert_eq(
1429
+ self.connect.sql(query).toPandas(),
1430
+ self.spark.sql(query).toPandas(),
1431
+ )
1432
+
1433
+ query = """
1434
+ SELECT * FROM VALUES
1435
+ (1, 1, NULL),
1436
+ (2, NULL, float(2.0)),
1437
+ (3, 3, float(3.0))
1438
+ AS tab(a, b, c)
1439
+ """
1440
+
1441
+ self.assert_eq(
1442
+ self.connect.sql(query).toPandas(),
1443
+ self.spark.sql(query).toPandas(),
1444
+ )
1445
+
1446
+ query = """
1447
+ SELECT * FROM VALUES
1448
+ (double(1.0), 1, "1"),
1449
+ (NULL, NULL, NULL),
1450
+ (double(2.0), 3, "3")
1451
+ AS tab(a, b, c)
1452
+ """
1453
+
1454
+ self.assert_eq(
1455
+ self.connect.sql(query).toPandas(),
1456
+ self.spark.sql(query).toPandas(),
1457
+ )
1458
+
1459
+ query = """
1460
+ SELECT * FROM VALUES
1461
+ (float(1.0), double(1.0), 1, "1"),
1462
+ (float(2.0), double(2.0), 2, "2"),
1463
+ (float(3.0), double(3.0), 3, "3")
1464
+ AS tab(a, b, c, d)
1465
+ """
1466
+
1467
+ self.assert_eq(
1468
+ self.connect.sql(query).toPandas(),
1469
+ self.spark.sql(query).toPandas(),
1470
+ )
1471
+
1472
+ def test_create_dataframe_from_pandas_with_ns_timestamp(self):
1473
+ """Truncate the timestamps for nanoseconds."""
1474
+ from datetime import datetime, timezone, timedelta
1475
+ from pandas import Timestamp
1476
+ import pandas as pd
1477
+
1478
+ pdf = pd.DataFrame(
1479
+ {
1480
+ "naive": [datetime(2019, 1, 1, 0)],
1481
+ "aware": [
1482
+ Timestamp(
1483
+ year=2019, month=1, day=1, nanosecond=500, tz=timezone(timedelta(hours=-8))
1484
+ )
1485
+ ],
1486
+ }
1487
+ )
1488
+
1489
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
1490
+ self.assertEqual(
1491
+ self.connect.createDataFrame(pdf).collect(),
1492
+ self.spark.createDataFrame(pdf).collect(),
1493
+ )
1494
+
1495
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
1496
+ self.assertEqual(
1497
+ self.connect.createDataFrame(pdf).collect(),
1498
+ self.spark.createDataFrame(pdf).collect(),
1499
+ )
1500
+
1501
+ def test_select_expr(self):
1502
+ # SPARK-41201: test selectExpr API.
1503
+ self.assert_eq(
1504
+ self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1505
+ self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
1506
+ )
1507
+ self.assert_eq(
1508
+ self.connect.read.table(self.tbl_name)
1509
+ .selectExpr(["id * 2", "cast(name as long) as name"])
1510
+ .toPandas(),
1511
+ self.spark.read.table(self.tbl_name)
1512
+ .selectExpr(["id * 2", "cast(name as long) as name"])
1513
+ .toPandas(),
1514
+ )
1515
+
1516
+ self.assert_eq(
1517
+ self.connect.read.table(self.tbl_name)
1518
+ .selectExpr("id * 2", "cast(name as long) as name")
1519
+ .toPandas(),
1520
+ self.spark.read.table(self.tbl_name)
1521
+ .selectExpr("id * 2", "cast(name as long) as name")
1522
+ .toPandas(),
1523
+ )
1524
+
1525
+ def test_select_star(self):
1526
+ data = [Row(a=1, b=Row(c=2, d=Row(e=3)))]
1527
+
1528
+ # +---+--------+
1529
+ # | a| b|
1530
+ # +---+--------+
1531
+ # | 1|{2, {3}}|
1532
+ # +---+--------+
1533
+
1534
+ cdf = self.connect.createDataFrame(data=data)
1535
+ sdf = self.spark.createDataFrame(data=data)
1536
+
1537
+ self.assertEqual(
1538
+ cdf.select("*").collect(),
1539
+ sdf.select("*").collect(),
1540
+ )
1541
+ self.assertEqual(
1542
+ cdf.select("a", "*").collect(),
1543
+ sdf.select("a", "*").collect(),
1544
+ )
1545
+ self.assertEqual(
1546
+ cdf.select("a", "b").collect(),
1547
+ sdf.select("a", "b").collect(),
1548
+ )
1549
+ self.assertEqual(
1550
+ cdf.select("a", "b.*").collect(),
1551
+ sdf.select("a", "b.*").collect(),
1552
+ )
1553
+
1554
+ def test_fill_na(self):
1555
+ # SPARK-41128: Test fill na
1556
+ query = """
1557
+ SELECT * FROM VALUES
1558
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1559
+ AS tab(a, b, c)
1560
+ """
1561
+ # +-----+----+----+
1562
+ # | a| b| c|
1563
+ # +-----+----+----+
1564
+ # |false| 1|NULL|
1565
+ # |false|NULL| 2.0|
1566
+ # | NULL| 3| 3.0|
1567
+ # +-----+----+----+
1568
+
1569
+ self.assert_eq(
1570
+ self.connect.sql(query).fillna(True).toPandas(),
1571
+ self.spark.sql(query).fillna(True).toPandas(),
1572
+ )
1573
+ self.assert_eq(
1574
+ self.connect.sql(query).fillna(2).toPandas(),
1575
+ self.spark.sql(query).fillna(2).toPandas(),
1576
+ )
1577
+ self.assert_eq(
1578
+ self.connect.sql(query).fillna(2, ["a", "b"]).toPandas(),
1579
+ self.spark.sql(query).fillna(2, ["a", "b"]).toPandas(),
1580
+ )
1581
+ self.assert_eq(
1582
+ self.connect.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1583
+ self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
1584
+ )
1585
+
1586
+ def test_drop_na(self):
1587
+ # SPARK-41148: Test drop na
1588
+ query = """
1589
+ SELECT * FROM VALUES
1590
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1591
+ AS tab(a, b, c)
1592
+ """
1593
+ # +-----+----+----+
1594
+ # | a| b| c|
1595
+ # +-----+----+----+
1596
+ # |false| 1|NULL|
1597
+ # |false|NULL| 2.0|
1598
+ # | NULL| 3| 3.0|
1599
+ # +-----+----+----+
1600
+
1601
+ self.assert_eq(
1602
+ self.connect.sql(query).dropna().toPandas(),
1603
+ self.spark.sql(query).dropna().toPandas(),
1604
+ )
1605
+ self.assert_eq(
1606
+ self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(),
1607
+ self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(),
1608
+ )
1609
+ self.assert_eq(
1610
+ self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1611
+ self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
1612
+ )
1613
+ self.assert_eq(
1614
+ self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1615
+ self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
1616
+ )
1617
+
1618
+ def test_replace(self):
1619
+ # SPARK-41315: Test replace
1620
+ query = """
1621
+ SELECT * FROM VALUES
1622
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
1623
+ AS tab(a, b, c)
1624
+ """
1625
+ # +-----+----+----+
1626
+ # | a| b| c|
1627
+ # +-----+----+----+
1628
+ # |false| 1|NULL|
1629
+ # |false|NULL| 2.0|
1630
+ # | NULL| 3| 3.0|
1631
+ # +-----+----+----+
1632
+
1633
+ self.assert_eq(
1634
+ self.connect.sql(query).replace(2, 3).toPandas(),
1635
+ self.spark.sql(query).replace(2, 3).toPandas(),
1636
+ )
1637
+ self.assert_eq(
1638
+ self.connect.sql(query).na.replace(False, True).toPandas(),
1639
+ self.spark.sql(query).na.replace(False, True).toPandas(),
1640
+ )
1641
+ self.assert_eq(
1642
+ self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1643
+ self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
1644
+ )
1645
+ self.assert_eq(
1646
+ self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1647
+ self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
1648
+ )
1649
+ self.assert_eq(
1650
+ self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1651
+ self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
1652
+ )
1653
+
1654
+ with self.assertRaises(ValueError) as context:
1655
+ self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
1656
+ self.assertTrue("Mixed type replacements are not supported" in str(context.exception))
1657
+
1658
+ with self.assertRaises(AnalysisException) as context:
1659
+ self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas()
1660
+ self.assertIn(
1661
+ """Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
1662
+ )
1663
+
1664
+ def test_unpivot(self):
1665
+ self.assert_eq(
1666
+ self.connect.read.table(self.tbl_name)
1667
+ .filter("id > 3")
1668
+ .unpivot(["id"], ["name"], "variable", "value")
1669
+ .toPandas(),
1670
+ self.spark.read.table(self.tbl_name)
1671
+ .filter("id > 3")
1672
+ .unpivot(["id"], ["name"], "variable", "value")
1673
+ .toPandas(),
1674
+ )
1675
+
1676
+ self.assert_eq(
1677
+ self.connect.read.table(self.tbl_name)
1678
+ .filter("id > 3")
1679
+ .unpivot("id", None, "variable", "value")
1680
+ .toPandas(),
1681
+ self.spark.read.table(self.tbl_name)
1682
+ .filter("id > 3")
1683
+ .unpivot("id", None, "variable", "value")
1684
+ .toPandas(),
1685
+ )
1686
+
1687
+ def test_union_by_name(self):
1688
+ # SPARK-41832: Test unionByName
1689
+ data1 = [(1, 2, 3)]
1690
+ data2 = [(6, 2, 5)]
1691
+ df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"])
1692
+ df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"])
1693
+ union_df_connect = df1_connect.unionByName(df2_connect)
1694
+
1695
+ df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"])
1696
+ df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"])
1697
+ union_df_spark = df1_spark.unionByName(df2_spark)
1698
+
1699
+ self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1700
+
1701
+ df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"])
1702
+ union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True)
1703
+
1704
+ df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"])
1705
+ union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True)
1706
+
1707
+ self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
1708
+
1709
+ def test_random_split(self):
1710
+ # SPARK-41440: test randomSplit(weights, seed).
1711
+ relations = (
1712
+ self.connect.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1713
+ )
1714
+ datasets = (
1715
+ self.spark.read.table(self.tbl_name).filter("id > 3").randomSplit([1.0, 2.0, 3.0], 2)
1716
+ )
1717
+
1718
+ self.assertTrue(len(relations) == len(datasets))
1719
+ i = 0
1720
+ while i < len(relations):
1721
+ self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
1722
+ i += 1
1723
+
1724
+ def test_observe(self):
1725
+ # SPARK-41527: test DataFrame.observe()
1726
+ observation_name = "my_metric"
1727
+
1728
+ self.assert_eq(
1729
+ self.connect.read.table(self.tbl_name)
1730
+ .filter("id > 3")
1731
+ .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
1732
+ .toPandas(),
1733
+ self.spark.read.table(self.tbl_name)
1734
+ .filter("id > 3")
1735
+ .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
1736
+ .toPandas(),
1737
+ )
1738
+
1739
+ from pyspark.sql.observation import Observation
1740
+
1741
+ observation = Observation(observation_name)
1742
+
1743
+ cdf = (
1744
+ self.connect.read.table(self.tbl_name)
1745
+ .filter("id > 3")
1746
+ .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
1747
+ .toPandas()
1748
+ )
1749
+ df = (
1750
+ self.spark.read.table(self.tbl_name)
1751
+ .filter("id > 3")
1752
+ .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
1753
+ .toPandas()
1754
+ )
1755
+
1756
+ self.assert_eq(cdf, df)
1757
+
1758
+ observed_metrics = cdf.attrs["observed_metrics"]
1759
+ self.assert_eq(len(observed_metrics), 1)
1760
+ self.assert_eq(observed_metrics[0].name, observation_name)
1761
+ self.assert_eq(len(observed_metrics[0].metrics), 3)
1762
+ for metric in observed_metrics[0].metrics:
1763
+ self.assertIsInstance(metric, ProtoExpression.Literal)
1764
+ values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
1765
+ self.assert_eq(values, [4, 99, 4944])
1766
+
1767
+ with self.assertRaises(PySparkValueError) as pe:
1768
+ self.connect.read.table(self.tbl_name).observe(observation_name)
1769
+
1770
+ self.check_error(
1771
+ exception=pe.exception,
1772
+ error_class="CANNOT_BE_EMPTY",
1773
+ message_parameters={"item": "exprs"},
1774
+ )
1775
+
1776
+ with self.assertRaises(PySparkTypeError) as pe:
1777
+ self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
1778
+
1779
+ self.check_error(
1780
+ exception=pe.exception,
1781
+ error_class="NOT_LIST_OF_COLUMN",
1782
+ message_parameters={"arg_name": "exprs"},
1783
+ )
1784
+
1785
+ def test_with_columns(self):
1786
+ # SPARK-41256: test withColumn(s).
1787
+ self.assert_eq(
1788
+ self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(),
1789
+ self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(),
1790
+ )
1791
+
1792
+ self.assert_eq(
1793
+ self.connect.read.table(self.tbl_name)
1794
+ .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)})
1795
+ .toPandas(),
1796
+ self.spark.read.table(self.tbl_name)
1797
+ .withColumns(
1798
+ {
1799
+ "id": SF.lit(False),
1800
+ "col_not_exist": SF.lit(False),
1801
+ }
1802
+ )
1803
+ .toPandas(),
1804
+ )
1805
+
1806
+ def test_hint(self):
1807
+ # SPARK-41349: Test hint
1808
+ self.assert_eq(
1809
+ self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1810
+ self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
1811
+ )
1812
+
1813
+ # Hint with unsupported name will be ignored
1814
+ self.assert_eq(
1815
+ self.connect.read.table(self.tbl_name).hint("illegal").toPandas(),
1816
+ self.spark.read.table(self.tbl_name).hint("illegal").toPandas(),
1817
+ )
1818
+
1819
+ # Hint with all supported parameter values
1820
+ such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
1821
+ self.assert_eq(
1822
+ self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1823
+ self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
1824
+ )
1825
+
1826
+ # Hint with unsupported parameter values
1827
+ with self.assertRaises(AnalysisException):
1828
+ self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas()
1829
+
1830
+ # Hint with unsupported parameter types
1831
+ with self.assertRaises(TypeError):
1832
+ self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas()
1833
+
1834
+ # Hint with unsupported parameter types
1835
+ with self.assertRaises(TypeError):
1836
+ self.connect.read.table(self.tbl_name).hint(
1837
+ "my awesome hint", 1.2345, 2, such_a_nice_list, range(6)
1838
+ ).toPandas()
1839
+
1840
+ # Hint with wrong combination
1841
+ with self.assertRaises(AnalysisException):
1842
+ self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas()
1843
+
1844
+ def test_join_hint(self):
1845
+ cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
1846
+ cdf2 = self.connect.createDataFrame(
1847
+ [Row(height=80, name="Tom"), Row(height=85, name="Bob")]
1848
+ )
1849
+
1850
+ self.assertTrue(
1851
+ "BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()
1852
+ )
1853
+ self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string())
1854
+ self.assertTrue(
1855
+ "ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()
1856
+ )
1857
+
1858
+ def test_different_spark_session_join_or_union(self):
1859
+ df = self.connect.range(10).limit(3)
1860
+
1861
+ spark2 = RemoteSparkSession(connection="sc://localhost")
1862
+ df2 = spark2.range(10).limit(3)
1863
+
1864
+ with self.assertRaises(SessionNotSameException) as e1:
1865
+ df.union(df2).collect()
1866
+ self.check_error(
1867
+ exception=e1.exception,
1868
+ error_class="SESSION_NOT_SAME",
1869
+ message_parameters={},
1870
+ )
1871
+
1872
+ with self.assertRaises(SessionNotSameException) as e2:
1873
+ df.unionByName(df2).collect()
1874
+ self.check_error(
1875
+ exception=e2.exception,
1876
+ error_class="SESSION_NOT_SAME",
1877
+ message_parameters={},
1878
+ )
1879
+
1880
+ with self.assertRaises(SessionNotSameException) as e3:
1881
+ df.join(df2).collect()
1882
+ self.check_error(
1883
+ exception=e3.exception,
1884
+ error_class="SESSION_NOT_SAME",
1885
+ message_parameters={},
1886
+ )
1887
+
1888
+ def test_extended_hint_types(self):
1889
+ cdf = self.connect.range(100).toDF("id")
1890
+
1891
+ cdf.hint(
1892
+ "my awesome hint",
1893
+ 1.2345,
1894
+ "what",
1895
+ ["itworks1", "itworks2", "itworks3"],
1896
+ ).show()
1897
+
1898
+ with self.assertRaises(PySparkTypeError) as pe:
1899
+ cdf.hint(
1900
+ "my awesome hint",
1901
+ 1.2345,
1902
+ "what",
1903
+ {"itworks1": "itworks2"},
1904
+ ).show()
1905
+
1906
+ self.check_error(
1907
+ exception=pe.exception,
1908
+ error_class="INVALID_ITEM_FOR_CONTAINER",
1909
+ message_parameters={
1910
+ "arg_name": "parameters",
1911
+ "allowed_types": "str, list, float, int",
1912
+ "item_type": "dict",
1913
+ },
1914
+ )
1915
+
1916
+ def test_empty_dataset(self):
1917
+ # SPARK-41005: Test arrow based collection with empty dataset.
1918
+ self.assertTrue(
1919
+ self.connect.sql("SELECT 1 AS X LIMIT 0")
1920
+ .toPandas()
1921
+ .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas())
1922
+ )
1923
+ pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas()
1924
+ self.assertEqual(0, len(pdf)) # empty dataset
1925
+ self.assertEqual(1, len(pdf.columns)) # one column
1926
+ self.assertEqual("X", pdf.columns[0])
1927
+
1928
+ def test_is_empty(self):
1929
+ # SPARK-41212: Test is empty
1930
+ self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty())
1931
+ self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty())
1932
+
1933
+ def test_session(self):
1934
+ self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession)
1935
+
1936
+ def test_show(self):
1937
+ # SPARK-41111: Test the show method
1938
+ show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string()
1939
+ # +---+---+
1940
+ # | X| Y|
1941
+ # +---+---+
1942
+ # | 1| 2|
1943
+ # +---+---+
1944
+ expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n"
1945
+ self.assertEqual(show_str, expected)
1946
+
1947
+ def test_describe(self):
1948
+ # SPARK-41403: Test the describe method
1949
+ self.assert_eq(
1950
+ self.connect.read.table(self.tbl_name).describe("id").toPandas(),
1951
+ self.spark.read.table(self.tbl_name).describe("id").toPandas(),
1952
+ )
1953
+ self.assert_eq(
1954
+ self.connect.read.table(self.tbl_name).describe("id", "name").toPandas(),
1955
+ self.spark.read.table(self.tbl_name).describe("id", "name").toPandas(),
1956
+ )
1957
+ self.assert_eq(
1958
+ self.connect.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1959
+ self.spark.read.table(self.tbl_name).describe(["id", "name"]).toPandas(),
1960
+ )
1961
+
1962
+ def test_stat_cov(self):
1963
+ # SPARK-41067: Test the stat.cov method
1964
+ self.assertEqual(
1965
+ self.connect.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1966
+ self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"),
1967
+ )
1968
+
1969
+ def test_stat_corr(self):
1970
+ # SPARK-41068: Test the stat.corr method
1971
+ self.assertEqual(
1972
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1973
+ self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"),
1974
+ )
1975
+
1976
+ self.assertEqual(
1977
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1978
+ self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"),
1979
+ )
1980
+
1981
+ with self.assertRaises(PySparkTypeError) as pe:
1982
+ self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson")
1983
+
1984
+ self.check_error(
1985
+ exception=pe.exception,
1986
+ error_class="NOT_STR",
1987
+ message_parameters={
1988
+ "arg_name": "col1",
1989
+ "arg_type": "int",
1990
+ },
1991
+ )
1992
+
1993
+ with self.assertRaises(PySparkTypeError) as pe:
1994
+ self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson")
1995
+
1996
+ self.check_error(
1997
+ exception=pe.exception,
1998
+ error_class="NOT_STR",
1999
+ message_parameters={
2000
+ "arg_name": "col2",
2001
+ "arg_type": "int",
2002
+ },
2003
+ )
2004
+ with self.assertRaises(ValueError) as context:
2005
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"),
2006
+ self.assertTrue(
2007
+ "Currently only the calculation of the Pearson Correlation "
2008
+ + "coefficient is supported."
2009
+ in str(context.exception)
2010
+ )
2011
+
2012
+ def test_stat_approx_quantile(self):
2013
+ # SPARK-41069: Test the stat.approxQuantile method
2014
+ result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2015
+ ["col1", "col3"], [0.1, 0.5, 0.9], 0.1
2016
+ )
2017
+ self.assertEqual(len(result), 2)
2018
+ self.assertEqual(len(result[0]), 3)
2019
+ self.assertEqual(len(result[1]), 3)
2020
+
2021
+ result = self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2022
+ ["col1"], [0.1, 0.5, 0.9], 0.1
2023
+ )
2024
+ self.assertEqual(len(result), 1)
2025
+ self.assertEqual(len(result[0]), 3)
2026
+
2027
+ with self.assertRaises(PySparkTypeError) as pe:
2028
+ self.connect.read.table(self.tbl_name2).stat.approxQuantile(1, [0.1, 0.5, 0.9], 0.1)
2029
+
2030
+ self.check_error(
2031
+ exception=pe.exception,
2032
+ error_class="NOT_LIST_OR_STR_OR_TUPLE",
2033
+ message_parameters={
2034
+ "arg_name": "col",
2035
+ "arg_type": "int",
2036
+ },
2037
+ )
2038
+
2039
+ with self.assertRaises(PySparkTypeError) as pe:
2040
+ self.connect.read.table(self.tbl_name2).stat.approxQuantile(["col1", "col3"], 0.1, 0.1)
2041
+
2042
+ self.check_error(
2043
+ exception=pe.exception,
2044
+ error_class="NOT_LIST_OR_TUPLE",
2045
+ message_parameters={
2046
+ "arg_name": "probabilities",
2047
+ "arg_type": "float",
2048
+ },
2049
+ )
2050
+ with self.assertRaises(PySparkTypeError) as pe:
2051
+ self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2052
+ ["col1", "col3"], [-0.1], 0.1
2053
+ )
2054
+
2055
+ self.check_error(
2056
+ exception=pe.exception,
2057
+ error_class="NOT_LIST_OF_FLOAT_OR_INT",
2058
+ message_parameters={"arg_name": "probabilities", "arg_type": "float"},
2059
+ )
2060
+ with self.assertRaises(PySparkTypeError) as pe:
2061
+ self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2062
+ ["col1", "col3"], [0.1, 0.5, 0.9], "str"
2063
+ )
2064
+
2065
+ self.check_error(
2066
+ exception=pe.exception,
2067
+ error_class="NOT_FLOAT_OR_INT",
2068
+ message_parameters={
2069
+ "arg_name": "relativeError",
2070
+ "arg_type": "str",
2071
+ },
2072
+ )
2073
+ with self.assertRaises(PySparkValueError) as pe:
2074
+ self.connect.read.table(self.tbl_name2).stat.approxQuantile(
2075
+ ["col1", "col3"], [0.1, 0.5, 0.9], -0.1
2076
+ )
2077
+
2078
+ self.check_error(
2079
+ exception=pe.exception,
2080
+ error_class="NEGATIVE_VALUE",
2081
+ message_parameters={
2082
+ "arg_name": "relativeError",
2083
+ "arg_value": "-0.1",
2084
+ },
2085
+ )
2086
+
2087
+ def test_stat_freq_items(self):
2088
+ # SPARK-41065: Test the stat.freqItems method
2089
+ self.assert_eq(
2090
+ self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2091
+ self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"]).toPandas(),
2092
+ )
2093
+
2094
+ self.assert_eq(
2095
+ self.connect.read.table(self.tbl_name2)
2096
+ .stat.freqItems(["col1", "col3"], 0.4)
2097
+ .toPandas(),
2098
+ self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", "col3"], 0.4).toPandas(),
2099
+ )
2100
+
2101
+ with self.assertRaises(PySparkTypeError) as pe:
2102
+ self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
2103
+
2104
+ self.check_error(
2105
+ exception=pe.exception,
2106
+ error_class="NOT_LIST_OR_TUPLE",
2107
+ message_parameters={
2108
+ "arg_name": "cols",
2109
+ "arg_type": "str",
2110
+ },
2111
+ )
2112
+
2113
+ def test_stat_sample_by(self):
2114
+ # SPARK-41069: Test stat.sample_by
2115
+
2116
+ cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key"))
2117
+ sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
2118
+
2119
+ self.assert_eq(
2120
+ cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2121
+ .groupBy("key")
2122
+ .agg(CF.count(CF.lit(1)))
2123
+ .orderBy("key")
2124
+ .toPandas(),
2125
+ sdf.sampleBy(sdf.key, fractions={0: 0.1, 1: 0.2}, seed=0)
2126
+ .groupBy("key")
2127
+ .agg(SF.count(SF.lit(1)))
2128
+ .orderBy("key")
2129
+ .toPandas(),
2130
+ )
2131
+
2132
+ with self.assertRaises(PySparkTypeError) as pe:
2133
+ cdf.stat.sampleBy(cdf.key, fractions={0: 0.1, None: 0.2}, seed=0)
2134
+
2135
+ self.check_error(
2136
+ exception=pe.exception,
2137
+ error_class="DISALLOWED_TYPE_FOR_CONTAINER",
2138
+ message_parameters={
2139
+ "arg_name": "fractions",
2140
+ "arg_type": "dict",
2141
+ "allowed_types": "float, int, str",
2142
+ "return_type": "NoneType",
2143
+ },
2144
+ )
2145
+
2146
+ with self.assertRaises(SparkConnectException):
2147
+ cdf.sampleBy(cdf.key, fractions={0: 0.1, 1: 1.2}, seed=0).show()
2148
+
2149
+ def test_repr(self):
2150
+ # SPARK-41213: Test the __repr__ method
2151
+ query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
2152
+ self.assertEqual(
2153
+ self.connect.sql(query).__repr__(),
2154
+ self.spark.sql(query).__repr__(),
2155
+ )
2156
+
2157
+ def test_explain_string(self):
2158
+ # SPARK-41122: test explain API.
2159
+ plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True)
2160
+ self.assertTrue("Parsed Logical Plan" in plan_str)
2161
+ self.assertTrue("Analyzed Logical Plan" in plan_str)
2162
+ self.assertTrue("Optimized Logical Plan" in plan_str)
2163
+ self.assertTrue("Physical Plan" in plan_str)
2164
+
2165
+ with self.assertRaises(PySparkValueError) as pe:
2166
+ self.connect.sql("SELECT 1")._explain_string(mode="unknown")
2167
+ self.check_error(
2168
+ exception=pe.exception,
2169
+ error_class="UNKNOWN_EXPLAIN_MODE",
2170
+ message_parameters={"explain_mode": "unknown"},
2171
+ )
2172
+
2173
+ def test_simple_datasource_read(self) -> None:
2174
+ writeDf = self.df_text
2175
+ tmpPath = tempfile.mkdtemp()
2176
+ shutil.rmtree(tmpPath)
2177
+ writeDf.write.text(tmpPath)
2178
+
2179
+ for schema in [
2180
+ "id STRING",
2181
+ StructType([StructField("id", StringType())]),
2182
+ ]:
2183
+ readDf = self.connect.read.format("text").schema(schema).load(path=tmpPath)
2184
+ expectResult = writeDf.collect()
2185
+ pandasResult = readDf.toPandas()
2186
+ if pandasResult is None:
2187
+ self.assertTrue(False, "Empty pandas dataframe")
2188
+ else:
2189
+ actualResult = pandasResult.values.tolist()
2190
+ self.assertEqual(len(expectResult), len(actualResult))
2191
+
2192
+ def test_simple_read_without_schema(self) -> None:
2193
+ """SPARK-41300: Schema not set when reading CSV."""
2194
+ writeDf = self.df_text
2195
+ tmpPath = tempfile.mkdtemp()
2196
+ shutil.rmtree(tmpPath)
2197
+ writeDf.write.csv(tmpPath, header=True)
2198
+
2199
+ readDf = self.connect.read.format("csv").option("header", True).load(path=tmpPath)
2200
+ expectResult = set(writeDf.collect())
2201
+ pandasResult = set(readDf.collect())
2202
+ self.assertEqual(expectResult, pandasResult)
2203
+
2204
+ def test_count(self) -> None:
2205
+ # SPARK-41308: test count() API.
2206
+ self.assertEqual(
2207
+ self.connect.read.table(self.tbl_name).count(),
2208
+ self.spark.read.table(self.tbl_name).count(),
2209
+ )
2210
+
2211
+ def test_simple_transform(self) -> None:
2212
+ """SPARK-41203: Support DF.transform"""
2213
+
2214
+ def transform_df(input_df: CDataFrame) -> CDataFrame:
2215
+ return input_df.select((CF.col("id") + CF.lit(10)).alias("id"))
2216
+
2217
+ df = self.connect.range(1, 100)
2218
+ result_left = df.transform(transform_df).collect()
2219
+ result_right = self.connect.range(11, 110).collect()
2220
+ self.assertEqual(result_right, result_left)
2221
+
2222
+ # Check assertion.
2223
+ with self.assertRaises(AssertionError):
2224
+ df.transform(lambda x: 2) # type: ignore
2225
+
2226
+ def test_alias(self) -> None:
2227
+ """Testing supported and unsupported alias"""
2228
+ col0 = (
2229
+ self.connect.range(1, 10)
2230
+ .select(CF.col("id").alias("name", metadata={"max": 99}))
2231
+ .schema.names[0]
2232
+ )
2233
+ self.assertEqual("name", col0)
2234
+
2235
+ with self.assertRaises(SparkConnectException) as exc:
2236
+ self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect()
2237
+ self.assertIn("(this, is, not)", str(exc.exception))
2238
+
2239
+ def test_column_regexp(self) -> None:
2240
+ # SPARK-41438: test dataframe.colRegex()
2241
+ ndf = self.connect.read.table(self.tbl_name3)
2242
+ df = self.spark.read.table(self.tbl_name3)
2243
+
2244
+ self.assert_eq(
2245
+ ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(),
2246
+ df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(),
2247
+ )
2248
+
2249
+ def test_repartition(self) -> None:
2250
+ # SPARK-41354: test dataframe.repartition(numPartitions)
2251
+ self.assert_eq(
2252
+ self.connect.read.table(self.tbl_name).repartition(10).toPandas(),
2253
+ self.spark.read.table(self.tbl_name).repartition(10).toPandas(),
2254
+ )
2255
+
2256
+ self.assert_eq(
2257
+ self.connect.read.table(self.tbl_name).coalesce(10).toPandas(),
2258
+ self.spark.read.table(self.tbl_name).coalesce(10).toPandas(),
2259
+ )
2260
+
2261
+ def test_repartition_by_expression(self) -> None:
2262
+ # SPARK-41354: test dataframe.repartition(expressions)
2263
+ self.assert_eq(
2264
+ self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2265
+ self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(),
2266
+ )
2267
+
2268
+ self.assert_eq(
2269
+ self.connect.read.table(self.tbl_name).repartition("id").toPandas(),
2270
+ self.spark.read.table(self.tbl_name).repartition("id").toPandas(),
2271
+ )
2272
+
2273
+ # repartition with unsupported parameter values
2274
+ with self.assertRaises(AnalysisException):
2275
+ self.connect.read.table(self.tbl_name).repartition("id+1").toPandas()
2276
+
2277
+ def test_repartition_by_range(self) -> None:
2278
+ # SPARK-41354: test dataframe.repartitionByRange(expressions)
2279
+ cdf = self.connect.read.table(self.tbl_name)
2280
+ sdf = self.spark.read.table(self.tbl_name)
2281
+
2282
+ self.assert_eq(
2283
+ cdf.repartitionByRange(10, "id").toPandas(),
2284
+ sdf.repartitionByRange(10, "id").toPandas(),
2285
+ )
2286
+
2287
+ self.assert_eq(
2288
+ cdf.repartitionByRange("id").toPandas(),
2289
+ sdf.repartitionByRange("id").toPandas(),
2290
+ )
2291
+
2292
+ self.assert_eq(
2293
+ cdf.repartitionByRange(cdf.id.desc()).toPandas(),
2294
+ sdf.repartitionByRange(sdf.id.desc()).toPandas(),
2295
+ )
2296
+
2297
+ # repartitionByRange with unsupported parameter values
2298
+ with self.assertRaises(AnalysisException):
2299
+ self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas()
2300
+
2301
+ def test_agg_with_two_agg_exprs(self) -> None:
2302
+ # SPARK-41230: test dataframe.agg()
2303
+ self.assert_eq(
2304
+ self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2305
+ self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
2306
+ )
2307
+
2308
+ def test_subtract(self):
2309
+ # SPARK-41453: test dataframe.subtract()
2310
+ ndf1 = self.connect.read.table(self.tbl_name)
2311
+ ndf2 = ndf1.filter("id > 3")
2312
+ df1 = self.spark.read.table(self.tbl_name)
2313
+ df2 = df1.filter("id > 3")
2314
+
2315
+ self.assert_eq(
2316
+ ndf1.subtract(ndf2).toPandas(),
2317
+ df1.subtract(df2).toPandas(),
2318
+ )
2319
+
2320
+ def test_write_operations(self):
2321
+ with tempfile.TemporaryDirectory() as d:
2322
+ df = self.connect.range(50)
2323
+ df.write.mode("overwrite").format("csv").save(d)
2324
+
2325
+ ndf = self.connect.read.schema("id int").load(d, format="csv")
2326
+ self.assertEqual(50, len(ndf.collect()))
2327
+ cd = ndf.collect()
2328
+ self.assertEqual(set(df.collect()), set(cd))
2329
+
2330
+ with tempfile.TemporaryDirectory() as d:
2331
+ df = self.connect.range(50)
2332
+ df.write.mode("overwrite").csv(d, lineSep="|")
2333
+
2334
+ ndf = self.connect.read.schema("id int").load(d, format="csv", lineSep="|")
2335
+ self.assertEqual(set(df.collect()), set(ndf.collect()))
2336
+
2337
+ df = self.connect.range(50)
2338
+ df.write.format("parquet").saveAsTable("parquet_test")
2339
+
2340
+ ndf = self.connect.read.table("parquet_test")
2341
+ self.assertEqual(set(df.collect()), set(ndf.collect()))
2342
+
2343
+ def test_writeTo_operations(self):
2344
+ # SPARK-42002: Implement DataFrameWriterV2
2345
+ import datetime
2346
+ from pyspark.sql.connect.functions import col, years, months, days, hours, bucket
2347
+
2348
+ df = self.connect.createDataFrame(
2349
+ [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
2350
+ )
2351
+ writer = df.writeTo("table1")
2352
+ self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
2353
+ self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
2354
+ self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
2355
+ self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
2356
+ self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
2357
+ self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
2358
+ self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
2359
+ self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
2360
+ self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
2361
+ self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
2362
+ self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2)
2363
+
2364
+ def test_agg_with_avg(self):
2365
+ # SPARK-41325: groupby.avg()
2366
+ df = (
2367
+ self.connect.range(10)
2368
+ .groupBy((CF.col("id") % CF.lit(2)).alias("moded"))
2369
+ .avg("id")
2370
+ .sort("moded")
2371
+ )
2372
+ res = df.collect()
2373
+ self.assertEqual(2, len(res))
2374
+ self.assertEqual(4.0, res[0][1])
2375
+ self.assertEqual(5.0, res[1][1])
2376
+
2377
+ # Additional GroupBy tests with 3 rows
2378
+
2379
+ df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded"))
2380
+ df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded"))
2381
+ self.assertEqual(
2382
+ set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect())
2383
+ )
2384
+
2385
+ # Dict agg
2386
+ measures = {"id": "sum"}
2387
+ self.assertEqual(
2388
+ set(df_a.agg(measures).select("sum(id)").collect()),
2389
+ set(df_b.agg(measures).select("sum(id)").collect()),
2390
+ )
2391
+
2392
+ def test_column_cannot_be_constructed_from_string(self):
2393
+ with self.assertRaises(TypeError):
2394
+ Column("col")
2395
+
2396
+ def test_crossjoin(self):
2397
+ # SPARK-41227: Test CrossJoin
2398
+ connect_df = self.connect.read.table(self.tbl_name)
2399
+ spark_df = self.spark.read.table(self.tbl_name)
2400
+ self.assert_eq(
2401
+ set(
2402
+ connect_df.select("id")
2403
+ .join(other=connect_df.select("name"), how="cross")
2404
+ .toPandas()
2405
+ ),
2406
+ set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()),
2407
+ )
2408
+ self.assert_eq(
2409
+ set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()),
2410
+ set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()),
2411
+ )
2412
+
2413
+ def test_grouped_data(self):
2414
+ query = """
2415
+ SELECT * FROM VALUES
2416
+ ('James', 'Sales', 3000, 2020),
2417
+ ('Michael', 'Sales', 4600, 2020),
2418
+ ('Robert', 'Sales', 4100, 2020),
2419
+ ('Maria', 'Finance', 3000, 2020),
2420
+ ('James', 'Sales', 3000, 2019),
2421
+ ('Scott', 'Finance', 3300, 2020),
2422
+ ('Jen', 'Finance', 3900, 2020),
2423
+ ('Jeff', 'Marketing', 3000, 2020),
2424
+ ('Kumar', 'Marketing', 2000, 2020),
2425
+ ('Saif', 'Sales', 4100, 2020)
2426
+ AS T(name, department, salary, year)
2427
+ """
2428
+
2429
+ # +-------+----------+------+----+
2430
+ # | name|department|salary|year|
2431
+ # +-------+----------+------+----+
2432
+ # | James| Sales| 3000|2020|
2433
+ # |Michael| Sales| 4600|2020|
2434
+ # | Robert| Sales| 4100|2020|
2435
+ # | Maria| Finance| 3000|2020|
2436
+ # | James| Sales| 3000|2019|
2437
+ # | Scott| Finance| 3300|2020|
2438
+ # | Jen| Finance| 3900|2020|
2439
+ # | Jeff| Marketing| 3000|2020|
2440
+ # | Kumar| Marketing| 2000|2020|
2441
+ # | Saif| Sales| 4100|2020|
2442
+ # +-------+----------+------+----+
2443
+
2444
+ cdf = self.connect.sql(query)
2445
+ sdf = self.spark.sql(query)
2446
+
2447
+ # test groupby
2448
+ self.assert_eq(
2449
+ cdf.groupBy("name").agg(CF.sum(cdf.salary)).toPandas(),
2450
+ sdf.groupBy("name").agg(SF.sum(sdf.salary)).toPandas(),
2451
+ )
2452
+ self.assert_eq(
2453
+ cdf.groupBy("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2454
+ sdf.groupBy("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2455
+ )
2456
+
2457
+ # test rollup
2458
+ self.assert_eq(
2459
+ cdf.rollup("name").agg(CF.sum(cdf.salary)).toPandas(),
2460
+ sdf.rollup("name").agg(SF.sum(sdf.salary)).toPandas(),
2461
+ )
2462
+ self.assert_eq(
2463
+ cdf.rollup("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2464
+ sdf.rollup("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2465
+ )
2466
+
2467
+ # test cube
2468
+ self.assert_eq(
2469
+ cdf.cube("name").agg(CF.sum(cdf.salary)).toPandas(),
2470
+ sdf.cube("name").agg(SF.sum(sdf.salary)).toPandas(),
2471
+ )
2472
+ self.assert_eq(
2473
+ cdf.cube("name", cdf.department).agg(CF.max("year"), CF.min(cdf.salary)).toPandas(),
2474
+ sdf.cube("name", sdf.department).agg(SF.max("year"), SF.min(sdf.salary)).toPandas(),
2475
+ )
2476
+
2477
+ # test pivot
2478
+ # pivot with values
2479
+ self.assert_eq(
2480
+ cdf.groupBy("name")
2481
+ .pivot("department", ["Sales", "Marketing"])
2482
+ .agg(CF.sum(cdf.salary))
2483
+ .toPandas(),
2484
+ sdf.groupBy("name")
2485
+ .pivot("department", ["Sales", "Marketing"])
2486
+ .agg(SF.sum(sdf.salary))
2487
+ .toPandas(),
2488
+ )
2489
+ self.assert_eq(
2490
+ cdf.groupBy(cdf.name)
2491
+ .pivot("department", ["Sales", "Finance", "Marketing"])
2492
+ .agg(CF.sum(cdf.salary))
2493
+ .toPandas(),
2494
+ sdf.groupBy(sdf.name)
2495
+ .pivot("department", ["Sales", "Finance", "Marketing"])
2496
+ .agg(SF.sum(sdf.salary))
2497
+ .toPandas(),
2498
+ )
2499
+ self.assert_eq(
2500
+ cdf.groupBy(cdf.name)
2501
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2502
+ .agg(CF.sum(cdf.salary))
2503
+ .toPandas(),
2504
+ sdf.groupBy(sdf.name)
2505
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2506
+ .agg(SF.sum(sdf.salary))
2507
+ .toPandas(),
2508
+ )
2509
+
2510
+ # pivot without values
2511
+ self.assert_eq(
2512
+ cdf.groupBy("name").pivot("department").agg(CF.sum(cdf.salary)).toPandas(),
2513
+ sdf.groupBy("name").pivot("department").agg(SF.sum(sdf.salary)).toPandas(),
2514
+ )
2515
+
2516
+ self.assert_eq(
2517
+ cdf.groupBy("name").pivot("year").agg(CF.sum(cdf.salary)).toPandas(),
2518
+ sdf.groupBy("name").pivot("year").agg(SF.sum(sdf.salary)).toPandas(),
2519
+ )
2520
+
2521
+ # check error
2522
+ with self.assertRaisesRegex(
2523
+ Exception,
2524
+ "PIVOT after ROLLUP is not supported",
2525
+ ):
2526
+ cdf.rollup("name").pivot("department").agg(CF.sum(cdf.salary))
2527
+
2528
+ with self.assertRaisesRegex(
2529
+ Exception,
2530
+ "PIVOT after CUBE is not supported",
2531
+ ):
2532
+ cdf.cube("name").pivot("department").agg(CF.sum(cdf.salary))
2533
+
2534
+ with self.assertRaisesRegex(
2535
+ Exception,
2536
+ "Repeated PIVOT operation is not supported",
2537
+ ):
2538
+ cdf.groupBy("name").pivot("year").pivot("year").agg(CF.sum(cdf.salary))
2539
+
2540
+ with self.assertRaises(PySparkTypeError) as pe:
2541
+ cdf.groupBy("name").pivot("department", ["Sales", b"Marketing"]).agg(CF.sum(cdf.salary))
2542
+
2543
+ self.check_error(
2544
+ exception=pe.exception,
2545
+ error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR",
2546
+ message_parameters={
2547
+ "arg_name": "value",
2548
+ "arg_type": "bytes",
2549
+ },
2550
+ )
2551
+
2552
+ def test_numeric_aggregation(self):
2553
+ # SPARK-41737: test numeric aggregation
2554
+ query = """
2555
+ SELECT * FROM VALUES
2556
+ ('James', 'Sales', 3000, 2020),
2557
+ ('Michael', 'Sales', 4600, 2020),
2558
+ ('Robert', 'Sales', 4100, 2020),
2559
+ ('Maria', 'Finance', 3000, 2020),
2560
+ ('James', 'Sales', 3000, 2019),
2561
+ ('Scott', 'Finance', 3300, 2020),
2562
+ ('Jen', 'Finance', 3900, 2020),
2563
+ ('Jeff', 'Marketing', 3000, 2020),
2564
+ ('Kumar', 'Marketing', 2000, 2020),
2565
+ ('Saif', 'Sales', 4100, 2020)
2566
+ AS T(name, department, salary, year)
2567
+ """
2568
+
2569
+ # +-------+----------+------+----+
2570
+ # | name|department|salary|year|
2571
+ # +-------+----------+------+----+
2572
+ # | James| Sales| 3000|2020|
2573
+ # |Michael| Sales| 4600|2020|
2574
+ # | Robert| Sales| 4100|2020|
2575
+ # | Maria| Finance| 3000|2020|
2576
+ # | James| Sales| 3000|2019|
2577
+ # | Scott| Finance| 3300|2020|
2578
+ # | Jen| Finance| 3900|2020|
2579
+ # | Jeff| Marketing| 3000|2020|
2580
+ # | Kumar| Marketing| 2000|2020|
2581
+ # | Saif| Sales| 4100|2020|
2582
+ # +-------+----------+------+----+
2583
+
2584
+ cdf = self.connect.sql(query)
2585
+ sdf = self.spark.sql(query)
2586
+
2587
+ # test groupby
2588
+ self.assert_eq(
2589
+ cdf.groupBy("name").min().toPandas(),
2590
+ sdf.groupBy("name").min().toPandas(),
2591
+ )
2592
+ self.assert_eq(
2593
+ cdf.groupBy("name").min("salary").toPandas(),
2594
+ sdf.groupBy("name").min("salary").toPandas(),
2595
+ )
2596
+ self.assert_eq(
2597
+ cdf.groupBy("name").max("salary").toPandas(),
2598
+ sdf.groupBy("name").max("salary").toPandas(),
2599
+ )
2600
+ self.assert_eq(
2601
+ cdf.groupBy("name", cdf.department).avg("salary", "year").toPandas(),
2602
+ sdf.groupBy("name", sdf.department).avg("salary", "year").toPandas(),
2603
+ )
2604
+ self.assert_eq(
2605
+ cdf.groupBy("name", cdf.department).mean("salary", "year").toPandas(),
2606
+ sdf.groupBy("name", sdf.department).mean("salary", "year").toPandas(),
2607
+ )
2608
+ self.assert_eq(
2609
+ cdf.groupBy("name", cdf.department).sum("salary", "year").toPandas(),
2610
+ sdf.groupBy("name", sdf.department).sum("salary", "year").toPandas(),
2611
+ )
2612
+
2613
+ # test rollup
2614
+ self.assert_eq(
2615
+ cdf.rollup("name").max().toPandas(),
2616
+ sdf.rollup("name").max().toPandas(),
2617
+ )
2618
+ self.assert_eq(
2619
+ cdf.rollup("name").min("salary").toPandas(),
2620
+ sdf.rollup("name").min("salary").toPandas(),
2621
+ )
2622
+ self.assert_eq(
2623
+ cdf.rollup("name").max("salary").toPandas(),
2624
+ sdf.rollup("name").max("salary").toPandas(),
2625
+ )
2626
+ self.assert_eq(
2627
+ cdf.rollup("name", cdf.department).avg("salary", "year").toPandas(),
2628
+ sdf.rollup("name", sdf.department).avg("salary", "year").toPandas(),
2629
+ )
2630
+ self.assert_eq(
2631
+ cdf.rollup("name", cdf.department).mean("salary", "year").toPandas(),
2632
+ sdf.rollup("name", sdf.department).mean("salary", "year").toPandas(),
2633
+ )
2634
+ self.assert_eq(
2635
+ cdf.rollup("name", cdf.department).sum("salary", "year").toPandas(),
2636
+ sdf.rollup("name", sdf.department).sum("salary", "year").toPandas(),
2637
+ )
2638
+
2639
+ # test cube
2640
+ self.assert_eq(
2641
+ cdf.cube("name").avg().toPandas(),
2642
+ sdf.cube("name").avg().toPandas(),
2643
+ )
2644
+ self.assert_eq(
2645
+ cdf.cube("name").mean().toPandas(),
2646
+ sdf.cube("name").mean().toPandas(),
2647
+ )
2648
+ self.assert_eq(
2649
+ cdf.cube("name").min("salary").toPandas(),
2650
+ sdf.cube("name").min("salary").toPandas(),
2651
+ )
2652
+ self.assert_eq(
2653
+ cdf.cube("name").max("salary").toPandas(),
2654
+ sdf.cube("name").max("salary").toPandas(),
2655
+ )
2656
+ self.assert_eq(
2657
+ cdf.cube("name", cdf.department).avg("salary", "year").toPandas(),
2658
+ sdf.cube("name", sdf.department).avg("salary", "year").toPandas(),
2659
+ )
2660
+ self.assert_eq(
2661
+ cdf.cube("name", cdf.department).sum("salary", "year").toPandas(),
2662
+ sdf.cube("name", sdf.department).sum("salary", "year").toPandas(),
2663
+ )
2664
+
2665
+ # test pivot
2666
+ # pivot with values
2667
+ self.assert_eq(
2668
+ cdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2669
+ sdf.groupBy("name").pivot("department", ["Sales", "Marketing"]).sum().toPandas(),
2670
+ )
2671
+ self.assert_eq(
2672
+ cdf.groupBy("name")
2673
+ .pivot("department", ["Sales", "Marketing"])
2674
+ .min("salary")
2675
+ .toPandas(),
2676
+ sdf.groupBy("name")
2677
+ .pivot("department", ["Sales", "Marketing"])
2678
+ .min("salary")
2679
+ .toPandas(),
2680
+ )
2681
+ self.assert_eq(
2682
+ cdf.groupBy("name")
2683
+ .pivot("department", ["Sales", "Marketing"])
2684
+ .max("salary")
2685
+ .toPandas(),
2686
+ sdf.groupBy("name")
2687
+ .pivot("department", ["Sales", "Marketing"])
2688
+ .max("salary")
2689
+ .toPandas(),
2690
+ )
2691
+ self.assert_eq(
2692
+ cdf.groupBy(cdf.name)
2693
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2694
+ .avg("salary", "year")
2695
+ .toPandas(),
2696
+ sdf.groupBy(sdf.name)
2697
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2698
+ .avg("salary", "year")
2699
+ .toPandas(),
2700
+ )
2701
+ self.assert_eq(
2702
+ cdf.groupBy(cdf.name)
2703
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2704
+ .sum("salary", "year")
2705
+ .toPandas(),
2706
+ sdf.groupBy(sdf.name)
2707
+ .pivot("department", ["Sales", "Finance", "Unknown"])
2708
+ .sum("salary", "year")
2709
+ .toPandas(),
2710
+ )
2711
+
2712
+ # pivot without values
2713
+ self.assert_eq(
2714
+ cdf.groupBy("name").pivot("department").min().toPandas(),
2715
+ sdf.groupBy("name").pivot("department").min().toPandas(),
2716
+ )
2717
+ self.assert_eq(
2718
+ cdf.groupBy("name").pivot("department").min("salary").toPandas(),
2719
+ sdf.groupBy("name").pivot("department").min("salary").toPandas(),
2720
+ )
2721
+ self.assert_eq(
2722
+ cdf.groupBy("name").pivot("department").max("salary").toPandas(),
2723
+ sdf.groupBy("name").pivot("department").max("salary").toPandas(),
2724
+ )
2725
+ self.assert_eq(
2726
+ cdf.groupBy(cdf.name).pivot("department").avg("salary", "year").toPandas(),
2727
+ sdf.groupBy(sdf.name).pivot("department").avg("salary", "year").toPandas(),
2728
+ )
2729
+ self.assert_eq(
2730
+ cdf.groupBy(cdf.name).pivot("department").sum("salary", "year").toPandas(),
2731
+ sdf.groupBy(sdf.name).pivot("department").sum("salary", "year").toPandas(),
2732
+ )
2733
+
2734
+ # check error
2735
+ with self.assertRaisesRegex(
2736
+ TypeError,
2737
+ "Numeric aggregation function can only be applied on numeric columns",
2738
+ ):
2739
+ cdf.groupBy("name").min("department").show()
2740
+
2741
+ with self.assertRaisesRegex(
2742
+ TypeError,
2743
+ "Numeric aggregation function can only be applied on numeric columns",
2744
+ ):
2745
+ cdf.groupBy("name").max("salary", "department").show()
2746
+
2747
+ with self.assertRaisesRegex(
2748
+ TypeError,
2749
+ "Numeric aggregation function can only be applied on numeric columns",
2750
+ ):
2751
+ cdf.rollup("name").avg("department").show()
2752
+
2753
+ with self.assertRaisesRegex(
2754
+ TypeError,
2755
+ "Numeric aggregation function can only be applied on numeric columns",
2756
+ ):
2757
+ cdf.rollup("name").sum("salary", "department").show()
2758
+
2759
+ with self.assertRaisesRegex(
2760
+ TypeError,
2761
+ "Numeric aggregation function can only be applied on numeric columns",
2762
+ ):
2763
+ cdf.cube("name").min("department").show()
2764
+
2765
+ with self.assertRaisesRegex(
2766
+ TypeError,
2767
+ "Numeric aggregation function can only be applied on numeric columns",
2768
+ ):
2769
+ cdf.cube("name").max("salary", "department").show()
2770
+
2771
+ with self.assertRaisesRegex(
2772
+ TypeError,
2773
+ "Numeric aggregation function can only be applied on numeric columns",
2774
+ ):
2775
+ cdf.groupBy("name").pivot("department").avg("department").show()
2776
+
2777
+ with self.assertRaisesRegex(
2778
+ TypeError,
2779
+ "Numeric aggregation function can only be applied on numeric columns",
2780
+ ):
2781
+ cdf.groupBy("name").pivot("department").sum("salary", "department").show()
2782
+
2783
+ def test_with_metadata(self):
2784
+ cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"])
2785
+ self.assertEqual(cdf.schema["age"].metadata, {})
2786
+ self.assertEqual(cdf.schema["name"].metadata, {})
2787
+
2788
+ cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5})
2789
+ self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5})
2790
+
2791
+ cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]})
2792
+ self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]})
2793
+
2794
+ with self.assertRaises(PySparkTypeError) as pe:
2795
+ cdf.withMetadata(columnName="name", metadata=["magic"])
2796
+
2797
+ self.check_error(
2798
+ exception=pe.exception,
2799
+ error_class="NOT_DICT",
2800
+ message_parameters={
2801
+ "arg_name": "metadata",
2802
+ "arg_type": "list",
2803
+ },
2804
+ )
2805
+
2806
+ def test_collect_nested_type(self):
2807
+ query = """
2808
+ SELECT * FROM VALUES
2809
+ (1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
2810
+ (2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
2811
+ (3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
2812
+ AS tab(a, b, c, d, e, f, g, h)
2813
+ """
2814
+
2815
+ # +---+---+----+----+-----+----+------------+-------------------+
2816
+ # | a| b| c| d| e| f| g| h|
2817
+ # +---+---+----+----+-----+----+------------+-------------------+
2818
+ # | 1| 4| 0| 8| true|true|[1, null, 3]| {1 -> 2, 3 -> 4}|
2819
+ # | 2| 5| -1|NULL|false|NULL| [1, 3]|{1 -> null, 3 -> 4}|
2820
+ # | 3| 6|NULL| 0|false|NULL| [null]| NULL|
2821
+ # +---+---+----+----+-----+----+------------+-------------------+
2822
+
2823
+ cdf = self.connect.sql(query)
2824
+ sdf = self.spark.sql(query)
2825
+
2826
+ # test collect array
2827
+ # +--------------+-------------+------------+
2828
+ # |array(a, b, c)| array(e, f)| g|
2829
+ # +--------------+-------------+------------+
2830
+ # | [1, 4, 0]| [true, true]|[1, null, 3]|
2831
+ # | [2, 5, -1]|[false, null]| [1, 3]|
2832
+ # | [3, 6, null]|[false, null]| [null]|
2833
+ # +--------------+-------------+------------+
2834
+ self.assertEqual(
2835
+ cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"), CF.col("g")).collect(),
2836
+ sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"), SF.col("g")).collect(),
2837
+ )
2838
+
2839
+ # test collect nested array
2840
+ # +-----------------------------------+-------------------------+
2841
+ # |array(array(a), array(b), array(c))|array(array(e), array(f))|
2842
+ # +-----------------------------------+-------------------------+
2843
+ # | [[1], [4], [0]]| [[true], [true]]|
2844
+ # | [[2], [5], [-1]]| [[false], [null]]|
2845
+ # | [[3], [6], [null]]| [[false], [null]]|
2846
+ # +-----------------------------------+-------------------------+
2847
+ self.assertEqual(
2848
+ cdf.select(
2849
+ CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
2850
+ CF.array(CF.array("e"), CF.array("f")),
2851
+ ).collect(),
2852
+ sdf.select(
2853
+ SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
2854
+ SF.array(SF.array("e"), SF.array("f")),
2855
+ ).collect(),
2856
+ )
2857
+
2858
+ # test collect array of struct, map
2859
+ # +----------------+---------------------+
2860
+ # |array(struct(a))| array(h)|
2861
+ # +----------------+---------------------+
2862
+ # | [{1}]| [{1 -> 2, 3 -> 4}]|
2863
+ # | [{2}]|[{1 -> null, 3 -> 4}]|
2864
+ # | [{3}]| [null]|
2865
+ # +----------------+---------------------+
2866
+ self.assertEqual(
2867
+ cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
2868
+ sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
2869
+ )
2870
+
2871
+ # test collect map
2872
+ # +-------------------+-------------------+
2873
+ # | h| map(a, b, b, c)|
2874
+ # +-------------------+-------------------+
2875
+ # | {1 -> 2, 3 -> 4}| {1 -> 4, 4 -> 0}|
2876
+ # |{1 -> null, 3 -> 4}| {2 -> 5, 5 -> -1}|
2877
+ # | NULL|{3 -> 6, 6 -> null}|
2878
+ # +-------------------+-------------------+
2879
+ self.assertEqual(
2880
+ cdf.select(CF.col("h"), CF.create_map("a", "b", "b", "c")).collect(),
2881
+ sdf.select(SF.col("h"), SF.create_map("a", "b", "b", "c")).collect(),
2882
+ )
2883
+
2884
+ # test collect map of struct, array
2885
+ # +-------------------+------------------------+
2886
+ # | map(a, g)| map(a, struct(b, g))|
2887
+ # +-------------------+------------------------+
2888
+ # |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
2889
+ # | {2 -> [1, 3]}| {2 -> {5, [1, 3]}}|
2890
+ # | {3 -> [null]}| {3 -> {6, [null]}}|
2891
+ # +-------------------+------------------------+
2892
+ self.assertEqual(
2893
+ cdf.select(CF.create_map("a", "g"), CF.create_map("a", CF.struct("b", "g"))).collect(),
2894
+ sdf.select(SF.create_map("a", "g"), SF.create_map("a", SF.struct("b", "g"))).collect(),
2895
+ )
2896
+
2897
+ # test collect struct
2898
+ # +------------------+--------------------------+
2899
+ # |struct(a, b, c, d)| struct(e, f, g)|
2900
+ # +------------------+--------------------------+
2901
+ # | {1, 4, 0, 8}|{true, true, [1, null, 3]}|
2902
+ # | {2, 5, -1, null}| {false, null, [1, 3]}|
2903
+ # | {3, 6, null, 0}| {false, null, [null]}|
2904
+ # +------------------+--------------------------+
2905
+ self.assertEqual(
2906
+ cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f", "g")).collect(),
2907
+ sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f", "g")).collect(),
2908
+ )
2909
+
2910
+ # test collect nested struct
2911
+ # +------------------------------------------+--------------------------+----------------------------+ # noqa
2912
+ # |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c, d))| struct(e, f, struct(g))| # noqa
2913
+ # +------------------------------------------+--------------------------+----------------------------+ # noqa
2914
+ # | {1, {1, {0, {8}}}}| {1, 4, {0, 8}}|{true, true, {[1, null, 3]}}| # noqa
2915
+ # | {2, {2, {-1, {null}}}}| {2, 5, {-1, null}}| {false, null, {[1, 3]}}| # noqa
2916
+ # | {3, {3, {null, {0}}}}| {3, 6, {null, 0}}| {false, null, {[null]}}| # noqa
2917
+ # +------------------------------------------+--------------------------+----------------------------+ # noqa
2918
+ self.assertEqual(
2919
+ cdf.select(
2920
+ CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
2921
+ CF.struct("a", "b", CF.struct("c", "d")),
2922
+ CF.struct("e", "f", CF.struct("g")),
2923
+ ).collect(),
2924
+ sdf.select(
2925
+ SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
2926
+ SF.struct("a", "b", SF.struct("c", "d")),
2927
+ SF.struct("e", "f", SF.struct("g")),
2928
+ ).collect(),
2929
+ )
2930
+
2931
+ # test collect struct containing array, map
2932
+ # +--------------------------------------------+
2933
+ # | struct(a, struct(a, struct(g, struct(h))))|
2934
+ # +--------------------------------------------+
2935
+ # |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
2936
+ # | {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
2937
+ # | {3, {3, {[null], {null}}}}|
2938
+ # +--------------------------------------------+
2939
+ self.assertEqual(
2940
+ cdf.select(
2941
+ CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
2942
+ ).collect(),
2943
+ sdf.select(
2944
+ SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
2945
+ ).collect(),
2946
+ )
2947
+
2948
+ def test_simple_udt(self):
2949
+ from pyspark.ml.linalg import MatrixUDT, VectorUDT
2950
+
2951
+ for schema in [
2952
+ StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2953
+ StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2954
+ StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())),
2955
+ StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2956
+ StructType().add("key", LongType()).add("vec", VectorUDT()),
2957
+ StructType().add("key", LongType()).add("mat", MatrixUDT()),
2958
+ ]:
2959
+ cdf = self.connect.createDataFrame(data=[], schema=schema)
2960
+ sdf = self.spark.createDataFrame(data=[], schema=schema)
2961
+
2962
+ self.assertEqual(cdf.schema, sdf.schema)
2963
+
2964
+ def test_simple_udt_from_read(self):
2965
+ from pyspark.ml.linalg import Matrices, Vectors
2966
+
2967
+ with tempfile.TemporaryDirectory() as d:
2968
+ path1 = f"{d}/df1.parquet"
2969
+ self.spark.createDataFrame(
2970
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2971
+ schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2972
+ ).write.parquet(path1)
2973
+
2974
+ path2 = f"{d}/df2.parquet"
2975
+ self.spark.createDataFrame(
2976
+ [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
2977
+ schema=StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
2978
+ ).write.parquet(path2)
2979
+
2980
+ path3 = f"{d}/df3.parquet"
2981
+ self.spark.createDataFrame(
2982
+ [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
2983
+ schema=StructType()
2984
+ .add("key", LongType())
2985
+ .add("val", MapType(LongType(), PythonOnlyUDT())),
2986
+ ).write.parquet(path3)
2987
+
2988
+ path4 = f"{d}/df4.parquet"
2989
+ self.spark.createDataFrame(
2990
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
2991
+ schema=StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
2992
+ ).write.parquet(path4)
2993
+
2994
+ path5 = f"{d}/df5.parquet"
2995
+ self.spark.createDataFrame(
2996
+ [Row(label=1.0, point=ExamplePoint(1.0, 2.0))]
2997
+ ).write.parquet(path5)
2998
+
2999
+ path6 = f"{d}/df6.parquet"
3000
+ self.spark.createDataFrame(
3001
+ [(Vectors.dense(1.0, 2.0, 3.0),), (Vectors.sparse(3, {1: 1.0, 2: 5.5}),)],
3002
+ ["vec"],
3003
+ ).write.parquet(path6)
3004
+
3005
+ path7 = f"{d}/df7.parquet"
3006
+ self.spark.createDataFrame(
3007
+ [
3008
+ (Matrices.dense(3, 2, [0, 1, 4, 5, 9, 10]),),
3009
+ (Matrices.sparse(1, 1, [0, 1], [0], [2.0]),),
3010
+ ],
3011
+ ["mat"],
3012
+ ).write.parquet(path7)
3013
+
3014
+ for path in [path1, path2, path3, path4, path5, path6, path7]:
3015
+ self.assertEqual(
3016
+ self.connect.read.parquet(path).schema,
3017
+ self.spark.read.parquet(path).schema,
3018
+ )
3019
+
3020
+ def test_version(self):
3021
+ self.assertEqual(
3022
+ self.connect.version,
3023
+ self.spark.version,
3024
+ )
3025
+
3026
+ def test_same_semantics(self):
3027
+ plan = self.connect.sql("SELECT 1")
3028
+ other = self.connect.sql("SELECT 1")
3029
+ self.assertTrue(plan.sameSemantics(other))
3030
+
3031
+ def test_semantic_hash(self):
3032
+ plan = self.connect.sql("SELECT 1")
3033
+ other = self.connect.sql("SELECT 1")
3034
+ self.assertEqual(
3035
+ plan.semanticHash(),
3036
+ other.semanticHash(),
3037
+ )
3038
+
3039
+ def test_unsupported_functions(self):
3040
+ # SPARK-41225: Disable unsupported functions.
3041
+ df = self.connect.read.table(self.tbl_name)
3042
+ for f in (
3043
+ "rdd",
3044
+ "foreach",
3045
+ "foreachPartition",
3046
+ "checkpoint",
3047
+ "localCheckpoint",
3048
+ ):
3049
+ with self.assertRaises(NotImplementedError):
3050
+ getattr(df, f)()
3051
+
3052
+ def test_unsupported_session_functions(self):
3053
+ # SPARK-41934: Disable unsupported functions.
3054
+
3055
+ with self.assertRaises(NotImplementedError):
3056
+ RemoteSparkSession.builder.enableHiveSupport()
3057
+
3058
+ for f in (
3059
+ "newSession",
3060
+ "sparkContext",
3061
+ ):
3062
+ with self.assertRaises(NotImplementedError):
3063
+ getattr(self.connect, f)()
3064
+
3065
+ def test_sql_with_command(self):
3066
+ # SPARK-42705: spark.sql should return values from the command.
3067
+ self.assertEqual(
3068
+ self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()
3069
+ )
3070
+
3071
+ def test_schema_has_nullable(self):
3072
+ schema_false = StructType().add("id", IntegerType(), False)
3073
+ cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
3074
+ sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
3075
+ self.assertEqual(cdf1.schema, sdf1.schema)
3076
+ self.assertEqual(cdf1.collect(), sdf1.collect())
3077
+
3078
+ schema_true = StructType().add("id", IntegerType(), True)
3079
+ cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
3080
+ sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
3081
+ self.assertEqual(cdf2.schema, sdf2.schema)
3082
+ self.assertEqual(cdf2.collect(), sdf2.collect())
3083
+
3084
+ pdf1 = cdf1.toPandas()
3085
+ cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
3086
+ sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
3087
+ self.assertEqual(cdf3.schema, sdf3.schema)
3088
+ self.assertEqual(cdf3.collect(), sdf3.collect())
3089
+
3090
+ pdf2 = cdf2.toPandas()
3091
+ cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
3092
+ sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
3093
+ self.assertEqual(cdf4.schema, sdf4.schema)
3094
+ self.assertEqual(cdf4.collect(), sdf4.collect())
3095
+
3096
+ def test_array_has_nullable(self):
3097
+ for schemas, data in [
3098
+ (
3099
+ [StructType().add("arr", ArrayType(IntegerType(), False), True)],
3100
+ [Row([1, 2]), Row([3]), Row(None)],
3101
+ ),
3102
+ (
3103
+ [
3104
+ StructType().add("arr", ArrayType(IntegerType(), True), True),
3105
+ "arr array<integer>",
3106
+ ],
3107
+ [Row([1, None]), Row([3]), Row(None)],
3108
+ ),
3109
+ (
3110
+ [StructType().add("arr", ArrayType(IntegerType(), False), False)],
3111
+ [Row([1, 2]), Row([3])],
3112
+ ),
3113
+ (
3114
+ [
3115
+ StructType().add("arr", ArrayType(IntegerType(), True), False),
3116
+ "arr array<integer> not null",
3117
+ ],
3118
+ [Row([1, None]), Row([3])],
3119
+ ),
3120
+ ]:
3121
+ for schema in schemas:
3122
+ with self.subTest(schema=schema):
3123
+ cdf = self.connect.createDataFrame(data, schema=schema)
3124
+ sdf = self.spark.createDataFrame(data, schema=schema)
3125
+ self.assertEqual(cdf.schema, sdf.schema)
3126
+ self.assertEqual(cdf.collect(), sdf.collect())
3127
+
3128
+ def test_map_has_nullable(self):
3129
+ for schemas, data in [
3130
+ (
3131
+ [StructType().add("map", MapType(StringType(), IntegerType(), False), True)],
3132
+ [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
3133
+ ),
3134
+ (
3135
+ [
3136
+ StructType().add("map", MapType(StringType(), IntegerType(), True), True),
3137
+ "map map<string, integer>",
3138
+ ],
3139
+ [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
3140
+ ),
3141
+ (
3142
+ [StructType().add("map", MapType(StringType(), IntegerType(), False), False)],
3143
+ [Row({"a": 1, "b": 2}), Row({"a": 3})],
3144
+ ),
3145
+ (
3146
+ [
3147
+ StructType().add("map", MapType(StringType(), IntegerType(), True), False),
3148
+ "map map<string, integer> not null",
3149
+ ],
3150
+ [Row({"a": 1, "b": None}), Row({"a": 3})],
3151
+ ),
3152
+ ]:
3153
+ for schema in schemas:
3154
+ with self.subTest(schema=schema):
3155
+ cdf = self.connect.createDataFrame(data, schema=schema)
3156
+ sdf = self.spark.createDataFrame(data, schema=schema)
3157
+ self.assertEqual(cdf.schema, sdf.schema)
3158
+ self.assertEqual(cdf.collect(), sdf.collect())
3159
+
3160
+ def test_struct_has_nullable(self):
3161
+ for schemas, data in [
3162
+ (
3163
+ [
3164
+ StructType().add("struct", StructType().add("i", IntegerType(), False), True),
3165
+ "struct struct<i: integer not null>",
3166
+ ],
3167
+ [Row(Row(1)), Row(Row(2)), Row(None)],
3168
+ ),
3169
+ (
3170
+ [
3171
+ StructType().add("struct", StructType().add("i", IntegerType(), True), True),
3172
+ "struct struct<i: integer>",
3173
+ ],
3174
+ [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
3175
+ ),
3176
+ (
3177
+ [
3178
+ StructType().add("struct", StructType().add("i", IntegerType(), False), False),
3179
+ "struct struct<i: integer not null> not null",
3180
+ ],
3181
+ [Row(Row(1)), Row(Row(2))],
3182
+ ),
3183
+ (
3184
+ [
3185
+ StructType().add("struct", StructType().add("i", IntegerType(), True), False),
3186
+ "struct struct<i: integer> not null",
3187
+ ],
3188
+ [Row(Row(1)), Row(Row(2)), Row(Row(None))],
3189
+ ),
3190
+ ]:
3191
+ for schema in schemas:
3192
+ with self.subTest(schema=schema):
3193
+ cdf = self.connect.createDataFrame(data, schema=schema)
3194
+ sdf = self.spark.createDataFrame(data, schema=schema)
3195
+ self.assertEqual(cdf.schema, sdf.schema)
3196
+ self.assertEqual(cdf.collect(), sdf.collect())
3197
+
3198
+ def test_large_client_data(self):
3199
+ # SPARK-42816 support more than 4MB message size.
3200
+ # ~200bytes
3201
+ cols = ["abcdefghijklmnoprstuvwxyz" for x in range(10)]
3202
+ # 100k rows => 20MB
3203
+ row_count = 100 * 1000
3204
+ rows = [cols] * row_count
3205
+ self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
3206
+
3207
+ def test_unsupported_jvm_attribute(self):
3208
+ # Unsupported jvm attributes for Spark session.
3209
+ unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
3210
+ spark_session = self.connect
3211
+ for attr in unsupported_attrs:
3212
+ with self.assertRaises(PySparkAttributeError) as pe:
3213
+ getattr(spark_session, attr)
3214
+
3215
+ self.check_error(
3216
+ exception=pe.exception,
3217
+ error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3218
+ message_parameters={"attr_name": attr},
3219
+ )
3220
+
3221
+ # Unsupported jvm attributes for DataFrame.
3222
+ unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"]
3223
+ cdf = self.connect.range(10)
3224
+ for attr in unsupported_attrs:
3225
+ with self.assertRaises(PySparkAttributeError) as pe:
3226
+ getattr(cdf, attr)
3227
+
3228
+ self.check_error(
3229
+ exception=pe.exception,
3230
+ error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3231
+ message_parameters={"attr_name": attr},
3232
+ )
3233
+
3234
+ # Unsupported jvm attributes for Column.
3235
+ with self.assertRaises(PySparkAttributeError) as pe:
3236
+ getattr(cdf.id, "_jc")
3237
+
3238
+ self.check_error(
3239
+ exception=pe.exception,
3240
+ error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3241
+ message_parameters={"attr_name": "_jc"},
3242
+ )
3243
+
3244
+ # Unsupported jvm attributes for DataFrameReader.
3245
+ with self.assertRaises(PySparkAttributeError) as pe:
3246
+ getattr(spark_session.read, "_jreader")
3247
+
3248
+ self.check_error(
3249
+ exception=pe.exception,
3250
+ error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
3251
+ message_parameters={"attr_name": "_jreader"},
3252
+ )
3253
+
3254
+ def test_df_caache(self):
3255
+ df = self.connect.range(10)
3256
+ df.cache()
3257
+ self.assert_eq(10, df.count())
3258
+ self.assertTrue(df.is_cached)
3259
+
3260
+
3261
+ @unittest.skipIf(
3262
+ "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Session creation different from local mode"
3263
+ )
3264
+ class SparkConnectSessionTests(ReusedConnectTestCase):
3265
+ def setUp(self) -> None:
3266
+ self.spark = (
3267
+ PySparkSession.builder.config(conf=self.conf())
3268
+ .appName(self.__class__.__name__)
3269
+ .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
3270
+ .getOrCreate()
3271
+ )
3272
+
3273
+ def tearDown(self):
3274
+ self.spark.stop()
3275
+
3276
+ def _check_no_active_session_error(self, e: PySparkException):
3277
+ self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict())
3278
+
3279
+ def test_stop_session(self):
3280
+ df = self.spark.sql("select 1 as a, 2 as b")
3281
+ catalog = self.spark.catalog
3282
+ self.spark.stop()
3283
+
3284
+ # _execute_and_fetch
3285
+ with self.assertRaises(SparkConnectException) as e:
3286
+ self.spark.sql("select 1")
3287
+ self._check_no_active_session_error(e.exception)
3288
+
3289
+ with self.assertRaises(SparkConnectException) as e:
3290
+ catalog.tableExists("table")
3291
+ self._check_no_active_session_error(e.exception)
3292
+
3293
+ # _execute
3294
+ with self.assertRaises(SparkConnectException) as e:
3295
+ self.spark.udf.register("test_func", lambda x: x + 1)
3296
+ self._check_no_active_session_error(e.exception)
3297
+
3298
+ # _analyze
3299
+ with self.assertRaises(SparkConnectException) as e:
3300
+ df._explain_string(extended=True)
3301
+ self._check_no_active_session_error(e.exception)
3302
+
3303
+ # Config
3304
+ with self.assertRaises(SparkConnectException) as e:
3305
+ self.spark.conf.get("some.conf")
3306
+ self._check_no_active_session_error(e.exception)
3307
+
3308
+ def test_error_stack_trace(self):
3309
+ with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
3310
+ with self.assertRaises(AnalysisException) as e:
3311
+ self.spark.sql("select x").collect()
3312
+ self.assertTrue("JVM stacktrace" in e.exception.message)
3313
+ self.assertTrue(
3314
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3315
+ )
3316
+
3317
+ with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
3318
+ with self.assertRaises(AnalysisException) as e:
3319
+ self.spark.sql("select x").collect()
3320
+ self.assertFalse("JVM stacktrace" in e.exception.message)
3321
+ self.assertFalse(
3322
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3323
+ )
3324
+
3325
+ # Create a new session with a different stack trace size.
3326
+ self.spark.stop()
3327
+ spark = (
3328
+ PySparkSession.builder.config(conf=self.conf())
3329
+ .config("spark.connect.jvmStacktrace.maxSize", 128)
3330
+ .remote("local[4]")
3331
+ .getOrCreate()
3332
+ )
3333
+ spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
3334
+ with self.assertRaises(AnalysisException) as e:
3335
+ spark.sql("select x").collect()
3336
+ self.assertTrue("JVM stacktrace" in e.exception.message)
3337
+ self.assertFalse(
3338
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
3339
+ )
3340
+ spark.stop()
3341
+
3342
+ def test_can_create_multiple_sessions_to_different_remotes(self):
3343
+ self.spark.stop()
3344
+ self.assertIsNotNone(self.spark._client)
3345
+ # Creates a new remote session.
3346
+ other = PySparkSession.builder.remote("sc://other.remote:114/").create()
3347
+ self.assertNotEquals(self.spark, other)
3348
+
3349
+ # Gets currently active session.
3350
+ same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
3351
+ self.assertEquals(other, same)
3352
+ same.stop()
3353
+
3354
+ # Make sure the environment is clean.
3355
+ self.spark.stop()
3356
+ with self.assertRaises(RuntimeError) as e:
3357
+ PySparkSession.builder.create()
3358
+ self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e))
3359
+
3360
+
3361
+ @unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
3362
+ class SparkConnectSessionWithOptionsTest(unittest.TestCase):
3363
+ def setUp(self) -> None:
3364
+ self.spark = (
3365
+ PySparkSession.builder.config("string", "foo")
3366
+ .config("integer", 1)
3367
+ .config("boolean", False)
3368
+ .appName(self.__class__.__name__)
3369
+ .remote("local[4]")
3370
+ .getOrCreate()
3371
+ )
3372
+
3373
+ def tearDown(self):
3374
+ self.spark.stop()
3375
+
3376
+ def test_config(self):
3377
+ # Config
3378
+ self.assertEqual(self.spark.conf.get("string"), "foo")
3379
+ self.assertEqual(self.spark.conf.get("boolean"), "false")
3380
+ self.assertEqual(self.spark.conf.get("integer"), "1")
3381
+
3382
+
3383
+ @unittest.skipIf(not should_test_connect, connect_requirement_message)
3384
+ class ClientTests(unittest.TestCase):
3385
+ def test_retry_error_handling(self):
3386
+ # Helper class for wrapping the test.
3387
+ class TestError(grpc.RpcError, Exception):
3388
+ def __init__(self, code: grpc.StatusCode):
3389
+ self._code = code
3390
+
3391
+ def code(self):
3392
+ return self._code
3393
+
3394
+ def stub(retries, w, code):
3395
+ w["attempts"] += 1
3396
+ if w["attempts"] < retries:
3397
+ w["raised"] += 1
3398
+ raise TestError(code)
3399
+
3400
+ # Check that max_retries 1 is only one retry so two attempts.
3401
+ call_wrap = defaultdict(int)
3402
+ for attempt in Retrying(
3403
+ can_retry=lambda x: True,
3404
+ max_retries=1,
3405
+ backoff_multiplier=1,
3406
+ initial_backoff=1,
3407
+ max_backoff=10,
3408
+ jitter=0,
3409
+ min_jitter_threshold=0,
3410
+ ):
3411
+ with attempt:
3412
+ stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3413
+
3414
+ self.assertEqual(2, call_wrap["attempts"])
3415
+ self.assertEqual(1, call_wrap["raised"])
3416
+
3417
+ # Check that if we have less than 4 retries all is ok.
3418
+ call_wrap = defaultdict(int)
3419
+ for attempt in Retrying(
3420
+ can_retry=lambda x: True,
3421
+ max_retries=4,
3422
+ backoff_multiplier=1,
3423
+ initial_backoff=1,
3424
+ max_backoff=10,
3425
+ jitter=0,
3426
+ min_jitter_threshold=0,
3427
+ ):
3428
+ with attempt:
3429
+ stub(2, call_wrap, grpc.StatusCode.INTERNAL)
3430
+
3431
+ self.assertTrue(call_wrap["attempts"] < 4)
3432
+ self.assertEqual(call_wrap["raised"], 1)
3433
+
3434
+ # Exceed the retries.
3435
+ call_wrap = defaultdict(int)
3436
+ with self.assertRaises(TestError):
3437
+ for attempt in Retrying(
3438
+ can_retry=lambda x: True,
3439
+ max_retries=2,
3440
+ max_backoff=50,
3441
+ backoff_multiplier=1,
3442
+ initial_backoff=50,
3443
+ jitter=0,
3444
+ min_jitter_threshold=0,
3445
+ ):
3446
+ with attempt:
3447
+ stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3448
+
3449
+ self.assertTrue(call_wrap["attempts"] < 5)
3450
+ self.assertEqual(call_wrap["raised"], 3)
3451
+
3452
+ # Check that only specific exceptions are retried.
3453
+ # Check that if we have less than 4 retries all is ok.
3454
+ call_wrap = defaultdict(int)
3455
+ for attempt in Retrying(
3456
+ can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3457
+ max_retries=4,
3458
+ backoff_multiplier=1,
3459
+ initial_backoff=1,
3460
+ max_backoff=10,
3461
+ jitter=0,
3462
+ min_jitter_threshold=0,
3463
+ ):
3464
+ with attempt:
3465
+ stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
3466
+
3467
+ self.assertTrue(call_wrap["attempts"] < 4)
3468
+ self.assertEqual(call_wrap["raised"], 1)
3469
+
3470
+ # Exceed the retries.
3471
+ call_wrap = defaultdict(int)
3472
+ with self.assertRaises(TestError):
3473
+ for attempt in Retrying(
3474
+ can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3475
+ max_retries=2,
3476
+ max_backoff=50,
3477
+ backoff_multiplier=1,
3478
+ initial_backoff=50,
3479
+ jitter=0,
3480
+ min_jitter_threshold=0,
3481
+ ):
3482
+ with attempt:
3483
+ stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
3484
+
3485
+ self.assertTrue(call_wrap["attempts"] < 4)
3486
+ self.assertEqual(call_wrap["raised"], 3)
3487
+
3488
+ # Test that another error is always thrown.
3489
+ call_wrap = defaultdict(int)
3490
+ with self.assertRaises(TestError):
3491
+ for attempt in Retrying(
3492
+ can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
3493
+ max_retries=4,
3494
+ backoff_multiplier=1,
3495
+ initial_backoff=1,
3496
+ max_backoff=10,
3497
+ jitter=0,
3498
+ min_jitter_threshold=0,
3499
+ ):
3500
+ with attempt:
3501
+ stub(5, call_wrap, grpc.StatusCode.INTERNAL)
3502
+
3503
+ self.assertEqual(call_wrap["attempts"], 1)
3504
+ self.assertEqual(call_wrap["raised"], 1)
3505
+
3506
+
3507
+ @unittest.skipIf(not should_test_connect, connect_requirement_message)
3508
+ class ChannelBuilderTests(unittest.TestCase):
3509
+ def test_invalid_connection_strings(self):
3510
+ invalid = [
3511
+ "scc://host:12",
3512
+ "http://host",
3513
+ "sc:/host:1234/path",
3514
+ "sc://host/path",
3515
+ "sc://host/;parm1;param2",
3516
+ ]
3517
+ for i in invalid:
3518
+ self.assertRaises(PySparkValueError, ChannelBuilder, i)
3519
+
3520
+ def test_sensible_defaults(self):
3521
+ chan = ChannelBuilder("sc://host")
3522
+ self.assertFalse(chan.secure, "Default URL is not secure")
3523
+
3524
+ chan = ChannelBuilder("sc://host/;token=abcs")
3525
+ self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
3526
+ self.assertRegex(
3527
+ chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$"
3528
+ )
3529
+ chan = ChannelBuilder("sc://host/;use_ssl=abcs")
3530
+ self.assertFalse(chan.secure, "Garbage in, false out")
3531
+
3532
+ def test_user_agent(self):
3533
+ chan = ChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
3534
+ self.assertIn("Agent123 /3.4", chan.userAgent)
3535
+
3536
+ def test_user_agent_len(self):
3537
+ user_agent = "x" * 2049
3538
+ chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3539
+ with self.assertRaises(SparkConnectException) as err:
3540
+ chan.userAgent
3541
+ self.assertRegex(err.exception.message, "'user_agent' parameter should not exceed")
3542
+
3543
+ user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
3544
+ expected = "ä" * 341
3545
+ chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
3546
+ self.assertIn(expected, chan.userAgent)
3547
+
3548
+ def test_valid_channel_creation(self):
3549
+ chan = ChannelBuilder("sc://host").toChannel()
3550
+ self.assertIsInstance(chan, grpc.Channel)
3551
+
3552
+ # Sets up a channel without tokens because ssl is not used.
3553
+ chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel()
3554
+ self.assertIsInstance(chan, grpc.Channel)
3555
+
3556
+ chan = ChannelBuilder("sc://host/;use_ssl=true").toChannel()
3557
+ self.assertIsInstance(chan, grpc.Channel)
3558
+
3559
+ def test_channel_properties(self):
3560
+ chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021")
3561
+ self.assertEqual("host:15002", chan.endpoint)
3562
+ self.assertIn("foo", chan.userAgent.split(" "))
3563
+ self.assertEqual(True, chan.secure)
3564
+ self.assertEqual("120 21", chan.get("param1"))
3565
+
3566
+ def test_metadata(self):
3567
+ chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd")
3568
+ md = chan.metadata()
3569
+ self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
3570
+
3571
+ def test_metadata(self):
3572
+ id = str(uuid.uuid4())
3573
+ chan = ChannelBuilder(f"sc://host/;session_id={id}")
3574
+ self.assertEqual(id, chan.session_id)
3575
+
3576
+ chan = ChannelBuilder(f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true")
3577
+ md = chan.metadata()
3578
+ for kv in md:
3579
+ self.assertNotIn(
3580
+ kv[0],
3581
+ [
3582
+ ChannelBuilder.PARAM_SESSION_ID,
3583
+ ChannelBuilder.PARAM_TOKEN,
3584
+ ChannelBuilder.PARAM_USER_ID,
3585
+ ChannelBuilder.PARAM_USER_AGENT,
3586
+ ChannelBuilder.PARAM_USE_SSL,
3587
+ ],
3588
+ "Metadata must not contain fixed params",
3589
+ )
3590
+
3591
+ with self.assertRaises(ValueError) as ve:
3592
+ chan = ChannelBuilder("sc://host/;session_id=abcd")
3593
+ SparkConnectClient(chan)
3594
+ self.assertIn(
3595
+ "Parameter value 'session_id' must be a valid UUID format.", str(ve.exception)
3596
+ )
3597
+
3598
+ chan = ChannelBuilder("sc://host/")
3599
+ self.assertIsNone(chan.session_id)
3600
+
3601
+
3602
+ if __name__ == "__main__":
3603
+ from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
3604
+
3605
+ try:
3606
+ import xmlrunner
3607
+
3608
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
3609
+ except ImportError:
3610
+ testRunner = None
3611
+
3612
+ unittest.main(testRunner=testRunner, verbosity=2)