snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2188 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import sys
19
+ import warnings
20
+
21
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING
22
+
23
+ import numpy as np
24
+
25
+ from pyspark import since, keyword_only
26
+ from pyspark.ml.param.shared import (
27
+ HasMaxIter,
28
+ HasFeaturesCol,
29
+ HasSeed,
30
+ HasPredictionCol,
31
+ HasAggregationDepth,
32
+ HasWeightCol,
33
+ HasTol,
34
+ HasProbabilityCol,
35
+ HasDistanceMeasure,
36
+ HasCheckpointInterval,
37
+ HasSolver,
38
+ HasMaxBlockSizeInMB,
39
+ Param,
40
+ Params,
41
+ TypeConverters,
42
+ )
43
+ from pyspark.ml.util import (
44
+ JavaMLWritable,
45
+ JavaMLReadable,
46
+ GeneralJavaMLWritable,
47
+ HasTrainingSummary,
48
+ SparkContext,
49
+ )
50
+ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
51
+ from pyspark.ml.common import inherit_doc, _java2py
52
+ from pyspark.ml.stat import MultivariateGaussian
53
+ from pyspark.sql import DataFrame
54
+ from pyspark.ml.linalg import Vector, Matrix
55
+
56
+ if TYPE_CHECKING:
57
+ from pyspark.ml._typing import M
58
+ from py4j.java_gateway import JavaObject
59
+
60
+
61
+ __all__ = [
62
+ "BisectingKMeans",
63
+ "BisectingKMeansModel",
64
+ "BisectingKMeansSummary",
65
+ "KMeans",
66
+ "KMeansModel",
67
+ "KMeansSummary",
68
+ "GaussianMixture",
69
+ "GaussianMixtureModel",
70
+ "GaussianMixtureSummary",
71
+ "LDA",
72
+ "LDAModel",
73
+ "LocalLDAModel",
74
+ "DistributedLDAModel",
75
+ "PowerIterationClustering",
76
+ ]
77
+
78
+
79
+ class ClusteringSummary(JavaWrapper):
80
+ """
81
+ Clustering results for a given model.
82
+
83
+ .. versionadded:: 2.1.0
84
+ """
85
+
86
+ @property
87
+ @since("2.1.0")
88
+ def predictionCol(self) -> str:
89
+ """
90
+ Name for column of predicted clusters in `predictions`.
91
+ """
92
+ return self._call_java("predictionCol")
93
+
94
+ @property
95
+ @since("2.1.0")
96
+ def predictions(self) -> DataFrame:
97
+ """
98
+ DataFrame produced by the model's `transform` method.
99
+ """
100
+ return self._call_java("predictions")
101
+
102
+ @property
103
+ @since("2.1.0")
104
+ def featuresCol(self) -> str:
105
+ """
106
+ Name for column of features in `predictions`.
107
+ """
108
+ return self._call_java("featuresCol")
109
+
110
+ @property
111
+ @since("2.1.0")
112
+ def k(self) -> int:
113
+ """
114
+ The number of clusters the model was trained with.
115
+ """
116
+ return self._call_java("k")
117
+
118
+ @property
119
+ @since("2.1.0")
120
+ def cluster(self) -> DataFrame:
121
+ """
122
+ DataFrame of predicted cluster centers for each training data point.
123
+ """
124
+ return self._call_java("cluster")
125
+
126
+ @property
127
+ @since("2.1.0")
128
+ def clusterSizes(self) -> List[int]:
129
+ """
130
+ Size of (number of data points in) each cluster.
131
+ """
132
+ return self._call_java("clusterSizes")
133
+
134
+ @property
135
+ @since("2.4.0")
136
+ def numIter(self) -> int:
137
+ """
138
+ Number of iterations.
139
+ """
140
+ return self._call_java("numIter")
141
+
142
+
143
+ @inherit_doc
144
+ class _GaussianMixtureParams(
145
+ HasMaxIter,
146
+ HasFeaturesCol,
147
+ HasSeed,
148
+ HasPredictionCol,
149
+ HasProbabilityCol,
150
+ HasTol,
151
+ HasAggregationDepth,
152
+ HasWeightCol,
153
+ ):
154
+ """
155
+ Params for :py:class:`GaussianMixture` and :py:class:`GaussianMixtureModel`.
156
+
157
+ .. versionadded:: 3.0.0
158
+ """
159
+
160
+ k: Param[int] = Param(
161
+ Params._dummy(),
162
+ "k",
163
+ "Number of independent Gaussians in the mixture model. " + "Must be > 1.",
164
+ typeConverter=TypeConverters.toInt,
165
+ )
166
+
167
+ def __init__(self, *args: Any):
168
+ super(_GaussianMixtureParams, self).__init__(*args)
169
+ self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2)
170
+
171
+ @since("2.0.0")
172
+ def getK(self) -> int:
173
+ """
174
+ Gets the value of `k`
175
+ """
176
+ return self.getOrDefault(self.k)
177
+
178
+
179
+ class GaussianMixtureModel(
180
+ JavaModel,
181
+ _GaussianMixtureParams,
182
+ JavaMLWritable,
183
+ JavaMLReadable["GaussianMixtureModel"],
184
+ HasTrainingSummary["GaussianMixtureSummary"],
185
+ ):
186
+ """
187
+ Model fitted by GaussianMixture.
188
+
189
+ .. versionadded:: 2.0.0
190
+ """
191
+
192
+ @since("3.0.0")
193
+ def setFeaturesCol(self, value: str) -> "GaussianMixtureModel":
194
+ """
195
+ Sets the value of :py:attr:`featuresCol`.
196
+ """
197
+ return self._set(featuresCol=value)
198
+
199
+ @since("3.0.0")
200
+ def setPredictionCol(self, value: str) -> "GaussianMixtureModel":
201
+ """
202
+ Sets the value of :py:attr:`predictionCol`.
203
+ """
204
+ return self._set(predictionCol=value)
205
+
206
+ @since("3.0.0")
207
+ def setProbabilityCol(self, value: str) -> "GaussianMixtureModel":
208
+ """
209
+ Sets the value of :py:attr:`probabilityCol`.
210
+ """
211
+ return self._set(probabilityCol=value)
212
+
213
+ @property
214
+ @since("2.0.0")
215
+ def weights(self) -> List[float]:
216
+ """
217
+ Weight for each Gaussian distribution in the mixture.
218
+ This is a multinomial probability distribution over the k Gaussians,
219
+ where weights[i] is the weight for Gaussian i, and weights sum to 1.
220
+ """
221
+ return self._call_java("weights")
222
+
223
+ @property
224
+ @since("3.0.0")
225
+ def gaussians(self) -> List[MultivariateGaussian]:
226
+ """
227
+ Array of :py:class:`MultivariateGaussian` where gaussians[i] represents
228
+ the Multivariate Gaussian (Normal) Distribution for Gaussian i
229
+ """
230
+ sc = SparkContext._active_spark_context
231
+ assert sc is not None and self._java_obj is not None
232
+
233
+ jgaussians = self._java_obj.gaussians()
234
+ return [
235
+ MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc, jgaussian.cov()))
236
+ for jgaussian in jgaussians
237
+ ]
238
+
239
+ @property
240
+ @since("2.0.0")
241
+ def gaussiansDF(self) -> DataFrame:
242
+ """
243
+ Retrieve Gaussian distributions as a DataFrame.
244
+ Each row represents a Gaussian Distribution.
245
+ The DataFrame has two columns: mean (Vector) and cov (Matrix).
246
+ """
247
+ return self._call_java("gaussiansDF")
248
+
249
+ @property
250
+ @since("2.1.0")
251
+ def summary(self) -> "GaussianMixtureSummary":
252
+ """
253
+ Gets summary (cluster assignments, cluster sizes) of the model trained on the
254
+ training set. An exception is thrown if no summary exists.
255
+ """
256
+ if self.hasSummary:
257
+ return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
258
+ else:
259
+ raise RuntimeError(
260
+ "No training summary available for this %s" % self.__class__.__name__
261
+ )
262
+
263
+ @since("3.0.0")
264
+ def predict(self, value: Vector) -> int:
265
+ """
266
+ Predict label for the given features.
267
+ """
268
+ return self._call_java("predict", value)
269
+
270
+ @since("3.0.0")
271
+ def predictProbability(self, value: Vector) -> Vector:
272
+ """
273
+ Predict probability for the given features.
274
+ """
275
+ return self._call_java("predictProbability", value)
276
+
277
+
278
+ @inherit_doc
279
+ class GaussianMixture(
280
+ JavaEstimator[GaussianMixtureModel],
281
+ _GaussianMixtureParams,
282
+ JavaMLWritable,
283
+ JavaMLReadable["GaussianMixture"],
284
+ ):
285
+ """
286
+ GaussianMixture clustering.
287
+ This class performs expectation maximization for multivariate Gaussian
288
+ Mixture Models (GMMs). A GMM represents a composite distribution of
289
+ independent Gaussian distributions with associated "mixing" weights
290
+ specifying each's contribution to the composite.
291
+
292
+ Given a set of sample points, this class will maximize the log-likelihood
293
+ for a mixture of k Gaussians, iterating until the log-likelihood changes by
294
+ less than convergenceTol, or until it has reached the max number of iterations.
295
+ While this process is generally guaranteed to converge, it is not guaranteed
296
+ to find a global optimum.
297
+
298
+ .. versionadded:: 2.0.0
299
+
300
+ Notes
301
+ -----
302
+ For high-dimensional data (with many features), this algorithm may perform poorly.
303
+ This is due to high-dimensional data (a) making it difficult to cluster at all
304
+ (based on statistical/theoretical arguments) and (b) numerical issues with
305
+ Gaussian distributions.
306
+
307
+ Examples
308
+ --------
309
+ >>> from pyspark.ml.linalg import Vectors
310
+
311
+ >>> data = [(Vectors.dense([-0.1, -0.05 ]),),
312
+ ... (Vectors.dense([-0.01, -0.1]),),
313
+ ... (Vectors.dense([0.9, 0.8]),),
314
+ ... (Vectors.dense([0.75, 0.935]),),
315
+ ... (Vectors.dense([-0.83, -0.68]),),
316
+ ... (Vectors.dense([-0.91, -0.76]),)]
317
+ >>> df = spark.createDataFrame(data, ["features"])
318
+ >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
319
+ >>> gm.getMaxIter()
320
+ 100
321
+ >>> gm.setMaxIter(30)
322
+ GaussianMixture...
323
+ >>> gm.getMaxIter()
324
+ 30
325
+ >>> model = gm.fit(df)
326
+ >>> model.getAggregationDepth()
327
+ 2
328
+ >>> model.getFeaturesCol()
329
+ 'features'
330
+ >>> model.setPredictionCol("newPrediction")
331
+ GaussianMixtureModel...
332
+ >>> model.predict(df.head().features)
333
+ 2
334
+ >>> model.predictProbability(df.head().features)
335
+ DenseVector([0.0, 0.0, 1.0])
336
+ >>> model.hasSummary
337
+ True
338
+ >>> summary = model.summary
339
+ >>> summary.k
340
+ 3
341
+ >>> summary.clusterSizes
342
+ [2, 2, 2]
343
+ >>> weights = model.weights
344
+ >>> len(weights)
345
+ 3
346
+ >>> gaussians = model.gaussians
347
+ >>> len(gaussians)
348
+ 3
349
+ >>> gaussians[0].mean
350
+ DenseVector([0.825, 0.8675])
351
+ >>> gaussians[0].cov
352
+ DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], 0)
353
+ >>> gaussians[1].mean
354
+ DenseVector([-0.87, -0.72])
355
+ >>> gaussians[1].cov
356
+ DenseMatrix(2, 2, [0.0016, 0.0016, 0.0016, 0.0016], 0)
357
+ >>> gaussians[2].mean
358
+ DenseVector([-0.055, -0.075])
359
+ >>> gaussians[2].cov
360
+ DenseMatrix(2, 2, [0.002, -0.0011, -0.0011, 0.0006], 0)
361
+ >>> model.gaussiansDF.select("mean").head()
362
+ Row(mean=DenseVector([0.825, 0.8675]))
363
+ >>> model.gaussiansDF.select("cov").head()
364
+ Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))
365
+ >>> transformed = model.transform(df).select("features", "newPrediction")
366
+ >>> rows = transformed.collect()
367
+ >>> rows[4].newPrediction == rows[5].newPrediction
368
+ True
369
+ >>> rows[2].newPrediction == rows[3].newPrediction
370
+ True
371
+ >>> gmm_path = temp_path + "/gmm"
372
+ >>> gm.save(gmm_path)
373
+ >>> gm2 = GaussianMixture.load(gmm_path)
374
+ >>> gm2.getK()
375
+ 3
376
+ >>> model_path = temp_path + "/gmm_model"
377
+ >>> model.save(model_path)
378
+ >>> model2 = GaussianMixtureModel.load(model_path)
379
+ >>> model2.hasSummary
380
+ False
381
+ >>> model2.weights == model.weights
382
+ True
383
+ >>> model2.gaussians[0].mean == model.gaussians[0].mean
384
+ True
385
+ >>> model2.gaussians[0].cov == model.gaussians[0].cov
386
+ True
387
+ >>> model2.gaussians[1].mean == model.gaussians[1].mean
388
+ True
389
+ >>> model2.gaussians[1].cov == model.gaussians[1].cov
390
+ True
391
+ >>> model2.gaussians[2].mean == model.gaussians[2].mean
392
+ True
393
+ >>> model2.gaussians[2].cov == model.gaussians[2].cov
394
+ True
395
+ >>> model2.gaussiansDF.select("mean").head()
396
+ Row(mean=DenseVector([0.825, 0.8675]))
397
+ >>> model2.gaussiansDF.select("cov").head()
398
+ Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))
399
+ >>> model.transform(df).take(1) == model2.transform(df).take(1)
400
+ True
401
+ >>> gm2.setWeightCol("weight")
402
+ GaussianMixture...
403
+ """
404
+
405
+ _input_kwargs: Dict[str, Any]
406
+
407
+ @keyword_only
408
+ def __init__(
409
+ self,
410
+ *,
411
+ featuresCol: str = "features",
412
+ predictionCol: str = "prediction",
413
+ k: int = 2,
414
+ probabilityCol: str = "probability",
415
+ tol: float = 0.01,
416
+ maxIter: int = 100,
417
+ seed: Optional[int] = None,
418
+ aggregationDepth: int = 2,
419
+ weightCol: Optional[str] = None,
420
+ ):
421
+ """
422
+ __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
423
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
424
+ aggregationDepth=2, weightCol=None)
425
+ """
426
+ super(GaussianMixture, self).__init__()
427
+ self._java_obj = self._new_java_obj(
428
+ "org.apache.spark.ml.clustering.GaussianMixture", self.uid
429
+ )
430
+ kwargs = self._input_kwargs
431
+ self.setParams(**kwargs)
432
+
433
+ def _create_model(self, java_model: "JavaObject") -> "GaussianMixtureModel":
434
+ return GaussianMixtureModel(java_model)
435
+
436
+ @keyword_only
437
+ @since("2.0.0")
438
+ def setParams(
439
+ self,
440
+ *,
441
+ featuresCol: str = "features",
442
+ predictionCol: str = "prediction",
443
+ k: int = 2,
444
+ probabilityCol: str = "probability",
445
+ tol: float = 0.01,
446
+ maxIter: int = 100,
447
+ seed: Optional[int] = None,
448
+ aggregationDepth: int = 2,
449
+ weightCol: Optional[str] = None,
450
+ ) -> "GaussianMixture":
451
+ """
452
+ setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
453
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
454
+ aggregationDepth=2, weightCol=None)
455
+
456
+ Sets params for GaussianMixture.
457
+ """
458
+ kwargs = self._input_kwargs
459
+ return self._set(**kwargs)
460
+
461
+ @since("2.0.0")
462
+ def setK(self, value: int) -> "GaussianMixture":
463
+ """
464
+ Sets the value of :py:attr:`k`.
465
+ """
466
+ return self._set(k=value)
467
+
468
+ @since("2.0.0")
469
+ def setMaxIter(self, value: int) -> "GaussianMixture":
470
+ """
471
+ Sets the value of :py:attr:`maxIter`.
472
+ """
473
+ return self._set(maxIter=value)
474
+
475
+ @since("2.0.0")
476
+ def setFeaturesCol(self, value: str) -> "GaussianMixture":
477
+ """
478
+ Sets the value of :py:attr:`featuresCol`.
479
+ """
480
+ return self._set(featuresCol=value)
481
+
482
+ @since("2.0.0")
483
+ def setPredictionCol(self, value: str) -> "GaussianMixture":
484
+ """
485
+ Sets the value of :py:attr:`predictionCol`.
486
+ """
487
+ return self._set(predictionCol=value)
488
+
489
+ @since("2.0.0")
490
+ def setProbabilityCol(self, value: str) -> "GaussianMixture":
491
+ """
492
+ Sets the value of :py:attr:`probabilityCol`.
493
+ """
494
+ return self._set(probabilityCol=value)
495
+
496
+ @since("3.0.0")
497
+ def setWeightCol(self, value: str) -> "GaussianMixture":
498
+ """
499
+ Sets the value of :py:attr:`weightCol`.
500
+ """
501
+ return self._set(weightCol=value)
502
+
503
+ @since("2.0.0")
504
+ def setSeed(self, value: int) -> "GaussianMixture":
505
+ """
506
+ Sets the value of :py:attr:`seed`.
507
+ """
508
+ return self._set(seed=value)
509
+
510
+ @since("2.0.0")
511
+ def setTol(self, value: float) -> "GaussianMixture":
512
+ """
513
+ Sets the value of :py:attr:`tol`.
514
+ """
515
+ return self._set(tol=value)
516
+
517
+ @since("3.0.0")
518
+ def setAggregationDepth(self, value: int) -> "GaussianMixture":
519
+ """
520
+ Sets the value of :py:attr:`aggregationDepth`.
521
+ """
522
+ return self._set(aggregationDepth=value)
523
+
524
+
525
+ class GaussianMixtureSummary(ClusteringSummary):
526
+ """
527
+ Gaussian mixture clustering results for a given model.
528
+
529
+ .. versionadded:: 2.1.0
530
+ """
531
+
532
+ @property
533
+ @since("2.1.0")
534
+ def probabilityCol(self) -> str:
535
+ """
536
+ Name for column of predicted probability of each cluster in `predictions`.
537
+ """
538
+ return self._call_java("probabilityCol")
539
+
540
+ @property
541
+ @since("2.1.0")
542
+ def probability(self) -> DataFrame:
543
+ """
544
+ DataFrame of probabilities of each cluster for each training data point.
545
+ """
546
+ return self._call_java("probability")
547
+
548
+ @property
549
+ @since("2.2.0")
550
+ def logLikelihood(self) -> float:
551
+ """
552
+ Total log-likelihood for this model on the given data.
553
+ """
554
+ return self._call_java("logLikelihood")
555
+
556
+
557
+ class KMeansSummary(ClusteringSummary):
558
+ """
559
+ Summary of KMeans.
560
+
561
+ .. versionadded:: 2.1.0
562
+ """
563
+
564
+ @property
565
+ @since("2.4.0")
566
+ def trainingCost(self) -> float:
567
+ """
568
+ K-means cost (sum of squared distances to the nearest centroid for all points in the
569
+ training dataset). This is equivalent to sklearn's inertia.
570
+ """
571
+ return self._call_java("trainingCost")
572
+
573
+
574
+ @inherit_doc
575
+ class _KMeansParams(
576
+ HasMaxIter,
577
+ HasFeaturesCol,
578
+ HasSeed,
579
+ HasPredictionCol,
580
+ HasTol,
581
+ HasDistanceMeasure,
582
+ HasWeightCol,
583
+ HasSolver,
584
+ HasMaxBlockSizeInMB,
585
+ ):
586
+ """
587
+ Params for :py:class:`KMeans` and :py:class:`KMeansModel`.
588
+
589
+ .. versionadded:: 3.0.0
590
+ """
591
+
592
+ k: Param[int] = Param(
593
+ Params._dummy(),
594
+ "k",
595
+ "The number of clusters to create. Must be > 1.",
596
+ typeConverter=TypeConverters.toInt,
597
+ )
598
+ initMode: Param[str] = Param(
599
+ Params._dummy(),
600
+ "initMode",
601
+ 'The initialization algorithm. This can be either "random" to '
602
+ + 'choose random points as initial cluster centers, or "k-means||" '
603
+ + "to use a parallel variant of k-means++",
604
+ typeConverter=TypeConverters.toString,
605
+ )
606
+ initSteps: Param[int] = Param(
607
+ Params._dummy(),
608
+ "initSteps",
609
+ "The number of steps for k-means|| " + "initialization mode. Must be > 0.",
610
+ typeConverter=TypeConverters.toInt,
611
+ )
612
+ solver: Param[str] = Param(
613
+ Params._dummy(),
614
+ "solver",
615
+ "The solver algorithm for optimization. Supported " + "options: auto, row, block.",
616
+ typeConverter=TypeConverters.toString,
617
+ )
618
+
619
+ def __init__(self, *args: Any):
620
+ super(_KMeansParams, self).__init__(*args)
621
+ self._setDefault(
622
+ k=2,
623
+ initMode="k-means||",
624
+ initSteps=2,
625
+ tol=1e-4,
626
+ maxIter=20,
627
+ distanceMeasure="euclidean",
628
+ solver="auto",
629
+ maxBlockSizeInMB=0.0,
630
+ )
631
+
632
+ @since("1.5.0")
633
+ def getK(self) -> int:
634
+ """
635
+ Gets the value of `k`
636
+ """
637
+ return self.getOrDefault(self.k)
638
+
639
+ @since("1.5.0")
640
+ def getInitMode(self) -> str:
641
+ """
642
+ Gets the value of `initMode`
643
+ """
644
+ return self.getOrDefault(self.initMode)
645
+
646
+ @since("1.5.0")
647
+ def getInitSteps(self) -> int:
648
+ """
649
+ Gets the value of `initSteps`
650
+ """
651
+ return self.getOrDefault(self.initSteps)
652
+
653
+
654
+ class KMeansModel(
655
+ JavaModel,
656
+ _KMeansParams,
657
+ GeneralJavaMLWritable,
658
+ JavaMLReadable["KMeansModel"],
659
+ HasTrainingSummary["KMeansSummary"],
660
+ ):
661
+ """
662
+ Model fitted by KMeans.
663
+
664
+ .. versionadded:: 1.5.0
665
+ """
666
+
667
+ @since("3.0.0")
668
+ def setFeaturesCol(self, value: str) -> "KMeansModel":
669
+ """
670
+ Sets the value of :py:attr:`featuresCol`.
671
+ """
672
+ return self._set(featuresCol=value)
673
+
674
+ @since("3.0.0")
675
+ def setPredictionCol(self, value: str) -> "KMeansModel":
676
+ """
677
+ Sets the value of :py:attr:`predictionCol`.
678
+ """
679
+ return self._set(predictionCol=value)
680
+
681
+ @since("1.5.0")
682
+ def clusterCenters(self) -> List[np.ndarray]:
683
+ """Get the cluster centers, represented as a list of NumPy arrays."""
684
+ return [c.toArray() for c in self._call_java("clusterCenters")]
685
+
686
+ @property
687
+ @since("2.1.0")
688
+ def summary(self) -> KMeansSummary:
689
+ """
690
+ Gets summary (cluster assignments, cluster sizes) of the model trained on the
691
+ training set. An exception is thrown if no summary exists.
692
+ """
693
+ if self.hasSummary:
694
+ return KMeansSummary(super(KMeansModel, self).summary)
695
+ else:
696
+ raise RuntimeError(
697
+ "No training summary available for this %s" % self.__class__.__name__
698
+ )
699
+
700
+ @since("3.0.0")
701
+ def predict(self, value: Vector) -> int:
702
+ """
703
+ Predict label for the given features.
704
+ """
705
+ return self._call_java("predict", value)
706
+
707
+
708
+ @inherit_doc
709
+ class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable["KMeans"]):
710
+ """
711
+ K-means clustering with a k-means++ like initialization mode
712
+ (the k-means|| algorithm by Bahmani et al).
713
+
714
+ .. versionadded:: 1.5.0
715
+
716
+ Examples
717
+ --------
718
+ >>> from pyspark.ml.linalg import Vectors
719
+ >>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),
720
+ ... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]
721
+ >>> df = spark.createDataFrame(data, ["features", "weighCol"])
722
+ >>> kmeans = KMeans(k=2)
723
+ >>> kmeans.setSeed(1)
724
+ KMeans...
725
+ >>> kmeans.setWeightCol("weighCol")
726
+ KMeans...
727
+ >>> kmeans.setMaxIter(10)
728
+ KMeans...
729
+ >>> kmeans.getMaxIter()
730
+ 10
731
+ >>> kmeans.clear(kmeans.maxIter)
732
+ >>> kmeans.getSolver()
733
+ 'auto'
734
+ >>> model = kmeans.fit(df)
735
+ >>> model.getMaxBlockSizeInMB()
736
+ 0.0
737
+ >>> model.getDistanceMeasure()
738
+ 'euclidean'
739
+ >>> model.setPredictionCol("newPrediction")
740
+ KMeansModel...
741
+ >>> model.predict(df.head().features)
742
+ 0
743
+ >>> centers = model.clusterCenters()
744
+ >>> len(centers)
745
+ 2
746
+ >>> transformed = model.transform(df).select("features", "newPrediction")
747
+ >>> rows = transformed.collect()
748
+ >>> rows[0].newPrediction == rows[1].newPrediction
749
+ True
750
+ >>> rows[2].newPrediction == rows[3].newPrediction
751
+ True
752
+ >>> model.hasSummary
753
+ True
754
+ >>> summary = model.summary
755
+ >>> summary.k
756
+ 2
757
+ >>> summary.clusterSizes
758
+ [2, 2]
759
+ >>> summary.trainingCost
760
+ 4.0
761
+ >>> kmeans_path = temp_path + "/kmeans"
762
+ >>> kmeans.save(kmeans_path)
763
+ >>> kmeans2 = KMeans.load(kmeans_path)
764
+ >>> kmeans2.getK()
765
+ 2
766
+ >>> model_path = temp_path + "/kmeans_model"
767
+ >>> model.save(model_path)
768
+ >>> model2 = KMeansModel.load(model_path)
769
+ >>> model2.hasSummary
770
+ False
771
+ >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
772
+ array([ True, True], dtype=bool)
773
+ >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
774
+ array([ True, True], dtype=bool)
775
+ >>> model.transform(df).take(1) == model2.transform(df).take(1)
776
+ True
777
+ """
778
+
779
+ _input_kwargs: Dict[str, Any]
780
+
781
+ @keyword_only
782
+ def __init__(
783
+ self,
784
+ *,
785
+ featuresCol: str = "features",
786
+ predictionCol: str = "prediction",
787
+ k: int = 2,
788
+ initMode: str = "k-means||",
789
+ initSteps: int = 2,
790
+ tol: float = 1e-4,
791
+ maxIter: int = 20,
792
+ seed: Optional[int] = None,
793
+ distanceMeasure: str = "euclidean",
794
+ weightCol: Optional[str] = None,
795
+ solver: str = "auto",
796
+ maxBlockSizeInMB: float = 0.0,
797
+ ):
798
+ """
799
+ __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
800
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
801
+ distanceMeasure="euclidean", weightCol=None, solver="auto", \
802
+ maxBlockSizeInMB=0.0)
803
+ """
804
+ super(KMeans, self).__init__()
805
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
806
+ kwargs = self._input_kwargs
807
+ self.setParams(**kwargs)
808
+
809
+ def _create_model(self, java_model: "JavaObject") -> KMeansModel:
810
+ return KMeansModel(java_model)
811
+
812
+ @keyword_only
813
+ @since("1.5.0")
814
+ def setParams(
815
+ self,
816
+ *,
817
+ featuresCol: str = "features",
818
+ predictionCol: str = "prediction",
819
+ k: int = 2,
820
+ initMode: str = "k-means||",
821
+ initSteps: int = 2,
822
+ tol: float = 1e-4,
823
+ maxIter: int = 20,
824
+ seed: Optional[int] = None,
825
+ distanceMeasure: str = "euclidean",
826
+ weightCol: Optional[str] = None,
827
+ solver: str = "auto",
828
+ maxBlockSizeInMB: float = 0.0,
829
+ ) -> "KMeans":
830
+ """
831
+ setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
832
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
833
+ distanceMeasure="euclidean", weightCol=None, solver="auto", \
834
+ maxBlockSizeInMB=0.0)
835
+
836
+ Sets params for KMeans.
837
+ """
838
+ kwargs = self._input_kwargs
839
+ return self._set(**kwargs)
840
+
841
+ @since("1.5.0")
842
+ def setK(self, value: int) -> "KMeans":
843
+ """
844
+ Sets the value of :py:attr:`k`.
845
+ """
846
+ return self._set(k=value)
847
+
848
+ @since("1.5.0")
849
+ def setInitMode(self, value: str) -> "KMeans":
850
+ """
851
+ Sets the value of :py:attr:`initMode`.
852
+ """
853
+ return self._set(initMode=value)
854
+
855
+ @since("1.5.0")
856
+ def setInitSteps(self, value: int) -> "KMeans":
857
+ """
858
+ Sets the value of :py:attr:`initSteps`.
859
+ """
860
+ return self._set(initSteps=value)
861
+
862
+ @since("2.4.0")
863
+ def setDistanceMeasure(self, value: str) -> "KMeans":
864
+ """
865
+ Sets the value of :py:attr:`distanceMeasure`.
866
+ """
867
+ return self._set(distanceMeasure=value)
868
+
869
+ @since("1.5.0")
870
+ def setMaxIter(self, value: int) -> "KMeans":
871
+ """
872
+ Sets the value of :py:attr:`maxIter`.
873
+ """
874
+ return self._set(maxIter=value)
875
+
876
+ @since("1.5.0")
877
+ def setFeaturesCol(self, value: str) -> "KMeans":
878
+ """
879
+ Sets the value of :py:attr:`featuresCol`.
880
+ """
881
+ return self._set(featuresCol=value)
882
+
883
+ @since("1.5.0")
884
+ def setPredictionCol(self, value: str) -> "KMeans":
885
+ """
886
+ Sets the value of :py:attr:`predictionCol`.
887
+ """
888
+ return self._set(predictionCol=value)
889
+
890
+ @since("1.5.0")
891
+ def setSeed(self, value: int) -> "KMeans":
892
+ """
893
+ Sets the value of :py:attr:`seed`.
894
+ """
895
+ return self._set(seed=value)
896
+
897
+ @since("1.5.0")
898
+ def setTol(self, value: float) -> "KMeans":
899
+ """
900
+ Sets the value of :py:attr:`tol`.
901
+ """
902
+ return self._set(tol=value)
903
+
904
+ @since("3.0.0")
905
+ def setWeightCol(self, value: str) -> "KMeans":
906
+ """
907
+ Sets the value of :py:attr:`weightCol`.
908
+ """
909
+ return self._set(weightCol=value)
910
+
911
+ @since("3.4.0")
912
+ def setSolver(self, value: str) -> "KMeans":
913
+ """
914
+ Sets the value of :py:attr:`solver`.
915
+ """
916
+ return self._set(solver=value)
917
+
918
+ @since("3.4.0")
919
+ def setMaxBlockSizeInMB(self, value: float) -> "KMeans":
920
+ """
921
+ Sets the value of :py:attr:`maxBlockSizeInMB`.
922
+ """
923
+ return self._set(maxBlockSizeInMB=value)
924
+
925
+
926
+ @inherit_doc
927
+ class _BisectingKMeansParams(
928
+ HasMaxIter,
929
+ HasFeaturesCol,
930
+ HasSeed,
931
+ HasPredictionCol,
932
+ HasDistanceMeasure,
933
+ HasWeightCol,
934
+ ):
935
+ """
936
+ Params for :py:class:`BisectingKMeans` and :py:class:`BisectingKMeansModel`.
937
+
938
+ .. versionadded:: 3.0.0
939
+ """
940
+
941
+ k: Param[int] = Param(
942
+ Params._dummy(),
943
+ "k",
944
+ "The desired number of leaf clusters. Must be > 1.",
945
+ typeConverter=TypeConverters.toInt,
946
+ )
947
+ minDivisibleClusterSize: Param[float] = Param(
948
+ Params._dummy(),
949
+ "minDivisibleClusterSize",
950
+ "The minimum number of points (if >= 1.0) or the minimum "
951
+ + "proportion of points (if < 1.0) of a divisible cluster.",
952
+ typeConverter=TypeConverters.toFloat,
953
+ )
954
+
955
+ def __init__(self, *args: Any):
956
+ super(_BisectingKMeansParams, self).__init__(*args)
957
+ self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0)
958
+
959
+ @since("2.0.0")
960
+ def getK(self) -> int:
961
+ """
962
+ Gets the value of `k` or its default value.
963
+ """
964
+ return self.getOrDefault(self.k)
965
+
966
+ @since("2.0.0")
967
+ def getMinDivisibleClusterSize(self) -> float:
968
+ """
969
+ Gets the value of `minDivisibleClusterSize` or its default value.
970
+ """
971
+ return self.getOrDefault(self.minDivisibleClusterSize)
972
+
973
+
974
+ class BisectingKMeansModel(
975
+ JavaModel,
976
+ _BisectingKMeansParams,
977
+ JavaMLWritable,
978
+ JavaMLReadable["BisectingKMeansModel"],
979
+ HasTrainingSummary["BisectingKMeansSummary"],
980
+ ):
981
+ """
982
+ Model fitted by BisectingKMeans.
983
+
984
+ .. versionadded:: 2.0.0
985
+ """
986
+
987
+ @since("3.0.0")
988
+ def setFeaturesCol(self, value: str) -> "BisectingKMeansModel":
989
+ """
990
+ Sets the value of :py:attr:`featuresCol`.
991
+ """
992
+ return self._set(featuresCol=value)
993
+
994
+ @since("3.0.0")
995
+ def setPredictionCol(self, value: str) -> "BisectingKMeansModel":
996
+ """
997
+ Sets the value of :py:attr:`predictionCol`.
998
+ """
999
+ return self._set(predictionCol=value)
1000
+
1001
+ @since("2.0.0")
1002
+ def clusterCenters(self) -> List[np.ndarray]:
1003
+ """Get the cluster centers, represented as a list of NumPy arrays."""
1004
+ return [c.toArray() for c in self._call_java("clusterCenters")]
1005
+
1006
+ @since("2.0.0")
1007
+ def computeCost(self, dataset: DataFrame) -> float:
1008
+ """
1009
+ Computes the sum of squared distances between the input points
1010
+ and their corresponding cluster centers.
1011
+
1012
+ .. deprecated:: 3.0.0
1013
+ It will be removed in future versions. Use :py:class:`ClusteringEvaluator` instead.
1014
+ You can also get the cost on the training dataset in the summary.
1015
+ """
1016
+ warnings.warn(
1017
+ "Deprecated in 3.0.0. It will be removed in future versions. Use "
1018
+ "ClusteringEvaluator instead. You can also get the cost on the training "
1019
+ "dataset in the summary.",
1020
+ FutureWarning,
1021
+ )
1022
+ return self._call_java("computeCost", dataset)
1023
+
1024
+ @property
1025
+ @since("2.1.0")
1026
+ def summary(self) -> "BisectingKMeansSummary":
1027
+ """
1028
+ Gets summary (cluster assignments, cluster sizes) of the model trained on the
1029
+ training set. An exception is thrown if no summary exists.
1030
+ """
1031
+ if self.hasSummary:
1032
+ return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
1033
+ else:
1034
+ raise RuntimeError(
1035
+ "No training summary available for this %s" % self.__class__.__name__
1036
+ )
1037
+
1038
+ @since("3.0.0")
1039
+ def predict(self, value: Vector) -> int:
1040
+ """
1041
+ Predict label for the given features.
1042
+ """
1043
+ return self._call_java("predict", value)
1044
+
1045
+
1046
+ @inherit_doc
1047
+ class BisectingKMeans(
1048
+ JavaEstimator[BisectingKMeansModel],
1049
+ _BisectingKMeansParams,
1050
+ JavaMLWritable,
1051
+ JavaMLReadable["BisectingKMeans"],
1052
+ ):
1053
+ """
1054
+ A bisecting k-means algorithm based on the paper "A comparison of document clustering
1055
+ techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark.
1056
+ The algorithm starts from a single cluster that contains all points.
1057
+ Iteratively it finds divisible clusters on the bottom level and bisects each of them using
1058
+ k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
1059
+ The bisecting steps of clusters on the same level are grouped together to increase parallelism.
1060
+ If bisecting all divisible clusters on the bottom level would result more than `k` leaf
1061
+ clusters, larger clusters get higher priority.
1062
+
1063
+ .. versionadded:: 2.0.0
1064
+
1065
+ Examples
1066
+ --------
1067
+ >>> from pyspark.ml.linalg import Vectors
1068
+ >>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),
1069
+ ... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]
1070
+ >>> df = spark.createDataFrame(data, ["features", "weighCol"])
1071
+ >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0)
1072
+ >>> bkm.setMaxIter(10)
1073
+ BisectingKMeans...
1074
+ >>> bkm.getMaxIter()
1075
+ 10
1076
+ >>> bkm.clear(bkm.maxIter)
1077
+ >>> bkm.setSeed(1)
1078
+ BisectingKMeans...
1079
+ >>> bkm.setWeightCol("weighCol")
1080
+ BisectingKMeans...
1081
+ >>> bkm.getSeed()
1082
+ 1
1083
+ >>> bkm.clear(bkm.seed)
1084
+ >>> model = bkm.fit(df)
1085
+ >>> model.getMaxIter()
1086
+ 20
1087
+ >>> model.setPredictionCol("newPrediction")
1088
+ BisectingKMeansModel...
1089
+ >>> model.predict(df.head().features)
1090
+ 0
1091
+ >>> centers = model.clusterCenters()
1092
+ >>> len(centers)
1093
+ 2
1094
+ >>> model.computeCost(df)
1095
+ 2.0
1096
+ >>> model.hasSummary
1097
+ True
1098
+ >>> summary = model.summary
1099
+ >>> summary.k
1100
+ 2
1101
+ >>> summary.clusterSizes
1102
+ [2, 2]
1103
+ >>> summary.trainingCost
1104
+ 4.000...
1105
+ >>> transformed = model.transform(df).select("features", "newPrediction")
1106
+ >>> rows = transformed.collect()
1107
+ >>> rows[0].newPrediction == rows[1].newPrediction
1108
+ True
1109
+ >>> rows[2].newPrediction == rows[3].newPrediction
1110
+ True
1111
+ >>> bkm_path = temp_path + "/bkm"
1112
+ >>> bkm.save(bkm_path)
1113
+ >>> bkm2 = BisectingKMeans.load(bkm_path)
1114
+ >>> bkm2.getK()
1115
+ 2
1116
+ >>> bkm2.getDistanceMeasure()
1117
+ 'euclidean'
1118
+ >>> model_path = temp_path + "/bkm_model"
1119
+ >>> model.save(model_path)
1120
+ >>> model2 = BisectingKMeansModel.load(model_path)
1121
+ >>> model2.hasSummary
1122
+ False
1123
+ >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
1124
+ array([ True, True], dtype=bool)
1125
+ >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
1126
+ array([ True, True], dtype=bool)
1127
+ >>> model.transform(df).take(1) == model2.transform(df).take(1)
1128
+ True
1129
+ """
1130
+
1131
+ _input_kwargs: Dict[str, Any]
1132
+
1133
+ @keyword_only
1134
+ def __init__(
1135
+ self,
1136
+ *,
1137
+ featuresCol: str = "features",
1138
+ predictionCol: str = "prediction",
1139
+ maxIter: int = 20,
1140
+ seed: Optional[int] = None,
1141
+ k: int = 4,
1142
+ minDivisibleClusterSize: float = 1.0,
1143
+ distanceMeasure: str = "euclidean",
1144
+ weightCol: Optional[str] = None,
1145
+ ):
1146
+ """
1147
+ __init__(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
1148
+ seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
1149
+ weightCol=None)
1150
+ """
1151
+ super(BisectingKMeans, self).__init__()
1152
+ self._java_obj = self._new_java_obj(
1153
+ "org.apache.spark.ml.clustering.BisectingKMeans", self.uid
1154
+ )
1155
+ kwargs = self._input_kwargs
1156
+ self.setParams(**kwargs)
1157
+
1158
+ @keyword_only
1159
+ @since("2.0.0")
1160
+ def setParams(
1161
+ self,
1162
+ *,
1163
+ featuresCol: str = "features",
1164
+ predictionCol: str = "prediction",
1165
+ maxIter: int = 20,
1166
+ seed: Optional[int] = None,
1167
+ k: int = 4,
1168
+ minDivisibleClusterSize: float = 1.0,
1169
+ distanceMeasure: str = "euclidean",
1170
+ weightCol: Optional[str] = None,
1171
+ ) -> "BisectingKMeans":
1172
+ """
1173
+ setParams(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
1174
+ seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
1175
+ weightCol=None)
1176
+ Sets params for BisectingKMeans.
1177
+ """
1178
+ kwargs = self._input_kwargs
1179
+ return self._set(**kwargs)
1180
+
1181
+ @since("2.0.0")
1182
+ def setK(self, value: int) -> "BisectingKMeans":
1183
+ """
1184
+ Sets the value of :py:attr:`k`.
1185
+ """
1186
+ return self._set(k=value)
1187
+
1188
+ @since("2.0.0")
1189
+ def setMinDivisibleClusterSize(self, value: float) -> "BisectingKMeans":
1190
+ """
1191
+ Sets the value of :py:attr:`minDivisibleClusterSize`.
1192
+ """
1193
+ return self._set(minDivisibleClusterSize=value)
1194
+
1195
+ @since("2.4.0")
1196
+ def setDistanceMeasure(self, value: str) -> "BisectingKMeans":
1197
+ """
1198
+ Sets the value of :py:attr:`distanceMeasure`.
1199
+ """
1200
+ return self._set(distanceMeasure=value)
1201
+
1202
+ @since("2.0.0")
1203
+ def setMaxIter(self, value: int) -> "BisectingKMeans":
1204
+ """
1205
+ Sets the value of :py:attr:`maxIter`.
1206
+ """
1207
+ return self._set(maxIter=value)
1208
+
1209
+ @since("2.0.0")
1210
+ def setFeaturesCol(self, value: str) -> "BisectingKMeans":
1211
+ """
1212
+ Sets the value of :py:attr:`featuresCol`.
1213
+ """
1214
+ return self._set(featuresCol=value)
1215
+
1216
+ @since("2.0.0")
1217
+ def setPredictionCol(self, value: str) -> "BisectingKMeans":
1218
+ """
1219
+ Sets the value of :py:attr:`predictionCol`.
1220
+ """
1221
+ return self._set(predictionCol=value)
1222
+
1223
+ @since("2.0.0")
1224
+ def setSeed(self, value: int) -> "BisectingKMeans":
1225
+ """
1226
+ Sets the value of :py:attr:`seed`.
1227
+ """
1228
+ return self._set(seed=value)
1229
+
1230
+ @since("3.0.0")
1231
+ def setWeightCol(self, value: str) -> "BisectingKMeans":
1232
+ """
1233
+ Sets the value of :py:attr:`weightCol`.
1234
+ """
1235
+ return self._set(weightCol=value)
1236
+
1237
+ def _create_model(self, java_model: "JavaObject") -> BisectingKMeansModel:
1238
+ return BisectingKMeansModel(java_model)
1239
+
1240
+
1241
+ class BisectingKMeansSummary(ClusteringSummary):
1242
+ """
1243
+ Bisecting KMeans clustering results for a given model.
1244
+
1245
+ .. versionadded:: 2.1.0
1246
+ """
1247
+
1248
+ @property
1249
+ @since("3.0.0")
1250
+ def trainingCost(self) -> float:
1251
+ """
1252
+ Sum of squared distances to the nearest centroid for all points in the training dataset.
1253
+ This is equivalent to sklearn's inertia.
1254
+ """
1255
+ return self._call_java("trainingCost")
1256
+
1257
+
1258
+ @inherit_doc
1259
+ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
1260
+ """
1261
+ Params for :py:class:`LDA` and :py:class:`LDAModel`.
1262
+
1263
+ .. versionadded:: 3.0.0
1264
+ """
1265
+
1266
+ k: Param[int] = Param(
1267
+ Params._dummy(),
1268
+ "k",
1269
+ "The number of topics (clusters) to infer. Must be > 1.",
1270
+ typeConverter=TypeConverters.toInt,
1271
+ )
1272
+ optimizer: Param[str] = Param(
1273
+ Params._dummy(),
1274
+ "optimizer",
1275
+ "Optimizer or inference algorithm used to estimate the LDA model. "
1276
+ "Supported: online, em",
1277
+ typeConverter=TypeConverters.toString,
1278
+ )
1279
+ learningOffset: Param[float] = Param(
1280
+ Params._dummy(),
1281
+ "learningOffset",
1282
+ "A (positive) learning parameter that downweights early iterations."
1283
+ " Larger values make early iterations count less",
1284
+ typeConverter=TypeConverters.toFloat,
1285
+ )
1286
+ learningDecay: Param[float] = Param(
1287
+ Params._dummy(),
1288
+ "learningDecay",
1289
+ "Learning rate, set as an"
1290
+ "exponential decay rate. This should be between (0.5, 1.0] to "
1291
+ "guarantee asymptotic convergence.",
1292
+ typeConverter=TypeConverters.toFloat,
1293
+ )
1294
+ subsamplingRate: Param[float] = Param(
1295
+ Params._dummy(),
1296
+ "subsamplingRate",
1297
+ "Fraction of the corpus to be sampled and used in each iteration "
1298
+ "of mini-batch gradient descent, in range (0, 1].",
1299
+ typeConverter=TypeConverters.toFloat,
1300
+ )
1301
+ optimizeDocConcentration: Param[bool] = Param(
1302
+ Params._dummy(),
1303
+ "optimizeDocConcentration",
1304
+ "Indicates whether the docConcentration (Dirichlet parameter "
1305
+ "for document-topic distribution) will be optimized during "
1306
+ "training.",
1307
+ typeConverter=TypeConverters.toBoolean,
1308
+ )
1309
+ docConcentration: Param[List[float]] = Param(
1310
+ Params._dummy(),
1311
+ "docConcentration",
1312
+ 'Concentration parameter (commonly named "alpha") for the '
1313
+ 'prior placed on documents\' distributions over topics ("theta").',
1314
+ typeConverter=TypeConverters.toListFloat,
1315
+ )
1316
+ topicConcentration: Param[float] = Param(
1317
+ Params._dummy(),
1318
+ "topicConcentration",
1319
+ 'Concentration parameter (commonly named "beta" or "eta") for '
1320
+ "the prior placed on topic' distributions over terms.",
1321
+ typeConverter=TypeConverters.toFloat,
1322
+ )
1323
+ topicDistributionCol: Param[str] = Param(
1324
+ Params._dummy(),
1325
+ "topicDistributionCol",
1326
+ "Output column with estimates of the topic mixture distribution "
1327
+ 'for each document (often called "theta" in the literature). '
1328
+ "Returns a vector of zeros for an empty document.",
1329
+ typeConverter=TypeConverters.toString,
1330
+ )
1331
+ keepLastCheckpoint: Param[bool] = Param(
1332
+ Params._dummy(),
1333
+ "keepLastCheckpoint",
1334
+ "(For EM optimizer) If using checkpointing, this indicates whether"
1335
+ " to keep the last checkpoint. If false, then the checkpoint will be"
1336
+ " deleted. Deleting the checkpoint can cause failures if a data"
1337
+ " partition is lost, so set this bit with care.",
1338
+ TypeConverters.toBoolean,
1339
+ )
1340
+
1341
+ def __init__(self, *args: Any):
1342
+ super(_LDAParams, self).__init__(*args)
1343
+ self._setDefault(
1344
+ maxIter=20,
1345
+ checkpointInterval=10,
1346
+ k=10,
1347
+ optimizer="online",
1348
+ learningOffset=1024.0,
1349
+ learningDecay=0.51,
1350
+ subsamplingRate=0.05,
1351
+ optimizeDocConcentration=True,
1352
+ topicDistributionCol="topicDistribution",
1353
+ keepLastCheckpoint=True,
1354
+ )
1355
+
1356
+ @since("2.0.0")
1357
+ def getK(self) -> int:
1358
+ """
1359
+ Gets the value of :py:attr:`k` or its default value.
1360
+ """
1361
+ return self.getOrDefault(self.k)
1362
+
1363
+ @since("2.0.0")
1364
+ def getOptimizer(self) -> str:
1365
+ """
1366
+ Gets the value of :py:attr:`optimizer` or its default value.
1367
+ """
1368
+ return self.getOrDefault(self.optimizer)
1369
+
1370
+ @since("2.0.0")
1371
+ def getLearningOffset(self) -> float:
1372
+ """
1373
+ Gets the value of :py:attr:`learningOffset` or its default value.
1374
+ """
1375
+ return self.getOrDefault(self.learningOffset)
1376
+
1377
+ @since("2.0.0")
1378
+ def getLearningDecay(self) -> float:
1379
+ """
1380
+ Gets the value of :py:attr:`learningDecay` or its default value.
1381
+ """
1382
+ return self.getOrDefault(self.learningDecay)
1383
+
1384
+ @since("2.0.0")
1385
+ def getSubsamplingRate(self) -> float:
1386
+ """
1387
+ Gets the value of :py:attr:`subsamplingRate` or its default value.
1388
+ """
1389
+ return self.getOrDefault(self.subsamplingRate)
1390
+
1391
+ @since("2.0.0")
1392
+ def getOptimizeDocConcentration(self) -> bool:
1393
+ """
1394
+ Gets the value of :py:attr:`optimizeDocConcentration` or its default value.
1395
+ """
1396
+ return self.getOrDefault(self.optimizeDocConcentration)
1397
+
1398
+ @since("2.0.0")
1399
+ def getDocConcentration(self) -> List[float]:
1400
+ """
1401
+ Gets the value of :py:attr:`docConcentration` or its default value.
1402
+ """
1403
+ return self.getOrDefault(self.docConcentration)
1404
+
1405
+ @since("2.0.0")
1406
+ def getTopicConcentration(self) -> float:
1407
+ """
1408
+ Gets the value of :py:attr:`topicConcentration` or its default value.
1409
+ """
1410
+ return self.getOrDefault(self.topicConcentration)
1411
+
1412
+ @since("2.0.0")
1413
+ def getTopicDistributionCol(self) -> str:
1414
+ """
1415
+ Gets the value of :py:attr:`topicDistributionCol` or its default value.
1416
+ """
1417
+ return self.getOrDefault(self.topicDistributionCol)
1418
+
1419
+ @since("2.0.0")
1420
+ def getKeepLastCheckpoint(self) -> bool:
1421
+ """
1422
+ Gets the value of :py:attr:`keepLastCheckpoint` or its default value.
1423
+ """
1424
+ return self.getOrDefault(self.keepLastCheckpoint)
1425
+
1426
+
1427
+ @inherit_doc
1428
+ class LDAModel(JavaModel, _LDAParams):
1429
+ """
1430
+ Latent Dirichlet Allocation (LDA) model.
1431
+ This abstraction permits for different underlying representations,
1432
+ including local and distributed data structures.
1433
+
1434
+ .. versionadded:: 2.0.0
1435
+ """
1436
+
1437
+ @since("3.0.0")
1438
+ def setFeaturesCol(self: "M", value: str) -> "M":
1439
+ """
1440
+ Sets the value of :py:attr:`featuresCol`.
1441
+ """
1442
+ return self._set(featuresCol=value)
1443
+
1444
+ @since("3.0.0")
1445
+ def setSeed(self: "M", value: int) -> "M":
1446
+ """
1447
+ Sets the value of :py:attr:`seed`.
1448
+ """
1449
+ return self._set(seed=value)
1450
+
1451
+ @since("3.0.0")
1452
+ def setTopicDistributionCol(self: "M", value: str) -> "M":
1453
+ """
1454
+ Sets the value of :py:attr:`topicDistributionCol`.
1455
+ """
1456
+ return self._set(topicDistributionCol=value)
1457
+
1458
+ @since("2.0.0")
1459
+ def isDistributed(self) -> bool:
1460
+ """
1461
+ Indicates whether this instance is of type DistributedLDAModel
1462
+ """
1463
+ return self._call_java("isDistributed")
1464
+
1465
+ @since("2.0.0")
1466
+ def vocabSize(self) -> int:
1467
+ """Vocabulary size (number of terms or words in the vocabulary)"""
1468
+ return self._call_java("vocabSize")
1469
+
1470
+ @since("2.0.0")
1471
+ def topicsMatrix(self) -> Matrix:
1472
+ """
1473
+ Inferred topics, where each topic is represented by a distribution over terms.
1474
+ This is a matrix of size vocabSize x k, where each column is a topic.
1475
+ No guarantees are given about the ordering of the topics.
1476
+
1477
+ .. warning:: If this model is actually a :py:class:`DistributedLDAModel`
1478
+ instance produced by the Expectation-Maximization ("em") `optimizer`,
1479
+ then this method could involve collecting a large amount of data
1480
+ to the driver (on the order of vocabSize x k).
1481
+ """
1482
+ return self._call_java("topicsMatrix")
1483
+
1484
+ @since("2.0.0")
1485
+ def logLikelihood(self, dataset: DataFrame) -> float:
1486
+ """
1487
+ Calculates a lower bound on the log likelihood of the entire corpus.
1488
+ See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
1489
+
1490
+ .. warning:: If this model is an instance of :py:class:`DistributedLDAModel` (produced when
1491
+ :py:attr:`optimizer` is set to "em"), this involves collecting a large
1492
+ :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future.
1493
+ """
1494
+ return self._call_java("logLikelihood", dataset)
1495
+
1496
+ @since("2.0.0")
1497
+ def logPerplexity(self, dataset: DataFrame) -> float:
1498
+ """
1499
+ Calculate an upper bound on perplexity. (Lower is better.)
1500
+ See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
1501
+
1502
+ .. warning:: If this model is an instance of :py:class:`DistributedLDAModel` (produced when
1503
+ :py:attr:`optimizer` is set to "em"), this involves collecting a large
1504
+ :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future.
1505
+ """
1506
+ return self._call_java("logPerplexity", dataset)
1507
+
1508
+ @since("2.0.0")
1509
+ def describeTopics(self, maxTermsPerTopic: int = 10) -> DataFrame:
1510
+ """
1511
+ Return the topics described by their top-weighted terms.
1512
+ """
1513
+ return self._call_java("describeTopics", maxTermsPerTopic)
1514
+
1515
+ @since("2.0.0")
1516
+ def estimatedDocConcentration(self) -> Vector:
1517
+ """
1518
+ Value for :py:attr:`LDA.docConcentration` estimated from data.
1519
+ If Online LDA was used and :py:attr:`LDA.optimizeDocConcentration` was set to false,
1520
+ then this returns the fixed (given) value for the :py:attr:`LDA.docConcentration` parameter.
1521
+ """
1522
+ return self._call_java("estimatedDocConcentration")
1523
+
1524
+
1525
+ @inherit_doc
1526
+ class DistributedLDAModel(LDAModel, JavaMLReadable["DistributedLDAModel"], JavaMLWritable):
1527
+ """
1528
+ Distributed model fitted by :py:class:`LDA`.
1529
+ This type of model is currently only produced by Expectation-Maximization (EM).
1530
+
1531
+ This model stores the inferred topics, the full training dataset, and the topic distribution
1532
+ for each training document.
1533
+
1534
+ .. versionadded:: 2.0.0
1535
+ """
1536
+
1537
+ @since("2.0.0")
1538
+ def toLocal(self) -> "LocalLDAModel":
1539
+ """
1540
+ Convert this distributed model to a local representation. This discards info about the
1541
+ training dataset.
1542
+
1543
+ .. warning:: This involves collecting a large :py:func:`topicsMatrix` to the driver.
1544
+ """
1545
+ model = LocalLDAModel(self._call_java("toLocal"))
1546
+
1547
+ # SPARK-10931: Temporary fix to be removed once LDAModel defines Params
1548
+ model._create_params_from_java()
1549
+ model._transfer_params_from_java()
1550
+
1551
+ return model
1552
+
1553
+ @since("2.0.0")
1554
+ def trainingLogLikelihood(self) -> float:
1555
+ """
1556
+ Log likelihood of the observed tokens in the training set,
1557
+ given the current parameter estimates:
1558
+ log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
1559
+
1560
+ Notes
1561
+ -----
1562
+ - This excludes the prior; for that, use :py:func:`logPrior`.
1563
+ - Even with :py:func:`logPrior`, this is NOT the same as the data log likelihood given
1564
+ the hyperparameters.
1565
+ - This is computed from the topic distributions computed during training. If you call
1566
+ :py:func:`logLikelihood` on the same training dataset, the topic distributions
1567
+ will be computed again, possibly giving different results.
1568
+ """
1569
+ return self._call_java("trainingLogLikelihood")
1570
+
1571
+ @since("2.0.0")
1572
+ def logPrior(self) -> float:
1573
+ """
1574
+ Log probability of the current parameter estimate:
1575
+ log P(topics, topic distributions for docs | alpha, eta)
1576
+ """
1577
+ return self._call_java("logPrior")
1578
+
1579
+ def getCheckpointFiles(self) -> List[str]:
1580
+ """
1581
+ If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may
1582
+ be saved checkpoint files. This method is provided so that users can manage those files.
1583
+
1584
+ .. versionadded:: 2.0.0
1585
+
1586
+ Returns
1587
+ -------
1588
+ list
1589
+ List of checkpoint files from training
1590
+
1591
+ Notes
1592
+ -----
1593
+ Removing the checkpoints can cause failures if a partition is lost and is needed
1594
+ by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up
1595
+ the checkpoints when this model and derivative data go out of scope.
1596
+ """
1597
+ return self._call_java("getCheckpointFiles")
1598
+
1599
+
1600
+ @inherit_doc
1601
+ class LocalLDAModel(LDAModel, JavaMLReadable["LocalLDAModel"], JavaMLWritable):
1602
+ """
1603
+ Local (non-distributed) model fitted by :py:class:`LDA`.
1604
+ This model stores the inferred topics only; it does not store info about the training dataset.
1605
+
1606
+ .. versionadded:: 2.0.0
1607
+ """
1608
+
1609
+ pass
1610
+
1611
+
1612
+ @inherit_doc
1613
+ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable["LDA"], JavaMLWritable):
1614
+ """
1615
+ Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
1616
+
1617
+ Terminology:
1618
+
1619
+ - "term" = "word": an element of the vocabulary
1620
+ - "token": instance of a term appearing in a document
1621
+ - "topic": multinomial distribution over terms representing some concept
1622
+ - "document": one piece of text, corresponding to one row in the input data
1623
+
1624
+ Original LDA paper (journal version):
1625
+ Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
1626
+
1627
+ Input data (featuresCol):
1628
+ LDA is given a collection of documents as input data, via the featuresCol parameter.
1629
+ Each document is specified as a :py:class:`Vector` of length vocabSize, where each entry is the
1630
+ count for the corresponding term (word) in the document. Feature transformers such as
1631
+ :py:class:`pyspark.ml.feature.Tokenizer` and :py:class:`pyspark.ml.feature.CountVectorizer`
1632
+ can be useful for converting text to word count vectors.
1633
+
1634
+ .. versionadded:: 2.0.0
1635
+
1636
+ Examples
1637
+ --------
1638
+ >>> from pyspark.ml.linalg import Vectors, SparseVector
1639
+ >>> from pyspark.ml.clustering import LDA
1640
+ >>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
1641
+ ... [2, SparseVector(2, {0: 1.0})],], ["id", "features"])
1642
+ >>> lda = LDA(k=2, seed=1, optimizer="em")
1643
+ >>> lda.setMaxIter(10)
1644
+ LDA...
1645
+ >>> lda.getMaxIter()
1646
+ 10
1647
+ >>> lda.clear(lda.maxIter)
1648
+ >>> model = lda.fit(df)
1649
+ >>> model.setSeed(1)
1650
+ DistributedLDAModel...
1651
+ >>> model.getTopicDistributionCol()
1652
+ 'topicDistribution'
1653
+ >>> model.isDistributed()
1654
+ True
1655
+ >>> localModel = model.toLocal()
1656
+ >>> localModel.isDistributed()
1657
+ False
1658
+ >>> model.vocabSize()
1659
+ 2
1660
+ >>> model.describeTopics().show()
1661
+ +-----+-----------+--------------------+
1662
+ |topic|termIndices| termWeights|
1663
+ +-----+-----------+--------------------+
1664
+ | 0| [1, 0]|[0.50401530077160...|
1665
+ | 1| [0, 1]|[0.50401530077160...|
1666
+ +-----+-----------+--------------------+
1667
+ ...
1668
+ >>> model.topicsMatrix()
1669
+ DenseMatrix(2, 2, [0.496, 0.504, 0.504, 0.496], 0)
1670
+ >>> lda_path = temp_path + "/lda"
1671
+ >>> lda.save(lda_path)
1672
+ >>> sameLDA = LDA.load(lda_path)
1673
+ >>> distributed_model_path = temp_path + "/lda_distributed_model"
1674
+ >>> model.save(distributed_model_path)
1675
+ >>> sameModel = DistributedLDAModel.load(distributed_model_path)
1676
+ >>> local_model_path = temp_path + "/lda_local_model"
1677
+ >>> localModel.save(local_model_path)
1678
+ >>> sameLocalModel = LocalLDAModel.load(local_model_path)
1679
+ >>> model.transform(df).take(1) == sameLocalModel.transform(df).take(1)
1680
+ True
1681
+ """
1682
+
1683
+ _input_kwargs: Dict[str, Any]
1684
+
1685
+ @keyword_only
1686
+ def __init__(
1687
+ self,
1688
+ *,
1689
+ featuresCol: str = "features",
1690
+ maxIter: int = 20,
1691
+ seed: Optional[int] = None,
1692
+ checkpointInterval: int = 10,
1693
+ k: int = 10,
1694
+ optimizer: str = "online",
1695
+ learningOffset: float = 1024.0,
1696
+ learningDecay: float = 0.51,
1697
+ subsamplingRate: float = 0.05,
1698
+ optimizeDocConcentration: bool = True,
1699
+ docConcentration: Optional[List[float]] = None,
1700
+ topicConcentration: Optional[float] = None,
1701
+ topicDistributionCol: str = "topicDistribution",
1702
+ keepLastCheckpoint: bool = True,
1703
+ ):
1704
+ """
1705
+ __init__(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
1706
+ k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
1707
+ subsamplingRate=0.05, optimizeDocConcentration=True,\
1708
+ docConcentration=None, topicConcentration=None,\
1709
+ topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
1710
+ """
1711
+ super(LDA, self).__init__()
1712
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid)
1713
+ kwargs = self._input_kwargs
1714
+ self.setParams(**kwargs)
1715
+
1716
+ def _create_model(self, java_model: "JavaObject") -> LDAModel:
1717
+ if self.getOptimizer() == "em":
1718
+ return DistributedLDAModel(java_model)
1719
+ else:
1720
+ return LocalLDAModel(java_model)
1721
+
1722
+ @keyword_only
1723
+ @since("2.0.0")
1724
+ def setParams(
1725
+ self,
1726
+ *,
1727
+ featuresCol: str = "features",
1728
+ maxIter: int = 20,
1729
+ seed: Optional[int] = None,
1730
+ checkpointInterval: int = 10,
1731
+ k: int = 10,
1732
+ optimizer: str = "online",
1733
+ learningOffset: float = 1024.0,
1734
+ learningDecay: float = 0.51,
1735
+ subsamplingRate: float = 0.05,
1736
+ optimizeDocConcentration: bool = True,
1737
+ docConcentration: Optional[List[float]] = None,
1738
+ topicConcentration: Optional[float] = None,
1739
+ topicDistributionCol: str = "topicDistribution",
1740
+ keepLastCheckpoint: bool = True,
1741
+ ) -> "LDA":
1742
+ """
1743
+ setParams(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
1744
+ k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
1745
+ subsamplingRate=0.05, optimizeDocConcentration=True,\
1746
+ docConcentration=None, topicConcentration=None,\
1747
+ topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
1748
+
1749
+ Sets params for LDA.
1750
+ """
1751
+ kwargs = self._input_kwargs
1752
+ return self._set(**kwargs)
1753
+
1754
+ @since("2.0.0")
1755
+ def setCheckpointInterval(self, value: int) -> "LDA":
1756
+ """
1757
+ Sets the value of :py:attr:`checkpointInterval`.
1758
+ """
1759
+ return self._set(checkpointInterval=value)
1760
+
1761
+ @since("2.0.0")
1762
+ def setSeed(self, value: int) -> "LDA":
1763
+ """
1764
+ Sets the value of :py:attr:`seed`.
1765
+ """
1766
+ return self._set(seed=value)
1767
+
1768
+ @since("2.0.0")
1769
+ def setK(self, value: int) -> "LDA":
1770
+ """
1771
+ Sets the value of :py:attr:`k`.
1772
+
1773
+ >>> algo = LDA().setK(10)
1774
+ >>> algo.getK()
1775
+ 10
1776
+ """
1777
+ return self._set(k=value)
1778
+
1779
+ @since("2.0.0")
1780
+ def setOptimizer(self, value: str) -> "LDA":
1781
+ """
1782
+ Sets the value of :py:attr:`optimizer`.
1783
+ Currently only support 'em' and 'online'.
1784
+
1785
+ Examples
1786
+ --------
1787
+ >>> algo = LDA().setOptimizer("em")
1788
+ >>> algo.getOptimizer()
1789
+ 'em'
1790
+ """
1791
+ return self._set(optimizer=value)
1792
+
1793
+ @since("2.0.0")
1794
+ def setLearningOffset(self, value: float) -> "LDA":
1795
+ """
1796
+ Sets the value of :py:attr:`learningOffset`.
1797
+
1798
+ Examples
1799
+ --------
1800
+ >>> algo = LDA().setLearningOffset(100)
1801
+ >>> algo.getLearningOffset()
1802
+ 100.0
1803
+ """
1804
+ return self._set(learningOffset=value)
1805
+
1806
+ @since("2.0.0")
1807
+ def setLearningDecay(self, value: float) -> "LDA":
1808
+ """
1809
+ Sets the value of :py:attr:`learningDecay`.
1810
+
1811
+ Examples
1812
+ --------
1813
+ >>> algo = LDA().setLearningDecay(0.1)
1814
+ >>> algo.getLearningDecay()
1815
+ 0.1...
1816
+ """
1817
+ return self._set(learningDecay=value)
1818
+
1819
+ @since("2.0.0")
1820
+ def setSubsamplingRate(self, value: float) -> "LDA":
1821
+ """
1822
+ Sets the value of :py:attr:`subsamplingRate`.
1823
+
1824
+ Examples
1825
+ --------
1826
+ >>> algo = LDA().setSubsamplingRate(0.1)
1827
+ >>> algo.getSubsamplingRate()
1828
+ 0.1...
1829
+ """
1830
+ return self._set(subsamplingRate=value)
1831
+
1832
+ @since("2.0.0")
1833
+ def setOptimizeDocConcentration(self, value: bool) -> "LDA":
1834
+ """
1835
+ Sets the value of :py:attr:`optimizeDocConcentration`.
1836
+
1837
+ Examples
1838
+ --------
1839
+ >>> algo = LDA().setOptimizeDocConcentration(True)
1840
+ >>> algo.getOptimizeDocConcentration()
1841
+ True
1842
+ """
1843
+ return self._set(optimizeDocConcentration=value)
1844
+
1845
+ @since("2.0.0")
1846
+ def setDocConcentration(self, value: List[float]) -> "LDA":
1847
+ """
1848
+ Sets the value of :py:attr:`docConcentration`.
1849
+
1850
+ Examples
1851
+ --------
1852
+ >>> algo = LDA().setDocConcentration([0.1, 0.2])
1853
+ >>> algo.getDocConcentration()
1854
+ [0.1..., 0.2...]
1855
+ """
1856
+ return self._set(docConcentration=value)
1857
+
1858
+ @since("2.0.0")
1859
+ def setTopicConcentration(self, value: float) -> "LDA":
1860
+ """
1861
+ Sets the value of :py:attr:`topicConcentration`.
1862
+
1863
+ Examples
1864
+ --------
1865
+ >>> algo = LDA().setTopicConcentration(0.5)
1866
+ >>> algo.getTopicConcentration()
1867
+ 0.5...
1868
+ """
1869
+ return self._set(topicConcentration=value)
1870
+
1871
+ @since("2.0.0")
1872
+ def setTopicDistributionCol(self, value: str) -> "LDA":
1873
+ """
1874
+ Sets the value of :py:attr:`topicDistributionCol`.
1875
+
1876
+ Examples
1877
+ --------
1878
+ >>> algo = LDA().setTopicDistributionCol("topicDistributionCol")
1879
+ >>> algo.getTopicDistributionCol()
1880
+ 'topicDistributionCol'
1881
+ """
1882
+ return self._set(topicDistributionCol=value)
1883
+
1884
+ @since("2.0.0")
1885
+ def setKeepLastCheckpoint(self, value: bool) -> "LDA":
1886
+ """
1887
+ Sets the value of :py:attr:`keepLastCheckpoint`.
1888
+
1889
+ Examples
1890
+ --------
1891
+ >>> algo = LDA().setKeepLastCheckpoint(False)
1892
+ >>> algo.getKeepLastCheckpoint()
1893
+ False
1894
+ """
1895
+ return self._set(keepLastCheckpoint=value)
1896
+
1897
+ @since("2.0.0")
1898
+ def setMaxIter(self, value: int) -> "LDA":
1899
+ """
1900
+ Sets the value of :py:attr:`maxIter`.
1901
+ """
1902
+ return self._set(maxIter=value)
1903
+
1904
+ @since("2.0.0")
1905
+ def setFeaturesCol(self, value: str) -> "LDA":
1906
+ """
1907
+ Sets the value of :py:attr:`featuresCol`.
1908
+ """
1909
+ return self._set(featuresCol=value)
1910
+
1911
+
1912
+ @inherit_doc
1913
+ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
1914
+ """
1915
+ Params for :py:class:`PowerIterationClustering`.
1916
+
1917
+ .. versionadded:: 3.0.0
1918
+ """
1919
+
1920
+ k: Param[int] = Param(
1921
+ Params._dummy(),
1922
+ "k",
1923
+ "The number of clusters to create. Must be > 1.",
1924
+ typeConverter=TypeConverters.toInt,
1925
+ )
1926
+ initMode: Param[str] = Param(
1927
+ Params._dummy(),
1928
+ "initMode",
1929
+ "The initialization algorithm. This can be either "
1930
+ + "'random' to use a random vector as vertex properties, or 'degree' to use "
1931
+ + "a normalized sum of similarities with other vertices. Supported options: "
1932
+ + "'random' and 'degree'.",
1933
+ typeConverter=TypeConverters.toString,
1934
+ )
1935
+ srcCol: Param[str] = Param(
1936
+ Params._dummy(),
1937
+ "srcCol",
1938
+ "Name of the input column for source vertex IDs.",
1939
+ typeConverter=TypeConverters.toString,
1940
+ )
1941
+ dstCol: Param[str] = Param(
1942
+ Params._dummy(),
1943
+ "dstCol",
1944
+ "Name of the input column for destination vertex IDs.",
1945
+ typeConverter=TypeConverters.toString,
1946
+ )
1947
+
1948
+ def __init__(self, *args: Any):
1949
+ super(_PowerIterationClusteringParams, self).__init__(*args)
1950
+ self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst")
1951
+
1952
+ @since("2.4.0")
1953
+ def getK(self) -> int:
1954
+ """
1955
+ Gets the value of :py:attr:`k` or its default value.
1956
+ """
1957
+ return self.getOrDefault(self.k)
1958
+
1959
+ @since("2.4.0")
1960
+ def getInitMode(self) -> str:
1961
+ """
1962
+ Gets the value of :py:attr:`initMode` or its default value.
1963
+ """
1964
+ return self.getOrDefault(self.initMode)
1965
+
1966
+ @since("2.4.0")
1967
+ def getSrcCol(self) -> str:
1968
+ """
1969
+ Gets the value of :py:attr:`srcCol` or its default value.
1970
+ """
1971
+ return self.getOrDefault(self.srcCol)
1972
+
1973
+ @since("2.4.0")
1974
+ def getDstCol(self) -> str:
1975
+ """
1976
+ Gets the value of :py:attr:`dstCol` or its default value.
1977
+ """
1978
+ return self.getOrDefault(self.dstCol)
1979
+
1980
+
1981
+ @inherit_doc
1982
+ class PowerIterationClustering(
1983
+ _PowerIterationClusteringParams,
1984
+ JavaParams,
1985
+ JavaMLReadable["PowerIterationClustering"],
1986
+ JavaMLWritable,
1987
+ ):
1988
+ """
1989
+ Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
1990
+ `Lin and Cohen <http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf>`_. From the
1991
+ abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power
1992
+ iteration on a normalized pair-wise similarity matrix of the data.
1993
+
1994
+ This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method
1995
+ to run the PowerIterationClustering algorithm.
1996
+
1997
+ .. versionadded:: 2.4.0
1998
+
1999
+ Notes
2000
+ -----
2001
+ See `Wikipedia on Spectral clustering <http://en.wikipedia.org/wiki/Spectral_clustering>`_
2002
+
2003
+ Examples
2004
+ --------
2005
+ >>> data = [(1, 0, 0.5),
2006
+ ... (2, 0, 0.5), (2, 1, 0.7),
2007
+ ... (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9),
2008
+ ... (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1),
2009
+ ... (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)]
2010
+ >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight").repartition(1)
2011
+ >>> pic = PowerIterationClustering(k=2, weightCol="weight")
2012
+ >>> pic.setMaxIter(40)
2013
+ PowerIterationClustering...
2014
+ >>> assignments = pic.assignClusters(df)
2015
+ >>> assignments.sort(assignments.id).show(truncate=False)
2016
+ +---+-------+
2017
+ |id |cluster|
2018
+ +---+-------+
2019
+ |0 |0 |
2020
+ |1 |0 |
2021
+ |2 |0 |
2022
+ |3 |0 |
2023
+ |4 |0 |
2024
+ |5 |1 |
2025
+ +---+-------+
2026
+ ...
2027
+ >>> pic_path = temp_path + "/pic"
2028
+ >>> pic.save(pic_path)
2029
+ >>> pic2 = PowerIterationClustering.load(pic_path)
2030
+ >>> pic2.getK()
2031
+ 2
2032
+ >>> pic2.getMaxIter()
2033
+ 40
2034
+ >>> pic2.assignClusters(df).take(6) == assignments.take(6)
2035
+ True
2036
+ """
2037
+
2038
+ _input_kwargs: Dict[str, Any]
2039
+
2040
+ @keyword_only
2041
+ def __init__(
2042
+ self,
2043
+ *,
2044
+ k: int = 2,
2045
+ maxIter: int = 20,
2046
+ initMode: str = "random",
2047
+ srcCol: str = "src",
2048
+ dstCol: str = "dst",
2049
+ weightCol: Optional[str] = None,
2050
+ ):
2051
+ """
2052
+ __init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
2053
+ weightCol=None)
2054
+ """
2055
+ super(PowerIterationClustering, self).__init__()
2056
+ self._java_obj = self._new_java_obj(
2057
+ "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid
2058
+ )
2059
+ kwargs = self._input_kwargs
2060
+ self.setParams(**kwargs)
2061
+
2062
+ @keyword_only
2063
+ @since("2.4.0")
2064
+ def setParams(
2065
+ self,
2066
+ *,
2067
+ k: int = 2,
2068
+ maxIter: int = 20,
2069
+ initMode: str = "random",
2070
+ srcCol: str = "src",
2071
+ dstCol: str = "dst",
2072
+ weightCol: Optional[str] = None,
2073
+ ) -> "PowerIterationClustering":
2074
+ """
2075
+ setParams(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
2076
+ weightCol=None)
2077
+ Sets params for PowerIterationClustering.
2078
+ """
2079
+ kwargs = self._input_kwargs
2080
+ return self._set(**kwargs)
2081
+
2082
+ @since("2.4.0")
2083
+ def setK(self, value: int) -> "PowerIterationClustering":
2084
+ """
2085
+ Sets the value of :py:attr:`k`.
2086
+ """
2087
+ return self._set(k=value)
2088
+
2089
+ @since("2.4.0")
2090
+ def setInitMode(self, value: str) -> "PowerIterationClustering":
2091
+ """
2092
+ Sets the value of :py:attr:`initMode`.
2093
+ """
2094
+ return self._set(initMode=value)
2095
+
2096
+ @since("2.4.0")
2097
+ def setSrcCol(self, value: str) -> "PowerIterationClustering":
2098
+ """
2099
+ Sets the value of :py:attr:`srcCol`.
2100
+ """
2101
+ return self._set(srcCol=value)
2102
+
2103
+ @since("2.4.0")
2104
+ def setDstCol(self, value: str) -> "PowerIterationClustering":
2105
+ """
2106
+ Sets the value of :py:attr:`dstCol`.
2107
+ """
2108
+ return self._set(dstCol=value)
2109
+
2110
+ @since("2.4.0")
2111
+ def setMaxIter(self, value: int) -> "PowerIterationClustering":
2112
+ """
2113
+ Sets the value of :py:attr:`maxIter`.
2114
+ """
2115
+ return self._set(maxIter=value)
2116
+
2117
+ @since("2.4.0")
2118
+ def setWeightCol(self, value: str) -> "PowerIterationClustering":
2119
+ """
2120
+ Sets the value of :py:attr:`weightCol`.
2121
+ """
2122
+ return self._set(weightCol=value)
2123
+
2124
+ @since("2.4.0")
2125
+ def assignClusters(self, dataset: DataFrame) -> DataFrame:
2126
+ """
2127
+ Run the PIC algorithm and returns a cluster assignment for each input vertex.
2128
+
2129
+ Parameters
2130
+ ----------
2131
+ dataset : :py:class:`pyspark.sql.DataFrame`
2132
+ A dataset with columns src, dst, weight representing the affinity matrix,
2133
+ which is the matrix A in the PIC paper. Suppose the src column value is i,
2134
+ the dst column value is j, the weight column value is similarity s,,ij,,
2135
+ which must be nonnegative. This is a symmetric matrix and hence
2136
+ s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
2137
+ either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
2138
+ ignored, because we assume s,,ij,, = 0.0.
2139
+
2140
+ Returns
2141
+ -------
2142
+ :py:class:`pyspark.sql.DataFrame`
2143
+ A dataset that contains columns of vertex id and the corresponding cluster for
2144
+ the id. The schema of it will be:
2145
+ - id: Long
2146
+ - cluster: Int
2147
+ """
2148
+ self._transfer_params_to_java()
2149
+ assert self._java_obj is not None
2150
+
2151
+ jdf = self._java_obj.assignClusters(dataset._jdf)
2152
+ return DataFrame(jdf, dataset.sparkSession)
2153
+
2154
+
2155
+ if __name__ == "__main__":
2156
+ import doctest
2157
+ import numpy
2158
+ import pyspark.ml.clustering
2159
+ from pyspark.sql import SparkSession
2160
+
2161
+ try:
2162
+ # Numpy 1.14+ changed it's string format.
2163
+ numpy.set_printoptions(legacy="1.13")
2164
+ except TypeError:
2165
+ pass
2166
+ globs = pyspark.ml.clustering.__dict__.copy()
2167
+ # The small batch size here ensures that we see multiple batches,
2168
+ # even in these small test examples:
2169
+ spark = SparkSession.builder.master("local[2]").appName("ml.clustering tests").getOrCreate()
2170
+ sc = spark.sparkContext
2171
+ globs["sc"] = sc
2172
+ globs["spark"] = spark
2173
+ import tempfile
2174
+
2175
+ temp_path = tempfile.mkdtemp()
2176
+ globs["temp_path"] = temp_path
2177
+ try:
2178
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
2179
+ spark.stop()
2180
+ finally:
2181
+ from shutil import rmtree
2182
+
2183
+ try:
2184
+ rmtree(temp_path)
2185
+ except OSError:
2186
+ pass
2187
+ if failure_count:
2188
+ sys.exit(-1)