julearn 0.3.2.dev24__tar.gz → 0.3.2.dev61__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/PKG-INFO +6 -1
  2. julearn-0.3.2.dev61/docs/changes/newsfragments/260.enh +1 -0
  3. julearn-0.3.2.dev61/docs/changes/newsfragments/260.misc +1 -0
  4. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/conf.py +1 -0
  5. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/getting_started.rst +5 -1
  6. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/links.inc +1 -0
  7. julearn-0.3.2.dev61/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +95 -0
  8. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_hyperparameters_docs.py +123 -13
  9. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/_version.py +2 -2
  10. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/api.py +46 -24
  11. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/base/estimators.py +26 -8
  12. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/base/tests/test_base_estimators.py +1 -1
  13. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/conftest.py +134 -1
  14. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/_cv.py +16 -10
  15. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/_preprocess.py +1 -1
  16. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/inspector.py +8 -5
  17. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/tests/test_cv.py +4 -2
  18. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/tests/test_inspector.py +12 -9
  19. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/tests/test_pipeline.py +31 -18
  20. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/tests/test_preprocess.py +8 -1
  21. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/__init__.py +4 -0
  22. julearn-0.3.2.dev61/julearn/model_selection/_skopt_searcher.py +32 -0
  23. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/available_searchers.py +69 -8
  24. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/stratified_bootstrap.py +7 -5
  25. julearn-0.3.2.dev61/julearn/model_selection/tests/test_available_searchers.py +83 -0
  26. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/dynamic.py +1 -1
  27. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/tests/test_models.py +6 -6
  28. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/pipeline/merger.py +44 -35
  29. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/pipeline/pipeline_creator.py +88 -21
  30. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/pipeline/target_pipeline.py +1 -1
  31. {julearn-0.3.2.dev24/julearn/pipeline/test → julearn-0.3.2.dev61/julearn/pipeline/tests}/test_merger.py +27 -15
  32. {julearn-0.3.2.dev24/julearn/pipeline/test → julearn-0.3.2.dev61/julearn/pipeline/tests}/test_pipeline_creator.py +231 -8
  33. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/prepare.py +2 -4
  34. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/scoring/available_scorers.py +12 -6
  35. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/stats/corrected_ttest.py +8 -4
  36. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/stats/tests/test_corrected_ttest.py +22 -5
  37. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/tests/test_api.py +103 -45
  38. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/confound_remover.py +2 -2
  39. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/tests/test_drop_columns.py +1 -1
  40. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/tests/test_filter_columns.py +1 -1
  41. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/tests/test_set_column_types.py +10 -6
  42. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/ju_column_transformer.py +1 -1
  43. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/ju_transformed_target_model.py +1 -1
  44. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/target_confound_remover.py +1 -1
  45. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/tests/test_ju_transformed_target_model.py +3 -1
  46. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/tests/test_cbpm.py +3 -3
  47. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/tests/test_confounds.py +17 -6
  48. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/tests/test_jucolumntransformers.py +8 -6
  49. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/checks.py +3 -1
  50. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/logging.py +1 -1
  51. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/testing.py +21 -11
  52. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/typing.py +39 -19
  53. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/viz/_scores.py +5 -4
  54. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn.egg-info/PKG-INFO +6 -1
  55. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn.egg-info/SOURCES.txt +8 -4
  56. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn.egg-info/requires.txt +7 -0
  57. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/pyproject.toml +25 -0
  58. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/tox.ini +15 -4
  59. julearn-0.3.2.dev24/julearn/model_selection/tests/test_available_searchers.py +0 -44
  60. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  61. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  62. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/ISSUE_TEMPLATE/documentation_request.yaml +0 -0
  63. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  64. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/check-stale.yml +0 -0
  65. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/ci-docs.yml +0 -0
  66. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/ci.yml +0 -0
  67. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/docs-preview.yml +0 -0
  68. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/docs.yml +0 -0
  69. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/lint.yml +0 -0
  70. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.github/workflows/pypi.yml +0 -0
  71. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.gitignore +0 -0
  72. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/.pre-commit-config.yaml +0 -0
  73. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/AUTHORS.rst +0 -0
  74. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/LICENSE.md +0 -0
  75. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/README.md +0 -0
  76. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/codecov.yml +0 -0
  77. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/Makefile +0 -0
  78. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_static/css/custom.css +0 -0
  79. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_static/js/custom.js +0 -0
  80. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_templates/class.rst +0 -0
  81. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_templates/function.rst +0 -0
  82. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_templates/function_warning.rst +0 -0
  83. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/_templates/versions.html +0 -0
  84. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/base.rst +0 -0
  85. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/index.rst +0 -0
  86. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/inspect.rst +0 -0
  87. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/main.rst +0 -0
  88. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/model_selection.rst +0 -0
  89. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/models.rst +0 -0
  90. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/pipeline.rst +0 -0
  91. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/prepare.rst +0 -0
  92. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/scoring.rst +0 -0
  93. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/stats.rst +0 -0
  94. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/transformers.rst +0 -0
  95. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/utils.rst +0 -0
  96. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/api/viz.rst +0 -0
  97. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/available_pipeline_steps.rst +0 -0
  98. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/contributors.inc +0 -0
  99. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/.gitignore +0 -0
  100. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/224.misc +0 -0
  101. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/244.misc +0 -0
  102. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/249.bugfix +0 -0
  103. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/251.misc +0 -0
  104. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/changes/newsfragments/255.bugfix +0 -0
  105. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/configuration.rst +0 -0
  106. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/contributing.rst +0 -0
  107. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/examples.rst +0 -0
  108. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/faq.rst +0 -0
  109. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/corrected_ttest.png +0 -0
  110. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/final_estimator.png +0 -0
  111. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/iris_X.png +0 -0
  112. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/iris_df.png +0 -0
  113. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/iris_y.png +0 -0
  114. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo.png +0 -0
  115. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_calm.png +0 -0
  116. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_confbias.png +0 -0
  117. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_cv.png +0 -0
  118. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_generalization.png +0 -0
  119. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_it.png +0 -0
  120. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_ml.png +0 -0
  121. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/julearn_logo_mlit.png +0 -0
  122. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/multiple_scorers_run_cv.png +0 -0
  123. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/plot_scores.png +0 -0
  124. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/scores_run_cv.png +0 -0
  125. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/scores_run_cv_splitter.png +0 -0
  126. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/images/scores_run_cv_train.png +0 -0
  127. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/index.rst +0 -0
  128. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/maintaining.rst +0 -0
  129. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/redirect.html +0 -0
  130. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/CBPM.rst +0 -0
  131. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/confound_removal.rst +0 -0
  132. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/cross_validation_splitter.rst +0 -0
  133. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/hyperparameter_tuning.rst +0 -0
  134. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/index.rst +0 -0
  135. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/model_inspect.rst +0 -0
  136. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/stacked_models.rst +0 -0
  137. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/selected_deeper_topics/target_transformers.rst +0 -0
  138. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/sphinxext/gh_substitutions.py +0 -0
  139. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/cross_validation.rst +0 -0
  140. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/data.rst +0 -0
  141. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/index.rst +0 -0
  142. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/model_comparison.rst +0 -0
  143. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/model_evaluation.rst +0 -0
  144. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/what_really_need_know/pipeline.rst +0 -0
  145. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/docs/whats_new.rst +0 -0
  146. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/README.rst +0 -0
  147. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/plot_cm_acc_multiclass.py +0 -0
  148. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/plot_example_regression.py +0 -0
  149. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/plot_stratified_kfold_reg.py +0 -0
  150. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/run_combine_pandas.py +0 -0
  151. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/run_grouped_cv.py +0 -0
  152. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/00_starting/run_simple_binary_classification.py +0 -0
  153. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/01_model_comparison/README.rst +0 -0
  154. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/01_model_comparison/plot_simple_model_comparison.py +0 -0
  155. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/02_inspection/README.rst +0 -0
  156. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/02_inspection/plot_groupcv_inspect_svm.py +0 -0
  157. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/02_inspection/plot_inspect_random_forest.py +0 -0
  158. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/02_inspection/plot_preprocess.py +0 -0
  159. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/02_inspection/run_binary_inspect_folds.py +0 -0
  160. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/README.rst +0 -0
  161. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/run_apply_to_target.py +0 -0
  162. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/run_example_pca_featsets.py +0 -0
  163. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/run_hyperparameter_multiple_grids.py +0 -0
  164. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/run_hyperparameter_tuning.py +0 -0
  165. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/03_complex_models/run_stacked_models.py +0 -0
  166. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/04_confounds/README.rst +0 -0
  167. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/04_confounds/plot_confound_removal_classification.py +0 -0
  168. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/04_confounds/run_return_confounds.py +0 -0
  169. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/05_customization/README.rst +0 -0
  170. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/05_customization/run_custom_scorers_regression.py +0 -0
  171. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/README.rst +0 -0
  172. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_cbpm_docs.py +0 -0
  173. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_confound_removal_docs.py +0 -0
  174. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_cv_splitters_docs.py +0 -0
  175. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_data_docs.py +0 -0
  176. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_model_comparison_docs.py +0 -0
  177. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_model_evaluation_docs.py +0 -0
  178. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_model_inspection_docs.py +0 -0
  179. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_pipeline_docs.py +0 -0
  180. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_stacked_models_docs.py +0 -0
  181. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/99_docs/run_target_transformer_docs.py +0 -0
  182. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/README.rst +0 -0
  183. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/XX_disabled/dis_run_n_jobs.py +0 -0
  184. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/examples/XX_disabled/dis_run_target_confound_removal.py +0 -0
  185. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/ignore_words.txt +0 -0
  186. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/__init__.py +0 -0
  187. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/base/__init__.py +0 -0
  188. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/base/column_types.py +0 -0
  189. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/base/tests/test_column_types.py +0 -0
  190. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/config.py +0 -0
  191. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/__init__.py +0 -0
  192. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/inspect/_pipeline.py +0 -0
  193. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/continuous_stratified_kfold.py +0 -0
  194. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/tests/test_continous_stratified_kfold.py +0 -0
  195. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/model_selection/tests/test_stratified_bootstrap.py +0 -0
  196. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/__init__.py +0 -0
  197. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/available_models.py +0 -0
  198. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/tests/test_available_models.py +0 -0
  199. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/models/tests/test_dynamic.py +0 -0
  200. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/pipeline/__init__.py +0 -0
  201. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/pipeline/target_pipeline_creator.py +0 -0
  202. {julearn-0.3.2.dev24/julearn/pipeline/test → julearn-0.3.2.dev61/julearn/pipeline/tests}/test_target_pipeline.py +0 -0
  203. {julearn-0.3.2.dev24/julearn/pipeline/test → julearn-0.3.2.dev61/julearn/pipeline/tests}/test_target_pipeline_creator.py +0 -0
  204. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/scoring/__init__.py +0 -0
  205. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/scoring/metrics.py +0 -0
  206. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/scoring/tests/test_available_scorers.py +0 -0
  207. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/scoring/tests/test_metrics.py +0 -0
  208. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/stats/__init__.py +0 -0
  209. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/tests/test_config.py +0 -0
  210. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/tests/test_prepare.py +0 -0
  211. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/__init__.py +0 -0
  212. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/available_transformers.py +0 -0
  213. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/cbpm.py +0 -0
  214. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/__init__.py +0 -0
  215. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/change_column_types.py +0 -0
  216. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/drop_columns.py +0 -0
  217. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/filter_columns.py +0 -0
  218. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/set_column_types.py +0 -0
  219. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/dataframe/tests/test_change_column_types.py +0 -0
  220. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/__init__.py +0 -0
  221. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/available_target_transformers.py +0 -0
  222. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/ju_target_transformer.py +0 -0
  223. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/tests/test_available_target_transformers.py +0 -0
  224. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/tests/test_ju_target_transformer.py +0 -0
  225. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/target/tests/test_target_confound_remover.py +0 -0
  226. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/transformers/tests/test_available_transformers.py +0 -0
  227. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/__init__.py +0 -0
  228. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/_cv.py +0 -0
  229. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/tests/test_logging.py +0 -0
  230. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/tests/test_version.py +0 -0
  231. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/utils/versions.py +0 -0
  232. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/viz/__init__.py +0 -0
  233. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn/viz/res/julearn_logo_generalization.png +0 -0
  234. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn.egg-info/dependency_links.txt +0 -0
  235. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/julearn.egg-info/top_level.txt +0 -0
  236. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/setup.cfg +0 -0
  237. {julearn-0.3.2.dev24 → julearn-0.3.2.dev61}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: julearn
3
- Version: 0.3.2.dev24
3
+ Version: 0.3.2.dev61
4
4
  Summary: Juelich Machine Learning Library
5
5
  Author-email: Fede Raimondo <f.raimondo@fz-juelich.de>, Sami Hamdan <s.hamdan@fz-juelich.de>
6
6
  Maintainer-email: Sami Hamdan <s.hamdan@fz-juelich.de>
@@ -40,12 +40,17 @@ Requires-Dist: furo<2024.0.0,>=2022.9.29; extra == "docs"
40
40
  Requires-Dist: sphinx_copybutton<0.6,>=0.5.0; extra == "docs"
41
41
  Requires-Dist: numpydoc<1.6,>=1.5.0; extra == "docs"
42
42
  Requires-Dist: towncrier<24; extra == "docs"
43
+ Requires-Dist: scikit-optimize<0.11,>=0.10.0; extra == "docs"
43
44
  Provides-Extra: deslib
44
45
  Requires-Dist: deslib<0.4,>=0.3.5; extra == "deslib"
45
46
  Provides-Extra: viz
46
47
  Requires-Dist: panel>=1.3.0; extra == "viz"
47
48
  Requires-Dist: bokeh>=3.0.0; extra == "viz"
48
49
  Requires-Dist: param>=2.0.0; extra == "viz"
50
+ Provides-Extra: skopt
51
+ Requires-Dist: scikit-optimize<0.11,>=0.10.0; extra == "skopt"
52
+ Provides-Extra: all
53
+ Requires-Dist: julearn[skopt,viz]; extra == "all"
49
54
 
50
55
  # julearn
51
56
 
@@ -0,0 +1 @@
1
+ Add :class:`~skopt.BayesSearchCV` to the list of available searchers as 'bayes' by `Fede Raimondo`_
@@ -0,0 +1 @@
1
+ Add ``all`` as optional dependencies to install all functional dependencies by `Fede Raimondo`_
@@ -160,6 +160,7 @@ intersphinx_mapping = {
160
160
  # "sqlalchemy": ("https://docs.sqlalchemy.org/en/20/", None),
161
161
  "joblib": ("https://joblib.readthedocs.io/en/latest/", None),
162
162
  "scipy": ("https://docs.scipy.org/doc/scipy/", None),
163
+ "skopt": ("https://scikit-optimize.readthedocs.io/en/latest", None),
163
164
  }
164
165
 
165
166
 
@@ -86,4 +86,8 @@ The following optional dependencies are available:
86
86
 
87
87
  * ``viz``: Visualization tools for ``julearn``. This includes the
88
88
  :mod:`.viz` module.
89
- * ``deslib``: The :mod:`.dynamic` module requires the `deslib`_ package.
89
+ * ``deslib``: The :mod:`.dynamic` module requires the `deslib`_ package. This
90
+ module is not compatible with newer Python versions and it is unmaintained.
91
+ * ``skopt``: Using the ``"bayes"`` searcher (:class:`~skopt.BayesSearchCV`)
92
+ requires the `scikit-optimize`_ package.
93
+ * ``all``: Install all optional functional dependencies (except ``deslib``).
@@ -40,3 +40,4 @@
40
40
 
41
41
 
42
42
  .. _`DESlib`: https://github.com/scikit-learn-contrib/DESlib
43
+ .. _`scikit-optimize`: https://scikit-optimize.readthedocs.io/en/stable/
@@ -0,0 +1,95 @@
1
+ """
2
+ Tuning Hyperparameters using Bayesian Search
3
+ ============================================
4
+
5
+ This example uses the ``fmri`` dataset, performs simple binary classification
6
+ using a Support Vector Machine classifier and analyzes the model.
7
+
8
+ References
9
+ ----------
10
+
11
+ Waskom, M.L., Frank, M.C., Wagner, A.D. (2016). Adaptive engagement of
12
+ cognitive control in context-dependent decision-making. Cerebral Cortex.
13
+
14
+ .. include:: ../../links.inc
15
+ """
16
+
17
+ # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
18
+ # License: AGPL
19
+
20
+ import numpy as np
21
+ from seaborn import load_dataset
22
+
23
+ from julearn import run_cross_validation
24
+ from julearn.utils import configure_logging, logger
25
+ from julearn.pipeline import PipelineCreator
26
+
27
+
28
+ ###############################################################################
29
+ # Set the logging level to info to see extra information.
30
+ configure_logging(level="INFO")
31
+
32
+ ###############################################################################
33
+ # Set the random seed to always have the same example.
34
+ np.random.seed(42)
35
+
36
+ ###############################################################################
37
+ # Load the dataset.
38
+ df_fmri = load_dataset("fmri")
39
+ df_fmri.head()
40
+
41
+ ###############################################################################
42
+ # Set the dataframe in the right format.
43
+ df_fmri = df_fmri.pivot(
44
+ index=["subject", "timepoint", "event"], columns="region", values="signal"
45
+ )
46
+
47
+ df_fmri = df_fmri.reset_index()
48
+ df_fmri.head()
49
+
50
+ ###############################################################################
51
+ # Following the hyperparamter tuning example, we will now use a Bayesian
52
+ # search to find the best hyperparameters for the SVM model.
53
+ X = ["frontal", "parietal"]
54
+ y = "event"
55
+
56
+ creator1 = PipelineCreator(problem_type="classification")
57
+ creator1.add("zscore")
58
+ creator1.add(
59
+ "svm",
60
+ kernel=["linear"],
61
+ C=(1e-6, 1e3, "log-uniform"),
62
+ )
63
+
64
+ creator2 = PipelineCreator(problem_type="classification")
65
+ creator2.add("zscore")
66
+ creator2.add(
67
+ "svm",
68
+ kernel=["rbf"],
69
+ C=(1e-6, 1e3, "log-uniform"),
70
+ gamma=(1e-6, 1e1, "log-uniform"),
71
+ )
72
+
73
+ search_params = {
74
+ "kind": "bayes",
75
+ "cv": 2, # to speed up the example
76
+ "n_iter": 10, # 10 iterations of bayesian search to speed up example
77
+ }
78
+
79
+
80
+ scores, estimator = run_cross_validation(
81
+ X=X,
82
+ y=y,
83
+ data=df_fmri,
84
+ model=[creator1, creator2],
85
+ cv=2, # to speed up the example
86
+ search_params=search_params,
87
+ return_estimator="final",
88
+ )
89
+
90
+ print(scores["test_score"].mean())
91
+
92
+
93
+ ###############################################################################
94
+ # It seems that we might have found a better model, but which one is it?
95
+ print(estimator.best_params_)
@@ -243,22 +243,132 @@ pprint(model_tuned.best_params_)
243
243
  # tries to find the best combination of values for the hyperparameters using
244
244
  # cross-validation.
245
245
  #
246
- # By default, ``julearn`` uses a :class:`~sklearn.model_selection.GridSearchCV`.
247
- # This searcher is very simple. First, it construct the "grid" of
248
- # hyperparameters to try. As we see above, we have 3 hyperparameters to tune.
249
- # So it constructs a 3-dimentional grid with all the possible combinations of
250
- # the hyperparameters values. The second step is to perform cross-validation
251
- # on each of the possible combinations of hyperparameters values.
246
+ # By default, ``julearn`` uses a
247
+ # :class:`~sklearn.model_selection.GridSearchCV`.
248
+ # This searcher, specified as ``"grid"`` is very simple. First, it constructs
249
+ # the _grid_ of hyperparameters to try. As we see above, we have 3
250
+ # hyperparameters to tune. So it constructs a 3-dimentional grid with all the
251
+ # possible combinations of the hyperparameters values. The second step is to
252
+ # perform cross-validation on each of the possible combinations of
253
+ # hyperparameters values.
252
254
  #
253
- # Another searcher that ``julearn`` provides is the
254
- # :class:`~sklearn.model_selection.RandomizedSearchCV`. This searcher is
255
- # similar to the :class:`~sklearn.model_selection.GridSearchCV`, but instead
256
- # of trying all the possible combinations of hyperparameters values, it tries
255
+ # Other searchers that ``julearn`` provides are the
256
+ # :class:`~sklearn.model_selection.RandomizedSearchCV` and
257
+ # :class:`~skopt.BayesSearchCV`.
258
+ #
259
+ # The randomized searcher
260
+ # (:class:`~sklearn.model_selection.RandomizedSearchCV`) is similar to the
261
+ # :class:`~sklearn.model_selection.GridSearchCV`, but instead
262
+ # of trying all the possible combinations of hyperparameter values, it tries
257
263
  # a random subset of them. This is useful when we have a lot of hyperparameters
258
- # to tune, since it can be very time consuming to try all the possible, as well
259
- # as continuous parameters that can be sampled out of a distribution. For
260
- # more information, see the
264
+ # to tune, since it can be very time consuming to try all the possible
265
+ # combinations, as well as continuous parameters that can be sampled out of a
266
+ # distribution. For more information, see the
261
267
  # :class:`~sklearn.model_selection.RandomizedSearchCV` documentation.
268
+ #
269
+ # The Bayesian searcher (:class:`~skopt.BayesSearchCV`) is a bit more
270
+ # complex. It uses Bayesian optimization to find the best hyperparameter set.
271
+ # As with the randomized search, it is useful when we have many
272
+ # hyperparameters to tune, and we don't want to try all the possible
273
+ # combinations due to computational constraints. For more information, see the
274
+ # :class:`~skopt.BayesSearchCV` documentation, including how to specify
275
+ # the prior distributions of the hyperparameters.
276
+ #
277
+ # We can specify the kind of searcher and its parametrization, by setting the
278
+ # ``search_params`` parameter in the :func:`.run_cross_validation` function.
279
+ # For example, we can use the
280
+ # :class:`~sklearn.model_selection.RandomizedSearchCV` searcher with
281
+ # 10 iterations of random search.
282
+
283
+ search_params = {
284
+ "kind": "random",
285
+ "n_iter": 10,
286
+ }
287
+
288
+ scores_tuned, model_tuned = run_cross_validation(
289
+ X=X,
290
+ y=y,
291
+ data=df,
292
+ X_types=X_types,
293
+ model=creator,
294
+ return_estimator="all",
295
+ search_params=search_params,
296
+ )
297
+
298
+ print(
299
+ "Scores with best hyperparameter using 10 iterations of "
300
+ f"randomized search: {scores_tuned['test_score'].mean()}"
301
+ )
302
+ pprint(model_tuned.best_params_)
303
+
304
+ ###############################################################################
305
+ # We can now see that the best hyperparameter might be different from the grid
306
+ # search. This is because it tried only 10 combinations and not the whole grid.
307
+ # Furthermore, the :class:`~sklearn.model_selection.RandomizedSearchCV`
308
+ # searcher can sample hyperparameters from distributions, which can be useful
309
+ # when we have continuous hyperparameters.
310
+ # Let's set both ``C`` and ``gamma`` to be sampled from log-uniform
311
+ # distributions. We can do this by setting the hyperparameter values as a
312
+ # tuple with the following format: ``(low, high, distribution)``. The
313
+ # distribution can be either ``"log-uniform"`` or ``"uniform"``.
314
+
315
+ creator = PipelineCreator(problem_type="classification")
316
+ creator.add("zscore")
317
+ creator.add("select_k", k=[2, 3, 4])
318
+ creator.add(
319
+ "svm",
320
+ C=(0.01, 10, "log-uniform"),
321
+ gamma=(1e-3, 1e-1, "log-uniform"),
322
+ )
323
+
324
+ print(creator)
325
+
326
+ scores_tuned, model_tuned = run_cross_validation(
327
+ X=X,
328
+ y=y,
329
+ data=df,
330
+ X_types=X_types,
331
+ model=creator,
332
+ return_estimator="all",
333
+ search_params=search_params,
334
+ )
335
+
336
+ print(
337
+ "Scores with best hyperparameter using 10 iterations of "
338
+ f"randomized search: {scores_tuned['test_score'].mean()}"
339
+ )
340
+ pprint(model_tuned.best_params_)
341
+
342
+
343
+ ###############################################################################
344
+ # We can also control the number of cross-validation folds used by the searcher
345
+ # by setting the ``cv`` parameter in the ``search_params`` dictionary. For
346
+ # example, we can use a bayesian search with 3 folds. Fortunately, the
347
+ # :class:`~skopt.BayesSearchCV` searcher also accepts distributions for the
348
+ # hyperparameters.
349
+
350
+ search_params = {
351
+ "kind": "bayes",
352
+ "n_iter": 10,
353
+ "cv": 3,
354
+ }
355
+
356
+ scores_tuned, model_tuned = run_cross_validation(
357
+ X=X,
358
+ y=y,
359
+ data=df,
360
+ X_types=X_types,
361
+ model=creator,
362
+ return_estimator="all",
363
+ search_params=search_params,
364
+ )
365
+
366
+ print(
367
+ "Scores with best hyperparameter using 10 iterations of "
368
+ f"bayesian search and 3-fold CV: {scores_tuned['test_score'].mean()}"
369
+ )
370
+ pprint(model_tuned.best_params_)
371
+
262
372
 
263
373
  ###############################################################################
264
374
  #
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.2.dev24'
16
- __version_tuple__ = version_tuple = (0, 3, 2, 'dev24')
15
+ __version__ = version = '0.3.2.dev61'
16
+ __version_tuple__ = version_tuple = (0, 3, 2, 'dev61')
@@ -8,8 +8,12 @@ from typing import Dict, List, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import sklearn
11
12
  from sklearn.base import BaseEstimator
12
- from sklearn.model_selection import check_cv, cross_validate
13
+ from sklearn.model_selection import (
14
+ check_cv,
15
+ cross_validate,
16
+ )
13
17
  from sklearn.model_selection._search import BaseSearchCV
14
18
  from sklearn.pipeline import Pipeline
15
19
 
@@ -19,20 +23,21 @@ from .pipeline.merger import merge_pipelines
19
23
  from .prepare import check_consistency, prepare_input_data
20
24
  from .scoring import check_scoring
21
25
  from .utils import _compute_cvmdsum, logger, raise_error
26
+ from .utils.typing import CVLike
22
27
 
23
28
 
24
29
  def run_cross_validation( # noqa: C901
25
30
  X: List[str], # noqa: N803
26
31
  y: str,
27
32
  model: Union[str, PipelineCreator, BaseEstimator, List[PipelineCreator]],
33
+ data: pd.DataFrame,
28
34
  X_types: Optional[Dict] = None, # noqa: N803
29
- data: Optional[pd.DataFrame] = None,
30
35
  problem_type: Optional[str] = None,
31
36
  preprocess: Union[None, str, List[str]] = None,
32
37
  return_estimator: Optional[str] = None,
33
38
  return_inspector: bool = False,
34
39
  return_train_score: bool = False,
35
- cv: Optional[int] = None,
40
+ cv: Optional[CVLike] = None,
36
41
  groups: Optional[str] = None,
37
42
  scoring: Union[str, List[str], None] = None,
38
43
  pos_labels: Union[str, List[str], None] = None,
@@ -54,12 +59,11 @@ def run_cross_validation( # noqa: C901
54
59
  See :ref:`data_usage` for details.
55
60
  model : str or scikit-learn compatible model.
56
61
  If string, it will use one of the available models.
62
+ data : pandas.DataFrame
63
+ DataFrame with the data. See :ref:`data_usage` for details.
57
64
  X_types : dict[str, list of str]
58
65
  A dictionary containing keys with column type as a str and the
59
66
  columns of this column type as a list of str.
60
- data : pandas.DataFrame | None
61
- DataFrame with the data (optional).
62
- See :ref:`data_usage` for details.
63
67
  problem_type : str
64
68
  The kind of problem to model.
65
69
 
@@ -132,8 +136,8 @@ def run_cross_validation( # noqa: C901
132
136
  the following keys:
133
137
 
134
138
  * 'kind': The kind of search algorithm to use, e.g.:
135
- 'grid' or 'random'. Can be any valid julearn searcher name or
136
- scikit-learn compatible searcher.
139
+ 'grid', 'random' or 'bayes'. Can be any valid julearn searcher name
140
+ or scikit-learn compatible searcher.
137
141
  * 'cv': If a searcher is going to be used, the cross-validation
138
142
  splitting strategy to use. Defaults to same CV as for the model
139
143
  evaluation.
@@ -196,7 +200,7 @@ def run_cross_validation( # noqa: C901
196
200
  np.random.seed(seed)
197
201
 
198
202
  # Interpret the input data and prepare it to be used with the library
199
- df_X, y, df_groups, X_types = prepare_input_data(
203
+ df_X, df_y, df_groups, X_types = prepare_input_data(
200
204
  X=X,
201
205
  y=y,
202
206
  df=data,
@@ -267,7 +271,7 @@ def run_cross_validation( # noqa: C901
267
271
 
268
272
  if has_target_transformer:
269
273
  if isinstance(pipeline, BaseSearchCV):
270
- last_step = pipeline.estimator[-1]
274
+ last_step = pipeline.estimator[-1] # type: ignore
271
275
  else:
272
276
  last_step = pipeline[-1]
273
277
  if not last_step.can_inverse_transform():
@@ -313,7 +317,7 @@ def run_cross_validation( # noqa: C901
313
317
  "Cannot use model_params with a model object. Use either "
314
318
  "a string or a PipelineCreator"
315
319
  )
316
- pipeline_creator.add(step=model, **t_params)
320
+ pipeline_creator.add(step=model, **t_params) # type: ignore
317
321
 
318
322
  # Check for extra model_params that are not used
319
323
  unused_params = []
@@ -346,38 +350,52 @@ def run_cross_validation( # noqa: C901
346
350
  logger.info("")
347
351
 
348
352
  if problem_type == "classification":
349
- logger.info(f"\tNumber of classes: {len(np.unique(y))}")
350
- logger.info(f"\tTarget type: {y.dtype}")
351
- logger.info(f"\tClass distributions: {y.value_counts()}")
353
+ logger.info(f"\tNumber of classes: {len(np.unique(df_y))}")
354
+ logger.info(f"\tTarget type: {df_y.dtype}")
355
+ logger.info(f"\tClass distributions: {df_y.value_counts()}")
352
356
  elif problem_type == "regression":
353
- logger.info(f"\tTarget type: {y.dtype}")
357
+ logger.info(f"\tTarget type: {df_y.dtype}")
354
358
 
355
359
  # Prepare cross validation
356
- cv_outer = check_cv(cv, classifier=problem_type == "classification")
360
+ cv_outer = check_cv(
361
+ cv, # type: ignore
362
+ classifier=problem_type == "classification",
363
+ )
357
364
  logger.info(f"Using outer CV scheme {cv_outer}")
358
365
 
359
- check_consistency(y, cv, groups, problem_type)
366
+ check_consistency(df_y, cv, groups, problem_type) # type: ignore
360
367
 
361
368
  cv_return_estimator = return_estimator in ["cv", "all"]
362
- scoring = check_scoring(pipeline, scoring, wrap_score=wrap_score)
369
+ scoring = check_scoring(
370
+ pipeline, # type: ignore
371
+ scoring,
372
+ wrap_score=wrap_score,
373
+ )
363
374
 
364
375
  cv_mdsum = _compute_cvmdsum(cv_outer)
365
376
  fit_params = {}
366
377
  if df_groups is not None:
367
378
  if isinstance(pipeline, BaseSearchCV):
368
379
  fit_params["groups"] = df_groups.values
380
+
381
+ _sklearn_deprec_fit_params = {}
382
+ if sklearn.__version__ >= "1.4.0":
383
+ _sklearn_deprec_fit_params["params"] = fit_params
384
+ else:
385
+ _sklearn_deprec_fit_params["fit_params"] = fit_params
386
+
369
387
  scores = cross_validate(
370
388
  pipeline,
371
389
  df_X,
372
- y,
390
+ df_y,
373
391
  cv=cv_outer,
374
392
  scoring=scoring,
375
393
  groups=df_groups,
376
394
  return_estimator=cv_return_estimator,
377
395
  n_jobs=n_jobs,
378
396
  return_train_score=return_train_score,
379
- verbose=verbose,
380
- fit_params=fit_params,
397
+ verbose=verbose, # type: ignore
398
+ **_sklearn_deprec_fit_params,
381
399
  )
382
400
 
383
401
  n_repeats = getattr(cv_outer, "n_repeats", 1)
@@ -387,7 +405,10 @@ def run_cross_validation( # noqa: C901
387
405
  folds = np.tile(np.arange(n_folds), n_repeats)
388
406
 
389
407
  fold_sizes = np.array(
390
- [list(map(len, x)) for x in cv_outer.split(df_X, y, groups=df_groups)]
408
+ [
409
+ list(map(len, x))
410
+ for x in cv_outer.split(df_X, df_y, groups=df_groups)
411
+ ]
391
412
  )
392
413
  scores["n_train"] = fold_sizes[:, 0]
393
414
  scores["n_test"] = fold_sizes[:, 1]
@@ -398,7 +419,8 @@ def run_cross_validation( # noqa: C901
398
419
  scores_df = pd.DataFrame(scores)
399
420
  out = scores_df
400
421
  if return_estimator in ["final", "all"]:
401
- pipeline.fit(df_X, y, **fit_params)
422
+ logger.info("Fitting final model")
423
+ pipeline.fit(df_X, df_y, **fit_params)
402
424
  out = scores_df, pipeline
403
425
 
404
426
  if return_inspector:
@@ -406,7 +428,7 @@ def run_cross_validation( # noqa: C901
406
428
  scores=scores_df,
407
429
  model=pipeline,
408
430
  X=df_X,
409
- y=y,
431
+ y=df_y,
410
432
  groups=df_groups,
411
433
  cv=cv_outer,
412
434
  )
@@ -13,11 +13,11 @@ from sklearn.utils.metaestimators import available_if
13
13
 
14
14
 
15
15
  try: # sklearn < 1.4.0
16
- from sklearn.utils.validation import _check_fit_params
16
+ from sklearn.utils.validation import _check_fit_params # type: ignore
17
17
 
18
18
  fit_params_checker = _check_fit_params
19
19
  except ImportError: # sklearn >= 1.4.0
20
- from sklearn.utils.validation import _check_method_params
20
+ from sklearn.utils.validation import _check_method_params # type: ignore
21
21
 
22
22
  fit_params_checker = _check_method_params
23
23
 
@@ -180,7 +180,12 @@ class JuTransformer(JuBaseEstimator, TransformerMixin):
180
180
  self.row_select_col_type = row_select_col_type
181
181
  self.row_select_vals = row_select_vals
182
182
 
183
- def fit(self, X, y=None, **fit_params): # noqa: N803
183
+ def fit(
184
+ self,
185
+ X: pd.DataFrame, # noqa: N803
186
+ y: Optional[pd.Series] = None,
187
+ **fit_params,
188
+ ):
184
189
  """Fit the model.
185
190
 
186
191
  This method will fit the model using only the columns selected by
@@ -217,8 +222,21 @@ class JuTransformer(JuBaseEstimator, TransformerMixin):
217
222
  self.row_select_vals = [self.row_select_vals]
218
223
  return self._fit(**self._select_rows(X, y, **fit_params))
219
224
 
225
+ def _fit(
226
+ self,
227
+ X: pd.DataFrame, # noqa: N803,
228
+ y: Optional[pd.Series],
229
+ **kwargs,
230
+ ) -> None:
231
+ raise_error(
232
+ "This method should be implemented in the concrete class",
233
+ klass=NotImplementedError,
234
+ )
235
+
220
236
  def _add_backed_filtered(
221
- self, X: pd.DataFrame, X_trans: pd.DataFrame # noqa: N803
237
+ self,
238
+ X: pd.DataFrame, # noqa: N803
239
+ X_trans: pd.DataFrame, # noqa: N803
222
240
  ) -> pd.DataFrame:
223
241
  """Add the left-out columns back to the transformed data.
224
242
 
@@ -301,7 +319,7 @@ class WrapModel(JuBaseEstimator):
301
319
 
302
320
  def fit(
303
321
  self,
304
- X: pd.DataFrame, # noqa: N803
322
+ X: DataLike, # noqa: N803
305
323
  y: Optional[DataLike] = None,
306
324
  **fit_params: Any,
307
325
  ) -> "WrapModel":
@@ -312,7 +330,7 @@ class WrapModel(JuBaseEstimator):
312
330
 
313
331
  Parameters
314
332
  ----------
315
- X : pd.DataFrame
333
+ X : DataLike
316
334
  The data to fit the model on.
317
335
  y : DataLike, optional
318
336
  The target data (default is None).
@@ -329,9 +347,9 @@ class WrapModel(JuBaseEstimator):
329
347
  if self.needed_types is not None:
330
348
  self.needed_types = ensure_column_types(self.needed_types)
331
349
 
332
- Xt = self.filter_columns(X)
350
+ Xt = self.filter_columns(X) # type: ignore
333
351
  self.model_ = self.model
334
- self.model_.fit(Xt, y, **fit_params)
352
+ self.model_.fit(Xt, y, **fit_params) # type: ignore
335
353
  return self
336
354
 
337
355
  def predict(self, X: pd.DataFrame) -> DataLike: # noqa: N803
@@ -110,7 +110,7 @@ def test_WrapModel(
110
110
 
111
111
  np.random.seed(42)
112
112
  lr = model()
113
- lr.fit(X_iris_selected, y_iris)
113
+ lr.fit(X_iris_selected, y_iris) # type: ignore
114
114
  pred_sk = lr.predict(X_iris_selected)
115
115
 
116
116
  np.random.seed(42)