snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1741 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import os
19
+ import sys
20
+ import itertools
21
+ from multiprocessing.pool import ThreadPool
22
+
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Dict,
27
+ Iterable,
28
+ List,
29
+ Optional,
30
+ Sequence,
31
+ Tuple,
32
+ Type,
33
+ Union,
34
+ cast,
35
+ overload,
36
+ TYPE_CHECKING,
37
+ )
38
+
39
+ import numpy as np
40
+
41
+ from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
42
+ from pyspark.ml import Estimator, Transformer, Model
43
+ from pyspark.ml.common import inherit_doc, _py2java, _java2py
44
+ from pyspark.ml.evaluation import Evaluator, JavaEvaluator
45
+ from pyspark.ml.param import Params, Param, TypeConverters
46
+ from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
47
+ from pyspark.ml.util import (
48
+ DefaultParamsReader,
49
+ DefaultParamsWriter,
50
+ MetaAlgorithmReadWrite,
51
+ MLReadable,
52
+ MLReader,
53
+ MLWritable,
54
+ MLWriter,
55
+ JavaMLReader,
56
+ JavaMLWriter,
57
+ )
58
+ from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
59
+ from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
60
+ from pyspark.sql.types import BooleanType
61
+
62
+ from pyspark.sql.dataframe import DataFrame
63
+
64
+ if TYPE_CHECKING:
65
+ from pyspark.ml._typing import ParamMap
66
+ from py4j.java_gateway import JavaObject
67
+ from py4j.java_collections import JavaArray
68
+
69
+ __all__ = [
70
+ "ParamGridBuilder",
71
+ "CrossValidator",
72
+ "CrossValidatorModel",
73
+ "TrainValidationSplit",
74
+ "TrainValidationSplitModel",
75
+ ]
76
+
77
+
78
+ def _parallelFitTasks(
79
+ est: Estimator,
80
+ train: DataFrame,
81
+ eva: Evaluator,
82
+ validation: DataFrame,
83
+ epm: Sequence["ParamMap"],
84
+ collectSubModel: bool,
85
+ ) -> List[Callable[[], Tuple[int, float, Transformer]]]:
86
+ """
87
+ Creates a list of callables which can be called from different threads to fit and evaluate
88
+ an estimator in parallel. Each callable returns an `(index, metric)` pair.
89
+
90
+ Parameters
91
+ ----------
92
+ est : :py:class:`pyspark.ml.baseEstimator`
93
+ he estimator to be fit.
94
+ train : :py:class:`pyspark.sql.DataFrame`
95
+ DataFrame, training data set, used for fitting.
96
+ eva : :py:class:`pyspark.ml.evaluation.Evaluator`
97
+ used to compute `metric`
98
+ validation : :py:class:`pyspark.sql.DataFrame`
99
+ DataFrame, validation data set, used for evaluation.
100
+ epm : :py:class:`collections.abc.Sequence`
101
+ Sequence of ParamMap, params maps to be used during fitting & evaluation.
102
+ collectSubModel : bool
103
+ Whether to collect sub model.
104
+
105
+ Returns
106
+ -------
107
+ tuple
108
+ (int, float, subModel), an index into `epm` and the associated metric value.
109
+ """
110
+ modelIter = est.fitMultiple(train, epm)
111
+
112
+ def singleTask() -> Tuple[int, float, Transformer]:
113
+ index, model = next(modelIter)
114
+ # TODO: duplicate evaluator to take extra params from input
115
+ # Note: Supporting tuning params in evaluator need update method
116
+ # `MetaAlgorithmReadWrite.getAllNestedStages`, make it return
117
+ # all nested stages and evaluators
118
+ metric = eva.evaluate(model.transform(validation, epm[index]))
119
+ return index, metric, model if collectSubModel else None
120
+
121
+ return [singleTask] * len(epm)
122
+
123
+
124
+ class ParamGridBuilder:
125
+ r"""
126
+ Builder for a param grid used in grid search-based model selection.
127
+
128
+
129
+ .. versionadded:: 1.4.0
130
+
131
+ Examples
132
+ --------
133
+ >>> from pyspark.ml.classification import LogisticRegression
134
+ >>> lr = LogisticRegression()
135
+ >>> output = ParamGridBuilder() \
136
+ ... .baseOn({lr.labelCol: 'l'}) \
137
+ ... .baseOn([lr.predictionCol, 'p']) \
138
+ ... .addGrid(lr.regParam, [1.0, 2.0]) \
139
+ ... .addGrid(lr.maxIter, [1, 5]) \
140
+ ... .build()
141
+ >>> expected = [
142
+ ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
143
+ ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
144
+ ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
145
+ ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
146
+ >>> len(output) == len(expected)
147
+ True
148
+ >>> all([m in expected for m in output])
149
+ True
150
+ """
151
+
152
+ def __init__(self) -> None:
153
+ self._param_grid: "ParamMap" = {}
154
+
155
+ @since("1.4.0")
156
+ def addGrid(self, param: Param[Any], values: List[Any]) -> "ParamGridBuilder":
157
+ """
158
+ Sets the given parameters in this grid to fixed values.
159
+
160
+ param must be an instance of Param associated with an instance of Params
161
+ (such as Estimator or Transformer).
162
+ """
163
+ if isinstance(param, Param):
164
+ self._param_grid[param] = values
165
+ else:
166
+ raise TypeError("param must be an instance of Param")
167
+
168
+ return self
169
+
170
+ @overload
171
+ def baseOn(self, __args: "ParamMap") -> "ParamGridBuilder":
172
+ ...
173
+
174
+ @overload
175
+ def baseOn(self, *args: Tuple[Param, Any]) -> "ParamGridBuilder":
176
+ ...
177
+
178
+ @since("1.4.0")
179
+ def baseOn(self, *args: Union["ParamMap", Tuple[Param, Any]]) -> "ParamGridBuilder":
180
+ """
181
+ Sets the given parameters in this grid to fixed values.
182
+ Accepts either a parameter dictionary or a list of (parameter, value) pairs.
183
+ """
184
+ if isinstance(args[0], dict):
185
+ self.baseOn(*args[0].items())
186
+ else:
187
+ for (param, value) in args:
188
+ self.addGrid(param, [value])
189
+
190
+ return self
191
+
192
+ @since("1.4.0")
193
+ def build(self) -> List["ParamMap"]:
194
+ """
195
+ Builds and returns all combinations of parameters specified
196
+ by the param grid.
197
+ """
198
+ keys = self._param_grid.keys()
199
+ grid_values = self._param_grid.values()
200
+
201
+ def to_key_value_pairs(
202
+ keys: Iterable[Param], values: Iterable[Any]
203
+ ) -> Sequence[Tuple[Param, Any]]:
204
+ return [(key, key.typeConverter(value)) for key, value in zip(keys, values)]
205
+
206
+ return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
207
+
208
+
209
+ class _ValidatorParams(HasSeed):
210
+ """
211
+ Common params for TrainValidationSplit and CrossValidator.
212
+ """
213
+
214
+ estimator: Param[Estimator] = Param(
215
+ Params._dummy(), "estimator", "estimator to be cross-validated"
216
+ )
217
+ estimatorParamMaps: Param[List["ParamMap"]] = Param(
218
+ Params._dummy(), "estimatorParamMaps", "estimator param maps"
219
+ )
220
+ evaluator: Param[Evaluator] = Param(
221
+ Params._dummy(),
222
+ "evaluator",
223
+ "evaluator used to select hyper-parameters that maximize the validator metric",
224
+ )
225
+
226
+ @since("2.0.0")
227
+ def getEstimator(self) -> Estimator:
228
+ """
229
+ Gets the value of estimator or its default value.
230
+ """
231
+ return self.getOrDefault(self.estimator)
232
+
233
+ @since("2.0.0")
234
+ def getEstimatorParamMaps(self) -> List["ParamMap"]:
235
+ """
236
+ Gets the value of estimatorParamMaps or its default value.
237
+ """
238
+ return self.getOrDefault(self.estimatorParamMaps)
239
+
240
+ @since("2.0.0")
241
+ def getEvaluator(self) -> Evaluator:
242
+ """
243
+ Gets the value of evaluator or its default value.
244
+ """
245
+ return self.getOrDefault(self.evaluator)
246
+
247
+ @classmethod
248
+ def _from_java_impl(
249
+ cls, java_stage: "JavaObject"
250
+ ) -> Tuple[Estimator, List["ParamMap"], Evaluator]:
251
+ """
252
+ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
253
+ """
254
+
255
+ # Load information from java_stage to the instance.
256
+ estimator: Estimator = JavaParams._from_java(java_stage.getEstimator())
257
+ evaluator: Evaluator = JavaParams._from_java(java_stage.getEvaluator())
258
+ if isinstance(estimator, JavaEstimator):
259
+ epms = [
260
+ estimator._transfer_param_map_from_java(epm)
261
+ for epm in java_stage.getEstimatorParamMaps()
262
+ ]
263
+ elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
264
+ # Meta estimator such as Pipeline, OneVsRest
265
+ epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java(
266
+ estimator, java_stage.getEstimatorParamMaps()
267
+ )
268
+ else:
269
+ raise ValueError("Unsupported estimator used in tuning: " + str(estimator))
270
+
271
+ return estimator, epms, evaluator
272
+
273
+ def _to_java_impl(self) -> Tuple["JavaObject", "JavaObject", "JavaObject"]:
274
+ """
275
+ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
276
+ """
277
+
278
+ gateway = SparkContext._gateway
279
+ assert gateway is not None and SparkContext._jvm is not None
280
+
281
+ cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
282
+
283
+ estimator = self.getEstimator()
284
+ if isinstance(estimator, JavaEstimator):
285
+ java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
286
+ for idx, epm in enumerate(self.getEstimatorParamMaps()):
287
+ java_epms[idx] = estimator._transfer_param_map_to_java(epm)
288
+ elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
289
+ # Meta estimator such as Pipeline, OneVsRest
290
+ java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java(
291
+ estimator, self.getEstimatorParamMaps()
292
+ )
293
+ else:
294
+ raise ValueError("Unsupported estimator used in tuning: " + str(estimator))
295
+
296
+ java_estimator = cast(JavaEstimator, self.getEstimator())._to_java()
297
+ java_evaluator = cast(JavaEvaluator, self.getEvaluator())._to_java()
298
+ return java_estimator, java_epms, java_evaluator
299
+
300
+
301
+ class _ValidatorSharedReadWrite:
302
+ @staticmethod
303
+ def meta_estimator_transfer_param_maps_to_java(
304
+ pyEstimator: Estimator, pyParamMaps: Sequence["ParamMap"]
305
+ ) -> "JavaArray":
306
+ pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
307
+ stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages))
308
+ sc = SparkContext._active_spark_context
309
+
310
+ assert (
311
+ sc is not None and SparkContext._jvm is not None and SparkContext._gateway is not None
312
+ )
313
+
314
+ paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
315
+ javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps))
316
+
317
+ for idx, pyParamMap in enumerate(pyParamMaps):
318
+ javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
319
+ for pyParam, pyValue in pyParamMap.items():
320
+ javaParam = None
321
+ for pyStage, javaStage in stagePairs:
322
+ if pyStage._testOwnParam(pyParam.parent, pyParam.name):
323
+ javaParam = javaStage.getParam(pyParam.name)
324
+ break
325
+ if javaParam is None:
326
+ raise ValueError("Resolve param in estimatorParamMaps failed: " + str(pyParam))
327
+ if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"):
328
+ javaValue = cast(JavaParams, pyValue)._to_java()
329
+ else:
330
+ javaValue = _py2java(sc, pyValue)
331
+ pair = javaParam.w(javaValue)
332
+ javaParamMap.put([pair])
333
+ javaParamMaps[idx] = javaParamMap
334
+ return javaParamMaps
335
+
336
+ @staticmethod
337
+ def meta_estimator_transfer_param_maps_from_java(
338
+ pyEstimator: Estimator, javaParamMaps: "JavaArray"
339
+ ) -> List["ParamMap"]:
340
+ pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
341
+ stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages))
342
+ sc = SparkContext._active_spark_context
343
+
344
+ assert sc is not None and sc._jvm is not None
345
+
346
+ pyParamMaps = []
347
+ for javaParamMap in javaParamMaps:
348
+ pyParamMap = dict()
349
+ for javaPair in javaParamMap.toList():
350
+ javaParam = javaPair.param()
351
+ pyParam = None
352
+ for pyStage, javaStage in stagePairs:
353
+ if pyStage._testOwnParam(javaParam.parent(), javaParam.name()):
354
+ pyParam = pyStage.getParam(javaParam.name())
355
+ if pyParam is None:
356
+ raise ValueError(
357
+ "Resolve param in estimatorParamMaps failed: "
358
+ + javaParam.parent()
359
+ + "."
360
+ + javaParam.name()
361
+ )
362
+ javaValue = javaPair.value()
363
+ pyValue: Any
364
+ if sc._jvm.Class.forName(
365
+ "org.apache.spark.ml.util.DefaultParamsWritable"
366
+ ).isInstance(javaValue):
367
+ pyValue = JavaParams._from_java(javaValue)
368
+ else:
369
+ pyValue = _java2py(sc, javaValue)
370
+ pyParamMap[pyParam] = pyValue
371
+ pyParamMaps.append(pyParamMap)
372
+ return pyParamMaps
373
+
374
+ @staticmethod
375
+ def is_java_convertible(instance: _ValidatorParams) -> bool:
376
+ allNestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator())
377
+ evaluator_convertible = isinstance(instance.getEvaluator(), JavaParams)
378
+ estimator_convertible = all(map(lambda stage: hasattr(stage, "_to_java"), allNestedStages))
379
+ return estimator_convertible and evaluator_convertible
380
+
381
+ @staticmethod
382
+ def saveImpl(
383
+ path: str,
384
+ instance: _ValidatorParams,
385
+ sc: SparkContext,
386
+ extraMetadata: Optional[Dict[str, Any]] = None,
387
+ ) -> None:
388
+ numParamsNotJson = 0
389
+ jsonEstimatorParamMaps = []
390
+ for paramMap in instance.getEstimatorParamMaps():
391
+ jsonParamMap = []
392
+ for p, v in paramMap.items():
393
+ jsonParam: Dict[str, Any] = {"parent": p.parent, "name": p.name}
394
+ if (
395
+ (isinstance(v, Estimator) and not MetaAlgorithmReadWrite.isMetaEstimator(v))
396
+ or isinstance(v, Transformer)
397
+ or isinstance(v, Evaluator)
398
+ ):
399
+ relative_path = f"epm_{p.name}{numParamsNotJson}"
400
+ param_path = os.path.join(path, relative_path)
401
+ numParamsNotJson += 1
402
+ cast(MLWritable, v).save(param_path)
403
+ jsonParam["value"] = relative_path
404
+ jsonParam["isJson"] = False
405
+ elif isinstance(v, MLWritable):
406
+ raise RuntimeError(
407
+ "ValidatorSharedReadWrite.saveImpl does not handle parameters of type: "
408
+ "MLWritable that are not Estimator/Evaluator/Transformer, and if parameter "
409
+ "is estimator, it cannot be meta estimator such as Validator or OneVsRest"
410
+ )
411
+ else:
412
+ jsonParam["value"] = v
413
+ jsonParam["isJson"] = True
414
+ jsonParamMap.append(jsonParam)
415
+ jsonEstimatorParamMaps.append(jsonParamMap)
416
+
417
+ skipParams = ["estimator", "evaluator", "estimatorParamMaps"]
418
+ jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
419
+ jsonParams["estimatorParamMaps"] = jsonEstimatorParamMaps
420
+
421
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, jsonParams)
422
+ evaluatorPath = os.path.join(path, "evaluator")
423
+ cast(MLWritable, instance.getEvaluator()).save(evaluatorPath)
424
+ estimatorPath = os.path.join(path, "estimator")
425
+ cast(MLWritable, instance.getEstimator()).save(estimatorPath)
426
+
427
+ @staticmethod
428
+ def load(
429
+ path: str, sc: SparkContext, metadata: Dict[str, Any]
430
+ ) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]:
431
+ evaluatorPath = os.path.join(path, "evaluator")
432
+ evaluator: Evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
433
+ estimatorPath = os.path.join(path, "estimator")
434
+ estimator: Estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc)
435
+
436
+ uidToParams = MetaAlgorithmReadWrite.getUidMap(estimator)
437
+ uidToParams[evaluator.uid] = evaluator
438
+
439
+ jsonEstimatorParamMaps = metadata["paramMap"]["estimatorParamMaps"]
440
+
441
+ estimatorParamMaps = []
442
+ for jsonParamMap in jsonEstimatorParamMaps:
443
+ paramMap = {}
444
+ for jsonParam in jsonParamMap:
445
+ est = uidToParams[jsonParam["parent"]]
446
+ param = getattr(est, jsonParam["name"])
447
+ if "isJson" not in jsonParam or ("isJson" in jsonParam and jsonParam["isJson"]):
448
+ value = jsonParam["value"]
449
+ else:
450
+ relativePath = jsonParam["value"]
451
+ valueSavedPath = os.path.join(path, relativePath)
452
+ value = DefaultParamsReader.loadParamsInstance(valueSavedPath, sc)
453
+ paramMap[param] = value
454
+ estimatorParamMaps.append(paramMap)
455
+
456
+ return metadata, estimator, evaluator, estimatorParamMaps
457
+
458
+ @staticmethod
459
+ def validateParams(instance: _ValidatorParams) -> None:
460
+ estiamtor = instance.getEstimator()
461
+ evaluator = instance.getEvaluator()
462
+ uidMap = MetaAlgorithmReadWrite.getUidMap(estiamtor)
463
+
464
+ for elem in [evaluator] + list(uidMap.values()):
465
+ if not isinstance(elem, MLWritable):
466
+ raise ValueError(
467
+ f"Validator write will fail because it contains {elem.uid} "
468
+ f"which is not writable."
469
+ )
470
+
471
+ estimatorParamMaps = instance.getEstimatorParamMaps()
472
+ paramErr = (
473
+ "Validator save requires all Params in estimatorParamMaps to apply to "
474
+ "its Estimator, An extraneous Param was found: "
475
+ )
476
+ for paramMap in estimatorParamMaps:
477
+ for param in paramMap:
478
+ if param.parent not in uidMap:
479
+ raise ValueError(paramErr + repr(param))
480
+
481
+ @staticmethod
482
+ def getValidatorModelWriterPersistSubModelsParam(writer: MLWriter) -> bool:
483
+ if "persistsubmodels" in writer.optionMap:
484
+ persistSubModelsParam = writer.optionMap["persistsubmodels"].lower()
485
+ if persistSubModelsParam == "true":
486
+ return True
487
+ elif persistSubModelsParam == "false":
488
+ return False
489
+ else:
490
+ raise ValueError(
491
+ f"persistSubModels option value {persistSubModelsParam} is invalid, "
492
+ f"the possible values are True, 'True' or False, 'False'"
493
+ )
494
+ else:
495
+ return writer.instance.subModels is not None # type: ignore[attr-defined]
496
+
497
+
498
+ _save_with_persist_submodels_no_submodels_found_err: str = (
499
+ "When persisting tuning models, you can only set persistSubModels to true if the tuning "
500
+ "was done with collectSubModels set to true. To save the sub-models, try rerunning fitting "
501
+ "with collectSubModels set to true."
502
+ )
503
+
504
+
505
+ @inherit_doc
506
+ class CrossValidatorReader(MLReader["CrossValidator"]):
507
+ def __init__(self, cls: Type["CrossValidator"]):
508
+ super(CrossValidatorReader, self).__init__()
509
+ self.cls = cls
510
+
511
+ def load(self, path: str) -> "CrossValidator":
512
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
513
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
514
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
515
+ else:
516
+ metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
517
+ path, self.sc, metadata
518
+ )
519
+ cv = CrossValidator(
520
+ estimator=estimator, estimatorParamMaps=estimatorParamMaps, evaluator=evaluator
521
+ )
522
+ cv = cv._resetUid(metadata["uid"])
523
+ DefaultParamsReader.getAndSetParams(cv, metadata, skipParams=["estimatorParamMaps"])
524
+ return cv
525
+
526
+
527
+ @inherit_doc
528
+ class CrossValidatorWriter(MLWriter):
529
+ def __init__(self, instance: "CrossValidator"):
530
+ super(CrossValidatorWriter, self).__init__()
531
+ self.instance = instance
532
+
533
+ def saveImpl(self, path: str) -> None:
534
+ _ValidatorSharedReadWrite.validateParams(self.instance)
535
+ _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
536
+
537
+
538
+ @inherit_doc
539
+ class CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
540
+ def __init__(self, cls: Type["CrossValidatorModel"]):
541
+ super(CrossValidatorModelReader, self).__init__()
542
+ self.cls = cls
543
+
544
+ def load(self, path: str) -> "CrossValidatorModel":
545
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
546
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
547
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
548
+ else:
549
+ metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
550
+ path, self.sc, metadata
551
+ )
552
+ numFolds = metadata["paramMap"]["numFolds"]
553
+ bestModelPath = os.path.join(path, "bestModel")
554
+ bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
555
+ avgMetrics = metadata["avgMetrics"]
556
+ if "stdMetrics" in metadata:
557
+ stdMetrics = metadata["stdMetrics"]
558
+ else:
559
+ stdMetrics = None
560
+ persistSubModels = ("persistSubModels" in metadata) and metadata["persistSubModels"]
561
+
562
+ if persistSubModels:
563
+ subModels = [[None] * len(estimatorParamMaps)] * numFolds
564
+ for splitIndex in range(numFolds):
565
+ for paramIndex in range(len(estimatorParamMaps)):
566
+ modelPath = os.path.join(
567
+ path, "subModels", f"fold{splitIndex}", f"{paramIndex}"
568
+ )
569
+ subModels[splitIndex][paramIndex] = DefaultParamsReader.loadParamsInstance(
570
+ modelPath, self.sc
571
+ )
572
+ else:
573
+ subModels = None
574
+
575
+ cvModel = CrossValidatorModel(
576
+ bestModel,
577
+ avgMetrics=avgMetrics,
578
+ subModels=cast(List[List[Model]], subModels),
579
+ stdMetrics=stdMetrics,
580
+ )
581
+ cvModel = cvModel._resetUid(metadata["uid"])
582
+ cvModel.set(cvModel.estimator, estimator)
583
+ cvModel.set(cvModel.estimatorParamMaps, estimatorParamMaps)
584
+ cvModel.set(cvModel.evaluator, evaluator)
585
+ DefaultParamsReader.getAndSetParams(
586
+ cvModel, metadata, skipParams=["estimatorParamMaps"]
587
+ )
588
+ return cvModel
589
+
590
+
591
+ @inherit_doc
592
+ class CrossValidatorModelWriter(MLWriter):
593
+ def __init__(self, instance: "CrossValidatorModel"):
594
+ super(CrossValidatorModelWriter, self).__init__()
595
+ self.instance = instance
596
+
597
+ def saveImpl(self, path: str) -> None:
598
+ _ValidatorSharedReadWrite.validateParams(self.instance)
599
+ instance = self.instance
600
+ persistSubModels = _ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam(
601
+ self
602
+ )
603
+ extraMetadata = {"avgMetrics": instance.avgMetrics, "persistSubModels": persistSubModels}
604
+ if instance.stdMetrics:
605
+ extraMetadata["stdMetrics"] = instance.stdMetrics
606
+
607
+ _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
608
+ bestModelPath = os.path.join(path, "bestModel")
609
+ cast(MLWritable, instance.bestModel).save(bestModelPath)
610
+ if persistSubModels:
611
+ if instance.subModels is None:
612
+ raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
613
+ subModelsPath = os.path.join(path, "subModels")
614
+ for splitIndex in range(instance.getNumFolds()):
615
+ splitPath = os.path.join(subModelsPath, f"fold{splitIndex}")
616
+ for paramIndex in range(len(instance.getEstimatorParamMaps())):
617
+ modelPath = os.path.join(splitPath, f"{paramIndex}")
618
+ cast(MLWritable, instance.subModels[splitIndex][paramIndex]).save(modelPath)
619
+
620
+
621
+ class _CrossValidatorParams(_ValidatorParams):
622
+ """
623
+ Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
624
+
625
+ .. versionadded:: 3.0.0
626
+ """
627
+
628
+ numFolds: Param[int] = Param(
629
+ Params._dummy(),
630
+ "numFolds",
631
+ "number of folds for cross validation",
632
+ typeConverter=TypeConverters.toInt,
633
+ )
634
+
635
+ foldCol: Param[str] = Param(
636
+ Params._dummy(),
637
+ "foldCol",
638
+ "Param for the column name of user "
639
+ + "specified fold number. Once this is specified, :py:class:`CrossValidator` "
640
+ + "won't do random k-fold split. Note that this column should be integer type "
641
+ + "with range [0, numFolds) and Spark will throw exception on out-of-range "
642
+ + "fold numbers.",
643
+ typeConverter=TypeConverters.toString,
644
+ )
645
+
646
+ def __init__(self, *args: Any):
647
+ super(_CrossValidatorParams, self).__init__(*args)
648
+ self._setDefault(numFolds=3, foldCol="")
649
+
650
+ @since("1.4.0")
651
+ def getNumFolds(self) -> int:
652
+ """
653
+ Gets the value of numFolds or its default value.
654
+ """
655
+ return self.getOrDefault(self.numFolds)
656
+
657
+ @since("3.1.0")
658
+ def getFoldCol(self) -> str:
659
+ """
660
+ Gets the value of foldCol or its default value.
661
+ """
662
+ return self.getOrDefault(self.foldCol)
663
+
664
+
665
+ class CrossValidator(
666
+ Estimator["CrossValidatorModel"],
667
+ _CrossValidatorParams,
668
+ HasParallelism,
669
+ HasCollectSubModels,
670
+ MLReadable["CrossValidator"],
671
+ MLWritable,
672
+ ):
673
+ """
674
+
675
+ K-fold cross validation performs model selection by splitting the dataset into a set of
676
+ non-overlapping randomly partitioned folds which are used as separate training and test datasets
677
+ e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
678
+ each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
679
+ test set exactly once.
680
+
681
+ .. versionadded:: 1.4.0
682
+
683
+ Examples
684
+ --------
685
+ >>> from pyspark.ml.classification import LogisticRegression
686
+ >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
687
+ >>> from pyspark.ml.linalg import Vectors
688
+ >>> from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
689
+ >>> import tempfile
690
+ >>> dataset = spark.createDataFrame(
691
+ ... [(Vectors.dense([0.0]), 0.0),
692
+ ... (Vectors.dense([0.4]), 1.0),
693
+ ... (Vectors.dense([0.5]), 0.0),
694
+ ... (Vectors.dense([0.6]), 1.0),
695
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
696
+ ... ["features", "label"])
697
+ >>> lr = LogisticRegression()
698
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
699
+ >>> evaluator = BinaryClassificationEvaluator()
700
+ >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
701
+ ... parallelism=2)
702
+ >>> cvModel = cv.fit(dataset)
703
+ >>> cvModel.getNumFolds()
704
+ 3
705
+ >>> cvModel.avgMetrics[0]
706
+ 0.5
707
+ >>> path = tempfile.mkdtemp()
708
+ >>> model_path = path + "/model"
709
+ >>> cvModel.write().save(model_path)
710
+ >>> cvModelRead = CrossValidatorModel.read().load(model_path)
711
+ >>> cvModelRead.avgMetrics
712
+ [0.5, ...
713
+ >>> evaluator.evaluate(cvModel.transform(dataset))
714
+ 0.8333...
715
+ >>> evaluator.evaluate(cvModelRead.transform(dataset))
716
+ 0.8333...
717
+ """
718
+
719
+ _input_kwargs: Dict[str, Any]
720
+
721
+ @keyword_only
722
+ def __init__(
723
+ self,
724
+ *,
725
+ estimator: Optional[Estimator] = None,
726
+ estimatorParamMaps: Optional[List["ParamMap"]] = None,
727
+ evaluator: Optional[Evaluator] = None,
728
+ numFolds: int = 3,
729
+ seed: Optional[int] = None,
730
+ parallelism: int = 1,
731
+ collectSubModels: bool = False,
732
+ foldCol: str = "",
733
+ ) -> None:
734
+ """
735
+ __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
736
+ seed=None, parallelism=1, collectSubModels=False, foldCol="")
737
+ """
738
+ super(CrossValidator, self).__init__()
739
+ self._setDefault(parallelism=1)
740
+ kwargs = self._input_kwargs
741
+ self._set(**kwargs)
742
+
743
+ @keyword_only
744
+ @since("1.4.0")
745
+ def setParams(
746
+ self,
747
+ *,
748
+ estimator: Optional[Estimator] = None,
749
+ estimatorParamMaps: Optional[List["ParamMap"]] = None,
750
+ evaluator: Optional[Evaluator] = None,
751
+ numFolds: int = 3,
752
+ seed: Optional[int] = None,
753
+ parallelism: int = 1,
754
+ collectSubModels: bool = False,
755
+ foldCol: str = "",
756
+ ) -> "CrossValidator":
757
+ """
758
+ setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
759
+ seed=None, parallelism=1, collectSubModels=False, foldCol=""):
760
+ Sets params for cross validator.
761
+ """
762
+ kwargs = self._input_kwargs
763
+ return self._set(**kwargs)
764
+
765
+ @since("2.0.0")
766
+ def setEstimator(self, value: Estimator) -> "CrossValidator":
767
+ """
768
+ Sets the value of :py:attr:`estimator`.
769
+ """
770
+ return self._set(estimator=value)
771
+
772
+ @since("2.0.0")
773
+ def setEstimatorParamMaps(self, value: List["ParamMap"]) -> "CrossValidator":
774
+ """
775
+ Sets the value of :py:attr:`estimatorParamMaps`.
776
+ """
777
+ return self._set(estimatorParamMaps=value)
778
+
779
+ @since("2.0.0")
780
+ def setEvaluator(self, value: Evaluator) -> "CrossValidator":
781
+ """
782
+ Sets the value of :py:attr:`evaluator`.
783
+ """
784
+ return self._set(evaluator=value)
785
+
786
+ @since("1.4.0")
787
+ def setNumFolds(self, value: int) -> "CrossValidator":
788
+ """
789
+ Sets the value of :py:attr:`numFolds`.
790
+ """
791
+ return self._set(numFolds=value)
792
+
793
+ @since("3.1.0")
794
+ def setFoldCol(self, value: str) -> "CrossValidator":
795
+ """
796
+ Sets the value of :py:attr:`foldCol`.
797
+ """
798
+ return self._set(foldCol=value)
799
+
800
+ def setSeed(self, value: int) -> "CrossValidator":
801
+ """
802
+ Sets the value of :py:attr:`seed`.
803
+ """
804
+ return self._set(seed=value)
805
+
806
+ def setParallelism(self, value: int) -> "CrossValidator":
807
+ """
808
+ Sets the value of :py:attr:`parallelism`.
809
+ """
810
+ return self._set(parallelism=value)
811
+
812
+ def setCollectSubModels(self, value: bool) -> "CrossValidator":
813
+ """
814
+ Sets the value of :py:attr:`collectSubModels`.
815
+ """
816
+ return self._set(collectSubModels=value)
817
+
818
+ @staticmethod
819
+ def _gen_avg_and_std_metrics(metrics_all: List[List[float]]) -> Tuple[List[float], List[float]]:
820
+ avg_metrics = np.mean(metrics_all, axis=0)
821
+ std_metrics = np.std(metrics_all, axis=0)
822
+ return list(avg_metrics), list(std_metrics)
823
+
824
+ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
825
+ est = self.getOrDefault(self.estimator)
826
+ epm = self.getOrDefault(self.estimatorParamMaps)
827
+ numModels = len(epm)
828
+ eva = self.getOrDefault(self.evaluator)
829
+ nFolds = self.getOrDefault(self.numFolds)
830
+ metrics_all = [[0.0] * numModels for i in range(nFolds)]
831
+
832
+ pool = ThreadPool(processes=min(self.getParallelism(), numModels))
833
+ subModels = None
834
+ collectSubModelsParam = self.getCollectSubModels()
835
+ if collectSubModelsParam:
836
+ subModels = [[None for j in range(numModels)] for i in range(nFolds)]
837
+
838
+ datasets = self._kFold(dataset)
839
+ for i in range(nFolds):
840
+ validation = datasets[i][1].cache()
841
+ train = datasets[i][0].cache()
842
+
843
+ tasks = map(
844
+ inheritable_thread_target,
845
+ _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
846
+ )
847
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
848
+ metrics_all[i][j] = metric
849
+ if collectSubModelsParam:
850
+ assert subModels is not None
851
+ subModels[i][j] = subModel
852
+
853
+ validation.unpersist()
854
+ train.unpersist()
855
+
856
+ metrics, std_metrics = CrossValidator._gen_avg_and_std_metrics(metrics_all)
857
+
858
+ if eva.isLargerBetter():
859
+ bestIndex = np.argmax(metrics)
860
+ else:
861
+ bestIndex = np.argmin(metrics)
862
+ bestModel = est.fit(dataset, epm[bestIndex])
863
+ return self._copyValues(
864
+ CrossValidatorModel(bestModel, metrics, cast(List[List[Model]], subModels), std_metrics)
865
+ )
866
+
867
+ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
868
+ nFolds = self.getOrDefault(self.numFolds)
869
+ foldCol = self.getOrDefault(self.foldCol)
870
+
871
+ datasets = []
872
+ if not foldCol:
873
+ # Do random k-fold split.
874
+ seed = self.getOrDefault(self.seed)
875
+ h = 1.0 / nFolds
876
+ randCol = self.uid + "_rand"
877
+ df = dataset.select("*", rand(seed).alias(randCol))
878
+ for i in range(nFolds):
879
+ validateLB = i * h
880
+ validateUB = (i + 1) * h
881
+ condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
882
+ validation = df.filter(condition)
883
+ train = df.filter(~condition)
884
+ datasets.append((train, validation))
885
+ else:
886
+ # Use user-specified fold numbers.
887
+ def checker(foldNum: int) -> bool:
888
+ if foldNum < 0 or foldNum >= nFolds:
889
+ raise ValueError(
890
+ "Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum)
891
+ )
892
+ return True
893
+
894
+ checker_udf = UserDefinedFunction(checker, BooleanType())
895
+ for i in range(nFolds):
896
+ training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
897
+ validation = dataset.filter(
898
+ checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i))
899
+ )
900
+ if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0:
901
+ raise ValueError("The training data at fold %s is empty." % i)
902
+ if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0:
903
+ raise ValueError("The validation data at fold %s is empty." % i)
904
+ datasets.append((training, validation))
905
+
906
+ return datasets
907
+
908
+ def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidator":
909
+ """
910
+ Creates a copy of this instance with a randomly generated uid
911
+ and some extra params. This copies creates a deep copy of
912
+ the embedded paramMap, and copies the embedded and extra parameters over.
913
+
914
+
915
+ .. versionadded:: 1.4.0
916
+
917
+ Parameters
918
+ ----------
919
+ extra : dict, optional
920
+ Extra parameters to copy to the new instance
921
+
922
+ Returns
923
+ -------
924
+ :py:class:`CrossValidator`
925
+ Copy of this instance
926
+ """
927
+ if extra is None:
928
+ extra = dict()
929
+ newCV = Params.copy(self, extra)
930
+ if self.isSet(self.estimator):
931
+ newCV.setEstimator(self.getEstimator().copy(extra))
932
+ # estimatorParamMaps remain the same
933
+ if self.isSet(self.evaluator):
934
+ newCV.setEvaluator(self.getEvaluator().copy(extra))
935
+ return newCV
936
+
937
+ @since("2.3.0")
938
+ def write(self) -> MLWriter:
939
+ """Returns an MLWriter instance for this ML instance."""
940
+ if _ValidatorSharedReadWrite.is_java_convertible(self):
941
+ return JavaMLWriter(self) # type: ignore[arg-type]
942
+ return CrossValidatorWriter(self)
943
+
944
+ @classmethod
945
+ @since("2.3.0")
946
+ def read(cls) -> CrossValidatorReader:
947
+ """Returns an MLReader instance for this class."""
948
+ return CrossValidatorReader(cls)
949
+
950
+ @classmethod
951
+ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidator":
952
+ """
953
+ Given a Java CrossValidator, create and return a Python wrapper of it.
954
+ Used for ML persistence.
955
+ """
956
+
957
+ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
958
+ numFolds = java_stage.getNumFolds()
959
+ seed = java_stage.getSeed()
960
+ parallelism = java_stage.getParallelism()
961
+ collectSubModels = java_stage.getCollectSubModels()
962
+ foldCol = java_stage.getFoldCol()
963
+ # Create a new instance of this stage.
964
+ py_stage = cls(
965
+ estimator=estimator,
966
+ estimatorParamMaps=epms,
967
+ evaluator=evaluator,
968
+ numFolds=numFolds,
969
+ seed=seed,
970
+ parallelism=parallelism,
971
+ collectSubModels=collectSubModels,
972
+ foldCol=foldCol,
973
+ )
974
+ py_stage._resetUid(java_stage.uid())
975
+ return py_stage
976
+
977
+ def _to_java(self) -> "JavaObject":
978
+ """
979
+ Transfer this instance to a Java CrossValidator. Used for ML persistence.
980
+
981
+ Returns
982
+ -------
983
+ py4j.java_gateway.JavaObject
984
+ Java object equivalent to this instance.
985
+ """
986
+
987
+ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
988
+
989
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
990
+ _java_obj.setEstimatorParamMaps(epms)
991
+ _java_obj.setEvaluator(evaluator)
992
+ _java_obj.setEstimator(estimator)
993
+ _java_obj.setSeed(self.getSeed())
994
+ _java_obj.setNumFolds(self.getNumFolds())
995
+ _java_obj.setParallelism(self.getParallelism())
996
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
997
+ _java_obj.setFoldCol(self.getFoldCol())
998
+
999
+ return _java_obj
1000
+
1001
+
1002
+ class CrossValidatorModel(
1003
+ Model, _CrossValidatorParams, MLReadable["CrossValidatorModel"], MLWritable
1004
+ ):
1005
+ """
1006
+ CrossValidatorModel contains the model with the highest average cross-validation
1007
+ metric across folds and uses this model to transform input data. CrossValidatorModel
1008
+ also tracks the metrics for each param map evaluated.
1009
+
1010
+ .. versionadded:: 1.4.0
1011
+
1012
+ Notes
1013
+ -----
1014
+ Since version 3.3.0, CrossValidatorModel contains a new attribute "stdMetrics",
1015
+ which represent standard deviation of metrics for each paramMap in
1016
+ CrossValidator.estimatorParamMaps.
1017
+ """
1018
+
1019
+ def __init__(
1020
+ self,
1021
+ bestModel: Model,
1022
+ avgMetrics: Optional[List[float]] = None,
1023
+ subModels: Optional[List[List[Model]]] = None,
1024
+ stdMetrics: Optional[List[float]] = None,
1025
+ ):
1026
+ super(CrossValidatorModel, self).__init__()
1027
+ #: best model from cross validation
1028
+ self.bestModel = bestModel
1029
+ #: Average cross-validation metrics for each paramMap in
1030
+ #: CrossValidator.estimatorParamMaps, in the corresponding order.
1031
+ self.avgMetrics = avgMetrics or []
1032
+ #: sub model list from cross validation
1033
+ self.subModels = subModels
1034
+ #: standard deviation of metrics for each paramMap in
1035
+ #: CrossValidator.estimatorParamMaps, in the corresponding order.
1036
+ self.stdMetrics = stdMetrics or []
1037
+
1038
+ def _transform(self, dataset: DataFrame) -> DataFrame:
1039
+ return self.bestModel.transform(dataset)
1040
+
1041
+ def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidatorModel":
1042
+ """
1043
+ Creates a copy of this instance with a randomly generated uid
1044
+ and some extra params. This copies the underlying bestModel,
1045
+ creates a deep copy of the embedded paramMap, and
1046
+ copies the embedded and extra parameters over.
1047
+ It does not copy the extra Params into the subModels.
1048
+
1049
+ .. versionadded:: 1.4.0
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ extra : dict, optional
1054
+ Extra parameters to copy to the new instance
1055
+
1056
+ Returns
1057
+ -------
1058
+ :py:class:`CrossValidatorModel`
1059
+ Copy of this instance
1060
+ """
1061
+ if extra is None:
1062
+ extra = dict()
1063
+ bestModel = self.bestModel.copy(extra)
1064
+ avgMetrics = list(self.avgMetrics)
1065
+ assert self.subModels is not None
1066
+ subModels = [
1067
+ [sub_model.copy() for sub_model in fold_sub_models]
1068
+ for fold_sub_models in self.subModels
1069
+ ]
1070
+ stdMetrics = list(self.stdMetrics)
1071
+ return self._copyValues(
1072
+ CrossValidatorModel(bestModel, avgMetrics, subModels, stdMetrics), extra=extra
1073
+ )
1074
+
1075
+ @since("2.3.0")
1076
+ def write(self) -> MLWriter:
1077
+ """Returns an MLWriter instance for this ML instance."""
1078
+ if _ValidatorSharedReadWrite.is_java_convertible(self):
1079
+ return JavaMLWriter(self) # type: ignore[arg-type]
1080
+ return CrossValidatorModelWriter(self)
1081
+
1082
+ @classmethod
1083
+ @since("2.3.0")
1084
+ def read(cls) -> CrossValidatorModelReader:
1085
+ """Returns an MLReader instance for this class."""
1086
+ return CrossValidatorModelReader(cls)
1087
+
1088
+ @classmethod
1089
+ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidatorModel":
1090
+ """
1091
+ Given a Java CrossValidatorModel, create and return a Python wrapper of it.
1092
+ Used for ML persistence.
1093
+ """
1094
+ sc = SparkContext._active_spark_context
1095
+ assert sc is not None
1096
+
1097
+ bestModel: Model = JavaParams._from_java(java_stage.bestModel())
1098
+ avgMetrics = _java2py(sc, java_stage.avgMetrics())
1099
+ estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
1100
+
1101
+ py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)
1102
+ params = {
1103
+ "evaluator": evaluator,
1104
+ "estimator": estimator,
1105
+ "estimatorParamMaps": epms,
1106
+ "numFolds": java_stage.getNumFolds(),
1107
+ "foldCol": java_stage.getFoldCol(),
1108
+ "seed": java_stage.getSeed(),
1109
+ }
1110
+ for param_name, param_val in params.items():
1111
+ py_stage = py_stage._set(**{param_name: param_val})
1112
+
1113
+ if java_stage.hasSubModels():
1114
+ py_stage.subModels = [
1115
+ [JavaParams._from_java(sub_model) for sub_model in fold_sub_models]
1116
+ for fold_sub_models in java_stage.subModels()
1117
+ ]
1118
+
1119
+ py_stage._resetUid(java_stage.uid())
1120
+ return py_stage
1121
+
1122
+ def _to_java(self) -> "JavaObject":
1123
+ """
1124
+ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
1125
+
1126
+ Returns
1127
+ -------
1128
+ py4j.java_gateway.JavaObject
1129
+ Java object equivalent to this instance.
1130
+ """
1131
+
1132
+ sc = SparkContext._active_spark_context
1133
+ assert sc is not None
1134
+
1135
+ _java_obj = JavaParams._new_java_obj(
1136
+ "org.apache.spark.ml.tuning.CrossValidatorModel",
1137
+ self.uid,
1138
+ cast(JavaParams, self.bestModel)._to_java(),
1139
+ _py2java(sc, self.avgMetrics),
1140
+ )
1141
+ estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
1142
+
1143
+ params = {
1144
+ "evaluator": evaluator,
1145
+ "estimator": estimator,
1146
+ "estimatorParamMaps": epms,
1147
+ "numFolds": self.getNumFolds(),
1148
+ "foldCol": self.getFoldCol(),
1149
+ "seed": self.getSeed(),
1150
+ }
1151
+ for param_name, param_val in params.items():
1152
+ java_param = _java_obj.getParam(param_name)
1153
+ pair = java_param.w(param_val)
1154
+ _java_obj.set(pair)
1155
+
1156
+ if self.subModels is not None:
1157
+ java_sub_models = [
1158
+ [cast(JavaParams, sub_model)._to_java() for sub_model in fold_sub_models]
1159
+ for fold_sub_models in self.subModels
1160
+ ]
1161
+ _java_obj.setSubModels(java_sub_models)
1162
+ return _java_obj
1163
+
1164
+
1165
+ @inherit_doc
1166
+ class TrainValidationSplitReader(MLReader["TrainValidationSplit"]):
1167
+ def __init__(self, cls: Type["TrainValidationSplit"]):
1168
+ super(TrainValidationSplitReader, self).__init__()
1169
+ self.cls = cls
1170
+
1171
+ def load(self, path: str) -> "TrainValidationSplit":
1172
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
1173
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
1174
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
1175
+ else:
1176
+ metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
1177
+ path, self.sc, metadata
1178
+ )
1179
+ tvs = TrainValidationSplit(
1180
+ estimator=estimator, estimatorParamMaps=estimatorParamMaps, evaluator=evaluator
1181
+ )
1182
+ tvs = tvs._resetUid(metadata["uid"])
1183
+ DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams=["estimatorParamMaps"])
1184
+ return tvs
1185
+
1186
+
1187
+ @inherit_doc
1188
+ class TrainValidationSplitWriter(MLWriter):
1189
+ def __init__(self, instance: "TrainValidationSplit"):
1190
+ super(TrainValidationSplitWriter, self).__init__()
1191
+ self.instance = instance
1192
+
1193
+ def saveImpl(self, path: str) -> None:
1194
+ _ValidatorSharedReadWrite.validateParams(self.instance)
1195
+ _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
1196
+
1197
+
1198
+ @inherit_doc
1199
+ class TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
1200
+ def __init__(self, cls: Type["TrainValidationSplitModel"]):
1201
+ super(TrainValidationSplitModelReader, self).__init__()
1202
+ self.cls = cls
1203
+
1204
+ def load(self, path: str) -> "TrainValidationSplitModel":
1205
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
1206
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
1207
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
1208
+ else:
1209
+ metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
1210
+ path, self.sc, metadata
1211
+ )
1212
+ bestModelPath = os.path.join(path, "bestModel")
1213
+ bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
1214
+ validationMetrics = metadata["validationMetrics"]
1215
+ persistSubModels = ("persistSubModels" in metadata) and metadata["persistSubModels"]
1216
+
1217
+ if persistSubModels:
1218
+ subModels = [None] * len(estimatorParamMaps)
1219
+ for paramIndex in range(len(estimatorParamMaps)):
1220
+ modelPath = os.path.join(path, "subModels", f"{paramIndex}")
1221
+ subModels[paramIndex] = DefaultParamsReader.loadParamsInstance(
1222
+ modelPath, self.sc
1223
+ )
1224
+ else:
1225
+ subModels = None
1226
+
1227
+ tvsModel = TrainValidationSplitModel(
1228
+ bestModel,
1229
+ validationMetrics=validationMetrics,
1230
+ subModels=cast(Optional[List[Model]], subModels),
1231
+ )
1232
+ tvsModel = tvsModel._resetUid(metadata["uid"])
1233
+ tvsModel.set(tvsModel.estimator, estimator)
1234
+ tvsModel.set(tvsModel.estimatorParamMaps, estimatorParamMaps)
1235
+ tvsModel.set(tvsModel.evaluator, evaluator)
1236
+ DefaultParamsReader.getAndSetParams(
1237
+ tvsModel, metadata, skipParams=["estimatorParamMaps"]
1238
+ )
1239
+ return tvsModel
1240
+
1241
+
1242
+ @inherit_doc
1243
+ class TrainValidationSplitModelWriter(MLWriter):
1244
+ def __init__(self, instance: "TrainValidationSplitModel"):
1245
+ super(TrainValidationSplitModelWriter, self).__init__()
1246
+ self.instance = instance
1247
+
1248
+ def saveImpl(self, path: str) -> None:
1249
+ _ValidatorSharedReadWrite.validateParams(self.instance)
1250
+ instance = self.instance
1251
+ persistSubModels = _ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam(
1252
+ self
1253
+ )
1254
+
1255
+ extraMetadata = {
1256
+ "validationMetrics": instance.validationMetrics,
1257
+ "persistSubModels": persistSubModels,
1258
+ }
1259
+ _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
1260
+ bestModelPath = os.path.join(path, "bestModel")
1261
+ cast(MLWritable, instance.bestModel).save(bestModelPath)
1262
+ if persistSubModels:
1263
+ if instance.subModels is None:
1264
+ raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
1265
+ subModelsPath = os.path.join(path, "subModels")
1266
+ for paramIndex in range(len(instance.getEstimatorParamMaps())):
1267
+ modelPath = os.path.join(subModelsPath, f"{paramIndex}")
1268
+ cast(MLWritable, instance.subModels[paramIndex]).save(modelPath)
1269
+
1270
+
1271
+ class _TrainValidationSplitParams(_ValidatorParams):
1272
+ """
1273
+ Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
1274
+
1275
+ .. versionadded:: 3.0.0
1276
+ """
1277
+
1278
+ trainRatio: Param[float] = Param(
1279
+ Params._dummy(),
1280
+ "trainRatio",
1281
+ "Param for ratio between train and\
1282
+ validation data. Must be between 0 and 1.",
1283
+ typeConverter=TypeConverters.toFloat,
1284
+ )
1285
+
1286
+ def __init__(self, *args: Any):
1287
+ super(_TrainValidationSplitParams, self).__init__(*args)
1288
+ self._setDefault(trainRatio=0.75)
1289
+
1290
+ @since("2.0.0")
1291
+ def getTrainRatio(self) -> float:
1292
+ """
1293
+ Gets the value of trainRatio or its default value.
1294
+ """
1295
+ return self.getOrDefault(self.trainRatio)
1296
+
1297
+
1298
+ class TrainValidationSplit(
1299
+ Estimator["TrainValidationSplitModel"],
1300
+ _TrainValidationSplitParams,
1301
+ HasParallelism,
1302
+ HasCollectSubModels,
1303
+ MLReadable["TrainValidationSplit"],
1304
+ MLWritable,
1305
+ ):
1306
+ """
1307
+ Validation for hyper-parameter tuning. Randomly splits the input dataset into train and
1308
+ validation sets, and uses evaluation metric on the validation set to select the best model.
1309
+ Similar to :class:`CrossValidator`, but only splits the set once.
1310
+
1311
+ .. versionadded:: 2.0.0
1312
+
1313
+ Examples
1314
+ --------
1315
+ >>> from pyspark.ml.classification import LogisticRegression
1316
+ >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
1317
+ >>> from pyspark.ml.linalg import Vectors
1318
+ >>> from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder
1319
+ >>> from pyspark.ml.tuning import TrainValidationSplitModel
1320
+ >>> import tempfile
1321
+ >>> dataset = spark.createDataFrame(
1322
+ ... [(Vectors.dense([0.0]), 0.0),
1323
+ ... (Vectors.dense([0.4]), 1.0),
1324
+ ... (Vectors.dense([0.5]), 0.0),
1325
+ ... (Vectors.dense([0.6]), 1.0),
1326
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
1327
+ ... ["features", "label"]).repartition(1)
1328
+ >>> lr = LogisticRegression()
1329
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
1330
+ >>> evaluator = BinaryClassificationEvaluator()
1331
+ >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
1332
+ ... parallelism=1, seed=42)
1333
+ >>> tvsModel = tvs.fit(dataset)
1334
+ >>> tvsModel.getTrainRatio()
1335
+ 0.75
1336
+ >>> tvsModel.validationMetrics
1337
+ [0.5, ...
1338
+ >>> path = tempfile.mkdtemp()
1339
+ >>> model_path = path + "/model"
1340
+ >>> tvsModel.write().save(model_path)
1341
+ >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
1342
+ >>> tvsModelRead.validationMetrics
1343
+ [0.5, ...
1344
+ >>> evaluator.evaluate(tvsModel.transform(dataset))
1345
+ 0.833...
1346
+ >>> evaluator.evaluate(tvsModelRead.transform(dataset))
1347
+ 0.833...
1348
+ """
1349
+
1350
+ _input_kwargs: Dict[str, Any]
1351
+
1352
+ @keyword_only
1353
+ def __init__(
1354
+ self,
1355
+ *,
1356
+ estimator: Optional[Estimator] = None,
1357
+ estimatorParamMaps: Optional[List["ParamMap"]] = None,
1358
+ evaluator: Optional[Evaluator] = None,
1359
+ trainRatio: float = 0.75,
1360
+ parallelism: int = 1,
1361
+ collectSubModels: bool = False,
1362
+ seed: Optional[int] = None,
1363
+ ) -> None:
1364
+ """
1365
+ __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \
1366
+ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)
1367
+ """
1368
+ super(TrainValidationSplit, self).__init__()
1369
+ self._setDefault(parallelism=1)
1370
+ kwargs = self._input_kwargs
1371
+ self._set(**kwargs)
1372
+
1373
+ @since("2.0.0")
1374
+ @keyword_only
1375
+ def setParams(
1376
+ self,
1377
+ *,
1378
+ estimator: Optional[Estimator] = None,
1379
+ estimatorParamMaps: Optional[List["ParamMap"]] = None,
1380
+ evaluator: Optional[Evaluator] = None,
1381
+ trainRatio: float = 0.75,
1382
+ parallelism: int = 1,
1383
+ collectSubModels: bool = False,
1384
+ seed: Optional[int] = None,
1385
+ ) -> "TrainValidationSplit":
1386
+ """
1387
+ setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \
1388
+ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None):
1389
+ Sets params for the train validation split.
1390
+ """
1391
+ kwargs = self._input_kwargs
1392
+ return self._set(**kwargs)
1393
+
1394
+ @since("2.0.0")
1395
+ def setEstimator(self, value: Estimator) -> "TrainValidationSplit":
1396
+ """
1397
+ Sets the value of :py:attr:`estimator`.
1398
+ """
1399
+ return self._set(estimator=value)
1400
+
1401
+ @since("2.0.0")
1402
+ def setEstimatorParamMaps(self, value: List["ParamMap"]) -> "TrainValidationSplit":
1403
+ """
1404
+ Sets the value of :py:attr:`estimatorParamMaps`.
1405
+ """
1406
+ return self._set(estimatorParamMaps=value)
1407
+
1408
+ @since("2.0.0")
1409
+ def setEvaluator(self, value: Evaluator) -> "TrainValidationSplit":
1410
+ """
1411
+ Sets the value of :py:attr:`evaluator`.
1412
+ """
1413
+ return self._set(evaluator=value)
1414
+
1415
+ @since("2.0.0")
1416
+ def setTrainRatio(self, value: float) -> "TrainValidationSplit":
1417
+ """
1418
+ Sets the value of :py:attr:`trainRatio`.
1419
+ """
1420
+ return self._set(trainRatio=value)
1421
+
1422
+ def setSeed(self, value: int) -> "TrainValidationSplit":
1423
+ """
1424
+ Sets the value of :py:attr:`seed`.
1425
+ """
1426
+ return self._set(seed=value)
1427
+
1428
+ def setParallelism(self, value: int) -> "TrainValidationSplit":
1429
+ """
1430
+ Sets the value of :py:attr:`parallelism`.
1431
+ """
1432
+ return self._set(parallelism=value)
1433
+
1434
+ def setCollectSubModels(self, value: bool) -> "TrainValidationSplit":
1435
+ """
1436
+ Sets the value of :py:attr:`collectSubModels`.
1437
+ """
1438
+ return self._set(collectSubModels=value)
1439
+
1440
+ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
1441
+ est = self.getOrDefault(self.estimator)
1442
+ epm = self.getOrDefault(self.estimatorParamMaps)
1443
+ numModels = len(epm)
1444
+ eva = self.getOrDefault(self.evaluator)
1445
+ tRatio = self.getOrDefault(self.trainRatio)
1446
+ seed = self.getOrDefault(self.seed)
1447
+ randCol = self.uid + "_rand"
1448
+ df = dataset.select("*", rand(seed).alias(randCol))
1449
+ condition = df[randCol] >= tRatio
1450
+ validation = df.filter(condition).cache()
1451
+ train = df.filter(~condition).cache()
1452
+
1453
+ subModels = None
1454
+ collectSubModelsParam = self.getCollectSubModels()
1455
+ if collectSubModelsParam:
1456
+ subModels = [None for i in range(numModels)]
1457
+
1458
+ tasks = map(
1459
+ inheritable_thread_target,
1460
+ _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
1461
+ )
1462
+ pool = ThreadPool(processes=min(self.getParallelism(), numModels))
1463
+ metrics = [None] * numModels
1464
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
1465
+ metrics[j] = metric
1466
+ if collectSubModelsParam:
1467
+ assert subModels is not None
1468
+ subModels[j] = subModel
1469
+
1470
+ train.unpersist()
1471
+ validation.unpersist()
1472
+
1473
+ if eva.isLargerBetter():
1474
+ bestIndex = np.argmax(cast(List[float], metrics))
1475
+ else:
1476
+ bestIndex = np.argmin(cast(List[float], metrics))
1477
+ bestModel = est.fit(dataset, epm[bestIndex])
1478
+ return self._copyValues(
1479
+ TrainValidationSplitModel(
1480
+ bestModel,
1481
+ cast(List[float], metrics),
1482
+ subModels, # type: ignore[arg-type]
1483
+ )
1484
+ )
1485
+
1486
+ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplit":
1487
+ """
1488
+ Creates a copy of this instance with a randomly generated uid
1489
+ and some extra params. This copies creates a deep copy of
1490
+ the embedded paramMap, and copies the embedded and extra parameters over.
1491
+
1492
+ .. versionadded:: 2.0.0
1493
+
1494
+ Parameters
1495
+ ----------
1496
+ extra : dict, optional
1497
+ Extra parameters to copy to the new instance
1498
+
1499
+ Returns
1500
+ -------
1501
+ :py:class:`TrainValidationSplit`
1502
+ Copy of this instance
1503
+ """
1504
+ if extra is None:
1505
+ extra = dict()
1506
+ newTVS = Params.copy(self, extra)
1507
+ if self.isSet(self.estimator):
1508
+ newTVS.setEstimator(self.getEstimator().copy(extra))
1509
+ # estimatorParamMaps remain the same
1510
+ if self.isSet(self.evaluator):
1511
+ newTVS.setEvaluator(self.getEvaluator().copy(extra))
1512
+ return newTVS
1513
+
1514
+ @since("2.3.0")
1515
+ def write(self) -> MLWriter:
1516
+ """Returns an MLWriter instance for this ML instance."""
1517
+ if _ValidatorSharedReadWrite.is_java_convertible(self):
1518
+ return JavaMLWriter(self) # type: ignore[arg-type]
1519
+ return TrainValidationSplitWriter(self)
1520
+
1521
+ @classmethod
1522
+ @since("2.3.0")
1523
+ def read(cls) -> TrainValidationSplitReader:
1524
+ """Returns an MLReader instance for this class."""
1525
+ return TrainValidationSplitReader(cls)
1526
+
1527
+ @classmethod
1528
+ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplit":
1529
+ """
1530
+ Given a Java TrainValidationSplit, create and return a Python wrapper of it.
1531
+ Used for ML persistence.
1532
+ """
1533
+
1534
+ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
1535
+ trainRatio = java_stage.getTrainRatio()
1536
+ seed = java_stage.getSeed()
1537
+ parallelism = java_stage.getParallelism()
1538
+ collectSubModels = java_stage.getCollectSubModels()
1539
+ # Create a new instance of this stage.
1540
+ py_stage = cls(
1541
+ estimator=estimator,
1542
+ estimatorParamMaps=epms,
1543
+ evaluator=evaluator,
1544
+ trainRatio=trainRatio,
1545
+ seed=seed,
1546
+ parallelism=parallelism,
1547
+ collectSubModels=collectSubModels,
1548
+ )
1549
+ py_stage._resetUid(java_stage.uid())
1550
+ return py_stage
1551
+
1552
+ def _to_java(self) -> "JavaObject":
1553
+ """
1554
+ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
1555
+
1556
+ Returns
1557
+ -------
1558
+ py4j.java_gateway.JavaObject
1559
+ Java object equivalent to this instance.
1560
+ """
1561
+
1562
+ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
1563
+
1564
+ _java_obj = JavaParams._new_java_obj(
1565
+ "org.apache.spark.ml.tuning.TrainValidationSplit", self.uid
1566
+ )
1567
+ _java_obj.setEstimatorParamMaps(epms)
1568
+ _java_obj.setEvaluator(evaluator)
1569
+ _java_obj.setEstimator(estimator)
1570
+ _java_obj.setTrainRatio(self.getTrainRatio())
1571
+ _java_obj.setSeed(self.getSeed())
1572
+ _java_obj.setParallelism(self.getParallelism())
1573
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
1574
+ return _java_obj
1575
+
1576
+
1577
+ class TrainValidationSplitModel(
1578
+ Model, _TrainValidationSplitParams, MLReadable["TrainValidationSplitModel"], MLWritable
1579
+ ):
1580
+ """
1581
+ Model from train validation split.
1582
+
1583
+ .. versionadded:: 2.0.0
1584
+ """
1585
+
1586
+ def __init__(
1587
+ self,
1588
+ bestModel: Model,
1589
+ validationMetrics: Optional[List[float]] = None,
1590
+ subModels: Optional[List[Model]] = None,
1591
+ ):
1592
+ super(TrainValidationSplitModel, self).__init__()
1593
+ #: best model from train validation split
1594
+ self.bestModel = bestModel
1595
+ #: evaluated validation metrics
1596
+ self.validationMetrics = validationMetrics or []
1597
+ #: sub models from train validation split
1598
+ self.subModels = subModels
1599
+
1600
+ def _transform(self, dataset: DataFrame) -> DataFrame:
1601
+ return self.bestModel.transform(dataset)
1602
+
1603
+ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplitModel":
1604
+ """
1605
+ Creates a copy of this instance with a randomly generated uid
1606
+ and some extra params. This copies the underlying bestModel,
1607
+ creates a deep copy of the embedded paramMap, and
1608
+ copies the embedded and extra parameters over.
1609
+ And, this creates a shallow copy of the validationMetrics.
1610
+ It does not copy the extra Params into the subModels.
1611
+
1612
+ .. versionadded:: 2.0.0
1613
+
1614
+ Parameters
1615
+ ----------
1616
+ extra : dict, optional
1617
+ Extra parameters to copy to the new instance
1618
+
1619
+ Returns
1620
+ -------
1621
+ :py:class:`TrainValidationSplitModel`
1622
+ Copy of this instance
1623
+ """
1624
+ if extra is None:
1625
+ extra = dict()
1626
+ bestModel = self.bestModel.copy(extra)
1627
+ validationMetrics = list(self.validationMetrics)
1628
+ assert self.subModels is not None
1629
+ subModels = [model.copy() for model in self.subModels]
1630
+ return self._copyValues(
1631
+ TrainValidationSplitModel(bestModel, validationMetrics, subModels), extra=extra
1632
+ )
1633
+
1634
+ @since("2.3.0")
1635
+ def write(self) -> MLWriter:
1636
+ """Returns an MLWriter instance for this ML instance."""
1637
+ if _ValidatorSharedReadWrite.is_java_convertible(self):
1638
+ return JavaMLWriter(self) # type: ignore[arg-type]
1639
+ return TrainValidationSplitModelWriter(self)
1640
+
1641
+ @classmethod
1642
+ @since("2.3.0")
1643
+ def read(cls) -> TrainValidationSplitModelReader:
1644
+ """Returns an MLReader instance for this class."""
1645
+ return TrainValidationSplitModelReader(cls)
1646
+
1647
+ @classmethod
1648
+ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplitModel":
1649
+ """
1650
+ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
1651
+ Used for ML persistence.
1652
+ """
1653
+
1654
+ # Load information from java_stage to the instance.
1655
+ sc = SparkContext._active_spark_context
1656
+ assert sc is not None
1657
+
1658
+ bestModel: Model = JavaParams._from_java(java_stage.bestModel())
1659
+ validationMetrics = _java2py(sc, java_stage.validationMetrics())
1660
+ estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(
1661
+ java_stage
1662
+ )
1663
+ # Create a new instance of this stage.
1664
+ py_stage = cls(bestModel=bestModel, validationMetrics=validationMetrics)
1665
+ params = {
1666
+ "evaluator": evaluator,
1667
+ "estimator": estimator,
1668
+ "estimatorParamMaps": epms,
1669
+ "trainRatio": java_stage.getTrainRatio(),
1670
+ "seed": java_stage.getSeed(),
1671
+ }
1672
+ for param_name, param_val in params.items():
1673
+ py_stage = py_stage._set(**{param_name: param_val})
1674
+
1675
+ if java_stage.hasSubModels():
1676
+ py_stage.subModels = [
1677
+ JavaParams._from_java(sub_model) for sub_model in java_stage.subModels()
1678
+ ]
1679
+
1680
+ py_stage._resetUid(java_stage.uid())
1681
+ return py_stage
1682
+
1683
+ def _to_java(self) -> "JavaObject":
1684
+ """
1685
+ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
1686
+
1687
+ Returns
1688
+ -------
1689
+ py4j.java_gateway.JavaObject
1690
+ Java object equivalent to this instance.
1691
+ """
1692
+
1693
+ sc = SparkContext._active_spark_context
1694
+ assert sc is not None
1695
+
1696
+ _java_obj = JavaParams._new_java_obj(
1697
+ "org.apache.spark.ml.tuning.TrainValidationSplitModel",
1698
+ self.uid,
1699
+ cast(JavaParams, self.bestModel)._to_java(),
1700
+ _py2java(sc, self.validationMetrics),
1701
+ )
1702
+ estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
1703
+
1704
+ params = {
1705
+ "evaluator": evaluator,
1706
+ "estimator": estimator,
1707
+ "estimatorParamMaps": epms,
1708
+ "trainRatio": self.getTrainRatio(),
1709
+ "seed": self.getSeed(),
1710
+ }
1711
+ for param_name, param_val in params.items():
1712
+ java_param = _java_obj.getParam(param_name)
1713
+ pair = java_param.w(param_val)
1714
+ _java_obj.set(pair)
1715
+
1716
+ if self.subModels is not None:
1717
+ java_sub_models = [
1718
+ cast(JavaParams, sub_model)._to_java() for sub_model in self.subModels
1719
+ ]
1720
+ _java_obj.setSubModels(java_sub_models)
1721
+
1722
+ return _java_obj
1723
+
1724
+
1725
+ if __name__ == "__main__":
1726
+ import doctest
1727
+
1728
+ from pyspark.sql import SparkSession
1729
+
1730
+ globs = globals().copy()
1731
+
1732
+ # The small batch size here ensures that we see multiple batches,
1733
+ # even in these small test examples:
1734
+ spark = SparkSession.builder.master("local[2]").appName("ml.tuning tests").getOrCreate()
1735
+ sc = spark.sparkContext
1736
+ globs["sc"] = sc
1737
+ globs["spark"] = spark
1738
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
1739
+ spark.stop()
1740
+ if failure_count:
1741
+ sys.exit(-1)