snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4448 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ """
19
+ A wrapper for GroupedData to behave like pandas GroupBy.
20
+ """
21
+ from abc import ABCMeta, abstractmethod
22
+ import inspect
23
+ from collections import defaultdict, namedtuple
24
+ from distutils.version import LooseVersion
25
+ from functools import partial
26
+ from itertools import product
27
+ from typing import (
28
+ Any,
29
+ Callable,
30
+ Dict,
31
+ Generic,
32
+ Iterator,
33
+ Mapping,
34
+ List,
35
+ Optional,
36
+ Sequence,
37
+ Set,
38
+ Tuple,
39
+ Type,
40
+ Union,
41
+ cast,
42
+ TYPE_CHECKING,
43
+ )
44
+ import warnings
45
+
46
+ import pandas as pd
47
+ from pandas.api.types import is_number, is_hashable, is_list_like # type: ignore[attr-defined]
48
+
49
+ if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
50
+ from pandas.core.common import _builtin_table # type: ignore[attr-defined]
51
+ else:
52
+ from pandas.core.base import SelectionMixin
53
+
54
+ _builtin_table = SelectionMixin._builtin_table # type: ignore[attr-defined]
55
+
56
+ from pyspark.sql import Column, DataFrame as SparkDataFrame, Window, functions as F
57
+ from pyspark.sql.types import (
58
+ BooleanType,
59
+ DataType,
60
+ DoubleType,
61
+ NumericType,
62
+ StructField,
63
+ StructType,
64
+ StringType,
65
+ )
66
+
67
+ from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
68
+ from pyspark.pandas._typing import Axis, FrameLike, Label, Name
69
+ from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
70
+ from pyspark.pandas.frame import DataFrame
71
+ from pyspark.pandas.internal import (
72
+ InternalField,
73
+ InternalFrame,
74
+ HIDDEN_COLUMNS,
75
+ NATURAL_ORDER_COLUMN_NAME,
76
+ SPARK_INDEX_NAME_FORMAT,
77
+ SPARK_DEFAULT_SERIES_NAME,
78
+ SPARK_INDEX_NAME_PATTERN,
79
+ )
80
+ from pyspark.pandas.missing.groupby import (
81
+ MissingPandasLikeDataFrameGroupBy,
82
+ MissingPandasLikeSeriesGroupBy,
83
+ )
84
+ from pyspark.pandas.series import Series, first_series
85
+ from pyspark.pandas.spark import functions as SF
86
+ from pyspark.pandas.config import get_option
87
+ from pyspark.pandas.utils import (
88
+ align_diff_frames,
89
+ is_name_like_tuple,
90
+ is_name_like_value,
91
+ name_like_string,
92
+ same_anchor,
93
+ scol_for,
94
+ verify_temp_column_name,
95
+ log_advice,
96
+ )
97
+ from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
98
+ from pyspark.pandas.exceptions import DataError
99
+
100
+ if TYPE_CHECKING:
101
+ from pyspark.pandas.window import RollingGroupby, ExpandingGroupby, ExponentialMovingGroupby
102
+
103
+
104
+ # to keep it the same as pandas
105
+ NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
106
+
107
+
108
+ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
109
+ """
110
+ :ivar _psdf: The parent dataframe that is used to perform the groupby
111
+ :type _psdf: DataFrame
112
+ :ivar _groupkeys: The list of keys that will be used to perform the grouping
113
+ :type _groupkeys: List[Series]
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ psdf: DataFrame,
119
+ groupkeys: List[Series],
120
+ as_index: bool,
121
+ dropna: bool,
122
+ column_labels_to_exclude: Set[Label],
123
+ agg_columns_selected: bool,
124
+ agg_columns: List[Series],
125
+ ):
126
+ self._psdf = psdf
127
+ self._groupkeys = groupkeys
128
+ self._as_index = as_index
129
+ self._dropna = dropna
130
+ self._column_labels_to_exclude = column_labels_to_exclude
131
+ self._agg_columns_selected = agg_columns_selected
132
+ self._agg_columns = agg_columns
133
+
134
+ @property
135
+ def _groupkeys_scols(self) -> List[Column]:
136
+ return [s.spark.column for s in self._groupkeys]
137
+
138
+ @property
139
+ def _agg_columns_scols(self) -> List[Column]:
140
+ return [s.spark.column for s in self._agg_columns]
141
+
142
+ @abstractmethod
143
+ def _apply_series_op(
144
+ self,
145
+ op: Callable[["SeriesGroupBy"], Series],
146
+ should_resolve: bool = False,
147
+ numeric_only: bool = False,
148
+ ) -> FrameLike:
149
+ pass
150
+
151
+ @abstractmethod
152
+ def _handle_output(self, psdf: DataFrame) -> FrameLike:
153
+ pass
154
+
155
+ # TODO: Series support is not implemented yet.
156
+ # TODO: not all arguments are implemented comparing to pandas' for now.
157
+ def aggregate(
158
+ self,
159
+ func_or_funcs: Optional[Union[str, List[str], Dict[Name, Union[str, List[str]]]]] = None,
160
+ *args: Any,
161
+ **kwargs: Any,
162
+ ) -> DataFrame:
163
+ """Aggregate using one or more operations over the specified axis.
164
+
165
+ Parameters
166
+ ----------
167
+ func_or_funcs : dict, str or list
168
+ a dict mapping from column name (string) to
169
+ aggregate functions (string or list of strings).
170
+
171
+ Returns
172
+ -------
173
+ Series or DataFrame
174
+
175
+ The return can be:
176
+
177
+ * Series : when DataFrame.agg is called with a single function
178
+ * DataFrame : when DataFrame.agg is called with several functions
179
+
180
+ Return Series or DataFrame.
181
+
182
+ Notes
183
+ -----
184
+ `agg` is an alias for `aggregate`. Use the alias.
185
+
186
+ See Also
187
+ --------
188
+ pyspark.pandas.Series.groupby
189
+ pyspark.pandas.DataFrame.groupby
190
+
191
+ Examples
192
+ --------
193
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 2],
194
+ ... 'B': [1, 2, 3, 4],
195
+ ... 'C': [0.362, 0.227, 1.267, -0.562]},
196
+ ... columns=['A', 'B', 'C'])
197
+
198
+ >>> df
199
+ A B C
200
+ 0 1 1 0.362
201
+ 1 1 2 0.227
202
+ 2 2 3 1.267
203
+ 3 2 4 -0.562
204
+
205
+ Different aggregations per column
206
+
207
+ >>> aggregated = df.groupby('A').agg({'B': 'min', 'C': 'sum'})
208
+ >>> aggregated[['B', 'C']].sort_index() # doctest: +NORMALIZE_WHITESPACE
209
+ B C
210
+ A
211
+ 1 1 0.589
212
+ 2 3 0.705
213
+
214
+ >>> aggregated = df.groupby('A').agg({'B': ['min', 'max']})
215
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
216
+ B
217
+ min max
218
+ A
219
+ 1 1 2
220
+ 2 3 4
221
+
222
+ >>> aggregated = df.groupby('A').agg('min')
223
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
224
+ B C
225
+ A
226
+ 1 1 0.227
227
+ 2 3 -0.562
228
+
229
+ >>> aggregated = df.groupby('A').agg(['min', 'max'])
230
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
231
+ B C
232
+ min max min max
233
+ A
234
+ 1 1 2 0.227 0.362
235
+ 2 3 4 -0.562 1.267
236
+
237
+ To control the output names with different aggregations per column, pandas-on-Spark
238
+ also supports 'named aggregation' or nested renaming in .agg. It can also be
239
+ used when applying multiple aggregation functions to specific columns.
240
+
241
+ >>> aggregated = df.groupby('A').agg(b_max=ps.NamedAgg(column='B', aggfunc='max'))
242
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
243
+ b_max
244
+ A
245
+ 1 2
246
+ 2 4
247
+
248
+ >>> aggregated = df.groupby('A').agg(b_max=('B', 'max'), b_min=('B', 'min'))
249
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
250
+ b_max b_min
251
+ A
252
+ 1 2 1
253
+ 2 4 3
254
+
255
+ >>> aggregated = df.groupby('A').agg(b_max=('B', 'max'), c_min=('C', 'min'))
256
+ >>> aggregated.sort_index() # doctest: +NORMALIZE_WHITESPACE
257
+ b_max c_min
258
+ A
259
+ 1 2 0.227
260
+ 2 4 -0.562
261
+ """
262
+ # I think current implementation of func and arguments in pandas-on-Spark for aggregate
263
+ # is different than pandas, later once arguments are added, this could be removed.
264
+ if func_or_funcs is None and kwargs is None:
265
+ raise ValueError("No aggregation argument or function specified.")
266
+
267
+ relabeling = func_or_funcs is None and is_multi_agg_with_relabel(**kwargs)
268
+ if relabeling:
269
+ (
270
+ func_or_funcs,
271
+ columns,
272
+ order,
273
+ ) = normalize_keyword_aggregation( # type: ignore[assignment]
274
+ kwargs
275
+ )
276
+
277
+ if not isinstance(func_or_funcs, (str, list)):
278
+ if not isinstance(func_or_funcs, dict) or not all(
279
+ is_name_like_value(key)
280
+ and (
281
+ isinstance(value, str)
282
+ or isinstance(value, list)
283
+ and all(isinstance(v, str) for v in value)
284
+ )
285
+ for key, value in func_or_funcs.items()
286
+ ):
287
+ raise ValueError(
288
+ "aggs must be a dict mapping from column name "
289
+ "to aggregate functions (string or list of strings)."
290
+ )
291
+
292
+ else:
293
+ agg_cols = [col.name for col in self._agg_columns]
294
+ func_or_funcs = {col: func_or_funcs for col in agg_cols}
295
+
296
+ psdf: DataFrame = DataFrame(
297
+ GroupBy._spark_groupby(self._psdf, func_or_funcs, self._groupkeys)
298
+ )
299
+
300
+ if self._dropna:
301
+ psdf = DataFrame(
302
+ psdf._internal.with_new_sdf(
303
+ psdf._internal.spark_frame.dropna(
304
+ subset=psdf._internal.index_spark_column_names
305
+ )
306
+ )
307
+ )
308
+
309
+ if not self._as_index:
310
+ should_drop_index = set(
311
+ i for i, gkey in enumerate(self._groupkeys) if gkey._psdf is not self._psdf
312
+ )
313
+ if len(should_drop_index) > 0:
314
+ psdf = psdf.reset_index(level=should_drop_index, drop=True)
315
+ if len(should_drop_index) < len(self._groupkeys):
316
+ psdf = psdf.reset_index()
317
+
318
+ if relabeling:
319
+ psdf = psdf[order]
320
+ psdf.columns = columns # type: ignore[assignment]
321
+ return psdf
322
+
323
+ agg = aggregate
324
+
325
+ @staticmethod
326
+ def _spark_groupby(
327
+ psdf: DataFrame,
328
+ func: Mapping[Name, Union[str, List[str]]],
329
+ groupkeys: Sequence[Series] = (),
330
+ ) -> InternalFrame:
331
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
332
+ groupkey_scols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]
333
+
334
+ multi_aggs = any(isinstance(v, list) for v in func.values())
335
+ reordered = []
336
+ data_columns = []
337
+ column_labels = []
338
+ for key, value in func.items():
339
+ label = key if is_name_like_tuple(key) else (key,)
340
+ if len(label) != psdf._internal.column_labels_level:
341
+ raise TypeError("The length of the key must be the same as the column label level.")
342
+ for aggfunc in [value] if isinstance(value, str) else value:
343
+ column_label = tuple(list(label) + [aggfunc]) if multi_aggs else label
344
+ column_labels.append(column_label)
345
+
346
+ data_col = name_like_string(column_label)
347
+ data_columns.append(data_col)
348
+
349
+ col_name = psdf._internal.spark_column_name_for(label)
350
+ if aggfunc == "nunique":
351
+ reordered.append(
352
+ F.expr("count(DISTINCT `{0}`) as `{1}`".format(col_name, data_col))
353
+ )
354
+
355
+ # Implement "quartiles" aggregate function for ``describe``.
356
+ elif aggfunc == "quartiles":
357
+ reordered.append(
358
+ F.expr(
359
+ "percentile_approx(`{0}`, array(0.25, 0.5, 0.75)) as `{1}`".format(
360
+ col_name, data_col
361
+ )
362
+ )
363
+ )
364
+
365
+ else:
366
+ reordered.append(
367
+ F.expr("{1}(`{0}`) as `{2}`".format(col_name, aggfunc, data_col))
368
+ )
369
+
370
+ sdf = psdf._internal.spark_frame.select(groupkey_scols + psdf._internal.data_spark_columns)
371
+ sdf = sdf.groupby(*groupkey_names).agg(*reordered)
372
+
373
+ return InternalFrame(
374
+ spark_frame=sdf,
375
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
376
+ index_names=[psser._column_label for psser in groupkeys],
377
+ index_fields=[
378
+ psser._internal.data_fields[0].copy(name=name)
379
+ for psser, name in zip(groupkeys, groupkey_names)
380
+ ],
381
+ column_labels=column_labels,
382
+ data_spark_columns=[scol_for(sdf, col) for col in data_columns],
383
+ )
384
+
385
+ def count(self) -> FrameLike:
386
+ """
387
+ Compute count of group, excluding missing values.
388
+
389
+ See Also
390
+ --------
391
+ pyspark.pandas.Series.groupby
392
+ pyspark.pandas.DataFrame.groupby
393
+
394
+ Examples
395
+ --------
396
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
397
+ ... 'B': [np.nan, 2, 3, 4, 5],
398
+ ... 'C': [1, 2, 1, 1, 2]}, columns=['A', 'B', 'C'])
399
+ >>> df.groupby('A').count().sort_index() # doctest: +NORMALIZE_WHITESPACE
400
+ B C
401
+ A
402
+ 1 2 3
403
+ 2 2 2
404
+ """
405
+ return self._reduce_for_stat_function(F.count)
406
+
407
+ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
408
+ """
409
+ Compute first of group values.
410
+
411
+ .. versionadded:: 3.3.0
412
+
413
+ Parameters
414
+ ----------
415
+ numeric_only : bool, default False
416
+ Include only float, int, boolean columns. If None, will attempt to use
417
+ everything, then use only numeric data.
418
+
419
+ .. versionadded:: 3.4.0
420
+ min_count : int, default -1
421
+ The required number of valid values to perform the operation. If fewer
422
+ than ``min_count`` non-NA values are present the result will be NA.
423
+
424
+ .. versionadded:: 3.4.0
425
+
426
+ See Also
427
+ --------
428
+ pyspark.pandas.Series.groupby
429
+ pyspark.pandas.DataFrame.groupby
430
+
431
+ Examples
432
+ --------
433
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
434
+ ... "C": [3, 3, 4, 4], "D": ["a", "b", "a", "a"]})
435
+ >>> df
436
+ A B C D
437
+ 0 1 True 3 a
438
+ 1 2 False 3 b
439
+ 2 1 False 4 a
440
+ 3 2 True 4 a
441
+
442
+ >>> df.groupby("A").first().sort_index()
443
+ B C D
444
+ A
445
+ 1 True 3 a
446
+ 2 False 3 b
447
+
448
+ Include only float, int, boolean columns when set numeric_only True.
449
+
450
+ >>> df.groupby("A").first(numeric_only=True).sort_index()
451
+ B C
452
+ A
453
+ 1 True 3
454
+ 2 False 3
455
+
456
+ >>> df.groupby("D").first().sort_index()
457
+ A B C
458
+ D
459
+ a 1 True 3
460
+ b 2 False 3
461
+
462
+ >>> df.groupby("D").first(min_count=3).sort_index()
463
+ A B C
464
+ D
465
+ a 1.0 True 3.0
466
+ b NaN None NaN
467
+ """
468
+ if not isinstance(min_count, int):
469
+ raise TypeError("min_count must be integer")
470
+
471
+ return self._reduce_for_stat_function(
472
+ lambda col: F.first(col, ignorenulls=True),
473
+ accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
474
+ min_count=min_count,
475
+ )
476
+
477
+ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
478
+ """
479
+ Compute last of group values.
480
+
481
+ .. versionadded:: 3.3.0
482
+
483
+ Parameters
484
+ ----------
485
+ numeric_only : bool, default False
486
+ Include only float, int, boolean columns. If None, will attempt to use
487
+ everything, then use only numeric data.
488
+
489
+ .. versionadded:: 3.4.0
490
+ min_count : int, default -1
491
+ The required number of valid values to perform the operation. If fewer
492
+ than ``min_count`` non-NA values are present the result will be NA.
493
+
494
+ .. versionadded:: 3.4.0
495
+
496
+ See Also
497
+ --------
498
+ pyspark.pandas.Series.groupby
499
+ pyspark.pandas.DataFrame.groupby
500
+
501
+ Examples
502
+ --------
503
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
504
+ ... "C": [3, 3, 4, 4], "D": ["a", "a", "b", "a"]})
505
+ >>> df
506
+ A B C D
507
+ 0 1 True 3 a
508
+ 1 2 False 3 a
509
+ 2 1 False 4 b
510
+ 3 2 True 4 a
511
+
512
+ >>> df.groupby("A").last().sort_index()
513
+ B C D
514
+ A
515
+ 1 False 4 b
516
+ 2 True 4 a
517
+
518
+ Include only float, int, boolean columns when set numeric_only True.
519
+
520
+ >>> df.groupby("A").last(numeric_only=True).sort_index()
521
+ B C
522
+ A
523
+ 1 False 4
524
+ 2 True 4
525
+
526
+ >>> df.groupby("D").last().sort_index()
527
+ A B C
528
+ D
529
+ a 2 True 4
530
+ b 1 False 4
531
+
532
+ >>> df.groupby("D").last(min_count=3).sort_index()
533
+ A B C
534
+ D
535
+ a 2.0 True 4.0
536
+ b NaN None NaN
537
+ """
538
+ if not isinstance(min_count, int):
539
+ raise TypeError("min_count must be integer")
540
+
541
+ return self._reduce_for_stat_function(
542
+ lambda col: F.last(col, ignorenulls=True),
543
+ accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
544
+ min_count=min_count,
545
+ )
546
+
547
+ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
548
+ """
549
+ Compute max of group values.
550
+
551
+ .. versionadded:: 3.3.0
552
+
553
+ Parameters
554
+ ----------
555
+ numeric_only : bool, default False
556
+ Include only float, int, boolean columns. If None, will attempt to use
557
+ everything, then use only numeric data.
558
+
559
+ .. versionadded:: 3.4.0
560
+ min_count : bool, default -1
561
+ The required number of valid values to perform the operation. If fewer
562
+ than min_count non-NA values are present the result will be NA.
563
+
564
+ .. versionadded:: 3.4.0
565
+
566
+ See Also
567
+ --------
568
+ pyspark.pandas.Series.groupby
569
+ pyspark.pandas.DataFrame.groupby
570
+
571
+ Examples
572
+ --------
573
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
574
+ ... "C": [3, 4, 3, 4], "D": ["a", "a", "b", "a"]})
575
+
576
+ >>> df.groupby("A").max().sort_index()
577
+ B C D
578
+ A
579
+ 1 True 3 b
580
+ 2 True 4 a
581
+
582
+ Include only float, int, boolean columns when set numeric_only True.
583
+
584
+ >>> df.groupby("A").max(numeric_only=True).sort_index()
585
+ B C
586
+ A
587
+ 1 True 3
588
+ 2 True 4
589
+
590
+ >>> df.groupby("D").max().sort_index()
591
+ A B C
592
+ D
593
+ a 2 True 4
594
+ b 1 False 3
595
+
596
+ >>> df.groupby("D").max(min_count=3).sort_index()
597
+ A B C
598
+ D
599
+ a 2.0 True 4.0
600
+ b NaN None NaN
601
+ """
602
+ if not isinstance(min_count, int):
603
+ raise TypeError("min_count must be integer")
604
+
605
+ return self._reduce_for_stat_function(
606
+ F.max,
607
+ accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
608
+ min_count=min_count,
609
+ )
610
+
611
+ def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
612
+ """
613
+ Compute mean of groups, excluding missing values.
614
+
615
+ Parameters
616
+ ----------
617
+ numeric_only : bool, default False
618
+ Include only float, int, boolean columns. If None, will attempt to use
619
+ everything, then use only numeric data.
620
+
621
+ .. versionadded:: 3.4.0
622
+
623
+ Returns
624
+ -------
625
+ pyspark.pandas.Series or pyspark.pandas.DataFrame
626
+
627
+ See Also
628
+ --------
629
+ pyspark.pandas.Series.groupby
630
+ pyspark.pandas.DataFrame.groupby
631
+
632
+ Examples
633
+ --------
634
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
635
+ ... 'B': [np.nan, 2, 3, 4, 5],
636
+ ... 'C': [1, 2, 1, 1, 2],
637
+ ... 'D': [True, False, True, False, True]})
638
+
639
+ Groupby one column and return the mean of the remaining columns in
640
+ each group.
641
+
642
+ >>> df.groupby('A').mean().sort_index() # doctest: +NORMALIZE_WHITESPACE
643
+ B C D
644
+ A
645
+ 1 3.0 1.333333 0.333333
646
+ 2 4.0 1.500000 1.000000
647
+ """
648
+ self._validate_agg_columns(numeric_only=numeric_only, function_name="median")
649
+ warnings.warn(
650
+ "Default value of `numeric_only` will be changed to `False` "
651
+ "instead of `True` in 4.0.0.",
652
+ FutureWarning,
653
+ )
654
+
655
+ return self._reduce_for_stat_function(
656
+ F.mean, accepted_spark_types=(NumericType,), bool_to_numeric=True
657
+ )
658
+
659
+ # TODO: 'q' accepts list like type
660
+ def quantile(self, q: float = 0.5, accuracy: int = 10000) -> FrameLike:
661
+ """
662
+ Return group values at the given quantile.
663
+
664
+ .. versionadded:: 3.4.0
665
+
666
+ Parameters
667
+ ----------
668
+ q : float, default 0.5 (50% quantile)
669
+ Value between 0 and 1 providing the quantile to compute.
670
+ accuracy : int, optional
671
+ Default accuracy of approximation. Larger value means better accuracy.
672
+ The relative error can be deduced by 1.0 / accuracy.
673
+ This is a panda-on-Spark specific parameter.
674
+
675
+ Returns
676
+ -------
677
+ pyspark.pandas.Series or pyspark.pandas.DataFrame
678
+ Return type determined by caller of GroupBy object.
679
+
680
+ Notes
681
+ -----
682
+ `quantile` in pandas-on-Spark are using distributed percentile approximation
683
+ algorithm unlike pandas, the result might be different with pandas, also
684
+ `interpolation` parameter is not supported yet.
685
+
686
+ See Also
687
+ --------
688
+ pyspark.pandas.Series.quantile
689
+ pyspark.pandas.DataFrame.quantile
690
+ pyspark.sql.functions.percentile_approx
691
+
692
+ Examples
693
+ --------
694
+ >>> df = ps.DataFrame([
695
+ ... ['a', 1], ['a', 2], ['a', 3],
696
+ ... ['b', 1], ['b', 3], ['b', 5]
697
+ ... ], columns=['key', 'val'])
698
+
699
+ Groupby one column and return the quantile of the remaining columns in
700
+ each group.
701
+
702
+ >>> df.groupby('key').quantile()
703
+ val
704
+ key
705
+ a 2.0
706
+ b 3.0
707
+ """
708
+ if is_list_like(q):
709
+ raise NotImplementedError("q doesn't support for list like type for now")
710
+ if not is_number(q):
711
+ raise TypeError("must be real number, not %s" % type(q).__name__)
712
+ if not 0 <= q <= 1:
713
+ raise ValueError("'q' must be between 0 and 1. Got '%s' instead" % q)
714
+ return self._reduce_for_stat_function(
715
+ lambda col: F.percentile_approx(col.cast(DoubleType()), q, accuracy),
716
+ accepted_spark_types=(NumericType, BooleanType),
717
+ bool_to_numeric=True,
718
+ )
719
+
720
+ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
721
+ """
722
+ Compute min of group values.
723
+
724
+ .. versionadded:: 3.3.0
725
+
726
+ Parameters
727
+ ----------
728
+ numeric_only : bool, default False
729
+ Include only float, int, boolean columns. If None, will attempt to use
730
+ everything, then use only numeric data.
731
+
732
+ .. versionadded:: 3.4.0
733
+ min_count : bool, default -1
734
+ The required number of valid values to perform the operation. If fewer
735
+ than min_count non-NA values are present the result will be NA.
736
+
737
+ .. versionadded:: 3.4.0
738
+
739
+ See Also
740
+ --------
741
+ pyspark.pandas.Series.groupby
742
+ pyspark.pandas.DataFrame.groupby
743
+
744
+ Examples
745
+ --------
746
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
747
+ ... "C": [3, 4, 3, 4], "D": ["a", "a", "b", "a"]})
748
+ >>> df.groupby("A").min().sort_index()
749
+ B C D
750
+ A
751
+ 1 False 3 a
752
+ 2 False 4 a
753
+
754
+ Include only float, int, boolean columns when set numeric_only True.
755
+
756
+ >>> df.groupby("A").min(numeric_only=True).sort_index()
757
+ B C
758
+ A
759
+ 1 False 3
760
+ 2 False 4
761
+
762
+ >>> df.groupby("D").min().sort_index()
763
+ A B C
764
+ D
765
+ a 1 False 3
766
+ b 1 False 3
767
+
768
+
769
+ >>> df.groupby("D").min(min_count=3).sort_index()
770
+ A B C
771
+ D
772
+ a 1.0 False 3.0
773
+ b NaN None NaN
774
+ """
775
+ if not isinstance(min_count, int):
776
+ raise TypeError("min_count must be integer")
777
+
778
+ return self._reduce_for_stat_function(
779
+ F.min,
780
+ accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
781
+ min_count=min_count,
782
+ )
783
+
784
+ # TODO: sync the doc.
785
+ def std(self, ddof: int = 1) -> FrameLike:
786
+ """
787
+ Compute standard deviation of groups, excluding missing values.
788
+
789
+ .. versionadded:: 3.3.0
790
+
791
+ Parameters
792
+ ----------
793
+ ddof : int, default 1
794
+ Delta Degrees of Freedom. The divisor used in calculations is N - ddof,
795
+ where N represents the number of elements.
796
+
797
+ .. versionchanged:: 3.4.0
798
+ Supported including arbitary integers.
799
+
800
+ Examples
801
+ --------
802
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
803
+ ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]})
804
+
805
+ >>> df.groupby("A").std()
806
+ B C
807
+ A
808
+ 1 0.707107 0.0
809
+ 2 0.707107 0.0
810
+
811
+ See Also
812
+ --------
813
+ pyspark.pandas.Series.groupby
814
+ pyspark.pandas.DataFrame.groupby
815
+ """
816
+ if not isinstance(ddof, int):
817
+ raise TypeError("ddof must be integer")
818
+
819
+ # Raise the TypeError when all aggregation columns are of unaccepted data types
820
+ any_accepted = any(
821
+ isinstance(_agg_col.spark.data_type, (NumericType, BooleanType))
822
+ for _agg_col in self._agg_columns
823
+ )
824
+ if not any_accepted:
825
+ raise TypeError(
826
+ "Unaccepted data types of aggregation columns; numeric or bool expected."
827
+ )
828
+
829
+ def std(col: Column) -> Column:
830
+ return SF.stddev(col, ddof)
831
+
832
+ return self._reduce_for_stat_function(
833
+ std,
834
+ accepted_spark_types=(NumericType,),
835
+ bool_to_numeric=True,
836
+ )
837
+
838
+ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
839
+ """
840
+ Compute sum of group values
841
+
842
+ .. versionadded:: 3.3.0
843
+
844
+ Parameters
845
+ ----------
846
+ numeric_only : bool, default False
847
+ Include only float, int, boolean columns. If None, will attempt to use
848
+ everything, then use only numeric data.
849
+ It takes no effect since only numeric columns can be support here.
850
+
851
+ .. versionadded:: 3.4.0
852
+ min_count : int, default 0
853
+ The required number of valid values to perform the operation.
854
+ If fewer than min_count non-NA values are present the result will be NA.
855
+
856
+ .. versionadded:: 3.4.0
857
+
858
+ Examples
859
+ --------
860
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
861
+ ... "C": [3, 4, 3, 4], "D": ["a", "a", "b", "a"]})
862
+
863
+ >>> df.groupby("A").sum().sort_index()
864
+ B C
865
+ A
866
+ 1 1 6
867
+ 2 1 8
868
+
869
+ >>> df.groupby("D").sum().sort_index()
870
+ A B C
871
+ D
872
+ a 5 2 11
873
+ b 1 0 3
874
+
875
+ >>> df.groupby("D").sum(min_count=3).sort_index()
876
+ A B C
877
+ D
878
+ a 5.0 2.0 11.0
879
+ b NaN NaN NaN
880
+
881
+ Notes
882
+ -----
883
+ There is a behavior difference between pandas-on-Spark and pandas:
884
+
885
+ * when there is a non-numeric aggregation column, it will be ignored
886
+ even if `numeric_only` is False.
887
+
888
+ See Also
889
+ --------
890
+ pyspark.pandas.Series.groupby
891
+ pyspark.pandas.DataFrame.groupby
892
+ """
893
+ warnings.warn(
894
+ "Default value of `numeric_only` will be changed to `False` "
895
+ "instead of `True` in 4.0.0.",
896
+ FutureWarning,
897
+ )
898
+ if numeric_only is not None and not isinstance(numeric_only, bool):
899
+ raise TypeError("numeric_only must be None or bool")
900
+ if not isinstance(min_count, int):
901
+ raise TypeError("min_count must be integer")
902
+
903
+ if numeric_only is not None and not numeric_only:
904
+ unsupported = [
905
+ col.name
906
+ for col in self._agg_columns
907
+ if not isinstance(col.spark.data_type, (NumericType, BooleanType))
908
+ ]
909
+ if len(unsupported) > 0:
910
+ log_advice(
911
+ "GroupBy.sum() can only support numeric and bool columns even if"
912
+ f"numeric_only=False, skip unsupported columns: {unsupported}"
913
+ )
914
+
915
+ return self._reduce_for_stat_function(
916
+ F.sum,
917
+ accepted_spark_types=(NumericType, BooleanType),
918
+ bool_to_numeric=True,
919
+ min_count=min_count,
920
+ )
921
+
922
+ # TODO: sync the doc.
923
+ def var(self, ddof: int = 1) -> FrameLike:
924
+ """
925
+ Compute variance of groups, excluding missing values.
926
+
927
+ .. versionadded:: 3.3.0
928
+
929
+ Parameters
930
+ ----------
931
+ ddof : int, default 1
932
+ Delta Degrees of Freedom. The divisor used in calculations is N - ddof,
933
+ where N represents the number of elements.
934
+
935
+ .. versionchanged:: 3.4.0
936
+ Supported including arbitary integers.
937
+
938
+ Examples
939
+ --------
940
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
941
+ ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]})
942
+
943
+ >>> df.groupby("A").var()
944
+ B C
945
+ A
946
+ 1 0.5 0.0
947
+ 2 0.5 0.0
948
+
949
+ See Also
950
+ --------
951
+ pyspark.pandas.Series.groupby
952
+ pyspark.pandas.DataFrame.groupby
953
+ """
954
+ if not isinstance(ddof, int):
955
+ raise TypeError("ddof must be integer")
956
+
957
+ def var(col: Column) -> Column:
958
+ return SF.var(col, ddof)
959
+
960
+ return self._reduce_for_stat_function(
961
+ var,
962
+ accepted_spark_types=(NumericType,),
963
+ bool_to_numeric=True,
964
+ )
965
+
966
+ def skew(self) -> FrameLike:
967
+ """
968
+ Compute skewness of groups, excluding missing values.
969
+
970
+ .. versionadded:: 3.4.0
971
+
972
+ Examples
973
+ --------
974
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, True],
975
+ ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]})
976
+
977
+ >>> df.groupby("A").skew()
978
+ B C
979
+ A
980
+ 1 -1.732051 1.732051
981
+ 2 NaN NaN
982
+
983
+ See Also
984
+ --------
985
+ pyspark.pandas.Series.groupby
986
+ pyspark.pandas.DataFrame.groupby
987
+ """
988
+ return self._reduce_for_stat_function(
989
+ SF.skew,
990
+ accepted_spark_types=(NumericType,),
991
+ bool_to_numeric=True,
992
+ )
993
+
994
+ # TODO: 'axis', 'skipna', 'level' parameter should be implemented.
995
+ def mad(self) -> FrameLike:
996
+ """
997
+ Compute mean absolute deviation of groups, excluding missing values.
998
+
999
+ .. versionadded:: 3.4.0
1000
+
1001
+ .. deprecated:: 3.4.0
1002
+
1003
+ Examples
1004
+ --------
1005
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, True],
1006
+ ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]})
1007
+
1008
+ >>> df.groupby("A").mad()
1009
+ B C
1010
+ A
1011
+ 1 0.444444 0.444444
1012
+ 2 0.000000 0.000000
1013
+
1014
+ >>> df.B.groupby(df.A).mad()
1015
+ A
1016
+ 1 0.444444
1017
+ 2 0.000000
1018
+ Name: B, dtype: float64
1019
+
1020
+ See Also
1021
+ --------
1022
+ pyspark.pandas.Series.groupby
1023
+ pyspark.pandas.DataFrame.groupby
1024
+ """
1025
+ warnings.warn(
1026
+ "The 'mad' method is deprecated and will be removed in a future version. "
1027
+ "To compute the same result, you may do `(group_df - group_df.mean()).abs().mean()`.",
1028
+ FutureWarning,
1029
+ )
1030
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
1031
+ internal, agg_columns, sdf = self._prepare_reduce(
1032
+ groupkey_names=groupkey_names,
1033
+ accepted_spark_types=(NumericType, BooleanType),
1034
+ bool_to_numeric=False,
1035
+ )
1036
+ psdf: DataFrame = DataFrame(internal)
1037
+
1038
+ if len(psdf._internal.column_labels) > 0:
1039
+ window = Window.partitionBy(groupkey_names).rowsBetween(
1040
+ Window.unboundedPreceding, Window.unboundedFollowing
1041
+ )
1042
+ new_agg_scols = {}
1043
+ new_stat_scols = []
1044
+ for agg_column in agg_columns:
1045
+ # it is not able to directly use 'self._reduce_for_stat_function', due to
1046
+ # 'it is not allowed to use a window function inside an aggregate function'.
1047
+ # so we need to create temporary columns to compute the 'abs(x - avg(x))' here.
1048
+ agg_column_name = agg_column._internal.data_spark_column_names[0]
1049
+ new_agg_column_name = verify_temp_column_name(
1050
+ psdf._internal.spark_frame, "__tmp_agg_col_{}__".format(agg_column_name)
1051
+ )
1052
+ casted_agg_scol = F.col(agg_column_name).cast("double")
1053
+ new_agg_scols[new_agg_column_name] = F.abs(
1054
+ casted_agg_scol - F.avg(casted_agg_scol).over(window)
1055
+ )
1056
+ new_stat_scols.append(F.avg(F.col(new_agg_column_name)).alias(agg_column_name))
1057
+
1058
+ sdf = (
1059
+ psdf._internal.spark_frame.withColumns(new_agg_scols)
1060
+ .groupby(groupkey_names)
1061
+ .agg(*new_stat_scols)
1062
+ )
1063
+ else:
1064
+ sdf = sdf.select(*groupkey_names).distinct()
1065
+
1066
+ internal = internal.copy(
1067
+ spark_frame=sdf,
1068
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
1069
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
1070
+ data_fields=None,
1071
+ )
1072
+
1073
+ return self._prepare_return(DataFrame(internal))
1074
+
1075
+ def sem(self, ddof: int = 1) -> FrameLike:
1076
+ """
1077
+ Compute standard error of the mean of groups, excluding missing values.
1078
+
1079
+ .. versionadded:: 3.4.0
1080
+
1081
+ Parameters
1082
+ ----------
1083
+ ddof : int, default 1
1084
+ Delta Degrees of Freedom. The divisor used in calculations is N - ddof,
1085
+ where N represents the number of elements.
1086
+
1087
+ Examples
1088
+ --------
1089
+ >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, True],
1090
+ ... "C": [3, None, 3, 4], "D": ["a", "b", "b", "a"]})
1091
+
1092
+ >>> df.groupby("A").sem()
1093
+ B C
1094
+ A
1095
+ 1 0.333333 0.333333
1096
+ 2 NaN NaN
1097
+
1098
+ >>> df.groupby("D").sem(ddof=1)
1099
+ A B C
1100
+ D
1101
+ a 0.0 0.0 0.5
1102
+ b 0.5 0.0 NaN
1103
+
1104
+ >>> df.B.groupby(df.A).sem()
1105
+ A
1106
+ 1 0.333333
1107
+ 2 NaN
1108
+ Name: B, dtype: float64
1109
+
1110
+ See Also
1111
+ --------
1112
+ pyspark.pandas.Series.sem
1113
+ pyspark.pandas.DataFrame.sem
1114
+ """
1115
+ if not isinstance(ddof, int):
1116
+ raise TypeError("ddof must be integer")
1117
+
1118
+ # Raise the TypeError when all aggregation columns are of unaccepted data types
1119
+ any_accepted = any(
1120
+ isinstance(_agg_col.spark.data_type, (NumericType, BooleanType))
1121
+ for _agg_col in self._agg_columns
1122
+ )
1123
+ if not any_accepted:
1124
+ raise TypeError(
1125
+ "Unaccepted data types of aggregation columns; numeric or bool expected."
1126
+ )
1127
+
1128
+ def sem(col: Column) -> Column:
1129
+ return SF.stddev(col, ddof) / F.sqrt(F.count(col))
1130
+
1131
+ return self._reduce_for_stat_function(
1132
+ sem,
1133
+ accepted_spark_types=(NumericType, BooleanType),
1134
+ bool_to_numeric=True,
1135
+ )
1136
+
1137
+ # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
1138
+ def nth(self, n: int) -> FrameLike:
1139
+ """
1140
+ Take the nth row from each group.
1141
+
1142
+ .. versionadded:: 3.4.0
1143
+
1144
+ Parameters
1145
+ ----------
1146
+ n : int
1147
+ A single nth value for the row
1148
+
1149
+ Returns
1150
+ -------
1151
+ Series or DataFrame
1152
+
1153
+ Notes
1154
+ -----
1155
+ There is a behavior difference between pandas-on-Spark and pandas:
1156
+
1157
+ * when there is no aggregation column, and `n` not equal to 0 or -1,
1158
+ the returned empty dataframe may have an index with different lenght `__len__`.
1159
+
1160
+ Examples
1161
+ --------
1162
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
1163
+ ... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
1164
+ >>> g = df.groupby('A')
1165
+ >>> g.nth(0)
1166
+ B
1167
+ A
1168
+ 1 NaN
1169
+ 2 3.0
1170
+ >>> g.nth(1)
1171
+ B
1172
+ A
1173
+ 1 2.0
1174
+ 2 5.0
1175
+ >>> g.nth(-1)
1176
+ B
1177
+ A
1178
+ 1 4.0
1179
+ 2 5.0
1180
+
1181
+ See Also
1182
+ --------
1183
+ pyspark.pandas.Series.groupby
1184
+ pyspark.pandas.DataFrame.groupby
1185
+ """
1186
+ if isinstance(n, slice) or is_list_like(n):
1187
+ raise NotImplementedError("n doesn't support slice or list for now")
1188
+ if not isinstance(n, int):
1189
+ raise TypeError("Invalid index %s" % type(n).__name__)
1190
+
1191
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
1192
+ internal, agg_columns, sdf = self._prepare_reduce(
1193
+ groupkey_names=groupkey_names,
1194
+ accepted_spark_types=None,
1195
+ bool_to_numeric=False,
1196
+ )
1197
+ psdf: DataFrame = DataFrame(internal)
1198
+
1199
+ if len(psdf._internal.column_labels) > 0:
1200
+ window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
1201
+ tmp_row_number_col = verify_temp_column_name(sdf, "__tmp_row_number_col__")
1202
+ if n >= 0:
1203
+ sdf = (
1204
+ psdf._internal.spark_frame.withColumn(
1205
+ tmp_row_number_col, F.row_number().over(window1)
1206
+ )
1207
+ .where(F.col(tmp_row_number_col) == n + 1)
1208
+ .drop(tmp_row_number_col)
1209
+ )
1210
+ else:
1211
+ window2 = Window.partitionBy(*groupkey_names).rowsBetween(
1212
+ Window.unboundedPreceding, Window.unboundedFollowing
1213
+ )
1214
+ tmp_group_size_col = verify_temp_column_name(sdf, "__tmp_group_size_col__")
1215
+ sdf = (
1216
+ psdf._internal.spark_frame.withColumn(
1217
+ tmp_group_size_col, F.count(F.lit(0)).over(window2)
1218
+ )
1219
+ .withColumn(tmp_row_number_col, F.row_number().over(window1))
1220
+ .where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n)
1221
+ .drop(tmp_group_size_col, tmp_row_number_col)
1222
+ )
1223
+ else:
1224
+ sdf = sdf.select(*groupkey_names).distinct()
1225
+
1226
+ internal = internal.copy(
1227
+ spark_frame=sdf,
1228
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
1229
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
1230
+ data_fields=None,
1231
+ )
1232
+
1233
+ return self._prepare_return(DataFrame(internal))
1234
+
1235
+ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
1236
+ """
1237
+ Compute prod of groups.
1238
+
1239
+ .. versionadded:: 3.4.0
1240
+
1241
+ Parameters
1242
+ ----------
1243
+ numeric_only : bool, default False
1244
+ Include only float, int, boolean columns. If None, will attempt to use
1245
+ everything, then use only numeric data.
1246
+
1247
+ min_count : int, default 0
1248
+ The required number of valid values to perform the operation.
1249
+ If fewer than min_count non-NA values are present the result will be NA.
1250
+
1251
+ Returns
1252
+ -------
1253
+ Series or DataFrame
1254
+ Computed prod of values within each group.
1255
+
1256
+ See Also
1257
+ --------
1258
+ pyspark.pandas.Series.groupby
1259
+ pyspark.pandas.DataFrame.groupby
1260
+
1261
+ Examples
1262
+ --------
1263
+ >>> import numpy as np
1264
+ >>> df = ps.DataFrame(
1265
+ ... {
1266
+ ... "A": [1, 1, 2, 1, 2],
1267
+ ... "B": [np.nan, 2, 3, 4, 5],
1268
+ ... "C": [1, 2, 1, 1, 2],
1269
+ ... "D": [True, False, True, False, True],
1270
+ ... }
1271
+ ... )
1272
+
1273
+ Groupby one column and return the prod of the remaining columns in
1274
+ each group.
1275
+
1276
+ >>> df.groupby('A').prod().sort_index()
1277
+ B C D
1278
+ A
1279
+ 1 8.0 2 0
1280
+ 2 15.0 2 1
1281
+
1282
+ >>> df.groupby('A').prod(min_count=3).sort_index()
1283
+ B C D
1284
+ A
1285
+ 1 NaN 2.0 0.0
1286
+ 2 NaN NaN NaN
1287
+ """
1288
+ if not isinstance(min_count, int):
1289
+ raise TypeError("min_count must be integer")
1290
+
1291
+ warnings.warn(
1292
+ "Default value of `numeric_only` will be changed to `False` "
1293
+ "instead of `True` in 4.0.0.",
1294
+ FutureWarning,
1295
+ )
1296
+
1297
+ self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")
1298
+
1299
+ return self._reduce_for_stat_function(
1300
+ lambda col: SF.product(col, True),
1301
+ accepted_spark_types=(NumericType, BooleanType),
1302
+ bool_to_numeric=True,
1303
+ min_count=min_count,
1304
+ )
1305
+
1306
+ def all(self, skipna: bool = True) -> FrameLike:
1307
+ """
1308
+ Returns True if all values in the group are truthful, else False.
1309
+
1310
+ Parameters
1311
+ ----------
1312
+ skipna : bool, default True
1313
+ Flag to ignore NA(nan/null) values during truth testing.
1314
+
1315
+ See Also
1316
+ --------
1317
+ pyspark.pandas.Series.groupby
1318
+ pyspark.pandas.DataFrame.groupby
1319
+
1320
+ Examples
1321
+ --------
1322
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
1323
+ ... 'B': [True, True, True, False, False,
1324
+ ... False, None, True, None, False]},
1325
+ ... columns=['A', 'B'])
1326
+ >>> df
1327
+ A B
1328
+ 0 1 True
1329
+ 1 1 True
1330
+ 2 2 True
1331
+ 3 2 False
1332
+ 4 3 False
1333
+ 5 3 False
1334
+ 6 4 None
1335
+ 7 4 True
1336
+ 8 5 None
1337
+ 9 5 False
1338
+
1339
+ >>> df.groupby('A').all().sort_index() # doctest: +NORMALIZE_WHITESPACE
1340
+ B
1341
+ A
1342
+ 1 True
1343
+ 2 False
1344
+ 3 False
1345
+ 4 True
1346
+ 5 False
1347
+
1348
+ >>> df.groupby('A').all(skipna=False).sort_index() # doctest: +NORMALIZE_WHITESPACE
1349
+ B
1350
+ A
1351
+ 1 True
1352
+ 2 False
1353
+ 3 False
1354
+ 4 False
1355
+ 5 False
1356
+ """
1357
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
1358
+ internal, _, sdf = self._prepare_reduce(groupkey_names)
1359
+ psdf: DataFrame = DataFrame(internal)
1360
+
1361
+ def sfun(scol: Column, scol_type: DataType) -> Column:
1362
+ if isinstance(scol_type, NumericType) or skipna:
1363
+ # np.nan takes no effect to the result; None takes no effect if `skipna`
1364
+ all_col = F.min(F.coalesce(scol.cast("boolean"), F.lit(True)))
1365
+ else:
1366
+ # Take None as False when not `skipna`
1367
+ all_col = F.min(F.when(scol.isNull(), F.lit(False)).otherwise(scol.cast("boolean")))
1368
+ return all_col
1369
+
1370
+ if len(psdf._internal.column_labels) > 0:
1371
+ stat_exprs = []
1372
+ for label in psdf._internal.column_labels:
1373
+ psser = psdf._psser_for(label)
1374
+ stat_exprs.append(
1375
+ sfun(
1376
+ psser._dtype_op.nan_to_null(psser).spark.column, psser.spark.data_type
1377
+ ).alias(psser._internal.data_spark_column_names[0])
1378
+ )
1379
+ sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
1380
+ else:
1381
+ sdf = sdf.select(*groupkey_names).distinct()
1382
+
1383
+ internal = internal.copy(
1384
+ spark_frame=sdf,
1385
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
1386
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
1387
+ data_fields=None,
1388
+ )
1389
+
1390
+ return self._prepare_return(DataFrame(internal))
1391
+
1392
+ # TODO: skipna should be implemented.
1393
+ def any(self) -> FrameLike:
1394
+ """
1395
+ Returns True if any value in the group is truthful, else False.
1396
+
1397
+ See Also
1398
+ --------
1399
+ pyspark.pandas.Series.groupby
1400
+ pyspark.pandas.DataFrame.groupby
1401
+
1402
+ Examples
1403
+ --------
1404
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
1405
+ ... 'B': [True, True, True, False, False,
1406
+ ... False, None, True, None, False]},
1407
+ ... columns=['A', 'B'])
1408
+ >>> df
1409
+ A B
1410
+ 0 1 True
1411
+ 1 1 True
1412
+ 2 2 True
1413
+ 3 2 False
1414
+ 4 3 False
1415
+ 5 3 False
1416
+ 6 4 None
1417
+ 7 4 True
1418
+ 8 5 None
1419
+ 9 5 False
1420
+
1421
+ >>> df.groupby('A').any().sort_index() # doctest: +NORMALIZE_WHITESPACE
1422
+ B
1423
+ A
1424
+ 1 True
1425
+ 2 True
1426
+ 3 False
1427
+ 4 True
1428
+ 5 False
1429
+ """
1430
+ return self._reduce_for_stat_function(
1431
+ lambda col: F.max(F.coalesce(col.cast("boolean"), F.lit(False)))
1432
+ )
1433
+
1434
+ # TODO: groupby multiply columns should be implemented.
1435
+ def size(self) -> Series:
1436
+ """
1437
+ Compute group sizes.
1438
+
1439
+ See Also
1440
+ --------
1441
+ pyspark.pandas.Series.groupby
1442
+ pyspark.pandas.DataFrame.groupby
1443
+
1444
+ Examples
1445
+ --------
1446
+ >>> df = ps.DataFrame({'A': [1, 2, 2, 3, 3, 3],
1447
+ ... 'B': [1, 1, 2, 3, 3, 3]},
1448
+ ... columns=['A', 'B'])
1449
+ >>> df
1450
+ A B
1451
+ 0 1 1
1452
+ 1 2 1
1453
+ 2 2 2
1454
+ 3 3 3
1455
+ 4 3 3
1456
+ 5 3 3
1457
+
1458
+ >>> df.groupby('A').size().sort_index()
1459
+ A
1460
+ 1 1
1461
+ 2 2
1462
+ 3 3
1463
+ dtype: int64
1464
+
1465
+ >>> df.groupby(['A', 'B']).size().sort_index()
1466
+ A B
1467
+ 1 1 1
1468
+ 2 1 1
1469
+ 2 1
1470
+ 3 3 3
1471
+ dtype: int64
1472
+
1473
+ For Series,
1474
+
1475
+ >>> df.B.groupby(df.A).size().sort_index()
1476
+ A
1477
+ 1 1
1478
+ 2 2
1479
+ 3 3
1480
+ Name: B, dtype: int64
1481
+
1482
+ >>> df.groupby(df.A).B.size().sort_index()
1483
+ A
1484
+ 1 1
1485
+ 2 2
1486
+ 3 3
1487
+ Name: B, dtype: int64
1488
+ """
1489
+ groupkeys = self._groupkeys
1490
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
1491
+ groupkey_scols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]
1492
+ sdf = self._psdf._internal.spark_frame.select(
1493
+ groupkey_scols + self._psdf._internal.data_spark_columns
1494
+ )
1495
+ sdf = sdf.groupby(*groupkey_names).count()
1496
+ internal = InternalFrame(
1497
+ spark_frame=sdf,
1498
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
1499
+ index_names=[psser._column_label for psser in groupkeys],
1500
+ index_fields=[
1501
+ psser._internal.data_fields[0].copy(name=name)
1502
+ for psser, name in zip(groupkeys, groupkey_names)
1503
+ ],
1504
+ column_labels=[None],
1505
+ data_spark_columns=[scol_for(sdf, "count")],
1506
+ )
1507
+ return first_series(DataFrame(internal))
1508
+
1509
+ def diff(self, periods: int = 1) -> FrameLike:
1510
+ """
1511
+ First discrete difference of element.
1512
+
1513
+ Calculates the difference of a DataFrame element compared with another element in the
1514
+ DataFrame group (default is the element in the same column of the previous row).
1515
+
1516
+ Parameters
1517
+ ----------
1518
+ periods : int, default 1
1519
+ Periods to shift for calculating difference, accepts negative values.
1520
+
1521
+ Returns
1522
+ -------
1523
+ diffed : DataFrame or Series
1524
+
1525
+ See Also
1526
+ --------
1527
+ pyspark.pandas.Series.groupby
1528
+ pyspark.pandas.DataFrame.groupby
1529
+
1530
+ Examples
1531
+ --------
1532
+ >>> df = ps.DataFrame({'a': [1, 2, 3, 4, 5, 6],
1533
+ ... 'b': [1, 1, 2, 3, 5, 8],
1534
+ ... 'c': [1, 4, 9, 16, 25, 36]}, columns=['a', 'b', 'c'])
1535
+ >>> df
1536
+ a b c
1537
+ 0 1 1 1
1538
+ 1 2 1 4
1539
+ 2 3 2 9
1540
+ 3 4 3 16
1541
+ 4 5 5 25
1542
+ 5 6 8 36
1543
+
1544
+ >>> df.groupby(['b']).diff().sort_index()
1545
+ a c
1546
+ 0 NaN NaN
1547
+ 1 1.0 3.0
1548
+ 2 NaN NaN
1549
+ 3 NaN NaN
1550
+ 4 NaN NaN
1551
+ 5 NaN NaN
1552
+
1553
+ Difference with previous column in a group.
1554
+
1555
+ >>> df.groupby(['b'])['a'].diff().sort_index()
1556
+ 0 NaN
1557
+ 1 1.0
1558
+ 2 NaN
1559
+ 3 NaN
1560
+ 4 NaN
1561
+ 5 NaN
1562
+ Name: a, dtype: float64
1563
+ """
1564
+ return self._apply_series_op(
1565
+ lambda sg: sg._psser._diff(periods, part_cols=sg._groupkeys_scols), should_resolve=True
1566
+ )
1567
+
1568
+ def cumcount(self, ascending: bool = True) -> Series:
1569
+ """
1570
+ Number each item in each group from 0 to the length of that group - 1.
1571
+
1572
+ Essentially this is equivalent to
1573
+
1574
+ .. code-block:: python
1575
+
1576
+ self.apply(lambda x: pd.Series(np.arange(len(x)), x.index))
1577
+
1578
+ Parameters
1579
+ ----------
1580
+ ascending : bool, default True
1581
+ If False, number in reverse, from length of group - 1 to 0.
1582
+
1583
+ Returns
1584
+ -------
1585
+ Series
1586
+ Sequence number of each element within each group.
1587
+
1588
+ Examples
1589
+ --------
1590
+
1591
+ >>> df = ps.DataFrame([['a'], ['a'], ['a'], ['b'], ['b'], ['a']],
1592
+ ... columns=['A'])
1593
+ >>> df
1594
+ A
1595
+ 0 a
1596
+ 1 a
1597
+ 2 a
1598
+ 3 b
1599
+ 4 b
1600
+ 5 a
1601
+ >>> df.groupby('A').cumcount().sort_index()
1602
+ 0 0
1603
+ 1 1
1604
+ 2 2
1605
+ 3 0
1606
+ 4 1
1607
+ 5 3
1608
+ dtype: int64
1609
+ >>> df.groupby('A').cumcount(ascending=False).sort_index()
1610
+ 0 3
1611
+ 1 2
1612
+ 2 1
1613
+ 3 1
1614
+ 4 0
1615
+ 5 0
1616
+ dtype: int64
1617
+ """
1618
+ ret = (
1619
+ self._groupkeys[0]
1620
+ .rename()
1621
+ .spark.transform(lambda _: F.lit(0))
1622
+ ._cum(F.count, True, part_cols=self._groupkeys_scols, ascending=ascending)
1623
+ - 1
1624
+ )
1625
+ internal = ret._internal.resolved_copy
1626
+ return first_series(DataFrame(internal))
1627
+
1628
+ def cummax(self) -> FrameLike:
1629
+ """
1630
+ Cumulative max for each group.
1631
+
1632
+ Returns
1633
+ -------
1634
+ Series or DataFrame
1635
+
1636
+ See Also
1637
+ --------
1638
+ Series.cummax
1639
+ DataFrame.cummax
1640
+
1641
+ Examples
1642
+ --------
1643
+ >>> df = ps.DataFrame(
1644
+ ... [[1, None, 4], [1, 0.1, 3], [1, 20.0, 2], [4, 10.0, 1]],
1645
+ ... columns=list('ABC'))
1646
+ >>> df
1647
+ A B C
1648
+ 0 1 NaN 4
1649
+ 1 1 0.1 3
1650
+ 2 1 20.0 2
1651
+ 3 4 10.0 1
1652
+
1653
+ By default, iterates over rows and finds the sum in each column.
1654
+
1655
+ >>> df.groupby("A").cummax().sort_index()
1656
+ B C
1657
+ 0 NaN 4
1658
+ 1 0.1 4
1659
+ 2 20.0 4
1660
+ 3 10.0 1
1661
+
1662
+ It works as below in Series.
1663
+
1664
+ >>> df.C.groupby(df.A).cummax().sort_index()
1665
+ 0 4
1666
+ 1 4
1667
+ 2 4
1668
+ 3 1
1669
+ Name: C, dtype: int64
1670
+ """
1671
+ return self._apply_series_op(
1672
+ lambda sg: sg._psser._cum(F.max, True, part_cols=sg._groupkeys_scols),
1673
+ should_resolve=True,
1674
+ numeric_only=True,
1675
+ )
1676
+
1677
+ def cummin(self) -> FrameLike:
1678
+ """
1679
+ Cumulative min for each group.
1680
+
1681
+ Returns
1682
+ -------
1683
+ Series or DataFrame
1684
+
1685
+ See Also
1686
+ --------
1687
+ Series.cummin
1688
+ DataFrame.cummin
1689
+
1690
+ Examples
1691
+ --------
1692
+ >>> df = ps.DataFrame(
1693
+ ... [[1, None, 4], [1, 0.1, 3], [1, 20.0, 2], [4, 10.0, 1]],
1694
+ ... columns=list('ABC'))
1695
+ >>> df
1696
+ A B C
1697
+ 0 1 NaN 4
1698
+ 1 1 0.1 3
1699
+ 2 1 20.0 2
1700
+ 3 4 10.0 1
1701
+
1702
+ By default, iterates over rows and finds the sum in each column.
1703
+
1704
+ >>> df.groupby("A").cummin().sort_index()
1705
+ B C
1706
+ 0 NaN 4
1707
+ 1 0.1 3
1708
+ 2 0.1 2
1709
+ 3 10.0 1
1710
+
1711
+ It works as below in Series.
1712
+
1713
+ >>> df.B.groupby(df.A).cummin().sort_index()
1714
+ 0 NaN
1715
+ 1 0.1
1716
+ 2 0.1
1717
+ 3 10.0
1718
+ Name: B, dtype: float64
1719
+ """
1720
+ return self._apply_series_op(
1721
+ lambda sg: sg._psser._cum(F.min, True, part_cols=sg._groupkeys_scols),
1722
+ should_resolve=True,
1723
+ numeric_only=True,
1724
+ )
1725
+
1726
+ def cumprod(self) -> FrameLike:
1727
+ """
1728
+ Cumulative product for each group.
1729
+
1730
+ Returns
1731
+ -------
1732
+ Series or DataFrame
1733
+
1734
+ See Also
1735
+ --------
1736
+ Series.cumprod
1737
+ DataFrame.cumprod
1738
+
1739
+ Examples
1740
+ --------
1741
+ >>> df = ps.DataFrame(
1742
+ ... [[1, None, 4], [1, 0.1, 3], [1, 20.0, 2], [4, 10.0, 1]],
1743
+ ... columns=list('ABC'))
1744
+ >>> df
1745
+ A B C
1746
+ 0 1 NaN 4
1747
+ 1 1 0.1 3
1748
+ 2 1 20.0 2
1749
+ 3 4 10.0 1
1750
+
1751
+ By default, iterates over rows and finds the sum in each column.
1752
+
1753
+ >>> df.groupby("A").cumprod().sort_index()
1754
+ B C
1755
+ 0 NaN 4
1756
+ 1 0.1 12
1757
+ 2 2.0 24
1758
+ 3 10.0 1
1759
+
1760
+ It works as below in Series.
1761
+
1762
+ >>> df.B.groupby(df.A).cumprod().sort_index()
1763
+ 0 NaN
1764
+ 1 0.1
1765
+ 2 2.0
1766
+ 3 10.0
1767
+ Name: B, dtype: float64
1768
+ """
1769
+ return self._apply_series_op(
1770
+ lambda sg: sg._psser._cumprod(True, part_cols=sg._groupkeys_scols),
1771
+ should_resolve=True,
1772
+ numeric_only=True,
1773
+ )
1774
+
1775
+ def cumsum(self) -> FrameLike:
1776
+ """
1777
+ Cumulative sum for each group.
1778
+
1779
+ Returns
1780
+ -------
1781
+ Series or DataFrame
1782
+
1783
+ See Also
1784
+ --------
1785
+ Series.cumsum
1786
+ DataFrame.cumsum
1787
+
1788
+ Examples
1789
+ --------
1790
+ >>> df = ps.DataFrame(
1791
+ ... [[1, None, 4], [1, 0.1, 3], [1, 20.0, 2], [4, 10.0, 1]],
1792
+ ... columns=list('ABC'))
1793
+ >>> df
1794
+ A B C
1795
+ 0 1 NaN 4
1796
+ 1 1 0.1 3
1797
+ 2 1 20.0 2
1798
+ 3 4 10.0 1
1799
+
1800
+ By default, iterates over rows and finds the sum in each column.
1801
+
1802
+ >>> df.groupby("A").cumsum().sort_index()
1803
+ B C
1804
+ 0 NaN 4
1805
+ 1 0.1 7
1806
+ 2 20.1 9
1807
+ 3 10.0 1
1808
+
1809
+ It works as below in Series.
1810
+
1811
+ >>> df.B.groupby(df.A).cumsum().sort_index()
1812
+ 0 NaN
1813
+ 1 0.1
1814
+ 2 20.1
1815
+ 3 10.0
1816
+ Name: B, dtype: float64
1817
+ """
1818
+ return self._apply_series_op(
1819
+ lambda sg: sg._psser._cumsum(True, part_cols=sg._groupkeys_scols),
1820
+ should_resolve=True,
1821
+ numeric_only=True,
1822
+ )
1823
+
1824
+ def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Union[DataFrame, Series]:
1825
+ """
1826
+ Apply function `func` group-wise and combine the results together.
1827
+
1828
+ The function passed to `apply` must take a DataFrame as its first
1829
+ argument and return a DataFrame. `apply` will
1830
+ then take care of combining the results back together into a single
1831
+ dataframe. `apply` is therefore a highly flexible
1832
+ grouping method.
1833
+
1834
+ While `apply` is a very flexible method, its downside is that
1835
+ using it can be quite a bit slower than using more specific methods
1836
+ like `agg` or `transform`. pandas-on-Spark offers a wide range of method that will
1837
+ be much faster than using `apply` for their specific purposes, so try to
1838
+ use them before reaching for `apply`.
1839
+
1840
+ .. note:: this API executes the function once to infer the type which is
1841
+ potentially expensive, for instance, when the dataset is created after
1842
+ aggregations or sorting.
1843
+
1844
+ To avoid this, specify return type in ``func``, for instance, as below:
1845
+
1846
+ >>> def pandas_div(x) -> ps.DataFrame[int, [float, float]]:
1847
+ ... return x[['B', 'C']] / x[['B', 'C']]
1848
+
1849
+ If the return type is specified, the output column names become
1850
+ `c0, c1, c2 ... cn`. These names are positionally mapped to the returned
1851
+ DataFrame in ``func``.
1852
+
1853
+ To specify the column names, you can assign them in a NumPy compound type style
1854
+ as below:
1855
+
1856
+ >>> def pandas_div(x) -> ps.DataFrame[("index", int), [("a", float), ("b", float)]]:
1857
+ ... return x[['B', 'C']] / x[['B', 'C']]
1858
+
1859
+ >>> pdf = pd.DataFrame({'B': [1.], 'C': [3.]})
1860
+ >>> def plus_one(x) -> ps.DataFrame[
1861
+ ... (pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]:
1862
+ ... return x[['B', 'C']] / x[['B', 'C']]
1863
+
1864
+ .. note:: the dataframe within ``func`` is actually a pandas dataframe. Therefore,
1865
+ any pandas API within this function is allowed.
1866
+
1867
+ Parameters
1868
+ ----------
1869
+ func : callable
1870
+ A callable that takes a DataFrame as its first argument, and
1871
+ returns a dataframe.
1872
+ *args
1873
+ Positional arguments to pass to func.
1874
+ **kwargs
1875
+ Keyword arguments to pass to func.
1876
+
1877
+ Returns
1878
+ -------
1879
+ applied : DataFrame or Series
1880
+
1881
+ See Also
1882
+ --------
1883
+ aggregate : Apply aggregate function to the GroupBy object.
1884
+ DataFrame.apply : Apply a function to a DataFrame.
1885
+ Series.apply : Apply a function to a Series.
1886
+
1887
+ Examples
1888
+ --------
1889
+ >>> df = ps.DataFrame({'A': 'a a b'.split(),
1890
+ ... 'B': [1, 2, 3],
1891
+ ... 'C': [4, 6, 5]}, columns=['A', 'B', 'C'])
1892
+ >>> g = df.groupby('A')
1893
+
1894
+ Notice that ``g`` has two groups, ``a`` and ``b``.
1895
+ Calling `apply` in various ways, we can get different grouping results:
1896
+
1897
+ Below the functions passed to `apply` takes a DataFrame as
1898
+ its argument and returns a DataFrame. `apply` combines the result for
1899
+ each group together into a new DataFrame:
1900
+
1901
+ >>> def plus_min(x):
1902
+ ... return x + x.min()
1903
+ >>> g.apply(plus_min).sort_index() # doctest: +SKIP
1904
+ A B C
1905
+ 0 aa 2 8
1906
+ 1 aa 3 10
1907
+ 2 bb 6 10
1908
+
1909
+ >>> g.apply(sum).sort_index() # doctest: +NORMALIZE_WHITESPACE
1910
+ A B C
1911
+ A
1912
+ a aa 3 10
1913
+ b b 3 5
1914
+
1915
+ >>> g.apply(len).sort_index() # doctest: +NORMALIZE_WHITESPACE
1916
+ A
1917
+ a 2
1918
+ b 1
1919
+ dtype: int64
1920
+
1921
+ You can specify the type hint and prevent schema inference for better performance.
1922
+
1923
+ >>> def pandas_div(x) -> ps.DataFrame[int, [float, float]]:
1924
+ ... return x[['B', 'C']] / x[['B', 'C']]
1925
+ >>> g.apply(pandas_div).sort_index() # doctest: +SKIP
1926
+ c0 c1
1927
+ 0 1.0 1.0
1928
+ 1 1.0 1.0
1929
+ 2 1.0 1.0
1930
+
1931
+ >>> def pandas_div(x) -> ps.DataFrame[("index", int), [("f1", float), ("f2", float)]]:
1932
+ ... return x[['B', 'C']] / x[['B', 'C']]
1933
+ >>> g.apply(pandas_div).sort_index() # doctest: +SKIP
1934
+ f1 f2
1935
+ index
1936
+ 0 1.0 1.0
1937
+ 1 1.0 1.0
1938
+ 2 1.0 1.0
1939
+
1940
+ In case of Series, it works as below.
1941
+
1942
+ >>> def plus_max(x) -> ps.Series[int]:
1943
+ ... return x + x.max()
1944
+ >>> df.B.groupby(df.A).apply(plus_max).sort_index() # doctest: +SKIP
1945
+ 0 6
1946
+ 1 3
1947
+ 2 4
1948
+ Name: B, dtype: int64
1949
+
1950
+ >>> def plus_min(x):
1951
+ ... return x + x.min()
1952
+ >>> df.B.groupby(df.A).apply(plus_min).sort_index() # doctest: +SKIP
1953
+ 0 2
1954
+ 1 3
1955
+ 2 6
1956
+ Name: B, dtype: int64
1957
+
1958
+ You can also return a scalar value as an aggregated value of the group:
1959
+
1960
+ >>> def plus_length(x) -> int:
1961
+ ... return len(x)
1962
+ >>> df.B.groupby(df.A).apply(plus_length).sort_index() # doctest: +SKIP
1963
+ 0 1
1964
+ 1 2
1965
+ Name: B, dtype: int64
1966
+
1967
+ The extra arguments to the function can be passed as below.
1968
+
1969
+ >>> def calculation(x, y, z) -> int:
1970
+ ... return len(x) + y * z
1971
+ >>> df.B.groupby(df.A).apply(calculation, 5, z=10).sort_index() # doctest: +SKIP
1972
+ 0 51
1973
+ 1 52
1974
+ Name: B, dtype: int64
1975
+ """
1976
+ if not callable(func):
1977
+ raise TypeError("%s object is not callable" % type(func).__name__)
1978
+
1979
+ spec = inspect.getfullargspec(func)
1980
+ return_sig = spec.annotations.get("return", None)
1981
+ should_infer_schema = return_sig is None
1982
+ should_retain_index = should_infer_schema
1983
+
1984
+ is_series_groupby = isinstance(self, SeriesGroupBy)
1985
+
1986
+ psdf = self._psdf
1987
+
1988
+ if self._agg_columns_selected:
1989
+ agg_columns = self._agg_columns
1990
+ else:
1991
+ agg_columns = [
1992
+ psdf._psser_for(label)
1993
+ for label in psdf._internal.column_labels
1994
+ if label not in self._column_labels_to_exclude
1995
+ ]
1996
+
1997
+ psdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply(
1998
+ psdf, self._groupkeys, agg_columns
1999
+ )
2000
+
2001
+ if is_series_groupby:
2002
+ name = psdf.columns[-1]
2003
+ pandas_apply = _builtin_table.get(func, func)
2004
+ else:
2005
+ f = _builtin_table.get(func, func)
2006
+
2007
+ def pandas_apply(pdf: pd.DataFrame, *a: Any, **k: Any) -> Any:
2008
+ return f(pdf.drop(groupkey_names, axis=1), *a, **k)
2009
+
2010
+ should_return_series = False
2011
+
2012
+ if should_infer_schema:
2013
+ # Here we execute with the first 1000 to get the return type.
2014
+ log_advice(
2015
+ "If the type hints is not specified for `groupby.apply`, "
2016
+ "it is expensive to infer the data type internally."
2017
+ )
2018
+ limit = get_option("compute.shortcut_limit")
2019
+ # Ensure sampling rows >= 2 to make sure apply's infer schema is accurate
2020
+ # See related: https://github.com/pandas-dev/pandas/issues/46893
2021
+ sample_limit = limit + 1 if limit else 2
2022
+ pdf = psdf.head(sample_limit)._to_internal_pandas()
2023
+ groupkeys = [
2024
+ pdf[groupkey_name].rename(psser.name)
2025
+ for groupkey_name, psser in zip(groupkey_names, self._groupkeys)
2026
+ ]
2027
+ grouped = pdf.groupby(groupkeys)
2028
+ if is_series_groupby:
2029
+ pser_or_pdf = grouped[name].apply(pandas_apply, *args, **kwargs)
2030
+ else:
2031
+ pser_or_pdf = grouped.apply(pandas_apply, *args, **kwargs)
2032
+ psser_or_psdf = ps.from_pandas(pser_or_pdf.infer_objects())
2033
+
2034
+ if len(pdf) <= limit:
2035
+ if isinstance(psser_or_psdf, ps.Series) and is_series_groupby:
2036
+ psser_or_psdf = psser_or_psdf.rename(cast(SeriesGroupBy, self)._psser.name)
2037
+ return cast(Union[Series, DataFrame], psser_or_psdf)
2038
+
2039
+ if len(grouped) <= 1:
2040
+ with warnings.catch_warnings():
2041
+ warnings.simplefilter("always")
2042
+ warnings.warn(
2043
+ "The amount of data for return type inference might not be large enough. "
2044
+ "Consider increasing an option `compute.shortcut_limit`."
2045
+ )
2046
+
2047
+ if isinstance(psser_or_psdf, Series):
2048
+ should_return_series = True
2049
+ psdf_from_pandas = psser_or_psdf._psdf
2050
+ else:
2051
+ psdf_from_pandas = cast(DataFrame, psser_or_psdf)
2052
+
2053
+ index_fields = [
2054
+ field.normalize_spark_type() for field in psdf_from_pandas._internal.index_fields
2055
+ ]
2056
+ data_fields = [
2057
+ field.normalize_spark_type() for field in psdf_from_pandas._internal.data_fields
2058
+ ]
2059
+ return_schema = StructType([field.struct_field for field in index_fields + data_fields])
2060
+ else:
2061
+ return_type = infer_return_type(func)
2062
+ if not is_series_groupby and isinstance(return_type, SeriesType):
2063
+ raise TypeError(
2064
+ "Series as a return type hint at frame groupby is not supported "
2065
+ "currently; however got [%s]. Use DataFrame type hint instead." % return_sig
2066
+ )
2067
+
2068
+ if isinstance(return_type, DataFrameType):
2069
+ data_fields = return_type.data_fields
2070
+ return_schema = return_type.spark_type
2071
+ index_fields = return_type.index_fields
2072
+ should_retain_index = len(index_fields) > 0
2073
+ psdf_from_pandas = None
2074
+ else:
2075
+ should_return_series = True
2076
+ dtype = cast(Union[SeriesType, ScalarType], return_type).dtype
2077
+ spark_type = cast(Union[SeriesType, ScalarType], return_type).spark_type
2078
+ if is_series_groupby:
2079
+ data_fields = [
2080
+ InternalField(
2081
+ dtype=dtype, struct_field=StructField(name=name, dataType=spark_type)
2082
+ )
2083
+ ]
2084
+ else:
2085
+ data_fields = [
2086
+ InternalField(
2087
+ dtype=dtype,
2088
+ struct_field=StructField(
2089
+ name=SPARK_DEFAULT_SERIES_NAME, dataType=spark_type
2090
+ ),
2091
+ )
2092
+ ]
2093
+ return_schema = StructType([field.struct_field for field in data_fields])
2094
+
2095
+ def pandas_groupby_apply(pdf: pd.DataFrame) -> pd.DataFrame:
2096
+
2097
+ if is_series_groupby:
2098
+ pdf_or_ser = pdf.groupby(groupkey_names)[name].apply(pandas_apply, *args, **kwargs)
2099
+ else:
2100
+ pdf_or_ser = pdf.groupby(groupkey_names).apply(pandas_apply, *args, **kwargs)
2101
+ if should_return_series and isinstance(pdf_or_ser, pd.DataFrame):
2102
+ pdf_or_ser = pdf_or_ser.stack()
2103
+
2104
+ if not isinstance(pdf_or_ser, pd.DataFrame):
2105
+ return pd.DataFrame(pdf_or_ser)
2106
+ else:
2107
+ return pdf_or_ser
2108
+
2109
+ sdf = GroupBy._spark_group_map_apply(
2110
+ psdf,
2111
+ pandas_groupby_apply,
2112
+ [psdf._internal.spark_column_for(label) for label in groupkey_labels],
2113
+ return_schema,
2114
+ retain_index=should_retain_index,
2115
+ )
2116
+
2117
+ if should_retain_index:
2118
+ # If schema is inferred, we can restore indexes too.
2119
+ if psdf_from_pandas is not None:
2120
+ internal = psdf_from_pandas._internal.with_new_sdf(
2121
+ spark_frame=sdf, index_fields=index_fields, data_fields=data_fields
2122
+ )
2123
+ else:
2124
+ index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
2125
+
2126
+ index_spark_columns = [
2127
+ scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
2128
+ ]
2129
+
2130
+ if not any(
2131
+ [
2132
+ SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
2133
+ for index_field in index_fields
2134
+ ]
2135
+ ):
2136
+ index_names = [(index_field.struct_field.name,) for index_field in index_fields]
2137
+ internal = InternalFrame(
2138
+ spark_frame=sdf,
2139
+ index_names=index_names,
2140
+ index_spark_columns=index_spark_columns,
2141
+ index_fields=index_fields,
2142
+ data_fields=data_fields,
2143
+ )
2144
+ else:
2145
+ # Otherwise, it loses index.
2146
+ internal = InternalFrame(
2147
+ spark_frame=sdf, index_spark_columns=None, data_fields=data_fields
2148
+ )
2149
+
2150
+ if should_return_series:
2151
+ psser = first_series(DataFrame(internal))
2152
+ if is_series_groupby:
2153
+ psser = psser.rename(cast(SeriesGroupBy, self)._psser.name)
2154
+ return psser
2155
+ else:
2156
+ return DataFrame(internal)
2157
+
2158
+ # TODO: implement 'dropna' parameter
2159
+ def filter(self, func: Callable[[FrameLike], FrameLike]) -> FrameLike:
2160
+ """
2161
+ Return a copy of a DataFrame excluding elements from groups that
2162
+ do not satisfy the boolean criterion specified by func.
2163
+
2164
+ Parameters
2165
+ ----------
2166
+ f : function
2167
+ Function to apply to each subframe. Should return True or False.
2168
+ dropna : Drop groups that do not pass the filter. True by default;
2169
+ if False, groups that evaluate False are filled with NaNs.
2170
+
2171
+ Returns
2172
+ -------
2173
+ filtered : DataFrame or Series
2174
+
2175
+ Notes
2176
+ -----
2177
+ Each subframe is endowed the attribute 'name' in case you need to know
2178
+ which group you are working on.
2179
+
2180
+ Examples
2181
+ --------
2182
+ >>> df = ps.DataFrame({'A' : ['foo', 'bar', 'foo', 'bar',
2183
+ ... 'foo', 'bar'],
2184
+ ... 'B' : [1, 2, 3, 4, 5, 6],
2185
+ ... 'C' : [2.0, 5., 8., 1., 2., 9.]}, columns=['A', 'B', 'C'])
2186
+ >>> grouped = df.groupby('A')
2187
+ >>> grouped.filter(lambda x: x['B'].mean() > 3.)
2188
+ A B C
2189
+ 1 bar 2 5.0
2190
+ 3 bar 4 1.0
2191
+ 5 bar 6 9.0
2192
+
2193
+ >>> df.B.groupby(df.A).filter(lambda x: x.mean() > 3.)
2194
+ 1 2
2195
+ 3 4
2196
+ 5 6
2197
+ Name: B, dtype: int64
2198
+ """
2199
+ if not callable(func):
2200
+ raise TypeError("%s object is not callable" % type(func).__name__)
2201
+
2202
+ is_series_groupby = isinstance(self, SeriesGroupBy)
2203
+
2204
+ psdf = self._psdf
2205
+
2206
+ if self._agg_columns_selected:
2207
+ agg_columns = self._agg_columns
2208
+ else:
2209
+ agg_columns = [
2210
+ psdf._psser_for(label)
2211
+ for label in psdf._internal.column_labels
2212
+ if label not in self._column_labels_to_exclude
2213
+ ]
2214
+
2215
+ data_schema = (
2216
+ psdf[agg_columns]._internal.resolved_copy.spark_frame.drop(*HIDDEN_COLUMNS).schema
2217
+ )
2218
+
2219
+ psdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply(
2220
+ psdf, self._groupkeys, agg_columns
2221
+ )
2222
+
2223
+ if is_series_groupby:
2224
+
2225
+ def pandas_filter(pdf: pd.DataFrame) -> pd.DataFrame:
2226
+ return pd.DataFrame(pdf.groupby(groupkey_names)[pdf.columns[-1]].filter(func))
2227
+
2228
+ else:
2229
+ f = _builtin_table.get(func, func)
2230
+
2231
+ def wrapped_func(pdf: pd.DataFrame) -> pd.DataFrame:
2232
+ return f(pdf.drop(groupkey_names, axis=1))
2233
+
2234
+ def pandas_filter(pdf: pd.DataFrame) -> pd.DataFrame:
2235
+ return pdf.groupby(groupkey_names).filter(wrapped_func).drop(groupkey_names, axis=1)
2236
+
2237
+ sdf = GroupBy._spark_group_map_apply(
2238
+ psdf,
2239
+ pandas_filter,
2240
+ [psdf._internal.spark_column_for(label) for label in groupkey_labels],
2241
+ data_schema,
2242
+ retain_index=True,
2243
+ )
2244
+
2245
+ psdf = DataFrame(self._psdf[agg_columns]._internal.with_new_sdf(sdf))
2246
+ if is_series_groupby:
2247
+ return cast(FrameLike, first_series(psdf))
2248
+ else:
2249
+ return cast(FrameLike, psdf)
2250
+
2251
+ @staticmethod
2252
+ def _prepare_group_map_apply(
2253
+ psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series]
2254
+ ) -> Tuple[DataFrame, List[Label], List[str]]:
2255
+ groupkey_labels: List[Label] = [
2256
+ verify_temp_column_name(psdf, "__groupkey_{}__".format(i))
2257
+ for i in range(len(groupkeys))
2258
+ ]
2259
+ psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns]
2260
+ groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels]
2261
+ return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names
2262
+
2263
+ @staticmethod
2264
+ def _spark_group_map_apply(
2265
+ psdf: DataFrame,
2266
+ func: Callable[[pd.DataFrame], pd.DataFrame],
2267
+ groupkeys_scols: List[Column],
2268
+ return_schema: StructType,
2269
+ retain_index: bool,
2270
+ ) -> SparkDataFrame:
2271
+ output_func = GroupBy._make_pandas_df_builder_func(psdf, func, return_schema, retain_index)
2272
+ sdf = psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS)
2273
+ return sdf.groupby(*groupkeys_scols).applyInPandas(output_func, return_schema)
2274
+
2275
+ @staticmethod
2276
+ def _make_pandas_df_builder_func(
2277
+ psdf: DataFrame,
2278
+ func: Callable[[pd.DataFrame], pd.DataFrame],
2279
+ return_schema: StructType,
2280
+ retain_index: bool,
2281
+ ) -> Callable[[pd.DataFrame], pd.DataFrame]:
2282
+ """
2283
+ Creates a function that can be used inside the pandas UDF. This function can construct
2284
+ the same pandas DataFrame as if the pandas-on-Spark DataFrame is collected to driver side.
2285
+ The index, column labels, etc. are re-constructed within the function.
2286
+ """
2287
+ from pyspark.sql.utils import is_timestamp_ntz_preferred
2288
+
2289
+ arguments_for_restore_index = psdf._internal.arguments_for_restore_index
2290
+ prefer_timestamp_ntz = is_timestamp_ntz_preferred()
2291
+
2292
+ def rename_output(pdf: pd.DataFrame) -> pd.DataFrame:
2293
+ pdf = InternalFrame.restore_index(pdf.copy(), **arguments_for_restore_index)
2294
+
2295
+ pdf = func(pdf)
2296
+
2297
+ # If schema should be inferred, we don't restore the index. pandas seems to restore
2298
+ # the index in some cases.
2299
+ # When Spark output type is specified, without executing it, we don't know
2300
+ # if we should restore the index or not. For instance, see the example in
2301
+ # https://github.com/databricks/koalas/issues/628.
2302
+ pdf, _, _, _, _ = InternalFrame.prepare_pandas_frame(
2303
+ pdf, retain_index=retain_index, prefer_timestamp_ntz=prefer_timestamp_ntz
2304
+ )
2305
+
2306
+ # Just positionally map the column names to given schema's.
2307
+ pdf.columns = return_schema.names
2308
+
2309
+ return pdf
2310
+
2311
+ return rename_output
2312
+
2313
+ def rank(self, method: str = "average", ascending: bool = True) -> FrameLike:
2314
+ """
2315
+ Provide the rank of values within each group.
2316
+
2317
+ Parameters
2318
+ ----------
2319
+ method : {'average', 'min', 'max', 'first', 'dense'}, default 'average'
2320
+ * average: average rank of group
2321
+ * min: lowest rank in group
2322
+ * max: highest rank in group
2323
+ * first: ranks assigned in order they appear in the array
2324
+ * dense: like 'min', but rank always increases by 1 between groups
2325
+ ascending : boolean, default True
2326
+ False for ranks by high (1) to low (N)
2327
+
2328
+ Returns
2329
+ -------
2330
+ DataFrame with ranking of values within each group
2331
+
2332
+ Examples
2333
+ --------
2334
+
2335
+ >>> df = ps.DataFrame({
2336
+ ... 'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
2337
+ ... 'b': [1, 2, 2, 2, 3, 3, 3, 4, 4]}, columns=['a', 'b'])
2338
+ >>> df
2339
+ a b
2340
+ 0 1 1
2341
+ 1 1 2
2342
+ 2 1 2
2343
+ 3 2 2
2344
+ 4 2 3
2345
+ 5 2 3
2346
+ 6 3 3
2347
+ 7 3 4
2348
+ 8 3 4
2349
+
2350
+ >>> df.groupby("a").rank().sort_index()
2351
+ b
2352
+ 0 1.0
2353
+ 1 2.5
2354
+ 2 2.5
2355
+ 3 1.0
2356
+ 4 2.5
2357
+ 5 2.5
2358
+ 6 1.0
2359
+ 7 2.5
2360
+ 8 2.5
2361
+
2362
+ >>> df.b.groupby(df.a).rank(method='max').sort_index()
2363
+ 0 1.0
2364
+ 1 3.0
2365
+ 2 3.0
2366
+ 3 1.0
2367
+ 4 3.0
2368
+ 5 3.0
2369
+ 6 1.0
2370
+ 7 3.0
2371
+ 8 3.0
2372
+ Name: b, dtype: float64
2373
+
2374
+ """
2375
+ return self._apply_series_op(
2376
+ lambda sg: sg._psser._rank(method, ascending, part_cols=sg._groupkeys_scols),
2377
+ should_resolve=True,
2378
+ )
2379
+
2380
+ # TODO: add axis parameter
2381
+ def idxmax(self, skipna: bool = True) -> FrameLike:
2382
+ """
2383
+ Return index of first occurrence of maximum over requested axis in group.
2384
+ NA/null values are excluded.
2385
+
2386
+ Parameters
2387
+ ----------
2388
+ skipna : boolean, default True
2389
+ Exclude NA/null values. If an entire row/column is NA, the result will be NA.
2390
+
2391
+ See Also
2392
+ --------
2393
+ Series.idxmax
2394
+ DataFrame.idxmax
2395
+ pyspark.pandas.Series.groupby
2396
+ pyspark.pandas.DataFrame.groupby
2397
+
2398
+ Examples
2399
+ --------
2400
+ >>> df = ps.DataFrame({'a': [1, 1, 2, 2, 3],
2401
+ ... 'b': [1, 2, 3, 4, 5],
2402
+ ... 'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])
2403
+
2404
+ >>> df.groupby(['a'])['b'].idxmax().sort_index() # doctest: +NORMALIZE_WHITESPACE
2405
+ a
2406
+ 1 1
2407
+ 2 3
2408
+ 3 4
2409
+ Name: b, dtype: int64
2410
+
2411
+ >>> df.groupby(['a']).idxmax().sort_index() # doctest: +NORMALIZE_WHITESPACE
2412
+ b c
2413
+ a
2414
+ 1 1 0
2415
+ 2 3 2
2416
+ 3 4 4
2417
+ """
2418
+ if self._psdf._internal.index_level != 1:
2419
+ raise ValueError("idxmax only support one-level index now")
2420
+
2421
+ groupkey_names = ["__groupkey_{}__".format(i) for i in range(len(self._groupkeys))]
2422
+
2423
+ sdf = self._psdf._internal.spark_frame
2424
+ for s, name in zip(self._groupkeys, groupkey_names):
2425
+ sdf = sdf.withColumn(name, s.spark.column)
2426
+ index = self._psdf._internal.index_spark_column_names[0]
2427
+
2428
+ stat_exprs = []
2429
+ for psser, scol in zip(self._agg_columns, self._agg_columns_scols):
2430
+ name = psser._internal.data_spark_column_names[0]
2431
+
2432
+ if skipna:
2433
+ order_column = scol.desc_nulls_last()
2434
+ else:
2435
+ order_column = scol.desc_nulls_first()
2436
+
2437
+ window = Window.partitionBy(*groupkey_names).orderBy(
2438
+ order_column, NATURAL_ORDER_COLUMN_NAME
2439
+ )
2440
+ sdf = sdf.withColumn(
2441
+ name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None)
2442
+ )
2443
+ stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
2444
+
2445
+ sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
2446
+
2447
+ internal = InternalFrame(
2448
+ spark_frame=sdf,
2449
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
2450
+ index_names=[psser._column_label for psser in self._groupkeys],
2451
+ index_fields=[
2452
+ psser._internal.data_fields[0].copy(name=name)
2453
+ for psser, name in zip(self._groupkeys, groupkey_names)
2454
+ ],
2455
+ column_labels=[psser._column_label for psser in self._agg_columns],
2456
+ data_spark_columns=[
2457
+ scol_for(sdf, psser._internal.data_spark_column_names[0])
2458
+ for psser in self._agg_columns
2459
+ ],
2460
+ )
2461
+ return self._handle_output(DataFrame(internal))
2462
+
2463
+ # TODO: add axis parameter
2464
+ def idxmin(self, skipna: bool = True) -> FrameLike:
2465
+ """
2466
+ Return index of first occurrence of minimum over requested axis in group.
2467
+ NA/null values are excluded.
2468
+
2469
+ Parameters
2470
+ ----------
2471
+ skipna : boolean, default True
2472
+ Exclude NA/null values. If an entire row/column is NA, the result will be NA.
2473
+
2474
+ See Also
2475
+ --------
2476
+ Series.idxmin
2477
+ DataFrame.idxmin
2478
+ pyspark.pandas.Series.groupby
2479
+ pyspark.pandas.DataFrame.groupby
2480
+
2481
+ Examples
2482
+ --------
2483
+ >>> df = ps.DataFrame({'a': [1, 1, 2, 2, 3],
2484
+ ... 'b': [1, 2, 3, 4, 5],
2485
+ ... 'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])
2486
+
2487
+ >>> df.groupby(['a'])['b'].idxmin().sort_index() # doctest: +NORMALIZE_WHITESPACE
2488
+ a
2489
+ 1 0
2490
+ 2 2
2491
+ 3 4
2492
+ Name: b, dtype: int64
2493
+
2494
+ >>> df.groupby(['a']).idxmin().sort_index() # doctest: +NORMALIZE_WHITESPACE
2495
+ b c
2496
+ a
2497
+ 1 0 1
2498
+ 2 2 3
2499
+ 3 4 4
2500
+ """
2501
+ if self._psdf._internal.index_level != 1:
2502
+ raise ValueError("idxmin only support one-level index now")
2503
+
2504
+ groupkey_names = ["__groupkey_{}__".format(i) for i in range(len(self._groupkeys))]
2505
+
2506
+ sdf = self._psdf._internal.spark_frame
2507
+ for s, name in zip(self._groupkeys, groupkey_names):
2508
+ sdf = sdf.withColumn(name, s.spark.column)
2509
+ index = self._psdf._internal.index_spark_column_names[0]
2510
+
2511
+ stat_exprs = []
2512
+ for psser, scol in zip(self._agg_columns, self._agg_columns_scols):
2513
+ name = psser._internal.data_spark_column_names[0]
2514
+
2515
+ if skipna:
2516
+ order_column = scol.asc_nulls_last()
2517
+ else:
2518
+ order_column = scol.asc_nulls_first()
2519
+
2520
+ window = Window.partitionBy(*groupkey_names).orderBy(
2521
+ order_column, NATURAL_ORDER_COLUMN_NAME
2522
+ )
2523
+ sdf = sdf.withColumn(
2524
+ name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None)
2525
+ )
2526
+ stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
2527
+
2528
+ sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
2529
+
2530
+ internal = InternalFrame(
2531
+ spark_frame=sdf,
2532
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
2533
+ index_names=[psser._column_label for psser in self._groupkeys],
2534
+ index_fields=[
2535
+ psser._internal.data_fields[0].copy(name=name)
2536
+ for psser, name in zip(self._groupkeys, groupkey_names)
2537
+ ],
2538
+ column_labels=[psser._column_label for psser in self._agg_columns],
2539
+ data_spark_columns=[
2540
+ scol_for(sdf, psser._internal.data_spark_column_names[0])
2541
+ for psser in self._agg_columns
2542
+ ],
2543
+ )
2544
+ return self._handle_output(DataFrame(internal))
2545
+
2546
+ def fillna(
2547
+ self,
2548
+ value: Optional[Any] = None,
2549
+ method: Optional[str] = None,
2550
+ axis: Optional[Axis] = None,
2551
+ inplace: bool = False,
2552
+ limit: Optional[int] = None,
2553
+ ) -> FrameLike:
2554
+ """Fill NA/NaN values in group.
2555
+
2556
+ Parameters
2557
+ ----------
2558
+ value : scalar, dict, Series
2559
+ Value to use to fill holes. alternately a dict/Series of values
2560
+ specifying which value to use for each column.
2561
+ DataFrame is not supported.
2562
+ method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None
2563
+ Method to use for filling holes in reindexed Series pad / ffill: propagate last valid
2564
+ observation forward to next valid backfill / bfill:
2565
+ use NEXT valid observation to fill gap
2566
+ axis : {0 or `index`}
2567
+ 1 and `columns` are not supported.
2568
+ inplace : boolean, default False
2569
+ Fill in place (do not create a new object)
2570
+ limit : int, default None
2571
+ If method is specified, this is the maximum number of consecutive NaN values to
2572
+ forward/backward fill. In other words, if there is a gap with more than this number of
2573
+ consecutive NaNs, it will only be partially filled. If method is not specified,
2574
+ this is the maximum number of entries along the entire axis where NaNs will be filled.
2575
+ Must be greater than 0 if not None
2576
+
2577
+ Returns
2578
+ -------
2579
+ DataFrame
2580
+ DataFrame with NA entries filled.
2581
+
2582
+ Examples
2583
+ --------
2584
+ >>> df = ps.DataFrame({
2585
+ ... 'A': [1, 1, 2, 2],
2586
+ ... 'B': [2, 4, None, 3],
2587
+ ... 'C': [None, None, None, 1],
2588
+ ... 'D': [0, 1, 5, 4]
2589
+ ... },
2590
+ ... columns=['A', 'B', 'C', 'D'])
2591
+ >>> df
2592
+ A B C D
2593
+ 0 1 2.0 NaN 0
2594
+ 1 1 4.0 NaN 1
2595
+ 2 2 NaN NaN 5
2596
+ 3 2 3.0 1.0 4
2597
+
2598
+ We can also propagate non-null values forward or backward in group.
2599
+
2600
+ >>> df.groupby(['A'])['B'].fillna(method='ffill').sort_index()
2601
+ 0 2.0
2602
+ 1 4.0
2603
+ 2 NaN
2604
+ 3 3.0
2605
+ Name: B, dtype: float64
2606
+
2607
+ >>> df.groupby(['A']).fillna(method='bfill').sort_index()
2608
+ B C D
2609
+ 0 2.0 NaN 0
2610
+ 1 4.0 NaN 1
2611
+ 2 3.0 1.0 5
2612
+ 3 3.0 1.0 4
2613
+ """
2614
+ return self._apply_series_op(
2615
+ lambda sg: sg._psser._fillna(
2616
+ value=value, method=method, axis=axis, limit=limit, part_cols=sg._groupkeys_scols
2617
+ ),
2618
+ should_resolve=(method is not None),
2619
+ )
2620
+
2621
+ def bfill(self, limit: Optional[int] = None) -> FrameLike:
2622
+ """
2623
+ Synonym for `DataFrame.fillna()` with ``method=`bfill```.
2624
+
2625
+ Parameters
2626
+ ----------
2627
+ axis : {0 or `index`}
2628
+ 1 and `columns` are not supported.
2629
+ inplace : boolean, default False
2630
+ Fill in place (do not create a new object)
2631
+ limit : int, default None
2632
+ If method is specified, this is the maximum number of consecutive NaN values to
2633
+ forward/backward fill. In other words, if there is a gap with more than this number of
2634
+ consecutive NaNs, it will only be partially filled. If method is not specified,
2635
+ this is the maximum number of entries along the entire axis where NaNs will be filled.
2636
+ Must be greater than 0 if not None
2637
+
2638
+ Returns
2639
+ -------
2640
+ DataFrame
2641
+ DataFrame with NA entries filled.
2642
+
2643
+ Examples
2644
+ --------
2645
+ >>> df = ps.DataFrame({
2646
+ ... 'A': [1, 1, 2, 2],
2647
+ ... 'B': [2, 4, None, 3],
2648
+ ... 'C': [None, None, None, 1],
2649
+ ... 'D': [0, 1, 5, 4]
2650
+ ... },
2651
+ ... columns=['A', 'B', 'C', 'D'])
2652
+ >>> df
2653
+ A B C D
2654
+ 0 1 2.0 NaN 0
2655
+ 1 1 4.0 NaN 1
2656
+ 2 2 NaN NaN 5
2657
+ 3 2 3.0 1.0 4
2658
+
2659
+ Propagate non-null values backward.
2660
+
2661
+ >>> df.groupby(['A']).bfill().sort_index()
2662
+ B C D
2663
+ 0 2.0 NaN 0
2664
+ 1 4.0 NaN 1
2665
+ 2 3.0 1.0 5
2666
+ 3 3.0 1.0 4
2667
+ """
2668
+ return self.fillna(method="bfill", limit=limit)
2669
+
2670
+ def backfill(self, limit: Optional[int] = None) -> FrameLike:
2671
+ """
2672
+ Alias for bfill.
2673
+
2674
+ .. deprecated:: 3.4.0
2675
+ """
2676
+ warnings.warn(
2677
+ "The GroupBy.backfill method is deprecated "
2678
+ "and will be removed in a future version. "
2679
+ "Use GroupBy.bfill instead.",
2680
+ FutureWarning,
2681
+ )
2682
+ return self.bfill(limit=limit)
2683
+
2684
+ def ffill(self, limit: Optional[int] = None) -> FrameLike:
2685
+ """
2686
+ Synonym for `DataFrame.fillna()` with ``method=`ffill```.
2687
+
2688
+ Parameters
2689
+ ----------
2690
+ axis : {0 or `index`}
2691
+ 1 and `columns` are not supported.
2692
+ inplace : boolean, default False
2693
+ Fill in place (do not create a new object)
2694
+ limit : int, default None
2695
+ If method is specified, this is the maximum number of consecutive NaN values to
2696
+ forward/backward fill. In other words, if there is a gap with more than this number of
2697
+ consecutive NaNs, it will only be partially filled. If method is not specified,
2698
+ this is the maximum number of entries along the entire axis where NaNs will be filled.
2699
+ Must be greater than 0 if not None
2700
+
2701
+ Returns
2702
+ -------
2703
+ DataFrame
2704
+ DataFrame with NA entries filled.
2705
+
2706
+ Examples
2707
+ --------
2708
+ >>> df = ps.DataFrame({
2709
+ ... 'A': [1, 1, 2, 2],
2710
+ ... 'B': [2, 4, None, 3],
2711
+ ... 'C': [None, None, None, 1],
2712
+ ... 'D': [0, 1, 5, 4]
2713
+ ... },
2714
+ ... columns=['A', 'B', 'C', 'D'])
2715
+ >>> df
2716
+ A B C D
2717
+ 0 1 2.0 NaN 0
2718
+ 1 1 4.0 NaN 1
2719
+ 2 2 NaN NaN 5
2720
+ 3 2 3.0 1.0 4
2721
+
2722
+ Propagate non-null values forward.
2723
+
2724
+ >>> df.groupby(['A']).ffill().sort_index()
2725
+ B C D
2726
+ 0 2.0 NaN 0
2727
+ 1 4.0 NaN 1
2728
+ 2 NaN NaN 5
2729
+ 3 3.0 1.0 4
2730
+ """
2731
+ return self.fillna(method="ffill", limit=limit)
2732
+
2733
+ def pad(self, limit: Optional[int] = None) -> FrameLike:
2734
+ """
2735
+ Alias for ffill.
2736
+
2737
+ .. deprecated:: 3.4.0
2738
+ """
2739
+ warnings.warn(
2740
+ "The GroupBy.pad method is deprecated "
2741
+ "and will be removed in a future version. "
2742
+ "Use GroupBy.ffill instead.",
2743
+ FutureWarning,
2744
+ )
2745
+ return self.ffill(limit=limit)
2746
+
2747
+ def _limit(self, n: int, asc: bool) -> FrameLike:
2748
+ """
2749
+ Private function for tail and head.
2750
+ """
2751
+ psdf = self._psdf
2752
+
2753
+ if self._agg_columns_selected:
2754
+ agg_columns = self._agg_columns
2755
+ else:
2756
+ agg_columns = [
2757
+ psdf._psser_for(label)
2758
+ for label in psdf._internal.column_labels
2759
+ if label not in self._column_labels_to_exclude
2760
+ ]
2761
+
2762
+ psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
2763
+ psdf,
2764
+ self._groupkeys,
2765
+ agg_columns,
2766
+ )
2767
+
2768
+ groupkey_scols = [psdf._internal.spark_column_for(label) for label in groupkey_labels]
2769
+
2770
+ sdf = psdf._internal.spark_frame
2771
+
2772
+ window = Window.partitionBy(*groupkey_scols)
2773
+ # This part is handled differently depending on whether it is a tail or a head.
2774
+ ordered_window = (
2775
+ window.orderBy(F.col(NATURAL_ORDER_COLUMN_NAME).asc())
2776
+ if asc
2777
+ else window.orderBy(F.col(NATURAL_ORDER_COLUMN_NAME).desc())
2778
+ )
2779
+
2780
+ if n >= 0 or LooseVersion(pd.__version__) < LooseVersion("1.4.0"):
2781
+ tmp_row_num_col = verify_temp_column_name(sdf, "__row_number__")
2782
+ sdf = (
2783
+ sdf.withColumn(tmp_row_num_col, F.row_number().over(ordered_window))
2784
+ .filter(F.col(tmp_row_num_col) <= n)
2785
+ .drop(tmp_row_num_col)
2786
+ )
2787
+ else:
2788
+ # Pandas supports Groupby positional indexing since v1.4.0
2789
+ # https://pandas.pydata.org/docs/whatsnew/v1.4.0.html#groupby-positional-indexing
2790
+ #
2791
+ # To support groupby positional indexing, we need add a `__tmp_lag__` column to help
2792
+ # us filtering rows before the specified offset row.
2793
+ #
2794
+ # For example for the dataframe:
2795
+ # >>> df = ps.DataFrame([["g", "g0"],
2796
+ # ... ["g", "g1"],
2797
+ # ... ["g", "g2"],
2798
+ # ... ["g", "g3"],
2799
+ # ... ["h", "h0"],
2800
+ # ... ["h", "h1"]], columns=["A", "B"])
2801
+ # >>> df.groupby("A").head(-1)
2802
+ #
2803
+ # Below is a result to show the `__tmp_lag__` column for above df, the limit n is
2804
+ # `-1`, the `__tmp_lag__` will be set to `0` in rows[:-1], and left will be set to
2805
+ # `null`:
2806
+ #
2807
+ # >>> sdf.withColumn(tmp_lag_col, F.lag(F.lit(0), -1).over(ordered_window))
2808
+ # +-----------------+--------------+---+---+-----------------+-----------+
2809
+ # |__index_level_0__|__groupkey_0__| A| B|__natural_order__|__tmp_lag__|
2810
+ # +-----------------+--------------+---+---+-----------------+-----------+
2811
+ # | 0| g| g| g0| 0| 0|
2812
+ # | 1| g| g| g1| 8589934592| 0|
2813
+ # | 2| g| g| g2| 17179869184| 0|
2814
+ # | 3| g| g| g3| 25769803776| null|
2815
+ # | 4| h| h| h0| 34359738368| 0|
2816
+ # | 5| h| h| h1| 42949672960| null|
2817
+ # +-----------------+--------------+---+---+-----------------+-----------+
2818
+ #
2819
+ tmp_lag_col = verify_temp_column_name(sdf, "__tmp_lag__")
2820
+ sdf = (
2821
+ sdf.withColumn(tmp_lag_col, F.lag(F.lit(0), n).over(ordered_window))
2822
+ .where(~F.isnull(F.col(tmp_lag_col)))
2823
+ .drop(tmp_lag_col)
2824
+ )
2825
+
2826
+ internal = psdf._internal.with_new_sdf(sdf)
2827
+ return self._handle_output(DataFrame(internal).drop(groupkey_labels, axis=1))
2828
+
2829
+ def head(self, n: int = 5) -> FrameLike:
2830
+ """
2831
+ Return first n rows of each group.
2832
+
2833
+ Returns
2834
+ -------
2835
+ DataFrame or Series
2836
+
2837
+ Examples
2838
+ --------
2839
+ >>> df = ps.DataFrame({'a': [1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
2840
+ ... 'b': [2, 3, 1, 4, 6, 9, 8, 10, 7, 5],
2841
+ ... 'c': [3, 5, 2, 5, 1, 2, 6, 4, 3, 6]},
2842
+ ... columns=['a', 'b', 'c'],
2843
+ ... index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6])
2844
+ >>> df
2845
+ a b c
2846
+ 7 1 2 3
2847
+ 2 1 3 5
2848
+ 4 1 1 2
2849
+ 1 1 4 5
2850
+ 3 2 6 1
2851
+ 4 2 9 2
2852
+ 9 2 8 6
2853
+ 10 3 10 4
2854
+ 5 3 7 3
2855
+ 6 3 5 6
2856
+
2857
+ >>> df.groupby('a').head(2).sort_index()
2858
+ a b c
2859
+ 2 1 3 5
2860
+ 3 2 6 1
2861
+ 4 2 9 2
2862
+ 5 3 7 3
2863
+ 7 1 2 3
2864
+ 10 3 10 4
2865
+
2866
+ >>> df.groupby('a')['b'].head(2).sort_index()
2867
+ 2 3
2868
+ 3 6
2869
+ 4 9
2870
+ 5 7
2871
+ 7 2
2872
+ 10 10
2873
+ Name: b, dtype: int64
2874
+
2875
+ Supports Groupby positional indexing Since pandas on Spark 3.4 (with pandas 1.4+):
2876
+
2877
+ >>> df = ps.DataFrame([["g", "g0"],
2878
+ ... ["g", "g1"],
2879
+ ... ["g", "g2"],
2880
+ ... ["g", "g3"],
2881
+ ... ["h", "h0"],
2882
+ ... ["h", "h1"]], columns=["A", "B"])
2883
+ >>> df.groupby("A").head(-1) # doctest: +SKIP
2884
+ A B
2885
+ 0 g g0
2886
+ 1 g g1
2887
+ 2 g g2
2888
+ 4 h h0
2889
+ """
2890
+ return self._limit(n, asc=True)
2891
+
2892
+ def tail(self, n: int = 5) -> FrameLike:
2893
+ """
2894
+ Return last n rows of each group.
2895
+
2896
+ Similar to `.apply(lambda x: x.tail(n))`, but it returns a subset of rows from
2897
+ the original DataFrame with original index and order preserved (`as_index` flag is ignored).
2898
+
2899
+ Does not work for negative values of n.
2900
+
2901
+ Returns
2902
+ -------
2903
+ DataFrame or Series
2904
+
2905
+ Examples
2906
+ --------
2907
+ >>> df = ps.DataFrame({'a': [1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
2908
+ ... 'b': [2, 3, 1, 4, 6, 9, 8, 10, 7, 5],
2909
+ ... 'c': [3, 5, 2, 5, 1, 2, 6, 4, 3, 6]},
2910
+ ... columns=['a', 'b', 'c'],
2911
+ ... index=[7, 2, 3, 1, 3, 4, 9, 10, 5, 6])
2912
+ >>> df
2913
+ a b c
2914
+ 7 1 2 3
2915
+ 2 1 3 5
2916
+ 3 1 1 2
2917
+ 1 1 4 5
2918
+ 3 2 6 1
2919
+ 4 2 9 2
2920
+ 9 2 8 6
2921
+ 10 3 10 4
2922
+ 5 3 7 3
2923
+ 6 3 5 6
2924
+
2925
+ >>> df.groupby('a').tail(2).sort_index()
2926
+ a b c
2927
+ 1 1 4 5
2928
+ 3 1 1 2
2929
+ 4 2 9 2
2930
+ 5 3 7 3
2931
+ 6 3 5 6
2932
+ 9 2 8 6
2933
+
2934
+ >>> df.groupby('a')['b'].tail(2).sort_index()
2935
+ 1 4
2936
+ 3 1
2937
+ 4 9
2938
+ 5 7
2939
+ 6 5
2940
+ 9 8
2941
+ Name: b, dtype: int64
2942
+
2943
+ Supports Groupby positional indexing Since pandas on Spark 3.4 (with pandas 1.4+):
2944
+
2945
+ >>> df = ps.DataFrame([["g", "g0"],
2946
+ ... ["g", "g1"],
2947
+ ... ["g", "g2"],
2948
+ ... ["g", "g3"],
2949
+ ... ["h", "h0"],
2950
+ ... ["h", "h1"]], columns=["A", "B"])
2951
+ >>> df.groupby("A").tail(-1) # doctest: +SKIP
2952
+ A B
2953
+ 3 g g3
2954
+ 2 g g2
2955
+ 1 g g1
2956
+ 5 h h1
2957
+ """
2958
+ return self._limit(n, asc=False)
2959
+
2960
+ def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> FrameLike:
2961
+ """
2962
+ Shift each group by periods observations.
2963
+
2964
+ Parameters
2965
+ ----------
2966
+ periods : integer, default 1
2967
+ number of periods to shift
2968
+ fill_value : optional
2969
+
2970
+ Returns
2971
+ -------
2972
+ Series or DataFrame
2973
+ Object shifted within each group.
2974
+
2975
+ Examples
2976
+ --------
2977
+
2978
+ >>> df = ps.DataFrame({
2979
+ ... 'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
2980
+ ... 'b': [1, 2, 2, 2, 3, 3, 3, 4, 4]}, columns=['a', 'b'])
2981
+ >>> df
2982
+ a b
2983
+ 0 1 1
2984
+ 1 1 2
2985
+ 2 1 2
2986
+ 3 2 2
2987
+ 4 2 3
2988
+ 5 2 3
2989
+ 6 3 3
2990
+ 7 3 4
2991
+ 8 3 4
2992
+
2993
+ >>> df.groupby('a').shift().sort_index() # doctest: +SKIP
2994
+ b
2995
+ 0 NaN
2996
+ 1 1.0
2997
+ 2 2.0
2998
+ 3 NaN
2999
+ 4 2.0
3000
+ 5 3.0
3001
+ 6 NaN
3002
+ 7 3.0
3003
+ 8 4.0
3004
+
3005
+ >>> df.groupby('a').shift(periods=-1, fill_value=0).sort_index() # doctest: +SKIP
3006
+ b
3007
+ 0 2
3008
+ 1 2
3009
+ 2 0
3010
+ 3 3
3011
+ 4 3
3012
+ 5 0
3013
+ 6 4
3014
+ 7 4
3015
+ 8 0
3016
+ """
3017
+ return self._apply_series_op(
3018
+ lambda sg: sg._psser._shift(periods, fill_value, part_cols=sg._groupkeys_scols),
3019
+ should_resolve=True,
3020
+ )
3021
+
3022
+ def transform(self, func: Callable[..., pd.Series], *args: Any, **kwargs: Any) -> FrameLike:
3023
+ """
3024
+ Apply function column-by-column to the GroupBy object.
3025
+
3026
+ The function passed to `transform` must take a Series as its first
3027
+ argument and return a Series. The given function is executed for
3028
+ each series in each grouped data.
3029
+
3030
+ While `transform` is a very flexible method, its downside is that
3031
+ using it can be quite a bit slower than using more specific methods
3032
+ like `agg` or `transform`. pandas-on-Spark offers a wide range of method that will
3033
+ be much faster than using `transform` for their specific purposes, so try to
3034
+ use them before reaching for `transform`.
3035
+
3036
+ .. note:: this API executes the function once to infer the type which is
3037
+ potentially expensive, for instance, when the dataset is created after
3038
+ aggregations or sorting.
3039
+
3040
+ To avoid this, specify return type in ``func``, for instance, as below:
3041
+
3042
+ >>> def convert_to_string(x) -> ps.Series[str]:
3043
+ ... return x.apply("a string {}".format)
3044
+
3045
+ When the given function has the return type annotated, the original index of the
3046
+ GroupBy object will be lost, and a default index will be attached to the result.
3047
+ Please be careful about configuring the default index. See also `Default Index Type
3048
+ <https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type>`_.
3049
+
3050
+ .. note:: the series within ``func`` is actually a pandas series. Therefore,
3051
+ any pandas API within this function is allowed.
3052
+
3053
+
3054
+ Parameters
3055
+ ----------
3056
+ func : callable
3057
+ A callable that takes a Series as its first argument, and
3058
+ returns a Series.
3059
+ *args
3060
+ Positional arguments to pass to func.
3061
+ **kwargs
3062
+ Keyword arguments to pass to func.
3063
+
3064
+ Returns
3065
+ -------
3066
+ applied : DataFrame
3067
+
3068
+ See Also
3069
+ --------
3070
+ aggregate : Apply aggregate function to the GroupBy object.
3071
+ Series.apply : Apply a function to a Series.
3072
+
3073
+ Examples
3074
+ --------
3075
+
3076
+ >>> df = ps.DataFrame({'A': [0, 0, 1],
3077
+ ... 'B': [1, 2, 3],
3078
+ ... 'C': [4, 6, 5]}, columns=['A', 'B', 'C'])
3079
+
3080
+ >>> g = df.groupby('A')
3081
+
3082
+ Notice that ``g`` has two groups, ``0`` and ``1``.
3083
+ Calling `transform` in various ways, we can get different grouping results:
3084
+ Below the functions passed to `transform` takes a Series as
3085
+ its argument and returns a Series. `transform` applies the function on each series
3086
+ in each grouped data, and combine them into a new DataFrame:
3087
+
3088
+ >>> def convert_to_string(x) -> ps.Series[str]:
3089
+ ... return x.apply("a string {}".format)
3090
+ >>> g.transform(convert_to_string) # doctest: +NORMALIZE_WHITESPACE
3091
+ B C
3092
+ 0 a string 1 a string 4
3093
+ 1 a string 2 a string 6
3094
+ 2 a string 3 a string 5
3095
+
3096
+ >>> def plus_max(x) -> ps.Series[int]:
3097
+ ... return x + x.max()
3098
+ >>> g.transform(plus_max) # doctest: +NORMALIZE_WHITESPACE
3099
+ B C
3100
+ 0 3 10
3101
+ 1 4 12
3102
+ 2 6 10
3103
+
3104
+ You can omit the type hint and let pandas-on-Spark infer its type.
3105
+
3106
+ >>> def plus_min(x):
3107
+ ... return x + x.min()
3108
+ >>> g.transform(plus_min) # doctest: +NORMALIZE_WHITESPACE
3109
+ B C
3110
+ 0 2 8
3111
+ 1 3 10
3112
+ 2 6 10
3113
+
3114
+ In case of Series, it works as below.
3115
+
3116
+ >>> df.B.groupby(df.A).transform(plus_max)
3117
+ 0 3
3118
+ 1 4
3119
+ 2 6
3120
+ Name: B, dtype: int64
3121
+
3122
+ >>> (df * -1).B.groupby(df.A).transform(abs)
3123
+ 0 1
3124
+ 1 2
3125
+ 2 3
3126
+ Name: B, dtype: int64
3127
+
3128
+ You can also specify extra arguments to pass to the function.
3129
+
3130
+ >>> def calculation(x, y, z) -> ps.Series[int]:
3131
+ ... return x + x.min() + y + z
3132
+ >>> g.transform(calculation, 5, z=20) # doctest: +NORMALIZE_WHITESPACE
3133
+ B C
3134
+ 0 27 33
3135
+ 1 28 35
3136
+ 2 31 35
3137
+ """
3138
+ if not callable(func):
3139
+ raise TypeError("%s object is not callable" % type(func).__name__)
3140
+
3141
+ spec = inspect.getfullargspec(func)
3142
+ return_sig = spec.annotations.get("return", None)
3143
+
3144
+ psdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply(
3145
+ self._psdf, self._groupkeys, agg_columns=self._agg_columns
3146
+ )
3147
+
3148
+ def pandas_transform(pdf: pd.DataFrame) -> pd.DataFrame:
3149
+ return pdf.groupby(groupkey_names).transform(func, *args, **kwargs)
3150
+
3151
+ should_infer_schema = return_sig is None
3152
+
3153
+ if should_infer_schema:
3154
+ # Here we execute with the first 1000 to get the return type.
3155
+ # If the records were less than 1000, it uses pandas API directly for a shortcut.
3156
+ log_advice(
3157
+ "If the type hints is not specified for `groupby.transform`, "
3158
+ "it is expensive to infer the data type internally."
3159
+ )
3160
+ limit = get_option("compute.shortcut_limit")
3161
+ pdf = psdf.head(limit + 1)._to_internal_pandas()
3162
+ pdf = pdf.groupby(groupkey_names).transform(func, *args, **kwargs)
3163
+ psdf_from_pandas: DataFrame = DataFrame(pdf)
3164
+ return_schema = force_decimal_precision_scale(
3165
+ as_nullable_spark_type(
3166
+ psdf_from_pandas._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
3167
+ )
3168
+ )
3169
+ if len(pdf) <= limit:
3170
+ return self._handle_output(psdf_from_pandas)
3171
+
3172
+ sdf = GroupBy._spark_group_map_apply(
3173
+ psdf,
3174
+ pandas_transform,
3175
+ [psdf._internal.spark_column_for(label) for label in groupkey_labels],
3176
+ return_schema,
3177
+ retain_index=True,
3178
+ )
3179
+ # If schema is inferred, we can restore indexes too.
3180
+ internal = psdf_from_pandas._internal.with_new_sdf(
3181
+ sdf,
3182
+ index_fields=[
3183
+ field.copy(nullable=True) for field in psdf_from_pandas._internal.index_fields
3184
+ ],
3185
+ data_fields=[
3186
+ field.copy(nullable=True) for field in psdf_from_pandas._internal.data_fields
3187
+ ],
3188
+ )
3189
+ else:
3190
+ return_type = infer_return_type(func)
3191
+ if not isinstance(return_type, SeriesType):
3192
+ raise TypeError(
3193
+ "Expected the return type of this function to be of Series type, "
3194
+ "but found type {}".format(return_type)
3195
+ )
3196
+
3197
+ dtype = return_type.dtype
3198
+ spark_type = return_type.spark_type
3199
+
3200
+ data_fields = [
3201
+ InternalField(dtype=dtype, struct_field=StructField(name=c, dataType=spark_type))
3202
+ for c in psdf._internal.data_spark_column_names
3203
+ if c not in groupkey_names
3204
+ ]
3205
+
3206
+ return_schema = StructType([field.struct_field for field in data_fields])
3207
+
3208
+ sdf = GroupBy._spark_group_map_apply(
3209
+ psdf,
3210
+ pandas_transform,
3211
+ [psdf._internal.spark_column_for(label) for label in groupkey_labels],
3212
+ return_schema,
3213
+ retain_index=False,
3214
+ )
3215
+ # Otherwise, it loses index.
3216
+ internal = InternalFrame(
3217
+ spark_frame=sdf, index_spark_columns=None, data_fields=data_fields
3218
+ )
3219
+
3220
+ return self._handle_output(DataFrame(internal))
3221
+
3222
+ def nunique(self, dropna: bool = True) -> FrameLike:
3223
+ """
3224
+ Return DataFrame with number of distinct observations per group for each column.
3225
+
3226
+ Parameters
3227
+ ----------
3228
+ dropna : boolean, default True
3229
+ Don’t include NaN in the counts.
3230
+
3231
+ Returns
3232
+ -------
3233
+ nunique : DataFrame or Series
3234
+
3235
+ Examples
3236
+ --------
3237
+
3238
+ >>> df = ps.DataFrame({'id': ['spam', 'egg', 'egg', 'spam',
3239
+ ... 'ham', 'ham'],
3240
+ ... 'value1': [1, 5, 5, 2, 5, 5],
3241
+ ... 'value2': list('abbaxy')}, columns=['id', 'value1', 'value2'])
3242
+ >>> df
3243
+ id value1 value2
3244
+ 0 spam 1 a
3245
+ 1 egg 5 b
3246
+ 2 egg 5 b
3247
+ 3 spam 2 a
3248
+ 4 ham 5 x
3249
+ 5 ham 5 y
3250
+
3251
+ >>> df.groupby('id').nunique().sort_index() # doctest: +SKIP
3252
+ value1 value2
3253
+ id
3254
+ egg 1 1
3255
+ ham 1 2
3256
+ spam 2 1
3257
+
3258
+ >>> df.groupby('id')['value1'].nunique().sort_index() # doctest: +NORMALIZE_WHITESPACE
3259
+ id
3260
+ egg 1
3261
+ ham 1
3262
+ spam 2
3263
+ Name: value1, dtype: int64
3264
+ """
3265
+ if dropna:
3266
+
3267
+ def stat_function(col: Column) -> Column:
3268
+ return F.countDistinct(col)
3269
+
3270
+ else:
3271
+
3272
+ def stat_function(col: Column) -> Column:
3273
+ return F.countDistinct(col) + F.when(
3274
+ F.count(F.when(col.isNull(), 1).otherwise(None)) >= 1, 1
3275
+ ).otherwise(0)
3276
+
3277
+ return self._reduce_for_stat_function(stat_function)
3278
+
3279
+ def rolling(
3280
+ self, window: int, min_periods: Optional[int] = None
3281
+ ) -> "RollingGroupby[FrameLike]":
3282
+ """
3283
+ Return an rolling grouper, providing rolling
3284
+ functionality per group.
3285
+
3286
+ .. note:: 'min_periods' in pandas-on-Spark works as a fixed window size unlike pandas.
3287
+ Unlike pandas, NA is also counted as the period. This might be changed
3288
+ soon.
3289
+
3290
+ Parameters
3291
+ ----------
3292
+ window : int, or offset
3293
+ Size of the moving window.
3294
+ This is the number of observations used for calculating the statistic.
3295
+ Each window will be a fixed size.
3296
+
3297
+ min_periods : int, default 1
3298
+ Minimum number of observations in window required to have a value
3299
+ (otherwise result is NA).
3300
+
3301
+ See Also
3302
+ --------
3303
+ Series.groupby
3304
+ DataFrame.groupby
3305
+ """
3306
+ from pyspark.pandas.window import RollingGroupby
3307
+
3308
+ return RollingGroupby(self, window, min_periods=min_periods)
3309
+
3310
+ def expanding(self, min_periods: int = 1) -> "ExpandingGroupby[FrameLike]":
3311
+ """
3312
+ Return an expanding grouper, providing expanding
3313
+ functionality per group.
3314
+
3315
+ .. note:: 'min_periods' in pandas-on-Spark works as a fixed window size unlike pandas.
3316
+ Unlike pandas, NA is also counted as the period. This might be changed
3317
+ soon.
3318
+
3319
+ Parameters
3320
+ ----------
3321
+ min_periods : int, default 1
3322
+ Minimum number of observations in window required to have a value
3323
+ (otherwise result is NA).
3324
+
3325
+ See Also
3326
+ --------
3327
+ Series.groupby
3328
+ DataFrame.groupby
3329
+ """
3330
+ from pyspark.pandas.window import ExpandingGroupby
3331
+
3332
+ return ExpandingGroupby(self, min_periods=min_periods)
3333
+
3334
+ # TODO: 'adjust', 'axis', 'method' parameter should be implemented.
3335
+ def ewm(
3336
+ self,
3337
+ com: Optional[float] = None,
3338
+ span: Optional[float] = None,
3339
+ halflife: Optional[float] = None,
3340
+ alpha: Optional[float] = None,
3341
+ min_periods: Optional[int] = None,
3342
+ ignore_na: bool = False,
3343
+ ) -> "ExponentialMovingGroupby[FrameLike]":
3344
+ """
3345
+ Return an ewm grouper, providing ewm functionality per group.
3346
+
3347
+ .. note:: 'min_periods' in pandas-on-Spark works as a fixed window size unlike pandas.
3348
+ Unlike pandas, NA is also counted as the period. This might be changed
3349
+ soon.
3350
+
3351
+ .. versionadded:: 3.4.0
3352
+
3353
+ Parameters
3354
+ ----------
3355
+ com : float, optional
3356
+ Specify decay in terms of center of mass.
3357
+ alpha = 1 / (1 + com), for com >= 0.
3358
+
3359
+ span : float, optional
3360
+ Specify decay in terms of span.
3361
+ alpha = 2 / (span + 1), for span >= 1.
3362
+
3363
+ halflife : float, optional
3364
+ Specify decay in terms of half-life.
3365
+ alpha = 1 - exp(-ln(2) / halflife), for halflife > 0.
3366
+
3367
+ alpha : float, optional
3368
+ Specify smoothing factor alpha directly.
3369
+ 0 < alpha <= 1.
3370
+
3371
+ min_periods : int, default None
3372
+ Minimum number of observations in window required to have a value
3373
+ (otherwise result is NA).
3374
+
3375
+ ignore_na : bool, default False
3376
+ Ignore missing values when calculating weights.
3377
+
3378
+ - When ``ignore_na=False`` (default), weights are based on absolute positions.
3379
+ For example, the weights of :math:`x_0` and :math:`x_2` used in calculating
3380
+ the final weighted average of [:math:`x_0`, None, :math:`x_2`] are
3381
+ :math:`(1-\alpha)^2` and :math:`1` if ``adjust=True``, and
3382
+ :math:`(1-\alpha)^2` and :math:`\alpha` if ``adjust=False``.
3383
+
3384
+ - When ``ignore_na=True``, weights are based
3385
+ on relative positions. For example, the weights of :math:`x_0` and :math:`x_2`
3386
+ used in calculating the final weighted average of
3387
+ [:math:`x_0`, None, :math:`x_2`] are :math:`1-\alpha` and :math:`1` if
3388
+ ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``.
3389
+ """
3390
+ from pyspark.pandas.window import ExponentialMovingGroupby
3391
+
3392
+ return ExponentialMovingGroupby(
3393
+ self,
3394
+ com=com,
3395
+ span=span,
3396
+ halflife=halflife,
3397
+ alpha=alpha,
3398
+ min_periods=min_periods,
3399
+ ignore_na=ignore_na,
3400
+ )
3401
+
3402
+ def get_group(self, name: Union[Name, List[Name]]) -> FrameLike:
3403
+ """
3404
+ Construct DataFrame from group with provided name.
3405
+
3406
+ Parameters
3407
+ ----------
3408
+ name : object
3409
+ The name of the group to get as a DataFrame.
3410
+
3411
+ Returns
3412
+ -------
3413
+ group : same type as obj
3414
+
3415
+ Examples
3416
+ --------
3417
+ >>> psdf = ps.DataFrame([('falcon', 'bird', 389.0),
3418
+ ... ('parrot', 'bird', 24.0),
3419
+ ... ('lion', 'mammal', 80.5),
3420
+ ... ('monkey', 'mammal', np.nan)],
3421
+ ... columns=['name', 'class', 'max_speed'],
3422
+ ... index=[0, 2, 3, 1])
3423
+ >>> psdf
3424
+ name class max_speed
3425
+ 0 falcon bird 389.0
3426
+ 2 parrot bird 24.0
3427
+ 3 lion mammal 80.5
3428
+ 1 monkey mammal NaN
3429
+
3430
+ >>> psdf.groupby("class").get_group("bird").sort_index()
3431
+ name class max_speed
3432
+ 0 falcon bird 389.0
3433
+ 2 parrot bird 24.0
3434
+
3435
+ >>> psdf.groupby("class").get_group("mammal").sort_index()
3436
+ name class max_speed
3437
+ 1 monkey mammal NaN
3438
+ 3 lion mammal 80.5
3439
+ """
3440
+ groupkeys = self._groupkeys
3441
+ if not is_hashable(name):
3442
+ raise TypeError("unhashable type: '{}'".format(type(name).__name__))
3443
+ elif len(groupkeys) > 1:
3444
+ if not isinstance(name, tuple):
3445
+ raise ValueError("must supply a tuple to get_group with multiple grouping keys")
3446
+ if len(groupkeys) != len(name):
3447
+ raise ValueError(
3448
+ "must supply a same-length tuple to get_group with multiple grouping keys"
3449
+ )
3450
+ if not is_list_like(name):
3451
+ name = [name]
3452
+ cond = F.lit(True)
3453
+ for groupkey, item in zip(groupkeys, name):
3454
+ scol = groupkey.spark.column
3455
+ cond = cond & (scol == item)
3456
+ if self._agg_columns_selected:
3457
+ internal = self._psdf._internal
3458
+ spark_frame = internal.spark_frame.select(
3459
+ internal.index_spark_columns + self._agg_columns_scols
3460
+ ).filter(cond)
3461
+
3462
+ internal = internal.copy(
3463
+ spark_frame=spark_frame,
3464
+ index_spark_columns=[
3465
+ scol_for(spark_frame, col) for col in internal.index_spark_column_names
3466
+ ],
3467
+ column_labels=[s._column_label for s in self._agg_columns],
3468
+ data_spark_columns=[
3469
+ scol_for(spark_frame, s._internal.data_spark_column_names[0])
3470
+ for s in self._agg_columns
3471
+ ],
3472
+ data_fields=[s._internal.data_fields[0] for s in self._agg_columns],
3473
+ )
3474
+ else:
3475
+ internal = self._psdf._internal.with_filter(cond)
3476
+ if internal.spark_frame.head() is None:
3477
+ raise KeyError(name)
3478
+
3479
+ return self._handle_output(DataFrame(internal))
3480
+
3481
+ def median(self, numeric_only: Optional[bool] = True, accuracy: int = 10000) -> FrameLike:
3482
+ """
3483
+ Compute median of groups, excluding missing values.
3484
+
3485
+ For multiple groupings, the result index will be a MultiIndex
3486
+
3487
+ .. note:: Unlike pandas', the median in pandas-on-Spark is an approximated median based upon
3488
+ approximate percentile computation because computing median across a large dataset
3489
+ is extremely expensive.
3490
+
3491
+ Parameters
3492
+ ----------
3493
+ numeric_only : bool, default False
3494
+ Include only float, int, boolean columns. If None, will attempt to use
3495
+ everything, then use only numeric data.
3496
+
3497
+ .. versionadded:: 3.4.0
3498
+
3499
+ Returns
3500
+ -------
3501
+ Series or DataFrame
3502
+ Median of values within each group.
3503
+
3504
+ Examples
3505
+ --------
3506
+ >>> psdf = ps.DataFrame({'a': [1., 1., 1., 1., 2., 2., 2., 3., 3., 3.],
3507
+ ... 'b': [2., 3., 1., 4., 6., 9., 8., 10., 7., 5.],
3508
+ ... 'c': [3., 5., 2., 5., 1., 2., 6., 4., 3., 6.]},
3509
+ ... columns=['a', 'b', 'c'],
3510
+ ... index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6])
3511
+ >>> psdf
3512
+ a b c
3513
+ 7 1.0 2.0 3.0
3514
+ 2 1.0 3.0 5.0
3515
+ 4 1.0 1.0 2.0
3516
+ 1 1.0 4.0 5.0
3517
+ 3 2.0 6.0 1.0
3518
+ 4 2.0 9.0 2.0
3519
+ 9 2.0 8.0 6.0
3520
+ 10 3.0 10.0 4.0
3521
+ 5 3.0 7.0 3.0
3522
+ 6 3.0 5.0 6.0
3523
+
3524
+ DataFrameGroupBy
3525
+
3526
+ >>> psdf.groupby('a').median().sort_index() # doctest: +NORMALIZE_WHITESPACE
3527
+ b c
3528
+ a
3529
+ 1.0 2.0 3.0
3530
+ 2.0 8.0 2.0
3531
+ 3.0 7.0 4.0
3532
+
3533
+ SeriesGroupBy
3534
+
3535
+ >>> psdf.groupby('a')['b'].median().sort_index()
3536
+ a
3537
+ 1.0 2.0
3538
+ 2.0 8.0
3539
+ 3.0 7.0
3540
+ Name: b, dtype: float64
3541
+ """
3542
+ if not isinstance(accuracy, int):
3543
+ raise TypeError(
3544
+ "accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
3545
+ )
3546
+
3547
+ self._validate_agg_columns(numeric_only=numeric_only, function_name="median")
3548
+
3549
+ warnings.warn(
3550
+ "Default value of `numeric_only` will be changed to `False` "
3551
+ "instead of `True` in 4.0.0.",
3552
+ FutureWarning,
3553
+ )
3554
+
3555
+ def stat_function(col: Column) -> Column:
3556
+ return F.percentile_approx(col, 0.5, accuracy)
3557
+
3558
+ return self._reduce_for_stat_function(
3559
+ stat_function,
3560
+ accepted_spark_types=(NumericType,),
3561
+ bool_to_numeric=True,
3562
+ )
3563
+
3564
+ def _validate_agg_columns(self, numeric_only: Optional[bool], function_name: str) -> None:
3565
+ """Validate aggregation columns and raise an error or a warning following pandas."""
3566
+ has_non_numeric = False
3567
+ for _agg_col in self._agg_columns:
3568
+ if not isinstance(_agg_col.spark.data_type, (NumericType, BooleanType)):
3569
+ has_non_numeric = True
3570
+ break
3571
+ if has_non_numeric:
3572
+ if isinstance(self, SeriesGroupBy):
3573
+ raise TypeError("Only numeric aggregation column is accepted.")
3574
+
3575
+ if not numeric_only and has_non_numeric:
3576
+ warnings.warn(
3577
+ "Dropping invalid columns in DataFrameGroupBy.%s is deprecated. "
3578
+ "In a future version, a TypeError will be raised. "
3579
+ "Before calling .%s, select only columns which should be "
3580
+ "valid for the function." % (function_name, function_name),
3581
+ FutureWarning,
3582
+ )
3583
+
3584
+ def _reduce_for_stat_function(
3585
+ self,
3586
+ sfun: Callable[[Column], Column],
3587
+ accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None,
3588
+ bool_to_numeric: bool = False,
3589
+ **kwargs: Any,
3590
+ ) -> FrameLike:
3591
+ """Apply an aggregate function `sfun` per column and reduce to a FrameLike.
3592
+
3593
+ Parameters
3594
+ ----------
3595
+ sfun : The aggregate function to apply per column.
3596
+ accepted_spark_types: Accepted spark types of columns to be aggregated;
3597
+ default None means all spark types are accepted.
3598
+ bool_to_numeric: If True, boolean columns are converted to numeric columns, which
3599
+ are accepted for all statistical functions regardless of
3600
+ `accepted_spark_types`.
3601
+ """
3602
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
3603
+ internal, _, sdf = self._prepare_reduce(
3604
+ groupkey_names, accepted_spark_types, bool_to_numeric
3605
+ )
3606
+ psdf: DataFrame = DataFrame(internal)
3607
+
3608
+ if len(psdf._internal.column_labels) > 0:
3609
+ min_count = kwargs.get("min_count", 0)
3610
+ stat_exprs = []
3611
+ for label in psdf._internal.column_labels:
3612
+ psser = psdf._psser_for(label)
3613
+ input_scol = psser._dtype_op.nan_to_null(psser).spark.column
3614
+ output_scol = sfun(input_scol)
3615
+
3616
+ if min_count > 0:
3617
+ output_scol = F.when(
3618
+ F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol
3619
+ )
3620
+
3621
+ stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0]))
3622
+ sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
3623
+ else:
3624
+ sdf = sdf.select(*groupkey_names).distinct()
3625
+
3626
+ internal = internal.copy(
3627
+ spark_frame=sdf,
3628
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
3629
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
3630
+ data_fields=None,
3631
+ )
3632
+ psdf = DataFrame(internal)
3633
+
3634
+ return self._prepare_return(psdf)
3635
+
3636
+ def _prepare_return(self, psdf: DataFrame) -> FrameLike:
3637
+ if self._dropna:
3638
+ psdf = DataFrame(
3639
+ psdf._internal.with_new_sdf(
3640
+ psdf._internal.spark_frame.dropna(
3641
+ subset=psdf._internal.index_spark_column_names
3642
+ )
3643
+ )
3644
+ )
3645
+ if not self._as_index:
3646
+ should_drop_index = set(
3647
+ i for i, gkey in enumerate(self._groupkeys) if gkey._psdf is not self._psdf
3648
+ )
3649
+ if len(should_drop_index) > 0:
3650
+ psdf = psdf.reset_index(level=should_drop_index, drop=True)
3651
+ if len(should_drop_index) < len(self._groupkeys):
3652
+ psdf = psdf.reset_index()
3653
+ return self._handle_output(psdf)
3654
+
3655
+ def _prepare_reduce(
3656
+ self,
3657
+ groupkey_names: List,
3658
+ accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None,
3659
+ bool_to_numeric: bool = False,
3660
+ ) -> Tuple[InternalFrame, List[Series], SparkDataFrame]:
3661
+ groupkey_scols = [s.alias(name) for s, name in zip(self._groupkeys_scols, groupkey_names)]
3662
+ agg_columns = []
3663
+ for psser in self._agg_columns:
3664
+ if bool_to_numeric and isinstance(psser.spark.data_type, BooleanType):
3665
+ agg_columns.append(psser.astype(int))
3666
+ elif (accepted_spark_types is None) or isinstance(
3667
+ psser.spark.data_type, accepted_spark_types
3668
+ ):
3669
+ agg_columns.append(psser)
3670
+ sdf = self._psdf._internal.spark_frame.select(
3671
+ *groupkey_scols, *[psser.spark.column for psser in agg_columns]
3672
+ )
3673
+ internal = InternalFrame(
3674
+ spark_frame=sdf,
3675
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
3676
+ index_names=[psser._column_label for psser in self._groupkeys],
3677
+ index_fields=[
3678
+ psser._internal.data_fields[0].copy(name=name)
3679
+ for psser, name in zip(self._groupkeys, groupkey_names)
3680
+ ],
3681
+ data_spark_columns=[
3682
+ scol_for(sdf, psser._internal.data_spark_column_names[0]) for psser in agg_columns
3683
+ ],
3684
+ column_labels=[psser._column_label for psser in agg_columns],
3685
+ data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
3686
+ column_label_names=self._psdf._internal.column_label_names,
3687
+ )
3688
+ return internal, agg_columns, sdf
3689
+
3690
+ @staticmethod
3691
+ def _resolve_grouping_from_diff_dataframes(
3692
+ psdf: DataFrame, by: List[Union[Series, Label]]
3693
+ ) -> Tuple[DataFrame, List[Series], Set[Label]]:
3694
+ column_labels_level = psdf._internal.column_labels_level
3695
+
3696
+ column_labels = []
3697
+ additional_pssers = []
3698
+ additional_column_labels = []
3699
+ tmp_column_labels = set()
3700
+ for i, col_or_s in enumerate(by):
3701
+ if isinstance(col_or_s, Series):
3702
+ if col_or_s._psdf is psdf:
3703
+ column_labels.append(col_or_s._column_label)
3704
+ elif same_anchor(col_or_s, psdf):
3705
+ temp_label = verify_temp_column_name(psdf, "__tmp_groupkey_{}__".format(i))
3706
+ column_labels.append(temp_label)
3707
+ additional_pssers.append(col_or_s.rename(temp_label))
3708
+ additional_column_labels.append(temp_label)
3709
+ else:
3710
+ temp_label = verify_temp_column_name(
3711
+ psdf,
3712
+ tuple(
3713
+ ([""] * (column_labels_level - 1)) + ["__tmp_groupkey_{}__".format(i)]
3714
+ ),
3715
+ )
3716
+ column_labels.append(temp_label)
3717
+ tmp_column_labels.add(temp_label)
3718
+ elif isinstance(col_or_s, tuple):
3719
+ psser = psdf[col_or_s]
3720
+ if not isinstance(psser, Series):
3721
+ raise ValueError(name_like_string(col_or_s))
3722
+ column_labels.append(col_or_s)
3723
+ else:
3724
+ raise ValueError(col_or_s)
3725
+
3726
+ psdf = DataFrame(
3727
+ psdf._internal.with_new_columns(
3728
+ [psdf._psser_for(label) for label in psdf._internal.column_labels]
3729
+ + additional_pssers
3730
+ )
3731
+ )
3732
+
3733
+ def assign_columns(
3734
+ psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label]
3735
+ ) -> Iterator[Tuple[Series, Label]]:
3736
+ raise NotImplementedError(
3737
+ "Duplicated labels with groupby() and "
3738
+ "'compute.ops_on_diff_frames' option is not supported currently "
3739
+ "Please use unique labels in series and frames."
3740
+ )
3741
+
3742
+ for col_or_s, label in zip(by, column_labels):
3743
+ if label in tmp_column_labels:
3744
+ psser = col_or_s
3745
+ psdf = align_diff_frames(
3746
+ assign_columns,
3747
+ psdf,
3748
+ psser.rename(label),
3749
+ fillna=False,
3750
+ how="inner",
3751
+ preserve_order_column=True,
3752
+ )
3753
+
3754
+ tmp_column_labels |= set(additional_column_labels)
3755
+
3756
+ new_by_series = []
3757
+ for col_or_s, label in zip(by, column_labels):
3758
+ if label in tmp_column_labels:
3759
+ psser = col_or_s
3760
+ new_by_series.append(psdf._psser_for(label).rename(psser.name))
3761
+ else:
3762
+ new_by_series.append(psdf._psser_for(label))
3763
+
3764
+ return psdf, new_by_series, tmp_column_labels
3765
+
3766
+ @staticmethod
3767
+ def _resolve_grouping(psdf: DataFrame, by: List[Union[Series, Label]]) -> List[Series]:
3768
+ new_by_series = []
3769
+ for col_or_s in by:
3770
+ if isinstance(col_or_s, Series):
3771
+ new_by_series.append(col_or_s)
3772
+ elif isinstance(col_or_s, tuple):
3773
+ psser = psdf[col_or_s]
3774
+ if not isinstance(psser, Series):
3775
+ raise ValueError(name_like_string(col_or_s))
3776
+ new_by_series.append(psser)
3777
+ else:
3778
+ raise ValueError(col_or_s)
3779
+ return new_by_series
3780
+
3781
+
3782
+ class DataFrameGroupBy(GroupBy[DataFrame]):
3783
+ @staticmethod
3784
+ def _build(
3785
+ psdf: DataFrame, by: List[Union[Series, Label]], as_index: bool, dropna: bool
3786
+ ) -> "DataFrameGroupBy":
3787
+ if any(isinstance(col_or_s, Series) and not same_anchor(psdf, col_or_s) for col_or_s in by):
3788
+ (
3789
+ psdf,
3790
+ new_by_series,
3791
+ column_labels_to_exclude,
3792
+ ) = GroupBy._resolve_grouping_from_diff_dataframes(psdf, by)
3793
+ else:
3794
+ new_by_series = GroupBy._resolve_grouping(psdf, by)
3795
+ column_labels_to_exclude = set()
3796
+ return DataFrameGroupBy(
3797
+ psdf,
3798
+ new_by_series,
3799
+ as_index=as_index,
3800
+ dropna=dropna,
3801
+ column_labels_to_exclude=column_labels_to_exclude,
3802
+ )
3803
+
3804
+ def __init__(
3805
+ self,
3806
+ psdf: DataFrame,
3807
+ by: List[Series],
3808
+ as_index: bool,
3809
+ dropna: bool,
3810
+ column_labels_to_exclude: Set[Label],
3811
+ agg_columns: List[Label] = None,
3812
+ ):
3813
+ agg_columns_selected = agg_columns is not None
3814
+ if agg_columns_selected:
3815
+ for label in agg_columns:
3816
+ if label in column_labels_to_exclude:
3817
+ raise KeyError(label)
3818
+ else:
3819
+ agg_columns = [
3820
+ label
3821
+ for label in psdf._internal.column_labels
3822
+ if not any(label == key._column_label and key._psdf is psdf for key in by)
3823
+ and label not in column_labels_to_exclude
3824
+ ]
3825
+
3826
+ super().__init__(
3827
+ psdf=psdf,
3828
+ groupkeys=by,
3829
+ as_index=as_index,
3830
+ dropna=dropna,
3831
+ column_labels_to_exclude=column_labels_to_exclude,
3832
+ agg_columns_selected=agg_columns_selected,
3833
+ agg_columns=[psdf[label] for label in agg_columns],
3834
+ )
3835
+
3836
+ def __getattr__(self, item: str) -> Any:
3837
+ if hasattr(MissingPandasLikeDataFrameGroupBy, item):
3838
+ property_or_func = getattr(MissingPandasLikeDataFrameGroupBy, item)
3839
+ if isinstance(property_or_func, property):
3840
+ return property_or_func.fget(self)
3841
+ else:
3842
+ return partial(property_or_func, self)
3843
+ return self.__getitem__(item)
3844
+
3845
+ def __getitem__(self, item: Any) -> GroupBy:
3846
+ if self._as_index and is_name_like_value(item):
3847
+ return SeriesGroupBy(
3848
+ self._psdf._psser_for(item if is_name_like_tuple(item) else (item,)),
3849
+ self._groupkeys,
3850
+ dropna=self._dropna,
3851
+ )
3852
+ else:
3853
+ if is_name_like_tuple(item):
3854
+ item = [item]
3855
+ elif is_name_like_value(item):
3856
+ item = [(item,)]
3857
+ else:
3858
+ item = [i if is_name_like_tuple(i) else (i,) for i in item]
3859
+ if not self._as_index:
3860
+ groupkey_names = set(key._column_label for key in self._groupkeys)
3861
+ for name in item:
3862
+ if name in groupkey_names:
3863
+ raise ValueError(
3864
+ "cannot insert {}, already exists".format(name_like_string(name))
3865
+ )
3866
+ return DataFrameGroupBy(
3867
+ self._psdf,
3868
+ self._groupkeys,
3869
+ as_index=self._as_index,
3870
+ dropna=self._dropna,
3871
+ column_labels_to_exclude=self._column_labels_to_exclude,
3872
+ agg_columns=item,
3873
+ )
3874
+
3875
+ def _apply_series_op(
3876
+ self,
3877
+ op: Callable[["SeriesGroupBy"], Series],
3878
+ should_resolve: bool = False,
3879
+ numeric_only: bool = False,
3880
+ ) -> DataFrame:
3881
+ applied = []
3882
+ for column in self._agg_columns:
3883
+ applied.append(op(column.groupby(self._groupkeys)))
3884
+ if numeric_only:
3885
+ applied = [col for col in applied if isinstance(col.spark.data_type, NumericType)]
3886
+ if not applied:
3887
+ raise DataError("No numeric types to aggregate")
3888
+ internal = self._psdf._internal.with_new_columns(applied, keep_order=False)
3889
+ if should_resolve:
3890
+ internal = internal.resolved_copy
3891
+ return DataFrame(internal)
3892
+
3893
+ def _handle_output(self, psdf: DataFrame) -> DataFrame:
3894
+ return psdf
3895
+
3896
+ # TODO: Implement 'percentiles', 'include', and 'exclude' arguments.
3897
+ # TODO: Add ``DataFrame.select_dtypes`` to See Also when 'include'
3898
+ # and 'exclude' arguments are implemented.
3899
+ def describe(self) -> DataFrame:
3900
+ """
3901
+ Generate descriptive statistics that summarize the central tendency,
3902
+ dispersion and shape of a dataset's distribution, excluding
3903
+ ``NaN`` values.
3904
+
3905
+ Analyzes both numeric and object series, as well
3906
+ as ``DataFrame`` column sets of mixed data types. The output
3907
+ will vary depending on what is provided. Refer to the notes
3908
+ below for more detail.
3909
+
3910
+ .. note:: Unlike pandas, the percentiles in pandas-on-Spark are based upon
3911
+ approximate percentile computation because computing percentiles
3912
+ across a large dataset is extremely expensive.
3913
+
3914
+ Returns
3915
+ -------
3916
+ DataFrame
3917
+ Summary statistics of the DataFrame provided.
3918
+
3919
+ See Also
3920
+ --------
3921
+ DataFrame.count
3922
+ DataFrame.max
3923
+ DataFrame.min
3924
+ DataFrame.mean
3925
+ DataFrame.std
3926
+
3927
+ Examples
3928
+ --------
3929
+ >>> df = ps.DataFrame({'a': [1, 1, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
3930
+ >>> df
3931
+ a b c
3932
+ 0 1 4 7
3933
+ 1 1 5 8
3934
+ 2 3 6 9
3935
+
3936
+ Describing a ``DataFrame``. By default only numeric fields
3937
+ are returned.
3938
+
3939
+ >>> described = df.groupby('a').describe()
3940
+ >>> described.sort_index() # doctest: +NORMALIZE_WHITESPACE
3941
+ b c
3942
+ count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max
3943
+ a
3944
+ 1 2.0 4.5 0.707107 4.0 4.0 4.0 5.0 5.0 2.0 7.5 0.707107 7.0 7.0 7.0 8.0 8.0
3945
+ 3 1.0 6.0 NaN 6.0 6.0 6.0 6.0 6.0 1.0 9.0 NaN 9.0 9.0 9.0 9.0 9.0
3946
+
3947
+ """
3948
+ for col in self._agg_columns:
3949
+ if isinstance(col.spark.data_type, StringType):
3950
+ raise NotImplementedError(
3951
+ "DataFrameGroupBy.describe() doesn't support for string type for now"
3952
+ )
3953
+
3954
+ psdf = self.aggregate(["count", "mean", "std", "min", "quartiles", "max"])
3955
+ sdf = psdf._internal.spark_frame
3956
+ agg_column_labels = [col._column_label for col in self._agg_columns]
3957
+ formatted_percentiles = ["25%", "50%", "75%"]
3958
+
3959
+ # Split "quartiles" columns into first, second, and third quartiles.
3960
+ for label in agg_column_labels:
3961
+ quartiles_col = name_like_string(tuple(list(label) + ["quartiles"]))
3962
+ for i, percentile in enumerate(formatted_percentiles):
3963
+ sdf = sdf.withColumn(
3964
+ name_like_string(tuple(list(label) + [percentile])),
3965
+ scol_for(sdf, quartiles_col)[i],
3966
+ )
3967
+ sdf = sdf.drop(quartiles_col)
3968
+
3969
+ # Reorder columns lexicographically by agg column followed by stats.
3970
+ stats = ["count", "mean", "std", "min"] + formatted_percentiles + ["max"]
3971
+ column_labels = [tuple(list(label) + [s]) for label, s in product(agg_column_labels, stats)]
3972
+ data_columns = map(name_like_string, column_labels)
3973
+
3974
+ # Reindex the DataFrame to reflect initial grouping and agg columns.
3975
+ internal = psdf._internal.copy(
3976
+ spark_frame=sdf,
3977
+ column_labels=column_labels,
3978
+ data_spark_columns=[scol_for(sdf, col) for col in data_columns],
3979
+ data_fields=None,
3980
+ )
3981
+
3982
+ # Cast columns to ``"float64"`` to match `pandas.DataFrame.groupby`.
3983
+ return DataFrame(internal).astype("float64")
3984
+
3985
+
3986
+ class SeriesGroupBy(GroupBy[Series]):
3987
+ @staticmethod
3988
+ def _build(
3989
+ psser: Series, by: List[Union[Series, Label]], as_index: bool, dropna: bool
3990
+ ) -> "SeriesGroupBy":
3991
+ if any(
3992
+ isinstance(col_or_s, Series) and not same_anchor(psser, col_or_s) for col_or_s in by
3993
+ ):
3994
+ psdf, new_by_series, _ = GroupBy._resolve_grouping_from_diff_dataframes(
3995
+ psser.to_frame(), by
3996
+ )
3997
+ return SeriesGroupBy(
3998
+ first_series(psdf).rename(psser.name),
3999
+ new_by_series,
4000
+ as_index=as_index,
4001
+ dropna=dropna,
4002
+ )
4003
+ else:
4004
+ new_by_series = GroupBy._resolve_grouping(psser._psdf, by)
4005
+ return SeriesGroupBy(psser, new_by_series, as_index=as_index, dropna=dropna)
4006
+
4007
+ def __init__(self, psser: Series, by: List[Series], as_index: bool = True, dropna: bool = True):
4008
+ if not as_index:
4009
+ raise TypeError("as_index=False only valid with DataFrame")
4010
+ super().__init__(
4011
+ psdf=psser._psdf,
4012
+ groupkeys=by,
4013
+ as_index=True,
4014
+ dropna=dropna,
4015
+ column_labels_to_exclude=set(),
4016
+ agg_columns_selected=True,
4017
+ agg_columns=[psser],
4018
+ )
4019
+ self._psser = psser
4020
+
4021
+ def __getattr__(self, item: str) -> Any:
4022
+ if hasattr(MissingPandasLikeSeriesGroupBy, item):
4023
+ property_or_func = getattr(MissingPandasLikeSeriesGroupBy, item)
4024
+ if isinstance(property_or_func, property):
4025
+ return property_or_func.fget(self)
4026
+ else:
4027
+ return partial(property_or_func, self)
4028
+ raise AttributeError(item)
4029
+
4030
+ def _apply_series_op(
4031
+ self,
4032
+ op: Callable[["SeriesGroupBy"], Series],
4033
+ should_resolve: bool = False,
4034
+ numeric_only: bool = False,
4035
+ ) -> Series:
4036
+ if numeric_only and not isinstance(self._agg_columns[0].spark.data_type, NumericType):
4037
+ raise DataError("No numeric types to aggregate")
4038
+ psser = op(self)
4039
+ if should_resolve:
4040
+ internal = psser._internal.resolved_copy
4041
+ return first_series(DataFrame(internal))
4042
+ else:
4043
+ return psser.copy()
4044
+
4045
+ def _handle_output(self, psdf: DataFrame) -> Series:
4046
+ return first_series(psdf).rename(self._psser.name)
4047
+
4048
+ def agg(self, *args: Any, **kwargs: Any) -> None:
4049
+ return MissingPandasLikeSeriesGroupBy.agg(self, *args, **kwargs)
4050
+
4051
+ def aggregate(self, *args: Any, **kwargs: Any) -> None:
4052
+ return MissingPandasLikeSeriesGroupBy.aggregate(self, *args, **kwargs)
4053
+
4054
+ def size(self) -> Series:
4055
+ return super().size().rename(self._psser.name)
4056
+
4057
+ size.__doc__ = GroupBy.size.__doc__
4058
+
4059
+ # TODO: add keep parameter
4060
+ def nsmallest(self, n: int = 5) -> Series:
4061
+ """
4062
+ Return the smallest `n` elements.
4063
+
4064
+ Parameters
4065
+ ----------
4066
+ n : int
4067
+ Number of items to retrieve.
4068
+
4069
+ See Also
4070
+ --------
4071
+ pyspark.pandas.Series.nsmallest
4072
+ pyspark.pandas.DataFrame.nsmallest
4073
+
4074
+ Examples
4075
+ --------
4076
+ >>> df = ps.DataFrame({'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
4077
+ ... 'b': [1, 2, 2, 2, 3, 3, 3, 4, 4]}, columns=['a', 'b'])
4078
+
4079
+ >>> df.groupby(['a'])['b'].nsmallest(1).sort_index() # doctest: +NORMALIZE_WHITESPACE
4080
+ a
4081
+ 1 0 1
4082
+ 2 3 2
4083
+ 3 6 3
4084
+ Name: b, dtype: int64
4085
+ """
4086
+ if self._psser._internal.index_level > 1:
4087
+ raise ValueError("nsmallest do not support multi-index now")
4088
+
4089
+ groupkey_col_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
4090
+ sdf = self._psser._internal.spark_frame.select(
4091
+ *[scol.alias(name) for scol, name in zip(self._groupkeys_scols, groupkey_col_names)],
4092
+ *[
4093
+ scol.alias(SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4094
+ for i, scol in enumerate(self._psser._internal.index_spark_columns)
4095
+ ],
4096
+ self._psser.spark.column,
4097
+ NATURAL_ORDER_COLUMN_NAME,
4098
+ )
4099
+
4100
+ window = Window.partitionBy(*groupkey_col_names).orderBy(
4101
+ scol_for(sdf, self._psser._internal.data_spark_column_names[0]).asc(),
4102
+ NATURAL_ORDER_COLUMN_NAME,
4103
+ )
4104
+
4105
+ temp_rank_column = verify_temp_column_name(sdf, "__rank__")
4106
+ sdf = (
4107
+ sdf.withColumn(temp_rank_column, F.row_number().over(window))
4108
+ .filter(F.col(temp_rank_column) <= n)
4109
+ .drop(temp_rank_column)
4110
+ ).drop(NATURAL_ORDER_COLUMN_NAME)
4111
+
4112
+ internal = InternalFrame(
4113
+ spark_frame=sdf,
4114
+ index_spark_columns=(
4115
+ [scol_for(sdf, col) for col in groupkey_col_names]
4116
+ + [
4117
+ scol_for(sdf, SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4118
+ for i in range(self._psdf._internal.index_level)
4119
+ ]
4120
+ ),
4121
+ index_names=(
4122
+ [psser._column_label for psser in self._groupkeys]
4123
+ + self._psdf._internal.index_names
4124
+ ),
4125
+ index_fields=(
4126
+ [
4127
+ psser._internal.data_fields[0].copy(name=name)
4128
+ for psser, name in zip(self._groupkeys, groupkey_col_names)
4129
+ ]
4130
+ + [
4131
+ field.copy(name=SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4132
+ for i, field in enumerate(self._psdf._internal.index_fields)
4133
+ ]
4134
+ ),
4135
+ column_labels=[self._psser._column_label],
4136
+ data_spark_columns=[scol_for(sdf, self._psser._internal.data_spark_column_names[0])],
4137
+ data_fields=[self._psser._internal.data_fields[0]],
4138
+ )
4139
+ return first_series(DataFrame(internal))
4140
+
4141
+ # TODO: add keep parameter
4142
+ def nlargest(self, n: int = 5) -> Series:
4143
+ """
4144
+ Return the first n rows ordered by columns in descending order in group.
4145
+
4146
+ Return the first n rows with the smallest values in columns, in descending order.
4147
+ The columns that are not specified are returned as well, but not used for ordering.
4148
+
4149
+ Parameters
4150
+ ----------
4151
+ n : int
4152
+ Number of items to retrieve.
4153
+
4154
+ See Also
4155
+ --------
4156
+ pyspark.pandas.Series.nlargest
4157
+ pyspark.pandas.DataFrame.nlargest
4158
+
4159
+ Examples
4160
+ --------
4161
+ >>> df = ps.DataFrame({'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
4162
+ ... 'b': [1, 2, 2, 2, 3, 3, 3, 4, 4]}, columns=['a', 'b'])
4163
+
4164
+ >>> df.groupby(['a'])['b'].nlargest(1).sort_index() # doctest: +NORMALIZE_WHITESPACE
4165
+ a
4166
+ 1 1 2
4167
+ 2 4 3
4168
+ 3 7 4
4169
+ Name: b, dtype: int64
4170
+ """
4171
+ if self._psser._internal.index_level > 1:
4172
+ raise ValueError("nlargest do not support multi-index now")
4173
+
4174
+ groupkey_col_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
4175
+ sdf = self._psser._internal.spark_frame.select(
4176
+ *[scol.alias(name) for scol, name in zip(self._groupkeys_scols, groupkey_col_names)],
4177
+ *[
4178
+ scol.alias(SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4179
+ for i, scol in enumerate(self._psser._internal.index_spark_columns)
4180
+ ],
4181
+ self._psser.spark.column,
4182
+ NATURAL_ORDER_COLUMN_NAME,
4183
+ )
4184
+
4185
+ window = Window.partitionBy(*groupkey_col_names).orderBy(
4186
+ scol_for(sdf, self._psser._internal.data_spark_column_names[0]).desc(),
4187
+ NATURAL_ORDER_COLUMN_NAME,
4188
+ )
4189
+
4190
+ temp_rank_column = verify_temp_column_name(sdf, "__rank__")
4191
+ sdf = (
4192
+ sdf.withColumn(temp_rank_column, F.row_number().over(window))
4193
+ .filter(F.col(temp_rank_column) <= n)
4194
+ .drop(temp_rank_column)
4195
+ ).drop(NATURAL_ORDER_COLUMN_NAME)
4196
+
4197
+ internal = InternalFrame(
4198
+ spark_frame=sdf,
4199
+ index_spark_columns=(
4200
+ [scol_for(sdf, col) for col in groupkey_col_names]
4201
+ + [
4202
+ scol_for(sdf, SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4203
+ for i in range(self._psdf._internal.index_level)
4204
+ ]
4205
+ ),
4206
+ index_names=(
4207
+ [psser._column_label for psser in self._groupkeys]
4208
+ + self._psdf._internal.index_names
4209
+ ),
4210
+ index_fields=(
4211
+ [
4212
+ psser._internal.data_fields[0].copy(name=name)
4213
+ for psser, name in zip(self._groupkeys, groupkey_col_names)
4214
+ ]
4215
+ + [
4216
+ field.copy(name=SPARK_INDEX_NAME_FORMAT(i + len(self._groupkeys)))
4217
+ for i, field in enumerate(self._psdf._internal.index_fields)
4218
+ ]
4219
+ ),
4220
+ column_labels=[self._psser._column_label],
4221
+ data_spark_columns=[scol_for(sdf, self._psser._internal.data_spark_column_names[0])],
4222
+ data_fields=[self._psser._internal.data_fields[0]],
4223
+ )
4224
+ return first_series(DataFrame(internal))
4225
+
4226
+ # TODO: add bins, normalize parameter
4227
+ def value_counts(
4228
+ self, sort: Optional[bool] = None, ascending: Optional[bool] = None, dropna: bool = True
4229
+ ) -> Series:
4230
+ """
4231
+ Compute group sizes.
4232
+
4233
+ Parameters
4234
+ ----------
4235
+ sort : boolean, default None
4236
+ Sort by frequencies.
4237
+ ascending : boolean, default False
4238
+ Sort in ascending order.
4239
+ dropna : boolean, default True
4240
+ Don't include counts of NaN.
4241
+
4242
+ See Also
4243
+ --------
4244
+ pyspark.pandas.Series.groupby
4245
+ pyspark.pandas.DataFrame.groupby
4246
+
4247
+ Examples
4248
+ --------
4249
+ >>> df = ps.DataFrame({'A': [1, 2, 2, 3, 3, 3],
4250
+ ... 'B': [1, 1, 2, 3, 3, np.nan]},
4251
+ ... columns=['A', 'B'])
4252
+ >>> df
4253
+ A B
4254
+ 0 1 1.0
4255
+ 1 2 1.0
4256
+ 2 2 2.0
4257
+ 3 3 3.0
4258
+ 4 3 3.0
4259
+ 5 3 NaN
4260
+
4261
+ >>> df.groupby('A')['B'].value_counts().sort_index() # doctest: +NORMALIZE_WHITESPACE
4262
+ A B
4263
+ 1 1.0 1
4264
+ 2 1.0 1
4265
+ 2.0 1
4266
+ 3 3.0 2
4267
+ Name: B, dtype: int64
4268
+
4269
+ Don't include counts of NaN when dropna is False.
4270
+
4271
+ >>> df.groupby('A')['B'].value_counts(
4272
+ ... dropna=False).sort_index() # doctest: +NORMALIZE_WHITESPACE
4273
+ A B
4274
+ 1 1.0 1
4275
+ 2 1.0 1
4276
+ 2.0 1
4277
+ 3 3.0 2
4278
+ NaN 1
4279
+ Name: B, dtype: int64
4280
+ """
4281
+ warnings.warn(
4282
+ "The resulting Series will have a fixed name of 'count' from 4.0.0.",
4283
+ FutureWarning,
4284
+ )
4285
+ groupkeys = self._groupkeys + self._agg_columns
4286
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
4287
+ groupkey_cols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]
4288
+
4289
+ sdf = self._psdf._internal.spark_frame
4290
+
4291
+ agg_column = self._agg_columns[0]._internal.data_spark_column_names[0]
4292
+ sdf = sdf.groupby(*groupkey_cols).count().withColumnRenamed("count", agg_column)
4293
+
4294
+ if self._dropna:
4295
+ _groupkey_column_names = groupkey_names[: len(self._groupkeys)]
4296
+ sdf = sdf.dropna(subset=_groupkey_column_names)
4297
+
4298
+ if dropna:
4299
+ _agg_columns_names = groupkey_names[len(self._groupkeys) :]
4300
+ sdf = sdf.dropna(subset=_agg_columns_names)
4301
+
4302
+ if sort:
4303
+ if ascending:
4304
+ sdf = sdf.orderBy(scol_for(sdf, agg_column).asc())
4305
+ else:
4306
+ sdf = sdf.orderBy(scol_for(sdf, agg_column).desc())
4307
+
4308
+ internal = InternalFrame(
4309
+ spark_frame=sdf,
4310
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
4311
+ index_names=[psser._column_label for psser in groupkeys],
4312
+ index_fields=[
4313
+ psser._internal.data_fields[0].copy(name=name)
4314
+ for psser, name in zip(groupkeys, groupkey_names)
4315
+ ],
4316
+ column_labels=[self._agg_columns[0]._column_label],
4317
+ data_spark_columns=[scol_for(sdf, agg_column)],
4318
+ )
4319
+ return first_series(DataFrame(internal))
4320
+
4321
+ def unique(self) -> Series:
4322
+ """
4323
+ Return unique values in group.
4324
+
4325
+ Unique is returned in order of unknown. It does NOT sort.
4326
+
4327
+ See Also
4328
+ --------
4329
+ pyspark.pandas.Series.unique
4330
+ pyspark.pandas.Index.unique
4331
+
4332
+ Examples
4333
+ --------
4334
+ >>> df = ps.DataFrame({'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
4335
+ ... 'b': [1, 2, 2, 2, 3, 3, 3, 4, 4]}, columns=['a', 'b'])
4336
+
4337
+ >>> df.groupby(['a'])['b'].unique().sort_index() # doctest: +SKIP
4338
+ a
4339
+ 1 [1, 2]
4340
+ 2 [2, 3]
4341
+ 3 [3, 4]
4342
+ Name: b, dtype: object
4343
+ """
4344
+ return self._reduce_for_stat_function(F.collect_set)
4345
+
4346
+
4347
+ def is_multi_agg_with_relabel(**kwargs: Any) -> bool:
4348
+ """
4349
+ Check whether the kwargs pass to .agg look like multi-agg with relabling.
4350
+
4351
+ Parameters
4352
+ ----------
4353
+ **kwargs : dict
4354
+
4355
+ Returns
4356
+ -------
4357
+ bool
4358
+
4359
+ Examples
4360
+ --------
4361
+ >>> is_multi_agg_with_relabel(a='max')
4362
+ False
4363
+ >>> is_multi_agg_with_relabel(a_max=('a', 'max'),
4364
+ ... a_min=('a', 'min'))
4365
+ True
4366
+ >>> is_multi_agg_with_relabel()
4367
+ False
4368
+ """
4369
+ if not kwargs:
4370
+ return False
4371
+ return all(isinstance(v, tuple) and len(v) == 2 for v in kwargs.values())
4372
+
4373
+
4374
+ def normalize_keyword_aggregation(
4375
+ kwargs: Dict[str, Tuple[Name, str]],
4376
+ ) -> Tuple[Dict[Name, List[str]], List[str], List[Tuple]]:
4377
+ """
4378
+ Normalize user-provided kwargs.
4379
+
4380
+ Transforms from the new ``Dict[str, NamedAgg]`` style kwargs
4381
+ to the old defaultdict[str, List[scalar]].
4382
+
4383
+ Parameters
4384
+ ----------
4385
+ kwargs : dict
4386
+
4387
+ Returns
4388
+ -------
4389
+ aggspec : dict
4390
+ The transformed kwargs.
4391
+ columns : List[str]
4392
+ The user-provided keys.
4393
+ order : List[Tuple[str, str]]
4394
+ Pairs of the input and output column names.
4395
+
4396
+ Examples
4397
+ --------
4398
+ >>> normalize_keyword_aggregation({'output': ('input', 'sum')})
4399
+ (defaultdict(<class 'list'>, {'input': ['sum']}), ['output'], [('input', 'sum')])
4400
+ """
4401
+ aggspec: Dict[Union[Any, Tuple], List[str]] = defaultdict(list)
4402
+ order: List[Tuple] = []
4403
+ columns, pairs = zip(*kwargs.items())
4404
+
4405
+ for column, aggfunc in pairs:
4406
+ if column in aggspec:
4407
+ aggspec[column].append(aggfunc)
4408
+ else:
4409
+ aggspec[column] = [aggfunc]
4410
+
4411
+ order.append((column, aggfunc))
4412
+ # For MultiIndex, we need to flatten the tuple, e.g. (('y', 'A'), 'max') needs to be
4413
+ # flattened to ('y', 'A', 'max'), it won't do anything on normal Index.
4414
+ if isinstance(order[0][0], tuple):
4415
+ order = [(*levs, method) for levs, method in order]
4416
+ return aggspec, list(columns), order
4417
+
4418
+
4419
+ def _test() -> None:
4420
+ import os
4421
+ import doctest
4422
+ import sys
4423
+ import numpy
4424
+ from pyspark.sql import SparkSession
4425
+ import pyspark.pandas.groupby
4426
+
4427
+ os.chdir(os.environ["SPARK_HOME"])
4428
+
4429
+ globs = pyspark.pandas.groupby.__dict__.copy()
4430
+ globs["np"] = numpy
4431
+ globs["ps"] = pyspark.pandas
4432
+ spark = (
4433
+ SparkSession.builder.master("local[4]")
4434
+ .appName("pyspark.pandas.groupby tests")
4435
+ .getOrCreate()
4436
+ )
4437
+ (failure_count, test_count) = doctest.testmod(
4438
+ pyspark.pandas.groupby,
4439
+ globs=globs,
4440
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
4441
+ )
4442
+ spark.stop()
4443
+ if failure_count:
4444
+ sys.exit(-1)
4445
+
4446
+
4447
+ if __name__ == "__main__":
4448
+ _test()