snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3335 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import sys
19
+
20
+ from typing import Any, Dict, Generic, List, Optional, TypeVar, TYPE_CHECKING
21
+
22
+ from abc import ABCMeta
23
+
24
+ from pyspark import keyword_only, since
25
+ from pyspark.ml import Predictor, PredictionModel
26
+ from pyspark.ml.base import _PredictorParams
27
+ from pyspark.ml.param.shared import (
28
+ HasFeaturesCol,
29
+ HasLabelCol,
30
+ HasPredictionCol,
31
+ HasWeightCol,
32
+ Param,
33
+ Params,
34
+ TypeConverters,
35
+ HasMaxIter,
36
+ HasTol,
37
+ HasFitIntercept,
38
+ HasAggregationDepth,
39
+ HasMaxBlockSizeInMB,
40
+ HasRegParam,
41
+ HasSolver,
42
+ HasStepSize,
43
+ HasSeed,
44
+ HasElasticNetParam,
45
+ HasStandardization,
46
+ HasLoss,
47
+ HasVarianceCol,
48
+ )
49
+ from pyspark.ml.tree import (
50
+ _DecisionTreeModel,
51
+ _DecisionTreeParams,
52
+ _TreeEnsembleModel,
53
+ _RandomForestParams,
54
+ _GBTParams,
55
+ _TreeRegressorParams,
56
+ )
57
+ from pyspark.ml.base import Transformer
58
+ from pyspark.ml.linalg import Vector, Matrix
59
+ from pyspark.ml.util import (
60
+ JavaMLWritable,
61
+ JavaMLReadable,
62
+ HasTrainingSummary,
63
+ GeneralJavaMLWritable,
64
+ )
65
+ from pyspark.ml.wrapper import (
66
+ JavaEstimator,
67
+ JavaModel,
68
+ JavaPredictor,
69
+ JavaPredictionModel,
70
+ JavaTransformer,
71
+ JavaWrapper,
72
+ )
73
+ from pyspark.ml.common import inherit_doc
74
+ from pyspark.sql import DataFrame
75
+
76
+ if TYPE_CHECKING:
77
+ from py4j.java_gateway import JavaObject
78
+
79
+ T = TypeVar("T")
80
+ M = TypeVar("M", bound=Transformer)
81
+ JM = TypeVar("JM", bound=JavaTransformer)
82
+
83
+
84
+ __all__ = [
85
+ "AFTSurvivalRegression",
86
+ "AFTSurvivalRegressionModel",
87
+ "DecisionTreeRegressor",
88
+ "DecisionTreeRegressionModel",
89
+ "GBTRegressor",
90
+ "GBTRegressionModel",
91
+ "GeneralizedLinearRegression",
92
+ "GeneralizedLinearRegressionModel",
93
+ "GeneralizedLinearRegressionSummary",
94
+ "GeneralizedLinearRegressionTrainingSummary",
95
+ "IsotonicRegression",
96
+ "IsotonicRegressionModel",
97
+ "LinearRegression",
98
+ "LinearRegressionModel",
99
+ "LinearRegressionSummary",
100
+ "LinearRegressionTrainingSummary",
101
+ "RandomForestRegressor",
102
+ "RandomForestRegressionModel",
103
+ "FMRegressor",
104
+ "FMRegressionModel",
105
+ ]
106
+
107
+
108
+ class Regressor(Predictor[M], _PredictorParams, Generic[M], metaclass=ABCMeta):
109
+ """
110
+ Regressor for regression tasks.
111
+
112
+ .. versionadded:: 3.0.0
113
+ """
114
+
115
+ pass
116
+
117
+
118
+ class RegressionModel(PredictionModel[T], _PredictorParams, metaclass=ABCMeta):
119
+ """
120
+ Model produced by a ``Regressor``.
121
+
122
+ .. versionadded:: 3.0.0
123
+ """
124
+
125
+ pass
126
+
127
+
128
+ class _JavaRegressor(Regressor, JavaPredictor[JM], Generic[JM], metaclass=ABCMeta):
129
+ """
130
+ Java Regressor for regression tasks.
131
+
132
+ .. versionadded:: 3.0.0
133
+ """
134
+
135
+ pass
136
+
137
+
138
+ class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=ABCMeta):
139
+ """
140
+ Java Model produced by a ``_JavaRegressor``.
141
+ To be mixed in with :class:`pyspark.ml.JavaModel`
142
+
143
+ .. versionadded:: 3.0.0
144
+ """
145
+
146
+ pass
147
+
148
+
149
+ class _LinearRegressionParams(
150
+ _PredictorParams,
151
+ HasRegParam,
152
+ HasElasticNetParam,
153
+ HasMaxIter,
154
+ HasTol,
155
+ HasFitIntercept,
156
+ HasStandardization,
157
+ HasWeightCol,
158
+ HasSolver,
159
+ HasAggregationDepth,
160
+ HasLoss,
161
+ HasMaxBlockSizeInMB,
162
+ ):
163
+ """
164
+ Params for :py:class:`LinearRegression` and :py:class:`LinearRegressionModel`.
165
+
166
+ .. versionadded:: 3.0.0
167
+ """
168
+
169
+ solver: Param[str] = Param(
170
+ Params._dummy(),
171
+ "solver",
172
+ "The solver algorithm for optimization. Supported " + "options: auto, normal, l-bfgs.",
173
+ typeConverter=TypeConverters.toString,
174
+ )
175
+
176
+ loss: Param[str] = Param(
177
+ Params._dummy(),
178
+ "loss",
179
+ "The loss function to be optimized. Supported " + "options: squaredError, huber.",
180
+ typeConverter=TypeConverters.toString,
181
+ )
182
+
183
+ epsilon: Param[float] = Param(
184
+ Params._dummy(),
185
+ "epsilon",
186
+ "The shape parameter to control the amount of "
187
+ + "robustness. Must be > 1.0. Only valid when loss is huber",
188
+ typeConverter=TypeConverters.toFloat,
189
+ )
190
+
191
+ def __init__(self, *args: Any):
192
+ super(_LinearRegressionParams, self).__init__(*args)
193
+ self._setDefault(
194
+ maxIter=100,
195
+ regParam=0.0,
196
+ tol=1e-6,
197
+ loss="squaredError",
198
+ epsilon=1.35,
199
+ maxBlockSizeInMB=0.0,
200
+ )
201
+
202
+ @since("2.3.0")
203
+ def getEpsilon(self) -> float:
204
+ """
205
+ Gets the value of epsilon or its default value.
206
+ """
207
+ return self.getOrDefault(self.epsilon)
208
+
209
+
210
+ @inherit_doc
211
+ class LinearRegression(
212
+ _JavaRegressor["LinearRegressionModel"],
213
+ _LinearRegressionParams,
214
+ JavaMLWritable,
215
+ JavaMLReadable["LinearRegression"],
216
+ ):
217
+ """
218
+ Linear regression.
219
+
220
+ The learning objective is to minimize the specified loss function, with regularization.
221
+ This supports two kinds of loss:
222
+
223
+ * squaredError (a.k.a squared loss)
224
+ * huber (a hybrid of squared error for relatively small errors and absolute error for \
225
+ relatively large ones, and we estimate the scale parameter from training data)
226
+
227
+ This supports multiple types of regularization:
228
+
229
+ * none (a.k.a. ordinary least squares)
230
+ * L2 (ridge regression)
231
+ * L1 (Lasso)
232
+ * L2 + L1 (elastic net)
233
+
234
+ .. versionadded:: 1.4.0
235
+
236
+ Notes
237
+ -----
238
+ Fitting with huber loss only supports none and L2 regularization.
239
+
240
+ Examples
241
+ --------
242
+ >>> from pyspark.ml.linalg import Vectors
243
+ >>> df = spark.createDataFrame([
244
+ ... (1.0, 2.0, Vectors.dense(1.0)),
245
+ ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
246
+ >>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")
247
+ >>> lr.setMaxIter(5)
248
+ LinearRegression...
249
+ >>> lr.getMaxIter()
250
+ 5
251
+ >>> lr.setRegParam(0.1)
252
+ LinearRegression...
253
+ >>> lr.getRegParam()
254
+ 0.1
255
+ >>> lr.setRegParam(0.0)
256
+ LinearRegression...
257
+ >>> model = lr.fit(df)
258
+ >>> model.setFeaturesCol("features")
259
+ LinearRegressionModel...
260
+ >>> model.setPredictionCol("newPrediction")
261
+ LinearRegressionModel...
262
+ >>> model.getMaxIter()
263
+ 5
264
+ >>> model.getMaxBlockSizeInMB()
265
+ 0.0
266
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
267
+ >>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001
268
+ True
269
+ >>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001
270
+ True
271
+ >>> abs(model.coefficients[0] - 1.0) < 0.001
272
+ True
273
+ >>> abs(model.intercept - 0.0) < 0.001
274
+ True
275
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
276
+ >>> abs(model.transform(test1).head().newPrediction - 1.0) < 0.001
277
+ True
278
+ >>> lr.setParams(featuresCol="vector")
279
+ LinearRegression...
280
+ >>> lr_path = temp_path + "/lr"
281
+ >>> lr.save(lr_path)
282
+ >>> lr2 = LinearRegression.load(lr_path)
283
+ >>> lr2.getMaxIter()
284
+ 5
285
+ >>> model_path = temp_path + "/lr_model"
286
+ >>> model.save(model_path)
287
+ >>> model2 = LinearRegressionModel.load(model_path)
288
+ >>> model.coefficients[0] == model2.coefficients[0]
289
+ True
290
+ >>> model.intercept == model2.intercept
291
+ True
292
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
293
+ True
294
+ >>> model.numFeatures
295
+ 1
296
+ >>> model.write().format("pmml").save(model_path + "_2")
297
+ """
298
+
299
+ _input_kwargs: Dict[str, Any]
300
+
301
+ @keyword_only
302
+ def __init__(
303
+ self,
304
+ *,
305
+ featuresCol: str = "features",
306
+ labelCol: str = "label",
307
+ predictionCol: str = "prediction",
308
+ maxIter: int = 100,
309
+ regParam: float = 0.0,
310
+ elasticNetParam: float = 0.0,
311
+ tol: float = 1e-6,
312
+ fitIntercept: bool = True,
313
+ standardization: bool = True,
314
+ solver: str = "auto",
315
+ weightCol: Optional[str] = None,
316
+ aggregationDepth: int = 2,
317
+ loss: str = "squaredError",
318
+ epsilon: float = 1.35,
319
+ maxBlockSizeInMB: float = 0.0,
320
+ ):
321
+ """
322
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
323
+ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
324
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
325
+ loss="squaredError", epsilon=1.35, maxBlockSizeInMB=0.0)
326
+ """
327
+ super(LinearRegression, self).__init__()
328
+ self._java_obj = self._new_java_obj(
329
+ "org.apache.spark.ml.regression.LinearRegression", self.uid
330
+ )
331
+ kwargs = self._input_kwargs
332
+ self.setParams(**kwargs)
333
+
334
+ @keyword_only
335
+ @since("1.4.0")
336
+ def setParams(
337
+ self,
338
+ *,
339
+ featuresCol: str = "features",
340
+ labelCol: str = "label",
341
+ predictionCol: str = "prediction",
342
+ maxIter: int = 100,
343
+ regParam: float = 0.0,
344
+ elasticNetParam: float = 0.0,
345
+ tol: float = 1e-6,
346
+ fitIntercept: bool = True,
347
+ standardization: bool = True,
348
+ solver: str = "auto",
349
+ weightCol: Optional[str] = None,
350
+ aggregationDepth: int = 2,
351
+ loss: str = "squaredError",
352
+ epsilon: float = 1.35,
353
+ maxBlockSizeInMB: float = 0.0,
354
+ ) -> "LinearRegression":
355
+ """
356
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
357
+ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
358
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
359
+ loss="squaredError", epsilon=1.35, maxBlockSizeInMB=0.0)
360
+ Sets params for linear regression.
361
+ """
362
+ kwargs = self._input_kwargs
363
+ return self._set(**kwargs)
364
+
365
+ def _create_model(self, java_model: "JavaObject") -> "LinearRegressionModel":
366
+ return LinearRegressionModel(java_model)
367
+
368
+ @since("2.3.0")
369
+ def setEpsilon(self, value: float) -> "LinearRegression":
370
+ """
371
+ Sets the value of :py:attr:`epsilon`.
372
+ """
373
+ return self._set(epsilon=value)
374
+
375
+ def setMaxIter(self, value: int) -> "LinearRegression":
376
+ """
377
+ Sets the value of :py:attr:`maxIter`.
378
+ """
379
+ return self._set(maxIter=value)
380
+
381
+ def setRegParam(self, value: float) -> "LinearRegression":
382
+ """
383
+ Sets the value of :py:attr:`regParam`.
384
+ """
385
+ return self._set(regParam=value)
386
+
387
+ def setTol(self, value: float) -> "LinearRegression":
388
+ """
389
+ Sets the value of :py:attr:`tol`.
390
+ """
391
+ return self._set(tol=value)
392
+
393
+ def setElasticNetParam(self, value: float) -> "LinearRegression":
394
+ """
395
+ Sets the value of :py:attr:`elasticNetParam`.
396
+ """
397
+ return self._set(elasticNetParam=value)
398
+
399
+ def setFitIntercept(self, value: bool) -> "LinearRegression":
400
+ """
401
+ Sets the value of :py:attr:`fitIntercept`.
402
+ """
403
+ return self._set(fitIntercept=value)
404
+
405
+ def setStandardization(self, value: bool) -> "LinearRegression":
406
+ """
407
+ Sets the value of :py:attr:`standardization`.
408
+ """
409
+ return self._set(standardization=value)
410
+
411
+ def setWeightCol(self, value: str) -> "LinearRegression":
412
+ """
413
+ Sets the value of :py:attr:`weightCol`.
414
+ """
415
+ return self._set(weightCol=value)
416
+
417
+ def setSolver(self, value: str) -> "LinearRegression":
418
+ """
419
+ Sets the value of :py:attr:`solver`.
420
+ """
421
+ return self._set(solver=value)
422
+
423
+ def setAggregationDepth(self, value: int) -> "LinearRegression":
424
+ """
425
+ Sets the value of :py:attr:`aggregationDepth`.
426
+ """
427
+ return self._set(aggregationDepth=value)
428
+
429
+ def setLoss(self, value: str) -> "LinearRegression":
430
+ """
431
+ Sets the value of :py:attr:`loss`.
432
+ """
433
+ return self._set(lossType=value)
434
+
435
+ @since("3.1.0")
436
+ def setMaxBlockSizeInMB(self, value: float) -> "LinearRegression":
437
+ """
438
+ Sets the value of :py:attr:`maxBlockSizeInMB`.
439
+ """
440
+ return self._set(maxBlockSizeInMB=value)
441
+
442
+
443
+ class LinearRegressionModel(
444
+ _JavaRegressionModel,
445
+ _LinearRegressionParams,
446
+ GeneralJavaMLWritable,
447
+ JavaMLReadable["LinearRegressionModel"],
448
+ HasTrainingSummary["LinearRegressionSummary"],
449
+ ):
450
+ """
451
+ Model fitted by :class:`LinearRegression`.
452
+
453
+ .. versionadded:: 1.4.0
454
+ """
455
+
456
+ @property
457
+ @since("2.0.0")
458
+ def coefficients(self) -> Vector:
459
+ """
460
+ Model coefficients.
461
+ """
462
+ return self._call_java("coefficients")
463
+
464
+ @property
465
+ @since("1.4.0")
466
+ def intercept(self) -> float:
467
+ """
468
+ Model intercept.
469
+ """
470
+ return self._call_java("intercept")
471
+
472
+ @property
473
+ @since("2.3.0")
474
+ def scale(self) -> float:
475
+ r"""
476
+ The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0.
477
+ """
478
+ return self._call_java("scale")
479
+
480
+ @property
481
+ @since("2.0.0")
482
+ def summary(self) -> "LinearRegressionTrainingSummary":
483
+ """
484
+ Gets summary (residuals, MSE, r-squared ) of model on
485
+ training set. An exception is thrown if
486
+ `trainingSummary is None`.
487
+ """
488
+ if self.hasSummary:
489
+ return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
490
+ else:
491
+ raise RuntimeError(
492
+ "No training summary available for this %s" % self.__class__.__name__
493
+ )
494
+
495
+ def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary":
496
+ """
497
+ Evaluates the model on a test dataset.
498
+
499
+ .. versionadded:: 2.0.0
500
+
501
+ Parameters
502
+ ----------
503
+ dataset : :py:class:`pyspark.sql.DataFrame`
504
+ Test dataset to evaluate model on, where dataset is an
505
+ instance of :py:class:`pyspark.sql.DataFrame`
506
+ """
507
+ if not isinstance(dataset, DataFrame):
508
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
509
+ java_lr_summary = self._call_java("evaluate", dataset)
510
+ return LinearRegressionSummary(java_lr_summary)
511
+
512
+
513
+ class LinearRegressionSummary(JavaWrapper):
514
+ """
515
+ Linear regression results evaluated on a dataset.
516
+
517
+ .. versionadded:: 2.0.0
518
+ """
519
+
520
+ @property
521
+ @since("2.0.0")
522
+ def predictions(self) -> DataFrame:
523
+ """
524
+ Dataframe outputted by the model's `transform` method.
525
+ """
526
+ return self._call_java("predictions")
527
+
528
+ @property
529
+ @since("2.0.0")
530
+ def predictionCol(self) -> str:
531
+ """
532
+ Field in "predictions" which gives the predicted value of
533
+ the label at each instance.
534
+ """
535
+ return self._call_java("predictionCol")
536
+
537
+ @property
538
+ @since("2.0.0")
539
+ def labelCol(self) -> str:
540
+ """
541
+ Field in "predictions" which gives the true label of each
542
+ instance.
543
+ """
544
+ return self._call_java("labelCol")
545
+
546
+ @property
547
+ @since("2.0.0")
548
+ def featuresCol(self) -> str:
549
+ """
550
+ Field in "predictions" which gives the features of each instance
551
+ as a vector.
552
+ """
553
+ return self._call_java("featuresCol")
554
+
555
+ @property
556
+ @since("2.0.0")
557
+ def explainedVariance(self) -> float:
558
+ r"""
559
+ Returns the explained variance regression score.
560
+ explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}`
561
+
562
+ Notes
563
+ -----
564
+ This ignores instance weights (setting all to 1.0) from
565
+ `LinearRegression.weightCol`. This will change in later Spark
566
+ versions.
567
+
568
+ For additional information see
569
+ `Explained variation on Wikipedia \
570
+ <http://en.wikipedia.org/wiki/Explained_variation>`_
571
+ """
572
+ return self._call_java("explainedVariance")
573
+
574
+ @property
575
+ @since("2.0.0")
576
+ def meanAbsoluteError(self) -> float:
577
+ """
578
+ Returns the mean absolute error, which is a risk function
579
+ corresponding to the expected value of the absolute error
580
+ loss or l1-norm loss.
581
+
582
+ Notes
583
+ -----
584
+ This ignores instance weights (setting all to 1.0) from
585
+ `LinearRegression.weightCol`. This will change in later Spark
586
+ versions.
587
+ """
588
+ return self._call_java("meanAbsoluteError")
589
+
590
+ @property
591
+ @since("2.0.0")
592
+ def meanSquaredError(self) -> float:
593
+ """
594
+ Returns the mean squared error, which is a risk function
595
+ corresponding to the expected value of the squared error
596
+ loss or quadratic loss.
597
+
598
+ Notes
599
+ -----
600
+ This ignores instance weights (setting all to 1.0) from
601
+ `LinearRegression.weightCol`. This will change in later Spark
602
+ versions.
603
+ """
604
+ return self._call_java("meanSquaredError")
605
+
606
+ @property
607
+ @since("2.0.0")
608
+ def rootMeanSquaredError(self) -> float:
609
+ """
610
+ Returns the root mean squared error, which is defined as the
611
+ square root of the mean squared error.
612
+
613
+ Notes
614
+ -----
615
+ This ignores instance weights (setting all to 1.0) from
616
+ `LinearRegression.weightCol`. This will change in later Spark
617
+ versions.
618
+ """
619
+ return self._call_java("rootMeanSquaredError")
620
+
621
+ @property
622
+ @since("2.0.0")
623
+ def r2(self) -> float:
624
+ """
625
+ Returns R^2, the coefficient of determination.
626
+
627
+ Notes
628
+ -----
629
+ This ignores instance weights (setting all to 1.0) from
630
+ `LinearRegression.weightCol`. This will change in later Spark
631
+ versions.
632
+
633
+ See also `Wikipedia coefficient of determination \
634
+ <http://en.wikipedia.org/wiki/Coefficient_of_determination>`_
635
+ """
636
+ return self._call_java("r2")
637
+
638
+ @property
639
+ @since("2.4.0")
640
+ def r2adj(self) -> float:
641
+ """
642
+ Returns Adjusted R^2, the adjusted coefficient of determination.
643
+
644
+ Notes
645
+ -----
646
+ This ignores instance weights (setting all to 1.0) from
647
+ `LinearRegression.weightCol`. This will change in later Spark versions.
648
+
649
+ `Wikipedia coefficient of determination, Adjusted R^2 \
650
+ <https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2>`_
651
+ """
652
+ return self._call_java("r2adj")
653
+
654
+ @property
655
+ @since("2.0.0")
656
+ def residuals(self) -> DataFrame:
657
+ """
658
+ Residuals (label - predicted value)
659
+ """
660
+ return self._call_java("residuals")
661
+
662
+ @property
663
+ @since("2.0.0")
664
+ def numInstances(self) -> int:
665
+ """
666
+ Number of instances in DataFrame predictions
667
+ """
668
+ return self._call_java("numInstances")
669
+
670
+ @property
671
+ @since("2.2.0")
672
+ def degreesOfFreedom(self) -> int:
673
+ """
674
+ Degrees of freedom.
675
+ """
676
+ return self._call_java("degreesOfFreedom")
677
+
678
+ @property
679
+ @since("2.0.0")
680
+ def devianceResiduals(self) -> List[float]:
681
+ """
682
+ The weighted residuals, the usual residuals rescaled by the
683
+ square root of the instance weights.
684
+ """
685
+ return self._call_java("devianceResiduals")
686
+
687
+ @property
688
+ def coefficientStandardErrors(self) -> List[float]:
689
+ """
690
+ Standard error of estimated coefficients and intercept.
691
+ This value is only available when using the "normal" solver.
692
+
693
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
694
+ then the last element returned corresponds to the intercept.
695
+
696
+ .. versionadded:: 2.0.0
697
+
698
+ See Also
699
+ --------
700
+ LinearRegression.solver
701
+ """
702
+ return self._call_java("coefficientStandardErrors")
703
+
704
+ @property
705
+ def tValues(self) -> List[float]:
706
+ """
707
+ T-statistic of estimated coefficients and intercept.
708
+ This value is only available when using the "normal" solver.
709
+
710
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
711
+ then the last element returned corresponds to the intercept.
712
+
713
+ .. versionadded:: 2.0.0
714
+
715
+ See Also
716
+ --------
717
+ LinearRegression.solver
718
+ """
719
+ return self._call_java("tValues")
720
+
721
+ @property
722
+ def pValues(self) -> List[float]:
723
+ """
724
+ Two-sided p-value of estimated coefficients and intercept.
725
+ This value is only available when using the "normal" solver.
726
+
727
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
728
+ then the last element returned corresponds to the intercept.
729
+
730
+ .. versionadded:: 2.0.0
731
+
732
+ See Also
733
+ --------
734
+ LinearRegression.solver
735
+ """
736
+ return self._call_java("pValues")
737
+
738
+
739
+ @inherit_doc
740
+ class LinearRegressionTrainingSummary(LinearRegressionSummary):
741
+ """
742
+ Linear regression training results. Currently, the training summary ignores the
743
+ training weights except for the objective trace.
744
+
745
+ .. versionadded:: 2.0.0
746
+ """
747
+
748
+ @property
749
+ def objectiveHistory(self) -> List[float]:
750
+ """
751
+ Objective function (scaled loss + regularization) at each
752
+ iteration.
753
+ This value is only available when using the "l-bfgs" solver.
754
+
755
+ .. versionadded:: 2.0.0
756
+
757
+ See Also
758
+ --------
759
+ LinearRegression.solver
760
+ """
761
+ return self._call_java("objectiveHistory")
762
+
763
+ @property
764
+ def totalIterations(self) -> int:
765
+ """
766
+ Number of training iterations until termination.
767
+ This value is only available when using the "l-bfgs" solver.
768
+
769
+ .. versionadded:: 2.0.0
770
+
771
+ See Also
772
+ --------
773
+ LinearRegression.solver
774
+ """
775
+ return self._call_java("totalIterations")
776
+
777
+
778
+ class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol):
779
+ """
780
+ Params for :py:class:`IsotonicRegression` and :py:class:`IsotonicRegressionModel`.
781
+
782
+ .. versionadded:: 3.0.0
783
+ """
784
+
785
+ isotonic: Param[bool] = Param(
786
+ Params._dummy(),
787
+ "isotonic",
788
+ "whether the output sequence should be isotonic/increasing (true) or"
789
+ + "antitonic/decreasing (false).",
790
+ typeConverter=TypeConverters.toBoolean,
791
+ )
792
+ featureIndex: Param[int] = Param(
793
+ Params._dummy(),
794
+ "featureIndex",
795
+ "The index of the feature if featuresCol is a vector column, no effect otherwise.",
796
+ typeConverter=TypeConverters.toInt,
797
+ )
798
+
799
+ def __init__(self, *args: Any):
800
+ super(_IsotonicRegressionParams, self).__init__(*args)
801
+ self._setDefault(isotonic=True, featureIndex=0)
802
+
803
+ def getIsotonic(self) -> bool:
804
+ """
805
+ Gets the value of isotonic or its default value.
806
+ """
807
+ return self.getOrDefault(self.isotonic)
808
+
809
+ def getFeatureIndex(self) -> int:
810
+ """
811
+ Gets the value of featureIndex or its default value.
812
+ """
813
+ return self.getOrDefault(self.featureIndex)
814
+
815
+
816
+ @inherit_doc
817
+ class IsotonicRegression(
818
+ JavaEstimator, _IsotonicRegressionParams, HasWeightCol, JavaMLWritable, JavaMLReadable
819
+ ):
820
+ """
821
+ Currently implemented using parallelized pool adjacent violators algorithm.
822
+ Only univariate (single feature) algorithm supported.
823
+
824
+ .. versionadded:: 1.6.0
825
+
826
+ Examples
827
+ --------
828
+ >>> from pyspark.ml.linalg import Vectors
829
+ >>> df = spark.createDataFrame([
830
+ ... (1.0, Vectors.dense(1.0)),
831
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
832
+ >>> ir = IsotonicRegression()
833
+ >>> model = ir.fit(df)
834
+ >>> model.setFeaturesCol("features")
835
+ IsotonicRegressionModel...
836
+ >>> model.numFeatures
837
+ 1
838
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
839
+ >>> model.transform(test0).head().prediction
840
+ 0.0
841
+ >>> model.predict(test0.head().features[model.getFeatureIndex()])
842
+ 0.0
843
+ >>> model.boundaries
844
+ DenseVector([0.0, 1.0])
845
+ >>> ir_path = temp_path + "/ir"
846
+ >>> ir.save(ir_path)
847
+ >>> ir2 = IsotonicRegression.load(ir_path)
848
+ >>> ir2.getIsotonic()
849
+ True
850
+ >>> model_path = temp_path + "/ir_model"
851
+ >>> model.save(model_path)
852
+ >>> model2 = IsotonicRegressionModel.load(model_path)
853
+ >>> model.boundaries == model2.boundaries
854
+ True
855
+ >>> model.predictions == model2.predictions
856
+ True
857
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
858
+ True
859
+ """
860
+
861
+ _input_kwargs: Dict[str, Any]
862
+
863
+ @keyword_only
864
+ def __init__(
865
+ self,
866
+ *,
867
+ featuresCol: str = "features",
868
+ labelCol: str = "label",
869
+ predictionCol: str = "prediction",
870
+ weightCol: Optional[str] = None,
871
+ isotonic: bool = True,
872
+ featureIndex: int = 0,
873
+ ):
874
+ """
875
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
876
+ weightCol=None, isotonic=True, featureIndex=0):
877
+ """
878
+ super(IsotonicRegression, self).__init__()
879
+ self._java_obj = self._new_java_obj(
880
+ "org.apache.spark.ml.regression.IsotonicRegression", self.uid
881
+ )
882
+ kwargs = self._input_kwargs
883
+ self.setParams(**kwargs)
884
+
885
+ @keyword_only
886
+ def setParams(
887
+ self,
888
+ *,
889
+ featuresCol: str = "features",
890
+ labelCol: str = "label",
891
+ predictionCol: str = "prediction",
892
+ weightCol: Optional[str] = None,
893
+ isotonic: bool = True,
894
+ featureIndex: int = 0,
895
+ ) -> "IsotonicRegression":
896
+ """
897
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
898
+ weightCol=None, isotonic=True, featureIndex=0):
899
+ Set the params for IsotonicRegression.
900
+ """
901
+ kwargs = self._input_kwargs
902
+ return self._set(**kwargs)
903
+
904
+ def _create_model(self, java_model: "JavaObject") -> "IsotonicRegressionModel":
905
+ return IsotonicRegressionModel(java_model)
906
+
907
+ def setIsotonic(self, value: bool) -> "IsotonicRegression":
908
+ """
909
+ Sets the value of :py:attr:`isotonic`.
910
+ """
911
+ return self._set(isotonic=value)
912
+
913
+ def setFeatureIndex(self, value: int) -> "IsotonicRegression":
914
+ """
915
+ Sets the value of :py:attr:`featureIndex`.
916
+ """
917
+ return self._set(featureIndex=value)
918
+
919
+ @since("1.6.0")
920
+ def setFeaturesCol(self, value: str) -> "IsotonicRegression":
921
+ """
922
+ Sets the value of :py:attr:`featuresCol`.
923
+ """
924
+ return self._set(featuresCol=value)
925
+
926
+ @since("1.6.0")
927
+ def setPredictionCol(self, value: str) -> "IsotonicRegression":
928
+ """
929
+ Sets the value of :py:attr:`predictionCol`.
930
+ """
931
+ return self._set(predictionCol=value)
932
+
933
+ @since("1.6.0")
934
+ def setLabelCol(self, value: str) -> "IsotonicRegression":
935
+ """
936
+ Sets the value of :py:attr:`labelCol`.
937
+ """
938
+ return self._set(labelCol=value)
939
+
940
+ @since("1.6.0")
941
+ def setWeightCol(self, value: str) -> "IsotonicRegression":
942
+ """
943
+ Sets the value of :py:attr:`weightCol`.
944
+ """
945
+ return self._set(weightCol=value)
946
+
947
+
948
+ class IsotonicRegressionModel(
949
+ JavaModel,
950
+ _IsotonicRegressionParams,
951
+ JavaMLWritable,
952
+ JavaMLReadable["IsotonicRegressionModel"],
953
+ ):
954
+ """
955
+ Model fitted by :class:`IsotonicRegression`.
956
+
957
+ .. versionadded:: 1.6.0
958
+ """
959
+
960
+ @since("3.0.0")
961
+ def setFeaturesCol(self, value: str) -> "IsotonicRegressionModel":
962
+ """
963
+ Sets the value of :py:attr:`featuresCol`.
964
+ """
965
+ return self._set(featuresCol=value)
966
+
967
+ @since("3.0.0")
968
+ def setPredictionCol(self, value: str) -> "IsotonicRegressionModel":
969
+ """
970
+ Sets the value of :py:attr:`predictionCol`.
971
+ """
972
+ return self._set(predictionCol=value)
973
+
974
+ def setFeatureIndex(self, value: int) -> "IsotonicRegressionModel":
975
+ """
976
+ Sets the value of :py:attr:`featureIndex`.
977
+ """
978
+ return self._set(featureIndex=value)
979
+
980
+ @property
981
+ @since("1.6.0")
982
+ def boundaries(self) -> Vector:
983
+ """
984
+ Boundaries in increasing order for which predictions are known.
985
+ """
986
+ return self._call_java("boundaries")
987
+
988
+ @property
989
+ @since("1.6.0")
990
+ def predictions(self) -> Vector:
991
+ """
992
+ Predictions associated with the boundaries at the same index, monotone because of isotonic
993
+ regression.
994
+ """
995
+ return self._call_java("predictions")
996
+
997
+ @property
998
+ @since("3.0.0")
999
+ def numFeatures(self) -> int:
1000
+ """
1001
+ Returns the number of features the model was trained on. If unknown, returns -1
1002
+ """
1003
+ return self._call_java("numFeatures")
1004
+
1005
+ @since("3.0.0")
1006
+ def predict(self, value: float) -> float:
1007
+ """
1008
+ Predict label for the given features.
1009
+ """
1010
+ return self._call_java("predict", value)
1011
+
1012
+
1013
+ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
1014
+ """
1015
+ Params for :py:class:`DecisionTreeRegressor` and :py:class:`DecisionTreeRegressionModel`.
1016
+
1017
+ .. versionadded:: 3.0.0
1018
+ """
1019
+
1020
+ def __init__(self, *args: Any):
1021
+ super(_DecisionTreeRegressorParams, self).__init__(*args)
1022
+ self._setDefault(
1023
+ maxDepth=5,
1024
+ maxBins=32,
1025
+ minInstancesPerNode=1,
1026
+ minInfoGain=0.0,
1027
+ maxMemoryInMB=256,
1028
+ cacheNodeIds=False,
1029
+ checkpointInterval=10,
1030
+ impurity="variance",
1031
+ leafCol="",
1032
+ minWeightFractionPerNode=0.0,
1033
+ )
1034
+
1035
+
1036
+ @inherit_doc
1037
+ class DecisionTreeRegressor(
1038
+ _JavaRegressor["DecisionTreeRegressionModel"],
1039
+ _DecisionTreeRegressorParams,
1040
+ JavaMLWritable,
1041
+ JavaMLReadable["DecisionTreeRegressor"],
1042
+ ):
1043
+ """
1044
+ `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
1045
+ learning algorithm for regression.
1046
+ It supports both continuous and categorical features.
1047
+
1048
+ .. versionadded:: 1.4.0
1049
+
1050
+ Examples
1051
+ --------
1052
+ >>> from pyspark.ml.linalg import Vectors
1053
+ >>> df = spark.createDataFrame([
1054
+ ... (1.0, Vectors.dense(1.0)),
1055
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1056
+ >>> dt = DecisionTreeRegressor(maxDepth=2)
1057
+ >>> dt.setVarianceCol("variance")
1058
+ DecisionTreeRegressor...
1059
+ >>> model = dt.fit(df)
1060
+ >>> model.getVarianceCol()
1061
+ 'variance'
1062
+ >>> model.setLeafCol("leafId")
1063
+ DecisionTreeRegressionModel...
1064
+ >>> model.depth
1065
+ 1
1066
+ >>> model.numNodes
1067
+ 3
1068
+ >>> model.featureImportances
1069
+ SparseVector(1, {0: 1.0})
1070
+ >>> model.numFeatures
1071
+ 1
1072
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1073
+ >>> model.predict(test0.head().features)
1074
+ 0.0
1075
+ >>> result = model.transform(test0).head()
1076
+ >>> result.prediction
1077
+ 0.0
1078
+ >>> model.predictLeaf(test0.head().features)
1079
+ 0.0
1080
+ >>> result.leafId
1081
+ 0.0
1082
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1083
+ >>> model.transform(test1).head().prediction
1084
+ 1.0
1085
+ >>> dtr_path = temp_path + "/dtr"
1086
+ >>> dt.save(dtr_path)
1087
+ >>> dt2 = DecisionTreeRegressor.load(dtr_path)
1088
+ >>> dt2.getMaxDepth()
1089
+ 2
1090
+ >>> model_path = temp_path + "/dtr_model"
1091
+ >>> model.save(model_path)
1092
+ >>> model2 = DecisionTreeRegressionModel.load(model_path)
1093
+ >>> model.numNodes == model2.numNodes
1094
+ True
1095
+ >>> model.depth == model2.depth
1096
+ True
1097
+ >>> model.transform(test1).head().variance
1098
+ 0.0
1099
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
1100
+ True
1101
+ >>> df3 = spark.createDataFrame([
1102
+ ... (1.0, 0.2, Vectors.dense(1.0)),
1103
+ ... (1.0, 0.8, Vectors.dense(1.0)),
1104
+ ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
1105
+ >>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
1106
+ >>> model3 = dt3.fit(df3)
1107
+ >>> print(model3.toDebugString)
1108
+ DecisionTreeRegressionModel...depth=1, numNodes=3...
1109
+ """
1110
+
1111
+ _input_kwargs: Dict[str, Any]
1112
+
1113
+ @keyword_only
1114
+ def __init__(
1115
+ self,
1116
+ *,
1117
+ featuresCol: str = "features",
1118
+ labelCol: str = "label",
1119
+ predictionCol: str = "prediction",
1120
+ maxDepth: int = 5,
1121
+ maxBins: int = 32,
1122
+ minInstancesPerNode: int = 1,
1123
+ minInfoGain: float = 0.0,
1124
+ maxMemoryInMB: int = 256,
1125
+ cacheNodeIds: bool = False,
1126
+ checkpointInterval: int = 10,
1127
+ impurity: str = "variance",
1128
+ seed: Optional[int] = None,
1129
+ varianceCol: Optional[str] = None,
1130
+ weightCol: Optional[str] = None,
1131
+ leafCol: str = "",
1132
+ minWeightFractionPerNode: float = 0.0,
1133
+ ):
1134
+ """
1135
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1136
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1137
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1138
+ impurity="variance", seed=None, varianceCol=None, weightCol=None, \
1139
+ leafCol="", minWeightFractionPerNode=0.0)
1140
+ """
1141
+ super(DecisionTreeRegressor, self).__init__()
1142
+ self._java_obj = self._new_java_obj(
1143
+ "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid
1144
+ )
1145
+ kwargs = self._input_kwargs
1146
+ self.setParams(**kwargs)
1147
+
1148
+ @keyword_only
1149
+ @since("1.4.0")
1150
+ def setParams(
1151
+ self,
1152
+ *,
1153
+ featuresCol: str = "features",
1154
+ labelCol: str = "label",
1155
+ predictionCol: str = "prediction",
1156
+ maxDepth: int = 5,
1157
+ maxBins: int = 32,
1158
+ minInstancesPerNode: int = 1,
1159
+ minInfoGain: float = 0.0,
1160
+ maxMemoryInMB: int = 256,
1161
+ cacheNodeIds: bool = False,
1162
+ checkpointInterval: int = 10,
1163
+ impurity: str = "variance",
1164
+ seed: Optional[int] = None,
1165
+ varianceCol: Optional[str] = None,
1166
+ weightCol: Optional[str] = None,
1167
+ leafCol: str = "",
1168
+ minWeightFractionPerNode: float = 0.0,
1169
+ ) -> "DecisionTreeRegressor":
1170
+ """
1171
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1172
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1173
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1174
+ impurity="variance", seed=None, varianceCol=None, weightCol=None, \
1175
+ leafCol="", minWeightFractionPerNode=0.0)
1176
+ Sets params for the DecisionTreeRegressor.
1177
+ """
1178
+ kwargs = self._input_kwargs
1179
+ return self._set(**kwargs)
1180
+
1181
+ def _create_model(self, java_model: "JavaObject") -> "DecisionTreeRegressionModel":
1182
+ return DecisionTreeRegressionModel(java_model)
1183
+
1184
+ @since("1.4.0")
1185
+ def setMaxDepth(self, value: int) -> "DecisionTreeRegressor":
1186
+ """
1187
+ Sets the value of :py:attr:`maxDepth`.
1188
+ """
1189
+ return self._set(maxDepth=value)
1190
+
1191
+ @since("1.4.0")
1192
+ def setMaxBins(self, value: int) -> "DecisionTreeRegressor":
1193
+ """
1194
+ Sets the value of :py:attr:`maxBins`.
1195
+ """
1196
+ return self._set(maxBins=value)
1197
+
1198
+ @since("1.4.0")
1199
+ def setMinInstancesPerNode(self, value: int) -> "DecisionTreeRegressor":
1200
+ """
1201
+ Sets the value of :py:attr:`minInstancesPerNode`.
1202
+ """
1203
+ return self._set(minInstancesPerNode=value)
1204
+
1205
+ @since("3.0.0")
1206
+ def setMinWeightFractionPerNode(self, value: float) -> "DecisionTreeRegressor":
1207
+ """
1208
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
1209
+ """
1210
+ return self._set(minWeightFractionPerNode=value)
1211
+
1212
+ @since("1.4.0")
1213
+ def setMinInfoGain(self, value: float) -> "DecisionTreeRegressor":
1214
+ """
1215
+ Sets the value of :py:attr:`minInfoGain`.
1216
+ """
1217
+ return self._set(minInfoGain=value)
1218
+
1219
+ @since("1.4.0")
1220
+ def setMaxMemoryInMB(self, value: int) -> "DecisionTreeRegressor":
1221
+ """
1222
+ Sets the value of :py:attr:`maxMemoryInMB`.
1223
+ """
1224
+ return self._set(maxMemoryInMB=value)
1225
+
1226
+ @since("1.4.0")
1227
+ def setCacheNodeIds(self, value: bool) -> "DecisionTreeRegressor":
1228
+ """
1229
+ Sets the value of :py:attr:`cacheNodeIds`.
1230
+ """
1231
+ return self._set(cacheNodeIds=value)
1232
+
1233
+ @since("1.4.0")
1234
+ def setImpurity(self, value: str) -> "DecisionTreeRegressor":
1235
+ """
1236
+ Sets the value of :py:attr:`impurity`.
1237
+ """
1238
+ return self._set(impurity=value)
1239
+
1240
+ @since("1.4.0")
1241
+ def setCheckpointInterval(self, value: int) -> "DecisionTreeRegressor":
1242
+ """
1243
+ Sets the value of :py:attr:`checkpointInterval`.
1244
+ """
1245
+ return self._set(checkpointInterval=value)
1246
+
1247
+ def setSeed(self, value: int) -> "DecisionTreeRegressor":
1248
+ """
1249
+ Sets the value of :py:attr:`seed`.
1250
+ """
1251
+ return self._set(seed=value)
1252
+
1253
+ @since("3.0.0")
1254
+ def setWeightCol(self, value: str) -> "DecisionTreeRegressor":
1255
+ """
1256
+ Sets the value of :py:attr:`weightCol`.
1257
+ """
1258
+ return self._set(weightCol=value)
1259
+
1260
+ @since("2.0.0")
1261
+ def setVarianceCol(self, value: str) -> "DecisionTreeRegressor":
1262
+ """
1263
+ Sets the value of :py:attr:`varianceCol`.
1264
+ """
1265
+ return self._set(varianceCol=value)
1266
+
1267
+
1268
+ @inherit_doc
1269
+ class DecisionTreeRegressionModel(
1270
+ _JavaRegressionModel,
1271
+ _DecisionTreeModel,
1272
+ _DecisionTreeRegressorParams,
1273
+ JavaMLWritable,
1274
+ JavaMLReadable["DecisionTreeRegressionModel"],
1275
+ ):
1276
+ """
1277
+ Model fitted by :class:`DecisionTreeRegressor`.
1278
+
1279
+ .. versionadded:: 1.4.0
1280
+ """
1281
+
1282
+ @since("3.0.0")
1283
+ def setVarianceCol(self, value: str) -> "DecisionTreeRegressionModel":
1284
+ """
1285
+ Sets the value of :py:attr:`varianceCol`.
1286
+ """
1287
+ return self._set(varianceCol=value)
1288
+
1289
+ @property
1290
+ def featureImportances(self) -> Vector:
1291
+ """
1292
+ Estimate of the importance of each feature.
1293
+
1294
+ This generalizes the idea of "Gini" importance to other losses,
1295
+ following the explanation of Gini importance from "Random Forests" documentation
1296
+ by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1297
+
1298
+ This feature importance is calculated as follows:
1299
+ - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1300
+ where gain is scaled by the number of instances passing through node
1301
+ - Normalize importances for tree to sum to 1.
1302
+
1303
+ .. versionadded:: 2.0.0
1304
+
1305
+ Notes
1306
+ -----
1307
+ Feature importance for single decision trees can have high variance due to
1308
+ correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
1309
+ to determine feature importance instead.
1310
+ """
1311
+ return self._call_java("featureImportances")
1312
+
1313
+
1314
+ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
1315
+ """
1316
+ Params for :py:class:`RandomForestRegressor` and :py:class:`RandomForestRegressionModel`.
1317
+
1318
+ .. versionadded:: 3.0.0
1319
+ """
1320
+
1321
+ def __init__(self, *args: Any):
1322
+ super(_RandomForestRegressorParams, self).__init__(*args)
1323
+ self._setDefault(
1324
+ maxDepth=5,
1325
+ maxBins=32,
1326
+ minInstancesPerNode=1,
1327
+ minInfoGain=0.0,
1328
+ maxMemoryInMB=256,
1329
+ cacheNodeIds=False,
1330
+ checkpointInterval=10,
1331
+ impurity="variance",
1332
+ subsamplingRate=1.0,
1333
+ numTrees=20,
1334
+ featureSubsetStrategy="auto",
1335
+ leafCol="",
1336
+ minWeightFractionPerNode=0.0,
1337
+ bootstrap=True,
1338
+ )
1339
+
1340
+
1341
+ @inherit_doc
1342
+ class RandomForestRegressor(
1343
+ _JavaRegressor["RandomForestRegressionModel"],
1344
+ _RandomForestRegressorParams,
1345
+ JavaMLWritable,
1346
+ JavaMLReadable["RandomForestRegressor"],
1347
+ ):
1348
+ """
1349
+ `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1350
+ learning algorithm for regression.
1351
+ It supports both continuous and categorical features.
1352
+
1353
+ .. versionadded:: 1.4.0
1354
+
1355
+ Examples
1356
+ --------
1357
+ >>> from numpy import allclose
1358
+ >>> from pyspark.ml.linalg import Vectors
1359
+ >>> df = spark.createDataFrame([
1360
+ ... (1.0, Vectors.dense(1.0)),
1361
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1362
+ >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
1363
+ >>> rf.getMinWeightFractionPerNode()
1364
+ 0.0
1365
+ >>> rf.setSeed(42)
1366
+ RandomForestRegressor...
1367
+ >>> model = rf.fit(df)
1368
+ >>> model.getBootstrap()
1369
+ True
1370
+ >>> model.getSeed()
1371
+ 42
1372
+ >>> model.setLeafCol("leafId")
1373
+ RandomForestRegressionModel...
1374
+ >>> model.featureImportances
1375
+ SparseVector(1, {0: 1.0})
1376
+ >>> allclose(model.treeWeights, [1.0, 1.0])
1377
+ True
1378
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1379
+ >>> model.predict(test0.head().features)
1380
+ 0.0
1381
+ >>> model.predictLeaf(test0.head().features)
1382
+ DenseVector([0.0, 0.0])
1383
+ >>> result = model.transform(test0).head()
1384
+ >>> result.prediction
1385
+ 0.0
1386
+ >>> result.leafId
1387
+ DenseVector([0.0, 0.0])
1388
+ >>> model.numFeatures
1389
+ 1
1390
+ >>> model.trees
1391
+ [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1392
+ >>> model.getNumTrees
1393
+ 2
1394
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1395
+ >>> model.transform(test1).head().prediction
1396
+ 0.5
1397
+ >>> rfr_path = temp_path + "/rfr"
1398
+ >>> rf.save(rfr_path)
1399
+ >>> rf2 = RandomForestRegressor.load(rfr_path)
1400
+ >>> rf2.getNumTrees()
1401
+ 2
1402
+ >>> model_path = temp_path + "/rfr_model"
1403
+ >>> model.save(model_path)
1404
+ >>> model2 = RandomForestRegressionModel.load(model_path)
1405
+ >>> model.featureImportances == model2.featureImportances
1406
+ True
1407
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
1408
+ True
1409
+ """
1410
+
1411
+ _input_kwargs: Dict[str, Any]
1412
+
1413
+ @keyword_only
1414
+ def __init__(
1415
+ self,
1416
+ *,
1417
+ featuresCol: str = "features",
1418
+ labelCol: str = "label",
1419
+ predictionCol: str = "prediction",
1420
+ maxDepth: int = 5,
1421
+ maxBins: int = 32,
1422
+ minInstancesPerNode: int = 1,
1423
+ minInfoGain: float = 0.0,
1424
+ maxMemoryInMB: int = 256,
1425
+ cacheNodeIds: bool = False,
1426
+ checkpointInterval: int = 10,
1427
+ impurity: str = "variance",
1428
+ subsamplingRate: float = 1.0,
1429
+ seed: Optional[int] = None,
1430
+ numTrees: int = 20,
1431
+ featureSubsetStrategy: str = "auto",
1432
+ leafCol: str = "",
1433
+ minWeightFractionPerNode: float = 0.0,
1434
+ weightCol: Optional[str] = None,
1435
+ bootstrap: Optional[bool] = True,
1436
+ ):
1437
+ """
1438
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1439
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1440
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1441
+ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1442
+ featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \
1443
+ weightCol=None, bootstrap=True)
1444
+ """
1445
+ super(RandomForestRegressor, self).__init__()
1446
+ self._java_obj = self._new_java_obj(
1447
+ "org.apache.spark.ml.regression.RandomForestRegressor", self.uid
1448
+ )
1449
+ kwargs = self._input_kwargs
1450
+ self.setParams(**kwargs)
1451
+
1452
+ @keyword_only
1453
+ @since("1.4.0")
1454
+ def setParams(
1455
+ self,
1456
+ *,
1457
+ featuresCol: str = "features",
1458
+ labelCol: str = "label",
1459
+ predictionCol: str = "prediction",
1460
+ maxDepth: int = 5,
1461
+ maxBins: int = 32,
1462
+ minInstancesPerNode: int = 1,
1463
+ minInfoGain: float = 0.0,
1464
+ maxMemoryInMB: int = 256,
1465
+ cacheNodeIds: bool = False,
1466
+ checkpointInterval: int = 10,
1467
+ impurity: str = "variance",
1468
+ subsamplingRate: float = 1.0,
1469
+ seed: Optional[int] = None,
1470
+ numTrees: int = 20,
1471
+ featureSubsetStrategy: str = "auto",
1472
+ leafCol: str = "",
1473
+ minWeightFractionPerNode: float = 0.0,
1474
+ weightCol: Optional[str] = None,
1475
+ bootstrap: Optional[bool] = True,
1476
+ ) -> "RandomForestRegressor":
1477
+ """
1478
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1479
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1480
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1481
+ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1482
+ featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0, \
1483
+ weightCol=None, bootstrap=True)
1484
+ Sets params for linear regression.
1485
+ """
1486
+ kwargs = self._input_kwargs
1487
+ return self._set(**kwargs)
1488
+
1489
+ def _create_model(self, java_model: "JavaObject") -> "RandomForestRegressionModel":
1490
+ return RandomForestRegressionModel(java_model)
1491
+
1492
+ def setMaxDepth(self, value: int) -> "RandomForestRegressor":
1493
+ """
1494
+ Sets the value of :py:attr:`maxDepth`.
1495
+ """
1496
+ return self._set(maxDepth=value)
1497
+
1498
+ def setMaxBins(self, value: int) -> "RandomForestRegressor":
1499
+ """
1500
+ Sets the value of :py:attr:`maxBins`.
1501
+ """
1502
+ return self._set(maxBins=value)
1503
+
1504
+ def setMinInstancesPerNode(self, value: int) -> "RandomForestRegressor":
1505
+ """
1506
+ Sets the value of :py:attr:`minInstancesPerNode`.
1507
+ """
1508
+ return self._set(minInstancesPerNode=value)
1509
+
1510
+ def setMinInfoGain(self, value: float) -> "RandomForestRegressor":
1511
+ """
1512
+ Sets the value of :py:attr:`minInfoGain`.
1513
+ """
1514
+ return self._set(minInfoGain=value)
1515
+
1516
+ def setMaxMemoryInMB(self, value: int) -> "RandomForestRegressor":
1517
+ """
1518
+ Sets the value of :py:attr:`maxMemoryInMB`.
1519
+ """
1520
+ return self._set(maxMemoryInMB=value)
1521
+
1522
+ def setCacheNodeIds(self, value: bool) -> "RandomForestRegressor":
1523
+ """
1524
+ Sets the value of :py:attr:`cacheNodeIds`.
1525
+ """
1526
+ return self._set(cacheNodeIds=value)
1527
+
1528
+ @since("1.4.0")
1529
+ def setImpurity(self, value: str) -> "RandomForestRegressor":
1530
+ """
1531
+ Sets the value of :py:attr:`impurity`.
1532
+ """
1533
+ return self._set(impurity=value)
1534
+
1535
+ @since("1.4.0")
1536
+ def setNumTrees(self, value: int) -> "RandomForestRegressor":
1537
+ """
1538
+ Sets the value of :py:attr:`numTrees`.
1539
+ """
1540
+ return self._set(numTrees=value)
1541
+
1542
+ @since("3.0.0")
1543
+ def setBootstrap(self, value: bool) -> "RandomForestRegressor":
1544
+ """
1545
+ Sets the value of :py:attr:`bootstrap`.
1546
+ """
1547
+ return self._set(bootstrap=value)
1548
+
1549
+ @since("1.4.0")
1550
+ def setSubsamplingRate(self, value: float) -> "RandomForestRegressor":
1551
+ """
1552
+ Sets the value of :py:attr:`subsamplingRate`.
1553
+ """
1554
+ return self._set(subsamplingRate=value)
1555
+
1556
+ @since("2.4.0")
1557
+ def setFeatureSubsetStrategy(self, value: str) -> "RandomForestRegressor":
1558
+ """
1559
+ Sets the value of :py:attr:`featureSubsetStrategy`.
1560
+ """
1561
+ return self._set(featureSubsetStrategy=value)
1562
+
1563
+ def setCheckpointInterval(self, value: int) -> "RandomForestRegressor":
1564
+ """
1565
+ Sets the value of :py:attr:`checkpointInterval`.
1566
+ """
1567
+ return self._set(checkpointInterval=value)
1568
+
1569
+ def setSeed(self, value: int) -> "RandomForestRegressor":
1570
+ """
1571
+ Sets the value of :py:attr:`seed`.
1572
+ """
1573
+ return self._set(seed=value)
1574
+
1575
+ @since("3.0.0")
1576
+ def setWeightCol(self, value: str) -> "RandomForestRegressor":
1577
+ """
1578
+ Sets the value of :py:attr:`weightCol`.
1579
+ """
1580
+ return self._set(weightCol=value)
1581
+
1582
+ @since("3.0.0")
1583
+ def setMinWeightFractionPerNode(self, value: float) -> "RandomForestRegressor":
1584
+ """
1585
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
1586
+ """
1587
+ return self._set(minWeightFractionPerNode=value)
1588
+
1589
+
1590
+ class RandomForestRegressionModel(
1591
+ _JavaRegressionModel[Vector],
1592
+ _TreeEnsembleModel,
1593
+ _RandomForestRegressorParams,
1594
+ JavaMLWritable,
1595
+ JavaMLReadable["RandomForestRegressionModel"],
1596
+ ):
1597
+ """
1598
+ Model fitted by :class:`RandomForestRegressor`.
1599
+
1600
+ .. versionadded:: 1.4.0
1601
+ """
1602
+
1603
+ @property
1604
+ @since("2.0.0")
1605
+ def trees(self) -> List[DecisionTreeRegressionModel]:
1606
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
1607
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1608
+
1609
+ @property
1610
+ def featureImportances(self) -> Vector:
1611
+ """
1612
+ Estimate of the importance of each feature.
1613
+
1614
+ Each feature's importance is the average of its importance across all trees in the ensemble
1615
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1616
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1617
+ and follows the implementation from scikit-learn.
1618
+
1619
+ .. versionadded:: 2.0.0
1620
+
1621
+ Examples
1622
+ --------
1623
+ DecisionTreeRegressionModel.featureImportances
1624
+ """
1625
+ return self._call_java("featureImportances")
1626
+
1627
+
1628
+ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
1629
+ """
1630
+ Params for :py:class:`GBTRegressor` and :py:class:`GBTRegressorModel`.
1631
+
1632
+ .. versionadded:: 3.0.0
1633
+ """
1634
+
1635
+ supportedLossTypes: List[str] = ["squared", "absolute"]
1636
+
1637
+ lossType: Param[str] = Param(
1638
+ Params._dummy(),
1639
+ "lossType",
1640
+ "Loss function which GBT tries to minimize (case-insensitive). "
1641
+ + "Supported options: "
1642
+ + ", ".join(supportedLossTypes),
1643
+ typeConverter=TypeConverters.toString,
1644
+ )
1645
+
1646
+ def __init__(self, *args: Any):
1647
+ super(_GBTRegressorParams, self).__init__(*args)
1648
+ self._setDefault(
1649
+ maxDepth=5,
1650
+ maxBins=32,
1651
+ minInstancesPerNode=1,
1652
+ minInfoGain=0.0,
1653
+ maxMemoryInMB=256,
1654
+ cacheNodeIds=False,
1655
+ subsamplingRate=1.0,
1656
+ checkpointInterval=10,
1657
+ lossType="squared",
1658
+ maxIter=20,
1659
+ stepSize=0.1,
1660
+ impurity="variance",
1661
+ featureSubsetStrategy="all",
1662
+ validationTol=0.01,
1663
+ leafCol="",
1664
+ minWeightFractionPerNode=0.0,
1665
+ )
1666
+
1667
+ @since("1.4.0")
1668
+ def getLossType(self) -> str:
1669
+ """
1670
+ Gets the value of lossType or its default value.
1671
+ """
1672
+ return self.getOrDefault(self.lossType)
1673
+
1674
+
1675
+ @inherit_doc
1676
+ class GBTRegressor(
1677
+ _JavaRegressor["GBTRegressionModel"],
1678
+ _GBTRegressorParams,
1679
+ JavaMLWritable,
1680
+ JavaMLReadable["GBTRegressor"],
1681
+ ):
1682
+ """
1683
+ `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
1684
+ learning algorithm for regression.
1685
+ It supports both continuous and categorical features.
1686
+
1687
+ .. versionadded:: 1.4.0
1688
+
1689
+ Examples
1690
+ --------
1691
+ >>> from numpy import allclose
1692
+ >>> from pyspark.ml.linalg import Vectors
1693
+ >>> df = spark.createDataFrame([
1694
+ ... (1.0, Vectors.dense(1.0)),
1695
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1696
+ >>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")
1697
+ >>> gbt.setMaxIter(5)
1698
+ GBTRegressor...
1699
+ >>> gbt.setMinWeightFractionPerNode(0.049)
1700
+ GBTRegressor...
1701
+ >>> gbt.getMaxIter()
1702
+ 5
1703
+ >>> print(gbt.getImpurity())
1704
+ variance
1705
+ >>> print(gbt.getFeatureSubsetStrategy())
1706
+ all
1707
+ >>> model = gbt.fit(df)
1708
+ >>> model.featureImportances
1709
+ SparseVector(1, {0: 1.0})
1710
+ >>> model.numFeatures
1711
+ 1
1712
+ >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
1713
+ True
1714
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1715
+ >>> model.predict(test0.head().features)
1716
+ 0.0
1717
+ >>> model.predictLeaf(test0.head().features)
1718
+ DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1719
+ >>> result = model.transform(test0).head()
1720
+ >>> result.prediction
1721
+ 0.0
1722
+ >>> result.leafId
1723
+ DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1724
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1725
+ >>> model.transform(test1).head().prediction
1726
+ 1.0
1727
+ >>> gbtr_path = temp_path + "gbtr"
1728
+ >>> gbt.save(gbtr_path)
1729
+ >>> gbt2 = GBTRegressor.load(gbtr_path)
1730
+ >>> gbt2.getMaxDepth()
1731
+ 2
1732
+ >>> model_path = temp_path + "gbtr_model"
1733
+ >>> model.save(model_path)
1734
+ >>> model2 = GBTRegressionModel.load(model_path)
1735
+ >>> model.featureImportances == model2.featureImportances
1736
+ True
1737
+ >>> model.treeWeights == model2.treeWeights
1738
+ True
1739
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
1740
+ True
1741
+ >>> model.trees
1742
+ [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1743
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
1744
+ ... ["label", "features"])
1745
+ >>> model.evaluateEachIteration(validation, "squared")
1746
+ [0.0, 0.0, 0.0, 0.0, 0.0]
1747
+ >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1748
+ >>> gbt.getValidationIndicatorCol()
1749
+ 'validationIndicator'
1750
+ >>> gbt.getValidationTol()
1751
+ 0.01
1752
+ """
1753
+
1754
+ _input_kwargs: Dict[str, Any]
1755
+
1756
+ @keyword_only
1757
+ def __init__(
1758
+ self,
1759
+ *,
1760
+ featuresCol: str = "features",
1761
+ labelCol: str = "label",
1762
+ predictionCol: str = "prediction",
1763
+ maxDepth: int = 5,
1764
+ maxBins: int = 32,
1765
+ minInstancesPerNode: int = 1,
1766
+ minInfoGain: float = 0.0,
1767
+ maxMemoryInMB: int = 256,
1768
+ cacheNodeIds: bool = False,
1769
+ subsamplingRate: float = 1.0,
1770
+ checkpointInterval: int = 10,
1771
+ lossType: str = "squared",
1772
+ maxIter: int = 20,
1773
+ stepSize: float = 0.1,
1774
+ seed: Optional[int] = None,
1775
+ impurity: str = "variance",
1776
+ featureSubsetStrategy: str = "all",
1777
+ validationTol: float = 0.1,
1778
+ validationIndicatorCol: Optional[str] = None,
1779
+ leafCol: str = "",
1780
+ minWeightFractionPerNode: float = 0.0,
1781
+ weightCol: Optional[str] = None,
1782
+ ):
1783
+ """
1784
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1785
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1786
+ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1787
+ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1788
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1789
+ validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1790
+ weightCol=None)
1791
+ """
1792
+ super(GBTRegressor, self).__init__()
1793
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
1794
+ kwargs = self._input_kwargs
1795
+ self.setParams(**kwargs)
1796
+
1797
+ @keyword_only
1798
+ @since("1.4.0")
1799
+ def setParams(
1800
+ self,
1801
+ *,
1802
+ featuresCol: str = "features",
1803
+ labelCol: str = "label",
1804
+ predictionCol: str = "prediction",
1805
+ maxDepth: int = 5,
1806
+ maxBins: int = 32,
1807
+ minInstancesPerNode: int = 1,
1808
+ minInfoGain: float = 0.0,
1809
+ maxMemoryInMB: int = 256,
1810
+ cacheNodeIds: bool = False,
1811
+ subsamplingRate: float = 1.0,
1812
+ checkpointInterval: int = 10,
1813
+ lossType: str = "squared",
1814
+ maxIter: int = 20,
1815
+ stepSize: float = 0.1,
1816
+ seed: Optional[int] = None,
1817
+ impurity: str = "variance",
1818
+ featureSubsetStrategy: str = "all",
1819
+ validationTol: float = 0.1,
1820
+ validationIndicatorCol: Optional[str] = None,
1821
+ leafCol: str = "",
1822
+ minWeightFractionPerNode: float = 0.0,
1823
+ weightCol: Optional[str] = None,
1824
+ ) -> "GBTRegressor":
1825
+ """
1826
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1827
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1828
+ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1829
+ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1830
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1831
+ validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1832
+ weightCol=None)
1833
+ Sets params for Gradient Boosted Tree Regression.
1834
+ """
1835
+ kwargs = self._input_kwargs
1836
+ return self._set(**kwargs)
1837
+
1838
+ def _create_model(self, java_model: "JavaObject") -> "GBTRegressionModel":
1839
+ return GBTRegressionModel(java_model)
1840
+
1841
+ @since("1.4.0")
1842
+ def setMaxDepth(self, value: int) -> "GBTRegressor":
1843
+ """
1844
+ Sets the value of :py:attr:`maxDepth`.
1845
+ """
1846
+ return self._set(maxDepth=value)
1847
+
1848
+ @since("1.4.0")
1849
+ def setMaxBins(self, value: int) -> "GBTRegressor":
1850
+ """
1851
+ Sets the value of :py:attr:`maxBins`.
1852
+ """
1853
+ return self._set(maxBins=value)
1854
+
1855
+ @since("1.4.0")
1856
+ def setMinInstancesPerNode(self, value: int) -> "GBTRegressor":
1857
+ """
1858
+ Sets the value of :py:attr:`minInstancesPerNode`.
1859
+ """
1860
+ return self._set(minInstancesPerNode=value)
1861
+
1862
+ @since("1.4.0")
1863
+ def setMinInfoGain(self, value: float) -> "GBTRegressor":
1864
+ """
1865
+ Sets the value of :py:attr:`minInfoGain`.
1866
+ """
1867
+ return self._set(minInfoGain=value)
1868
+
1869
+ @since("1.4.0")
1870
+ def setMaxMemoryInMB(self, value: int) -> "GBTRegressor":
1871
+ """
1872
+ Sets the value of :py:attr:`maxMemoryInMB`.
1873
+ """
1874
+ return self._set(maxMemoryInMB=value)
1875
+
1876
+ @since("1.4.0")
1877
+ def setCacheNodeIds(self, value: bool) -> "GBTRegressor":
1878
+ """
1879
+ Sets the value of :py:attr:`cacheNodeIds`.
1880
+ """
1881
+ return self._set(cacheNodeIds=value)
1882
+
1883
+ @since("1.4.0")
1884
+ def setImpurity(self, value: str) -> "GBTRegressor":
1885
+ """
1886
+ Sets the value of :py:attr:`impurity`.
1887
+ """
1888
+ return self._set(impurity=value)
1889
+
1890
+ @since("1.4.0")
1891
+ def setLossType(self, value: str) -> "GBTRegressor":
1892
+ """
1893
+ Sets the value of :py:attr:`lossType`.
1894
+ """
1895
+ return self._set(lossType=value)
1896
+
1897
+ @since("1.4.0")
1898
+ def setSubsamplingRate(self, value: float) -> "GBTRegressor":
1899
+ """
1900
+ Sets the value of :py:attr:`subsamplingRate`.
1901
+ """
1902
+ return self._set(subsamplingRate=value)
1903
+
1904
+ @since("2.4.0")
1905
+ def setFeatureSubsetStrategy(self, value: str) -> "GBTRegressor":
1906
+ """
1907
+ Sets the value of :py:attr:`featureSubsetStrategy`.
1908
+ """
1909
+ return self._set(featureSubsetStrategy=value)
1910
+
1911
+ @since("3.0.0")
1912
+ def setValidationIndicatorCol(self, value: str) -> "GBTRegressor":
1913
+ """
1914
+ Sets the value of :py:attr:`validationIndicatorCol`.
1915
+ """
1916
+ return self._set(validationIndicatorCol=value)
1917
+
1918
+ @since("1.4.0")
1919
+ def setMaxIter(self, value: int) -> "GBTRegressor":
1920
+ """
1921
+ Sets the value of :py:attr:`maxIter`.
1922
+ """
1923
+ return self._set(maxIter=value)
1924
+
1925
+ @since("1.4.0")
1926
+ def setCheckpointInterval(self, value: int) -> "GBTRegressor":
1927
+ """
1928
+ Sets the value of :py:attr:`checkpointInterval`.
1929
+ """
1930
+ return self._set(checkpointInterval=value)
1931
+
1932
+ @since("1.4.0")
1933
+ def setSeed(self, value: int) -> "GBTRegressor":
1934
+ """
1935
+ Sets the value of :py:attr:`seed`.
1936
+ """
1937
+ return self._set(seed=value)
1938
+
1939
+ @since("1.4.0")
1940
+ def setStepSize(self, value: float) -> "GBTRegressor":
1941
+ """
1942
+ Sets the value of :py:attr:`stepSize`.
1943
+ """
1944
+ return self._set(stepSize=value)
1945
+
1946
+ @since("3.0.0")
1947
+ def setWeightCol(self, value: str) -> "GBTRegressor":
1948
+ """
1949
+ Sets the value of :py:attr:`weightCol`.
1950
+ """
1951
+ return self._set(weightCol=value)
1952
+
1953
+ @since("3.0.0")
1954
+ def setMinWeightFractionPerNode(self, value: float) -> "GBTRegressor":
1955
+ """
1956
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
1957
+ """
1958
+ return self._set(minWeightFractionPerNode=value)
1959
+
1960
+
1961
+ class GBTRegressionModel(
1962
+ _JavaRegressionModel[Vector],
1963
+ _TreeEnsembleModel,
1964
+ _GBTRegressorParams,
1965
+ JavaMLWritable,
1966
+ JavaMLReadable["GBTRegressionModel"],
1967
+ ):
1968
+ """
1969
+ Model fitted by :class:`GBTRegressor`.
1970
+
1971
+ .. versionadded:: 1.4.0
1972
+ """
1973
+
1974
+ @property
1975
+ def featureImportances(self) -> Vector:
1976
+ """
1977
+ Estimate of the importance of each feature.
1978
+
1979
+ Each feature's importance is the average of its importance across all trees in the ensemble
1980
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1981
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1982
+ and follows the implementation from scikit-learn.
1983
+
1984
+ .. versionadded:: 2.0.0
1985
+
1986
+ Examples
1987
+ --------
1988
+ DecisionTreeRegressionModel.featureImportances
1989
+ """
1990
+ return self._call_java("featureImportances")
1991
+
1992
+ @property
1993
+ @since("2.0.0")
1994
+ def trees(self) -> List[DecisionTreeRegressionModel]:
1995
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
1996
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1997
+
1998
+ def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]:
1999
+ """
2000
+ Method to compute error or loss for every iteration of gradient boosting.
2001
+
2002
+ .. versionadded:: 2.4.0
2003
+
2004
+ Parameters
2005
+ ----------
2006
+ dataset : :py:class:`pyspark.sql.DataFrame`
2007
+ Test dataset to evaluate model on, where dataset is an
2008
+ instance of :py:class:`pyspark.sql.DataFrame`
2009
+ loss : str
2010
+ The loss function used to compute error.
2011
+ Supported options: squared, absolute
2012
+ """
2013
+ return self._call_java("evaluateEachIteration", dataset, loss)
2014
+
2015
+
2016
+ class _AFTSurvivalRegressionParams(
2017
+ _PredictorParams, HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth, HasMaxBlockSizeInMB
2018
+ ):
2019
+ """
2020
+ Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
2021
+
2022
+ .. versionadded:: 3.0.0
2023
+ """
2024
+
2025
+ censorCol: Param[str] = Param(
2026
+ Params._dummy(),
2027
+ "censorCol",
2028
+ "censor column name. The value of this column could be 0 or 1. "
2029
+ + "If the value is 1, it means the event has occurred i.e. "
2030
+ + "uncensored; otherwise censored.",
2031
+ typeConverter=TypeConverters.toString,
2032
+ )
2033
+ quantileProbabilities: Param[List[float]] = Param(
2034
+ Params._dummy(),
2035
+ "quantileProbabilities",
2036
+ "quantile probabilities array. Values of the quantile probabilities array "
2037
+ + "should be in the range (0, 1) and the array should be non-empty.",
2038
+ typeConverter=TypeConverters.toListFloat,
2039
+ )
2040
+ quantilesCol: Param[str] = Param(
2041
+ Params._dummy(),
2042
+ "quantilesCol",
2043
+ "quantiles column name. This column will output quantiles of "
2044
+ + "corresponding quantileProbabilities if it is set.",
2045
+ typeConverter=TypeConverters.toString,
2046
+ )
2047
+
2048
+ def __init__(self, *args: Any):
2049
+ super(_AFTSurvivalRegressionParams, self).__init__(*args)
2050
+ self._setDefault(
2051
+ censorCol="censor",
2052
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
2053
+ maxIter=100,
2054
+ tol=1e-6,
2055
+ maxBlockSizeInMB=0.0,
2056
+ )
2057
+
2058
+ @since("1.6.0")
2059
+ def getCensorCol(self) -> str:
2060
+ """
2061
+ Gets the value of censorCol or its default value.
2062
+ """
2063
+ return self.getOrDefault(self.censorCol)
2064
+
2065
+ @since("1.6.0")
2066
+ def getQuantileProbabilities(self) -> List[float]:
2067
+ """
2068
+ Gets the value of quantileProbabilities or its default value.
2069
+ """
2070
+ return self.getOrDefault(self.quantileProbabilities)
2071
+
2072
+ @since("1.6.0")
2073
+ def getQuantilesCol(self) -> str:
2074
+ """
2075
+ Gets the value of quantilesCol or its default value.
2076
+ """
2077
+ return self.getOrDefault(self.quantilesCol)
2078
+
2079
+
2080
+ @inherit_doc
2081
+ class AFTSurvivalRegression(
2082
+ _JavaRegressor["AFTSurvivalRegressionModel"],
2083
+ _AFTSurvivalRegressionParams,
2084
+ JavaMLWritable,
2085
+ JavaMLReadable["AFTSurvivalRegression"],
2086
+ ):
2087
+ """
2088
+ Accelerated Failure Time (AFT) Model Survival Regression
2089
+
2090
+ Fit a parametric AFT survival regression model based on the Weibull distribution
2091
+ of the survival time.
2092
+
2093
+ Notes
2094
+ -----
2095
+ For more information see Wikipedia page on
2096
+ `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
2097
+
2098
+
2099
+ Examples
2100
+ --------
2101
+ >>> from pyspark.ml.linalg import Vectors
2102
+ >>> df = spark.createDataFrame([
2103
+ ... (1.0, Vectors.dense(1.0), 1.0),
2104
+ ... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
2105
+ >>> aftsr = AFTSurvivalRegression()
2106
+ >>> aftsr.setMaxIter(10)
2107
+ AFTSurvivalRegression...
2108
+ >>> aftsr.getMaxIter()
2109
+ 10
2110
+ >>> aftsr.clear(aftsr.maxIter)
2111
+ >>> model = aftsr.fit(df)
2112
+ >>> model.getMaxBlockSizeInMB()
2113
+ 0.0
2114
+ >>> model.setFeaturesCol("features")
2115
+ AFTSurvivalRegressionModel...
2116
+ >>> model.predict(Vectors.dense(6.3))
2117
+ 1.0
2118
+ >>> model.predictQuantiles(Vectors.dense(6.3))
2119
+ DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
2120
+ >>> model.transform(df).show()
2121
+ +-------+---------+------+----------+
2122
+ | label| features|censor|prediction|
2123
+ +-------+---------+------+----------+
2124
+ | 1.0| [1.0]| 1.0| 1.0|
2125
+ |1.0E-40|(1,[],[])| 0.0| 1.0|
2126
+ +-------+---------+------+----------+
2127
+ ...
2128
+ >>> aftsr_path = temp_path + "/aftsr"
2129
+ >>> aftsr.save(aftsr_path)
2130
+ >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
2131
+ >>> aftsr2.getMaxIter()
2132
+ 100
2133
+ >>> model_path = temp_path + "/aftsr_model"
2134
+ >>> model.save(model_path)
2135
+ >>> model2 = AFTSurvivalRegressionModel.load(model_path)
2136
+ >>> model.coefficients == model2.coefficients
2137
+ True
2138
+ >>> model.intercept == model2.intercept
2139
+ True
2140
+ >>> model.scale == model2.scale
2141
+ True
2142
+ >>> model.transform(df).take(1) == model2.transform(df).take(1)
2143
+ True
2144
+
2145
+ .. versionadded:: 1.6.0
2146
+ """
2147
+
2148
+ _input_kwargs: Dict[str, Any]
2149
+
2150
+ @keyword_only
2151
+ def __init__(
2152
+ self,
2153
+ *,
2154
+ featuresCol: str = "features",
2155
+ labelCol: str = "label",
2156
+ predictionCol: str = "prediction",
2157
+ fitIntercept: bool = True,
2158
+ maxIter: int = 100,
2159
+ tol: float = 1e-6,
2160
+ censorCol: str = "censor",
2161
+ quantileProbabilities: List[float] = [
2162
+ 0.01,
2163
+ 0.05,
2164
+ 0.1,
2165
+ 0.25,
2166
+ 0.5,
2167
+ 0.75,
2168
+ 0.9,
2169
+ 0.95,
2170
+ 0.99,
2171
+ ], # noqa: B005
2172
+ quantilesCol: Optional[str] = None,
2173
+ aggregationDepth: int = 2,
2174
+ maxBlockSizeInMB: float = 0.0,
2175
+ ):
2176
+ """
2177
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2178
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
2179
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
2180
+ quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0)
2181
+ """
2182
+ super(AFTSurvivalRegression, self).__init__()
2183
+ self._java_obj = self._new_java_obj(
2184
+ "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid
2185
+ )
2186
+ kwargs = self._input_kwargs
2187
+ self.setParams(**kwargs)
2188
+
2189
+ @keyword_only
2190
+ @since("1.6.0")
2191
+ def setParams(
2192
+ self,
2193
+ *,
2194
+ featuresCol: str = "features",
2195
+ labelCol: str = "label",
2196
+ predictionCol: str = "prediction",
2197
+ fitIntercept: bool = True,
2198
+ maxIter: int = 100,
2199
+ tol: float = 1e-6,
2200
+ censorCol: str = "censor",
2201
+ quantileProbabilities: List[float] = [
2202
+ 0.01,
2203
+ 0.05,
2204
+ 0.1,
2205
+ 0.25,
2206
+ 0.5,
2207
+ 0.75,
2208
+ 0.9,
2209
+ 0.95,
2210
+ 0.99,
2211
+ ], # noqa: B005
2212
+ quantilesCol: Optional[str] = None,
2213
+ aggregationDepth: int = 2,
2214
+ maxBlockSizeInMB: float = 0.0,
2215
+ ) -> "AFTSurvivalRegression":
2216
+ """
2217
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2218
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
2219
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
2220
+ quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0):
2221
+ """
2222
+ kwargs = self._input_kwargs
2223
+ return self._set(**kwargs)
2224
+
2225
+ def _create_model(self, java_model: "JavaObject") -> "AFTSurvivalRegressionModel":
2226
+ return AFTSurvivalRegressionModel(java_model)
2227
+
2228
+ @since("1.6.0")
2229
+ def setCensorCol(self, value: str) -> "AFTSurvivalRegression":
2230
+ """
2231
+ Sets the value of :py:attr:`censorCol`.
2232
+ """
2233
+ return self._set(censorCol=value)
2234
+
2235
+ @since("1.6.0")
2236
+ def setQuantileProbabilities(self, value: List[float]) -> "AFTSurvivalRegression":
2237
+ """
2238
+ Sets the value of :py:attr:`quantileProbabilities`.
2239
+ """
2240
+ return self._set(quantileProbabilities=value)
2241
+
2242
+ @since("1.6.0")
2243
+ def setQuantilesCol(self, value: str) -> "AFTSurvivalRegression":
2244
+ """
2245
+ Sets the value of :py:attr:`quantilesCol`.
2246
+ """
2247
+ return self._set(quantilesCol=value)
2248
+
2249
+ @since("1.6.0")
2250
+ def setMaxIter(self, value: int) -> "AFTSurvivalRegression":
2251
+ """
2252
+ Sets the value of :py:attr:`maxIter`.
2253
+ """
2254
+ return self._set(maxIter=value)
2255
+
2256
+ @since("1.6.0")
2257
+ def setTol(self, value: float) -> "AFTSurvivalRegression":
2258
+ """
2259
+ Sets the value of :py:attr:`tol`.
2260
+ """
2261
+ return self._set(tol=value)
2262
+
2263
+ @since("1.6.0")
2264
+ def setFitIntercept(self, value: bool) -> "AFTSurvivalRegression":
2265
+ """
2266
+ Sets the value of :py:attr:`fitIntercept`.
2267
+ """
2268
+ return self._set(fitIntercept=value)
2269
+
2270
+ @since("2.1.0")
2271
+ def setAggregationDepth(self, value: int) -> "AFTSurvivalRegression":
2272
+ """
2273
+ Sets the value of :py:attr:`aggregationDepth`.
2274
+ """
2275
+ return self._set(aggregationDepth=value)
2276
+
2277
+ @since("3.1.0")
2278
+ def setMaxBlockSizeInMB(self, value: int) -> "AFTSurvivalRegression":
2279
+ """
2280
+ Sets the value of :py:attr:`maxBlockSizeInMB`.
2281
+ """
2282
+ return self._set(maxBlockSizeInMB=value)
2283
+
2284
+
2285
+ class AFTSurvivalRegressionModel(
2286
+ _JavaRegressionModel[Vector],
2287
+ _AFTSurvivalRegressionParams,
2288
+ JavaMLWritable,
2289
+ JavaMLReadable["AFTSurvivalRegressionModel"],
2290
+ ):
2291
+ """
2292
+ Model fitted by :class:`AFTSurvivalRegression`.
2293
+
2294
+ .. versionadded:: 1.6.0
2295
+ """
2296
+
2297
+ @since("3.0.0")
2298
+ def setQuantileProbabilities(self, value: List[float]) -> "AFTSurvivalRegressionModel":
2299
+ """
2300
+ Sets the value of :py:attr:`quantileProbabilities`.
2301
+ """
2302
+ return self._set(quantileProbabilities=value)
2303
+
2304
+ @since("3.0.0")
2305
+ def setQuantilesCol(self, value: str) -> "AFTSurvivalRegressionModel":
2306
+ """
2307
+ Sets the value of :py:attr:`quantilesCol`.
2308
+ """
2309
+ return self._set(quantilesCol=value)
2310
+
2311
+ @property
2312
+ @since("2.0.0")
2313
+ def coefficients(self) -> Vector:
2314
+ """
2315
+ Model coefficients.
2316
+ """
2317
+ return self._call_java("coefficients")
2318
+
2319
+ @property
2320
+ @since("1.6.0")
2321
+ def intercept(self) -> float:
2322
+ """
2323
+ Model intercept.
2324
+ """
2325
+ return self._call_java("intercept")
2326
+
2327
+ @property
2328
+ @since("1.6.0")
2329
+ def scale(self) -> float:
2330
+ """
2331
+ Model scale parameter.
2332
+ """
2333
+ return self._call_java("scale")
2334
+
2335
+ @since("2.0.0")
2336
+ def predictQuantiles(self, features: Vector) -> Vector:
2337
+ """
2338
+ Predicted Quantiles
2339
+ """
2340
+ return self._call_java("predictQuantiles", features)
2341
+
2342
+
2343
+ class _GeneralizedLinearRegressionParams(
2344
+ _PredictorParams,
2345
+ HasFitIntercept,
2346
+ HasMaxIter,
2347
+ HasTol,
2348
+ HasRegParam,
2349
+ HasWeightCol,
2350
+ HasSolver,
2351
+ HasAggregationDepth,
2352
+ ):
2353
+ """
2354
+ Params for :py:class:`GeneralizedLinearRegression` and
2355
+ :py:class:`GeneralizedLinearRegressionModel`.
2356
+
2357
+ .. versionadded:: 3.0.0
2358
+ """
2359
+
2360
+ family: Param[str] = Param(
2361
+ Params._dummy(),
2362
+ "family",
2363
+ "The name of family which is a description of "
2364
+ + "the error distribution to be used in the model. Supported options: "
2365
+ + "gaussian (default), binomial, poisson, gamma and tweedie.",
2366
+ typeConverter=TypeConverters.toString,
2367
+ )
2368
+ link: Param[str] = Param(
2369
+ Params._dummy(),
2370
+ "link",
2371
+ "The name of link function which provides the "
2372
+ + "relationship between the linear predictor and the mean of the distribution "
2373
+ + "function. Supported options: identity, log, inverse, logit, probit, cloglog "
2374
+ + "and sqrt.",
2375
+ typeConverter=TypeConverters.toString,
2376
+ )
2377
+ linkPredictionCol: Param[str] = Param(
2378
+ Params._dummy(),
2379
+ "linkPredictionCol",
2380
+ "link prediction (linear " + "predictor) column name",
2381
+ typeConverter=TypeConverters.toString,
2382
+ )
2383
+ variancePower: Param[float] = Param(
2384
+ Params._dummy(),
2385
+ "variancePower",
2386
+ "The power in the variance function "
2387
+ + "of the Tweedie distribution which characterizes the relationship "
2388
+ + "between the variance and mean of the distribution. Only applicable "
2389
+ + "for the Tweedie family. Supported values: 0 and [1, Inf).",
2390
+ typeConverter=TypeConverters.toFloat,
2391
+ )
2392
+ linkPower: Param[float] = Param(
2393
+ Params._dummy(),
2394
+ "linkPower",
2395
+ "The index in the power link function. " + "Only applicable to the Tweedie family.",
2396
+ typeConverter=TypeConverters.toFloat,
2397
+ )
2398
+ solver: Param[str] = Param(
2399
+ Params._dummy(),
2400
+ "solver",
2401
+ "The solver algorithm for optimization. Supported " + "options: irls.",
2402
+ typeConverter=TypeConverters.toString,
2403
+ )
2404
+ offsetCol: Param[str] = Param(
2405
+ Params._dummy(),
2406
+ "offsetCol",
2407
+ "The offset column name. If this is not set "
2408
+ + "or empty, we treat all instance offsets as 0.0",
2409
+ typeConverter=TypeConverters.toString,
2410
+ )
2411
+
2412
+ def __init__(self, *args: Any):
2413
+ super(_GeneralizedLinearRegressionParams, self).__init__(*args)
2414
+ self._setDefault(
2415
+ family="gaussian",
2416
+ maxIter=25,
2417
+ tol=1e-6,
2418
+ regParam=0.0,
2419
+ solver="irls",
2420
+ variancePower=0.0,
2421
+ aggregationDepth=2,
2422
+ )
2423
+
2424
+ @since("2.0.0")
2425
+ def getFamily(self) -> str:
2426
+ """
2427
+ Gets the value of family or its default value.
2428
+ """
2429
+ return self.getOrDefault(self.family)
2430
+
2431
+ @since("2.0.0")
2432
+ def getLinkPredictionCol(self) -> str:
2433
+ """
2434
+ Gets the value of linkPredictionCol or its default value.
2435
+ """
2436
+ return self.getOrDefault(self.linkPredictionCol)
2437
+
2438
+ @since("2.0.0")
2439
+ def getLink(self) -> str:
2440
+ """
2441
+ Gets the value of link or its default value.
2442
+ """
2443
+ return self.getOrDefault(self.link)
2444
+
2445
+ @since("2.2.0")
2446
+ def getVariancePower(self) -> float:
2447
+ """
2448
+ Gets the value of variancePower or its default value.
2449
+ """
2450
+ return self.getOrDefault(self.variancePower)
2451
+
2452
+ @since("2.2.0")
2453
+ def getLinkPower(self) -> float:
2454
+ """
2455
+ Gets the value of linkPower or its default value.
2456
+ """
2457
+ return self.getOrDefault(self.linkPower)
2458
+
2459
+ @since("2.3.0")
2460
+ def getOffsetCol(self) -> str:
2461
+ """
2462
+ Gets the value of offsetCol or its default value.
2463
+ """
2464
+ return self.getOrDefault(self.offsetCol)
2465
+
2466
+
2467
+ @inherit_doc
2468
+ class GeneralizedLinearRegression(
2469
+ _JavaRegressor["GeneralizedLinearRegressionModel"],
2470
+ _GeneralizedLinearRegressionParams,
2471
+ JavaMLWritable,
2472
+ JavaMLReadable["GeneralizedLinearRegression"],
2473
+ ):
2474
+ """
2475
+ Generalized Linear Regression.
2476
+
2477
+ Fit a Generalized Linear Model specified by giving a symbolic description of the linear
2478
+ predictor (link function) and a description of the error distribution (family). It supports
2479
+ "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
2480
+ each family is listed below. The first link function of each family is the default one.
2481
+
2482
+ * "gaussian" -> "identity", "log", "inverse"
2483
+
2484
+ * "binomial" -> "logit", "probit", "cloglog"
2485
+
2486
+ * "poisson" -> "log", "identity", "sqrt"
2487
+
2488
+ * "gamma" -> "inverse", "identity", "log"
2489
+
2490
+ * "tweedie" -> power link function specified through "linkPower". \
2491
+ The default link power in the tweedie family is 1 - variancePower.
2492
+
2493
+ .. versionadded:: 2.0.0
2494
+
2495
+ Notes
2496
+ -----
2497
+ For more information see Wikipedia page on
2498
+ `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
2499
+
2500
+ Examples
2501
+ --------
2502
+ >>> from pyspark.ml.linalg import Vectors
2503
+ >>> df = spark.createDataFrame([
2504
+ ... (1.0, Vectors.dense(0.0, 0.0)),
2505
+ ... (1.0, Vectors.dense(1.0, 2.0)),
2506
+ ... (2.0, Vectors.dense(0.0, 0.0)),
2507
+ ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
2508
+ >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
2509
+ >>> glr.setRegParam(0.1)
2510
+ GeneralizedLinearRegression...
2511
+ >>> glr.getRegParam()
2512
+ 0.1
2513
+ >>> glr.clear(glr.regParam)
2514
+ >>> glr.setMaxIter(10)
2515
+ GeneralizedLinearRegression...
2516
+ >>> glr.getMaxIter()
2517
+ 10
2518
+ >>> glr.clear(glr.maxIter)
2519
+ >>> model = glr.fit(df)
2520
+ >>> model.setFeaturesCol("features")
2521
+ GeneralizedLinearRegressionModel...
2522
+ >>> model.getMaxIter()
2523
+ 25
2524
+ >>> model.getAggregationDepth()
2525
+ 2
2526
+ >>> transformed = model.transform(df)
2527
+ >>> abs(transformed.head().prediction - 1.5) < 0.001
2528
+ True
2529
+ >>> abs(transformed.head().p - 1.5) < 0.001
2530
+ True
2531
+ >>> model.coefficients
2532
+ DenseVector([1.5..., -1.0...])
2533
+ >>> model.numFeatures
2534
+ 2
2535
+ >>> abs(model.intercept - 1.5) < 0.001
2536
+ True
2537
+ >>> glr_path = temp_path + "/glr"
2538
+ >>> glr.save(glr_path)
2539
+ >>> glr2 = GeneralizedLinearRegression.load(glr_path)
2540
+ >>> glr.getFamily() == glr2.getFamily()
2541
+ True
2542
+ >>> model_path = temp_path + "/glr_model"
2543
+ >>> model.save(model_path)
2544
+ >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
2545
+ >>> model.intercept == model2.intercept
2546
+ True
2547
+ >>> model.coefficients[0] == model2.coefficients[0]
2548
+ True
2549
+ >>> model.transform(df).take(1) == model2.transform(df).take(1)
2550
+ True
2551
+ """
2552
+
2553
+ _input_kwargs: Dict[str, Any]
2554
+
2555
+ @keyword_only
2556
+ def __init__(
2557
+ self,
2558
+ *,
2559
+ labelCol: str = "label",
2560
+ featuresCol: str = "features",
2561
+ predictionCol: str = "prediction",
2562
+ family: str = "gaussian",
2563
+ link: Optional[str] = None,
2564
+ fitIntercept: bool = True,
2565
+ maxIter: int = 25,
2566
+ tol: float = 1e-6,
2567
+ regParam: float = 0.0,
2568
+ weightCol: Optional[str] = None,
2569
+ solver: str = "irls",
2570
+ linkPredictionCol: Optional[str] = None,
2571
+ variancePower: float = 0.0,
2572
+ linkPower: Optional[float] = None,
2573
+ offsetCol: Optional[str] = None,
2574
+ aggregationDepth: int = 2,
2575
+ ):
2576
+ """
2577
+ __init__(self, \\*, labelCol="label", featuresCol="features", predictionCol="prediction", \
2578
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
2579
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
2580
+ variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
2581
+ """
2582
+ super(GeneralizedLinearRegression, self).__init__()
2583
+ self._java_obj = self._new_java_obj(
2584
+ "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid
2585
+ )
2586
+ kwargs = self._input_kwargs
2587
+
2588
+ self.setParams(**kwargs)
2589
+
2590
+ @keyword_only
2591
+ @since("2.0.0")
2592
+ def setParams(
2593
+ self,
2594
+ *,
2595
+ labelCol: str = "label",
2596
+ featuresCol: str = "features",
2597
+ predictionCol: str = "prediction",
2598
+ family: str = "gaussian",
2599
+ link: Optional[str] = None,
2600
+ fitIntercept: bool = True,
2601
+ maxIter: int = 25,
2602
+ tol: float = 1e-6,
2603
+ regParam: float = 0.0,
2604
+ weightCol: Optional[str] = None,
2605
+ solver: str = "irls",
2606
+ linkPredictionCol: Optional[str] = None,
2607
+ variancePower: float = 0.0,
2608
+ linkPower: Optional[float] = None,
2609
+ offsetCol: Optional[str] = None,
2610
+ aggregationDepth: int = 2,
2611
+ ) -> "GeneralizedLinearRegression":
2612
+ """
2613
+ setParams(self, \\*, labelCol="label", featuresCol="features", predictionCol="prediction", \
2614
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
2615
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
2616
+ variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
2617
+ Sets params for generalized linear regression.
2618
+ """
2619
+ kwargs = self._input_kwargs
2620
+ return self._set(**kwargs)
2621
+
2622
+ def _create_model(self, java_model: "JavaObject") -> "GeneralizedLinearRegressionModel":
2623
+ return GeneralizedLinearRegressionModel(java_model)
2624
+
2625
+ @since("2.0.0")
2626
+ def setFamily(self, value: str) -> "GeneralizedLinearRegression":
2627
+ """
2628
+ Sets the value of :py:attr:`family`.
2629
+ """
2630
+ return self._set(family=value)
2631
+
2632
+ @since("2.0.0")
2633
+ def setLinkPredictionCol(self, value: str) -> "GeneralizedLinearRegression":
2634
+ """
2635
+ Sets the value of :py:attr:`linkPredictionCol`.
2636
+ """
2637
+ return self._set(linkPredictionCol=value)
2638
+
2639
+ @since("2.0.0")
2640
+ def setLink(self, value: str) -> "GeneralizedLinearRegression":
2641
+ """
2642
+ Sets the value of :py:attr:`link`.
2643
+ """
2644
+ return self._set(link=value)
2645
+
2646
+ @since("2.2.0")
2647
+ def setVariancePower(self, value: float) -> "GeneralizedLinearRegression":
2648
+ """
2649
+ Sets the value of :py:attr:`variancePower`.
2650
+ """
2651
+ return self._set(variancePower=value)
2652
+
2653
+ @since("2.2.0")
2654
+ def setLinkPower(self, value: float) -> "GeneralizedLinearRegression":
2655
+ """
2656
+ Sets the value of :py:attr:`linkPower`.
2657
+ """
2658
+ return self._set(linkPower=value)
2659
+
2660
+ @since("2.3.0")
2661
+ def setOffsetCol(self, value: str) -> "GeneralizedLinearRegression":
2662
+ """
2663
+ Sets the value of :py:attr:`offsetCol`.
2664
+ """
2665
+ return self._set(offsetCol=value)
2666
+
2667
+ @since("2.0.0")
2668
+ def setMaxIter(self, value: int) -> "GeneralizedLinearRegression":
2669
+ """
2670
+ Sets the value of :py:attr:`maxIter`.
2671
+ """
2672
+ return self._set(maxIter=value)
2673
+
2674
+ @since("2.0.0")
2675
+ def setRegParam(self, value: float) -> "GeneralizedLinearRegression":
2676
+ """
2677
+ Sets the value of :py:attr:`regParam`.
2678
+ """
2679
+ return self._set(regParam=value)
2680
+
2681
+ @since("2.0.0")
2682
+ def setTol(self, value: float) -> "GeneralizedLinearRegression":
2683
+ """
2684
+ Sets the value of :py:attr:`tol`.
2685
+ """
2686
+ return self._set(tol=value)
2687
+
2688
+ @since("2.0.0")
2689
+ def setFitIntercept(self, value: bool) -> "GeneralizedLinearRegression":
2690
+ """
2691
+ Sets the value of :py:attr:`fitIntercept`.
2692
+ """
2693
+ return self._set(fitIntercept=value)
2694
+
2695
+ @since("2.0.0")
2696
+ def setWeightCol(self, value: str) -> "GeneralizedLinearRegression":
2697
+ """
2698
+ Sets the value of :py:attr:`weightCol`.
2699
+ """
2700
+ return self._set(weightCol=value)
2701
+
2702
+ @since("2.0.0")
2703
+ def setSolver(self, value: str) -> "GeneralizedLinearRegression":
2704
+ """
2705
+ Sets the value of :py:attr:`solver`.
2706
+ """
2707
+ return self._set(solver=value)
2708
+
2709
+ @since("3.0.0")
2710
+ def setAggregationDepth(self, value: int) -> "GeneralizedLinearRegression":
2711
+ """
2712
+ Sets the value of :py:attr:`aggregationDepth`.
2713
+ """
2714
+ return self._set(aggregationDepth=value)
2715
+
2716
+
2717
+ class GeneralizedLinearRegressionModel(
2718
+ _JavaRegressionModel[Vector],
2719
+ _GeneralizedLinearRegressionParams,
2720
+ JavaMLWritable,
2721
+ JavaMLReadable["GeneralizedLinearRegressionModel"],
2722
+ HasTrainingSummary["GeneralizedLinearRegressionTrainingSummary"],
2723
+ ):
2724
+ """
2725
+ Model fitted by :class:`GeneralizedLinearRegression`.
2726
+
2727
+ .. versionadded:: 2.0.0
2728
+ """
2729
+
2730
+ @since("3.0.0")
2731
+ def setLinkPredictionCol(self, value: str) -> "GeneralizedLinearRegressionModel":
2732
+ """
2733
+ Sets the value of :py:attr:`linkPredictionCol`.
2734
+ """
2735
+ return self._set(linkPredictionCol=value)
2736
+
2737
+ @property
2738
+ @since("2.0.0")
2739
+ def coefficients(self) -> Vector:
2740
+ """
2741
+ Model coefficients.
2742
+ """
2743
+ return self._call_java("coefficients")
2744
+
2745
+ @property
2746
+ @since("2.0.0")
2747
+ def intercept(self) -> float:
2748
+ """
2749
+ Model intercept.
2750
+ """
2751
+ return self._call_java("intercept")
2752
+
2753
+ @property
2754
+ @since("2.0.0")
2755
+ def summary(self) -> "GeneralizedLinearRegressionTrainingSummary":
2756
+ """
2757
+ Gets summary (residuals, deviance, p-values) of model on
2758
+ training set. An exception is thrown if
2759
+ `trainingSummary is None`.
2760
+ """
2761
+ if self.hasSummary:
2762
+ return GeneralizedLinearRegressionTrainingSummary(
2763
+ super(GeneralizedLinearRegressionModel, self).summary
2764
+ )
2765
+ else:
2766
+ raise RuntimeError(
2767
+ "No training summary available for this %s" % self.__class__.__name__
2768
+ )
2769
+
2770
+ def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary":
2771
+ """
2772
+ Evaluates the model on a test dataset.
2773
+
2774
+ .. versionadded:: 2.0.0
2775
+
2776
+ Parameters
2777
+ ----------
2778
+ dataset : :py:class:`pyspark.sql.DataFrame`
2779
+ Test dataset to evaluate model on, where dataset is an
2780
+ instance of :py:class:`pyspark.sql.DataFrame`
2781
+ """
2782
+ if not isinstance(dataset, DataFrame):
2783
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
2784
+ java_glr_summary = self._call_java("evaluate", dataset)
2785
+ return GeneralizedLinearRegressionSummary(java_glr_summary)
2786
+
2787
+
2788
+ class GeneralizedLinearRegressionSummary(JavaWrapper):
2789
+ """
2790
+ Generalized linear regression results evaluated on a dataset.
2791
+
2792
+ .. versionadded:: 2.0.0
2793
+ """
2794
+
2795
+ @property
2796
+ @since("2.0.0")
2797
+ def predictions(self) -> DataFrame:
2798
+ """
2799
+ Predictions output by the model's `transform` method.
2800
+ """
2801
+ return self._call_java("predictions")
2802
+
2803
+ @property
2804
+ @since("2.0.0")
2805
+ def predictionCol(self) -> str:
2806
+ """
2807
+ Field in :py:attr:`predictions` which gives the predicted value of each instance.
2808
+ This is set to a new column name if the original model's `predictionCol` is not set.
2809
+ """
2810
+ return self._call_java("predictionCol")
2811
+
2812
+ @property
2813
+ @since("2.2.0")
2814
+ def numInstances(self) -> int:
2815
+ """
2816
+ Number of instances in DataFrame predictions.
2817
+ """
2818
+ return self._call_java("numInstances")
2819
+
2820
+ @property
2821
+ @since("2.0.0")
2822
+ def rank(self) -> int:
2823
+ """
2824
+ The numeric rank of the fitted linear model.
2825
+ """
2826
+ return self._call_java("rank")
2827
+
2828
+ @property
2829
+ @since("2.0.0")
2830
+ def degreesOfFreedom(self) -> int:
2831
+ """
2832
+ Degrees of freedom.
2833
+ """
2834
+ return self._call_java("degreesOfFreedom")
2835
+
2836
+ @property
2837
+ @since("2.0.0")
2838
+ def residualDegreeOfFreedom(self) -> int:
2839
+ """
2840
+ The residual degrees of freedom.
2841
+ """
2842
+ return self._call_java("residualDegreeOfFreedom")
2843
+
2844
+ @property
2845
+ @since("2.0.0")
2846
+ def residualDegreeOfFreedomNull(self) -> int:
2847
+ """
2848
+ The residual degrees of freedom for the null model.
2849
+ """
2850
+ return self._call_java("residualDegreeOfFreedomNull")
2851
+
2852
+ def residuals(self, residualsType: str = "deviance") -> DataFrame:
2853
+ """
2854
+ Get the residuals of the fitted model by type.
2855
+
2856
+ .. versionadded:: 2.0.0
2857
+
2858
+ Parameters
2859
+ ----------
2860
+ residualsType : str, optional
2861
+ The type of residuals which should be returned.
2862
+ Supported options: deviance (default), pearson, working, and response.
2863
+ """
2864
+ return self._call_java("residuals", residualsType)
2865
+
2866
+ @property
2867
+ @since("2.0.0")
2868
+ def nullDeviance(self) -> float:
2869
+ """
2870
+ The deviance for the null model.
2871
+ """
2872
+ return self._call_java("nullDeviance")
2873
+
2874
+ @property
2875
+ @since("2.0.0")
2876
+ def deviance(self) -> float:
2877
+ """
2878
+ The deviance for the fitted model.
2879
+ """
2880
+ return self._call_java("deviance")
2881
+
2882
+ @property
2883
+ @since("2.0.0")
2884
+ def dispersion(self) -> float:
2885
+ """
2886
+ The dispersion of the fitted model.
2887
+ It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
2888
+ estimated by the residual Pearson's Chi-Squared statistic (which is defined as
2889
+ sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
2890
+ """
2891
+ return self._call_java("dispersion")
2892
+
2893
+ @property
2894
+ @since("2.0.0")
2895
+ def aic(self) -> float:
2896
+ """
2897
+ Akaike's "An Information Criterion"(AIC) for the fitted model.
2898
+ """
2899
+ return self._call_java("aic")
2900
+
2901
+
2902
+ @inherit_doc
2903
+ class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary):
2904
+ """
2905
+ Generalized linear regression training results.
2906
+
2907
+ .. versionadded:: 2.0.0
2908
+ """
2909
+
2910
+ @property
2911
+ @since("2.0.0")
2912
+ def numIterations(self) -> int:
2913
+ """
2914
+ Number of training iterations.
2915
+ """
2916
+ return self._call_java("numIterations")
2917
+
2918
+ @property
2919
+ @since("2.0.0")
2920
+ def solver(self) -> str:
2921
+ """
2922
+ The numeric solver used for training.
2923
+ """
2924
+ return self._call_java("solver")
2925
+
2926
+ @property
2927
+ @since("2.0.0")
2928
+ def coefficientStandardErrors(self) -> List[float]:
2929
+ """
2930
+ Standard error of estimated coefficients and intercept.
2931
+
2932
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2933
+ then the last element returned corresponds to the intercept.
2934
+ """
2935
+ return self._call_java("coefficientStandardErrors")
2936
+
2937
+ @property
2938
+ @since("2.0.0")
2939
+ def tValues(self) -> List[float]:
2940
+ """
2941
+ T-statistic of estimated coefficients and intercept.
2942
+
2943
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2944
+ then the last element returned corresponds to the intercept.
2945
+ """
2946
+ return self._call_java("tValues")
2947
+
2948
+ @property
2949
+ @since("2.0.0")
2950
+ def pValues(self) -> List[float]:
2951
+ """
2952
+ Two-sided p-value of estimated coefficients and intercept.
2953
+
2954
+ If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2955
+ then the last element returned corresponds to the intercept.
2956
+ """
2957
+ return self._call_java("pValues")
2958
+
2959
+ def __repr__(self) -> str:
2960
+ return self._call_java("toString")
2961
+
2962
+
2963
+ class _FactorizationMachinesParams(
2964
+ _PredictorParams,
2965
+ HasMaxIter,
2966
+ HasStepSize,
2967
+ HasTol,
2968
+ HasSolver,
2969
+ HasSeed,
2970
+ HasFitIntercept,
2971
+ HasRegParam,
2972
+ HasWeightCol,
2973
+ ):
2974
+ """
2975
+ Params for :py:class:`FMRegressor`, :py:class:`FMRegressionModel`, :py:class:`FMClassifier`
2976
+ and :py:class:`FMClassifierModel`.
2977
+
2978
+ .. versionadded:: 3.0.0
2979
+ """
2980
+
2981
+ factorSize: Param[int] = Param(
2982
+ Params._dummy(),
2983
+ "factorSize",
2984
+ "Dimensionality of the factor vectors, "
2985
+ + "which are used to get pairwise interactions between variables",
2986
+ typeConverter=TypeConverters.toInt,
2987
+ )
2988
+
2989
+ fitLinear: Param[bool] = Param(
2990
+ Params._dummy(),
2991
+ "fitLinear",
2992
+ "whether to fit linear term (aka 1-way term)",
2993
+ typeConverter=TypeConverters.toBoolean,
2994
+ )
2995
+
2996
+ miniBatchFraction: Param[float] = Param(
2997
+ Params._dummy(),
2998
+ "miniBatchFraction",
2999
+ "fraction of the input data "
3000
+ + "set that should be used for one iteration of gradient descent",
3001
+ typeConverter=TypeConverters.toFloat,
3002
+ )
3003
+
3004
+ initStd: Param[float] = Param(
3005
+ Params._dummy(),
3006
+ "initStd",
3007
+ "standard deviation of initial coefficients",
3008
+ typeConverter=TypeConverters.toFloat,
3009
+ )
3010
+
3011
+ solver = Param(
3012
+ Params._dummy(),
3013
+ "solver",
3014
+ "The solver algorithm for optimization. Supported " + "options: gd, adamW. (Default adamW)",
3015
+ typeConverter=TypeConverters.toString,
3016
+ )
3017
+
3018
+ def __init__(self, *args: Any):
3019
+ super(_FactorizationMachinesParams, self).__init__(*args)
3020
+ self._setDefault(
3021
+ factorSize=8,
3022
+ fitIntercept=True,
3023
+ fitLinear=True,
3024
+ regParam=0.0,
3025
+ miniBatchFraction=1.0,
3026
+ initStd=0.01,
3027
+ maxIter=100,
3028
+ stepSize=1.0,
3029
+ tol=1e-6,
3030
+ solver="adamW",
3031
+ )
3032
+
3033
+ @since("3.0.0")
3034
+ def getFactorSize(self) -> int:
3035
+ """
3036
+ Gets the value of factorSize or its default value.
3037
+ """
3038
+ return self.getOrDefault(self.factorSize)
3039
+
3040
+ @since("3.0.0")
3041
+ def getFitLinear(self) -> bool:
3042
+ """
3043
+ Gets the value of fitLinear or its default value.
3044
+ """
3045
+ return self.getOrDefault(self.fitLinear)
3046
+
3047
+ @since("3.0.0")
3048
+ def getMiniBatchFraction(self) -> float:
3049
+ """
3050
+ Gets the value of miniBatchFraction or its default value.
3051
+ """
3052
+ return self.getOrDefault(self.miniBatchFraction)
3053
+
3054
+ @since("3.0.0")
3055
+ def getInitStd(self) -> float:
3056
+ """
3057
+ Gets the value of initStd or its default value.
3058
+ """
3059
+ return self.getOrDefault(self.initStd)
3060
+
3061
+
3062
+ @inherit_doc
3063
+ class FMRegressor(
3064
+ _JavaRegressor["FMRegressionModel"],
3065
+ _FactorizationMachinesParams,
3066
+ JavaMLWritable,
3067
+ JavaMLReadable["FMRegressor"],
3068
+ ):
3069
+ """
3070
+ Factorization Machines learning algorithm for regression.
3071
+
3072
+ solver Supports:
3073
+
3074
+ * gd (normal mini-batch gradient descent)
3075
+ * adamW (default)
3076
+
3077
+ .. versionadded:: 3.0.0
3078
+
3079
+ Examples
3080
+ --------
3081
+ >>> from pyspark.ml.linalg import Vectors
3082
+ >>> from pyspark.ml.regression import FMRegressor
3083
+ >>> df = spark.createDataFrame([
3084
+ ... (2.0, Vectors.dense(2.0)),
3085
+ ... (1.0, Vectors.dense(1.0)),
3086
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
3087
+ >>>
3088
+ >>> fm = FMRegressor(factorSize=2)
3089
+ >>> fm.setSeed(16)
3090
+ FMRegressor...
3091
+ >>> model = fm.fit(df)
3092
+ >>> model.getMaxIter()
3093
+ 100
3094
+ >>> test0 = spark.createDataFrame([
3095
+ ... (Vectors.dense(-2.0),),
3096
+ ... (Vectors.dense(0.5),),
3097
+ ... (Vectors.dense(1.0),),
3098
+ ... (Vectors.dense(4.0),)], ["features"])
3099
+ >>> model.transform(test0).show(10, False)
3100
+ +--------+-------------------+
3101
+ |features|prediction |
3102
+ +--------+-------------------+
3103
+ |[-2.0] |-1.9989237712341565|
3104
+ |[0.5] |0.4956682219523814 |
3105
+ |[1.0] |0.994586620589689 |
3106
+ |[4.0] |3.9880970124135344 |
3107
+ +--------+-------------------+
3108
+ ...
3109
+ >>> model.intercept
3110
+ -0.0032501766849261557
3111
+ >>> model.linear
3112
+ DenseVector([0.9978])
3113
+ >>> model.factors
3114
+ DenseMatrix(1, 2, [0.0173, 0.0021], 1)
3115
+ >>> model_path = temp_path + "/fm_model"
3116
+ >>> model.save(model_path)
3117
+ >>> model2 = FMRegressionModel.load(model_path)
3118
+ >>> model2.intercept
3119
+ -0.0032501766849261557
3120
+ >>> model2.linear
3121
+ DenseVector([0.9978])
3122
+ >>> model2.factors
3123
+ DenseMatrix(1, 2, [0.0173, 0.0021], 1)
3124
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
3125
+ True
3126
+ """
3127
+
3128
+ _input_kwargs: Dict[str, Any]
3129
+
3130
+ @keyword_only
3131
+ def __init__(
3132
+ self,
3133
+ *,
3134
+ featuresCol: str = "features",
3135
+ labelCol: str = "label",
3136
+ predictionCol: str = "prediction",
3137
+ factorSize: int = 8,
3138
+ fitIntercept: bool = True,
3139
+ fitLinear: bool = True,
3140
+ regParam: float = 0.0,
3141
+ miniBatchFraction: float = 1.0,
3142
+ initStd: float = 0.01,
3143
+ maxIter: int = 100,
3144
+ stepSize: float = 1.0,
3145
+ tol: float = 1e-6,
3146
+ solver: str = "adamW",
3147
+ seed: Optional[int] = None,
3148
+ ):
3149
+ """
3150
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3151
+ factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
3152
+ miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
3153
+ tol=1e-6, solver="adamW", seed=None)
3154
+ """
3155
+ super(FMRegressor, self).__init__()
3156
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.FMRegressor", self.uid)
3157
+ kwargs = self._input_kwargs
3158
+ self.setParams(**kwargs)
3159
+
3160
+ @keyword_only
3161
+ @since("3.0.0")
3162
+ def setParams(
3163
+ self,
3164
+ *,
3165
+ featuresCol: str = "features",
3166
+ labelCol: str = "label",
3167
+ predictionCol: str = "prediction",
3168
+ factorSize: int = 8,
3169
+ fitIntercept: bool = True,
3170
+ fitLinear: bool = True,
3171
+ regParam: float = 0.0,
3172
+ miniBatchFraction: float = 1.0,
3173
+ initStd: float = 0.01,
3174
+ maxIter: int = 100,
3175
+ stepSize: float = 1.0,
3176
+ tol: float = 1e-6,
3177
+ solver: str = "adamW",
3178
+ seed: Optional[int] = None,
3179
+ ) -> "FMRegressor":
3180
+ """
3181
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3182
+ factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
3183
+ miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
3184
+ tol=1e-6, solver="adamW", seed=None)
3185
+ Sets Params for FMRegressor.
3186
+ """
3187
+ kwargs = self._input_kwargs
3188
+ return self._set(**kwargs)
3189
+
3190
+ def _create_model(self, java_model: "JavaObject") -> "FMRegressionModel":
3191
+ return FMRegressionModel(java_model)
3192
+
3193
+ @since("3.0.0")
3194
+ def setFactorSize(self, value: int) -> "FMRegressor":
3195
+ """
3196
+ Sets the value of :py:attr:`factorSize`.
3197
+ """
3198
+ return self._set(factorSize=value)
3199
+
3200
+ @since("3.0.0")
3201
+ def setFitLinear(self, value: bool) -> "FMRegressor":
3202
+ """
3203
+ Sets the value of :py:attr:`fitLinear`.
3204
+ """
3205
+ return self._set(fitLinear=value)
3206
+
3207
+ @since("3.0.0")
3208
+ def setMiniBatchFraction(self, value: float) -> "FMRegressor":
3209
+ """
3210
+ Sets the value of :py:attr:`miniBatchFraction`.
3211
+ """
3212
+ return self._set(miniBatchFraction=value)
3213
+
3214
+ @since("3.0.0")
3215
+ def setInitStd(self, value: float) -> "FMRegressor":
3216
+ """
3217
+ Sets the value of :py:attr:`initStd`.
3218
+ """
3219
+ return self._set(initStd=value)
3220
+
3221
+ @since("3.0.0")
3222
+ def setMaxIter(self, value: int) -> "FMRegressor":
3223
+ """
3224
+ Sets the value of :py:attr:`maxIter`.
3225
+ """
3226
+ return self._set(maxIter=value)
3227
+
3228
+ @since("3.0.0")
3229
+ def setStepSize(self, value: float) -> "FMRegressor":
3230
+ """
3231
+ Sets the value of :py:attr:`stepSize`.
3232
+ """
3233
+ return self._set(stepSize=value)
3234
+
3235
+ @since("3.0.0")
3236
+ def setTol(self, value: float) -> "FMRegressor":
3237
+ """
3238
+ Sets the value of :py:attr:`tol`.
3239
+ """
3240
+ return self._set(tol=value)
3241
+
3242
+ @since("3.0.0")
3243
+ def setSolver(self, value: str) -> "FMRegressor":
3244
+ """
3245
+ Sets the value of :py:attr:`solver`.
3246
+ """
3247
+ return self._set(solver=value)
3248
+
3249
+ @since("3.0.0")
3250
+ def setSeed(self, value: int) -> "FMRegressor":
3251
+ """
3252
+ Sets the value of :py:attr:`seed`.
3253
+ """
3254
+ return self._set(seed=value)
3255
+
3256
+ @since("3.0.0")
3257
+ def setFitIntercept(self, value: bool) -> "FMRegressor":
3258
+ """
3259
+ Sets the value of :py:attr:`fitIntercept`.
3260
+ """
3261
+ return self._set(fitIntercept=value)
3262
+
3263
+ @since("3.0.0")
3264
+ def setRegParam(self, value: float) -> "FMRegressor":
3265
+ """
3266
+ Sets the value of :py:attr:`regParam`.
3267
+ """
3268
+ return self._set(regParam=value)
3269
+
3270
+
3271
+ class FMRegressionModel(
3272
+ _JavaRegressionModel,
3273
+ _FactorizationMachinesParams,
3274
+ JavaMLWritable,
3275
+ JavaMLReadable["FMRegressionModel"],
3276
+ ):
3277
+ """
3278
+ Model fitted by :class:`FMRegressor`.
3279
+
3280
+ .. versionadded:: 3.0.0
3281
+ """
3282
+
3283
+ @property
3284
+ @since("3.0.0")
3285
+ def intercept(self) -> float:
3286
+ """
3287
+ Model intercept.
3288
+ """
3289
+ return self._call_java("intercept")
3290
+
3291
+ @property
3292
+ @since("3.0.0")
3293
+ def linear(self) -> Vector:
3294
+ """
3295
+ Model linear term.
3296
+ """
3297
+ return self._call_java("linear")
3298
+
3299
+ @property
3300
+ @since("3.0.0")
3301
+ def factors(self) -> Matrix:
3302
+ """
3303
+ Model factor term.
3304
+ """
3305
+ return self._call_java("factors")
3306
+
3307
+
3308
+ if __name__ == "__main__":
3309
+ import doctest
3310
+ import pyspark.ml.regression
3311
+ from pyspark.sql import SparkSession
3312
+
3313
+ globs = pyspark.ml.regression.__dict__.copy()
3314
+ # The small batch size here ensures that we see multiple batches,
3315
+ # even in these small test examples:
3316
+ spark = SparkSession.builder.master("local[2]").appName("ml.regression tests").getOrCreate()
3317
+ sc = spark.sparkContext
3318
+ globs["sc"] = sc
3319
+ globs["spark"] = spark
3320
+ import tempfile
3321
+
3322
+ temp_path = tempfile.mkdtemp()
3323
+ globs["temp_path"] = temp_path
3324
+ try:
3325
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
3326
+ spark.stop()
3327
+ finally:
3328
+ from shutil import rmtree
3329
+
3330
+ try:
3331
+ rmtree(temp_path)
3332
+ except OSError:
3333
+ pass
3334
+ if failure_count:
3335
+ sys.exit(-1)