snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1133 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ import json
18
+ from contextlib import contextmanager
19
+ import collections
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import re
25
+ import shutil
26
+ import subprocess
27
+ import sys
28
+ import tempfile
29
+ import textwrap
30
+ import time
31
+ from typing import (
32
+ Union,
33
+ Callable,
34
+ List,
35
+ Dict,
36
+ Optional,
37
+ Any,
38
+ Tuple,
39
+ Generator,
40
+ Iterator,
41
+ )
42
+
43
+ from pyspark import cloudpickle
44
+ from pyspark.resource.information import ResourceInformation
45
+ from pyspark.sql import DataFrame, SparkSession
46
+ from pyspark.taskcontext import BarrierTaskContext
47
+ from pyspark.ml.torch.log_communication import ( # type: ignore
48
+ LogStreamingClient,
49
+ LogStreamingServer,
50
+ )
51
+
52
+
53
+ def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]:
54
+ resources: Dict[str, ResourceInformation] = {}
55
+ try:
56
+ resources = session.sparkContext.resources
57
+ except Exception:
58
+ resources = session._client._resources() # type: ignore[attr-defined]
59
+ return resources
60
+
61
+
62
+ def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
63
+ """Get the conf "key" from the given spark session,
64
+ or return the default value if the conf is not set.
65
+
66
+ Parameters
67
+ ----------
68
+ spark : :class:`SparkSession`
69
+ The :class:`SparkSession` for the distributor.
70
+ key : str
71
+ string for conf name
72
+ default_value : str
73
+ default value for the conf value for the given key
74
+
75
+ Returns
76
+ -------
77
+ str
78
+ Returns the string value that corresponds to the conf
79
+ """
80
+ value = spark.conf.get(key, default_value)
81
+ assert value is not None
82
+ return value
83
+
84
+
85
+ # TODO(SPARK-41589): will move the functions and tests to an external file
86
+ # once we are in agreement about which functions should be in utils.py
87
+ def _get_conf_boolean(spark: SparkSession, key: str, default_value: str) -> bool:
88
+ value = _get_conf(spark=spark, key=key, default_value=default_value)
89
+ value = value.lower()
90
+ assert value in ["true", "false"]
91
+ return value == "true"
92
+
93
+
94
+ def _get_logger(name: str) -> logging.Logger:
95
+ """
96
+ Gets a logger by name, or creates and configures it for the first time.
97
+ """
98
+ logger = logging.getLogger(name)
99
+ logger.setLevel(logging.INFO)
100
+ # If the logger is configured, skip the configure
101
+ if not logger.handlers and not logging.getLogger().handlers:
102
+ handler = logging.StreamHandler(sys.stderr)
103
+ logger.addHandler(handler)
104
+ return logger
105
+
106
+
107
+ def _get_gpus_owned(context: Union[SparkSession, BarrierTaskContext]) -> List[str]:
108
+ """Gets the number of GPUs that Spark scheduled to the calling task.
109
+
110
+ Parameters
111
+ ----------
112
+ context : :class:`SparkSession` or :class:`BarrierTaskContext`
113
+ The :class:`SparkSession` or :class:`BarrierTaskContext` that has GPUs available.
114
+
115
+ Returns
116
+ -------
117
+ list
118
+ The correct mapping of addresses to workers.
119
+
120
+ Raises
121
+ ------
122
+ ValueError
123
+ Raised if the input addresses were not found.
124
+ """
125
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
126
+ pattern = re.compile("^[1-9][0-9]*|0$")
127
+ if isinstance(context, BarrierTaskContext):
128
+ addresses = context.resources()["gpu"].addresses
129
+ else:
130
+ addresses = _get_resources(context)["gpu"].addresses
131
+
132
+ if any(not pattern.match(address) for address in addresses):
133
+ raise ValueError(
134
+ f"Found GPU addresses {addresses} which "
135
+ "are not all in the correct format "
136
+ "for CUDA_VISIBLE_DEVICES, which requires "
137
+ "integers with no zero padding."
138
+ )
139
+ if CUDA_VISIBLE_DEVICES in os.environ:
140
+ gpu_indices = list(map(int, addresses))
141
+ gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",")
142
+ gpu_owned = [gpu_list[i] for i in gpu_indices]
143
+ return gpu_owned
144
+ return addresses
145
+
146
+
147
+ SPARK_PARTITION_ARROW_DATA_FILE = "SPARK_PARTITION_ARROW_DATA_FILE"
148
+ SPARK_DATAFRAME_SCHEMA_FILE = "SPARK_DATAFRAME_SCHEMA_FILE"
149
+
150
+
151
+ class Distributor:
152
+ """
153
+ The parent class for TorchDistributor. This class shouldn't be instantiated directly.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ num_processes: int = 1,
159
+ local_mode: bool = True,
160
+ use_gpu: bool = True,
161
+ ssl_conf: Optional[str] = None,
162
+ ):
163
+ from pyspark.sql.utils import is_remote
164
+
165
+ self.is_remote = is_remote()
166
+ self.spark = SparkSession.active()
167
+
168
+ # indicate whether the server side is local mode
169
+ self.is_spark_local_master = False
170
+ # Refer to 'org.apache.spark.util.Utils#isLocalMaster'
171
+ master = _get_conf(self.spark, "spark.master", "")
172
+ if master == "local" or master.startswith("local["):
173
+ self.is_spark_local_master = True
174
+
175
+ self.logger = _get_logger(self.__class__.__name__)
176
+ self.num_processes = num_processes
177
+ self.local_mode = local_mode
178
+ self.use_gpu = use_gpu
179
+ self.num_tasks = self._get_num_tasks()
180
+ self.ssl_conf = ssl_conf
181
+
182
+ def _create_input_params(self) -> Dict[str, Any]:
183
+ input_params = self.__dict__.copy()
184
+ for unneeded_param in [
185
+ "spark",
186
+ "ssl_conf",
187
+ "logger",
188
+ "is_remote",
189
+ "is_spark_local_master",
190
+ ]:
191
+ del input_params[unneeded_param]
192
+ return input_params
193
+
194
+ def _get_num_tasks(self) -> int:
195
+ """
196
+ Returns the number of Spark tasks to use for distributed training
197
+
198
+ Returns
199
+ -------
200
+ int
201
+ The number of Spark tasks to use for distributed training
202
+
203
+ Raises
204
+ ------
205
+ RuntimeError
206
+ Raised when the SparkConf was misconfigured.
207
+ """
208
+ if self.use_gpu:
209
+ if not self.local_mode:
210
+ key = "spark.task.resource.gpu.amount"
211
+ task_gpu_amount = int(_get_conf(self.spark, key, "0"))
212
+ if task_gpu_amount < 1:
213
+ raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
214
+ # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
215
+ return math.ceil(self.num_processes / task_gpu_amount)
216
+ else:
217
+ key = "spark.driver.resource.gpu.amount"
218
+ if "gpu" not in _get_resources(self.spark):
219
+ raise RuntimeError("GPUs were unable to be found on the driver.")
220
+ num_available_gpus = int(_get_conf(self.spark, key, "0"))
221
+ if num_available_gpus == 0:
222
+ raise RuntimeError("GPU resources were not configured properly on the driver.")
223
+ if self.num_processes > num_available_gpus:
224
+ self.logger.warning(
225
+ "'num_processes' cannot be set to a value greater than the number of "
226
+ f"available GPUs on the driver, which is {num_available_gpus}. "
227
+ "'num_processes' was reset to be equal to the number of available GPUs.",
228
+ )
229
+ self.num_processes = num_available_gpus
230
+ return self.num_processes
231
+
232
+ def _validate_input_params(self) -> None:
233
+ if self.num_processes <= 0:
234
+ raise ValueError("num_proccesses has to be a positive integer")
235
+
236
+ def _check_encryption(self) -> None:
237
+ """Checks to see if the user requires encrpytion of data.
238
+ If required, throw an exception since we don't support that.
239
+
240
+ Raises
241
+ ------
242
+ RuntimeError
243
+ Thrown when the user requires ssl encryption or when the user initializes
244
+ the Distributor parent class.
245
+ """
246
+ if not hasattr(self, "ssl_conf"):
247
+ raise RuntimeError(
248
+ "Distributor doesn't have this functionality. Use TorchDistributor instead."
249
+ )
250
+ is_ssl_enabled = _get_conf_boolean(self.spark, "spark.ssl.enabled", "false")
251
+ ignore_ssl = _get_conf_boolean(self.spark, self.ssl_conf, "false") # type: ignore
252
+ if is_ssl_enabled:
253
+ name = self.__class__.__name__
254
+ if ignore_ssl:
255
+ self.logger.warning(
256
+ textwrap.dedent(
257
+ f"""
258
+ This cluster has TLS encryption enabled;
259
+ however, {name} does not
260
+ support data encryption in transit.
261
+ The Spark configuration
262
+ '{self.ssl_conf}' has been set to
263
+ 'true' to override this
264
+ configuration and use {name} anyway. Please
265
+ note this will cause model
266
+ parameters and possibly training data to
267
+ be sent between nodes unencrypted.
268
+ """,
269
+ )
270
+ )
271
+ return
272
+ raise RuntimeError(
273
+ textwrap.dedent(
274
+ f"""
275
+ This cluster has TLS encryption enabled;
276
+ however, {name} does not support
277
+ data encryption in transit. To override
278
+ this configuration and use {name}
279
+ anyway, you may set '{self.ssl_conf}'
280
+ to 'true' in the Spark configuration. Please note this
281
+ will cause model parameters and possibly training
282
+ data to be sent between nodes unencrypted.
283
+ """
284
+ )
285
+ )
286
+
287
+
288
+ class TorchDistributor(Distributor):
289
+ """
290
+ A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
291
+
292
+ .. versionadded:: 3.4.0
293
+
294
+ .. versionchanged:: 3.5.0
295
+ Supports Spark Connect.
296
+
297
+ Parameters
298
+ ----------
299
+ num_processes : int, optional
300
+ An integer that determines how many different concurrent
301
+ tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
302
+ should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
303
+ local_mode : bool, optional
304
+ A boolean that determines whether we are using the driver
305
+ node for training. Default should be false; we don't want to invoke executors without
306
+ explicit mention.
307
+ use_gpu : bool, optional
308
+ A boolean that indicates whether or not we are doing training
309
+ on the GPU. Note that there are differences in how GPU-enabled code looks like and
310
+ how CPU-specific code looks like.
311
+
312
+ Examples
313
+ --------
314
+ Run PyTorch Training locally on GPU (using a PyTorch native function)
315
+
316
+ >>> def train(learning_rate):
317
+ ... import torch.distributed
318
+ ... torch.distributed.init_process_group(backend="nccl")
319
+ ... # ...
320
+ ... torch.destroy_process_group()
321
+ ... return model # or anything else
322
+ ...
323
+ >>> distributor = TorchDistributor(
324
+ ... num_processes=2,
325
+ ... local_mode=True,
326
+ ... use_gpu=True)
327
+ >>> model = distributor.run(train, 1e-3)
328
+
329
+ Run PyTorch Training on GPU (using a file with PyTorch code)
330
+
331
+ >>> distributor = TorchDistributor(
332
+ ... num_processes=2,
333
+ ... local_mode=False,
334
+ ... use_gpu=True)
335
+ >>> distributor.run("/path/to/train.py", "--learning-rate=1e-3")
336
+
337
+ Run PyTorch Lightning Training on GPU
338
+
339
+ >>> num_proc = 2
340
+ >>> def train():
341
+ ... from pytorch_lightning import Trainer
342
+ ... # ...
343
+ ... # required to set devices = 1 and num_nodes = num_processes for multi node
344
+ ... # required to set devices = num_processes and num_nodes = 1 for single node multi GPU
345
+ ... trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
346
+ ... trainer.fit()
347
+ ... # ...
348
+ ... return trainer
349
+ ...
350
+ >>> distributor = TorchDistributor(
351
+ ... num_processes=num_proc,
352
+ ... local_mode=True,
353
+ ... use_gpu=True)
354
+ >>> trainer = distributor.run(train)
355
+ """
356
+
357
+ _PICKLED_FUNC_FILE = "func.pickle"
358
+ _TRAIN_FILE = "train.py"
359
+ _PICKLED_OUTPUT_FILE = "output.pickle"
360
+ _TORCH_SSL_CONF = "pytorch.spark.distributor.ignoreSsl"
361
+
362
+ def __init__(
363
+ self,
364
+ num_processes: int = 1,
365
+ local_mode: bool = True,
366
+ use_gpu: bool = True,
367
+ _ssl_conf: str = _TORCH_SSL_CONF,
368
+ ):
369
+ """Initializes the distributor.
370
+
371
+ Parameters
372
+ ----------
373
+ num_processes : int, optional
374
+ An integer that determines how many different concurrent
375
+ tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
376
+ should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
377
+ local_mode : bool, optional
378
+ A boolean that determines whether we are using the driver
379
+ node for training. Default should be false; we don't want to invoke executors without
380
+ explicit mention.
381
+ use_gpu : bool, optional
382
+ A boolean that indicates whether or not we are doing training
383
+ on the GPU. Note that there are differences in how GPU-enabled code looks like and
384
+ how CPU-specific code looks like.
385
+
386
+ Raises
387
+ ------
388
+ ValueError
389
+ If any of the parameters are incorrect.
390
+ RuntimeError
391
+ If an active SparkSession is unavailable.
392
+ """
393
+ super().__init__(num_processes, local_mode, use_gpu, ssl_conf=_ssl_conf)
394
+ self._validate_input_params()
395
+ self.input_params = self._create_input_params()
396
+
397
+ @staticmethod
398
+ def _get_torchrun_args(local_mode: bool, num_processes: int) -> Tuple[List[Any], int]:
399
+ """
400
+ Given the mode and the number of processes, create the arguments to be given to for torch
401
+
402
+ Parameters
403
+ ---------
404
+ local_mode: bool
405
+ Whether or not we are running training locally or in a distributed fashion
406
+
407
+ num_processes: int
408
+ The number of processes that we are going to use
409
+
410
+ Returns
411
+ ------
412
+ Tuple[List[Any], int]
413
+ A tuple containing a list of arguments to pass as pytorch args,
414
+ as well as the number of processes per node
415
+ """
416
+ if local_mode:
417
+ torchrun_args = ["--standalone", "--nnodes=1"]
418
+ processes_per_node = num_processes
419
+ return torchrun_args, processes_per_node
420
+
421
+ master_addr = os.environ["MASTER_ADDR"]
422
+ master_port = os.environ["MASTER_PORT"]
423
+ node_rank = os.environ["RANK"]
424
+ torchrun_args = [
425
+ f"--nnodes={num_processes}",
426
+ f"--node_rank={node_rank}",
427
+ f"--rdzv_endpoint={master_addr}:{master_port}",
428
+ "--rdzv_id=0", # TODO: setup random ID that is gleaned from env variables
429
+ ]
430
+ processes_per_node = 1
431
+ return torchrun_args, processes_per_node
432
+
433
+ @staticmethod
434
+ def _create_torchrun_command(
435
+ input_params: Dict[str, Any], path_to_train_file: str, *args: Any
436
+ ) -> List[str]:
437
+ local_mode = input_params["local_mode"]
438
+ num_processes = input_params["num_processes"]
439
+
440
+ torchrun_args, processes_per_node = TorchDistributor._get_torchrun_args(
441
+ local_mode=local_mode, num_processes=num_processes
442
+ )
443
+ args_string = list(map(str, args)) # converting all args to strings
444
+
445
+ return [
446
+ sys.executable,
447
+ "-m",
448
+ "pyspark.ml.torch.torch_run_process_wrapper",
449
+ *torchrun_args,
450
+ f"--nproc_per_node={processes_per_node}",
451
+ path_to_train_file,
452
+ *args_string,
453
+ ]
454
+
455
+ @staticmethod
456
+ def _execute_command(
457
+ cmd: List[str],
458
+ _prctl: bool = True,
459
+ redirect_to_stdout: bool = True,
460
+ log_streaming_client: Optional[LogStreamingClient] = None,
461
+ ) -> None:
462
+ _TAIL_LINES_TO_KEEP = 100
463
+
464
+ task = subprocess.Popen(
465
+ cmd,
466
+ stdout=subprocess.PIPE,
467
+ stderr=subprocess.STDOUT,
468
+ stdin=subprocess.PIPE,
469
+ env=os.environ,
470
+ )
471
+ task.stdin.close() # type: ignore
472
+ tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
473
+ try:
474
+ for line in task.stdout: # type: ignore
475
+ decoded = line.decode()
476
+ tail.append(decoded)
477
+ if redirect_to_stdout:
478
+ if (
479
+ log_streaming_client
480
+ and not log_streaming_client.failed
481
+ and (
482
+ log_streaming_client.sock.getsockname()[0]
483
+ == log_streaming_client.sock.getpeername()[0]
484
+ )
485
+ ):
486
+ # If log_streaming_client and log_stream_server are in the same
487
+ # node (typical case is spark local mode),
488
+ # server side will redirect the log to STDOUT,
489
+ # to avoid STDOUT outputs duplication, skip redirecting
490
+ # logs to STDOUT in client side.
491
+ pass
492
+ else:
493
+ sys.stdout.write(decoded)
494
+ if log_streaming_client:
495
+ log_streaming_client.send(decoded.rstrip())
496
+ task.wait()
497
+ finally:
498
+ if task.poll() is None:
499
+ try:
500
+ task.terminate() # SIGTERM
501
+ time.sleep(0.5)
502
+ if task.poll() is None:
503
+ task.kill() # SIGKILL
504
+ except OSError:
505
+ pass
506
+ if task.returncode != os.EX_OK:
507
+ if len(tail) == _TAIL_LINES_TO_KEEP:
508
+ last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task output are"
509
+ else:
510
+ last_n_msg = "task output is"
511
+ task_output = "".join(tail)
512
+ raise RuntimeError(
513
+ f"Command {cmd} failed with return code {task.returncode}. "
514
+ f"The {last_n_msg} included below: {task_output}"
515
+ )
516
+
517
+ @staticmethod
518
+ def _get_output_from_framework_wrapper(
519
+ framework_wrapper: Optional[Callable],
520
+ input_params: Dict,
521
+ train_object: Union[Callable, str],
522
+ run_pytorch_file_fn: Optional[Callable],
523
+ *args: Any,
524
+ **kwargs: Any,
525
+ ) -> Optional[Any]:
526
+ """
527
+ This function is meant to get the output from framework wrapper function by passing in the
528
+ correct arguments, depending on the type of train_object.
529
+
530
+ Parameters
531
+ ----------
532
+ framework_wrapper: Optional[Callable]
533
+ Function pointer that will be invoked. Can either be the function that runs distributed
534
+ training on files if train_object is a string. Otherwise, it will be the function that
535
+ runs distributed training for functions if the train_object is a Callable
536
+ input_params: Dict
537
+ A dictionary that maps parameter to arguments for the command to be created.
538
+ train_object: Union[Callable, str]
539
+ This input comes from the user. If the user inputs a string, then this means
540
+ it's a filepath. Otherwise, if the input is a function, then this means that
541
+ the user wants to run this function in a distributed manner.
542
+ run_pytorch_file_fn: Optional[Callable]
543
+ The function that will be used to run distributed training of a file;
544
+ mainly used for the distributed training using a function.
545
+ *args: Any
546
+ Extra arguments to be used by framework wrapper.
547
+ **kwargs: Any
548
+ Extra keyword args to be used. Not currently supported but kept for
549
+ future improvement.
550
+
551
+ Returns
552
+ -------
553
+ Optional[Any]
554
+ Returns the result of the framework_wrapper
555
+ """
556
+ if not framework_wrapper:
557
+ raise RuntimeError("`framework_wrapper` is not set. ...")
558
+ # The object to train is a file path, so framework_wrapper is some
559
+ # run_training_on_pytorch_file function.
560
+ if type(train_object) is str:
561
+ return framework_wrapper(input_params, train_object, *args, **kwargs)
562
+ else:
563
+ # We are doing training with a function, will call run_training_on_pytorch_function
564
+ if not run_pytorch_file_fn:
565
+ run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file
566
+ return framework_wrapper(
567
+ input_params, train_object, run_pytorch_file_fn, *args, **kwargs
568
+ )
569
+
570
+ def _run_local_training(
571
+ self,
572
+ framework_wrapper_fn: Callable,
573
+ train_object: Union[Callable, str],
574
+ run_pytorch_file_fn: Optional[Callable],
575
+ *args: Any,
576
+ **kwargs: Any,
577
+ ) -> Optional[Any]:
578
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
579
+ cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
580
+ old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
581
+ try:
582
+ # Only replace the GPUs with 'SparkContext.resources' in legacy mode.
583
+ # In connect mode, this replacement is skipped since only GPUs on the client side
584
+ # can be used.
585
+ if self.use_gpu and not self.is_remote:
586
+ gpus_owned = _get_gpus_owned(self.spark)
587
+ random.seed(hash(train_object))
588
+ selected_gpus = [str(e) for e in random.sample(gpus_owned, self.num_processes)]
589
+ os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
590
+
591
+ self.logger.info(f"Started local training with {self.num_processes} processes")
592
+ output = TorchDistributor._get_output_from_framework_wrapper(
593
+ framework_wrapper_fn,
594
+ self.input_params,
595
+ train_object,
596
+ run_pytorch_file_fn,
597
+ *args,
598
+ **kwargs,
599
+ )
600
+ self.logger.info(f"Finished local training with {self.num_processes} processes")
601
+
602
+ finally:
603
+ if cuda_state_was_set:
604
+ os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
605
+ else:
606
+ if CUDA_VISIBLE_DEVICES in os.environ:
607
+ del os.environ[CUDA_VISIBLE_DEVICES]
608
+
609
+ return output
610
+
611
+ def _get_spark_task_function(
612
+ self,
613
+ framework_wrapper_fn: Optional[Callable],
614
+ train_object: Union[Callable, str],
615
+ run_pytorch_file_fn: Optional[Callable],
616
+ input_dataframe: Optional["DataFrame"],
617
+ *args: Any,
618
+ **kwargs: Any,
619
+ ) -> Callable:
620
+ """Creates a spark task function that is used inside `mapPartitions`.
621
+
622
+ Parameters
623
+ ----------
624
+ framework_wrapper_fn : Optional[Callable]
625
+ The function that determines whether we are running training
626
+ on a PyTorch file or a PyTorch function.
627
+ train_object : Union[Callable, str]
628
+ The actual train function/file.
629
+
630
+ Returns
631
+ -------
632
+ Callable
633
+ The wrapped function ready for use with `mapPartitions`
634
+ """
635
+ num_processes = self.num_processes
636
+ use_gpu = self.use_gpu
637
+ input_params = self.input_params
638
+ driver_address = self.driver_address
639
+ log_streaming_server_port = self.log_streaming_server_port
640
+ is_spark_local_master = self.is_spark_local_master
641
+ driver_owned_gpus: List[str] = []
642
+ if is_spark_local_master and use_gpu:
643
+ driver_owned_gpus = _get_gpus_owned(self.spark)
644
+
645
+ if input_dataframe is not None:
646
+ schema_json = input_dataframe.schema.jsonValue()
647
+ else:
648
+ schema_json = None
649
+
650
+ # Spark task program
651
+ def wrapped_train_fn(iterator): # type: ignore[no-untyped-def]
652
+ import os
653
+ import pandas as pd
654
+ import pyarrow
655
+ from pyspark import BarrierTaskContext
656
+
657
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
658
+
659
+ def get_free_port(address: str, context: "BarrierTaskContext") -> int:
660
+ port = ""
661
+ if context.partitionId() == 0:
662
+ try:
663
+ import socket
664
+
665
+ sock = socket.socket()
666
+ sock.bind((address, 0))
667
+ port = sock.getsockname()[1]
668
+ except socket.error:
669
+ pass
670
+ available_port = context.allGather(str(port))[0]
671
+ if not available_port:
672
+ raise RuntimeError("Failed to find free port for distributed training.")
673
+ return int(available_port)
674
+
675
+ def set_torch_config(context: "BarrierTaskContext") -> None:
676
+ addrs = [e.address.split(":")[0] for e in context.getTaskInfos()]
677
+
678
+ os.environ["MASTER_ADDR"] = str(addrs[0])
679
+ os.environ["MASTER_PORT"] = str(get_free_port(addrs[0], context))
680
+ os.environ["WORLD_SIZE"] = str(num_processes)
681
+ os.environ["NODE_RANK"] = str(context.partitionId())
682
+ os.environ["RANK"] = str(context.partitionId())
683
+
684
+ if context.partitionId() >= num_processes:
685
+ raise ValueError(
686
+ "TorchDistributor._train_on_dataframe requires setting num_processes "
687
+ "equal to input spark dataframe partition number."
688
+ )
689
+
690
+ if is_spark_local_master:
691
+ # distributed training on a local mode spark cluster
692
+ def set_gpus(context: "BarrierTaskContext") -> None:
693
+ if CUDA_VISIBLE_DEVICES in os.environ:
694
+ return
695
+
696
+ gpu_owned = driver_owned_gpus[context.partitionId()]
697
+ os.environ[CUDA_VISIBLE_DEVICES] = gpu_owned
698
+
699
+ else:
700
+
701
+ def set_gpus(context: "BarrierTaskContext") -> None:
702
+ if CUDA_VISIBLE_DEVICES in os.environ:
703
+ return
704
+
705
+ gpus_owned = _get_gpus_owned(context)
706
+ os.environ[CUDA_VISIBLE_DEVICES] = ",".join(gpus_owned)
707
+
708
+ context = BarrierTaskContext.get()
709
+
710
+ if use_gpu:
711
+ set_gpus(context)
712
+ else:
713
+ os.environ[CUDA_VISIBLE_DEVICES] = ""
714
+ set_torch_config(context)
715
+
716
+ log_streaming_client = LogStreamingClient(driver_address, log_streaming_server_port)
717
+ input_params["log_streaming_client"] = log_streaming_client
718
+ try:
719
+ with TorchDistributor._setup_spark_partition_data(iterator, schema_json):
720
+ output = TorchDistributor._get_output_from_framework_wrapper(
721
+ framework_wrapper_fn,
722
+ input_params,
723
+ train_object,
724
+ run_pytorch_file_fn,
725
+ *args,
726
+ **kwargs,
727
+ )
728
+ finally:
729
+ try:
730
+ LogStreamingClient._destroy()
731
+ except BaseException:
732
+ pass
733
+
734
+ if context.partitionId() == 0:
735
+ output_bytes = cloudpickle.dumps(output)
736
+ output_size = len(output_bytes)
737
+
738
+ # In Spark Connect, DataFrame.collect stacks rows to size
739
+ # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
740
+ # here use 4KiB for each chunk, which mean each arrow batch
741
+ # may contain about 1000 chunks.
742
+ chunks = []
743
+ chunk_size = 4096
744
+ index = 0
745
+ while index < output_size:
746
+ chunks.append(output_bytes[index : index + chunk_size])
747
+ index += chunk_size
748
+
749
+ yield pyarrow.RecordBatch.from_pandas(pd.DataFrame(data={"chunk": chunks}))
750
+
751
+ return wrapped_train_fn
752
+
753
+ def _run_distributed_training(
754
+ self,
755
+ framework_wrapper_fn: Callable,
756
+ train_object: Union[Callable, str],
757
+ run_pytorch_file_fn: Optional[Callable],
758
+ spark_dataframe: Optional["DataFrame"],
759
+ *args: Any,
760
+ **kwargs: Any,
761
+ ) -> Optional[Any]:
762
+ if not framework_wrapper_fn:
763
+ raise RuntimeError("Unknown combination of parameters")
764
+
765
+ log_streaming_server = LogStreamingServer()
766
+ self.driver_address = _get_conf(self.spark, "spark.driver.host", "")
767
+ assert self.driver_address != ""
768
+ try:
769
+ log_streaming_server.start(spark_host_address=self.driver_address)
770
+ time.sleep(1) # wait for the server to start
771
+ self.log_streaming_server_port = log_streaming_server.port
772
+ except Exception as e:
773
+ # If starting log streaming server failed, we don't need to break
774
+ # the distributor training but emit a warning instead.
775
+ self.log_streaming_server_port = -1
776
+ self.logger.warning(
777
+ "Start torch distributor log streaming server failed, "
778
+ "You cannot receive logs sent from distributor workers, ",
779
+ f"error: {repr(e)}.",
780
+ )
781
+
782
+ try:
783
+ spark_task_function = self._get_spark_task_function(
784
+ framework_wrapper_fn,
785
+ train_object,
786
+ run_pytorch_file_fn,
787
+ spark_dataframe,
788
+ *args,
789
+ **kwargs,
790
+ )
791
+ self._check_encryption()
792
+ self.logger.info(
793
+ f"Started distributed training with {self.num_processes} executor processes"
794
+ )
795
+ if spark_dataframe is not None:
796
+ input_df = spark_dataframe
797
+ else:
798
+ input_df = self.spark.range(
799
+ start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks
800
+ )
801
+ rows = input_df.mapInArrow(
802
+ func=spark_task_function, schema="chunk binary", barrier=True
803
+ ).collect()
804
+ output_bytes = b"".join([row.chunk for row in rows])
805
+ result = cloudpickle.loads(output_bytes)
806
+ finally:
807
+ log_streaming_server.shutdown()
808
+ self.logger.info(
809
+ f"Finished distributed training with {self.num_processes} executor processes"
810
+ )
811
+ return result
812
+
813
+ @staticmethod
814
+ def _run_training_on_pytorch_file(
815
+ input_params: Dict[str, Any], train_path: str, *args: Any, **kwargs: Any
816
+ ) -> None:
817
+ if kwargs:
818
+ raise ValueError("Running pytorch file does not support key-word type arguments.")
819
+ log_streaming_client = input_params.get("log_streaming_client", None)
820
+ training_command = TorchDistributor._create_torchrun_command(
821
+ input_params, train_path, *args
822
+ )
823
+ TorchDistributor._execute_command(
824
+ training_command, log_streaming_client=log_streaming_client
825
+ )
826
+
827
+ @staticmethod
828
+ @contextmanager
829
+ def _setup_files(
830
+ train_fn: Callable, *args: Any, **kwargs: Any
831
+ ) -> Generator[Tuple[str, str], None, None]:
832
+ save_dir = TorchDistributor._create_save_dir()
833
+ pickle_file_path = TorchDistributor._save_pickled_function(
834
+ save_dir, train_fn, *args, **kwargs
835
+ )
836
+ output_file_path = os.path.join(save_dir, TorchDistributor._PICKLED_OUTPUT_FILE)
837
+ train_file_path = TorchDistributor._create_torchrun_train_file(
838
+ save_dir, pickle_file_path, output_file_path
839
+ )
840
+ try:
841
+ yield (train_file_path, output_file_path)
842
+ finally:
843
+ TorchDistributor._cleanup_files(save_dir)
844
+
845
+ @staticmethod
846
+ @contextmanager
847
+ def _setup_spark_partition_data(
848
+ partition_data_iterator: Iterator[Any], input_schema_json: Dict[str, Any]
849
+ ) -> Iterator[Any]:
850
+ from pyspark.sql.pandas.serializers import ArrowStreamSerializer
851
+ from pyspark.files import SparkFiles
852
+ import json
853
+
854
+ if input_schema_json is None:
855
+ yield
856
+ return
857
+
858
+ # We need to temporarily write partition data into a temp dir,
859
+ # partition data might be huge, so we need to write it under
860
+ # configured `SPARK_LOCAL_DIRS`.
861
+ save_dir = TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
862
+
863
+ try:
864
+ serializer = ArrowStreamSerializer()
865
+ arrow_file_path = os.path.join(save_dir, "data.arrow")
866
+ with open(arrow_file_path, "wb") as f:
867
+ serializer.dump_stream(partition_data_iterator, f)
868
+ if f.tell() == 0:
869
+ # Nothing is written to file, this partition is empty
870
+ raise ValueError(
871
+ "Empty Spark partition is not allowed in "
872
+ "TorchDistributor.train_on_dataframe."
873
+ )
874
+
875
+ schema_file_path = os.path.join(save_dir, "schema.json")
876
+ schema_json_string = json.dumps(input_schema_json)
877
+
878
+ with open(schema_file_path, "w") as f:
879
+ f.write(schema_json_string)
880
+
881
+ os.environ[SPARK_PARTITION_ARROW_DATA_FILE] = arrow_file_path
882
+ os.environ[SPARK_DATAFRAME_SCHEMA_FILE] = schema_file_path
883
+ yield
884
+ finally:
885
+ os.environ.pop(SPARK_PARTITION_ARROW_DATA_FILE)
886
+ os.environ.pop(SPARK_DATAFRAME_SCHEMA_FILE)
887
+ TorchDistributor._cleanup_files(save_dir)
888
+
889
+ @staticmethod
890
+ def _run_training_on_pytorch_function(
891
+ input_params: Dict[str, Any],
892
+ train_fn: Callable,
893
+ run_pytorch_file_fn: Optional[Callable],
894
+ *args: Any,
895
+ **kwargs: Any,
896
+ ) -> Any:
897
+
898
+ if not run_pytorch_file_fn:
899
+ run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file
900
+
901
+ with TorchDistributor._setup_files(train_fn, *args, **kwargs) as (
902
+ train_file_path,
903
+ output_file_path,
904
+ ):
905
+ run_pytorch_file_fn(input_params, train_file_path)
906
+ if not os.path.exists(output_file_path):
907
+ raise RuntimeError(
908
+ "TorchDistributor failed during training."
909
+ "View stdout logs for detailed error message."
910
+ )
911
+ try:
912
+ output = TorchDistributor._get_pickled_output(output_file_path)
913
+ except Exception as e:
914
+ raise RuntimeError(
915
+ "TorchDistributor failed due to a pickling error. "
916
+ "View stdout logs for detailed error message."
917
+ ) from e
918
+ return output
919
+
920
+ @staticmethod
921
+ def _create_save_dir(root_dir: Optional[str] = None) -> str:
922
+ # TODO: need to do this in a safe way to avoid issues during concurrent runs
923
+ return tempfile.mkdtemp(dir=root_dir)
924
+
925
+ @staticmethod
926
+ def _cleanup_files(save_dir: str) -> None:
927
+ shutil.rmtree(save_dir, ignore_errors=True)
928
+
929
+ @staticmethod
930
+ def _save_pickled_function(
931
+ save_dir: str, train_fn: Union[str, Callable], *args: Any, **kwargs: Any
932
+ ) -> str:
933
+ saved_pickle_path = os.path.join(save_dir, TorchDistributor._PICKLED_FUNC_FILE)
934
+ with open(saved_pickle_path, "wb") as f:
935
+ cloudpickle.dump((train_fn, args, kwargs), f)
936
+ return saved_pickle_path
937
+
938
+ @staticmethod
939
+ def _create_torchrun_train_file(
940
+ save_dir_path: str, pickle_file_path: str, output_file_path: str
941
+ ) -> str:
942
+ code = textwrap.dedent(
943
+ f"""
944
+ from pyspark import cloudpickle
945
+ import os
946
+
947
+ if __name__ == "__main__":
948
+ with open("{pickle_file_path}", "rb") as f:
949
+ train_fn, args, kwargs = cloudpickle.load(f)
950
+ output = train_fn(*args, **kwargs)
951
+ with open("{output_file_path}", "wb") as f:
952
+ cloudpickle.dump(output, f)
953
+ """
954
+ )
955
+ saved_file_path = os.path.join(save_dir_path, TorchDistributor._TRAIN_FILE)
956
+ with open(saved_file_path, "w") as f:
957
+ f.write(code)
958
+ return saved_file_path
959
+
960
+ @staticmethod
961
+ def _get_pickled_output(output_file_path: str) -> Any:
962
+ with open(output_file_path, "rb") as f:
963
+ output = cloudpickle.load(f)
964
+ return output
965
+
966
+ def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]:
967
+ """Runs distributed training.
968
+
969
+ Parameters
970
+ ----------
971
+ train_object : callable object or str
972
+ Either a PyTorch function, PyTorch Lightning function, or the path to a python file
973
+ that launches distributed training.
974
+ args :
975
+ If train_object is a python function and not a path to a python file, args need
976
+ to be the input parameters to that function. It would look like
977
+
978
+ >>> model = distributor.run(train, 1e-3, 64)
979
+
980
+ where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
981
+
982
+ If train_object is a python file, then args would be the command-line arguments for
983
+ that python file which are all in the form of strings. An example would be
984
+
985
+ >>> distributor.run("/path/to/train.py", "--learning-rate=1e-3", "--batch-size=64")
986
+
987
+ where since the input is a path, all of the parameters are strings that can be
988
+ handled by argparse in that python file.
989
+ kwargs :
990
+ If train_object is a python function and not a path to a python file, kwargs need
991
+ to be the key-word input parameters to that function. It would look like
992
+
993
+ >>> model = distributor.run(train, tol=1e-3, max_iter=64)
994
+
995
+ where train is a function of 2 arguments `tol` and `max_iter`.
996
+
997
+ If train_object is a python file, then you should not set kwargs arguments.
998
+
999
+ Returns
1000
+ -------
1001
+ Returns the output of train_object called with args inside spark rank 0 task if the
1002
+ train_object is a Callable with an expected output. Returns None if train_object is
1003
+ a file.
1004
+ """
1005
+ return self._run(
1006
+ train_object, TorchDistributor._run_training_on_pytorch_file, *args, **kwargs
1007
+ )
1008
+
1009
+ def _run(
1010
+ self,
1011
+ train_object: Union[Callable, str],
1012
+ run_pytorch_file_fn: Callable,
1013
+ *args: Any,
1014
+ **kwargs: Any,
1015
+ ) -> Optional[Any]:
1016
+ if isinstance(train_object, str):
1017
+ framework_wrapper_fn = run_pytorch_file_fn
1018
+ else:
1019
+ framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_function
1020
+ if self.local_mode:
1021
+ output = self._run_local_training(
1022
+ framework_wrapper_fn, train_object, run_pytorch_file_fn, *args, **kwargs
1023
+ )
1024
+ else:
1025
+ output = self._run_distributed_training(
1026
+ framework_wrapper_fn, train_object, run_pytorch_file_fn, None, *args, **kwargs
1027
+ )
1028
+ return output
1029
+
1030
+ def _train_on_dataframe(
1031
+ self,
1032
+ train_function: Callable,
1033
+ spark_dataframe: "DataFrame",
1034
+ *args: Any,
1035
+ **kwargs: Any,
1036
+ ) -> Any:
1037
+ """
1038
+ Runs distributed training using provided Spark DataFrame as input data.
1039
+ You should ensure the input Spark DataFrame have evenly distributed partitions,
1040
+ and this method starts a barrier Spark job that each Spark task in the job
1041
+ process one partition of the input Spark DataFrame.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ train_function :
1046
+ Either a PyTorch function, PyTorch Lightning function that launches distributed
1047
+ training. Note that inside the function, you can call
1048
+ `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch
1049
+ data loader, the data loader loads data from the corresponding partition of the
1050
+ input Spark DataFrame.
1051
+ spark_dataframe :
1052
+ An input Spark DataFrame that can be used in PyTorch `train_function` function.
1053
+ See `train_function` argument doc for details.
1054
+ args :
1055
+ `args` need to be the input parameters to `train_function` function. It would look like
1056
+
1057
+ >>> model = distributor.run(train, 1e-3, 64)
1058
+
1059
+ where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
1060
+ kwargs :
1061
+ `kwargs` need to be the key-word input parameters to `train_function` function.
1062
+ It would look like
1063
+
1064
+ >>> model = distributor.run(train, tol=1e-3, max_iter=64)
1065
+
1066
+ where train is a function of 2 arguments `tol` and `max_iter`.
1067
+
1068
+ Returns
1069
+ -------
1070
+ Returns the output of `train_function` called with args inside Spark rank 0 task.
1071
+ """
1072
+
1073
+ if self.local_mode:
1074
+ raise ValueError(
1075
+ "TorchDistributor.train_on_dataframe requires setting "
1076
+ "TorchDistributor.local_mode to False."
1077
+ )
1078
+
1079
+ return self._run_distributed_training(
1080
+ TorchDistributor._run_training_on_pytorch_function,
1081
+ train_function,
1082
+ TorchDistributor._run_training_on_pytorch_file,
1083
+ spark_dataframe,
1084
+ *args,
1085
+ **kwargs,
1086
+ )
1087
+
1088
+
1089
+ def _get_spark_partition_data_loader(
1090
+ num_samples: int, batch_size: int, num_workers: int = 1, prefetch_factor: int = 2
1091
+ ) -> Any:
1092
+ """
1093
+ This function must be called inside the `train_function` where `train_function`
1094
+ is the input argument of `TorchDistributor.train_on_dataframe`.
1095
+ The function returns a pytorch data loader that loads data from
1096
+ the corresponding spark partition data.
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ num_samples :
1101
+ Number of samples to generate per epoch. If `num_samples` is less than the number of
1102
+ rows in the spark partition, it generate the first `num_samples` rows of
1103
+ the spark partition, if `num_samples` is greater than the number of
1104
+ rows in the spark partition, then after the iterator loaded all rows from the partition,
1105
+ it wraps round back to the first row.
1106
+ batch_size:
1107
+ How many samples per batch to load.
1108
+ num_workers:
1109
+ How many subprocesses to use for data loading.
1110
+ 0 means that the data will be loaded in the main process.
1111
+ prefetch_factor:
1112
+ Number of batches loaded in advance by each worker
1113
+ """
1114
+ from pyspark.sql.types import StructType
1115
+ from pyspark.ml.torch.data import _SparkPartitionTorchDataset
1116
+ from torch.utils.data import DataLoader
1117
+
1118
+ arrow_file = os.environ[SPARK_PARTITION_ARROW_DATA_FILE]
1119
+ schema_file = os.environ[SPARK_DATAFRAME_SCHEMA_FILE]
1120
+
1121
+ with open(schema_file, "r") as fp:
1122
+ schema = StructType.fromJson(json.load(fp))
1123
+
1124
+ dataset = _SparkPartitionTorchDataset(arrow_file, schema, num_samples)
1125
+
1126
+ if num_workers > 0:
1127
+ return DataLoader(
1128
+ dataset, batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor
1129
+ )
1130
+ else:
1131
+ # if num_workers is zero, we cannot set `prefetch_factor` otherwise
1132
+ # torch will raise error.
1133
+ return DataLoader(dataset, batch_size, num_workers=num_workers)