julearn 0.3.2.dev57__tar.gz → 0.3.2.dev78__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/PKG-INFO +7 -2
  2. julearn-0.3.2.dev78/docs/changes/newsfragments/262.doc +1 -0
  3. julearn-0.3.2.dev78/docs/changes/newsfragments/262.enh +1 -0
  4. julearn-0.3.2.dev78/docs/changes/newsfragments/262.feature +1 -0
  5. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/conf.py +5 -0
  6. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/links.inc +1 -0
  7. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_hyperparameters_docs.py +143 -2
  8. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/_version.py +2 -2
  9. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/api.py +31 -9
  10. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/base/estimators.py +26 -8
  11. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/base/tests/test_base_estimators.py +1 -1
  12. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/conftest.py +24 -1
  13. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/_cv.py +16 -10
  14. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/_preprocess.py +1 -1
  15. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/tests/test_cv.py +4 -2
  16. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/tests/test_inspector.py +12 -9
  17. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/tests/test_pipeline.py +27 -14
  18. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/tests/test_preprocess.py +8 -1
  19. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/__init__.py +3 -0
  20. julearn-0.3.2.dev78/julearn/model_selection/_optuna_searcher.py +107 -0
  21. julearn-0.3.2.dev78/julearn/model_selection/_skopt_searcher.py +95 -0
  22. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/available_searchers.py +3 -3
  23. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/stratified_bootstrap.py +7 -5
  24. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/tests/test_available_searchers.py +53 -21
  25. julearn-0.3.2.dev78/julearn/model_selection/tests/test_optuna_searcher.py +165 -0
  26. julearn-0.3.2.dev78/julearn/model_selection/tests/test_skopt_searcher.py +135 -0
  27. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/dynamic.py +1 -1
  28. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/tests/test_models.py +6 -6
  29. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/pipeline_creator.py +65 -22
  30. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/target_pipeline.py +1 -1
  31. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/tests/test_merger.py +16 -13
  32. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/tests/test_pipeline_creator.py +44 -3
  33. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/prepare.py +2 -4
  34. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/scoring/available_scorers.py +12 -6
  35. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/stats/corrected_ttest.py +8 -4
  36. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/stats/tests/test_corrected_ttest.py +22 -5
  37. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/tests/test_api.py +103 -45
  38. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/confound_remover.py +2 -2
  39. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/tests/test_drop_columns.py +1 -1
  40. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/tests/test_filter_columns.py +1 -1
  41. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/tests/test_set_column_types.py +10 -6
  42. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/ju_column_transformer.py +1 -1
  43. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/ju_transformed_target_model.py +1 -1
  44. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/target_confound_remover.py +1 -1
  45. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/tests/test_ju_transformed_target_model.py +3 -1
  46. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/tests/test_cbpm.py +3 -3
  47. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/tests/test_confounds.py +17 -6
  48. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/tests/test_jucolumntransformers.py +8 -6
  49. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/checks.py +3 -1
  50. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/logging.py +1 -1
  51. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/testing.py +21 -11
  52. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/typing.py +39 -19
  53. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/viz/_scores.py +5 -4
  54. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn.egg-info/PKG-INFO +7 -2
  55. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn.egg-info/SOURCES.txt +6 -0
  56. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn.egg-info/requires.txt +7 -1
  57. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/pyproject.toml +22 -2
  58. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/tox.ini +6 -0
  59. julearn-0.3.2.dev57/julearn/model_selection/_skopt_searcher.py +0 -32
  60. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  61. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  62. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/ISSUE_TEMPLATE/documentation_request.yaml +0 -0
  63. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  64. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/check-stale.yml +0 -0
  65. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/ci-docs.yml +0 -0
  66. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/ci.yml +0 -0
  67. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/docs-preview.yml +0 -0
  68. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/docs.yml +0 -0
  69. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/lint.yml +0 -0
  70. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.github/workflows/pypi.yml +0 -0
  71. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.gitignore +0 -0
  72. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/.pre-commit-config.yaml +0 -0
  73. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/AUTHORS.rst +0 -0
  74. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/LICENSE.md +0 -0
  75. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/README.md +0 -0
  76. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/codecov.yml +0 -0
  77. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/Makefile +0 -0
  78. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_static/css/custom.css +0 -0
  79. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_static/js/custom.js +0 -0
  80. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_templates/class.rst +0 -0
  81. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_templates/function.rst +0 -0
  82. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_templates/function_warning.rst +0 -0
  83. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/_templates/versions.html +0 -0
  84. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/base.rst +0 -0
  85. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/index.rst +0 -0
  86. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/inspect.rst +0 -0
  87. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/main.rst +0 -0
  88. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/model_selection.rst +0 -0
  89. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/models.rst +0 -0
  90. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/pipeline.rst +0 -0
  91. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/prepare.rst +0 -0
  92. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/scoring.rst +0 -0
  93. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/stats.rst +0 -0
  94. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/transformers.rst +0 -0
  95. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/utils.rst +0 -0
  96. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/api/viz.rst +0 -0
  97. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/available_pipeline_steps.rst +0 -0
  98. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/contributors.inc +0 -0
  99. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/.gitignore +0 -0
  100. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/224.misc +0 -0
  101. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/244.misc +0 -0
  102. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/249.bugfix +0 -0
  103. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/251.misc +0 -0
  104. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/255.bugfix +0 -0
  105. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/260.enh +0 -0
  106. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/changes/newsfragments/260.misc +0 -0
  107. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/configuration.rst +0 -0
  108. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/contributing.rst +0 -0
  109. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/examples.rst +0 -0
  110. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/faq.rst +0 -0
  111. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/getting_started.rst +0 -0
  112. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/corrected_ttest.png +0 -0
  113. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/final_estimator.png +0 -0
  114. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/iris_X.png +0 -0
  115. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/iris_df.png +0 -0
  116. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/iris_y.png +0 -0
  117. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo.png +0 -0
  118. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_calm.png +0 -0
  119. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_confbias.png +0 -0
  120. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_cv.png +0 -0
  121. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_generalization.png +0 -0
  122. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_it.png +0 -0
  123. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_ml.png +0 -0
  124. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/julearn_logo_mlit.png +0 -0
  125. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/multiple_scorers_run_cv.png +0 -0
  126. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/plot_scores.png +0 -0
  127. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/scores_run_cv.png +0 -0
  128. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/scores_run_cv_splitter.png +0 -0
  129. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/images/scores_run_cv_train.png +0 -0
  130. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/index.rst +0 -0
  131. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/maintaining.rst +0 -0
  132. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/redirect.html +0 -0
  133. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/CBPM.rst +0 -0
  134. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/confound_removal.rst +0 -0
  135. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/cross_validation_splitter.rst +0 -0
  136. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/hyperparameter_tuning.rst +0 -0
  137. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/index.rst +0 -0
  138. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/model_inspect.rst +0 -0
  139. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/stacked_models.rst +0 -0
  140. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/selected_deeper_topics/target_transformers.rst +0 -0
  141. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/sphinxext/gh_substitutions.py +0 -0
  142. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/cross_validation.rst +0 -0
  143. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/data.rst +0 -0
  144. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/index.rst +0 -0
  145. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/model_comparison.rst +0 -0
  146. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/model_evaluation.rst +0 -0
  147. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/what_really_need_know/pipeline.rst +0 -0
  148. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/docs/whats_new.rst +0 -0
  149. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/README.rst +0 -0
  150. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/plot_cm_acc_multiclass.py +0 -0
  151. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/plot_example_regression.py +0 -0
  152. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/plot_stratified_kfold_reg.py +0 -0
  153. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/run_combine_pandas.py +0 -0
  154. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/run_grouped_cv.py +0 -0
  155. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/00_starting/run_simple_binary_classification.py +0 -0
  156. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/01_model_comparison/README.rst +0 -0
  157. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/01_model_comparison/plot_simple_model_comparison.py +0 -0
  158. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/02_inspection/README.rst +0 -0
  159. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/02_inspection/plot_groupcv_inspect_svm.py +0 -0
  160. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/02_inspection/plot_inspect_random_forest.py +0 -0
  161. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/02_inspection/plot_preprocess.py +0 -0
  162. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/02_inspection/run_binary_inspect_folds.py +0 -0
  163. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/README.rst +0 -0
  164. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_apply_to_target.py +0 -0
  165. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_example_pca_featsets.py +0 -0
  166. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_hyperparameter_multiple_grids.py +0 -0
  167. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_hyperparameter_tuning.py +0 -0
  168. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +0 -0
  169. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/03_complex_models/run_stacked_models.py +0 -0
  170. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/04_confounds/README.rst +0 -0
  171. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/04_confounds/plot_confound_removal_classification.py +0 -0
  172. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/04_confounds/run_return_confounds.py +0 -0
  173. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/05_customization/README.rst +0 -0
  174. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/05_customization/run_custom_scorers_regression.py +0 -0
  175. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/README.rst +0 -0
  176. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_cbpm_docs.py +0 -0
  177. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_confound_removal_docs.py +0 -0
  178. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_cv_splitters_docs.py +0 -0
  179. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_data_docs.py +0 -0
  180. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_model_comparison_docs.py +0 -0
  181. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_model_evaluation_docs.py +0 -0
  182. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_model_inspection_docs.py +0 -0
  183. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_pipeline_docs.py +0 -0
  184. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_stacked_models_docs.py +0 -0
  185. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/99_docs/run_target_transformer_docs.py +0 -0
  186. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/README.rst +0 -0
  187. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/XX_disabled/dis_run_n_jobs.py +0 -0
  188. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/examples/XX_disabled/dis_run_target_confound_removal.py +0 -0
  189. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/ignore_words.txt +0 -0
  190. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/__init__.py +0 -0
  191. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/base/__init__.py +0 -0
  192. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/base/column_types.py +0 -0
  193. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/base/tests/test_column_types.py +0 -0
  194. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/config.py +0 -0
  195. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/__init__.py +0 -0
  196. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/_pipeline.py +0 -0
  197. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/inspect/inspector.py +0 -0
  198. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/continuous_stratified_kfold.py +0 -0
  199. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/tests/test_continous_stratified_kfold.py +0 -0
  200. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/model_selection/tests/test_stratified_bootstrap.py +0 -0
  201. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/__init__.py +0 -0
  202. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/available_models.py +0 -0
  203. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/tests/test_available_models.py +0 -0
  204. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/models/tests/test_dynamic.py +0 -0
  205. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/__init__.py +0 -0
  206. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/merger.py +0 -0
  207. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/target_pipeline_creator.py +0 -0
  208. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/tests/test_target_pipeline.py +0 -0
  209. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/pipeline/tests/test_target_pipeline_creator.py +0 -0
  210. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/scoring/__init__.py +0 -0
  211. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/scoring/metrics.py +0 -0
  212. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/scoring/tests/test_available_scorers.py +0 -0
  213. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/scoring/tests/test_metrics.py +0 -0
  214. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/stats/__init__.py +0 -0
  215. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/tests/test_config.py +0 -0
  216. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/tests/test_prepare.py +0 -0
  217. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/__init__.py +0 -0
  218. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/available_transformers.py +0 -0
  219. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/cbpm.py +0 -0
  220. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/__init__.py +0 -0
  221. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/change_column_types.py +0 -0
  222. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/drop_columns.py +0 -0
  223. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/filter_columns.py +0 -0
  224. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/set_column_types.py +0 -0
  225. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/dataframe/tests/test_change_column_types.py +0 -0
  226. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/__init__.py +0 -0
  227. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/available_target_transformers.py +0 -0
  228. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/ju_target_transformer.py +0 -0
  229. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/tests/test_available_target_transformers.py +0 -0
  230. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/tests/test_ju_target_transformer.py +0 -0
  231. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/target/tests/test_target_confound_remover.py +0 -0
  232. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/transformers/tests/test_available_transformers.py +0 -0
  233. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/__init__.py +0 -0
  234. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/_cv.py +0 -0
  235. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/tests/test_logging.py +0 -0
  236. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/tests/test_version.py +0 -0
  237. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/utils/versions.py +0 -0
  238. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/viz/__init__.py +0 -0
  239. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn/viz/res/julearn_logo_generalization.png +0 -0
  240. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn.egg-info/dependency_links.txt +0 -0
  241. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/julearn.egg-info/top_level.txt +0 -0
  242. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/setup.cfg +0 -0
  243. {julearn-0.3.2.dev57 → julearn-0.3.2.dev78}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: julearn
3
- Version: 0.3.2.dev57
3
+ Version: 0.3.2.dev78
4
4
  Summary: Juelich Machine Learning Library
5
5
  Author-email: Fede Raimondo <f.raimondo@fz-juelich.de>, Sami Hamdan <s.hamdan@fz-juelich.de>
6
6
  Maintainer-email: Sami Hamdan <s.hamdan@fz-juelich.de>
@@ -41,6 +41,8 @@ Requires-Dist: sphinx_copybutton<0.6,>=0.5.0; extra == "docs"
41
41
  Requires-Dist: numpydoc<1.6,>=1.5.0; extra == "docs"
42
42
  Requires-Dist: towncrier<24; extra == "docs"
43
43
  Requires-Dist: scikit-optimize<0.11,>=0.10.0; extra == "docs"
44
+ Requires-Dist: optuna<3.7,>=3.6.0; extra == "docs"
45
+ Requires-Dist: optuna_integration<3.7,>=3.6.0; extra == "docs"
44
46
  Provides-Extra: deslib
45
47
  Requires-Dist: deslib<0.4,>=0.3.5; extra == "deslib"
46
48
  Provides-Extra: viz
@@ -49,8 +51,11 @@ Requires-Dist: bokeh>=3.0.0; extra == "viz"
49
51
  Requires-Dist: param>=2.0.0; extra == "viz"
50
52
  Provides-Extra: skopt
51
53
  Requires-Dist: scikit-optimize<0.11,>=0.10.0; extra == "skopt"
54
+ Provides-Extra: optuna
55
+ Requires-Dist: optuna<3.7,>=3.6.0; extra == "optuna"
56
+ Requires-Dist: optuna_integration<3.7,>=3.6.0; extra == "optuna"
52
57
  Provides-Extra: all
53
- Requires-Dist: julearn[skopt,viz]; extra == "all"
58
+ Requires-Dist: julearn[optuna,skopt,viz]; extra == "all"
54
59
 
55
60
  # julearn
56
61
 
@@ -0,0 +1 @@
1
+ Update documentation on Hyperparameter Tuning by `Fede Raimondo`_
@@ -0,0 +1 @@
1
+ Refactor how hyperparmeters' distributions are specified by `Fede Raimondo`_
@@ -0,0 +1 @@
1
+ Add :class:`~optuna_integration.sklearn.OptunaSearchCV` to the list of available searchers as ``optuna`` by `Fede Raimondo`_
@@ -161,6 +161,11 @@ intersphinx_mapping = {
161
161
  "joblib": ("https://joblib.readthedocs.io/en/latest/", None),
162
162
  "scipy": ("https://docs.scipy.org/doc/scipy/", None),
163
163
  "skopt": ("https://scikit-optimize.readthedocs.io/en/latest", None),
164
+ "optuna": ("https://optuna.readthedocs.io/en/stable", None),
165
+ "optuna_integration": (
166
+ "https://optuna-integration.readthedocs.io/en/stable",
167
+ None,
168
+ ),
164
169
  }
165
170
 
166
171
 
@@ -41,3 +41,4 @@
41
41
 
42
42
  .. _`DESlib`: https://github.com/scikit-learn-contrib/DESlib
43
43
  .. _`scikit-optimize`: https://scikit-optimize.readthedocs.io/en/stable/
44
+ .. _`Optuna`: https://optuna.org
@@ -253,8 +253,9 @@ pprint(model_tuned.best_params_)
253
253
  # hyperparameters values.
254
254
  #
255
255
  # Other searchers that ``julearn`` provides are the
256
- # :class:`~sklearn.model_selection.RandomizedSearchCV` and
257
- # :class:`~skopt.BayesSearchCV`.
256
+ # :class:`~sklearn.model_selection.RandomizedSearchCV`,
257
+ # :class:`~skopt.BayesSearchCV` and
258
+ # :class:`~optuna_integration.sklearn.OptunaSearchCV`.
258
259
  #
259
260
  # The randomized searcher
260
261
  # (:class:`~sklearn.model_selection.RandomizedSearchCV`) is similar to the
@@ -274,6 +275,12 @@ pprint(model_tuned.best_params_)
274
275
  # :class:`~skopt.BayesSearchCV` documentation, including how to specify
275
276
  # the prior distributions of the hyperparameters.
276
277
  #
278
+ # The Optuna searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`)
279
+ # uses the Optuna library to find the best hyperparameter set. Optuna is a
280
+ # hyperparameter optimization framework that has several algorithms to find
281
+ # the best hyperparameter set. For more information, see the
282
+ # `Optuna`_ documentation.
283
+ #
277
284
  # We can specify the kind of searcher and its parametrization, by setting the
278
285
  # ``search_params`` parameter in the :func:`.run_cross_validation` function.
279
286
  # For example, we can use the
@@ -369,6 +376,140 @@ print(
369
376
  )
370
377
  pprint(model_tuned.best_params_)
371
378
 
379
+ ###############################################################################
380
+ # An example using optuna searcher is shown below. The searcher is specified
381
+ # as ``"optuna"`` and the hyperparameters are specified as a dictionary with
382
+ # the hyperparameters to tune and their distributions as for the bayesian
383
+ # searcher. However, the optuna searcher behaviour is controlled by a
384
+ # :class:`~optuna.study.Study` object. This object can be passed to the
385
+ # searcher using the ``study`` parameter in the ``search_params`` dictionary.
386
+ #
387
+ # .. important::
388
+ # The optuna searcher requires that all the hyperparameters are specified
389
+ # as distributions, even the categorical ones.
390
+ #
391
+ # We first modify the pipeline creator so the ``select_k`` parameter is
392
+ # specified as a distribution. We exemplarily use a categorical distribution
393
+ # for the ``class_weight`` hyperparameter, trying the ``"balanced"`` and
394
+ # ``None`` values.
395
+
396
+ creator = PipelineCreator(problem_type="classification")
397
+ creator.add("zscore")
398
+ creator.add("select_k", k=(2, 4, "uniform"))
399
+ creator.add(
400
+ "svm",
401
+ C=(0.01, 10, "log-uniform"),
402
+ gamma=(1e-3, 1e-1, "log-uniform"),
403
+ class_weight=("balanced", None, "categorical")
404
+ )
405
+ print(creator)
406
+
407
+ ###############################################################################
408
+ # We can now use the optuna searcher with 10 trials and 3-fold cross-validation.
409
+
410
+ import optuna
411
+
412
+ study = optuna.create_study(
413
+ direction="maximize",
414
+ study_name="optuna-concept",
415
+ load_if_exists=True,
416
+ )
417
+
418
+ search_params = {
419
+ "kind": "optuna",
420
+ "study": study,
421
+ "cv": 3,
422
+ }
423
+ scores_tuned, model_tuned = run_cross_validation(
424
+ X=X,
425
+ y=y,
426
+ data=df,
427
+ X_types=X_types,
428
+ model=creator,
429
+ return_estimator="all",
430
+ search_params=search_params,
431
+ )
432
+
433
+ print(
434
+ "Scores with best hyperparameter using 10 iterations of "
435
+ f"optuna and 3-fold CV: {scores_tuned['test_score'].mean()}"
436
+ )
437
+ pprint(model_tuned.best_params_)
438
+
439
+ ###############################################################################
440
+ #
441
+ # Specifying distributions
442
+ # ~~~~~~~~~~~~~~~~~~~~~~~~
443
+ #
444
+ # The hyperparameters can be specified as distributions for the randomized
445
+ # searcher, bayesian searcher and optuna searcher. The distributions are
446
+ # either specified toolbox-specific method or a tuple convention with the
447
+ # following format: ``(low, high, distribution)`` where the distribution can
448
+ # be either ``"log-uniform"`` or ``"uniform"`` or
449
+ # ``(a, b, c, d, ..., "categorical")`` where ``a``, ``b``, ``c``, ``d``, etc.
450
+ # are the possible categorical values for the hyperparameter.
451
+ #
452
+ # For example, we can specify the ``C`` and ``gamma`` hyperparameters of the
453
+ # :class:`~sklearn.svm.SVC` as log-uniform distributions, while keeping
454
+ # the ``with_mean`` parameter of the
455
+ # :class:`~sklearn.preprocessing.StandardScaler` as a categorical parameter
456
+ # with two options.
457
+
458
+
459
+ creator = PipelineCreator(problem_type="classification")
460
+ creator.add("zscore", with_mean=(True, False, "categorical"))
461
+ creator.add(
462
+ "svm",
463
+ C=(0.01, 10, "log-uniform"),
464
+ gamma=(1e-3, 1e-1, "log-uniform"),
465
+ )
466
+ print(creator)
467
+
468
+ ###############################################################################
469
+ # While this will work for any of the ``random``, ``bayes`` or ``optuna``
470
+ # searcher options, it is important to note that both ``bayes`` and ``optuna``
471
+ # searchers accept further parameters to specify distributions. For example,
472
+ # the ``bayes`` searcher distributions are defined using the
473
+ # :class:`~skopt.space.space.Categorical`, :class:`~skopt.space.space.Integer`
474
+ # and :class:`~skopt.space.space.Real`.
475
+ #
476
+ # For example, we can define a log-uniform distribution with base 2 for the
477
+ # ``C`` hyperparameter of the :class:`~sklearn.svm.SVC` model:
478
+ from skopt.space import Real
479
+ creator = PipelineCreator(problem_type="classification")
480
+ creator.add("zscore", with_mean=(True, False, "categorical"))
481
+ creator.add(
482
+ "svm",
483
+ C=Real(0.01, 10, prior="log-uniform", base=2),
484
+ gamma=(1e-3, 1e-1, "log-uniform"),
485
+ )
486
+ print(creator)
487
+
488
+ ###############################################################################
489
+ # For the optuna searcher, the distributions are defined using the
490
+ # :class:`~optuna.distributions.CategoricalDistribution`,
491
+ # :class:`~optuna.distributions.FloatDistribution` and
492
+ # :class:`~optuna.distributions.IntDistribution`.
493
+ #
494
+ #
495
+ # For example, we can define a uniform distribution from 0.5 to 0.9 with a 0.05
496
+ # step for the ``n_components`` of a :class:`~sklearn.decomposition.PCA`
497
+ # transformer, while keeping a log-uniform distribution for the ``C`` and
498
+ # ``gamma`` hyperparameters of the :class:`~sklearn.svm.SVC` model.
499
+ from optuna.distributions import FloatDistribution
500
+ creator = PipelineCreator(problem_type="classification")
501
+ creator.add("zscore")
502
+ creator.add(
503
+ "pca",
504
+ n_components=FloatDistribution(0.5, 0.9, step=0.05),
505
+ )
506
+ creator.add(
507
+ "svm",
508
+ C=FloatDistribution(0.01, 10, log=True),
509
+ gamma=(1e-3, 1e-1, "log-uniform"),
510
+ )
511
+ print(creator)
512
+
372
513
 
373
514
  ###############################################################################
374
515
  #
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.2.dev57'
16
- __version_tuple__ = version_tuple = (0, 3, 2, 'dev57')
15
+ __version__ = version = '0.3.2.dev78'
16
+ __version_tuple__ = version_tuple = (0, 3, 2, 'dev78')
@@ -4,13 +4,13 @@
4
4
  # Sami Hamdan <s.hamdan@fz-juelich.de>
5
5
  # License: AGPL
6
6
 
7
- from typing import Dict, Iterable, List, Optional, Union
7
+ from typing import Dict, List, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import sklearn
11
12
  from sklearn.base import BaseEstimator
12
13
  from sklearn.model_selection import (
13
- BaseCrossValidator,
14
14
  check_cv,
15
15
  cross_validate,
16
16
  )
@@ -23,6 +23,7 @@ from .pipeline.merger import merge_pipelines
23
23
  from .prepare import check_consistency, prepare_input_data
24
24
  from .scoring import check_scoring
25
25
  from .utils import _compute_cvmdsum, logger, raise_error
26
+ from .utils.typing import CVLike
26
27
 
27
28
 
28
29
  def run_cross_validation( # noqa: C901
@@ -36,7 +37,7 @@ def run_cross_validation( # noqa: C901
36
37
  return_estimator: Optional[str] = None,
37
38
  return_inspector: bool = False,
38
39
  return_train_score: bool = False,
39
- cv: Optional[Union[int, BaseCrossValidator, Iterable]] = None,
40
+ cv: Optional[CVLike] = None,
40
41
  groups: Optional[str] = None,
41
42
  scoring: Union[str, List[str], None] = None,
42
43
  pos_labels: Union[str, List[str], None] = None,
@@ -134,9 +135,18 @@ def run_cross_validation( # noqa: C901
134
135
  Additional parameters in case Hyperparameter Tuning is performed, with
135
136
  the following keys:
136
137
 
137
- * 'kind': The kind of search algorithm to use, e.g.:
138
- 'grid', 'random' or 'bayes'. Can be any valid julearn searcher name
139
- or scikit-learn compatible searcher.
138
+ * 'kind': The kind of search algorithm to use, Valid options are:
139
+
140
+ * ``"grid"`` : :class:`~sklearn.model_selection.GridSearchCV`
141
+ * ``"random"`` :
142
+ :class:`~sklearn.model_selection.RandomizedSearchCV`
143
+ * ``"bayes"`` : :class:`~skopt.BayesSearchCV`
144
+ * ``"optuna"`` :
145
+ :class:`~optuna_integration.sklearn.OptunaSearchCV`
146
+ * user-registered searcher name : see
147
+ :func:`~julearn.model_selection.register_searcher`
148
+ * ``scikit-learn``-compatible searcher
149
+
140
150
  * 'cv': If a searcher is going to be used, the cross-validation
141
151
  splitting strategy to use. Defaults to same CV as for the model
142
152
  evaluation.
@@ -357,20 +367,32 @@ def run_cross_validation( # noqa: C901
357
367
 
358
368
  # Prepare cross validation
359
369
  cv_outer = check_cv(
360
- cv, classifier=problem_type == "classification" # type: ignore
370
+ cv, # type: ignore
371
+ classifier=problem_type == "classification",
361
372
  )
362
373
  logger.info(f"Using outer CV scheme {cv_outer}")
363
374
 
364
375
  check_consistency(df_y, cv, groups, problem_type) # type: ignore
365
376
 
366
377
  cv_return_estimator = return_estimator in ["cv", "all"]
367
- scoring = check_scoring(pipeline, scoring, wrap_score=wrap_score)
378
+ scoring = check_scoring(
379
+ pipeline, # type: ignore
380
+ scoring,
381
+ wrap_score=wrap_score,
382
+ )
368
383
 
369
384
  cv_mdsum = _compute_cvmdsum(cv_outer)
370
385
  fit_params = {}
371
386
  if df_groups is not None:
372
387
  if isinstance(pipeline, BaseSearchCV):
373
388
  fit_params["groups"] = df_groups.values
389
+
390
+ _sklearn_deprec_fit_params = {}
391
+ if sklearn.__version__ >= "1.4.0":
392
+ _sklearn_deprec_fit_params["params"] = fit_params
393
+ else:
394
+ _sklearn_deprec_fit_params["fit_params"] = fit_params
395
+
374
396
  scores = cross_validate(
375
397
  pipeline,
376
398
  df_X,
@@ -382,7 +404,7 @@ def run_cross_validation( # noqa: C901
382
404
  n_jobs=n_jobs,
383
405
  return_train_score=return_train_score,
384
406
  verbose=verbose, # type: ignore
385
- fit_params=fit_params,
407
+ **_sklearn_deprec_fit_params,
386
408
  )
387
409
 
388
410
  n_repeats = getattr(cv_outer, "n_repeats", 1)
@@ -13,11 +13,11 @@ from sklearn.utils.metaestimators import available_if
13
13
 
14
14
 
15
15
  try: # sklearn < 1.4.0
16
- from sklearn.utils.validation import _check_fit_params
16
+ from sklearn.utils.validation import _check_fit_params # type: ignore
17
17
 
18
18
  fit_params_checker = _check_fit_params
19
19
  except ImportError: # sklearn >= 1.4.0
20
- from sklearn.utils.validation import _check_method_params
20
+ from sklearn.utils.validation import _check_method_params # type: ignore
21
21
 
22
22
  fit_params_checker = _check_method_params
23
23
 
@@ -180,7 +180,12 @@ class JuTransformer(JuBaseEstimator, TransformerMixin):
180
180
  self.row_select_col_type = row_select_col_type
181
181
  self.row_select_vals = row_select_vals
182
182
 
183
- def fit(self, X, y=None, **fit_params): # noqa: N803
183
+ def fit(
184
+ self,
185
+ X: pd.DataFrame, # noqa: N803
186
+ y: Optional[pd.Series] = None,
187
+ **fit_params,
188
+ ):
184
189
  """Fit the model.
185
190
 
186
191
  This method will fit the model using only the columns selected by
@@ -217,8 +222,21 @@ class JuTransformer(JuBaseEstimator, TransformerMixin):
217
222
  self.row_select_vals = [self.row_select_vals]
218
223
  return self._fit(**self._select_rows(X, y, **fit_params))
219
224
 
225
+ def _fit(
226
+ self,
227
+ X: pd.DataFrame, # noqa: N803,
228
+ y: Optional[pd.Series],
229
+ **kwargs,
230
+ ) -> None:
231
+ raise_error(
232
+ "This method should be implemented in the concrete class",
233
+ klass=NotImplementedError,
234
+ )
235
+
220
236
  def _add_backed_filtered(
221
- self, X: pd.DataFrame, X_trans: pd.DataFrame # noqa: N803
237
+ self,
238
+ X: pd.DataFrame, # noqa: N803
239
+ X_trans: pd.DataFrame, # noqa: N803
222
240
  ) -> pd.DataFrame:
223
241
  """Add the left-out columns back to the transformed data.
224
242
 
@@ -301,7 +319,7 @@ class WrapModel(JuBaseEstimator):
301
319
 
302
320
  def fit(
303
321
  self,
304
- X: pd.DataFrame, # noqa: N803
322
+ X: DataLike, # noqa: N803
305
323
  y: Optional[DataLike] = None,
306
324
  **fit_params: Any,
307
325
  ) -> "WrapModel":
@@ -312,7 +330,7 @@ class WrapModel(JuBaseEstimator):
312
330
 
313
331
  Parameters
314
332
  ----------
315
- X : pd.DataFrame
333
+ X : DataLike
316
334
  The data to fit the model on.
317
335
  y : DataLike, optional
318
336
  The target data (default is None).
@@ -329,9 +347,9 @@ class WrapModel(JuBaseEstimator):
329
347
  if self.needed_types is not None:
330
348
  self.needed_types = ensure_column_types(self.needed_types)
331
349
 
332
- Xt = self.filter_columns(X)
350
+ Xt = self.filter_columns(X) # type: ignore
333
351
  self.model_ = self.model
334
- self.model_.fit(Xt, y, **fit_params)
352
+ self.model_.fit(Xt, y, **fit_params) # type: ignore
335
353
  return self
336
354
 
337
355
  def predict(self, X: pd.DataFrame) -> DataLike: # noqa: N803
@@ -110,7 +110,7 @@ def test_WrapModel(
110
110
 
111
111
  np.random.seed(42)
112
112
  lr = model()
113
- lr.fit(X_iris_selected, y_iris)
113
+ lr.fit(X_iris_selected, y_iris) # type: ignore
114
114
  pred_sk = lr.predict(X_iris_selected)
115
115
 
116
116
  np.random.seed(42)
@@ -270,7 +270,7 @@ def search_params(request: FixtureRequest) -> Optional[Dict]:
270
270
  scope="function",
271
271
  )
272
272
  def bayes_search_params(request: FixtureRequest) -> Optional[Dict]:
273
- """Return different search_params argument for BayesSearchCV.
273
+ """Return different search_params argument for BayesSearchCV.
274
274
 
275
275
  Parameters
276
276
  ----------
@@ -286,6 +286,29 @@ def bayes_search_params(request: FixtureRequest) -> Optional[Dict]:
286
286
 
287
287
  return request.param
288
288
 
289
+ @fixture(
290
+ params=[
291
+ {"kind": "optuna", "n_trials": 10, "cv": 3},
292
+ {"kind": "optuna", "timeout": 20},
293
+ ],
294
+ scope="function",
295
+ )
296
+ def optuna_search_params(request: FixtureRequest) -> Optional[Dict]:
297
+ """Return different search_params argument for OptunaSearchCV.
298
+
299
+ Parameters
300
+ ----------
301
+ request : pytest.FixtureRequest
302
+ The request object.
303
+
304
+ Returns
305
+ -------
306
+ dict or None
307
+ A dictionary with the search_params argument.
308
+
309
+ """
310
+
311
+ return request.param
289
312
 
290
313
  _tuning_params = {
291
314
  "zscore": {"with_mean": [True, False]},
@@ -4,13 +4,14 @@
4
4
  # Sami Hamdan <s.hamdan@fz-juelich.de>
5
5
  # License: AGPL
6
6
 
7
- from typing import List, Optional, Union
7
+ from typing import Optional, Union
8
8
 
9
9
  import pandas as pd
10
10
  from sklearn.model_selection import BaseCrossValidator, check_cv
11
11
  from sklearn.utils.metaestimators import available_if
12
12
 
13
13
  from ..utils import _compute_cvmdsum, is_nonoverlapping_cv, raise_error
14
+ from ..utils.typing import DataLike
14
15
  from ._pipeline import PipelineInspector
15
16
 
16
17
 
@@ -60,14 +61,13 @@ class FoldsInspector:
60
61
  def __init__(
61
62
  self,
62
63
  scores: pd.DataFrame,
63
- cv: BaseCrossValidator,
64
- X: Union[str, List[str]], # noqa: N803
65
- y: str,
64
+ cv: Union[BaseCrossValidator, int],
65
+ X: DataLike, # noqa: N803
66
+ y: pd.Series,
66
67
  func: str = "predict",
67
- groups: Optional[str] = None,
68
+ groups: Optional[pd.Series] = None,
68
69
  ):
69
70
  self._scores = scores
70
- self._cv = cv
71
71
  self._X = X
72
72
  self._y = y
73
73
  self._func = func
@@ -92,7 +92,7 @@ class FoldsInspector:
92
92
  )
93
93
 
94
94
  cv = check_cv(cv)
95
-
95
+ self._cv = cv
96
96
  t_cv_mdsum = _compute_cvmdsum(cv)
97
97
  if t_cv_mdsum != cv_mdsums[0]:
98
98
  raise_error(
@@ -120,10 +120,16 @@ class FoldsInspector:
120
120
 
121
121
  predictions = []
122
122
  for i_fold, (_, test) in enumerate(
123
- self._cv.split(self._X, self._y, groups=self._groups)
123
+ self._cv.split(
124
+ self._X, # type: ignore
125
+ self._y,
126
+ groups=self._groups,
127
+ )
124
128
  ):
125
129
  t_model = self._scores["estimator"][i_fold]
126
- t_values = getattr(t_model, func)(self._X.iloc[test])
130
+ t_values = getattr(t_model, func)(
131
+ self._X.iloc[test] # type: ignore
132
+ )
127
133
  if t_values.ndim == 1:
128
134
  t_values = t_values[:, None]
129
135
  column_names = [f"p{i}" for i in range(t_values.shape[1])]
@@ -152,7 +158,7 @@ class FoldsInspector:
152
158
  t_df.columns = [f"fold{i_fold}_{x}" for x in t_df.columns]
153
159
  predictions = pd.concat(predictions, axis=1)
154
160
  predictions = predictions.sort_index()
155
- predictions["target"] = self._y.values
161
+ predictions["target"] = self._y.values # type: ignore
156
162
  return predictions
157
163
 
158
164
  def __getitem__(self, key):
@@ -53,7 +53,7 @@ def preprocess(
53
53
  else:
54
54
  raise_error(f"No step named {until} found.")
55
55
  df_out = pipeline[:i].transform(_X)
56
-
56
+ df_out = df_out.copy()
57
57
  if not isinstance(df_out, pd.DataFrame) and with_column_types is False:
58
58
  raise_error(
59
59
  "The output of the pipeline is not a DataFrame. Cannot remove "
@@ -3,7 +3,6 @@
3
3
  # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
4
4
  # Sami Hamdan <s.hamdan@fz-juelich.de>
5
5
  # License: AGPL
6
-
7
6
  import numpy as np
8
7
  import pandas as pd
9
8
  import pytest
@@ -70,7 +69,10 @@ def scores(df_typed_iris, n_iters=5, mock_model=None):
70
69
  if mock_model is None:
71
70
  mock_model = MockModelReturnsIndex
72
71
 
73
- estimators = [WrapModel(mock_model()).fit(X, y) for _ in range(n_iters)]
72
+ estimators = [
73
+ WrapModel(mock_model()).fit(X, y) # type: ignore
74
+ for _ in range(n_iters)
75
+ ]
74
76
 
75
77
  return pd.DataFrame(
76
78
  {
@@ -18,28 +18,28 @@ if TYPE_CHECKING:
18
18
 
19
19
  def test_no_cv() -> None:
20
20
  """Test inspector with no cross-validation."""
21
- inspector = Inspector({})
21
+ inspector = Inspector({}) # type: ignore
22
22
  with pytest.raises(ValueError, match="No cv"):
23
23
  _ = inspector.folds
24
24
 
25
25
 
26
26
  def test_no_X() -> None:
27
27
  """Test inspector with no features."""
28
- inspector = Inspector({}, cv=5)
28
+ inspector = Inspector({}, cv=5) # type: ignore
29
29
  with pytest.raises(ValueError, match="No X"):
30
30
  _ = inspector.folds
31
31
 
32
32
 
33
33
  def test_no_y() -> None:
34
34
  """Test inspector with no targets."""
35
- inspector = Inspector({}, cv=5, X=[1, 2, 3])
35
+ inspector = Inspector({}, cv=5, X=[1, 2, 3]) # type: ignore
36
36
  with pytest.raises(ValueError, match="No y"):
37
37
  _ = inspector.folds
38
38
 
39
39
 
40
40
  def test_no_model() -> None:
41
41
  """Test inspector with no model."""
42
- inspector = Inspector({})
42
+ inspector = Inspector({}) # type: ignore
43
43
  with pytest.raises(ValueError, match="No model"):
44
44
  _ = inspector.model
45
45
 
@@ -63,8 +63,11 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None:
63
63
  return_inspector=True,
64
64
  problem_type="classification",
65
65
  )
66
- assert pipe == inspect.model._model
67
- for (_, score), inspect_fold in zip(scores.iterrows(), inspect.folds):
66
+ assert pipe == inspect.model._model # type: ignore
67
+ for (_, score), inspect_fold in zip(
68
+ scores.iterrows(), # type: ignore
69
+ inspect.folds, # type: ignore
70
+ ):
68
71
  assert score["estimator"] == inspect_fold.model._model
69
72
 
70
73
 
@@ -88,6 +91,6 @@ def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None:
88
91
  return_estimator="all",
89
92
  return_inspector=True,
90
93
  )
91
- assert pipe == inspect.model._model
92
- inspect.model.get_fitted_params()
93
- inspect.model.get_params()
94
+ assert pipe == inspect.model._model # type: ignore
95
+ inspect.model.get_fitted_params() # type: ignore
96
+ inspect.model.get_params() # type: ignore