julearn 0.3.6.dev47__tar.gz → 0.3.6.dev72__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/PKG-INFO +5 -3
  2. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/models.rst +15 -0
  3. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/available_pipeline_steps.rst +14 -0
  4. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/conf.py +1 -0
  5. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/_version.py +3 -3
  6. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/available_models.py +31 -0
  7. julearn-0.3.6.dev72/julearn/models/tests/test_xgb_cvearlystopping.py +476 -0
  8. julearn-0.3.6.dev72/julearn/models/xgb_cvearlystopping.py +382 -0
  9. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn.egg-info/PKG-INFO +5 -3
  10. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn.egg-info/SOURCES.txt +2 -0
  11. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn.egg-info/requires.txt +5 -2
  12. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/pyproject.toml +6 -2
  13. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/tox.ini +3 -0
  14. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/uv.lock +40 -3
  15. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  16. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  17. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/ISSUE_TEMPLATE/documentation_request.yaml +0 -0
  18. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  19. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/dependabot.yml +0 -0
  20. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/check-stale.yml +0 -0
  21. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/ci-docs.yml +0 -0
  22. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/ci.yml +0 -0
  23. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/docs-preview.yml +0 -0
  24. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/docs.yml +0 -0
  25. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/lint.yml +0 -0
  26. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.github/workflows/publish.yml +0 -0
  27. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.gitignore +0 -0
  28. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/.pre-commit-config.yaml +0 -0
  29. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/AUTHORS.rst +0 -0
  30. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/LICENSE.md +0 -0
  31. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/README.md +0 -0
  32. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/codecov.yml +0 -0
  33. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/Makefile +0 -0
  34. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_static/css/custom.css +0 -0
  35. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_static/css/version_selector.css +0 -0
  36. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_static/js/custom.js +0 -0
  37. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/class.rst +0 -0
  38. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/function.rst +0 -0
  39. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/function_warning.rst +0 -0
  40. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/protocol.rst +0 -0
  41. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/type.rst +0 -0
  42. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/_templates/versions.html +0 -0
  43. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/base.rst +0 -0
  44. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/config.rst +0 -0
  45. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/index.rst +0 -0
  46. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/inspect.rst +0 -0
  47. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/main.rst +0 -0
  48. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/model_selection.rst +0 -0
  49. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/pipeline.rst +0 -0
  50. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/prepare.rst +0 -0
  51. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/scoring.rst +0 -0
  52. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/stats.rst +0 -0
  53. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/transformers.rst +0 -0
  54. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/utils.rst +0 -0
  55. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/api/viz.rst +0 -0
  56. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/contributors.inc +0 -0
  57. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/newsfragments/.gitignore +0 -0
  58. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/newsfragments/301.misc +0 -0
  59. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/newsfragments/303.feature +0 -0
  60. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/newsfragments/307.bugfix +0 -0
  61. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/changes/newsfragments/307.misc +0 -0
  62. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/configuration.rst +0 -0
  63. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/contributing.rst +0 -0
  64. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/examples.rst +0 -0
  65. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/faq.rst +0 -0
  66. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/getting_started.rst +0 -0
  67. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/corrected_ttest.png +0 -0
  68. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/final_estimator.png +0 -0
  69. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/iris_X.png +0 -0
  70. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/iris_df.png +0 -0
  71. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/iris_y.png +0 -0
  72. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/joblib_htcondor/condor_q.png +0 -0
  73. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/joblib_htcondor/ui_main.png +0 -0
  74. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/joblib_htcondor/ui_open.png +0 -0
  75. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/joblib_htcondor/ui_stacked.png +0 -0
  76. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo.png +0 -0
  77. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_calm.png +0 -0
  78. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_confbias.png +0 -0
  79. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_cv.png +0 -0
  80. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_generalization.png +0 -0
  81. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_it.png +0 -0
  82. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_ml.png +0 -0
  83. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/julearn_logo_mlit.png +0 -0
  84. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/multiple_scorers_run_cv.png +0 -0
  85. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/plot_scores.png +0 -0
  86. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/scores_run_cv.png +0 -0
  87. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/scores_run_cv_splitter.png +0 -0
  88. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/images/scores_run_cv_train.png +0 -0
  89. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/index.rst +0 -0
  90. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/links.inc +0 -0
  91. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/maintaining.rst +0 -0
  92. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly.py +0 -0
  93. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly_files/patches/version_banner.html +0 -0
  94. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly_files/patches/version_banner_rtd.html +0 -0
  95. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly_files/patches/versions.html +0 -0
  96. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly_files/patches/versions_rtd.html +0 -0
  97. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/poly_files/templates/index.html +0 -0
  98. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/redirect.html +0 -0
  99. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/CBPM.rst +0 -0
  100. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/confound_removal.rst +0 -0
  101. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/cross_validation_splitter.rst +0 -0
  102. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/hyperparameter_tuning.rst +0 -0
  103. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/index.rst +0 -0
  104. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/joblib.rst +0 -0
  105. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/model_inspect.rst +0 -0
  106. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/stacked_models.rst +0 -0
  107. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/selected_deeper_topics/target_transformers.rst +0 -0
  108. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/cross_validation.rst +0 -0
  109. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/data.rst +0 -0
  110. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/index.rst +0 -0
  111. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/model_comparison.rst +0 -0
  112. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/model_evaluation.rst +0 -0
  113. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/what_really_need_know/pipeline.rst +0 -0
  114. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/docs/whats_new.rst +0 -0
  115. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/README.rst +0 -0
  116. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/plot_cm_acc_multiclass.py +0 -0
  117. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/plot_example_regression.py +0 -0
  118. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/plot_stratified_kfold_reg.py +0 -0
  119. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/run_combine_pandas.py +0 -0
  120. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/run_grouped_cv.py +0 -0
  121. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/00_starting/run_simple_binary_classification.py +0 -0
  122. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/01_model_comparison/README.rst +0 -0
  123. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/01_model_comparison/plot_simple_model_comparison.py +0 -0
  124. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/02_inspection/README.rst +0 -0
  125. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/02_inspection/plot_groupcv_inspect_svm.py +0 -0
  126. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/02_inspection/plot_inspect_random_forest.py +0 -0
  127. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/02_inspection/plot_preprocess.py +0 -0
  128. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/02_inspection/run_binary_inspect_folds.py +0 -0
  129. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/README.rst +0 -0
  130. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_apply_to_target.py +0 -0
  131. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_example_pca_featsets.py +0 -0
  132. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_generate_target.py +0 -0
  133. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_hyperparameter_multiple_grids.py +0 -0
  134. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_hyperparameter_tuning.py +0 -0
  135. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +0 -0
  136. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/03_complex_models/run_stacked_models.py +0 -0
  137. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/04_confounds/README.rst +0 -0
  138. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/04_confounds/plot_confound_removal_classification.py +0 -0
  139. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/04_confounds/run_return_confounds.py +0 -0
  140. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/05_customization/README.rst +0 -0
  141. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/05_customization/run_custom_scorers_regression.py +0 -0
  142. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/README.rst +0 -0
  143. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_cbpm_docs.py +0 -0
  144. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_confound_removal_docs.py +0 -0
  145. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_cv_splitters_docs.py +0 -0
  146. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_data_docs.py +0 -0
  147. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_hyperparameters_docs.py +0 -0
  148. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_model_comparison_docs.py +0 -0
  149. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_model_evaluation_docs.py +0 -0
  150. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_model_inspection_docs.py +0 -0
  151. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_pipeline_docs.py +0 -0
  152. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_stacked_models_docs.py +0 -0
  153. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/99_docs/run_target_transformer_docs.py +0 -0
  154. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/README.rst +0 -0
  155. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/XX_disabled/dis_run_n_jobs.py +0 -0
  156. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/examples/XX_disabled/dis_run_target_confound_removal.py +0 -0
  157. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/ignore_words.txt +0 -0
  158. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/__init__.py +0 -0
  159. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/api.py +0 -0
  160. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/base/__init__.py +0 -0
  161. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/base/column_types.py +0 -0
  162. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/base/estimators.py +0 -0
  163. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/base/tests/test_base_estimators.py +0 -0
  164. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/base/tests/test_column_types.py +0 -0
  165. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/config.py +0 -0
  166. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/conftest.py +0 -0
  167. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/external/optuna_searchcv.py +0 -0
  168. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/__init__.py +0 -0
  169. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/_cv.py +0 -0
  170. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/_pipeline.py +0 -0
  171. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/_preprocess.py +0 -0
  172. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/inspector.py +0 -0
  173. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/tests/test_cv.py +0 -0
  174. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/tests/test_inspector.py +0 -0
  175. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/tests/test_pipeline.py +0 -0
  176. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/inspect/tests/test_preprocess.py +0 -0
  177. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/__init__.py +0 -0
  178. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/_optuna_searcher.py +0 -0
  179. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/_skopt_searcher.py +0 -0
  180. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/available_searchers.py +0 -0
  181. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/continuous_stratified_kfold.py +0 -0
  182. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/final_model_cv.py +0 -0
  183. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/stratified_bootstrap.py +0 -0
  184. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_available_searchers.py +0 -0
  185. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_continous_stratified_kfold.py +0 -0
  186. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_final_model_cv.py +0 -0
  187. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_optuna_searcher.py +0 -0
  188. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_skopt_searcher.py +0 -0
  189. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/tests/test_stratified_bootstrap.py +0 -0
  190. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/model_selection/utils.py +0 -0
  191. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/__init__.py +0 -0
  192. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/dynamic.py +0 -0
  193. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/tests/test_available_models.py +0 -0
  194. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/tests/test_dynamic.py +0 -0
  195. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/models/tests/test_models.py +0 -0
  196. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/__init__.py +0 -0
  197. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/merger.py +0 -0
  198. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/pipeline_creator.py +0 -0
  199. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/target_pipeline.py +0 -0
  200. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/target_pipeline_creator.py +0 -0
  201. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/tests/test_merger.py +0 -0
  202. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/tests/test_pipeline_creator.py +0 -0
  203. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/tests/test_target_pipeline.py +0 -0
  204. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/pipeline/tests/test_target_pipeline_creator.py +0 -0
  205. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/prepare.py +0 -0
  206. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/scoring/__init__.py +0 -0
  207. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/scoring/available_scorers.py +0 -0
  208. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/scoring/metrics.py +0 -0
  209. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/scoring/tests/test_available_scorers.py +0 -0
  210. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/scoring/tests/test_metrics.py +0 -0
  211. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/stats/__init__.py +0 -0
  212. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/stats/corrected_ttest.py +0 -0
  213. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/stats/tests/test_corrected_ttest.py +0 -0
  214. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/tests/test_api.py +0 -0
  215. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/tests/test_config.py +0 -0
  216. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/tests/test_prepare.py +0 -0
  217. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/__init__.py +0 -0
  218. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/available_transformers.py +0 -0
  219. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/cbpm.py +0 -0
  220. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/confound_remover.py +0 -0
  221. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/__init__.py +0 -0
  222. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/change_column_types.py +0 -0
  223. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/drop_columns.py +0 -0
  224. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/filter_columns.py +0 -0
  225. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/pick_columns.py +0 -0
  226. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/set_column_types.py +0 -0
  227. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/tests/test_change_column_types.py +0 -0
  228. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/tests/test_drop_columns.py +0 -0
  229. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/tests/test_filter_columns.py +0 -0
  230. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/dataframe/tests/test_set_column_types.py +0 -0
  231. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/ju_column_transformer.py +0 -0
  232. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/__init__.py +0 -0
  233. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/available_target_transformers.py +0 -0
  234. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/ju_generated_target_model.py +0 -0
  235. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/ju_target_transformer.py +0 -0
  236. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/ju_transformed_target_model.py +0 -0
  237. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/target_confound_remover.py +0 -0
  238. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/tests/test_available_target_transformers.py +0 -0
  239. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/tests/test_ju_generated_target_model.py +0 -0
  240. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/tests/test_ju_target_transformer.py +0 -0
  241. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/tests/test_ju_transformed_target_model.py +0 -0
  242. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/target/tests/test_target_confound_remover.py +0 -0
  243. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/tests/test_available_transformers.py +0 -0
  244. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/tests/test_cbpm.py +0 -0
  245. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/tests/test_confounds.py +0 -0
  246. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/transformers/tests/test_jucolumntransformers.py +0 -0
  247. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/__init__.py +0 -0
  248. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/_cv.py +0 -0
  249. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/checks.py +0 -0
  250. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/logging.py +0 -0
  251. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/testing.py +0 -0
  252. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/tests/test_logging.py +0 -0
  253. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/tests/test_version.py +0 -0
  254. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/typing.py +0 -0
  255. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/utils/versions.py +0 -0
  256. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/viz/__init__.py +0 -0
  257. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/viz/_scores.py +0 -0
  258. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn/viz/res/julearn_logo_generalization.png +0 -0
  259. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn.egg-info/dependency_links.txt +0 -0
  260. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/julearn.egg-info/top_level.txt +0 -0
  261. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/setup.cfg +0 -0
  262. {julearn-0.3.6.dev47 → julearn-0.3.6.dev72}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: julearn
3
- Version: 0.3.6.dev47
3
+ Version: 0.3.6.dev72
4
4
  Summary: Juelich Machine Learning Library
5
5
  Author-email: Fede Raimondo <f.raimondo@fz-juelich.de>, Sami Hamdan <s.hamdan@fz-juelich.de>
6
6
  Maintainer-email: Sami Hamdan <s.hamdan@fz-juelich.de>
@@ -59,10 +59,12 @@ Requires-Dist: scikit-optimize<0.11.0,>=0.10.2; extra == "skopt"
59
59
  Provides-Extra: optuna
60
60
  Requires-Dist: optuna<5.0.0,>=4.0.0; extra == "optuna"
61
61
  Requires-Dist: optuna_integration<5.0.0,>=4.0.0; extra == "optuna"
62
+ Provides-Extra: xgboost
63
+ Requires-Dist: xgboost<4.0.0,>=3.0.0; extra == "xgboost"
62
64
  Provides-Extra: docs
63
- Requires-Dist: julearn[optuna,skopt,sphinx,viz]; extra == "docs"
65
+ Requires-Dist: julearn[optuna,skopt,sphinx,viz,xgboost]; extra == "docs"
64
66
  Provides-Extra: all
65
- Requires-Dist: julearn[optuna,skopt,viz]; extra == "all"
67
+ Requires-Dist: julearn[optuna,skopt,viz,xgboost]; extra == "all"
66
68
  Dynamic: license-file
67
69
 
68
70
  # julearn
@@ -19,6 +19,21 @@ Functions
19
19
  register_model
20
20
  reset_model_register
21
21
 
22
+ Julearn custom models
23
+ ---------------------
24
+
25
+ This is a list of models implemented by Julearn that are not simple wrappers
26
+ around existing models in other libraries but rather variants of existing
27
+ models or novel models.
28
+
29
+ .. autosummary::
30
+ :nosignatures:
31
+ :toctree: generated/
32
+ :template: class.rst
33
+
34
+ xgb_cvearlystopping.XGBClassifierCVEarlyStopping
35
+ xgb_cvearlystopping.XGBRegressorCVEarlyStopping
36
+
22
37
  Dynamic Selection (DESLib)
23
38
  ==========================
24
39
 
@@ -235,6 +235,20 @@ Ensemble
235
235
  - Y
236
236
  - Y
237
237
  - Y
238
+ * - ``xgb``
239
+ - XGBoost
240
+ - | :class:`~xgboost.XGBClassifier` and
241
+ | :class:`~xgboost.XGBRegressor`
242
+ - Y
243
+ - Y
244
+ - Y
245
+ * - ``xgb_cvearlystopping``
246
+ - XGBoost with Cross-Validation and Early Stopping
247
+ - | :class:`~julearn.models.xgb_cvearlystopping.XGBClassifierCVEarlyStopping` and
248
+ | :class:`~julearn.models.xgb_cvearlystopping.XGBRegressorCVEarlyStopping`
249
+ - Y
250
+ - Y
251
+ - Y
238
252
 
239
253
  Gaussian Processes
240
254
  ~~~~~~~~~~~~~~~~~~
@@ -231,6 +231,7 @@ intersphinx_mapping = {
231
231
  None,
232
232
  ),
233
233
  "panel": ("https://panel.holoviz.org/", None),
234
+ "xgboost": ("https://xgboost.readthedocs.io/en/stable/", None),
234
235
  }
235
236
 
236
237
  # -- sphinx.ext.extlinks configuration ---------------------------------------
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '0.3.6.dev47'
22
- __version_tuple__ = version_tuple = (0, 3, 6, 'dev47')
21
+ __version__ = version = '0.3.6.dev72'
22
+ __version_tuple__ = version_tuple = (0, 3, 6, 'dev72')
23
23
 
24
- __commit_id__ = commit_id = 'g5501ac265'
24
+ __commit_id__ = commit_id = 'gbc239b21b'
@@ -46,6 +46,19 @@ from sklearn.naive_bayes import (
46
46
  )
47
47
  from sklearn.svm import SVC, SVR
48
48
 
49
+
50
+ try: # pragma: no cover
51
+ from xgboost import XGBClassifier, XGBRegressor
52
+
53
+ from .xgb_cvearlystopping import (
54
+ XGBClassifierCVEarlyStopping,
55
+ XGBRegressorCVEarlyStopping,
56
+ )
57
+
58
+ _has_xgboost = True
59
+ except ImportError:
60
+ _has_xgboost = False
61
+
49
62
  from ..utils import logger, raise_error, warn_with_log
50
63
  from ..utils.logging import DelayedFmtMessage as __
51
64
  from ..utils.typing import ModelLike
@@ -137,6 +150,24 @@ _available_models: dict[str, dict[str, Any]] = {
137
150
  },
138
151
  }
139
152
 
153
+ if _has_xgboost is True:
154
+ _available_models["xgb"] = {
155
+ "regression": XGBRegressor,
156
+ "classification": XGBClassifier,
157
+ }
158
+ _available_models["xgb_cvearlystopping"] = {
159
+ "regression": XGBRegressorCVEarlyStopping,
160
+ "classification": XGBClassifierCVEarlyStopping,
161
+ }
162
+ logger.info(
163
+ "XGBoost is available and has been added to the model registry."
164
+ )
165
+ else:
166
+ logger.info(
167
+ "XGBoost is not available and has not been added to the model "
168
+ "registry. To use XGBoost models, please install the xgboost package."
169
+ )
170
+
140
171
  _available_models_reset = deepcopy(_available_models)
141
172
 
142
173
 
@@ -0,0 +1,476 @@
1
+ """Provide tests for XGBEarlyStoppingCV."""
2
+
3
+ # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
4
+ # License: AGPL
5
+
6
+ import pandas as pd
7
+ import pytest
8
+ from sklearn.utils.validation import _is_fitted
9
+
10
+ from julearn.models.xgb_cvearlystopping import (
11
+ XGBClassifierCVEarlyStopping,
12
+ XGBRegressorCVEarlyStopping,
13
+ )
14
+
15
+
16
+ def test_XGBRegressorCVEarlyStopping_grouped(df_iris) -> None:
17
+ """Test XGBRegressorCVEarlyStopping with grouped data.
18
+
19
+ Parameters
20
+ ----------
21
+ df_iris : pd.DataFrame
22
+ The iris dataset as a DataFrame.
23
+
24
+ """
25
+ X = ["sepal_length", "sepal_width", "petal_width"]
26
+ y = "petal_length"
27
+ n_groups = 20
28
+ bins = pd.cut(
29
+ df_iris.index.values, labels=list(range(n_groups)), bins=n_groups
30
+ )
31
+ df_iris["group"] = bins.astype(int)
32
+
33
+ model = XGBRegressorCVEarlyStopping(
34
+ test_size=0.2, early_stopping_rounds=5, random_state=42
35
+ )
36
+
37
+ assert _is_fitted(model) is False
38
+ assert not hasattr(model, "_grouped_cv")
39
+ model.fit(df_iris[X], df_iris[y], groups=df_iris["group"])
40
+ assert _is_fitted(model)
41
+ assert hasattr(model, "_grouped_cv")
42
+ assert model._grouped_cv is True
43
+ assert model._model.get_params()["num_parallel_tree"] is None
44
+
45
+ # Check that the model was refit with the best number of iterations
46
+ assert model._model.get_params()["early_stopping_rounds"] is None
47
+ assert model._model.get_params()["random_state"] == 42
48
+ assert model._best_iteration is not None
49
+ assert (
50
+ model._model.get_params()["n_estimators"] == model._best_iteration + 1
51
+ )
52
+
53
+ y_pred = model.predict(df_iris[X])
54
+ assert y_pred.shape == (len(df_iris),)
55
+
56
+ score = model.score(df_iris[X], df_iris[y])
57
+ assert isinstance(score, float)
58
+
59
+
60
+ def test_XGBRegressorCVEarlyStopping_notgrouped(df_iris) -> None:
61
+ """Test XGBRegressorCVEarlyStopping with non-grouped data.
62
+
63
+ Parameters
64
+ ----------
65
+ df_iris : pd.DataFrame
66
+ The iris dataset as a DataFrame.
67
+
68
+ """
69
+ X = ["sepal_length", "sepal_width", "petal_width"]
70
+ y = "petal_length"
71
+
72
+ model = XGBRegressorCVEarlyStopping(
73
+ test_size=0.2, early_stopping_rounds=5, random_state=42
74
+ )
75
+
76
+ assert _is_fitted(model) is False
77
+ assert not hasattr(model, "_grouped_cv")
78
+ model.fit(df_iris[X], df_iris[y])
79
+ assert _is_fitted(model)
80
+ assert hasattr(model, "_grouped_cv")
81
+ assert model._grouped_cv is False
82
+ assert model._model.get_params()["num_parallel_tree"] is None
83
+
84
+ # Check that the model was refit with the best number of iterations
85
+ assert model._model.get_params()["early_stopping_rounds"] is None
86
+ assert model._model.get_params()["random_state"] == 42
87
+ assert model._best_iteration is not None
88
+ assert (
89
+ model._model.get_params()["n_estimators"] == model._best_iteration + 1
90
+ )
91
+
92
+
93
+ def test_XGBRegressorCVEarlyStopping_numpy(df_iris) -> None:
94
+ """Test XGBRegressorCVEarlyStopping with numpy data.
95
+
96
+ Parameters
97
+ ----------
98
+ df_iris : pd.DataFrame
99
+ The iris dataset as a DataFrame.
100
+
101
+ """
102
+ X = ["sepal_length", "sepal_width", "petal_width"]
103
+ y = "petal_length"
104
+
105
+ model = XGBRegressorCVEarlyStopping(
106
+ test_size=0.2,
107
+ early_stopping_rounds=5,
108
+ random_state=42,
109
+ num_parallel_tree=2,
110
+ )
111
+
112
+ assert _is_fitted(model) is False
113
+ assert not hasattr(model, "_grouped_cv")
114
+ model.fit(df_iris[X].values, df_iris[y].values)
115
+ assert _is_fitted(model)
116
+ assert hasattr(model, "_grouped_cv")
117
+ assert model._grouped_cv is False
118
+
119
+ # Check that the model was refit with the best number of iterations
120
+ assert model._model.get_params()["early_stopping_rounds"] is None
121
+ assert model._model.get_params()["random_state"] == 42
122
+ assert model._best_iteration is not None
123
+ assert model._model.get_params()["num_parallel_tree"] == 2
124
+ assert (
125
+ model._model.get_params()["n_estimators"]
126
+ == (model._best_iteration + 1) * 2
127
+ )
128
+
129
+
130
+ def test_XGBClassifierCVEarlyStopping_notgrouped(df_iris) -> None:
131
+ """Test XGBClassifierCVEarlyStopping with non-grouped data.
132
+
133
+ Parameters
134
+ ----------
135
+ df_iris : pd.DataFrame
136
+ The iris dataset as a DataFrame.
137
+
138
+ """
139
+ X = ["sepal_length", "sepal_width", "petal_width"]
140
+ y = "species"
141
+
142
+ model = XGBClassifierCVEarlyStopping(
143
+ test_size=0.2, early_stopping_rounds=5, random_state=42
144
+ )
145
+
146
+ assert _is_fitted(model) is False
147
+ assert not hasattr(model, "_grouped_cv")
148
+ model.fit(df_iris[X], df_iris[y])
149
+ assert _is_fitted(model)
150
+ assert hasattr(model, "_grouped_cv")
151
+ assert model._grouped_cv is False
152
+ assert model._model.get_params()["num_parallel_tree"] is None
153
+
154
+ # Check that the model was refit with the best number of iterations
155
+ assert model._model.get_params()["early_stopping_rounds"] is None
156
+ assert model._model.get_params()["random_state"] == 42
157
+ assert model._best_iteration is not None
158
+
159
+ # Three classes, so the number of trees is the best iteration times 3
160
+ assert (
161
+ model._model.get_params()["n_estimators"]
162
+ == (model._best_iteration + 1) * 3
163
+ )
164
+
165
+
166
+ def test_XGBClassifierCVEarlyStopping_grouped(df_iris) -> None:
167
+ """Test XGBClassifierCVEarlyStopping with grouped data.
168
+
169
+ Parameters
170
+ ----------
171
+ df_iris : pd.DataFrame
172
+ The iris dataset as a DataFrame.
173
+
174
+ """
175
+ X = ["sepal_length", "sepal_width", "petal_width"]
176
+ y = "species"
177
+ n_groups = 20
178
+ bins = pd.cut(
179
+ df_iris.index.values, labels=list(range(n_groups)), bins=n_groups
180
+ )
181
+ df_iris["group"] = bins.astype(int)
182
+
183
+ model = XGBClassifierCVEarlyStopping(
184
+ test_size=0.2, early_stopping_rounds=5, random_state=42
185
+ )
186
+
187
+ assert _is_fitted(model) is False
188
+ assert not hasattr(model, "_grouped_cv")
189
+ model.fit(df_iris[X], df_iris[y], groups=df_iris["group"])
190
+ assert _is_fitted(model)
191
+ assert hasattr(model, "_grouped_cv")
192
+ assert model._grouped_cv is True
193
+ assert model.get_params()["test_size"] == 0.2
194
+ assert model.get_params()["early_stopping_rounds"] == 5
195
+ assert model.get_params()["random_state"] == 42
196
+
197
+ # Check that the model was refit with the best number of iterations
198
+ assert model._model.get_params()["early_stopping_rounds"] is None
199
+ assert model._model.get_params()["random_state"] == 42
200
+ assert model._best_iteration is not None
201
+
202
+ # Three classes, so the number of trees is the best iteration times 3
203
+ assert (
204
+ model._model.get_params()["n_estimators"]
205
+ == (model._best_iteration + 1) * 3
206
+ )
207
+
208
+ y_pred = model.predict(df_iris[X])
209
+ assert y_pred.shape == (len(df_iris),)
210
+ assert set(y_pred).issubset(set(df_iris[y]))
211
+
212
+ y_probas = model.predict_proba(df_iris[X])
213
+ assert y_probas.shape == (len(df_iris), 3)
214
+ assert (y_probas >= 0).all() and (y_probas <= 1).all()
215
+
216
+ score = model.score(df_iris[X], df_iris[y])
217
+ assert isinstance(score, float)
218
+
219
+
220
+ def test_XGBClassifierCVEarlyStopping_binary(df_binary) -> None:
221
+ """Test XGBClassifierCVEarlyStopping with binary classification.
222
+
223
+ Parameters
224
+ ----------
225
+ df_binary : pd.DataFrame
226
+ The binary classification dataset as a DataFrame.
227
+
228
+ """
229
+ X = ["sepal_length", "sepal_width", "petal_width"]
230
+ y = "species"
231
+
232
+ model = XGBClassifierCVEarlyStopping(
233
+ test_size=0.2, early_stopping_rounds=5
234
+ )
235
+
236
+ assert _is_fitted(model) is False
237
+ assert not hasattr(model, "_grouped_cv")
238
+ model.fit(df_binary[X], df_binary[y])
239
+ assert _is_fitted(model)
240
+ assert hasattr(model, "_grouped_cv")
241
+ assert model._grouped_cv is False
242
+ assert model.get_params()["test_size"] == 0.2
243
+ assert model.get_params()["early_stopping_rounds"] == 5
244
+ # Check that the model was refit with the best number of iterations
245
+ assert model._model.get_params()["early_stopping_rounds"] is None
246
+ assert model._model.get_params()["random_state"] is None
247
+ assert model._best_iteration is not None
248
+
249
+ # Two classes, so the number of trees is the best iteration times 2
250
+ assert (
251
+ model._model.get_params()["n_estimators"]
252
+ == (model._best_iteration + 1) * 2
253
+ )
254
+ y_pred = model.predict(df_binary[X])
255
+ assert y_pred.shape == (len(df_binary),)
256
+ assert set(y_pred).issubset(set(df_binary[y]))
257
+
258
+ y_probas = model.predict_proba(df_binary[X])
259
+ assert y_probas.shape == (len(df_binary), 2)
260
+ assert (y_probas >= 0).all() and (y_probas <= 1).all()
261
+
262
+ score = model.score(df_binary[X], df_binary[y])
263
+ assert isinstance(score, float)
264
+
265
+
266
+ def test_XGBClassifierCVEarlyStopping_grouped_numpy(df_iris) -> None:
267
+ """Test XGBClassifierCVEarlyStopping with grouped data and numpy arrays.
268
+
269
+ Parameters
270
+ ----------
271
+ df_iris : pd.DataFrame
272
+ The iris dataset as a DataFrame.
273
+
274
+ """
275
+ X = ["sepal_length", "sepal_width", "petal_width"]
276
+ y = "species"
277
+ n_groups = 20
278
+ bins = pd.cut(
279
+ df_iris.index.values, labels=list(range(n_groups)), bins=n_groups
280
+ )
281
+ df_iris["group"] = bins.astype(int)
282
+
283
+ model = XGBClassifierCVEarlyStopping(
284
+ test_size=0.2, early_stopping_rounds=5, random_state=42
285
+ )
286
+
287
+ assert _is_fitted(model) is False
288
+ assert not hasattr(model, "_grouped_cv")
289
+ model.fit(
290
+ df_iris[X].values,
291
+ df_iris[y].values.to_numpy(),
292
+ groups=df_iris["group"].values,
293
+ )
294
+ assert _is_fitted(model)
295
+ assert hasattr(model, "_grouped_cv")
296
+ assert model._grouped_cv is True
297
+ assert model.get_params()["test_size"] == 0.2
298
+ assert model.get_params()["early_stopping_rounds"] == 5
299
+ assert model.get_params()["random_state"] == 42
300
+
301
+ # Check that the model was refit with the best number of iterations
302
+ assert model._model.get_params()["early_stopping_rounds"] is None
303
+ assert model._model.get_params()["random_state"] == 42
304
+ assert model._best_iteration is not None
305
+
306
+ # Three classes, so the number of trees is the best iteration times 3
307
+ assert (
308
+ model._model.get_params()["n_estimators"]
309
+ == (model._best_iteration + 1) * 3
310
+ )
311
+
312
+ y_pred = model.predict(df_iris[X])
313
+ assert y_pred.shape == (len(df_iris),)
314
+ assert set(y_pred).issubset(set(df_iris[y]))
315
+
316
+ y_probas = model.predict_proba(df_iris[X])
317
+ assert y_probas.shape == (len(df_iris), 3)
318
+ assert (y_probas >= 0).all() and (y_probas <= 1).all()
319
+
320
+ score = model.score(df_iris[X], df_iris[y])
321
+ assert isinstance(score, float)
322
+
323
+
324
+ def test_XGBClassifierCVEarlyStopping_errors() -> None:
325
+ """Test XGBClassifierCVEarlyStopping error handling."""
326
+ with pytest.raises(ValueError, match="early_stopping_rounds"):
327
+ model = XGBClassifierCVEarlyStopping(
328
+ test_size=0.2, early_stopping_rounds=None, random_state=42
329
+ )
330
+
331
+ with pytest.raises(ValueError, match="not fitted"):
332
+ model = XGBClassifierCVEarlyStopping(
333
+ test_size=None, early_stopping_rounds=5, random_state=42
334
+ )
335
+ model.predict([[1, 2], [3, 4], [5, 6]])
336
+
337
+ with pytest.raises(ValueError, match="not fitted"):
338
+ model = XGBClassifierCVEarlyStopping(
339
+ test_size=None, early_stopping_rounds=5, random_state=42
340
+ )
341
+ model.predict_proba([[1, 2], [3, 4], [5, 6]])
342
+
343
+
344
+ def test_XGBClassifierCVEarlyStopping_numpy(df_iris) -> None:
345
+ """Test XGBClassifierCVEarlyStopping with numpy data.
346
+
347
+ Parameters
348
+ ----------
349
+ df_iris : pd.DataFrame
350
+ The iris dataset as a DataFrame.
351
+
352
+ """
353
+ X = ["sepal_length", "sepal_width", "petal_width"]
354
+ y = "species"
355
+
356
+ model = XGBClassifierCVEarlyStopping(
357
+ test_size=0.2, early_stopping_rounds=5, random_state=42
358
+ )
359
+
360
+ assert _is_fitted(model) is False
361
+ assert not hasattr(model, "_grouped_cv")
362
+ model.fit(df_iris[X].values, df_iris[y].values)
363
+ assert _is_fitted(model)
364
+ assert hasattr(model, "_grouped_cv")
365
+ assert model._grouped_cv is False
366
+
367
+ # Check that the model was refit with the best number of iterations
368
+ assert model._model.get_params()["early_stopping_rounds"] is None
369
+ assert model._model.get_params()["random_state"] == 42
370
+ assert model._best_iteration is not None
371
+
372
+ # Three classes, so the number of trees is the best iteration times 3
373
+ assert (
374
+ model._model.get_params()["n_estimators"]
375
+ == (model._best_iteration + 1) * 3
376
+ )
377
+
378
+ y_nostring = df_iris[y].values.to_numpy() == "setosa"
379
+
380
+ model.fit(df_iris[X].values, y_nostring)
381
+ assert _is_fitted(model)
382
+ assert hasattr(model, "_grouped_cv")
383
+ assert model._grouped_cv is False
384
+
385
+ # Check that the model was refit with the best number of iterations
386
+ assert model._model.get_params()["early_stopping_rounds"] is None
387
+ assert model._model.get_params()["random_state"] == 42
388
+ assert model._best_iteration is not None
389
+
390
+ # Three classes, so the number of trees is the best iteration times 3
391
+ assert (
392
+ model._model.get_params()["n_estimators"]
393
+ == (model._best_iteration + 1) * 2
394
+ )
395
+
396
+ y_pred = model.predict(df_iris[X].values)
397
+ assert y_pred.shape == (len(df_iris),)
398
+ assert set(y_pred).issubset(set(y_nostring))
399
+
400
+ y_probas = model.predict_proba(df_iris[X].values)
401
+ assert y_probas.shape == (len(df_iris), 2)
402
+ assert (y_probas >= 0).all() and (y_probas <= 1).all()
403
+
404
+ score = model.score(df_iris[X].values, y_nostring)
405
+ assert isinstance(score, float)
406
+
407
+
408
+ def test_XGBClassifierCVEarlyStopping_set_params(df_iris) -> None:
409
+ """Test XGBClassifierCVEarlyStopping with grouped data.
410
+
411
+ Parameters
412
+ ----------
413
+ df_iris : pd.DataFrame
414
+ The iris dataset as a DataFrame.
415
+
416
+ """
417
+ X = ["sepal_length", "sepal_width", "petal_width"]
418
+ y = "species"
419
+ n_groups = 20
420
+ bins = pd.cut(
421
+ df_iris.index.values, labels=list(range(n_groups)), bins=n_groups
422
+ )
423
+ df_iris["group"] = bins.astype(int)
424
+
425
+ model = XGBClassifierCVEarlyStopping(
426
+ test_size=0.2, early_stopping_rounds=5, random_state=42
427
+ )
428
+
429
+ assert _is_fitted(model) is False
430
+ assert not hasattr(model, "_grouped_cv")
431
+ model.fit(df_iris[X], df_iris[y], groups=df_iris["group"])
432
+ assert _is_fitted(model)
433
+ assert hasattr(model, "_grouped_cv")
434
+ assert model._grouped_cv is True
435
+ assert model.get_params()["test_size"] == 0.2
436
+ assert model.get_params()["early_stopping_rounds"] == 5
437
+ assert model.get_params()["random_state"] == 42
438
+
439
+ # Check that the model was refit with the best number of iterations
440
+ assert model._model.get_params()["early_stopping_rounds"] is None
441
+ assert model._model.get_params()["random_state"] == 42
442
+ assert model._best_iteration is not None
443
+
444
+ # Three classes, so the number of trees is the best iteration times 3
445
+ assert (
446
+ model._model.get_params()["n_estimators"]
447
+ == (model._best_iteration + 1) * 3
448
+ )
449
+
450
+ model.set_params(
451
+ test_size=0.3,
452
+ early_stopping_rounds=10,
453
+ random_state=24,
454
+ num_parallel_tree=2,
455
+ )
456
+ assert model.get_params()["test_size"] == 0.3
457
+ assert model.get_params()["early_stopping_rounds"] == 10
458
+ assert model.get_params()["random_state"] == 24
459
+ assert model.get_params()["num_parallel_tree"] == 2
460
+ model.fit(df_iris[X], df_iris[y], groups=df_iris["group"])
461
+ assert _is_fitted(model)
462
+ assert hasattr(model, "_grouped_cv")
463
+ assert model._grouped_cv is True
464
+ assert model.get_params()["test_size"] == 0.3
465
+ assert model.get_params()["early_stopping_rounds"] == 10
466
+ assert model.get_params()["random_state"] == 24
467
+ assert model.get_params()["num_parallel_tree"] == 2
468
+ # Check that the model was refit with the best number of iterations
469
+ assert model._model.get_params()["early_stopping_rounds"] is None
470
+ assert model._model.get_params()["random_state"] == 24
471
+ assert model._best_iteration is not None
472
+ # Three classes, so the number of trees is the best iteration times 3
473
+ assert (
474
+ model._model.get_params()["n_estimators"]
475
+ == (model._best_iteration + 1) * 3 * 2
476
+ )