snowpark-connect 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (879) hide show
  1. snowflake/snowpark_connect/__init__.py +23 -0
  2. snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
  3. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
  4. snowflake/snowpark_connect/column_name_handler.py +735 -0
  5. snowflake/snowpark_connect/config.py +576 -0
  6. snowflake/snowpark_connect/constants.py +47 -0
  7. snowflake/snowpark_connect/control_server.py +52 -0
  8. snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
  9. snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
  10. snowflake/snowpark_connect/empty_dataframe.py +18 -0
  11. snowflake/snowpark_connect/error/__init__.py +11 -0
  12. snowflake/snowpark_connect/error/error_mapping.py +6174 -0
  13. snowflake/snowpark_connect/error/error_utils.py +321 -0
  14. snowflake/snowpark_connect/error/exceptions.py +24 -0
  15. snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
  16. snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
  17. snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
  18. snowflake/snowpark_connect/execute_plan/utils.py +183 -0
  19. snowflake/snowpark_connect/expression/__init__.py +3 -0
  20. snowflake/snowpark_connect/expression/literal.py +90 -0
  21. snowflake/snowpark_connect/expression/map_cast.py +343 -0
  22. snowflake/snowpark_connect/expression/map_expression.py +293 -0
  23. snowflake/snowpark_connect/expression/map_extension.py +104 -0
  24. snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
  25. snowflake/snowpark_connect/expression/map_udf.py +142 -0
  26. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
  27. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
  28. snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
  29. snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
  30. snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
  31. snowflake/snowpark_connect/expression/map_window_function.py +258 -0
  32. snowflake/snowpark_connect/expression/typer.py +125 -0
  33. snowflake/snowpark_connect/includes/__init__.py +0 -0
  34. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/python/__init__.py +21 -0
  99. snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
  100. snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
  101. snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
  102. snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
  103. snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
  104. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
  105. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
  106. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
  107. snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
  108. snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
  109. snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
  110. snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
  111. snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
  112. snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
  113. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
  114. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
  115. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
  116. snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
  117. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
  118. snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
  119. snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
  120. snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
  121. snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
  122. snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
  123. snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
  124. snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
  125. snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
  126. snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
  127. snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
  128. snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
  129. snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
  130. snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
  131. snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
  132. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
  133. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
  134. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
  135. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
  136. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
  137. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
  138. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
  139. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
  140. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
  141. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
  142. snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
  143. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
  144. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
  145. snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
  146. snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
  147. snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
  148. snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
  149. snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
  150. snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
  151. snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
  152. snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
  153. snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
  154. snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
  155. snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
  156. snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
  157. snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
  158. snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
  159. snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
  160. snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
  161. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
  162. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
  163. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
  164. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
  165. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
  166. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
  167. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
  168. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
  169. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
  170. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
  171. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
  172. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
  173. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
  174. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
  175. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
  176. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
  177. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
  178. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
  179. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
  180. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
  181. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
  182. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
  183. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
  184. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
  185. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
  186. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
  187. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
  188. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
  189. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
  190. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
  191. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
  192. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
  193. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
  194. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
  195. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
  196. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
  197. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
  198. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
  199. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
  200. snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
  201. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
  202. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
  203. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
  204. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
  205. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
  206. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
  207. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
  208. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
  209. snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
  210. snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
  211. snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
  212. snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
  213. snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
  214. snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
  215. snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
  216. snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
  217. snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
  218. snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
  219. snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
  220. snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
  221. snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
  222. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
  223. snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
  224. snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
  225. snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
  226. snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
  227. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
  228. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
  229. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
  230. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
  231. snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
  232. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
  233. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
  234. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
  235. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
  236. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
  237. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
  238. snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
  239. snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
  240. snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
  241. snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
  242. snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
  243. snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
  244. snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
  245. snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
  246. snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
  247. snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
  248. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
  249. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
  250. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
  251. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
  252. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
  253. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
  254. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
  255. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
  256. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
  257. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
  258. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
  259. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
  260. snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
  261. snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
  262. snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
  263. snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
  264. snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
  265. snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
  266. snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
  267. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
  268. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
  269. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
  270. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
  271. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
  272. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
  273. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
  274. snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
  275. snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
  276. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
  277. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
  278. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
  279. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
  280. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
  281. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
  282. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
  283. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
  284. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
  285. snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
  286. snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
  287. snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
  288. snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
  289. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
  290. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
  291. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
  292. snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
  293. snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
  294. snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
  295. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
  296. snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
  297. snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
  298. snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
  299. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
  300. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
  301. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
  302. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
  303. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
  304. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
  305. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
  306. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
  307. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
  308. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
  309. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
  310. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
  311. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
  312. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
  313. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
  314. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
  315. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
  316. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
  317. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
  318. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
  319. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
  320. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
  321. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
  322. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
  323. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
  324. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
  325. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
  326. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
  327. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
  328. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
  329. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
  330. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
  331. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
  332. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
  333. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
  334. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
  335. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
  336. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
  337. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
  338. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
  339. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
  340. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
  341. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
  342. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
  343. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
  344. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
  345. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
  346. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
  347. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
  348. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
  349. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
  350. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
  351. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
  352. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
  353. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
  354. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
  355. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
  356. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
  357. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
  358. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
  359. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
  360. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
  361. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
  362. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
  363. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
  364. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
  365. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
  366. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
  367. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
  368. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
  369. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
  370. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
  371. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
  372. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
  373. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
  374. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
  375. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
  376. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
  377. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
  378. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
  379. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
  380. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
  381. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
  382. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
  383. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
  384. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
  385. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
  386. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
  387. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
  388. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
  389. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
  390. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
  391. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
  392. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
  393. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
  394. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
  395. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
  396. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
  397. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
  398. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
  399. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
  400. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
  401. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
  402. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
  403. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
  404. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
  405. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
  406. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
  407. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
  408. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
  409. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
  410. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
  411. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
  412. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
  413. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
  414. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
  415. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
  416. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
  417. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
  418. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
  419. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
  420. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
  421. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
  422. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
  423. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
  424. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
  425. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
  426. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
  427. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
  428. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
  429. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
  430. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
  431. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
  432. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
  433. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
  434. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
  435. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
  436. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
  437. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
  438. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
  439. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
  440. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
  441. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
  442. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
  443. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
  444. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
  445. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
  446. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
  447. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
  448. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
  449. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
  450. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
  451. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
  452. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
  453. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
  454. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
  455. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
  456. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
  457. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
  458. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
  459. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
  460. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
  461. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
  462. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
  463. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
  464. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
  465. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
  466. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
  467. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
  468. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
  469. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
  470. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
  471. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
  472. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
  473. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
  474. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
  475. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
  476. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
  477. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
  478. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
  479. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
  480. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
  481. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
  482. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
  483. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
  484. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
  485. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
  486. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
  487. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
  488. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
  489. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
  490. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
  491. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
  492. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
  493. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
  494. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
  495. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
  496. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
  497. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
  498. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
  499. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
  500. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
  501. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
  502. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
  503. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
  504. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
  505. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
  506. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
  507. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
  508. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
  509. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
  510. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
  511. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
  512. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
  513. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
  514. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
  515. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
  516. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
  517. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
  518. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
  519. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
  520. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
  521. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
  522. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
  523. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
  524. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
  525. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
  526. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
  527. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
  528. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
  529. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
  530. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
  531. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
  532. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
  533. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
  534. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
  535. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
  536. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
  537. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
  538. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
  539. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
  540. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
  541. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
  542. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
  543. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
  544. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
  545. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
  546. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
  547. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
  548. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
  549. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
  550. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
  551. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
  552. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
  553. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
  554. snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
  555. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
  556. snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
  557. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
  558. snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
  559. snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
  560. snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
  561. snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
  562. snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
  563. snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
  564. snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
  565. snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
  566. snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
  567. snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
  568. snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
  569. snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
  570. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
  571. snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
  572. snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
  573. snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
  574. snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
  575. snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
  576. snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
  577. snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
  578. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
  579. snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
  580. snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
  581. snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
  582. snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
  583. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
  584. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
  585. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
  586. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
  587. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
  588. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
  589. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
  590. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
  591. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
  592. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
  593. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
  594. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
  595. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
  596. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
  597. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
  598. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
  599. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
  600. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
  601. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
  602. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
  603. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
  604. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
  605. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
  606. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
  607. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
  608. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
  609. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
  610. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
  611. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
  612. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
  613. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
  614. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
  615. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
  616. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
  617. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
  618. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
  619. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
  620. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
  621. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
  622. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
  623. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
  624. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
  625. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
  626. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
  627. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
  628. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
  629. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
  630. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
  631. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
  632. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
  633. snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
  634. snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
  635. snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
  636. snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
  637. snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
  638. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
  639. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
  640. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
  641. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
  642. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
  643. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
  644. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
  645. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
  646. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
  647. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
  648. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
  649. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
  650. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
  651. snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
  652. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
  653. snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
  654. snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
  655. snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
  656. snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
  657. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
  658. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
  659. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
  660. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
  661. snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
  662. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
  663. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
  664. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
  665. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
  666. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
  667. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
  668. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
  669. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
  670. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
  671. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
  672. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
  673. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
  674. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
  675. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
  676. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
  677. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
  678. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
  679. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
  680. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
  681. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
  682. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
  683. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
  684. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
  685. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
  686. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
  687. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
  688. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
  689. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
  690. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
  691. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
  692. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
  693. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
  694. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
  695. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
  696. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
  697. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
  698. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
  699. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
  700. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
  701. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
  702. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
  703. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
  704. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
  705. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
  706. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
  707. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
  708. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
  709. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
  710. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
  711. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
  712. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
  713. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
  714. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
  715. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
  716. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
  717. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
  718. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
  719. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
  720. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
  721. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
  722. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
  723. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
  724. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
  725. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
  726. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
  727. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
  728. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
  729. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
  730. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
  731. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
  732. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
  733. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
  734. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
  735. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
  736. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
  737. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
  738. snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
  739. snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
  740. snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
  741. snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
  742. snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
  743. snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
  744. snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
  745. snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
  746. snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
  747. snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
  748. snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
  749. snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
  750. snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
  751. snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
  752. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
  753. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
  754. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
  755. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
  756. snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
  757. snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
  758. snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
  759. snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
  760. snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
  761. snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
  762. snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
  763. snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
  764. snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
  765. snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
  766. snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
  767. snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
  768. snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
  769. snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
  770. snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
  771. snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
  772. snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
  773. snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
  774. snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
  775. snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
  776. snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
  777. snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
  778. snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
  779. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
  780. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
  781. snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
  782. snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
  783. snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
  784. snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
  785. snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
  786. snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
  787. snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
  788. snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
  789. snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
  790. snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
  791. snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
  792. snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
  793. snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
  794. snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
  795. snowflake/snowpark_connect/proto/__init__.py +10 -0
  796. snowflake/snowpark_connect/proto/control_pb2.py +35 -0
  797. snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
  798. snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
  799. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
  800. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
  801. snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
  802. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
  803. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
  804. snowflake/snowpark_connect/relation/__init__.py +3 -0
  805. snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
  806. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
  807. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
  808. snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
  809. snowflake/snowpark_connect/relation/io_utils.py +76 -0
  810. snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
  811. snowflake/snowpark_connect/relation/map_catalog.py +151 -0
  812. snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
  813. snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
  814. snowflake/snowpark_connect/relation/map_extension.py +412 -0
  815. snowflake/snowpark_connect/relation/map_join.py +341 -0
  816. snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
  817. snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
  818. snowflake/snowpark_connect/relation/map_relation.py +253 -0
  819. snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
  820. snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
  821. snowflake/snowpark_connect/relation/map_show_string.py +50 -0
  822. snowflake/snowpark_connect/relation/map_sql.py +1874 -0
  823. snowflake/snowpark_connect/relation/map_stats.py +324 -0
  824. snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
  825. snowflake/snowpark_connect/relation/map_udtf.py +288 -0
  826. snowflake/snowpark_connect/relation/read/__init__.py +7 -0
  827. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
  828. snowflake/snowpark_connect/relation/read/map_read.py +367 -0
  829. snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
  830. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
  831. snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
  832. snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
  833. snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
  834. snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
  835. snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
  836. snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
  837. snowflake/snowpark_connect/relation/read/utils.py +155 -0
  838. snowflake/snowpark_connect/relation/stage_locator.py +161 -0
  839. snowflake/snowpark_connect/relation/utils.py +219 -0
  840. snowflake/snowpark_connect/relation/write/__init__.py +3 -0
  841. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
  842. snowflake/snowpark_connect/relation/write/map_write.py +436 -0
  843. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
  844. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  845. snowflake/snowpark_connect/resources_initializer.py +75 -0
  846. snowflake/snowpark_connect/server.py +1136 -0
  847. snowflake/snowpark_connect/start_server.py +32 -0
  848. snowflake/snowpark_connect/tcm.py +8 -0
  849. snowflake/snowpark_connect/type_mapping.py +1003 -0
  850. snowflake/snowpark_connect/typed_column.py +94 -0
  851. snowflake/snowpark_connect/utils/__init__.py +3 -0
  852. snowflake/snowpark_connect/utils/artifacts.py +48 -0
  853. snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
  854. snowflake/snowpark_connect/utils/cache.py +84 -0
  855. snowflake/snowpark_connect/utils/concurrent.py +124 -0
  856. snowflake/snowpark_connect/utils/context.py +390 -0
  857. snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
  858. snowflake/snowpark_connect/utils/interrupt.py +85 -0
  859. snowflake/snowpark_connect/utils/io_utils.py +35 -0
  860. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
  861. snowflake/snowpark_connect/utils/profiling.py +47 -0
  862. snowflake/snowpark_connect/utils/session.py +180 -0
  863. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
  864. snowflake/snowpark_connect/utils/telemetry.py +513 -0
  865. snowflake/snowpark_connect/utils/udf_cache.py +392 -0
  866. snowflake/snowpark_connect/utils/udf_helper.py +328 -0
  867. snowflake/snowpark_connect/utils/udf_utils.py +310 -0
  868. snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
  869. snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
  870. snowflake/snowpark_connect/utils/xxhash64.py +247 -0
  871. snowflake/snowpark_connect/version.py +6 -0
  872. snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
  873. snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
  874. snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
  875. snowpark_connect-0.20.2.dist-info/METADATA +37 -0
  876. snowpark_connect-0.20.2.dist-info/RECORD +879 -0
  877. snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
  878. snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
  879. snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4332 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import os
19
+ import operator
20
+ import sys
21
+ import uuid
22
+ import warnings
23
+ from abc import ABCMeta, abstractmethod
24
+ from multiprocessing.pool import ThreadPool
25
+
26
+ from typing import (
27
+ Any,
28
+ Dict,
29
+ Generic,
30
+ Iterable,
31
+ List,
32
+ Optional,
33
+ Type,
34
+ TypeVar,
35
+ Union,
36
+ cast,
37
+ overload,
38
+ TYPE_CHECKING,
39
+ )
40
+
41
+ from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
42
+ from pyspark.ml import Estimator, Predictor, PredictionModel, Model
43
+ from pyspark.ml.param.shared import (
44
+ HasRawPredictionCol,
45
+ HasProbabilityCol,
46
+ HasThresholds,
47
+ HasRegParam,
48
+ HasMaxIter,
49
+ HasFitIntercept,
50
+ HasTol,
51
+ HasStandardization,
52
+ HasWeightCol,
53
+ HasAggregationDepth,
54
+ HasThreshold,
55
+ HasBlockSize,
56
+ HasMaxBlockSizeInMB,
57
+ Param,
58
+ Params,
59
+ TypeConverters,
60
+ HasElasticNetParam,
61
+ HasSeed,
62
+ HasStepSize,
63
+ HasSolver,
64
+ HasParallelism,
65
+ )
66
+ from pyspark.ml.tree import (
67
+ _DecisionTreeModel,
68
+ _DecisionTreeParams,
69
+ _TreeEnsembleModel,
70
+ _RandomForestParams,
71
+ _GBTParams,
72
+ _HasVarianceImpurity,
73
+ _TreeClassifierParams,
74
+ )
75
+ from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
76
+ from pyspark.ml.base import _PredictorParams
77
+ from pyspark.ml.util import (
78
+ DefaultParamsReader,
79
+ DefaultParamsWriter,
80
+ JavaMLReadable,
81
+ JavaMLReader,
82
+ JavaMLWritable,
83
+ JavaMLWriter,
84
+ MLReader,
85
+ MLReadable,
86
+ MLWriter,
87
+ MLWritable,
88
+ HasTrainingSummary,
89
+ )
90
+ from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
91
+ from pyspark.ml.common import inherit_doc
92
+ from pyspark.ml.linalg import Matrix, Vector, Vectors, VectorUDT
93
+ from pyspark.sql import DataFrame, Row
94
+ from pyspark.sql.functions import udf, when
95
+ from pyspark.sql.types import ArrayType, DoubleType
96
+ from pyspark.storagelevel import StorageLevel
97
+
98
+
99
+ if TYPE_CHECKING:
100
+ from pyspark.ml._typing import P, ParamMap
101
+ from py4j.java_gateway import JavaObject
102
+
103
+
104
+ T = TypeVar("T")
105
+ JPM = TypeVar("JPM", bound=JavaPredictionModel)
106
+ CM = TypeVar("CM", bound="ClassificationModel")
107
+
108
+ __all__ = [
109
+ "LinearSVC",
110
+ "LinearSVCModel",
111
+ "LinearSVCSummary",
112
+ "LinearSVCTrainingSummary",
113
+ "LogisticRegression",
114
+ "LogisticRegressionModel",
115
+ "LogisticRegressionSummary",
116
+ "LogisticRegressionTrainingSummary",
117
+ "BinaryLogisticRegressionSummary",
118
+ "BinaryLogisticRegressionTrainingSummary",
119
+ "DecisionTreeClassifier",
120
+ "DecisionTreeClassificationModel",
121
+ "GBTClassifier",
122
+ "GBTClassificationModel",
123
+ "RandomForestClassifier",
124
+ "RandomForestClassificationModel",
125
+ "RandomForestClassificationSummary",
126
+ "RandomForestClassificationTrainingSummary",
127
+ "BinaryRandomForestClassificationSummary",
128
+ "BinaryRandomForestClassificationTrainingSummary",
129
+ "NaiveBayes",
130
+ "NaiveBayesModel",
131
+ "MultilayerPerceptronClassifier",
132
+ "MultilayerPerceptronClassificationModel",
133
+ "MultilayerPerceptronClassificationSummary",
134
+ "MultilayerPerceptronClassificationTrainingSummary",
135
+ "OneVsRest",
136
+ "OneVsRestModel",
137
+ "FMClassifier",
138
+ "FMClassificationModel",
139
+ "FMClassificationSummary",
140
+ "FMClassificationTrainingSummary",
141
+ ]
142
+
143
+
144
+ class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
145
+ """
146
+ Classifier Params for classification tasks.
147
+
148
+ .. versionadded:: 3.0.0
149
+ """
150
+
151
+ pass
152
+
153
+
154
+ @inherit_doc
155
+ class Classifier(Predictor[CM], _ClassifierParams, Generic[CM], metaclass=ABCMeta):
156
+ """
157
+ Classifier for classification tasks.
158
+ Classes are indexed {0, 1, ..., numClasses - 1}.
159
+ """
160
+
161
+ @since("3.0.0")
162
+ def setRawPredictionCol(self: "P", value: str) -> "P":
163
+ """
164
+ Sets the value of :py:attr:`rawPredictionCol`.
165
+ """
166
+ return self._set(rawPredictionCol=value)
167
+
168
+
169
+ @inherit_doc
170
+ class ClassificationModel(PredictionModel, _ClassifierParams, metaclass=ABCMeta):
171
+ """
172
+ Model produced by a ``Classifier``.
173
+ Classes are indexed {0, 1, ..., numClasses - 1}.
174
+ """
175
+
176
+ @since("3.0.0")
177
+ def setRawPredictionCol(self: "P", value: str) -> "P":
178
+ """
179
+ Sets the value of :py:attr:`rawPredictionCol`.
180
+ """
181
+ return self._set(rawPredictionCol=value)
182
+
183
+ @property
184
+ @abstractmethod
185
+ @since("2.1.0")
186
+ def numClasses(self) -> int:
187
+ """
188
+ Number of classes (values which the label can take).
189
+ """
190
+ raise NotImplementedError()
191
+
192
+ @abstractmethod
193
+ @since("3.0.0")
194
+ def predictRaw(self, value: Vector) -> Vector:
195
+ """
196
+ Raw prediction for each possible label.
197
+ """
198
+ raise NotImplementedError()
199
+
200
+
201
+ class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _ClassifierParams):
202
+ """
203
+ Params for :py:class:`ProbabilisticClassifier` and
204
+ :py:class:`ProbabilisticClassificationModel`.
205
+
206
+ .. versionadded:: 3.0.0
207
+ """
208
+
209
+ pass
210
+
211
+
212
+ @inherit_doc
213
+ class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams, metaclass=ABCMeta):
214
+ """
215
+ Probabilistic Classifier for classification tasks.
216
+ """
217
+
218
+ @since("3.0.0")
219
+ def setProbabilityCol(self: "P", value: str) -> "P":
220
+ """
221
+ Sets the value of :py:attr:`probabilityCol`.
222
+ """
223
+ return self._set(probabilityCol=value)
224
+
225
+ @since("3.0.0")
226
+ def setThresholds(self: "P", value: List[float]) -> "P":
227
+ """
228
+ Sets the value of :py:attr:`thresholds`.
229
+ """
230
+ return self._set(thresholds=value)
231
+
232
+
233
+ @inherit_doc
234
+ class ProbabilisticClassificationModel(
235
+ ClassificationModel, _ProbabilisticClassifierParams, metaclass=ABCMeta
236
+ ):
237
+ """
238
+ Model produced by a ``ProbabilisticClassifier``.
239
+ """
240
+
241
+ @since("3.0.0")
242
+ def setProbabilityCol(self: CM, value: str) -> CM:
243
+ """
244
+ Sets the value of :py:attr:`probabilityCol`.
245
+ """
246
+ return self._set(probabilityCol=value)
247
+
248
+ @since("3.0.0")
249
+ def setThresholds(self: CM, value: List[float]) -> CM:
250
+ """
251
+ Sets the value of :py:attr:`thresholds`.
252
+ """
253
+ return self._set(thresholds=value)
254
+
255
+ @abstractmethod
256
+ @since("3.0.0")
257
+ def predictProbability(self, value: Vector) -> Vector:
258
+ """
259
+ Predict the probability of each class given the features.
260
+ """
261
+ raise NotImplementedError()
262
+
263
+
264
+ @inherit_doc
265
+ class _JavaClassifier(Classifier, JavaPredictor[JPM], Generic[JPM], metaclass=ABCMeta):
266
+ """
267
+ Java Classifier for classification tasks.
268
+ Classes are indexed {0, 1, ..., numClasses - 1}.
269
+ """
270
+
271
+ @since("3.0.0")
272
+ def setRawPredictionCol(self: "P", value: str) -> "P":
273
+ """
274
+ Sets the value of :py:attr:`rawPredictionCol`.
275
+ """
276
+ return self._set(rawPredictionCol=value)
277
+
278
+
279
+ @inherit_doc
280
+ class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]):
281
+ """
282
+ Java Model produced by a ``Classifier``.
283
+ Classes are indexed {0, 1, ..., numClasses - 1}.
284
+ To be mixed in with :class:`pyspark.ml.JavaModel`
285
+ """
286
+
287
+ @property
288
+ @since("2.1.0")
289
+ def numClasses(self) -> int:
290
+ """
291
+ Number of classes (values which the label can take).
292
+ """
293
+ return self._call_java("numClasses")
294
+
295
+ @since("3.0.0")
296
+ def predictRaw(self, value: Vector) -> Vector:
297
+ """
298
+ Raw prediction for each possible label.
299
+ """
300
+ return self._call_java("predictRaw", value)
301
+
302
+
303
+ @inherit_doc
304
+ class _JavaProbabilisticClassifier(
305
+ ProbabilisticClassifier, _JavaClassifier[JPM], Generic[JPM], metaclass=ABCMeta
306
+ ):
307
+ """
308
+ Java Probabilistic Classifier for classification tasks.
309
+ """
310
+
311
+ pass
312
+
313
+
314
+ @inherit_doc
315
+ class _JavaProbabilisticClassificationModel(
316
+ ProbabilisticClassificationModel, _JavaClassificationModel[T]
317
+ ):
318
+ """
319
+ Java Model produced by a ``ProbabilisticClassifier``.
320
+ """
321
+
322
+ @since("3.0.0")
323
+ def predictProbability(self, value: Vector) -> Vector:
324
+ """
325
+ Predict the probability of each class given the features.
326
+ """
327
+ return self._call_java("predictProbability", value)
328
+
329
+
330
+ @inherit_doc
331
+ class _ClassificationSummary(JavaWrapper):
332
+ """
333
+ Abstraction for multiclass classification results for a given model.
334
+
335
+ .. versionadded:: 3.1.0
336
+ """
337
+
338
+ @property
339
+ @since("3.1.0")
340
+ def predictions(self) -> DataFrame:
341
+ """
342
+ Dataframe outputted by the model's `transform` method.
343
+ """
344
+ return self._call_java("predictions")
345
+
346
+ @property
347
+ @since("3.1.0")
348
+ def predictionCol(self) -> str:
349
+ """
350
+ Field in "predictions" which gives the prediction of each class.
351
+ """
352
+ return self._call_java("predictionCol")
353
+
354
+ @property
355
+ @since("3.1.0")
356
+ def labelCol(self) -> str:
357
+ """
358
+ Field in "predictions" which gives the true label of each
359
+ instance.
360
+ """
361
+ return self._call_java("labelCol")
362
+
363
+ @property
364
+ @since("3.1.0")
365
+ def weightCol(self) -> str:
366
+ """
367
+ Field in "predictions" which gives the weight of each instance
368
+ as a vector.
369
+ """
370
+ return self._call_java("weightCol")
371
+
372
+ @property
373
+ def labels(self) -> List[str]:
374
+ """
375
+ Returns the sequence of labels in ascending order. This order matches the order used
376
+ in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
377
+
378
+ .. versionadded:: 3.1.0
379
+
380
+ Notes
381
+ -----
382
+ In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
383
+ training set is missing a label, then all of the arrays over labels
384
+ (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
385
+ expected numClasses.
386
+ """
387
+ return self._call_java("labels")
388
+
389
+ @property
390
+ @since("3.1.0")
391
+ def truePositiveRateByLabel(self) -> List[float]:
392
+ """
393
+ Returns true positive rate for each label (category).
394
+ """
395
+ return self._call_java("truePositiveRateByLabel")
396
+
397
+ @property
398
+ @since("3.1.0")
399
+ def falsePositiveRateByLabel(self) -> List[float]:
400
+ """
401
+ Returns false positive rate for each label (category).
402
+ """
403
+ return self._call_java("falsePositiveRateByLabel")
404
+
405
+ @property
406
+ @since("3.1.0")
407
+ def precisionByLabel(self) -> List[float]:
408
+ """
409
+ Returns precision for each label (category).
410
+ """
411
+ return self._call_java("precisionByLabel")
412
+
413
+ @property
414
+ @since("3.1.0")
415
+ def recallByLabel(self) -> List[float]:
416
+ """
417
+ Returns recall for each label (category).
418
+ """
419
+ return self._call_java("recallByLabel")
420
+
421
+ @since("3.1.0")
422
+ def fMeasureByLabel(self, beta: float = 1.0) -> List[float]:
423
+ """
424
+ Returns f-measure for each label (category).
425
+ """
426
+ return self._call_java("fMeasureByLabel", beta)
427
+
428
+ @property
429
+ @since("3.1.0")
430
+ def accuracy(self) -> float:
431
+ """
432
+ Returns accuracy.
433
+ (equals to the total number of correctly classified instances
434
+ out of the total number of instances.)
435
+ """
436
+ return self._call_java("accuracy")
437
+
438
+ @property
439
+ @since("3.1.0")
440
+ def weightedTruePositiveRate(self) -> float:
441
+ """
442
+ Returns weighted true positive rate.
443
+ (equals to precision, recall and f-measure)
444
+ """
445
+ return self._call_java("weightedTruePositiveRate")
446
+
447
+ @property
448
+ @since("3.1.0")
449
+ def weightedFalsePositiveRate(self) -> float:
450
+ """
451
+ Returns weighted false positive rate.
452
+ """
453
+ return self._call_java("weightedFalsePositiveRate")
454
+
455
+ @property
456
+ @since("3.1.0")
457
+ def weightedRecall(self) -> float:
458
+ """
459
+ Returns weighted averaged recall.
460
+ (equals to precision, recall and f-measure)
461
+ """
462
+ return self._call_java("weightedRecall")
463
+
464
+ @property
465
+ @since("3.1.0")
466
+ def weightedPrecision(self) -> float:
467
+ """
468
+ Returns weighted averaged precision.
469
+ """
470
+ return self._call_java("weightedPrecision")
471
+
472
+ @since("3.1.0")
473
+ def weightedFMeasure(self, beta: float = 1.0) -> float:
474
+ """
475
+ Returns weighted averaged f-measure.
476
+ """
477
+ return self._call_java("weightedFMeasure", beta)
478
+
479
+
480
+ @inherit_doc
481
+ class _TrainingSummary(JavaWrapper):
482
+ """
483
+ Abstraction for Training results.
484
+
485
+ .. versionadded:: 3.1.0
486
+ """
487
+
488
+ @property
489
+ @since("3.1.0")
490
+ def objectiveHistory(self) -> List[float]:
491
+ """
492
+ Objective function (scaled loss + regularization) at each
493
+ iteration. It contains one more element, the initial state,
494
+ than number of iterations.
495
+ """
496
+ return self._call_java("objectiveHistory")
497
+
498
+ @property
499
+ @since("3.1.0")
500
+ def totalIterations(self) -> int:
501
+ """
502
+ Number of training iterations until termination.
503
+ """
504
+ return self._call_java("totalIterations")
505
+
506
+
507
+ @inherit_doc
508
+ class _BinaryClassificationSummary(_ClassificationSummary):
509
+ """
510
+ Binary classification results for a given model.
511
+
512
+ .. versionadded:: 3.1.0
513
+ """
514
+
515
+ @property
516
+ @since("3.1.0")
517
+ def scoreCol(self) -> str:
518
+ """
519
+ Field in "predictions" which gives the probability or raw prediction
520
+ of each class as a vector.
521
+ """
522
+ return self._call_java("scoreCol")
523
+
524
+ @property
525
+ def roc(self) -> DataFrame:
526
+ """
527
+ Returns the receiver operating characteristic (ROC) curve,
528
+ which is a Dataframe having two fields (FPR, TPR) with
529
+ (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
530
+
531
+ .. versionadded:: 3.1.0
532
+
533
+ Notes
534
+ -----
535
+ `Wikipedia reference <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
536
+ """
537
+ return self._call_java("roc")
538
+
539
+ @property
540
+ @since("3.1.0")
541
+ def areaUnderROC(self) -> float:
542
+ """
543
+ Computes the area under the receiver operating characteristic
544
+ (ROC) curve.
545
+ """
546
+ return self._call_java("areaUnderROC")
547
+
548
+ @property
549
+ @since("3.1.0")
550
+ def pr(self) -> DataFrame:
551
+ """
552
+ Returns the precision-recall curve, which is a Dataframe
553
+ containing two fields recall, precision with (0.0, 1.0) prepended
554
+ to it.
555
+ """
556
+ return self._call_java("pr")
557
+
558
+ @property
559
+ @since("3.1.0")
560
+ def fMeasureByThreshold(self) -> DataFrame:
561
+ """
562
+ Returns a dataframe with two fields (threshold, F-Measure) curve
563
+ with beta = 1.0.
564
+ """
565
+ return self._call_java("fMeasureByThreshold")
566
+
567
+ @property
568
+ @since("3.1.0")
569
+ def precisionByThreshold(self) -> DataFrame:
570
+ """
571
+ Returns a dataframe with two fields (threshold, precision) curve.
572
+ Every possible probability obtained in transforming the dataset
573
+ are used as thresholds used in calculating the precision.
574
+ """
575
+ return self._call_java("precisionByThreshold")
576
+
577
+ @property
578
+ @since("3.1.0")
579
+ def recallByThreshold(self) -> DataFrame:
580
+ """
581
+ Returns a dataframe with two fields (threshold, recall) curve.
582
+ Every possible probability obtained in transforming the dataset
583
+ are used as thresholds used in calculating the recall.
584
+ """
585
+ return self._call_java("recallByThreshold")
586
+
587
+
588
+ class _LinearSVCParams(
589
+ _ClassifierParams,
590
+ HasRegParam,
591
+ HasMaxIter,
592
+ HasFitIntercept,
593
+ HasTol,
594
+ HasStandardization,
595
+ HasWeightCol,
596
+ HasAggregationDepth,
597
+ HasThreshold,
598
+ HasMaxBlockSizeInMB,
599
+ ):
600
+ """
601
+ Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
602
+
603
+ .. versionadded:: 3.0.0
604
+ """
605
+
606
+ threshold: Param[float] = Param(
607
+ Params._dummy(),
608
+ "threshold",
609
+ "The threshold in binary classification applied to the linear model"
610
+ " prediction. This threshold can be any real number, where Inf will make"
611
+ " all predictions 0.0 and -Inf will make all predictions 1.0.",
612
+ typeConverter=TypeConverters.toFloat,
613
+ )
614
+
615
+ def __init__(self, *args: Any) -> None:
616
+ super(_LinearSVCParams, self).__init__(*args)
617
+ self._setDefault(
618
+ maxIter=100,
619
+ regParam=0.0,
620
+ tol=1e-6,
621
+ fitIntercept=True,
622
+ standardization=True,
623
+ threshold=0.0,
624
+ aggregationDepth=2,
625
+ maxBlockSizeInMB=0.0,
626
+ )
627
+
628
+
629
+ @inherit_doc
630
+ class LinearSVC(
631
+ _JavaClassifier["LinearSVCModel"],
632
+ _LinearSVCParams,
633
+ JavaMLWritable,
634
+ JavaMLReadable["LinearSVC"],
635
+ ):
636
+ """
637
+ This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
638
+ Only supports L2 regularization currently.
639
+
640
+ .. versionadded:: 2.2.0
641
+
642
+ Notes
643
+ -----
644
+ `Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
645
+
646
+ Examples
647
+ --------
648
+ >>> from pyspark.sql import Row
649
+ >>> from pyspark.ml.linalg import Vectors
650
+ >>> df = sc.parallelize([
651
+ ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
652
+ ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
653
+ >>> svm = LinearSVC()
654
+ >>> svm.getMaxIter()
655
+ 100
656
+ >>> svm.setMaxIter(5)
657
+ LinearSVC...
658
+ >>> svm.getMaxIter()
659
+ 5
660
+ >>> svm.getRegParam()
661
+ 0.0
662
+ >>> svm.setRegParam(0.01)
663
+ LinearSVC...
664
+ >>> svm.getRegParam()
665
+ 0.01
666
+ >>> model = svm.fit(df)
667
+ >>> model.setPredictionCol("newPrediction")
668
+ LinearSVCModel...
669
+ >>> model.getPredictionCol()
670
+ 'newPrediction'
671
+ >>> model.setThreshold(0.5)
672
+ LinearSVCModel...
673
+ >>> model.getThreshold()
674
+ 0.5
675
+ >>> model.getMaxBlockSizeInMB()
676
+ 0.0
677
+ >>> model.coefficients
678
+ DenseVector([0.0, -1.0319, -0.5159])
679
+ >>> model.intercept
680
+ 2.579645978780695
681
+ >>> model.numClasses
682
+ 2
683
+ >>> model.numFeatures
684
+ 3
685
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
686
+ >>> model.predict(test0.head().features)
687
+ 1.0
688
+ >>> model.predictRaw(test0.head().features)
689
+ DenseVector([-4.1274, 4.1274])
690
+ >>> result = model.transform(test0).head()
691
+ >>> result.newPrediction
692
+ 1.0
693
+ >>> result.rawPrediction
694
+ DenseVector([-4.1274, 4.1274])
695
+ >>> svm_path = temp_path + "/svm"
696
+ >>> svm.save(svm_path)
697
+ >>> svm2 = LinearSVC.load(svm_path)
698
+ >>> svm2.getMaxIter()
699
+ 5
700
+ >>> model_path = temp_path + "/svm_model"
701
+ >>> model.save(model_path)
702
+ >>> model2 = LinearSVCModel.load(model_path)
703
+ >>> model.coefficients[0] == model2.coefficients[0]
704
+ True
705
+ >>> model.intercept == model2.intercept
706
+ True
707
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
708
+ True
709
+ """
710
+
711
+ _input_kwargs: Dict[str, Any]
712
+
713
+ @keyword_only
714
+ def __init__(
715
+ self,
716
+ *,
717
+ featuresCol: str = "features",
718
+ labelCol: str = "label",
719
+ predictionCol: str = "prediction",
720
+ maxIter: int = 100,
721
+ regParam: float = 0.0,
722
+ tol: float = 1e-6,
723
+ rawPredictionCol: str = "rawPrediction",
724
+ fitIntercept: bool = True,
725
+ standardization: bool = True,
726
+ threshold: float = 0.0,
727
+ weightCol: Optional[str] = None,
728
+ aggregationDepth: int = 2,
729
+ maxBlockSizeInMB: float = 0.0,
730
+ ):
731
+ """
732
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
733
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
734
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
735
+ aggregationDepth=2, maxBlockSizeInMB=0.0):
736
+ """
737
+ super(LinearSVC, self).__init__()
738
+ self._java_obj = self._new_java_obj(
739
+ "org.apache.spark.ml.classification.LinearSVC", self.uid
740
+ )
741
+ kwargs = self._input_kwargs
742
+ self.setParams(**kwargs)
743
+
744
+ @keyword_only
745
+ @since("2.2.0")
746
+ def setParams(
747
+ self,
748
+ *,
749
+ featuresCol: str = "features",
750
+ labelCol: str = "label",
751
+ predictionCol: str = "prediction",
752
+ maxIter: int = 100,
753
+ regParam: float = 0.0,
754
+ tol: float = 1e-6,
755
+ rawPredictionCol: str = "rawPrediction",
756
+ fitIntercept: bool = True,
757
+ standardization: bool = True,
758
+ threshold: float = 0.0,
759
+ weightCol: Optional[str] = None,
760
+ aggregationDepth: int = 2,
761
+ maxBlockSizeInMB: float = 0.0,
762
+ ) -> "LinearSVC":
763
+ """
764
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
765
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
766
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
767
+ aggregationDepth=2, maxBlockSizeInMB=0.0):
768
+ Sets params for Linear SVM Classifier.
769
+ """
770
+ kwargs = self._input_kwargs
771
+ return self._set(**kwargs)
772
+
773
+ def _create_model(self, java_model: "JavaObject") -> "LinearSVCModel":
774
+ return LinearSVCModel(java_model)
775
+
776
+ @since("2.2.0")
777
+ def setMaxIter(self, value: int) -> "LinearSVC":
778
+ """
779
+ Sets the value of :py:attr:`maxIter`.
780
+ """
781
+ return self._set(maxIter=value)
782
+
783
+ @since("2.2.0")
784
+ def setRegParam(self, value: float) -> "LinearSVC":
785
+ """
786
+ Sets the value of :py:attr:`regParam`.
787
+ """
788
+ return self._set(regParam=value)
789
+
790
+ @since("2.2.0")
791
+ def setTol(self, value: float) -> "LinearSVC":
792
+ """
793
+ Sets the value of :py:attr:`tol`.
794
+ """
795
+ return self._set(tol=value)
796
+
797
+ @since("2.2.0")
798
+ def setFitIntercept(self, value: bool) -> "LinearSVC":
799
+ """
800
+ Sets the value of :py:attr:`fitIntercept`.
801
+ """
802
+ return self._set(fitIntercept=value)
803
+
804
+ @since("2.2.0")
805
+ def setStandardization(self, value: bool) -> "LinearSVC":
806
+ """
807
+ Sets the value of :py:attr:`standardization`.
808
+ """
809
+ return self._set(standardization=value)
810
+
811
+ @since("2.2.0")
812
+ def setThreshold(self, value: float) -> "LinearSVC":
813
+ """
814
+ Sets the value of :py:attr:`threshold`.
815
+ """
816
+ return self._set(threshold=value)
817
+
818
+ @since("2.2.0")
819
+ def setWeightCol(self, value: str) -> "LinearSVC":
820
+ """
821
+ Sets the value of :py:attr:`weightCol`.
822
+ """
823
+ return self._set(weightCol=value)
824
+
825
+ @since("2.2.0")
826
+ def setAggregationDepth(self, value: int) -> "LinearSVC":
827
+ """
828
+ Sets the value of :py:attr:`aggregationDepth`.
829
+ """
830
+ return self._set(aggregationDepth=value)
831
+
832
+ @since("3.1.0")
833
+ def setMaxBlockSizeInMB(self, value: float) -> "LinearSVC":
834
+ """
835
+ Sets the value of :py:attr:`maxBlockSizeInMB`.
836
+ """
837
+ return self._set(maxBlockSizeInMB=value)
838
+
839
+
840
+ class LinearSVCModel(
841
+ _JavaClassificationModel[Vector],
842
+ _LinearSVCParams,
843
+ JavaMLWritable,
844
+ JavaMLReadable["LinearSVCModel"],
845
+ HasTrainingSummary["LinearSVCTrainingSummary"],
846
+ ):
847
+ """
848
+ Model fitted by LinearSVC.
849
+
850
+ .. versionadded:: 2.2.0
851
+ """
852
+
853
+ @since("3.0.0")
854
+ def setThreshold(self, value: float) -> "LinearSVCModel":
855
+ """
856
+ Sets the value of :py:attr:`threshold`.
857
+ """
858
+ return self._set(threshold=value)
859
+
860
+ @property
861
+ @since("2.2.0")
862
+ def coefficients(self) -> Vector:
863
+ """
864
+ Model coefficients of Linear SVM Classifier.
865
+ """
866
+ return self._call_java("coefficients")
867
+
868
+ @property
869
+ @since("2.2.0")
870
+ def intercept(self) -> float:
871
+ """
872
+ Model intercept of Linear SVM Classifier.
873
+ """
874
+ return self._call_java("intercept")
875
+
876
+ @since("3.1.0")
877
+ def summary(self) -> "LinearSVCTrainingSummary":
878
+ """
879
+ Gets summary (accuracy/precision/recall, objective history, total iterations) of model
880
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
881
+ """
882
+ if self.hasSummary:
883
+ return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
884
+ else:
885
+ raise RuntimeError(
886
+ "No training summary available for this %s" % self.__class__.__name__
887
+ )
888
+
889
+ def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary":
890
+ """
891
+ Evaluates the model on a test dataset.
892
+
893
+ .. versionadded:: 3.1.0
894
+
895
+ Parameters
896
+ ----------
897
+ dataset : :py:class:`pyspark.sql.DataFrame`
898
+ Test dataset to evaluate model on.
899
+ """
900
+ if not isinstance(dataset, DataFrame):
901
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
902
+ java_lsvc_summary = self._call_java("evaluate", dataset)
903
+ return LinearSVCSummary(java_lsvc_summary)
904
+
905
+
906
+ class LinearSVCSummary(_BinaryClassificationSummary):
907
+ """
908
+ Abstraction for LinearSVC Results for a given model.
909
+
910
+ .. versionadded:: 3.1.0
911
+ """
912
+
913
+ pass
914
+
915
+
916
+ @inherit_doc
917
+ class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary):
918
+ """
919
+ Abstraction for LinearSVC Training results.
920
+
921
+ .. versionadded:: 3.1.0
922
+ """
923
+
924
+ pass
925
+
926
+
927
+ class _LogisticRegressionParams(
928
+ _ProbabilisticClassifierParams,
929
+ HasRegParam,
930
+ HasElasticNetParam,
931
+ HasMaxIter,
932
+ HasFitIntercept,
933
+ HasTol,
934
+ HasStandardization,
935
+ HasWeightCol,
936
+ HasAggregationDepth,
937
+ HasThreshold,
938
+ HasMaxBlockSizeInMB,
939
+ ):
940
+ """
941
+ Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.
942
+
943
+ .. versionadded:: 3.0.0
944
+ """
945
+
946
+ threshold: Param[float] = Param(
947
+ Params._dummy(),
948
+ "threshold",
949
+ "Threshold in binary classification prediction, in range [0, 1]."
950
+ + " If threshold and thresholds are both set, they must match."
951
+ + "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
952
+ typeConverter=TypeConverters.toFloat,
953
+ )
954
+
955
+ family: Param[str] = Param(
956
+ Params._dummy(),
957
+ "family",
958
+ "The name of family which is a description of the label distribution to "
959
+ + "be used in the model. Supported options: auto, binomial, multinomial",
960
+ typeConverter=TypeConverters.toString,
961
+ )
962
+
963
+ lowerBoundsOnCoefficients: Param[Matrix] = Param(
964
+ Params._dummy(),
965
+ "lowerBoundsOnCoefficients",
966
+ "The lower bounds on coefficients if fitting under bound "
967
+ "constrained optimization. The bound matrix must be "
968
+ "compatible with the shape "
969
+ "(1, number of features) for binomial regression, or "
970
+ "(number of classes, number of features) "
971
+ "for multinomial regression.",
972
+ typeConverter=TypeConverters.toMatrix,
973
+ )
974
+
975
+ upperBoundsOnCoefficients: Param[Matrix] = Param(
976
+ Params._dummy(),
977
+ "upperBoundsOnCoefficients",
978
+ "The upper bounds on coefficients if fitting under bound "
979
+ "constrained optimization. The bound matrix must be "
980
+ "compatible with the shape "
981
+ "(1, number of features) for binomial regression, or "
982
+ "(number of classes, number of features) "
983
+ "for multinomial regression.",
984
+ typeConverter=TypeConverters.toMatrix,
985
+ )
986
+
987
+ lowerBoundsOnIntercepts: Param[Vector] = Param(
988
+ Params._dummy(),
989
+ "lowerBoundsOnIntercepts",
990
+ "The lower bounds on intercepts if fitting under bound "
991
+ "constrained optimization. The bounds vector size must be"
992
+ "equal with 1 for binomial regression, or the number of"
993
+ "lasses for multinomial regression.",
994
+ typeConverter=TypeConverters.toVector,
995
+ )
996
+
997
+ upperBoundsOnIntercepts: Param[Vector] = Param(
998
+ Params._dummy(),
999
+ "upperBoundsOnIntercepts",
1000
+ "The upper bounds on intercepts if fitting under bound "
1001
+ "constrained optimization. The bound vector size must be "
1002
+ "equal with 1 for binomial regression, or the number of "
1003
+ "classes for multinomial regression.",
1004
+ typeConverter=TypeConverters.toVector,
1005
+ )
1006
+
1007
+ def __init__(self, *args: Any):
1008
+ super(_LogisticRegressionParams, self).__init__(*args)
1009
+ self._setDefault(
1010
+ maxIter=100, regParam=0.0, tol=1e-6, threshold=0.5, family="auto", maxBlockSizeInMB=0.0
1011
+ )
1012
+
1013
+ @since("1.4.0")
1014
+ def setThreshold(self: "P", value: float) -> "P":
1015
+ """
1016
+ Sets the value of :py:attr:`threshold`.
1017
+ Clears value of :py:attr:`thresholds` if it has been set.
1018
+ """
1019
+ self._set(threshold=value)
1020
+ self.clear(self.thresholds) # type: ignore[attr-defined]
1021
+ return self
1022
+
1023
+ @since("1.4.0")
1024
+ def getThreshold(self) -> float:
1025
+ """
1026
+ Get threshold for binary classification.
1027
+
1028
+ If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),
1029
+ this returns the equivalent threshold:
1030
+ :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.
1031
+ Otherwise, returns :py:attr:`threshold` if set or its default value if unset.
1032
+ """
1033
+ self._checkThresholdConsistency()
1034
+ if self.isSet(self.thresholds):
1035
+ ts = self.getOrDefault(self.thresholds)
1036
+ if len(ts) != 2:
1037
+ raise ValueError(
1038
+ "Logistic Regression getThreshold only applies to"
1039
+ + " binary classification, but thresholds has length != 2."
1040
+ + " thresholds: {ts}".format(ts=ts)
1041
+ )
1042
+ return 1.0 / (1.0 + ts[0] / ts[1])
1043
+ else:
1044
+ return self.getOrDefault(self.threshold)
1045
+
1046
+ @since("1.5.0")
1047
+ def setThresholds(self: "P", value: List[float]) -> "P":
1048
+ """
1049
+ Sets the value of :py:attr:`thresholds`.
1050
+ Clears value of :py:attr:`threshold` if it has been set.
1051
+ """
1052
+ self._set(thresholds=value)
1053
+ self.clear(self.threshold) # type: ignore[attr-defined]
1054
+ return self
1055
+
1056
+ @since("1.5.0")
1057
+ def getThresholds(self) -> List[float]:
1058
+ """
1059
+ If :py:attr:`thresholds` is set, return its value.
1060
+ Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
1061
+ classification: (1-threshold, threshold).
1062
+ If neither are set, throw an error.
1063
+ """
1064
+ self._checkThresholdConsistency()
1065
+ if not self.isSet(self.thresholds) and self.isSet(self.threshold):
1066
+ t = self.getOrDefault(self.threshold)
1067
+ return [1.0 - t, t]
1068
+ else:
1069
+ return self.getOrDefault(self.thresholds)
1070
+
1071
+ def _checkThresholdConsistency(self) -> None:
1072
+ if self.isSet(self.threshold) and self.isSet(self.thresholds):
1073
+ ts = self.getOrDefault(self.thresholds)
1074
+ if len(ts) != 2:
1075
+ raise ValueError(
1076
+ "Logistic Regression getThreshold only applies to"
1077
+ + " binary classification, but thresholds has length != 2."
1078
+ + " thresholds: {0}".format(str(ts))
1079
+ )
1080
+ t = 1.0 / (1.0 + ts[0] / ts[1])
1081
+ t2 = self.getOrDefault(self.threshold)
1082
+ if abs(t2 - t) >= 1e-5:
1083
+ raise ValueError(
1084
+ "Logistic Regression getThreshold found inconsistent values for"
1085
+ + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)
1086
+ )
1087
+
1088
+ @since("2.1.0")
1089
+ def getFamily(self) -> str:
1090
+ """
1091
+ Gets the value of :py:attr:`family` or its default value.
1092
+ """
1093
+ return self.getOrDefault(self.family)
1094
+
1095
+ @since("2.3.0")
1096
+ def getLowerBoundsOnCoefficients(self) -> Matrix:
1097
+ """
1098
+ Gets the value of :py:attr:`lowerBoundsOnCoefficients`
1099
+ """
1100
+ return self.getOrDefault(self.lowerBoundsOnCoefficients)
1101
+
1102
+ @since("2.3.0")
1103
+ def getUpperBoundsOnCoefficients(self) -> Matrix:
1104
+ """
1105
+ Gets the value of :py:attr:`upperBoundsOnCoefficients`
1106
+ """
1107
+ return self.getOrDefault(self.upperBoundsOnCoefficients)
1108
+
1109
+ @since("2.3.0")
1110
+ def getLowerBoundsOnIntercepts(self) -> Vector:
1111
+ """
1112
+ Gets the value of :py:attr:`lowerBoundsOnIntercepts`
1113
+ """
1114
+ return self.getOrDefault(self.lowerBoundsOnIntercepts)
1115
+
1116
+ @since("2.3.0")
1117
+ def getUpperBoundsOnIntercepts(self) -> Vector:
1118
+ """
1119
+ Gets the value of :py:attr:`upperBoundsOnIntercepts`
1120
+ """
1121
+ return self.getOrDefault(self.upperBoundsOnIntercepts)
1122
+
1123
+
1124
+ @inherit_doc
1125
+ class LogisticRegression(
1126
+ _JavaProbabilisticClassifier["LogisticRegressionModel"],
1127
+ _LogisticRegressionParams,
1128
+ JavaMLWritable,
1129
+ JavaMLReadable["LogisticRegression"],
1130
+ ):
1131
+ """
1132
+ Logistic regression.
1133
+ This class supports multinomial logistic (softmax) and binomial logistic regression.
1134
+
1135
+ .. versionadded:: 1.3.0
1136
+
1137
+ Examples
1138
+ --------
1139
+ >>> from pyspark.sql import Row
1140
+ >>> from pyspark.ml.linalg import Vectors
1141
+ >>> bdf = sc.parallelize([
1142
+ ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),
1143
+ ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
1144
+ ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
1145
+ ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
1146
+ >>> blor = LogisticRegression(weightCol="weight")
1147
+ >>> blor.getRegParam()
1148
+ 0.0
1149
+ >>> blor.setRegParam(0.01)
1150
+ LogisticRegression...
1151
+ >>> blor.getRegParam()
1152
+ 0.01
1153
+ >>> blor.setMaxIter(10)
1154
+ LogisticRegression...
1155
+ >>> blor.getMaxIter()
1156
+ 10
1157
+ >>> blor.clear(blor.maxIter)
1158
+ >>> blorModel = blor.fit(bdf)
1159
+ >>> blorModel.setFeaturesCol("features")
1160
+ LogisticRegressionModel...
1161
+ >>> blorModel.setProbabilityCol("newProbability")
1162
+ LogisticRegressionModel...
1163
+ >>> blorModel.getProbabilityCol()
1164
+ 'newProbability'
1165
+ >>> blorModel.getMaxBlockSizeInMB()
1166
+ 0.0
1167
+ >>> blorModel.setThreshold(0.1)
1168
+ LogisticRegressionModel...
1169
+ >>> blorModel.getThreshold()
1170
+ 0.1
1171
+ >>> blorModel.coefficients
1172
+ DenseVector([-1.080..., -0.646...])
1173
+ >>> blorModel.intercept
1174
+ 3.112...
1175
+ >>> blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracy
1176
+ True
1177
+ >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
1178
+ >>> mdf = spark.read.format("libsvm").load(data_path)
1179
+ >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")
1180
+ >>> mlorModel = mlor.fit(mdf)
1181
+ >>> mlorModel.coefficientMatrix
1182
+ SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)
1183
+ >>> mlorModel.interceptVector
1184
+ DenseVector([0.04..., -0.42..., 0.37...])
1185
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
1186
+ >>> blorModel.predict(test0.head().features)
1187
+ 1.0
1188
+ >>> blorModel.predictRaw(test0.head().features)
1189
+ DenseVector([-3.54..., 3.54...])
1190
+ >>> blorModel.predictProbability(test0.head().features)
1191
+ DenseVector([0.028, 0.972])
1192
+ >>> result = blorModel.transform(test0).head()
1193
+ >>> result.prediction
1194
+ 1.0
1195
+ >>> result.newProbability
1196
+ DenseVector([0.02..., 0.97...])
1197
+ >>> result.rawPrediction
1198
+ DenseVector([-3.54..., 3.54...])
1199
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
1200
+ >>> blorModel.transform(test1).head().prediction
1201
+ 1.0
1202
+ >>> blor.setParams("vector")
1203
+ Traceback (most recent call last):
1204
+ ...
1205
+ TypeError: Method setParams forces keyword arguments.
1206
+ >>> lr_path = temp_path + "/lr"
1207
+ >>> blor.save(lr_path)
1208
+ >>> lr2 = LogisticRegression.load(lr_path)
1209
+ >>> lr2.getRegParam()
1210
+ 0.01
1211
+ >>> model_path = temp_path + "/lr_model"
1212
+ >>> blorModel.save(model_path)
1213
+ >>> model2 = LogisticRegressionModel.load(model_path)
1214
+ >>> blorModel.coefficients[0] == model2.coefficients[0]
1215
+ True
1216
+ >>> blorModel.intercept == model2.intercept
1217
+ True
1218
+ >>> model2
1219
+ LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2
1220
+ >>> blorModel.transform(test0).take(1) == model2.transform(test0).take(1)
1221
+ True
1222
+ """
1223
+
1224
+ _input_kwargs: Dict[str, Any]
1225
+
1226
+ @overload
1227
+ def __init__(
1228
+ self,
1229
+ *,
1230
+ featuresCol: str = ...,
1231
+ labelCol: str = ...,
1232
+ predictionCol: str = ...,
1233
+ maxIter: int = ...,
1234
+ regParam: float = ...,
1235
+ elasticNetParam: float = ...,
1236
+ tol: float = ...,
1237
+ fitIntercept: bool = ...,
1238
+ threshold: float = ...,
1239
+ probabilityCol: str = ...,
1240
+ rawPredictionCol: str = ...,
1241
+ standardization: bool = ...,
1242
+ weightCol: Optional[str] = ...,
1243
+ aggregationDepth: int = ...,
1244
+ family: str = ...,
1245
+ lowerBoundsOnCoefficients: Optional[Matrix] = ...,
1246
+ upperBoundsOnCoefficients: Optional[Matrix] = ...,
1247
+ lowerBoundsOnIntercepts: Optional[Vector] = ...,
1248
+ upperBoundsOnIntercepts: Optional[Vector] = ...,
1249
+ maxBlockSizeInMB: float = ...,
1250
+ ):
1251
+ ...
1252
+
1253
+ @overload
1254
+ def __init__(
1255
+ self,
1256
+ *,
1257
+ featuresCol: str = ...,
1258
+ labelCol: str = ...,
1259
+ predictionCol: str = ...,
1260
+ maxIter: int = ...,
1261
+ regParam: float = ...,
1262
+ elasticNetParam: float = ...,
1263
+ tol: float = ...,
1264
+ fitIntercept: bool = ...,
1265
+ thresholds: Optional[List[float]] = ...,
1266
+ probabilityCol: str = ...,
1267
+ rawPredictionCol: str = ...,
1268
+ standardization: bool = ...,
1269
+ weightCol: Optional[str] = ...,
1270
+ aggregationDepth: int = ...,
1271
+ family: str = ...,
1272
+ lowerBoundsOnCoefficients: Optional[Matrix] = ...,
1273
+ upperBoundsOnCoefficients: Optional[Matrix] = ...,
1274
+ lowerBoundsOnIntercepts: Optional[Vector] = ...,
1275
+ upperBoundsOnIntercepts: Optional[Vector] = ...,
1276
+ maxBlockSizeInMB: float = ...,
1277
+ ):
1278
+ ...
1279
+
1280
+ @keyword_only
1281
+ def __init__(
1282
+ self,
1283
+ *,
1284
+ featuresCol: str = "features",
1285
+ labelCol: str = "label",
1286
+ predictionCol: str = "prediction",
1287
+ maxIter: int = 100,
1288
+ regParam: float = 0.0,
1289
+ elasticNetParam: float = 0.0,
1290
+ tol: float = 1e-6,
1291
+ fitIntercept: bool = True,
1292
+ threshold: float = 0.5,
1293
+ thresholds: Optional[List[float]] = None,
1294
+ probabilityCol: str = "probability",
1295
+ rawPredictionCol: str = "rawPrediction",
1296
+ standardization: bool = True,
1297
+ weightCol: Optional[str] = None,
1298
+ aggregationDepth: int = 2,
1299
+ family: str = "auto",
1300
+ lowerBoundsOnCoefficients: Optional[Matrix] = None,
1301
+ upperBoundsOnCoefficients: Optional[Matrix] = None,
1302
+ lowerBoundsOnIntercepts: Optional[Vector] = None,
1303
+ upperBoundsOnIntercepts: Optional[Vector] = None,
1304
+ maxBlockSizeInMB: float = 0.0,
1305
+ ):
1306
+ """
1307
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1308
+ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
1309
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
1310
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
1311
+ aggregationDepth=2, family="auto", \
1312
+ lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
1313
+ lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \
1314
+ maxBlockSizeInMB=0.0):
1315
+ If the threshold and thresholds Params are both set, they must be equivalent.
1316
+ """
1317
+ super(LogisticRegression, self).__init__()
1318
+ self._java_obj = self._new_java_obj(
1319
+ "org.apache.spark.ml.classification.LogisticRegression", self.uid
1320
+ )
1321
+ kwargs = self._input_kwargs
1322
+ self.setParams(**kwargs)
1323
+ self._checkThresholdConsistency()
1324
+
1325
+ @overload
1326
+ def setParams(
1327
+ self,
1328
+ *,
1329
+ featuresCol: str = ...,
1330
+ labelCol: str = ...,
1331
+ predictionCol: str = ...,
1332
+ maxIter: int = ...,
1333
+ regParam: float = ...,
1334
+ elasticNetParam: float = ...,
1335
+ tol: float = ...,
1336
+ fitIntercept: bool = ...,
1337
+ threshold: float = ...,
1338
+ probabilityCol: str = ...,
1339
+ rawPredictionCol: str = ...,
1340
+ standardization: bool = ...,
1341
+ weightCol: Optional[str] = ...,
1342
+ aggregationDepth: int = ...,
1343
+ family: str = ...,
1344
+ lowerBoundsOnCoefficients: Optional[Matrix] = ...,
1345
+ upperBoundsOnCoefficients: Optional[Matrix] = ...,
1346
+ lowerBoundsOnIntercepts: Optional[Vector] = ...,
1347
+ upperBoundsOnIntercepts: Optional[Vector] = ...,
1348
+ maxBlockSizeInMB: float = ...,
1349
+ ) -> "LogisticRegression":
1350
+ ...
1351
+
1352
+ @overload
1353
+ def setParams(
1354
+ self,
1355
+ *,
1356
+ featuresCol: str = ...,
1357
+ labelCol: str = ...,
1358
+ predictionCol: str = ...,
1359
+ maxIter: int = ...,
1360
+ regParam: float = ...,
1361
+ elasticNetParam: float = ...,
1362
+ tol: float = ...,
1363
+ fitIntercept: bool = ...,
1364
+ thresholds: Optional[List[float]] = ...,
1365
+ probabilityCol: str = ...,
1366
+ rawPredictionCol: str = ...,
1367
+ standardization: bool = ...,
1368
+ weightCol: Optional[str] = ...,
1369
+ aggregationDepth: int = ...,
1370
+ family: str = ...,
1371
+ lowerBoundsOnCoefficients: Optional[Matrix] = ...,
1372
+ upperBoundsOnCoefficients: Optional[Matrix] = ...,
1373
+ lowerBoundsOnIntercepts: Optional[Vector] = ...,
1374
+ upperBoundsOnIntercepts: Optional[Vector] = ...,
1375
+ maxBlockSizeInMB: float = ...,
1376
+ ) -> "LogisticRegression":
1377
+ ...
1378
+
1379
+ @keyword_only
1380
+ @since("1.3.0")
1381
+ def setParams(
1382
+ self,
1383
+ *,
1384
+ featuresCol: str = "features",
1385
+ labelCol: str = "label",
1386
+ predictionCol: str = "prediction",
1387
+ maxIter: int = 100,
1388
+ regParam: float = 0.0,
1389
+ elasticNetParam: float = 0.0,
1390
+ tol: float = 1e-6,
1391
+ fitIntercept: bool = True,
1392
+ threshold: float = 0.5,
1393
+ thresholds: Optional[List[float]] = None,
1394
+ probabilityCol: str = "probability",
1395
+ rawPredictionCol: str = "rawPrediction",
1396
+ standardization: bool = True,
1397
+ weightCol: Optional[str] = None,
1398
+ aggregationDepth: int = 2,
1399
+ family: str = "auto",
1400
+ lowerBoundsOnCoefficients: Optional[Matrix] = None,
1401
+ upperBoundsOnCoefficients: Optional[Matrix] = None,
1402
+ lowerBoundsOnIntercepts: Optional[Vector] = None,
1403
+ upperBoundsOnIntercepts: Optional[Vector] = None,
1404
+ maxBlockSizeInMB: float = 0.0,
1405
+ ) -> "LogisticRegression":
1406
+ """
1407
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1408
+ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
1409
+ threshold=0.5, thresholds=None, probabilityCol="probability", \
1410
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
1411
+ aggregationDepth=2, family="auto", \
1412
+ lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
1413
+ lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \
1414
+ maxBlockSizeInMB=0.0):
1415
+ Sets params for logistic regression.
1416
+ If the threshold and thresholds Params are both set, they must be equivalent.
1417
+ """
1418
+ kwargs = self._input_kwargs
1419
+ self._set(**kwargs)
1420
+ self._checkThresholdConsistency()
1421
+ return self
1422
+
1423
+ def _create_model(self, java_model: "JavaObject") -> "LogisticRegressionModel":
1424
+ return LogisticRegressionModel(java_model)
1425
+
1426
+ @since("2.1.0")
1427
+ def setFamily(self, value: str) -> "LogisticRegression":
1428
+ """
1429
+ Sets the value of :py:attr:`family`.
1430
+ """
1431
+ return self._set(family=value)
1432
+
1433
+ @since("2.3.0")
1434
+ def setLowerBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression":
1435
+ """
1436
+ Sets the value of :py:attr:`lowerBoundsOnCoefficients`
1437
+ """
1438
+ return self._set(lowerBoundsOnCoefficients=value)
1439
+
1440
+ @since("2.3.0")
1441
+ def setUpperBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression":
1442
+ """
1443
+ Sets the value of :py:attr:`upperBoundsOnCoefficients`
1444
+ """
1445
+ return self._set(upperBoundsOnCoefficients=value)
1446
+
1447
+ @since("2.3.0")
1448
+ def setLowerBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression":
1449
+ """
1450
+ Sets the value of :py:attr:`lowerBoundsOnIntercepts`
1451
+ """
1452
+ return self._set(lowerBoundsOnIntercepts=value)
1453
+
1454
+ @since("2.3.0")
1455
+ def setUpperBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression":
1456
+ """
1457
+ Sets the value of :py:attr:`upperBoundsOnIntercepts`
1458
+ """
1459
+ return self._set(upperBoundsOnIntercepts=value)
1460
+
1461
+ def setMaxIter(self, value: int) -> "LogisticRegression":
1462
+ """
1463
+ Sets the value of :py:attr:`maxIter`.
1464
+ """
1465
+ return self._set(maxIter=value)
1466
+
1467
+ def setRegParam(self, value: float) -> "LogisticRegression":
1468
+ """
1469
+ Sets the value of :py:attr:`regParam`.
1470
+ """
1471
+ return self._set(regParam=value)
1472
+
1473
+ def setTol(self, value: float) -> "LogisticRegression":
1474
+ """
1475
+ Sets the value of :py:attr:`tol`.
1476
+ """
1477
+ return self._set(tol=value)
1478
+
1479
+ def setElasticNetParam(self, value: float) -> "LogisticRegression":
1480
+ """
1481
+ Sets the value of :py:attr:`elasticNetParam`.
1482
+ """
1483
+ return self._set(elasticNetParam=value)
1484
+
1485
+ def setFitIntercept(self, value: bool) -> "LogisticRegression":
1486
+ """
1487
+ Sets the value of :py:attr:`fitIntercept`.
1488
+ """
1489
+ return self._set(fitIntercept=value)
1490
+
1491
+ def setStandardization(self, value: bool) -> "LogisticRegression":
1492
+ """
1493
+ Sets the value of :py:attr:`standardization`.
1494
+ """
1495
+ return self._set(standardization=value)
1496
+
1497
+ def setWeightCol(self, value: str) -> "LogisticRegression":
1498
+ """
1499
+ Sets the value of :py:attr:`weightCol`.
1500
+ """
1501
+ return self._set(weightCol=value)
1502
+
1503
+ def setAggregationDepth(self, value: int) -> "LogisticRegression":
1504
+ """
1505
+ Sets the value of :py:attr:`aggregationDepth`.
1506
+ """
1507
+ return self._set(aggregationDepth=value)
1508
+
1509
+ @since("3.1.0")
1510
+ def setMaxBlockSizeInMB(self, value: float) -> "LogisticRegression":
1511
+ """
1512
+ Sets the value of :py:attr:`maxBlockSizeInMB`.
1513
+ """
1514
+ return self._set(maxBlockSizeInMB=value)
1515
+
1516
+
1517
+ class LogisticRegressionModel(
1518
+ _JavaProbabilisticClassificationModel[Vector],
1519
+ _LogisticRegressionParams,
1520
+ JavaMLWritable,
1521
+ JavaMLReadable["LogisticRegressionModel"],
1522
+ HasTrainingSummary["LogisticRegressionTrainingSummary"],
1523
+ ):
1524
+ """
1525
+ Model fitted by LogisticRegression.
1526
+
1527
+ .. versionadded:: 1.3.0
1528
+ """
1529
+
1530
+ @property
1531
+ @since("2.0.0")
1532
+ def coefficients(self) -> Vector:
1533
+ """
1534
+ Model coefficients of binomial logistic regression.
1535
+ An exception is thrown in the case of multinomial logistic regression.
1536
+ """
1537
+ return self._call_java("coefficients")
1538
+
1539
+ @property
1540
+ @since("1.4.0")
1541
+ def intercept(self) -> float:
1542
+ """
1543
+ Model intercept of binomial logistic regression.
1544
+ An exception is thrown in the case of multinomial logistic regression.
1545
+ """
1546
+ return self._call_java("intercept")
1547
+
1548
+ @property
1549
+ @since("2.1.0")
1550
+ def coefficientMatrix(self) -> Matrix:
1551
+ """
1552
+ Model coefficients.
1553
+ """
1554
+ return self._call_java("coefficientMatrix")
1555
+
1556
+ @property
1557
+ @since("2.1.0")
1558
+ def interceptVector(self) -> Vector:
1559
+ """
1560
+ Model intercept.
1561
+ """
1562
+ return self._call_java("interceptVector")
1563
+
1564
+ @property
1565
+ @since("2.0.0")
1566
+ def summary(self) -> "LogisticRegressionTrainingSummary":
1567
+ """
1568
+ Gets summary (accuracy/precision/recall, objective history, total iterations) of model
1569
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
1570
+ """
1571
+ if self.hasSummary:
1572
+ if self.numClasses <= 2:
1573
+ return BinaryLogisticRegressionTrainingSummary(
1574
+ super(LogisticRegressionModel, self).summary
1575
+ )
1576
+ else:
1577
+ return LogisticRegressionTrainingSummary(
1578
+ super(LogisticRegressionModel, self).summary
1579
+ )
1580
+ else:
1581
+ raise RuntimeError(
1582
+ "No training summary available for this %s" % self.__class__.__name__
1583
+ )
1584
+
1585
+ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary":
1586
+ """
1587
+ Evaluates the model on a test dataset.
1588
+
1589
+ .. versionadded:: 2.0.0
1590
+
1591
+ Parameters
1592
+ ----------
1593
+ dataset : :py:class:`pyspark.sql.DataFrame`
1594
+ Test dataset to evaluate model on.
1595
+ """
1596
+ if not isinstance(dataset, DataFrame):
1597
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
1598
+ java_blr_summary = self._call_java("evaluate", dataset)
1599
+ if self.numClasses <= 2:
1600
+ return BinaryLogisticRegressionSummary(java_blr_summary)
1601
+ else:
1602
+ return LogisticRegressionSummary(java_blr_summary)
1603
+
1604
+
1605
+ class LogisticRegressionSummary(_ClassificationSummary):
1606
+ """
1607
+ Abstraction for Logistic Regression Results for a given model.
1608
+
1609
+ .. versionadded:: 2.0.0
1610
+ """
1611
+
1612
+ @property
1613
+ @since("2.0.0")
1614
+ def probabilityCol(self) -> str:
1615
+ """
1616
+ Field in "predictions" which gives the probability
1617
+ of each class as a vector.
1618
+ """
1619
+ return self._call_java("probabilityCol")
1620
+
1621
+ @property
1622
+ @since("2.0.0")
1623
+ def featuresCol(self) -> str:
1624
+ """
1625
+ Field in "predictions" which gives the features of each instance
1626
+ as a vector.
1627
+ """
1628
+ return self._call_java("featuresCol")
1629
+
1630
+
1631
+ @inherit_doc
1632
+ class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary):
1633
+ """
1634
+ Abstraction for multinomial Logistic Regression Training results.
1635
+
1636
+ .. versionadded:: 2.0.0
1637
+ """
1638
+
1639
+ pass
1640
+
1641
+
1642
+ @inherit_doc
1643
+ class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary):
1644
+ """
1645
+ Binary Logistic regression results for a given model.
1646
+
1647
+ .. versionadded:: 2.0.0
1648
+ """
1649
+
1650
+ pass
1651
+
1652
+
1653
+ @inherit_doc
1654
+ class BinaryLogisticRegressionTrainingSummary(
1655
+ BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
1656
+ ):
1657
+ """
1658
+ Binary Logistic regression training results for a given model.
1659
+
1660
+ .. versionadded:: 2.0.0
1661
+ """
1662
+
1663
+ pass
1664
+
1665
+
1666
+ @inherit_doc
1667
+ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
1668
+ """
1669
+ Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
1670
+ """
1671
+
1672
+ def __init__(self, *args: Any):
1673
+ super(_DecisionTreeClassifierParams, self).__init__(*args)
1674
+ self._setDefault(
1675
+ maxDepth=5,
1676
+ maxBins=32,
1677
+ minInstancesPerNode=1,
1678
+ minInfoGain=0.0,
1679
+ maxMemoryInMB=256,
1680
+ cacheNodeIds=False,
1681
+ checkpointInterval=10,
1682
+ impurity="gini",
1683
+ leafCol="",
1684
+ minWeightFractionPerNode=0.0,
1685
+ )
1686
+
1687
+
1688
+ @inherit_doc
1689
+ class DecisionTreeClassifier(
1690
+ _JavaProbabilisticClassifier["DecisionTreeClassificationModel"],
1691
+ _DecisionTreeClassifierParams,
1692
+ JavaMLWritable,
1693
+ JavaMLReadable["DecisionTreeClassifier"],
1694
+ ):
1695
+ """
1696
+ `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
1697
+ learning algorithm for classification.
1698
+ It supports both binary and multiclass labels, as well as both continuous and categorical
1699
+ features.
1700
+
1701
+ .. versionadded:: 1.4.0
1702
+
1703
+ Examples
1704
+ --------
1705
+ >>> from pyspark.ml.linalg import Vectors
1706
+ >>> from pyspark.ml.feature import StringIndexer
1707
+ >>> df = spark.createDataFrame([
1708
+ ... (1.0, Vectors.dense(1.0)),
1709
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1710
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1711
+ >>> si_model = stringIndexer.fit(df)
1712
+ >>> td = si_model.transform(df)
1713
+ >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")
1714
+ >>> model = dt.fit(td)
1715
+ >>> model.getLabelCol()
1716
+ 'indexed'
1717
+ >>> model.setFeaturesCol("features")
1718
+ DecisionTreeClassificationModel...
1719
+ >>> model.numNodes
1720
+ 3
1721
+ >>> model.depth
1722
+ 1
1723
+ >>> model.featureImportances
1724
+ SparseVector(1, {0: 1.0})
1725
+ >>> model.numFeatures
1726
+ 1
1727
+ >>> model.numClasses
1728
+ 2
1729
+ >>> print(model.toDebugString)
1730
+ DecisionTreeClassificationModel...depth=1, numNodes=3...
1731
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1732
+ >>> model.predict(test0.head().features)
1733
+ 0.0
1734
+ >>> model.predictRaw(test0.head().features)
1735
+ DenseVector([1.0, 0.0])
1736
+ >>> model.predictProbability(test0.head().features)
1737
+ DenseVector([1.0, 0.0])
1738
+ >>> result = model.transform(test0).head()
1739
+ >>> result.prediction
1740
+ 0.0
1741
+ >>> result.probability
1742
+ DenseVector([1.0, 0.0])
1743
+ >>> result.rawPrediction
1744
+ DenseVector([1.0, 0.0])
1745
+ >>> result.leafId
1746
+ 0.0
1747
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1748
+ >>> model.transform(test1).head().prediction
1749
+ 1.0
1750
+ >>> dtc_path = temp_path + "/dtc"
1751
+ >>> dt.save(dtc_path)
1752
+ >>> dt2 = DecisionTreeClassifier.load(dtc_path)
1753
+ >>> dt2.getMaxDepth()
1754
+ 2
1755
+ >>> model_path = temp_path + "/dtc_model"
1756
+ >>> model.save(model_path)
1757
+ >>> model2 = DecisionTreeClassificationModel.load(model_path)
1758
+ >>> model.featureImportances == model2.featureImportances
1759
+ True
1760
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
1761
+ True
1762
+ >>> df3 = spark.createDataFrame([
1763
+ ... (1.0, 0.2, Vectors.dense(1.0)),
1764
+ ... (1.0, 0.8, Vectors.dense(1.0)),
1765
+ ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
1766
+ >>> si3 = StringIndexer(inputCol="label", outputCol="indexed")
1767
+ >>> si_model3 = si3.fit(df3)
1768
+ >>> td3 = si_model3.transform(df3)
1769
+ >>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
1770
+ >>> model3 = dt3.fit(td3)
1771
+ >>> print(model3.toDebugString)
1772
+ DecisionTreeClassificationModel...depth=1, numNodes=3...
1773
+ """
1774
+
1775
+ _input_kwargs: Dict[str, Any]
1776
+
1777
+ @keyword_only
1778
+ def __init__(
1779
+ self,
1780
+ *,
1781
+ featuresCol: str = "features",
1782
+ labelCol: str = "label",
1783
+ predictionCol: str = "prediction",
1784
+ probabilityCol: str = "probability",
1785
+ rawPredictionCol: str = "rawPrediction",
1786
+ maxDepth: int = 5,
1787
+ maxBins: int = 32,
1788
+ minInstancesPerNode: int = 1,
1789
+ minInfoGain: float = 0.0,
1790
+ maxMemoryInMB: int = 256,
1791
+ cacheNodeIds: bool = False,
1792
+ checkpointInterval: int = 10,
1793
+ impurity: str = "gini",
1794
+ seed: Optional[int] = None,
1795
+ weightCol: Optional[str] = None,
1796
+ leafCol: str = "",
1797
+ minWeightFractionPerNode: float = 0.0,
1798
+ ):
1799
+ """
1800
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1801
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
1802
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1803
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1804
+ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1805
+ """
1806
+ super(DecisionTreeClassifier, self).__init__()
1807
+ self._java_obj = self._new_java_obj(
1808
+ "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid
1809
+ )
1810
+ kwargs = self._input_kwargs
1811
+ self.setParams(**kwargs)
1812
+
1813
+ @keyword_only
1814
+ @since("1.4.0")
1815
+ def setParams(
1816
+ self,
1817
+ *,
1818
+ featuresCol: str = "features",
1819
+ labelCol: str = "label",
1820
+ predictionCol: str = "prediction",
1821
+ probabilityCol: str = "probability",
1822
+ rawPredictionCol: str = "rawPrediction",
1823
+ maxDepth: int = 5,
1824
+ maxBins: int = 32,
1825
+ minInstancesPerNode: int = 1,
1826
+ minInfoGain: float = 0.0,
1827
+ maxMemoryInMB: int = 256,
1828
+ cacheNodeIds: bool = False,
1829
+ checkpointInterval: int = 10,
1830
+ impurity: str = "gini",
1831
+ seed: Optional[int] = None,
1832
+ weightCol: Optional[str] = None,
1833
+ leafCol: str = "",
1834
+ minWeightFractionPerNode: float = 0.0,
1835
+ ) -> "DecisionTreeClassifier":
1836
+ """
1837
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
1838
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
1839
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1840
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1841
+ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1842
+ Sets params for the DecisionTreeClassifier.
1843
+ """
1844
+ kwargs = self._input_kwargs
1845
+ return self._set(**kwargs)
1846
+
1847
+ def _create_model(self, java_model: "JavaObject") -> "DecisionTreeClassificationModel":
1848
+ return DecisionTreeClassificationModel(java_model)
1849
+
1850
+ def setMaxDepth(self, value: int) -> "DecisionTreeClassifier":
1851
+ """
1852
+ Sets the value of :py:attr:`maxDepth`.
1853
+ """
1854
+ return self._set(maxDepth=value)
1855
+
1856
+ def setMaxBins(self, value: int) -> "DecisionTreeClassifier":
1857
+ """
1858
+ Sets the value of :py:attr:`maxBins`.
1859
+ """
1860
+ return self._set(maxBins=value)
1861
+
1862
+ def setMinInstancesPerNode(self, value: int) -> "DecisionTreeClassifier":
1863
+ """
1864
+ Sets the value of :py:attr:`minInstancesPerNode`.
1865
+ """
1866
+ return self._set(minInstancesPerNode=value)
1867
+
1868
+ @since("3.0.0")
1869
+ def setMinWeightFractionPerNode(self, value: float) -> "DecisionTreeClassifier":
1870
+ """
1871
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
1872
+ """
1873
+ return self._set(minWeightFractionPerNode=value)
1874
+
1875
+ def setMinInfoGain(self, value: float) -> "DecisionTreeClassifier":
1876
+ """
1877
+ Sets the value of :py:attr:`minInfoGain`.
1878
+ """
1879
+ return self._set(minInfoGain=value)
1880
+
1881
+ def setMaxMemoryInMB(self, value: int) -> "DecisionTreeClassifier":
1882
+ """
1883
+ Sets the value of :py:attr:`maxMemoryInMB`.
1884
+ """
1885
+ return self._set(maxMemoryInMB=value)
1886
+
1887
+ def setCacheNodeIds(self, value: bool) -> "DecisionTreeClassifier":
1888
+ """
1889
+ Sets the value of :py:attr:`cacheNodeIds`.
1890
+ """
1891
+ return self._set(cacheNodeIds=value)
1892
+
1893
+ @since("1.4.0")
1894
+ def setImpurity(self, value: str) -> "DecisionTreeClassifier":
1895
+ """
1896
+ Sets the value of :py:attr:`impurity`.
1897
+ """
1898
+ return self._set(impurity=value)
1899
+
1900
+ @since("1.4.0")
1901
+ def setCheckpointInterval(self, value: int) -> "DecisionTreeClassifier":
1902
+ """
1903
+ Sets the value of :py:attr:`checkpointInterval`.
1904
+ """
1905
+ return self._set(checkpointInterval=value)
1906
+
1907
+ def setSeed(self, value: int) -> "DecisionTreeClassifier":
1908
+ """
1909
+ Sets the value of :py:attr:`seed`.
1910
+ """
1911
+ return self._set(seed=value)
1912
+
1913
+ @since("3.0.0")
1914
+ def setWeightCol(self, value: str) -> "DecisionTreeClassifier":
1915
+ """
1916
+ Sets the value of :py:attr:`weightCol`.
1917
+ """
1918
+ return self._set(weightCol=value)
1919
+
1920
+
1921
+ @inherit_doc
1922
+ class DecisionTreeClassificationModel(
1923
+ _DecisionTreeModel,
1924
+ _JavaProbabilisticClassificationModel[Vector],
1925
+ _DecisionTreeClassifierParams,
1926
+ JavaMLWritable,
1927
+ JavaMLReadable["DecisionTreeClassificationModel"],
1928
+ ):
1929
+ """
1930
+ Model fitted by DecisionTreeClassifier.
1931
+
1932
+ .. versionadded:: 1.4.0
1933
+ """
1934
+
1935
+ @property
1936
+ def featureImportances(self) -> Vector:
1937
+ """
1938
+ Estimate of the importance of each feature.
1939
+
1940
+ This generalizes the idea of "Gini" importance to other losses,
1941
+ following the explanation of Gini importance from "Random Forests" documentation
1942
+ by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1943
+
1944
+ This feature importance is calculated as follows:
1945
+ - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1946
+ where gain is scaled by the number of instances passing through node
1947
+ - Normalize importances for tree to sum to 1.
1948
+
1949
+ .. versionadded:: 2.0.0
1950
+
1951
+ Notes
1952
+ -----
1953
+ Feature importance for single decision trees can have high variance due to
1954
+ correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
1955
+ to determine feature importance instead.
1956
+ """
1957
+ return self._call_java("featureImportances")
1958
+
1959
+
1960
+ @inherit_doc
1961
+ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
1962
+ """
1963
+ Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
1964
+ """
1965
+
1966
+ def __init__(self, *args: Any):
1967
+ super(_RandomForestClassifierParams, self).__init__(*args)
1968
+ self._setDefault(
1969
+ maxDepth=5,
1970
+ maxBins=32,
1971
+ minInstancesPerNode=1,
1972
+ minInfoGain=0.0,
1973
+ maxMemoryInMB=256,
1974
+ cacheNodeIds=False,
1975
+ checkpointInterval=10,
1976
+ impurity="gini",
1977
+ numTrees=20,
1978
+ featureSubsetStrategy="auto",
1979
+ subsamplingRate=1.0,
1980
+ leafCol="",
1981
+ minWeightFractionPerNode=0.0,
1982
+ bootstrap=True,
1983
+ )
1984
+
1985
+
1986
+ @inherit_doc
1987
+ class RandomForestClassifier(
1988
+ _JavaProbabilisticClassifier["RandomForestClassificationModel"],
1989
+ _RandomForestClassifierParams,
1990
+ JavaMLWritable,
1991
+ JavaMLReadable["RandomForestClassifier"],
1992
+ ):
1993
+ """
1994
+ `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1995
+ learning algorithm for classification.
1996
+ It supports both binary and multiclass labels, as well as both continuous and categorical
1997
+ features.
1998
+
1999
+ .. versionadded:: 1.4.0
2000
+
2001
+ Examples
2002
+ --------
2003
+ >>> import numpy
2004
+ >>> from numpy import allclose
2005
+ >>> from pyspark.ml.linalg import Vectors
2006
+ >>> from pyspark.ml.feature import StringIndexer
2007
+ >>> df = spark.createDataFrame([
2008
+ ... (1.0, Vectors.dense(1.0)),
2009
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2010
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
2011
+ >>> si_model = stringIndexer.fit(df)
2012
+ >>> td = si_model.transform(df)
2013
+ >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
2014
+ ... leafCol="leafId")
2015
+ >>> rf.getMinWeightFractionPerNode()
2016
+ 0.0
2017
+ >>> model = rf.fit(td)
2018
+ >>> model.getLabelCol()
2019
+ 'indexed'
2020
+ >>> model.setFeaturesCol("features")
2021
+ RandomForestClassificationModel...
2022
+ >>> model.setRawPredictionCol("newRawPrediction")
2023
+ RandomForestClassificationModel...
2024
+ >>> model.getBootstrap()
2025
+ True
2026
+ >>> model.getRawPredictionCol()
2027
+ 'newRawPrediction'
2028
+ >>> model.featureImportances
2029
+ SparseVector(1, {0: 1.0})
2030
+ >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
2031
+ True
2032
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
2033
+ >>> model.predict(test0.head().features)
2034
+ 0.0
2035
+ >>> model.predictRaw(test0.head().features)
2036
+ DenseVector([2.0, 0.0])
2037
+ >>> model.predictProbability(test0.head().features)
2038
+ DenseVector([1.0, 0.0])
2039
+ >>> result = model.transform(test0).head()
2040
+ >>> result.prediction
2041
+ 0.0
2042
+ >>> numpy.argmax(result.probability)
2043
+ 0
2044
+ >>> numpy.argmax(result.newRawPrediction)
2045
+ 0
2046
+ >>> result.leafId
2047
+ DenseVector([0.0, 0.0, 0.0])
2048
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
2049
+ >>> model.transform(test1).head().prediction
2050
+ 1.0
2051
+ >>> model.trees
2052
+ [DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]
2053
+ >>> rfc_path = temp_path + "/rfc"
2054
+ >>> rf.save(rfc_path)
2055
+ >>> rf2 = RandomForestClassifier.load(rfc_path)
2056
+ >>> rf2.getNumTrees()
2057
+ 3
2058
+ >>> model_path = temp_path + "/rfc_model"
2059
+ >>> model.save(model_path)
2060
+ >>> model2 = RandomForestClassificationModel.load(model_path)
2061
+ >>> model.featureImportances == model2.featureImportances
2062
+ True
2063
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
2064
+ True
2065
+ """
2066
+
2067
+ _input_kwargs: Dict[str, Any]
2068
+
2069
+ @keyword_only
2070
+ def __init__(
2071
+ self,
2072
+ *,
2073
+ featuresCol: str = "features",
2074
+ labelCol: str = "label",
2075
+ predictionCol: str = "prediction",
2076
+ probabilityCol: str = "probability",
2077
+ rawPredictionCol: str = "rawPrediction",
2078
+ maxDepth: int = 5,
2079
+ maxBins: int = 32,
2080
+ minInstancesPerNode: int = 1,
2081
+ minInfoGain: float = 0.0,
2082
+ maxMemoryInMB: int = 256,
2083
+ cacheNodeIds: bool = False,
2084
+ checkpointInterval: int = 10,
2085
+ impurity: str = "gini",
2086
+ numTrees: int = 20,
2087
+ featureSubsetStrategy: str = "auto",
2088
+ seed: Optional[int] = None,
2089
+ subsamplingRate: float = 1.0,
2090
+ leafCol: str = "",
2091
+ minWeightFractionPerNode: float = 0.0,
2092
+ weightCol: Optional[str] = None,
2093
+ bootstrap: Optional[bool] = True,
2094
+ ):
2095
+ """
2096
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2097
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
2098
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
2099
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
2100
+ numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
2101
+ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
2102
+ """
2103
+ super(RandomForestClassifier, self).__init__()
2104
+ self._java_obj = self._new_java_obj(
2105
+ "org.apache.spark.ml.classification.RandomForestClassifier", self.uid
2106
+ )
2107
+ kwargs = self._input_kwargs
2108
+ self.setParams(**kwargs)
2109
+
2110
+ @keyword_only
2111
+ @since("1.4.0")
2112
+ def setParams(
2113
+ self,
2114
+ *,
2115
+ featuresCol: str = "features",
2116
+ labelCol: str = "label",
2117
+ predictionCol: str = "prediction",
2118
+ probabilityCol: str = "probability",
2119
+ rawPredictionCol: str = "rawPrediction",
2120
+ maxDepth: int = 5,
2121
+ maxBins: int = 32,
2122
+ minInstancesPerNode: int = 1,
2123
+ minInfoGain: float = 0.0,
2124
+ maxMemoryInMB: int = 256,
2125
+ cacheNodeIds: bool = False,
2126
+ checkpointInterval: int = 10,
2127
+ impurity: str = "gini",
2128
+ numTrees: int = 20,
2129
+ featureSubsetStrategy: str = "auto",
2130
+ seed: Optional[int] = None,
2131
+ subsamplingRate: float = 1.0,
2132
+ leafCol: str = "",
2133
+ minWeightFractionPerNode: float = 0.0,
2134
+ weightCol: Optional[str] = None,
2135
+ bootstrap: Optional[bool] = True,
2136
+ ) -> "RandomForestClassifier":
2137
+ """
2138
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2139
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
2140
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
2141
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
2142
+ impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
2143
+ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
2144
+ Sets params for linear classification.
2145
+ """
2146
+ kwargs = self._input_kwargs
2147
+ return self._set(**kwargs)
2148
+
2149
+ def _create_model(self, java_model: "JavaObject") -> "RandomForestClassificationModel":
2150
+ return RandomForestClassificationModel(java_model)
2151
+
2152
+ def setMaxDepth(self, value: int) -> "RandomForestClassifier":
2153
+ """
2154
+ Sets the value of :py:attr:`maxDepth`.
2155
+ """
2156
+ return self._set(maxDepth=value)
2157
+
2158
+ def setMaxBins(self, value: int) -> "RandomForestClassifier":
2159
+ """
2160
+ Sets the value of :py:attr:`maxBins`.
2161
+ """
2162
+ return self._set(maxBins=value)
2163
+
2164
+ def setMinInstancesPerNode(self, value: int) -> "RandomForestClassifier":
2165
+ """
2166
+ Sets the value of :py:attr:`minInstancesPerNode`.
2167
+ """
2168
+ return self._set(minInstancesPerNode=value)
2169
+
2170
+ def setMinInfoGain(self, value: float) -> "RandomForestClassifier":
2171
+ """
2172
+ Sets the value of :py:attr:`minInfoGain`.
2173
+ """
2174
+ return self._set(minInfoGain=value)
2175
+
2176
+ def setMaxMemoryInMB(self, value: int) -> "RandomForestClassifier":
2177
+ """
2178
+ Sets the value of :py:attr:`maxMemoryInMB`.
2179
+ """
2180
+ return self._set(maxMemoryInMB=value)
2181
+
2182
+ def setCacheNodeIds(self, value: bool) -> "RandomForestClassifier":
2183
+ """
2184
+ Sets the value of :py:attr:`cacheNodeIds`.
2185
+ """
2186
+ return self._set(cacheNodeIds=value)
2187
+
2188
+ @since("1.4.0")
2189
+ def setImpurity(self, value: str) -> "RandomForestClassifier":
2190
+ """
2191
+ Sets the value of :py:attr:`impurity`.
2192
+ """
2193
+ return self._set(impurity=value)
2194
+
2195
+ @since("1.4.0")
2196
+ def setNumTrees(self, value: int) -> "RandomForestClassifier":
2197
+ """
2198
+ Sets the value of :py:attr:`numTrees`.
2199
+ """
2200
+ return self._set(numTrees=value)
2201
+
2202
+ @since("3.0.0")
2203
+ def setBootstrap(self, value: bool) -> "RandomForestClassifier":
2204
+ """
2205
+ Sets the value of :py:attr:`bootstrap`.
2206
+ """
2207
+ return self._set(bootstrap=value)
2208
+
2209
+ @since("1.4.0")
2210
+ def setSubsamplingRate(self, value: float) -> "RandomForestClassifier":
2211
+ """
2212
+ Sets the value of :py:attr:`subsamplingRate`.
2213
+ """
2214
+ return self._set(subsamplingRate=value)
2215
+
2216
+ @since("2.4.0")
2217
+ def setFeatureSubsetStrategy(self, value: str) -> "RandomForestClassifier":
2218
+ """
2219
+ Sets the value of :py:attr:`featureSubsetStrategy`.
2220
+ """
2221
+ return self._set(featureSubsetStrategy=value)
2222
+
2223
+ def setSeed(self, value: int) -> "RandomForestClassifier":
2224
+ """
2225
+ Sets the value of :py:attr:`seed`.
2226
+ """
2227
+ return self._set(seed=value)
2228
+
2229
+ def setCheckpointInterval(self, value: int) -> "RandomForestClassifier":
2230
+ """
2231
+ Sets the value of :py:attr:`checkpointInterval`.
2232
+ """
2233
+ return self._set(checkpointInterval=value)
2234
+
2235
+ @since("3.0.0")
2236
+ def setWeightCol(self, value: str) -> "RandomForestClassifier":
2237
+ """
2238
+ Sets the value of :py:attr:`weightCol`.
2239
+ """
2240
+ return self._set(weightCol=value)
2241
+
2242
+ @since("3.0.0")
2243
+ def setMinWeightFractionPerNode(self, value: float) -> "RandomForestClassifier":
2244
+ """
2245
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
2246
+ """
2247
+ return self._set(minWeightFractionPerNode=value)
2248
+
2249
+
2250
+ class RandomForestClassificationModel(
2251
+ _TreeEnsembleModel,
2252
+ _JavaProbabilisticClassificationModel[Vector],
2253
+ _RandomForestClassifierParams,
2254
+ JavaMLWritable,
2255
+ JavaMLReadable["RandomForestClassificationModel"],
2256
+ HasTrainingSummary["RandomForestClassificationTrainingSummary"],
2257
+ ):
2258
+ """
2259
+ Model fitted by RandomForestClassifier.
2260
+
2261
+ .. versionadded:: 1.4.0
2262
+ """
2263
+
2264
+ @property
2265
+ def featureImportances(self) -> Vector:
2266
+ """
2267
+ Estimate of the importance of each feature.
2268
+
2269
+ Each feature's importance is the average of its importance across all trees in the ensemble
2270
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
2271
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
2272
+ and follows the implementation from scikit-learn.
2273
+
2274
+ .. versionadded:: 2.0.0
2275
+
2276
+ See Also
2277
+ --------
2278
+ DecisionTreeClassificationModel.featureImportances
2279
+ """
2280
+ return self._call_java("featureImportances")
2281
+
2282
+ @property
2283
+ @since("2.0.0")
2284
+ def trees(self) -> List[DecisionTreeClassificationModel]:
2285
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
2286
+ return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
2287
+
2288
+ @property
2289
+ @since("3.1.0")
2290
+ def summary(self) -> "RandomForestClassificationTrainingSummary":
2291
+ """
2292
+ Gets summary (accuracy/precision/recall, objective history, total iterations) of model
2293
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
2294
+ """
2295
+ if self.hasSummary:
2296
+ if self.numClasses <= 2:
2297
+ return BinaryRandomForestClassificationTrainingSummary(
2298
+ super(RandomForestClassificationModel, self).summary
2299
+ )
2300
+ else:
2301
+ return RandomForestClassificationTrainingSummary(
2302
+ super(RandomForestClassificationModel, self).summary
2303
+ )
2304
+ else:
2305
+ raise RuntimeError(
2306
+ "No training summary available for this %s" % self.__class__.__name__
2307
+ )
2308
+
2309
+ def evaluate(
2310
+ self, dataset: DataFrame
2311
+ ) -> Union["BinaryRandomForestClassificationSummary", "RandomForestClassificationSummary"]:
2312
+ """
2313
+ Evaluates the model on a test dataset.
2314
+
2315
+ .. versionadded:: 3.1.0
2316
+
2317
+ Parameters
2318
+ ----------
2319
+ dataset : :py:class:`pyspark.sql.DataFrame`
2320
+ Test dataset to evaluate model on.
2321
+ """
2322
+ if not isinstance(dataset, DataFrame):
2323
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
2324
+ java_rf_summary = self._call_java("evaluate", dataset)
2325
+ if self.numClasses <= 2:
2326
+ return BinaryRandomForestClassificationSummary(java_rf_summary)
2327
+ else:
2328
+ return RandomForestClassificationSummary(java_rf_summary)
2329
+
2330
+
2331
+ class RandomForestClassificationSummary(_ClassificationSummary):
2332
+ """
2333
+ Abstraction for RandomForestClassification Results for a given model.
2334
+
2335
+ .. versionadded:: 3.1.0
2336
+ """
2337
+
2338
+ pass
2339
+
2340
+
2341
+ @inherit_doc
2342
+ class RandomForestClassificationTrainingSummary(
2343
+ RandomForestClassificationSummary, _TrainingSummary
2344
+ ):
2345
+ """
2346
+ Abstraction for RandomForestClassificationTraining Training results.
2347
+
2348
+ .. versionadded:: 3.1.0
2349
+ """
2350
+
2351
+ pass
2352
+
2353
+
2354
+ @inherit_doc
2355
+ class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
2356
+ """
2357
+ BinaryRandomForestClassification results for a given model.
2358
+
2359
+ .. versionadded:: 3.1.0
2360
+ """
2361
+
2362
+ pass
2363
+
2364
+
2365
+ @inherit_doc
2366
+ class BinaryRandomForestClassificationTrainingSummary(
2367
+ BinaryRandomForestClassificationSummary, RandomForestClassificationTrainingSummary
2368
+ ):
2369
+ """
2370
+ BinaryRandomForestClassification training results for a given model.
2371
+
2372
+ .. versionadded:: 3.1.0
2373
+ """
2374
+
2375
+ pass
2376
+
2377
+
2378
+ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
2379
+ """
2380
+ Params for :py:class:`GBTClassifier` and :py:class:`GBTClassifierModel`.
2381
+
2382
+ .. versionadded:: 3.0.0
2383
+ """
2384
+
2385
+ supportedLossTypes: List[str] = ["logistic"]
2386
+
2387
+ lossType: Param[str] = Param(
2388
+ Params._dummy(),
2389
+ "lossType",
2390
+ "Loss function which GBT tries to minimize (case-insensitive). "
2391
+ + "Supported options: "
2392
+ + ", ".join(supportedLossTypes),
2393
+ typeConverter=TypeConverters.toString,
2394
+ )
2395
+
2396
+ def __init__(self, *args: Any):
2397
+ super(_GBTClassifierParams, self).__init__(*args)
2398
+ self._setDefault(
2399
+ maxDepth=5,
2400
+ maxBins=32,
2401
+ minInstancesPerNode=1,
2402
+ minInfoGain=0.0,
2403
+ maxMemoryInMB=256,
2404
+ cacheNodeIds=False,
2405
+ checkpointInterval=10,
2406
+ lossType="logistic",
2407
+ maxIter=20,
2408
+ stepSize=0.1,
2409
+ subsamplingRate=1.0,
2410
+ impurity="variance",
2411
+ featureSubsetStrategy="all",
2412
+ validationTol=0.01,
2413
+ leafCol="",
2414
+ minWeightFractionPerNode=0.0,
2415
+ )
2416
+
2417
+ @since("1.4.0")
2418
+ def getLossType(self) -> str:
2419
+ """
2420
+ Gets the value of lossType or its default value.
2421
+ """
2422
+ return self.getOrDefault(self.lossType)
2423
+
2424
+
2425
+ @inherit_doc
2426
+ class GBTClassifier(
2427
+ _JavaProbabilisticClassifier["GBTClassificationModel"],
2428
+ _GBTClassifierParams,
2429
+ JavaMLWritable,
2430
+ JavaMLReadable["GBTClassifier"],
2431
+ ):
2432
+ """
2433
+ `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
2434
+ learning algorithm for classification.
2435
+ It supports binary labels, as well as both continuous and categorical features.
2436
+
2437
+ .. versionadded:: 1.4.0
2438
+
2439
+ Notes
2440
+ -----
2441
+ Multiclass labels are not currently supported.
2442
+
2443
+ The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
2444
+
2445
+ Gradient Boosting vs. TreeBoost:
2446
+
2447
+ - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
2448
+ - Both algorithms learn tree ensembles by minimizing loss functions.
2449
+ - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
2450
+ based on the loss function, whereas the original gradient boosting method does not.
2451
+ - We expect to implement TreeBoost in the future:
2452
+ `SPARK-4240 <https://issues.apache.org/jira/browse/SPARK-4240>`_
2453
+
2454
+ Examples
2455
+ --------
2456
+ >>> from numpy import allclose
2457
+ >>> from pyspark.ml.linalg import Vectors
2458
+ >>> from pyspark.ml.feature import StringIndexer
2459
+ >>> df = spark.createDataFrame([
2460
+ ... (1.0, Vectors.dense(1.0)),
2461
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2462
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
2463
+ >>> si_model = stringIndexer.fit(df)
2464
+ >>> td = si_model.transform(df)
2465
+ >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
2466
+ ... leafCol="leafId")
2467
+ >>> gbt.setMaxIter(5)
2468
+ GBTClassifier...
2469
+ >>> gbt.setMinWeightFractionPerNode(0.049)
2470
+ GBTClassifier...
2471
+ >>> gbt.getMaxIter()
2472
+ 5
2473
+ >>> gbt.getFeatureSubsetStrategy()
2474
+ 'all'
2475
+ >>> model = gbt.fit(td)
2476
+ >>> model.getLabelCol()
2477
+ 'indexed'
2478
+ >>> model.setFeaturesCol("features")
2479
+ GBTClassificationModel...
2480
+ >>> model.setThresholds([0.3, 0.7])
2481
+ GBTClassificationModel...
2482
+ >>> model.getThresholds()
2483
+ [0.3, 0.7]
2484
+ >>> model.featureImportances
2485
+ SparseVector(1, {0: 1.0})
2486
+ >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
2487
+ True
2488
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
2489
+ >>> model.predict(test0.head().features)
2490
+ 0.0
2491
+ >>> model.predictRaw(test0.head().features)
2492
+ DenseVector([1.1697, -1.1697])
2493
+ >>> model.predictProbability(test0.head().features)
2494
+ DenseVector([0.9121, 0.0879])
2495
+ >>> result = model.transform(test0).head()
2496
+ >>> result.prediction
2497
+ 0.0
2498
+ >>> result.leafId
2499
+ DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
2500
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
2501
+ >>> model.transform(test1).head().prediction
2502
+ 1.0
2503
+ >>> model.totalNumNodes
2504
+ 15
2505
+ >>> print(model.toDebugString)
2506
+ GBTClassificationModel...numTrees=5...
2507
+ >>> gbtc_path = temp_path + "gbtc"
2508
+ >>> gbt.save(gbtc_path)
2509
+ >>> gbt2 = GBTClassifier.load(gbtc_path)
2510
+ >>> gbt2.getMaxDepth()
2511
+ 2
2512
+ >>> model_path = temp_path + "gbtc_model"
2513
+ >>> model.save(model_path)
2514
+ >>> model2 = GBTClassificationModel.load(model_path)
2515
+ >>> model.featureImportances == model2.featureImportances
2516
+ True
2517
+ >>> model.treeWeights == model2.treeWeights
2518
+ True
2519
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
2520
+ True
2521
+ >>> model.trees
2522
+ [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
2523
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
2524
+ ... ["indexed", "features"])
2525
+ >>> model.evaluateEachIteration(validation)
2526
+ [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
2527
+ >>> model.numClasses
2528
+ 2
2529
+ >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
2530
+ >>> gbt.getValidationIndicatorCol()
2531
+ 'validationIndicator'
2532
+ >>> gbt.getValidationTol()
2533
+ 0.01
2534
+ """
2535
+
2536
+ _input_kwargs: Dict[str, Any]
2537
+
2538
+ @keyword_only
2539
+ def __init__(
2540
+ self,
2541
+ *,
2542
+ featuresCol: str = "features",
2543
+ labelCol: str = "label",
2544
+ predictionCol: str = "prediction",
2545
+ maxDepth: int = 5,
2546
+ maxBins: int = 32,
2547
+ minInstancesPerNode: int = 1,
2548
+ minInfoGain: float = 0.0,
2549
+ maxMemoryInMB: int = 256,
2550
+ cacheNodeIds: bool = False,
2551
+ checkpointInterval: int = 10,
2552
+ lossType: str = "logistic",
2553
+ maxIter: int = 20,
2554
+ stepSize: float = 0.1,
2555
+ seed: Optional[int] = None,
2556
+ subsamplingRate: float = 1.0,
2557
+ impurity: str = "variance",
2558
+ featureSubsetStrategy: str = "all",
2559
+ validationTol: float = 0.01,
2560
+ validationIndicatorCol: Optional[str] = None,
2561
+ leafCol: str = "",
2562
+ minWeightFractionPerNode: float = 0.0,
2563
+ weightCol: Optional[str] = None,
2564
+ ):
2565
+ """
2566
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2567
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
2568
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
2569
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
2570
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
2571
+ validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
2572
+ weightCol=None)
2573
+ """
2574
+ super(GBTClassifier, self).__init__()
2575
+ self._java_obj = self._new_java_obj(
2576
+ "org.apache.spark.ml.classification.GBTClassifier", self.uid
2577
+ )
2578
+ kwargs = self._input_kwargs
2579
+ self.setParams(**kwargs)
2580
+
2581
+ @keyword_only
2582
+ @since("1.4.0")
2583
+ def setParams(
2584
+ self,
2585
+ *,
2586
+ featuresCol: str = "features",
2587
+ labelCol: str = "label",
2588
+ predictionCol: str = "prediction",
2589
+ maxDepth: int = 5,
2590
+ maxBins: int = 32,
2591
+ minInstancesPerNode: int = 1,
2592
+ minInfoGain: float = 0.0,
2593
+ maxMemoryInMB: int = 256,
2594
+ cacheNodeIds: bool = False,
2595
+ checkpointInterval: int = 10,
2596
+ lossType: str = "logistic",
2597
+ maxIter: int = 20,
2598
+ stepSize: float = 0.1,
2599
+ seed: Optional[int] = None,
2600
+ subsamplingRate: float = 1.0,
2601
+ impurity: str = "variance",
2602
+ featureSubsetStrategy: str = "all",
2603
+ validationTol: float = 0.01,
2604
+ validationIndicatorCol: Optional[str] = None,
2605
+ leafCol: str = "",
2606
+ minWeightFractionPerNode: float = 0.0,
2607
+ weightCol: Optional[str] = None,
2608
+ ) -> "GBTClassifier":
2609
+ """
2610
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2611
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
2612
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
2613
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
2614
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
2615
+ validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
2616
+ weightCol=None)
2617
+ Sets params for Gradient Boosted Tree Classification.
2618
+ """
2619
+ kwargs = self._input_kwargs
2620
+ return self._set(**kwargs)
2621
+
2622
+ def _create_model(self, java_model: "JavaObject") -> "GBTClassificationModel":
2623
+ return GBTClassificationModel(java_model)
2624
+
2625
+ def setMaxDepth(self, value: int) -> "GBTClassifier":
2626
+ """
2627
+ Sets the value of :py:attr:`maxDepth`.
2628
+ """
2629
+ return self._set(maxDepth=value)
2630
+
2631
+ def setMaxBins(self, value: int) -> "GBTClassifier":
2632
+ """
2633
+ Sets the value of :py:attr:`maxBins`.
2634
+ """
2635
+ return self._set(maxBins=value)
2636
+
2637
+ def setMinInstancesPerNode(self, value: int) -> "GBTClassifier":
2638
+ """
2639
+ Sets the value of :py:attr:`minInstancesPerNode`.
2640
+ """
2641
+ return self._set(minInstancesPerNode=value)
2642
+
2643
+ def setMinInfoGain(self, value: float) -> "GBTClassifier":
2644
+ """
2645
+ Sets the value of :py:attr:`minInfoGain`.
2646
+ """
2647
+ return self._set(minInfoGain=value)
2648
+
2649
+ def setMaxMemoryInMB(self, value: int) -> "GBTClassifier":
2650
+ """
2651
+ Sets the value of :py:attr:`maxMemoryInMB`.
2652
+ """
2653
+ return self._set(maxMemoryInMB=value)
2654
+
2655
+ def setCacheNodeIds(self, value: bool) -> "GBTClassifier":
2656
+ """
2657
+ Sets the value of :py:attr:`cacheNodeIds`.
2658
+ """
2659
+ return self._set(cacheNodeIds=value)
2660
+
2661
+ @since("1.4.0")
2662
+ def setImpurity(self, value: str) -> "GBTClassifier":
2663
+ """
2664
+ Sets the value of :py:attr:`impurity`.
2665
+ """
2666
+ return self._set(impurity=value)
2667
+
2668
+ @since("1.4.0")
2669
+ def setLossType(self, value: str) -> "GBTClassifier":
2670
+ """
2671
+ Sets the value of :py:attr:`lossType`.
2672
+ """
2673
+ return self._set(lossType=value)
2674
+
2675
+ @since("1.4.0")
2676
+ def setSubsamplingRate(self, value: float) -> "GBTClassifier":
2677
+ """
2678
+ Sets the value of :py:attr:`subsamplingRate`.
2679
+ """
2680
+ return self._set(subsamplingRate=value)
2681
+
2682
+ @since("2.4.0")
2683
+ def setFeatureSubsetStrategy(self, value: str) -> "GBTClassifier":
2684
+ """
2685
+ Sets the value of :py:attr:`featureSubsetStrategy`.
2686
+ """
2687
+ return self._set(featureSubsetStrategy=value)
2688
+
2689
+ @since("3.0.0")
2690
+ def setValidationIndicatorCol(self, value: str) -> "GBTClassifier":
2691
+ """
2692
+ Sets the value of :py:attr:`validationIndicatorCol`.
2693
+ """
2694
+ return self._set(validationIndicatorCol=value)
2695
+
2696
+ @since("1.4.0")
2697
+ def setMaxIter(self, value: int) -> "GBTClassifier":
2698
+ """
2699
+ Sets the value of :py:attr:`maxIter`.
2700
+ """
2701
+ return self._set(maxIter=value)
2702
+
2703
+ @since("1.4.0")
2704
+ def setCheckpointInterval(self, value: int) -> "GBTClassifier":
2705
+ """
2706
+ Sets the value of :py:attr:`checkpointInterval`.
2707
+ """
2708
+ return self._set(checkpointInterval=value)
2709
+
2710
+ @since("1.4.0")
2711
+ def setSeed(self, value: int) -> "GBTClassifier":
2712
+ """
2713
+ Sets the value of :py:attr:`seed`.
2714
+ """
2715
+ return self._set(seed=value)
2716
+
2717
+ @since("1.4.0")
2718
+ def setStepSize(self, value: int) -> "GBTClassifier":
2719
+ """
2720
+ Sets the value of :py:attr:`stepSize`.
2721
+ """
2722
+ return self._set(stepSize=value)
2723
+
2724
+ @since("3.0.0")
2725
+ def setWeightCol(self, value: str) -> "GBTClassifier":
2726
+ """
2727
+ Sets the value of :py:attr:`weightCol`.
2728
+ """
2729
+ return self._set(weightCol=value)
2730
+
2731
+ @since("3.0.0")
2732
+ def setMinWeightFractionPerNode(self, value: float) -> "GBTClassifier":
2733
+ """
2734
+ Sets the value of :py:attr:`minWeightFractionPerNode`.
2735
+ """
2736
+ return self._set(minWeightFractionPerNode=value)
2737
+
2738
+
2739
+ class GBTClassificationModel(
2740
+ _TreeEnsembleModel,
2741
+ _JavaProbabilisticClassificationModel[Vector],
2742
+ _GBTClassifierParams,
2743
+ JavaMLWritable,
2744
+ JavaMLReadable["GBTClassificationModel"],
2745
+ ):
2746
+ """
2747
+ Model fitted by GBTClassifier.
2748
+
2749
+ .. versionadded:: 1.4.0
2750
+ """
2751
+
2752
+ @property
2753
+ def featureImportances(self) -> Vector:
2754
+ """
2755
+ Estimate of the importance of each feature.
2756
+
2757
+ Each feature's importance is the average of its importance across all trees in the ensemble
2758
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
2759
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
2760
+ and follows the implementation from scikit-learn.
2761
+
2762
+ .. versionadded:: 2.0.0
2763
+
2764
+ See Also
2765
+ --------
2766
+ DecisionTreeClassificationModel.featureImportances
2767
+ """
2768
+ return self._call_java("featureImportances")
2769
+
2770
+ @property
2771
+ @since("2.0.0")
2772
+ def trees(self) -> List[DecisionTreeRegressionModel]:
2773
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
2774
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
2775
+
2776
+ def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:
2777
+ """
2778
+ Method to compute error or loss for every iteration of gradient boosting.
2779
+
2780
+ .. versionadded:: 2.4.0
2781
+
2782
+ Parameters
2783
+ ----------
2784
+ dataset : :py:class:`pyspark.sql.DataFrame`
2785
+ Test dataset to evaluate model on.
2786
+ """
2787
+ return self._call_java("evaluateEachIteration", dataset)
2788
+
2789
+
2790
+ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
2791
+ """
2792
+ Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.
2793
+
2794
+ .. versionadded:: 3.0.0
2795
+ """
2796
+
2797
+ smoothing: Param[float] = Param(
2798
+ Params._dummy(),
2799
+ "smoothing",
2800
+ "The smoothing parameter, should be >= 0, " + "default is 1.0",
2801
+ typeConverter=TypeConverters.toFloat,
2802
+ )
2803
+ modelType: Param[str] = Param(
2804
+ Params._dummy(),
2805
+ "modelType",
2806
+ "The model type which is a string "
2807
+ + "(case-sensitive). Supported options: multinomial (default), bernoulli "
2808
+ + "and gaussian.",
2809
+ typeConverter=TypeConverters.toString,
2810
+ )
2811
+
2812
+ def __init__(self, *args: Any):
2813
+ super(_NaiveBayesParams, self).__init__(*args)
2814
+ self._setDefault(smoothing=1.0, modelType="multinomial")
2815
+
2816
+ @since("1.5.0")
2817
+ def getSmoothing(self) -> float:
2818
+ """
2819
+ Gets the value of smoothing or its default value.
2820
+ """
2821
+ return self.getOrDefault(self.smoothing)
2822
+
2823
+ @since("1.5.0")
2824
+ def getModelType(self) -> str:
2825
+ """
2826
+ Gets the value of modelType or its default value.
2827
+ """
2828
+ return self.getOrDefault(self.modelType)
2829
+
2830
+
2831
+ @inherit_doc
2832
+ class NaiveBayes(
2833
+ _JavaProbabilisticClassifier["NaiveBayesModel"],
2834
+ _NaiveBayesParams,
2835
+ HasThresholds,
2836
+ HasWeightCol,
2837
+ JavaMLWritable,
2838
+ JavaMLReadable["NaiveBayes"],
2839
+ ):
2840
+ """
2841
+ Naive Bayes Classifiers.
2842
+ It supports both Multinomial and Bernoulli NB. `Multinomial NB \
2843
+ <http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html>`_
2844
+ can handle finitely supported discrete data. For example, by converting documents into
2845
+ TF-IDF vectors, it can be used for document classification. By making every vector a
2846
+ binary (0/1) data, it can also be used as `Bernoulli NB \
2847
+ <http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html>`_.
2848
+
2849
+ The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
2850
+ Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.
2851
+ Specifically, Complement NB uses statistics from the complement of each class to compute
2852
+ the model's coefficients. The inventors of Complement NB show empirically that the parameter
2853
+ estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the
2854
+ input feature values for Complement NB must be nonnegative.
2855
+ Since 3.0.0, it also supports `Gaussian NB \
2856
+ <https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes>`_.
2857
+ which can handle continuous data.
2858
+
2859
+ .. versionadded:: 1.5.0
2860
+
2861
+ Examples
2862
+ --------
2863
+ >>> from pyspark.sql import Row
2864
+ >>> from pyspark.ml.linalg import Vectors
2865
+ >>> df = spark.createDataFrame([
2866
+ ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
2867
+ ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
2868
+ ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
2869
+ >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
2870
+ >>> model = nb.fit(df)
2871
+ >>> model.setFeaturesCol("features")
2872
+ NaiveBayesModel...
2873
+ >>> model.getSmoothing()
2874
+ 1.0
2875
+ >>> model.pi
2876
+ DenseVector([-0.81..., -0.58...])
2877
+ >>> model.theta
2878
+ DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
2879
+ >>> model.sigma
2880
+ DenseMatrix(0, 0, [...], ...)
2881
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
2882
+ >>> model.predict(test0.head().features)
2883
+ 1.0
2884
+ >>> model.predictRaw(test0.head().features)
2885
+ DenseVector([-1.72..., -0.99...])
2886
+ >>> model.predictProbability(test0.head().features)
2887
+ DenseVector([0.32..., 0.67...])
2888
+ >>> result = model.transform(test0).head()
2889
+ >>> result.prediction
2890
+ 1.0
2891
+ >>> result.probability
2892
+ DenseVector([0.32..., 0.67...])
2893
+ >>> result.rawPrediction
2894
+ DenseVector([-1.72..., -0.99...])
2895
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
2896
+ >>> model.transform(test1).head().prediction
2897
+ 1.0
2898
+ >>> nb_path = temp_path + "/nb"
2899
+ >>> nb.save(nb_path)
2900
+ >>> nb2 = NaiveBayes.load(nb_path)
2901
+ >>> nb2.getSmoothing()
2902
+ 1.0
2903
+ >>> model_path = temp_path + "/nb_model"
2904
+ >>> model.save(model_path)
2905
+ >>> model2 = NaiveBayesModel.load(model_path)
2906
+ >>> model.pi == model2.pi
2907
+ True
2908
+ >>> model.theta == model2.theta
2909
+ True
2910
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
2911
+ True
2912
+ >>> nb = nb.setThresholds([0.01, 10.00])
2913
+ >>> model3 = nb.fit(df)
2914
+ >>> result = model3.transform(test0).head()
2915
+ >>> result.prediction
2916
+ 0.0
2917
+ >>> nb3 = NaiveBayes().setModelType("gaussian")
2918
+ >>> model4 = nb3.fit(df)
2919
+ >>> model4.getModelType()
2920
+ 'gaussian'
2921
+ >>> model4.sigma
2922
+ DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)
2923
+ >>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")
2924
+ >>> model5 = nb5.fit(df)
2925
+ >>> model5.getModelType()
2926
+ 'complement'
2927
+ >>> model5.theta
2928
+ DenseMatrix(2, 2, [...], 1)
2929
+ >>> model5.sigma
2930
+ DenseMatrix(0, 0, [...], ...)
2931
+ """
2932
+
2933
+ _input_kwargs: Dict[str, Any]
2934
+
2935
+ @keyword_only
2936
+ def __init__(
2937
+ self,
2938
+ *,
2939
+ featuresCol: str = "features",
2940
+ labelCol: str = "label",
2941
+ predictionCol: str = "prediction",
2942
+ probabilityCol: str = "probability",
2943
+ rawPredictionCol: str = "rawPrediction",
2944
+ smoothing: float = 1.0,
2945
+ modelType: str = "multinomial",
2946
+ thresholds: Optional[List[float]] = None,
2947
+ weightCol: Optional[str] = None,
2948
+ ):
2949
+ """
2950
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2951
+ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2952
+ modelType="multinomial", thresholds=None, weightCol=None)
2953
+ """
2954
+ super(NaiveBayes, self).__init__()
2955
+ self._java_obj = self._new_java_obj(
2956
+ "org.apache.spark.ml.classification.NaiveBayes", self.uid
2957
+ )
2958
+ kwargs = self._input_kwargs
2959
+ self.setParams(**kwargs)
2960
+
2961
+ @keyword_only
2962
+ @since("1.5.0")
2963
+ def setParams(
2964
+ self,
2965
+ *,
2966
+ featuresCol: str = "features",
2967
+ labelCol: str = "label",
2968
+ predictionCol: str = "prediction",
2969
+ probabilityCol: str = "probability",
2970
+ rawPredictionCol: str = "rawPrediction",
2971
+ smoothing: float = 1.0,
2972
+ modelType: str = "multinomial",
2973
+ thresholds: Optional[List[float]] = None,
2974
+ weightCol: Optional[str] = None,
2975
+ ) -> "NaiveBayes":
2976
+ """
2977
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
2978
+ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2979
+ modelType="multinomial", thresholds=None, weightCol=None)
2980
+ Sets params for Naive Bayes.
2981
+ """
2982
+ kwargs = self._input_kwargs
2983
+ return self._set(**kwargs)
2984
+
2985
+ def _create_model(self, java_model: "JavaObject") -> "NaiveBayesModel":
2986
+ return NaiveBayesModel(java_model)
2987
+
2988
+ @since("1.5.0")
2989
+ def setSmoothing(self, value: float) -> "NaiveBayes":
2990
+ """
2991
+ Sets the value of :py:attr:`smoothing`.
2992
+ """
2993
+ return self._set(smoothing=value)
2994
+
2995
+ @since("1.5.0")
2996
+ def setModelType(self, value: str) -> "NaiveBayes":
2997
+ """
2998
+ Sets the value of :py:attr:`modelType`.
2999
+ """
3000
+ return self._set(modelType=value)
3001
+
3002
+ def setWeightCol(self, value: str) -> "NaiveBayes":
3003
+ """
3004
+ Sets the value of :py:attr:`weightCol`.
3005
+ """
3006
+ return self._set(weightCol=value)
3007
+
3008
+
3009
+ class NaiveBayesModel(
3010
+ _JavaProbabilisticClassificationModel[Vector],
3011
+ _NaiveBayesParams,
3012
+ JavaMLWritable,
3013
+ JavaMLReadable["NaiveBayesModel"],
3014
+ ):
3015
+ """
3016
+ Model fitted by NaiveBayes.
3017
+
3018
+ .. versionadded:: 1.5.0
3019
+ """
3020
+
3021
+ @property
3022
+ @since("2.0.0")
3023
+ def pi(self) -> Vector:
3024
+ """
3025
+ log of class priors.
3026
+ """
3027
+ return self._call_java("pi")
3028
+
3029
+ @property
3030
+ @since("2.0.0")
3031
+ def theta(self) -> Matrix:
3032
+ """
3033
+ log of class conditional probabilities.
3034
+ """
3035
+ return self._call_java("theta")
3036
+
3037
+ @property
3038
+ @since("3.0.0")
3039
+ def sigma(self) -> Matrix:
3040
+ """
3041
+ variance of each feature.
3042
+ """
3043
+ return self._call_java("sigma")
3044
+
3045
+
3046
+ class _MultilayerPerceptronParams(
3047
+ _ProbabilisticClassifierParams,
3048
+ HasSeed,
3049
+ HasMaxIter,
3050
+ HasTol,
3051
+ HasStepSize,
3052
+ HasSolver,
3053
+ HasBlockSize,
3054
+ ):
3055
+ """
3056
+ Params for :py:class:`MultilayerPerceptronClassifier`.
3057
+
3058
+ .. versionadded:: 3.0.0
3059
+ """
3060
+
3061
+ layers: Param[List[int]] = Param(
3062
+ Params._dummy(),
3063
+ "layers",
3064
+ "Sizes of layers from input layer to output layer "
3065
+ + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 "
3066
+ + "neurons and output layer of 10 neurons.",
3067
+ typeConverter=TypeConverters.toListInt,
3068
+ )
3069
+ solver: Param[str] = Param(
3070
+ Params._dummy(),
3071
+ "solver",
3072
+ "The solver algorithm for optimization. Supported " + "options: l-bfgs, gd.",
3073
+ typeConverter=TypeConverters.toString,
3074
+ )
3075
+ initialWeights: Param[Vector] = Param(
3076
+ Params._dummy(),
3077
+ "initialWeights",
3078
+ "The initial weights of the model.",
3079
+ typeConverter=TypeConverters.toVector,
3080
+ )
3081
+
3082
+ def __init__(self, *args: Any):
3083
+ super(_MultilayerPerceptronParams, self).__init__(*args)
3084
+ self._setDefault(maxIter=100, tol=1e-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
3085
+
3086
+ @since("1.6.0")
3087
+ def getLayers(self) -> List[int]:
3088
+ """
3089
+ Gets the value of layers or its default value.
3090
+ """
3091
+ return self.getOrDefault(self.layers)
3092
+
3093
+ @since("2.0.0")
3094
+ def getInitialWeights(self) -> Vector:
3095
+ """
3096
+ Gets the value of initialWeights or its default value.
3097
+ """
3098
+ return self.getOrDefault(self.initialWeights)
3099
+
3100
+
3101
+ @inherit_doc
3102
+ class MultilayerPerceptronClassifier(
3103
+ _JavaProbabilisticClassifier["MultilayerPerceptronClassificationModel"],
3104
+ _MultilayerPerceptronParams,
3105
+ JavaMLWritable,
3106
+ JavaMLReadable["MultilayerPerceptronClassifier"],
3107
+ ):
3108
+ """
3109
+ Classifier trainer based on the Multilayer Perceptron.
3110
+ Each layer has sigmoid activation function, output layer has softmax.
3111
+ Number of inputs has to be equal to the size of feature vectors.
3112
+ Number of outputs has to be equal to the total number of labels.
3113
+
3114
+ .. versionadded:: 1.6.0
3115
+
3116
+ Examples
3117
+ --------
3118
+ >>> from pyspark.ml.linalg import Vectors
3119
+ >>> df = spark.createDataFrame([
3120
+ ... (0.0, Vectors.dense([0.0, 0.0])),
3121
+ ... (1.0, Vectors.dense([0.0, 1.0])),
3122
+ ... (1.0, Vectors.dense([1.0, 0.0])),
3123
+ ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
3124
+ >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
3125
+ >>> mlp.setMaxIter(100)
3126
+ MultilayerPerceptronClassifier...
3127
+ >>> mlp.getMaxIter()
3128
+ 100
3129
+ >>> mlp.getBlockSize()
3130
+ 128
3131
+ >>> mlp.setBlockSize(1)
3132
+ MultilayerPerceptronClassifier...
3133
+ >>> mlp.getBlockSize()
3134
+ 1
3135
+ >>> model = mlp.fit(df)
3136
+ >>> model.setFeaturesCol("features")
3137
+ MultilayerPerceptronClassificationModel...
3138
+ >>> model.getMaxIter()
3139
+ 100
3140
+ >>> model.getLayers()
3141
+ [2, 2, 2]
3142
+ >>> model.weights.size
3143
+ 12
3144
+ >>> testDF = spark.createDataFrame([
3145
+ ... (Vectors.dense([1.0, 0.0]),),
3146
+ ... (Vectors.dense([0.0, 0.0]),)], ["features"])
3147
+ >>> model.predict(testDF.head().features)
3148
+ 1.0
3149
+ >>> model.predictRaw(testDF.head().features)
3150
+ DenseVector([-16.208, 16.344])
3151
+ >>> model.predictProbability(testDF.head().features)
3152
+ DenseVector([0.0, 1.0])
3153
+ >>> model.transform(testDF).select("features", "prediction").show()
3154
+ +---------+----------+
3155
+ | features|prediction|
3156
+ +---------+----------+
3157
+ |[1.0,0.0]| 1.0|
3158
+ |[0.0,0.0]| 0.0|
3159
+ +---------+----------+
3160
+ ...
3161
+ >>> mlp_path = temp_path + "/mlp"
3162
+ >>> mlp.save(mlp_path)
3163
+ >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
3164
+ >>> mlp2.getBlockSize()
3165
+ 1
3166
+ >>> model_path = temp_path + "/mlp_model"
3167
+ >>> model.save(model_path)
3168
+ >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
3169
+ >>> model.getLayers() == model2.getLayers()
3170
+ True
3171
+ >>> model.weights == model2.weights
3172
+ True
3173
+ >>> model.transform(testDF).take(1) == model2.transform(testDF).take(1)
3174
+ True
3175
+ >>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))
3176
+ >>> model3 = mlp2.fit(df)
3177
+ >>> model3.weights != model2.weights
3178
+ True
3179
+ >>> model3.getLayers() == model.getLayers()
3180
+ True
3181
+ """
3182
+
3183
+ _input_kwargs: Dict[str, Any]
3184
+
3185
+ @keyword_only
3186
+ def __init__(
3187
+ self,
3188
+ *,
3189
+ featuresCol: str = "features",
3190
+ labelCol: str = "label",
3191
+ predictionCol: str = "prediction",
3192
+ maxIter: int = 100,
3193
+ tol: float = 1e-6,
3194
+ seed: Optional[int] = None,
3195
+ layers: Optional[List[int]] = None,
3196
+ blockSize: int = 128,
3197
+ stepSize: float = 0.03,
3198
+ solver: str = "l-bfgs",
3199
+ initialWeights: Optional[Vector] = None,
3200
+ probabilityCol: str = "probability",
3201
+ rawPredictionCol: str = "rawPrediction",
3202
+ ):
3203
+ """
3204
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3205
+ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
3206
+ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
3207
+ rawPredictionCol="rawPrediction")
3208
+ """
3209
+ super(MultilayerPerceptronClassifier, self).__init__()
3210
+ self._java_obj = self._new_java_obj(
3211
+ "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid
3212
+ )
3213
+ kwargs = self._input_kwargs
3214
+ self.setParams(**kwargs)
3215
+
3216
+ @keyword_only
3217
+ @since("1.6.0")
3218
+ def setParams(
3219
+ self,
3220
+ *,
3221
+ featuresCol: str = "features",
3222
+ labelCol: str = "label",
3223
+ predictionCol: str = "prediction",
3224
+ maxIter: int = 100,
3225
+ tol: float = 1e-6,
3226
+ seed: Optional[int] = None,
3227
+ layers: Optional[List[int]] = None,
3228
+ blockSize: int = 128,
3229
+ stepSize: float = 0.03,
3230
+ solver: str = "l-bfgs",
3231
+ initialWeights: Optional[Vector] = None,
3232
+ probabilityCol: str = "probability",
3233
+ rawPredictionCol: str = "rawPrediction",
3234
+ ) -> "MultilayerPerceptronClassifier":
3235
+ """
3236
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3237
+ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
3238
+ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
3239
+ rawPredictionCol="rawPrediction"):
3240
+ Sets params for MultilayerPerceptronClassifier.
3241
+ """
3242
+ kwargs = self._input_kwargs
3243
+ return self._set(**kwargs)
3244
+
3245
+ def _create_model(self, java_model: "JavaObject") -> "MultilayerPerceptronClassificationModel":
3246
+ return MultilayerPerceptronClassificationModel(java_model)
3247
+
3248
+ @since("1.6.0")
3249
+ def setLayers(self, value: List[int]) -> "MultilayerPerceptronClassifier":
3250
+ """
3251
+ Sets the value of :py:attr:`layers`.
3252
+ """
3253
+ return self._set(layers=value)
3254
+
3255
+ @since("1.6.0")
3256
+ def setBlockSize(self, value: int) -> "MultilayerPerceptronClassifier":
3257
+ """
3258
+ Sets the value of :py:attr:`blockSize`.
3259
+ """
3260
+ return self._set(blockSize=value)
3261
+
3262
+ @since("2.0.0")
3263
+ def setInitialWeights(self, value: Vector) -> "MultilayerPerceptronClassifier":
3264
+ """
3265
+ Sets the value of :py:attr:`initialWeights`.
3266
+ """
3267
+ return self._set(initialWeights=value)
3268
+
3269
+ def setMaxIter(self, value: int) -> "MultilayerPerceptronClassifier":
3270
+ """
3271
+ Sets the value of :py:attr:`maxIter`.
3272
+ """
3273
+ return self._set(maxIter=value)
3274
+
3275
+ def setSeed(self, value: int) -> "MultilayerPerceptronClassifier":
3276
+ """
3277
+ Sets the value of :py:attr:`seed`.
3278
+ """
3279
+ return self._set(seed=value)
3280
+
3281
+ def setTol(self, value: float) -> "MultilayerPerceptronClassifier":
3282
+ """
3283
+ Sets the value of :py:attr:`tol`.
3284
+ """
3285
+ return self._set(tol=value)
3286
+
3287
+ @since("2.0.0")
3288
+ def setStepSize(self, value: float) -> "MultilayerPerceptronClassifier":
3289
+ """
3290
+ Sets the value of :py:attr:`stepSize`.
3291
+ """
3292
+ return self._set(stepSize=value)
3293
+
3294
+ def setSolver(self, value: str) -> "MultilayerPerceptronClassifier":
3295
+ """
3296
+ Sets the value of :py:attr:`solver`.
3297
+ """
3298
+ return self._set(solver=value)
3299
+
3300
+
3301
+ class MultilayerPerceptronClassificationModel(
3302
+ _JavaProbabilisticClassificationModel[Vector],
3303
+ _MultilayerPerceptronParams,
3304
+ JavaMLWritable,
3305
+ JavaMLReadable["MultilayerPerceptronClassificationModel"],
3306
+ HasTrainingSummary["MultilayerPerceptronClassificationTrainingSummary"],
3307
+ ):
3308
+ """
3309
+ Model fitted by MultilayerPerceptronClassifier.
3310
+
3311
+ .. versionadded:: 1.6.0
3312
+ """
3313
+
3314
+ @property
3315
+ @since("2.0.0")
3316
+ def weights(self) -> Vector:
3317
+ """
3318
+ the weights of layers.
3319
+ """
3320
+ return self._call_java("weights")
3321
+
3322
+ @since("3.1.0")
3323
+ def summary(self) -> "MultilayerPerceptronClassificationTrainingSummary":
3324
+ """
3325
+ Gets summary (accuracy/precision/recall, objective history, total iterations) of model
3326
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
3327
+ """
3328
+ if self.hasSummary:
3329
+ return MultilayerPerceptronClassificationTrainingSummary(
3330
+ super(MultilayerPerceptronClassificationModel, self).summary
3331
+ )
3332
+ else:
3333
+ raise RuntimeError(
3334
+ "No training summary available for this %s" % self.__class__.__name__
3335
+ )
3336
+
3337
+ def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSummary":
3338
+ """
3339
+ Evaluates the model on a test dataset.
3340
+
3341
+ .. versionadded:: 3.1.0
3342
+
3343
+ Parameters
3344
+ ----------
3345
+ dataset : :py:class:`pyspark.sql.DataFrame`
3346
+ Test dataset to evaluate model on.
3347
+ """
3348
+ if not isinstance(dataset, DataFrame):
3349
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
3350
+ java_mlp_summary = self._call_java("evaluate", dataset)
3351
+ return MultilayerPerceptronClassificationSummary(java_mlp_summary)
3352
+
3353
+
3354
+ class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
3355
+ """
3356
+ Abstraction for MultilayerPerceptronClassifier Results for a given model.
3357
+
3358
+ .. versionadded:: 3.1.0
3359
+ """
3360
+
3361
+ pass
3362
+
3363
+
3364
+ @inherit_doc
3365
+ class MultilayerPerceptronClassificationTrainingSummary(
3366
+ MultilayerPerceptronClassificationSummary, _TrainingSummary
3367
+ ):
3368
+ """
3369
+ Abstraction for MultilayerPerceptronClassifier Training results.
3370
+
3371
+ .. versionadded:: 3.1.0
3372
+ """
3373
+
3374
+ pass
3375
+
3376
+
3377
+ class _OneVsRestParams(_ClassifierParams, HasWeightCol):
3378
+ """
3379
+ Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.
3380
+ """
3381
+
3382
+ classifier: Param[Classifier] = Param(Params._dummy(), "classifier", "base binary classifier")
3383
+
3384
+ @since("2.0.0")
3385
+ def getClassifier(self) -> Classifier:
3386
+ """
3387
+ Gets the value of classifier or its default value.
3388
+ """
3389
+ return self.getOrDefault(self.classifier)
3390
+
3391
+
3392
+ @inherit_doc
3393
+ class OneVsRest(
3394
+ Estimator["OneVsRestModel"],
3395
+ _OneVsRestParams,
3396
+ HasParallelism,
3397
+ MLReadable["OneVsRest"],
3398
+ MLWritable,
3399
+ Generic[CM],
3400
+ ):
3401
+ """
3402
+ Reduction of Multiclass Classification to Binary Classification.
3403
+ Performs reduction using one against all strategy.
3404
+ For a multiclass classification with k classes, train k models (one per class).
3405
+ Each example is scored against all k models and the model with highest score
3406
+ is picked to label the example.
3407
+
3408
+ .. versionadded:: 2.0.0
3409
+
3410
+ Examples
3411
+ --------
3412
+ >>> from pyspark.sql import Row
3413
+ >>> from pyspark.ml.linalg import Vectors
3414
+ >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
3415
+ >>> df = spark.read.format("libsvm").load(data_path)
3416
+ >>> lr = LogisticRegression(regParam=0.01)
3417
+ >>> ovr = OneVsRest(classifier=lr)
3418
+ >>> ovr.getRawPredictionCol()
3419
+ 'rawPrediction'
3420
+ >>> ovr.setPredictionCol("newPrediction")
3421
+ OneVsRest...
3422
+ >>> model = ovr.fit(df)
3423
+ >>> model.models[0].coefficients
3424
+ DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
3425
+ >>> model.models[1].coefficients
3426
+ DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])
3427
+ >>> model.models[2].coefficients
3428
+ DenseVector([0.3..., -3.4..., 1.0..., -1.1...])
3429
+ >>> [x.intercept for x in model.models]
3430
+ [-2.7..., -2.5..., -1.3...]
3431
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()
3432
+ >>> model.transform(test0).head().newPrediction
3433
+ 0.0
3434
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()
3435
+ >>> model.transform(test1).head().newPrediction
3436
+ 2.0
3437
+ >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()
3438
+ >>> model.transform(test2).head().newPrediction
3439
+ 0.0
3440
+ >>> model_path = temp_path + "/ovr_model"
3441
+ >>> model.save(model_path)
3442
+ >>> model2 = OneVsRestModel.load(model_path)
3443
+ >>> model2.transform(test0).head().newPrediction
3444
+ 0.0
3445
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
3446
+ True
3447
+ >>> model.transform(test2).columns
3448
+ ['features', 'rawPrediction', 'newPrediction']
3449
+ """
3450
+
3451
+ _input_kwargs: Dict[str, Any]
3452
+
3453
+ @keyword_only
3454
+ def __init__(
3455
+ self,
3456
+ *,
3457
+ featuresCol: str = "features",
3458
+ labelCol: str = "label",
3459
+ predictionCol: str = "prediction",
3460
+ rawPredictionCol: str = "rawPrediction",
3461
+ classifier: Optional[Classifier[CM]] = None,
3462
+ weightCol: Optional[str] = None,
3463
+ parallelism: int = 1,
3464
+ ):
3465
+ """
3466
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3467
+ rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
3468
+ """
3469
+ super(OneVsRest, self).__init__()
3470
+ self._setDefault(parallelism=1)
3471
+ kwargs = self._input_kwargs
3472
+ self._set(**kwargs)
3473
+
3474
+ @keyword_only
3475
+ @since("2.0.0")
3476
+ def setParams(
3477
+ self,
3478
+ *,
3479
+ featuresCol: str = "features",
3480
+ labelCol: str = "label",
3481
+ predictionCol: str = "prediction",
3482
+ rawPredictionCol: str = "rawPrediction",
3483
+ classifier: Optional[Classifier[CM]] = None,
3484
+ weightCol: Optional[str] = None,
3485
+ parallelism: int = 1,
3486
+ ) -> "OneVsRest":
3487
+ """
3488
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
3489
+ rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
3490
+ Sets params for OneVsRest.
3491
+ """
3492
+ kwargs = self._input_kwargs
3493
+ return self._set(**kwargs)
3494
+
3495
+ @since("2.0.0")
3496
+ def setClassifier(self, value: Classifier[CM]) -> "OneVsRest":
3497
+ """
3498
+ Sets the value of :py:attr:`classifier`.
3499
+ """
3500
+ return self._set(classifier=value)
3501
+
3502
+ def setLabelCol(self, value: str) -> "OneVsRest":
3503
+ """
3504
+ Sets the value of :py:attr:`labelCol`.
3505
+ """
3506
+ return self._set(labelCol=value)
3507
+
3508
+ def setFeaturesCol(self, value: str) -> "OneVsRest":
3509
+ """
3510
+ Sets the value of :py:attr:`featuresCol`.
3511
+ """
3512
+ return self._set(featuresCol=value)
3513
+
3514
+ def setPredictionCol(self, value: str) -> "OneVsRest":
3515
+ """
3516
+ Sets the value of :py:attr:`predictionCol`.
3517
+ """
3518
+ return self._set(predictionCol=value)
3519
+
3520
+ def setRawPredictionCol(self, value: str) -> "OneVsRest":
3521
+ """
3522
+ Sets the value of :py:attr:`rawPredictionCol`.
3523
+ """
3524
+ return self._set(rawPredictionCol=value)
3525
+
3526
+ def setWeightCol(self, value: str) -> "OneVsRest":
3527
+ """
3528
+ Sets the value of :py:attr:`weightCol`.
3529
+ """
3530
+ return self._set(weightCol=value)
3531
+
3532
+ def setParallelism(self, value: int) -> "OneVsRest":
3533
+ """
3534
+ Sets the value of :py:attr:`parallelism`.
3535
+ """
3536
+ return self._set(parallelism=value)
3537
+
3538
+ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
3539
+ labelCol = self.getLabelCol()
3540
+ featuresCol = self.getFeaturesCol()
3541
+ predictionCol = self.getPredictionCol()
3542
+ classifier = self.getClassifier()
3543
+
3544
+ numClasses = (
3545
+ int(cast(Row, dataset.agg({labelCol: "max"}).head())["max(" + labelCol + ")"]) + 1
3546
+ )
3547
+
3548
+ weightCol = None
3549
+ if self.isDefined(self.weightCol) and self.getWeightCol():
3550
+ if isinstance(classifier, HasWeightCol):
3551
+ weightCol = self.getWeightCol()
3552
+ else:
3553
+ warnings.warn(
3554
+ "weightCol is ignored, " "as it is not supported by {} now.".format(classifier)
3555
+ )
3556
+
3557
+ if weightCol:
3558
+ multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
3559
+ else:
3560
+ multiclassLabeled = dataset.select(labelCol, featuresCol)
3561
+
3562
+ # persist if underlying dataset is not persistent.
3563
+ handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
3564
+ if handlePersistence:
3565
+ multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
3566
+
3567
+ def trainSingleClass(index: int) -> CM:
3568
+ binaryLabelCol = "mc2b$" + str(index)
3569
+ trainingDataset = multiclassLabeled.withColumn(
3570
+ binaryLabelCol,
3571
+ when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3572
+ )
3573
+ paramMap = dict(
3574
+ [
3575
+ (classifier.labelCol, binaryLabelCol),
3576
+ (classifier.featuresCol, featuresCol),
3577
+ (classifier.predictionCol, predictionCol),
3578
+ ]
3579
+ )
3580
+ if weightCol:
3581
+ paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3582
+ return classifier.fit(trainingDataset, paramMap)
3583
+
3584
+ pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
3585
+
3586
+ models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
3587
+
3588
+ if handlePersistence:
3589
+ multiclassLabeled.unpersist()
3590
+
3591
+ return self._copyValues(OneVsRestModel(models=models))
3592
+
3593
+ def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
3594
+ """
3595
+ Creates a copy of this instance with a randomly generated uid
3596
+ and some extra params. This creates a deep copy of the embedded paramMap,
3597
+ and copies the embedded and extra parameters over.
3598
+
3599
+ .. versionadded:: 2.0.0
3600
+
3601
+ Examples
3602
+ --------
3603
+ extra : dict, optional
3604
+ Extra parameters to copy to the new instance
3605
+
3606
+ Returns
3607
+ -------
3608
+ :py:class:`OneVsRest`
3609
+ Copy of this instance
3610
+ """
3611
+ if extra is None:
3612
+ extra = dict()
3613
+ newOvr = Params.copy(self, extra)
3614
+ if self.isSet(self.classifier):
3615
+ newOvr.setClassifier(self.getClassifier().copy(extra))
3616
+ return newOvr
3617
+
3618
+ @classmethod
3619
+ def _from_java(cls, java_stage: "JavaObject") -> "OneVsRest":
3620
+ """
3621
+ Given a Java OneVsRest, create and return a Python wrapper of it.
3622
+ Used for ML persistence.
3623
+ """
3624
+ featuresCol = java_stage.getFeaturesCol()
3625
+ labelCol = java_stage.getLabelCol()
3626
+ predictionCol = java_stage.getPredictionCol()
3627
+ rawPredictionCol = java_stage.getRawPredictionCol()
3628
+ classifier: Classifier = JavaParams._from_java(java_stage.getClassifier())
3629
+ parallelism = java_stage.getParallelism()
3630
+ py_stage = cls(
3631
+ featuresCol=featuresCol,
3632
+ labelCol=labelCol,
3633
+ predictionCol=predictionCol,
3634
+ rawPredictionCol=rawPredictionCol,
3635
+ classifier=classifier,
3636
+ parallelism=parallelism,
3637
+ )
3638
+ if java_stage.isDefined(java_stage.getParam("weightCol")):
3639
+ py_stage.setWeightCol(java_stage.getWeightCol())
3640
+ py_stage._resetUid(java_stage.uid())
3641
+ return py_stage
3642
+
3643
+ def _to_java(self) -> "JavaObject":
3644
+ """
3645
+ Transfer this instance to a Java OneVsRest. Used for ML persistence.
3646
+
3647
+ Returns
3648
+ -------
3649
+ py4j.java_gateway.JavaObject
3650
+ Java object equivalent to this instance.
3651
+ """
3652
+ _java_obj = JavaParams._new_java_obj(
3653
+ "org.apache.spark.ml.classification.OneVsRest", self.uid
3654
+ )
3655
+ _java_obj.setClassifier(cast(_JavaClassifier, self.getClassifier())._to_java())
3656
+ _java_obj.setParallelism(self.getParallelism())
3657
+ _java_obj.setFeaturesCol(self.getFeaturesCol())
3658
+ _java_obj.setLabelCol(self.getLabelCol())
3659
+ _java_obj.setPredictionCol(self.getPredictionCol())
3660
+ if self.isDefined(self.weightCol) and self.getWeightCol():
3661
+ _java_obj.setWeightCol(self.getWeightCol())
3662
+ _java_obj.setRawPredictionCol(self.getRawPredictionCol())
3663
+ return _java_obj
3664
+
3665
+ @classmethod
3666
+ def read(cls) -> "OneVsRestReader":
3667
+ return OneVsRestReader(cls)
3668
+
3669
+ def write(self) -> MLWriter:
3670
+ if isinstance(self.getClassifier(), JavaMLWritable):
3671
+ return JavaMLWriter(self) # type: ignore[arg-type]
3672
+ else:
3673
+ return OneVsRestWriter(self)
3674
+
3675
+
3676
+ class _OneVsRestSharedReadWrite:
3677
+ @staticmethod
3678
+ def saveImpl(
3679
+ instance: Union[OneVsRest, "OneVsRestModel"],
3680
+ sc: SparkContext,
3681
+ path: str,
3682
+ extraMetadata: Optional[Dict[str, Any]] = None,
3683
+ ) -> None:
3684
+ skipParams = ["classifier"]
3685
+ jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
3686
+ DefaultParamsWriter.saveMetadata(
3687
+ instance, path, sc, paramMap=jsonParams, extraMetadata=extraMetadata
3688
+ )
3689
+ classifierPath = os.path.join(path, "classifier")
3690
+ cast(MLWritable, instance.getClassifier()).save(classifierPath)
3691
+
3692
+ @staticmethod
3693
+ def loadClassifier(path: str, sc: SparkContext) -> Union[OneVsRest, "OneVsRestModel"]:
3694
+ classifierPath = os.path.join(path, "classifier")
3695
+ return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
3696
+
3697
+ @staticmethod
3698
+ def validateParams(instance: Union[OneVsRest, "OneVsRestModel"]) -> None:
3699
+ elems_to_check: List[Params] = [instance.getClassifier()]
3700
+ if isinstance(instance, OneVsRestModel):
3701
+ elems_to_check.extend(instance.models)
3702
+
3703
+ for elem in elems_to_check:
3704
+ if not isinstance(elem, MLWritable):
3705
+ raise ValueError(
3706
+ f"OneVsRest write will fail because it contains {elem.uid} "
3707
+ f"which is not writable."
3708
+ )
3709
+
3710
+
3711
+ @inherit_doc
3712
+ class OneVsRestReader(MLReader[OneVsRest]):
3713
+ def __init__(self, cls: Type[OneVsRest]) -> None:
3714
+ super(OneVsRestReader, self).__init__()
3715
+ self.cls = cls
3716
+
3717
+ def load(self, path: str) -> OneVsRest:
3718
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
3719
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
3720
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
3721
+ else:
3722
+ classifier = cast(Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sc))
3723
+ ova: OneVsRest = OneVsRest(classifier=classifier)._resetUid(metadata["uid"])
3724
+ DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"])
3725
+ return ova
3726
+
3727
+
3728
+ @inherit_doc
3729
+ class OneVsRestWriter(MLWriter):
3730
+ def __init__(self, instance: OneVsRest):
3731
+ super(OneVsRestWriter, self).__init__()
3732
+ self.instance = instance
3733
+
3734
+ def saveImpl(self, path: str) -> None:
3735
+ _OneVsRestSharedReadWrite.validateParams(self.instance)
3736
+ _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
3737
+
3738
+
3739
+ class OneVsRestModel(
3740
+ Model,
3741
+ _OneVsRestParams,
3742
+ MLReadable["OneVsRestModel"],
3743
+ MLWritable,
3744
+ ):
3745
+ """
3746
+ Model fitted by OneVsRest.
3747
+ This stores the models resulting from training k binary classifiers: one for each class.
3748
+ Each example is scored against all k models, and the model with the highest score
3749
+ is picked to label the example.
3750
+
3751
+ .. versionadded:: 2.0.0
3752
+ """
3753
+
3754
+ def setFeaturesCol(self, value: str) -> "OneVsRestModel":
3755
+ """
3756
+ Sets the value of :py:attr:`featuresCol`.
3757
+ """
3758
+ return self._set(featuresCol=value)
3759
+
3760
+ def setPredictionCol(self, value: str) -> "OneVsRestModel":
3761
+ """
3762
+ Sets the value of :py:attr:`predictionCol`.
3763
+ """
3764
+ return self._set(predictionCol=value)
3765
+
3766
+ def setRawPredictionCol(self, value: str) -> "OneVsRestModel":
3767
+ """
3768
+ Sets the value of :py:attr:`rawPredictionCol`.
3769
+ """
3770
+ return self._set(rawPredictionCol=value)
3771
+
3772
+ def __init__(self, models: List[ClassificationModel]):
3773
+ super(OneVsRestModel, self).__init__()
3774
+ self.models = models
3775
+ if not isinstance(models[0], JavaMLWritable):
3776
+ return
3777
+ # set java instance
3778
+ java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
3779
+ sc = SparkContext._active_spark_context
3780
+ assert sc is not None and sc._gateway is not None
3781
+
3782
+ java_models_array = JavaWrapper._new_java_array(
3783
+ java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
3784
+ )
3785
+ # TODO: need to set metadata
3786
+ metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
3787
+ self._java_obj = JavaParams._new_java_obj(
3788
+ "org.apache.spark.ml.classification.OneVsRestModel",
3789
+ self.uid,
3790
+ metadata.empty(),
3791
+ java_models_array,
3792
+ )
3793
+
3794
+ def _transform(self, dataset: DataFrame) -> DataFrame:
3795
+ # determine the input columns: these need to be passed through
3796
+ origCols = dataset.columns
3797
+
3798
+ # add an accumulator column to store predictions of all the models
3799
+ accColName = "mbc$acc" + str(uuid.uuid4())
3800
+ initUDF = udf(lambda _: [], ArrayType(DoubleType()))
3801
+ newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
3802
+
3803
+ # persist if underlying dataset is not persistent.
3804
+ handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
3805
+ if handlePersistence:
3806
+ newDataset.persist(StorageLevel.MEMORY_AND_DISK)
3807
+
3808
+ # update the accumulator column with the result of prediction of models
3809
+ aggregatedDataset = newDataset
3810
+ for index, model in enumerate(self.models):
3811
+ rawPredictionCol = self.getRawPredictionCol()
3812
+
3813
+ columns = origCols + [rawPredictionCol, accColName]
3814
+
3815
+ # add temporary column to store intermediate scores and update
3816
+ tmpColName = "mbc$tmp" + str(uuid.uuid4())
3817
+ updateUDF = udf(
3818
+ lambda predictions, prediction: predictions + [prediction.tolist()[1]],
3819
+ ArrayType(DoubleType()),
3820
+ )
3821
+ transformedDataset = model.transform(aggregatedDataset).select(*columns)
3822
+ updatedDataset = transformedDataset.withColumn(
3823
+ tmpColName,
3824
+ updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]),
3825
+ )
3826
+ newColumns = origCols + [tmpColName]
3827
+
3828
+ # switch out the intermediate column with the accumulator column
3829
+ aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
3830
+ tmpColName, accColName
3831
+ )
3832
+
3833
+ if handlePersistence:
3834
+ newDataset.unpersist()
3835
+
3836
+ if self.getRawPredictionCol():
3837
+
3838
+ def func(predictions: Iterable[float]) -> Vector:
3839
+ predArray: List[float] = []
3840
+ for x in predictions:
3841
+ predArray.append(x)
3842
+ return Vectors.dense(predArray)
3843
+
3844
+ rawPredictionUDF = udf(func, VectorUDT())
3845
+ aggregatedDataset = aggregatedDataset.withColumn(
3846
+ self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])
3847
+ )
3848
+
3849
+ if self.getPredictionCol():
3850
+ # output the index of the classifier with highest confidence as prediction
3851
+ labelUDF = udf(
3852
+ lambda predictions: float(
3853
+ max(enumerate(predictions), key=operator.itemgetter(1))[0]
3854
+ ),
3855
+ DoubleType(),
3856
+ )
3857
+ aggregatedDataset = aggregatedDataset.withColumn(
3858
+ self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])
3859
+ )
3860
+ return aggregatedDataset.drop(accColName)
3861
+
3862
+ def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRestModel":
3863
+ """
3864
+ Creates a copy of this instance with a randomly generated uid
3865
+ and some extra params. This creates a deep copy of the embedded paramMap,
3866
+ and copies the embedded and extra parameters over.
3867
+
3868
+ .. versionadded:: 2.0.0
3869
+
3870
+ Parameters
3871
+ ----------
3872
+ extra : dict, optional
3873
+ Extra parameters to copy to the new instance
3874
+
3875
+ Returns
3876
+ -------
3877
+ :py:class:`OneVsRestModel`
3878
+ Copy of this instance
3879
+ """
3880
+ if extra is None:
3881
+ extra = dict()
3882
+ newModel = Params.copy(self, extra)
3883
+ newModel.models = [model.copy(extra) for model in self.models]
3884
+ return newModel
3885
+
3886
+ @classmethod
3887
+ def _from_java(cls, java_stage: "JavaObject") -> "OneVsRestModel":
3888
+ """
3889
+ Given a Java OneVsRestModel, create and return a Python wrapper of it.
3890
+ Used for ML persistence.
3891
+ """
3892
+ featuresCol = java_stage.getFeaturesCol()
3893
+ labelCol = java_stage.getLabelCol()
3894
+ predictionCol = java_stage.getPredictionCol()
3895
+ classifier: Classifier = JavaParams._from_java(java_stage.getClassifier())
3896
+ models: List[ClassificationModel] = [
3897
+ JavaParams._from_java(model) for model in java_stage.models()
3898
+ ]
3899
+ py_stage = cls(models=models).setPredictionCol(predictionCol).setFeaturesCol(featuresCol)
3900
+ py_stage._set(labelCol=labelCol)
3901
+ if java_stage.isDefined(java_stage.getParam("weightCol")):
3902
+ py_stage._set(weightCol=java_stage.getWeightCol())
3903
+ py_stage._set(classifier=classifier)
3904
+ py_stage._resetUid(java_stage.uid())
3905
+ return py_stage
3906
+
3907
+ def _to_java(self) -> "JavaObject":
3908
+ """
3909
+ Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
3910
+
3911
+ Returns
3912
+ -------
3913
+ py4j.java_gateway.JavaObject
3914
+ Java object equivalent to this instance.
3915
+ """
3916
+ sc = SparkContext._active_spark_context
3917
+ assert sc is not None and sc._gateway is not None
3918
+
3919
+ java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
3920
+ java_models_array = JavaWrapper._new_java_array(
3921
+ java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
3922
+ )
3923
+ metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
3924
+ _java_obj = JavaParams._new_java_obj(
3925
+ "org.apache.spark.ml.classification.OneVsRestModel",
3926
+ self.uid,
3927
+ metadata.empty(),
3928
+ java_models_array,
3929
+ )
3930
+ _java_obj.set("classifier", cast(_JavaClassifier, self.getClassifier())._to_java())
3931
+ _java_obj.set("featuresCol", self.getFeaturesCol())
3932
+ _java_obj.set("labelCol", self.getLabelCol())
3933
+ _java_obj.set("predictionCol", self.getPredictionCol())
3934
+ if self.isDefined(self.weightCol) and self.getWeightCol():
3935
+ _java_obj.set("weightCol", self.getWeightCol())
3936
+ return _java_obj
3937
+
3938
+ @classmethod
3939
+ def read(cls) -> "OneVsRestModelReader":
3940
+ return OneVsRestModelReader(cls)
3941
+
3942
+ def write(self) -> MLWriter:
3943
+ if all(
3944
+ map(
3945
+ lambda elem: isinstance(elem, JavaMLWritable),
3946
+ [self.getClassifier()] + self.models, # type: ignore[operator]
3947
+ )
3948
+ ):
3949
+ return JavaMLWriter(self) # type: ignore[arg-type]
3950
+ else:
3951
+ return OneVsRestModelWriter(self)
3952
+
3953
+
3954
+ @inherit_doc
3955
+ class OneVsRestModelReader(MLReader[OneVsRestModel]):
3956
+ def __init__(self, cls: Type[OneVsRestModel]):
3957
+ super(OneVsRestModelReader, self).__init__()
3958
+ self.cls = cls
3959
+
3960
+ def load(self, path: str) -> OneVsRestModel:
3961
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
3962
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
3963
+ return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
3964
+ else:
3965
+ classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
3966
+ numClasses = metadata["numClasses"]
3967
+ subModels = [None] * numClasses
3968
+ for idx in range(numClasses):
3969
+ subModelPath = os.path.join(path, f"model_{idx}")
3970
+ subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
3971
+ ovaModel = OneVsRestModel(cast(List[ClassificationModel], subModels))._resetUid(
3972
+ metadata["uid"]
3973
+ )
3974
+ ovaModel.set(ovaModel.classifier, classifier)
3975
+ DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=["classifier"])
3976
+ return ovaModel
3977
+
3978
+
3979
+ @inherit_doc
3980
+ class OneVsRestModelWriter(MLWriter):
3981
+ def __init__(self, instance: OneVsRestModel):
3982
+ super(OneVsRestModelWriter, self).__init__()
3983
+ self.instance = instance
3984
+
3985
+ def saveImpl(self, path: str) -> None:
3986
+ _OneVsRestSharedReadWrite.validateParams(self.instance)
3987
+ instance = self.instance
3988
+ numClasses = len(instance.models)
3989
+ extraMetadata = {"numClasses": numClasses}
3990
+ _OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
3991
+ for idx in range(numClasses):
3992
+ subModelPath = os.path.join(path, f"model_{idx}")
3993
+ cast(MLWritable, instance.models[idx]).save(subModelPath)
3994
+
3995
+
3996
+ @inherit_doc
3997
+ class FMClassifier(
3998
+ _JavaProbabilisticClassifier["FMClassificationModel"],
3999
+ _FactorizationMachinesParams,
4000
+ JavaMLWritable,
4001
+ JavaMLReadable["FMClassifier"],
4002
+ ):
4003
+ """
4004
+ Factorization Machines learning algorithm for classification.
4005
+
4006
+ Solver supports:
4007
+
4008
+ * gd (normal mini-batch gradient descent)
4009
+ * adamW (default)
4010
+
4011
+ .. versionadded:: 3.0.0
4012
+
4013
+ Examples
4014
+ --------
4015
+ >>> from pyspark.ml.linalg import Vectors
4016
+ >>> from pyspark.ml.classification import FMClassifier
4017
+ >>> df = spark.createDataFrame([
4018
+ ... (1.0, Vectors.dense(1.0)),
4019
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
4020
+ >>> fm = FMClassifier(factorSize=2)
4021
+ >>> fm.setSeed(11)
4022
+ FMClassifier...
4023
+ >>> model = fm.fit(df)
4024
+ >>> model.getMaxIter()
4025
+ 100
4026
+ >>> test0 = spark.createDataFrame([
4027
+ ... (Vectors.dense(-1.0),),
4028
+ ... (Vectors.dense(0.5),),
4029
+ ... (Vectors.dense(1.0),),
4030
+ ... (Vectors.dense(2.0),)], ["features"])
4031
+ >>> model.predictRaw(test0.head().features)
4032
+ DenseVector([22.13..., -22.13...])
4033
+ >>> model.predictProbability(test0.head().features)
4034
+ DenseVector([1.0, 0.0])
4035
+ >>> model.transform(test0).select("features", "probability").show(10, False)
4036
+ +--------+------------------------------------------+
4037
+ |features|probability |
4038
+ +--------+------------------------------------------+
4039
+ |[-1.0] |[0.9999999997574736,2.425264676902229E-10]|
4040
+ |[0.5] |[0.47627851732981163,0.5237214826701884] |
4041
+ |[1.0] |[5.491554426243495E-4,0.9994508445573757] |
4042
+ |[2.0] |[2.005766663870645E-10,0.9999999997994233]|
4043
+ +--------+------------------------------------------+
4044
+ ...
4045
+ >>> model.intercept
4046
+ -7.316665276826291
4047
+ >>> model.linear
4048
+ DenseVector([14.8232])
4049
+ >>> model.factors
4050
+ DenseMatrix(1, 2, [0.0163, -0.0051], 1)
4051
+ >>> model_path = temp_path + "/fm_model"
4052
+ >>> model.save(model_path)
4053
+ >>> model2 = FMClassificationModel.load(model_path)
4054
+ >>> model2.intercept
4055
+ -7.316665276826291
4056
+ >>> model2.linear
4057
+ DenseVector([14.8232])
4058
+ >>> model2.factors
4059
+ DenseMatrix(1, 2, [0.0163, -0.0051], 1)
4060
+ >>> model.transform(test0).take(1) == model2.transform(test0).take(1)
4061
+ True
4062
+ """
4063
+
4064
+ _input_kwargs: Dict[str, Any]
4065
+
4066
+ @keyword_only
4067
+ def __init__(
4068
+ self,
4069
+ *,
4070
+ featuresCol: str = "features",
4071
+ labelCol: str = "label",
4072
+ predictionCol: str = "prediction",
4073
+ probabilityCol: str = "probability",
4074
+ rawPredictionCol: str = "rawPrediction",
4075
+ factorSize: int = 8,
4076
+ fitIntercept: bool = True,
4077
+ fitLinear: bool = True,
4078
+ regParam: float = 0.0,
4079
+ miniBatchFraction: float = 1.0,
4080
+ initStd: float = 0.01,
4081
+ maxIter: int = 100,
4082
+ stepSize: float = 1.0,
4083
+ tol: float = 1e-6,
4084
+ solver: str = "adamW",
4085
+ thresholds: Optional[List[float]] = None,
4086
+ seed: Optional[int] = None,
4087
+ ):
4088
+ """
4089
+ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
4090
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
4091
+ factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
4092
+ miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
4093
+ tol=1e-6, solver="adamW", thresholds=None, seed=None)
4094
+ """
4095
+ super(FMClassifier, self).__init__()
4096
+ self._java_obj = self._new_java_obj(
4097
+ "org.apache.spark.ml.classification.FMClassifier", self.uid
4098
+ )
4099
+ kwargs = self._input_kwargs
4100
+ self.setParams(**kwargs)
4101
+
4102
+ @keyword_only
4103
+ @since("3.0.0")
4104
+ def setParams(
4105
+ self,
4106
+ *,
4107
+ featuresCol: str = "features",
4108
+ labelCol: str = "label",
4109
+ predictionCol: str = "prediction",
4110
+ probabilityCol: str = "probability",
4111
+ rawPredictionCol: str = "rawPrediction",
4112
+ factorSize: int = 8,
4113
+ fitIntercept: bool = True,
4114
+ fitLinear: bool = True,
4115
+ regParam: float = 0.0,
4116
+ miniBatchFraction: float = 1.0,
4117
+ initStd: float = 0.01,
4118
+ maxIter: int = 100,
4119
+ stepSize: float = 1.0,
4120
+ tol: float = 1e-6,
4121
+ solver: str = "adamW",
4122
+ thresholds: Optional[List[float]] = None,
4123
+ seed: Optional[int] = None,
4124
+ ) -> "FMClassifier":
4125
+ """
4126
+ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
4127
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
4128
+ factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
4129
+ miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
4130
+ tol=1e-6, solver="adamW", thresholds=None, seed=None)
4131
+ Sets Params for FMClassifier.
4132
+ """
4133
+ kwargs = self._input_kwargs
4134
+ return self._set(**kwargs)
4135
+
4136
+ def _create_model(self, java_model: "JavaObject") -> "FMClassificationModel":
4137
+ return FMClassificationModel(java_model)
4138
+
4139
+ @since("3.0.0")
4140
+ def setFactorSize(self, value: int) -> "FMClassifier":
4141
+ """
4142
+ Sets the value of :py:attr:`factorSize`.
4143
+ """
4144
+ return self._set(factorSize=value)
4145
+
4146
+ @since("3.0.0")
4147
+ def setFitLinear(self, value: bool) -> "FMClassifier":
4148
+ """
4149
+ Sets the value of :py:attr:`fitLinear`.
4150
+ """
4151
+ return self._set(fitLinear=value)
4152
+
4153
+ @since("3.0.0")
4154
+ def setMiniBatchFraction(self, value: float) -> "FMClassifier":
4155
+ """
4156
+ Sets the value of :py:attr:`miniBatchFraction`.
4157
+ """
4158
+ return self._set(miniBatchFraction=value)
4159
+
4160
+ @since("3.0.0")
4161
+ def setInitStd(self, value: float) -> "FMClassifier":
4162
+ """
4163
+ Sets the value of :py:attr:`initStd`.
4164
+ """
4165
+ return self._set(initStd=value)
4166
+
4167
+ @since("3.0.0")
4168
+ def setMaxIter(self, value: int) -> "FMClassifier":
4169
+ """
4170
+ Sets the value of :py:attr:`maxIter`.
4171
+ """
4172
+ return self._set(maxIter=value)
4173
+
4174
+ @since("3.0.0")
4175
+ def setStepSize(self, value: float) -> "FMClassifier":
4176
+ """
4177
+ Sets the value of :py:attr:`stepSize`.
4178
+ """
4179
+ return self._set(stepSize=value)
4180
+
4181
+ @since("3.0.0")
4182
+ def setTol(self, value: float) -> "FMClassifier":
4183
+ """
4184
+ Sets the value of :py:attr:`tol`.
4185
+ """
4186
+ return self._set(tol=value)
4187
+
4188
+ @since("3.0.0")
4189
+ def setSolver(self, value: str) -> "FMClassifier":
4190
+ """
4191
+ Sets the value of :py:attr:`solver`.
4192
+ """
4193
+ return self._set(solver=value)
4194
+
4195
+ @since("3.0.0")
4196
+ def setSeed(self, value: int) -> "FMClassifier":
4197
+ """
4198
+ Sets the value of :py:attr:`seed`.
4199
+ """
4200
+ return self._set(seed=value)
4201
+
4202
+ @since("3.0.0")
4203
+ def setFitIntercept(self, value: bool) -> "FMClassifier":
4204
+ """
4205
+ Sets the value of :py:attr:`fitIntercept`.
4206
+ """
4207
+ return self._set(fitIntercept=value)
4208
+
4209
+ @since("3.0.0")
4210
+ def setRegParam(self, value: float) -> "FMClassifier":
4211
+ """
4212
+ Sets the value of :py:attr:`regParam`.
4213
+ """
4214
+ return self._set(regParam=value)
4215
+
4216
+
4217
+ class FMClassificationModel(
4218
+ _JavaProbabilisticClassificationModel[Vector],
4219
+ _FactorizationMachinesParams,
4220
+ JavaMLWritable,
4221
+ JavaMLReadable["FMClassificationModel"],
4222
+ HasTrainingSummary,
4223
+ ):
4224
+ """
4225
+ Model fitted by :class:`FMClassifier`.
4226
+
4227
+ .. versionadded:: 3.0.0
4228
+ """
4229
+
4230
+ @property
4231
+ @since("3.0.0")
4232
+ def intercept(self) -> float:
4233
+ """
4234
+ Model intercept.
4235
+ """
4236
+ return self._call_java("intercept")
4237
+
4238
+ @property
4239
+ @since("3.0.0")
4240
+ def linear(self) -> Vector:
4241
+ """
4242
+ Model linear term.
4243
+ """
4244
+ return self._call_java("linear")
4245
+
4246
+ @property
4247
+ @since("3.0.0")
4248
+ def factors(self) -> Matrix:
4249
+ """
4250
+ Model factor term.
4251
+ """
4252
+ return self._call_java("factors")
4253
+
4254
+ @since("3.1.0")
4255
+ def summary(self) -> "FMClassificationTrainingSummary":
4256
+ """
4257
+ Gets summary (accuracy/precision/recall, objective history, total iterations) of model
4258
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
4259
+ """
4260
+ if self.hasSummary:
4261
+ return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
4262
+ else:
4263
+ raise RuntimeError(
4264
+ "No training summary available for this %s" % self.__class__.__name__
4265
+ )
4266
+
4267
+ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary":
4268
+ """
4269
+ Evaluates the model on a test dataset.
4270
+
4271
+ .. versionadded:: 3.1.0
4272
+
4273
+ Parameters
4274
+ ----------
4275
+ dataset : :py:class:`pyspark.sql.DataFrame`
4276
+ Test dataset to evaluate model on.
4277
+ """
4278
+ if not isinstance(dataset, DataFrame):
4279
+ raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
4280
+ java_fm_summary = self._call_java("evaluate", dataset)
4281
+ return FMClassificationSummary(java_fm_summary)
4282
+
4283
+
4284
+ class FMClassificationSummary(_BinaryClassificationSummary):
4285
+ """
4286
+ Abstraction for FMClassifier Results for a given model.
4287
+
4288
+ .. versionadded:: 3.1.0
4289
+ """
4290
+
4291
+ pass
4292
+
4293
+
4294
+ @inherit_doc
4295
+ class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary):
4296
+ """
4297
+ Abstraction for FMClassifier Training results.
4298
+
4299
+ .. versionadded:: 3.1.0
4300
+ """
4301
+
4302
+ pass
4303
+
4304
+
4305
+ if __name__ == "__main__":
4306
+ import doctest
4307
+ import pyspark.ml.classification
4308
+ from pyspark.sql import SparkSession
4309
+
4310
+ globs = pyspark.ml.classification.__dict__.copy()
4311
+ # The small batch size here ensures that we see multiple batches,
4312
+ # even in these small test examples:
4313
+ spark = SparkSession.builder.master("local[2]").appName("ml.classification tests").getOrCreate()
4314
+ sc = spark.sparkContext
4315
+ globs["sc"] = sc
4316
+ globs["spark"] = spark
4317
+ import tempfile
4318
+
4319
+ temp_path = tempfile.mkdtemp()
4320
+ globs["temp_path"] = temp_path
4321
+ try:
4322
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
4323
+ spark.stop()
4324
+ finally:
4325
+ from shutil import rmtree
4326
+
4327
+ try:
4328
+ rmtree(temp_path)
4329
+ except OSError:
4330
+ pass
4331
+ if failure_count:
4332
+ sys.exit(-1)