julearn 0.3.4.dev13__tar.gz → 0.3.4.dev25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (240) hide show
  1. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/PKG-INFO +1 -1
  2. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/main.rst +1 -0
  3. julearn-0.3.4.dev25/docs/changes/newsfragments/271.enh +1 -0
  4. julearn-0.3.4.dev25/docs/changes/newsfragments/293.enh +1 -0
  5. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/02_inspection/run_binary_inspect_folds.py +0 -1
  6. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/__init__.py +1 -1
  7. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/_version.py +2 -2
  8. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/api.py +418 -66
  9. julearn-0.3.4.dev25/julearn/model_selection/final_model_cv.py +96 -0
  10. julearn-0.3.4.dev25/julearn/model_selection/tests/test_final_model_cv.py +53 -0
  11. julearn-0.3.4.dev25/julearn/model_selection/utils.py +55 -0
  12. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/prepare.py +5 -1
  13. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/tests/test_api.py +54 -5
  14. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/_cv.py +7 -0
  15. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/typing.py +8 -1
  16. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn.egg-info/PKG-INFO +1 -1
  17. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn.egg-info/SOURCES.txt +5 -0
  18. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  19. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  20. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/ISSUE_TEMPLATE/documentation_request.yaml +0 -0
  21. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  22. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/check-stale.yml +0 -0
  23. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/ci-docs.yml +0 -0
  24. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/ci.yml +0 -0
  25. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/docs-preview.yml +0 -0
  26. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/docs.yml +0 -0
  27. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/lint.yml +0 -0
  28. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.github/workflows/pypi.yml +0 -0
  29. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.gitignore +0 -0
  30. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/.pre-commit-config.yaml +0 -0
  31. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/AUTHORS.rst +0 -0
  32. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/LICENSE.md +0 -0
  33. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/README.md +0 -0
  34. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/codecov.yml +0 -0
  35. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/Makefile +0 -0
  36. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_static/css/custom.css +0 -0
  37. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_static/js/custom.js +0 -0
  38. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_templates/class.rst +0 -0
  39. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_templates/function.rst +0 -0
  40. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_templates/function_warning.rst +0 -0
  41. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/_templates/versions.html +0 -0
  42. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/base.rst +0 -0
  43. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/index.rst +0 -0
  44. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/inspect.rst +0 -0
  45. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/model_selection.rst +0 -0
  46. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/models.rst +0 -0
  47. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/pipeline.rst +0 -0
  48. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/prepare.rst +0 -0
  49. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/scoring.rst +0 -0
  50. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/stats.rst +0 -0
  51. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/transformers.rst +0 -0
  52. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/utils.rst +0 -0
  53. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/api/viz.rst +0 -0
  54. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/available_pipeline_steps.rst +0 -0
  55. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/changes/contributors.inc +0 -0
  56. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/changes/newsfragments/.gitignore +0 -0
  57. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/changes/newsfragments/268.bugfix +0 -0
  58. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/changes/newsfragments/270.enh +0 -0
  59. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/conf.py +0 -0
  60. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/configuration.rst +0 -0
  61. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/contributing.rst +0 -0
  62. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/examples.rst +0 -0
  63. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/faq.rst +0 -0
  64. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/getting_started.rst +0 -0
  65. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/corrected_ttest.png +0 -0
  66. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/final_estimator.png +0 -0
  67. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/iris_X.png +0 -0
  68. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/iris_df.png +0 -0
  69. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/iris_y.png +0 -0
  70. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo.png +0 -0
  71. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_calm.png +0 -0
  72. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_confbias.png +0 -0
  73. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_cv.png +0 -0
  74. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_generalization.png +0 -0
  75. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_it.png +0 -0
  76. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_ml.png +0 -0
  77. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/julearn_logo_mlit.png +0 -0
  78. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/multiple_scorers_run_cv.png +0 -0
  79. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/plot_scores.png +0 -0
  80. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/scores_run_cv.png +0 -0
  81. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/scores_run_cv_splitter.png +0 -0
  82. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/images/scores_run_cv_train.png +0 -0
  83. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/index.rst +0 -0
  84. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/links.inc +0 -0
  85. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/maintaining.rst +0 -0
  86. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/redirect.html +0 -0
  87. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/CBPM.rst +0 -0
  88. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/confound_removal.rst +0 -0
  89. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/cross_validation_splitter.rst +0 -0
  90. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/hyperparameter_tuning.rst +0 -0
  91. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/index.rst +0 -0
  92. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/model_inspect.rst +0 -0
  93. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/stacked_models.rst +0 -0
  94. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/selected_deeper_topics/target_transformers.rst +0 -0
  95. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/sphinxext/gh_substitutions.py +0 -0
  96. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/cross_validation.rst +0 -0
  97. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/data.rst +0 -0
  98. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/index.rst +0 -0
  99. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/model_comparison.rst +0 -0
  100. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/model_evaluation.rst +0 -0
  101. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/what_really_need_know/pipeline.rst +0 -0
  102. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/docs/whats_new.rst +0 -0
  103. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/README.rst +0 -0
  104. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/plot_cm_acc_multiclass.py +0 -0
  105. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/plot_example_regression.py +0 -0
  106. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/plot_stratified_kfold_reg.py +0 -0
  107. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/run_combine_pandas.py +0 -0
  108. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/run_grouped_cv.py +0 -0
  109. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/00_starting/run_simple_binary_classification.py +0 -0
  110. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/01_model_comparison/README.rst +0 -0
  111. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/01_model_comparison/plot_simple_model_comparison.py +0 -0
  112. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/02_inspection/README.rst +0 -0
  113. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/02_inspection/plot_groupcv_inspect_svm.py +0 -0
  114. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/02_inspection/plot_inspect_random_forest.py +0 -0
  115. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/02_inspection/plot_preprocess.py +0 -0
  116. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/README.rst +0 -0
  117. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_apply_to_target.py +0 -0
  118. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_example_pca_featsets.py +0 -0
  119. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_hyperparameter_multiple_grids.py +0 -0
  120. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_hyperparameter_tuning.py +0 -0
  121. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +0 -0
  122. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/03_complex_models/run_stacked_models.py +0 -0
  123. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/04_confounds/README.rst +0 -0
  124. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/04_confounds/plot_confound_removal_classification.py +0 -0
  125. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/04_confounds/run_return_confounds.py +0 -0
  126. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/05_customization/README.rst +0 -0
  127. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/05_customization/run_custom_scorers_regression.py +0 -0
  128. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/README.rst +0 -0
  129. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_cbpm_docs.py +0 -0
  130. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_confound_removal_docs.py +0 -0
  131. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_cv_splitters_docs.py +0 -0
  132. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_data_docs.py +0 -0
  133. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_hyperparameters_docs.py +0 -0
  134. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_model_comparison_docs.py +0 -0
  135. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_model_evaluation_docs.py +0 -0
  136. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_model_inspection_docs.py +0 -0
  137. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_pipeline_docs.py +0 -0
  138. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_stacked_models_docs.py +0 -0
  139. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/99_docs/run_target_transformer_docs.py +0 -0
  140. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/README.rst +0 -0
  141. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/XX_disabled/dis_run_n_jobs.py +0 -0
  142. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/examples/XX_disabled/dis_run_target_confound_removal.py +0 -0
  143. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/ignore_words.txt +0 -0
  144. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/base/__init__.py +0 -0
  145. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/base/column_types.py +0 -0
  146. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/base/estimators.py +0 -0
  147. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/base/tests/test_base_estimators.py +0 -0
  148. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/base/tests/test_column_types.py +0 -0
  149. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/config.py +0 -0
  150. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/conftest.py +0 -0
  151. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/external/optuna_searchcv.py +0 -0
  152. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/__init__.py +0 -0
  153. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/_cv.py +0 -0
  154. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/_pipeline.py +0 -0
  155. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/_preprocess.py +0 -0
  156. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/inspector.py +0 -0
  157. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/tests/test_cv.py +0 -0
  158. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/tests/test_inspector.py +0 -0
  159. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/tests/test_pipeline.py +0 -0
  160. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/inspect/tests/test_preprocess.py +0 -0
  161. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/__init__.py +0 -0
  162. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/_optuna_searcher.py +0 -0
  163. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/_skopt_searcher.py +0 -0
  164. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/available_searchers.py +0 -0
  165. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/continuous_stratified_kfold.py +0 -0
  166. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/stratified_bootstrap.py +0 -0
  167. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/tests/test_available_searchers.py +0 -0
  168. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/tests/test_continous_stratified_kfold.py +0 -0
  169. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/tests/test_optuna_searcher.py +0 -0
  170. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/tests/test_skopt_searcher.py +0 -0
  171. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/model_selection/tests/test_stratified_bootstrap.py +0 -0
  172. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/__init__.py +0 -0
  173. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/available_models.py +0 -0
  174. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/dynamic.py +0 -0
  175. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/tests/test_available_models.py +0 -0
  176. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/tests/test_dynamic.py +0 -0
  177. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/models/tests/test_models.py +0 -0
  178. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/__init__.py +0 -0
  179. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/merger.py +0 -0
  180. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/pipeline_creator.py +0 -0
  181. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/target_pipeline.py +0 -0
  182. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/target_pipeline_creator.py +0 -0
  183. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/tests/test_merger.py +0 -0
  184. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/tests/test_pipeline_creator.py +0 -0
  185. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/tests/test_target_pipeline.py +0 -0
  186. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/pipeline/tests/test_target_pipeline_creator.py +0 -0
  187. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/scoring/__init__.py +0 -0
  188. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/scoring/available_scorers.py +0 -0
  189. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/scoring/metrics.py +0 -0
  190. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/scoring/tests/test_available_scorers.py +0 -0
  191. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/scoring/tests/test_metrics.py +0 -0
  192. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/stats/__init__.py +0 -0
  193. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/stats/corrected_ttest.py +0 -0
  194. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/stats/tests/test_corrected_ttest.py +0 -0
  195. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/tests/test_config.py +0 -0
  196. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/tests/test_prepare.py +0 -0
  197. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/__init__.py +0 -0
  198. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/available_transformers.py +0 -0
  199. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/cbpm.py +0 -0
  200. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/confound_remover.py +0 -0
  201. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/__init__.py +0 -0
  202. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/change_column_types.py +0 -0
  203. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/drop_columns.py +0 -0
  204. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/filter_columns.py +0 -0
  205. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/set_column_types.py +0 -0
  206. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/tests/test_change_column_types.py +0 -0
  207. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/tests/test_drop_columns.py +0 -0
  208. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/tests/test_filter_columns.py +0 -0
  209. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/dataframe/tests/test_set_column_types.py +0 -0
  210. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/ju_column_transformer.py +0 -0
  211. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/__init__.py +0 -0
  212. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/available_target_transformers.py +0 -0
  213. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/ju_target_transformer.py +0 -0
  214. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/ju_transformed_target_model.py +0 -0
  215. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/target_confound_remover.py +0 -0
  216. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/tests/test_available_target_transformers.py +0 -0
  217. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/tests/test_ju_target_transformer.py +0 -0
  218. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/tests/test_ju_transformed_target_model.py +0 -0
  219. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/target/tests/test_target_confound_remover.py +0 -0
  220. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/tests/test_available_transformers.py +0 -0
  221. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/tests/test_cbpm.py +0 -0
  222. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/tests/test_confounds.py +0 -0
  223. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/transformers/tests/test_jucolumntransformers.py +0 -0
  224. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/__init__.py +0 -0
  225. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/checks.py +0 -0
  226. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/logging.py +0 -0
  227. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/testing.py +0 -0
  228. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/tests/test_logging.py +0 -0
  229. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/tests/test_version.py +0 -0
  230. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/utils/versions.py +0 -0
  231. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/viz/__init__.py +0 -0
  232. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/viz/_scores.py +0 -0
  233. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn/viz/res/julearn_logo_generalization.png +0 -0
  234. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn.egg-info/dependency_links.txt +0 -0
  235. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn.egg-info/requires.txt +0 -0
  236. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/julearn.egg-info/top_level.txt +0 -0
  237. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/pyproject.toml +0 -0
  238. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/setup.cfg +0 -0
  239. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/setup.py +0 -0
  240. {julearn-0.3.4.dev13 → julearn-0.3.4.dev25}/tox.ini +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: julearn
3
- Version: 0.3.4.dev13
3
+ Version: 0.3.4.dev25
4
4
  Summary: Juelich Machine Learning Library
5
5
  Author-email: Fede Raimondo <f.raimondo@fz-juelich.de>, Sami Hamdan <s.hamdan@fz-juelich.de>
6
6
  Maintainer-email: Sami Hamdan <s.hamdan@fz-juelich.de>
@@ -15,3 +15,4 @@ Functions
15
15
  :template: function.rst
16
16
 
17
17
  run_cross_validation
18
+ run_fit
@@ -0,0 +1 @@
1
+ Add :func:`.run_fit` that implements a model fitting procedure with the same API as :func:`.run_cross_validation` by `Fede Raimondo`_.
@@ -0,0 +1 @@
1
+ Change the internal logic of :func:`.run_cross_validation` to optimise joblib calls by `Fede Raimondo`_
@@ -44,7 +44,6 @@ creator = PipelineCreator(problem_type="classification")
44
44
  creator.add("zscore")
45
45
  creator.add("svm")
46
46
 
47
- cv = ShuffleSplit(n_splits=5, train_size=0.7, random_state=200)
48
47
  cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=200)
49
48
 
50
49
  scores, model, inspector = run_cross_validation(
@@ -14,5 +14,5 @@ from . import utils
14
14
  from . import prepare
15
15
  from . import api
16
16
  from . import stats
17
- from .api import run_cross_validation
17
+ from .api import run_cross_validation, run_fit
18
18
  from .pipeline import PipelineCreator, TargetPipelineCreator
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.4.dev13'
16
- __version_tuple__ = version_tuple = (0, 3, 4, 'dev13')
15
+ __version__ = version = '0.3.4.dev25'
16
+ __version_tuple__ = version_tuple = (0, 3, 4, 'dev25')
@@ -4,20 +4,20 @@
4
4
  # Sami Hamdan <s.hamdan@fz-juelich.de>
5
5
  # License: AGPL
6
6
 
7
- from typing import Dict, List, Optional, Union
7
+ from typing import Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
  import sklearn
12
12
  from sklearn.base import BaseEstimator
13
13
  from sklearn.model_selection import (
14
- check_cv,
15
14
  cross_validate,
16
15
  )
17
16
  from sklearn.model_selection._search import BaseSearchCV
18
17
  from sklearn.pipeline import Pipeline
19
18
 
20
19
  from .inspect import Inspector
20
+ from .model_selection.utils import check_cv
21
21
  from .pipeline import PipelineCreator
22
22
  from .pipeline.merger import merge_pipelines
23
23
  from .prepare import check_consistency, prepare_input_data
@@ -26,7 +26,7 @@ from .utils import _compute_cvmdsum, logger, raise_error
26
26
  from .utils.typing import CVLike
27
27
 
28
28
 
29
- def run_cross_validation( # noqa: C901
29
+ def _validata_api_params( # noqa: C901
30
30
  X: List[str], # noqa: N803
31
31
  y: str,
32
32
  model: Union[str, PipelineCreator, BaseEstimator, List[PipelineCreator]],
@@ -36,18 +36,22 @@ def run_cross_validation( # noqa: C901
36
36
  preprocess: Union[None, str, List[str]] = None,
37
37
  return_estimator: Optional[str] = None,
38
38
  return_inspector: bool = False,
39
- return_train_score: bool = False,
40
- cv: Optional[CVLike] = None,
41
39
  groups: Optional[str] = None,
42
- scoring: Union[str, List[str], None] = None,
43
40
  pos_labels: Union[str, List[str], None] = None,
44
41
  model_params: Optional[Dict] = None,
45
42
  search_params: Optional[Dict] = None,
46
43
  seed: Optional[int] = None,
47
- n_jobs: Optional[int] = None,
48
- verbose: Optional[int] = 0,
49
- ):
50
- """Run cross validation and score.
44
+ ) -> Tuple[
45
+ pd.DataFrame,
46
+ pd.Series,
47
+ Optional[pd.Series],
48
+ Union[Pipeline, BaseSearchCV],
49
+ Optional[str],
50
+ bool,
51
+ bool,
52
+ str,
53
+ ]:
54
+ """Validate the parameters passed to the API functions.
51
55
 
52
56
  Parameters
53
57
  ----------
@@ -95,28 +99,9 @@ def run_cross_validation( # noqa: C901
95
99
 
96
100
  return_inspector : bool
97
101
  Whether to return the inspector object (default is False)
98
-
99
- return_train_score : bool
100
- Whether to return the training score with the test scores
101
- (default is False).
102
- cv : int, str or cross-validation generator | None
103
- Cross-validation splitting strategy to use for model evaluation.
104
-
105
- Options are:
106
-
107
- * None: defaults to 5-fold
108
- * int: the number of folds in a `(Stratified)KFold`
109
- * CV Splitter (see scikit-learn documentation on CV)
110
- * An iterable yielding (train, test) splits as arrays of indices.
111
-
112
102
  groups : str | None
113
103
  The grouping labels in case a Group CV is used.
114
104
  See :ref:`data_usage` for details.
115
- scoring : ScorerLike, optional
116
- The scoring metric to use.
117
- See https://scikit-learn.org/stable/modules/model_evaluation.html for
118
- a comprehensive list of options. If None, use the model's default
119
- scorer.
120
105
  pos_labels : str, int, float or list | None
121
106
  The labels to interpret as positive. If not None, every element from y
122
107
  will be converted to 1 if is equal or in pos_labels and to 0 if not.
@@ -157,36 +142,27 @@ def run_cross_validation( # noqa: C901
157
142
  seed : int | None
158
143
  If not None, set the random seed before any operation. Useful for
159
144
  reproducibility.
160
- n_jobs : int, optional
161
- Number of jobs to run in parallel. Training the estimator and computing
162
- the score are parallelized over the cross-validation splits.
163
- ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
164
- ``-1`` means using all processors (default None).
165
- verbose: int
166
- Verbosity level of outer cross-validation.
167
- Follows scikit-learn/joblib converntions.
168
- 0 means no additional information is printed.
169
- Larger number generally mean more information is printed.
170
- Note: verbosity up to 50 will print into standard error,
171
- while larger than 50 will print in standrad output.
172
145
 
173
146
  Returns
174
147
  -------
175
- scores : pd.DataFrame
176
- The resulting scores (one column for each score specified).
177
- Additionally, a 'fit_time' column will be added.
178
- And, if ``return_estimator='all'`` or
179
- ``return_estimator='cv'``, an 'estimator' columns with the
180
- corresponding estimators fitted for each CV split.
181
- final_estimator : object
182
- The final estimator, fitted on all the data (only if
183
- ``return_estimator='all'`` or ``return_estimator='final'``)
184
- inspector : Inspector | None
185
- The inspector object (only if ``return_inspector=True``)
148
+ df_X : pd.DataFrame
149
+ The features DataFrame.
150
+ df_y : pd.Series
151
+ The target Series.
152
+ df_groups : pd.Series | None
153
+ The groups Series.
154
+ pipeline : Pipeline | BaseSearchCV
155
+ The pipeline to use.
156
+ return_estimator : str | None
157
+ The validated return_estimator parameter.
158
+ return_inspector : bool
159
+ The validated return_inspector parameter.
160
+ wrap_score : bool
161
+ Whether to wrap the score or not.
162
+ problem_type : str
163
+ The problem type.
186
164
 
187
165
  """
188
-
189
- # Validate parameters
190
166
  if return_estimator not in [None, "final", "cv", "all"]:
191
167
  raise_error(
192
168
  f"return_estimator must be one of None, 'final', 'cv', 'all'. "
@@ -365,16 +341,219 @@ def run_cross_validation( # noqa: C901
365
341
  elif problem_type == "regression":
366
342
  logger.info(f"\tTarget type: {df_y.dtype}")
367
343
 
344
+ out = (
345
+ df_X,
346
+ df_y,
347
+ df_groups,
348
+ pipeline,
349
+ return_estimator,
350
+ return_inspector,
351
+ wrap_score,
352
+ problem_type,
353
+ )
354
+ return out
355
+
356
+
357
+ def run_cross_validation(
358
+ X: List[str], # noqa: N803
359
+ y: str,
360
+ model: Union[str, PipelineCreator, BaseEstimator, List[PipelineCreator]],
361
+ data: pd.DataFrame,
362
+ X_types: Optional[Dict] = None, # noqa: N803
363
+ problem_type: Optional[str] = None,
364
+ preprocess: Union[None, str, List[str]] = None,
365
+ return_estimator: Optional[str] = None,
366
+ return_inspector: bool = False,
367
+ return_train_score: bool = False,
368
+ cv: Optional[CVLike] = None,
369
+ groups: Optional[str] = None,
370
+ scoring: Union[str, List[str], None] = None,
371
+ pos_labels: Union[str, List[str], None] = None,
372
+ model_params: Optional[Dict] = None,
373
+ search_params: Optional[Dict] = None,
374
+ seed: Optional[int] = None,
375
+ n_jobs: Optional[int] = None,
376
+ verbose: Optional[int] = 0,
377
+ ):
378
+ """Run cross validation and score.
379
+
380
+ Parameters
381
+ ----------
382
+ X : list of str
383
+ The features to use.
384
+ See :ref:`data_usage` for details.
385
+ y : str
386
+ The targets to predict.
387
+ See :ref:`data_usage` for details.
388
+ model : str or scikit-learn compatible model.
389
+ If string, it will use one of the available models.
390
+ data : pandas.DataFrame
391
+ DataFrame with the data. See :ref:`data_usage` for details.
392
+ X_types : dict[str, list of str]
393
+ A dictionary containing keys with column type as a str and the
394
+ columns of this column type as a list of str.
395
+ problem_type : str
396
+ The kind of problem to model.
397
+
398
+ Options are:
399
+
400
+ * "classification": Perform a classification
401
+ in which the target (y) has categorical classes (default).
402
+ The parameter pos_labels can be used to convert a target with
403
+ multiple_classes into binary.
404
+ * "regression". Perform a regression. The target (y) has to be
405
+ ordinal at least.
406
+
407
+ preprocess : str, TransformerLike or list or PipelineCreator | None
408
+ Transformer to apply to the features. If string, use one of the
409
+ available transformers. If list, each element can be a string or
410
+ scikit-learn compatible transformer. If None (default), no
411
+ transformation is applied.
412
+
413
+ See documentation for details.
414
+
415
+ return_estimator : str | None
416
+ Return the fitted estimator(s).
417
+ Options are:
418
+
419
+ * 'final': Return the estimator fitted on all the data.
420
+ * 'cv': Return the all the estimator from each CV split, fitted on the
421
+ training data.
422
+ * 'all': Return all the estimators (final and cv).
423
+
424
+ return_inspector : bool
425
+ Whether to return the inspector object (default is False)
426
+
427
+ return_train_score : bool
428
+ Whether to return the training score with the test scores
429
+ (default is False).
430
+ cv : int, str or cross-validation generator | None
431
+ Cross-validation splitting strategy to use for model evaluation.
432
+
433
+ Options are:
434
+
435
+ * None: defaults to 5-fold
436
+ * int: the number of folds in a `(Stratified)KFold`
437
+ * CV Splitter (see scikit-learn documentation on CV)
438
+ * An iterable yielding (train, test) splits as arrays of indices.
439
+
440
+ groups : str | None
441
+ The grouping labels in case a Group CV is used.
442
+ See :ref:`data_usage` for details.
443
+ scoring : ScorerLike, optional
444
+ The scoring metric to use.
445
+ See https://scikit-learn.org/stable/modules/model_evaluation.html for
446
+ a comprehensive list of options. If None, use the model's default
447
+ scorer.
448
+ pos_labels : str, int, float or list | None
449
+ The labels to interpret as positive. If not None, every element from y
450
+ will be converted to 1 if is equal or in pos_labels and to 0 if not.
451
+ model_params : dict | None
452
+ If not None, this dictionary specifies the model parameters to use
453
+
454
+ The dictionary can define the following keys:
455
+
456
+ * 'STEP__PARAMETER': A value (or several) to be used as PARAMETER for
457
+ STEP in the pipeline. Example: 'svm__probability': True will set
458
+ the parameter 'probability' of the 'svm' model. If more than option
459
+ is provided for at least one hyperparameter, a search will be
460
+ performed.
461
+
462
+ search_params : dict | None
463
+ Additional parameters in case Hyperparameter Tuning is performed, with
464
+ the following keys:
465
+
466
+ * 'kind': The kind of search algorithm to use, Valid options are:
467
+
468
+ * ``"grid"`` : :class:`~sklearn.model_selection.GridSearchCV`
469
+ * ``"random"`` :
470
+ :class:`~sklearn.model_selection.RandomizedSearchCV`
471
+ * ``"bayes"`` : :class:`~skopt.BayesSearchCV`
472
+ * ``"optuna"`` :
473
+ :class:`~optuna_integration.OptunaSearchCV`
474
+ * user-registered searcher name : see
475
+ :func:`~julearn.model_selection.register_searcher`
476
+ * ``scikit-learn``-compatible searcher
477
+
478
+ * 'cv': If a searcher is going to be used, the cross-validation
479
+ splitting strategy to use. Defaults to same CV as for the model
480
+ evaluation.
481
+ * 'scoring': If a searcher is going to be used, the scoring metric to
482
+ evaluate the performance.
483
+
484
+ See :ref:`hp_tuning` for details.
485
+ seed : int | None
486
+ If not None, set the random seed before any operation. Useful for
487
+ reproducibility.
488
+ n_jobs : int, optional
489
+ Number of jobs to run in parallel. Training the estimator and computing
490
+ the score are parallelized over the cross-validation splits.
491
+ ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
492
+ ``-1`` means using all processors (default None).
493
+ verbose: int
494
+ Verbosity level of outer cross-validation.
495
+ Follows scikit-learn/joblib converntions.
496
+ 0 means no additional information is printed.
497
+ Larger number generally mean more information is printed.
498
+ Note: verbosity up to 50 will print into standard error,
499
+ while larger than 50 will print in standrad output.
500
+
501
+ Returns
502
+ -------
503
+ scores : pd.DataFrame
504
+ The resulting scores (one column for each score specified).
505
+ Additionally, a 'fit_time' column will be added.
506
+ And, if ``return_estimator='all'`` or
507
+ ``return_estimator='cv'``, an 'estimator' columns with the
508
+ corresponding estimators fitted for each CV split.
509
+ final_estimator : object
510
+ The final estimator, fitted on all the data (only if
511
+ ``return_estimator='all'`` or ``return_estimator='final'``)
512
+ inspector : Inspector | None
513
+ The inspector object (only if ``return_inspector=True``)
514
+
515
+ """
516
+
517
+ # Validate parameters
518
+ (
519
+ df_X,
520
+ df_y,
521
+ df_groups,
522
+ pipeline,
523
+ return_estimator,
524
+ return_inspector,
525
+ wrap_score,
526
+ problem_type,
527
+ ) = _validata_api_params(
528
+ X=X,
529
+ y=y,
530
+ model=model,
531
+ data=data,
532
+ X_types=X_types,
533
+ problem_type=problem_type,
534
+ preprocess=preprocess,
535
+ return_estimator=return_estimator,
536
+ return_inspector=return_inspector,
537
+ groups=groups,
538
+ pos_labels=pos_labels,
539
+ model_params=model_params,
540
+ search_params=search_params,
541
+ seed=seed,
542
+ )
543
+
544
+ include_final_model = return_estimator in ["final", "all"]
545
+ cv_return_estimator = return_estimator in ["cv", "all", "final"]
546
+
368
547
  # Prepare cross validation
369
548
  cv_outer = check_cv(
370
549
  cv, # type: ignore
371
550
  classifier=problem_type == "classification",
551
+ include_final_model=include_final_model,
372
552
  )
373
553
  logger.info(f"Using outer CV scheme {cv_outer}")
374
554
 
375
555
  check_consistency(df_y, cv, groups, problem_type) # type: ignore
376
556
 
377
- cv_return_estimator = return_estimator in ["cv", "all"]
378
557
  scoring = check_scoring(
379
558
  pipeline, # type: ignore
380
559
  scoring,
@@ -407,18 +586,28 @@ def run_cross_validation( # noqa: C901
407
586
  **_sklearn_deprec_fit_params,
408
587
  )
409
588
 
410
- n_repeats = getattr(cv_outer, "n_repeats", 1)
411
- n_folds = len(scores["fit_time"]) // n_repeats
412
-
413
- repeats = np.repeat(np.arange(n_repeats), n_folds)
414
- folds = np.tile(np.arange(n_folds), n_repeats)
415
-
416
589
  fold_sizes = np.array(
417
590
  [
418
591
  list(map(len, x))
419
592
  for x in cv_outer.split(df_X, df_y, groups=df_groups)
420
593
  ]
421
594
  )
595
+
596
+ if include_final_model:
597
+ # If we include the final model, we need to remove the last item in
598
+ # the scores as this is the final model
599
+ pipeline = scores["estimator"][-1]
600
+ if return_estimator == "final":
601
+ scores.pop("estimator")
602
+ scores = {k: v[:-1] for k, v in scores.items()}
603
+ fold_sizes = fold_sizes[:-1]
604
+
605
+ n_repeats = getattr(cv_outer, "n_repeats", 1)
606
+ n_folds = len(scores["fit_time"]) // n_repeats
607
+
608
+ repeats = np.repeat(np.arange(n_repeats), n_folds)
609
+ folds = np.tile(np.arange(n_folds), n_repeats)
610
+
422
611
  scores["n_train"] = fold_sizes[:, 0]
423
612
  scores["n_test"] = fold_sizes[:, 1]
424
613
  scores["repeat"] = repeats
@@ -426,11 +615,10 @@ def run_cross_validation( # noqa: C901
426
615
  scores["cv_mdsum"] = cv_mdsum
427
616
 
428
617
  scores_df = pd.DataFrame(scores)
618
+
429
619
  out = scores_df
430
- if return_estimator in ["final", "all"]:
431
- logger.info("Fitting final model")
432
- pipeline.fit(df_X, df_y, **fit_params)
433
- out = scores_df, pipeline
620
+ if include_final_model:
621
+ out = out, pipeline
434
622
 
435
623
  if return_inspector:
436
624
  inspector = Inspector(
@@ -439,7 +627,7 @@ def run_cross_validation( # noqa: C901
439
627
  X=df_X,
440
628
  y=df_y,
441
629
  groups=df_groups,
442
- cv=cv_outer,
630
+ cv=cv_outer.cv if include_final_model else cv_outer,
443
631
  )
444
632
  if isinstance(out, tuple):
445
633
  out = (*out, inspector)
@@ -447,3 +635,167 @@ def run_cross_validation( # noqa: C901
447
635
  out = out, inspector
448
636
 
449
637
  return out
638
+
639
+
640
+ def run_fit(
641
+ X: List[str], # noqa: N803
642
+ y: str,
643
+ model: Union[str, PipelineCreator, BaseEstimator, List[PipelineCreator]],
644
+ data: pd.DataFrame,
645
+ X_types: Optional[Dict] = None, # noqa: N803
646
+ problem_type: Optional[str] = None,
647
+ preprocess: Union[None, str, List[str]] = None,
648
+ groups: Optional[str] = None,
649
+ pos_labels: Union[str, List[str], None] = None,
650
+ model_params: Optional[Dict] = None,
651
+ search_params: Optional[Dict] = None,
652
+ seed: Optional[int] = None,
653
+ verbose: Optional[int] = 0,
654
+ ):
655
+ """Fit the model on all the data.
656
+
657
+ Parameters
658
+ ----------
659
+ X : list of str
660
+ The features to use.
661
+ See :ref:`data_usage` for details.
662
+ y : str
663
+ The targets to predict.
664
+ See :ref:`data_usage` for details.
665
+ model : str or scikit-learn compatible model.
666
+ If string, it will use one of the available models.
667
+ data : pandas.DataFrame
668
+ DataFrame with the data. See :ref:`data_usage` for details.
669
+ X_types : dict[str, list of str]
670
+ A dictionary containing keys with column type as a str and the
671
+ columns of this column type as a list of str.
672
+ problem_type : str
673
+ The kind of problem to model.
674
+
675
+ Options are:
676
+
677
+ * "classification": Perform a classification
678
+ in which the target (y) has categorical classes (default).
679
+ The parameter pos_labels can be used to convert a target with
680
+ multiple_classes into binary.
681
+ * "regression". Perform a regression. The target (y) has to be
682
+ ordinal at least.
683
+
684
+ preprocess : str, TransformerLike or list or PipelineCreator | None
685
+ Transformer to apply to the features. If string, use one of the
686
+ available transformers. If list, each element can be a string or
687
+ scikit-learn compatible transformer. If None (default), no
688
+ transformation is applied.
689
+
690
+ See documentation for details.
691
+
692
+ groups : str | None
693
+ The grouping labels in case a Group CV is used.
694
+ See :ref:`data_usage` for details.
695
+ pos_labels : str, int, float or list | None
696
+ The labels to interpret as positive. If not None, every element from y
697
+ will be converted to 1 if is equal or in pos_labels and to 0 if not.
698
+ model_params : dict | None
699
+ If not None, this dictionary specifies the model parameters to use
700
+
701
+ The dictionary can define the following keys:
702
+
703
+ * 'STEP__PARAMETER': A value (or several) to be used as PARAMETER for
704
+ STEP in the pipeline. Example: 'svm__probability': True will set
705
+ the parameter 'probability' of the 'svm' model. If more than option
706
+ is provided for at least one hyperparameter, a search will be
707
+ performed.
708
+
709
+ search_params : dict | None
710
+ Additional parameters in case Hyperparameter Tuning is performed, with
711
+ the following keys:
712
+
713
+ * 'kind': The kind of search algorithm to use, Valid options are:
714
+
715
+ * ``"grid"`` : :class:`~sklearn.model_selection.GridSearchCV`
716
+ * ``"random"`` :
717
+ :class:`~sklearn.model_selection.RandomizedSearchCV`
718
+ * ``"bayes"`` : :class:`~skopt.BayesSearchCV`
719
+ * ``"optuna"`` :
720
+ :class:`~optuna_integration.OptunaSearchCV`
721
+ * user-registered searcher name : see
722
+ :func:`~julearn.model_selection.register_searcher`
723
+ * ``scikit-learn``-compatible searcher
724
+
725
+ * 'cv': If a searcher is going to be used, the cross-validation
726
+ splitting strategy to use. Defaults to same CV as for the model
727
+ evaluation.
728
+ * 'scoring': If a searcher is going to be used, the scoring metric to
729
+ evaluate the performance.
730
+
731
+ See :ref:`hp_tuning` for details.
732
+
733
+ seed : int | None
734
+ If not None, set the random seed before any operation. Useful for
735
+ reproducibility.
736
+ verbose: int
737
+ Verbosity level of outer cross-validation.
738
+ Follows scikit-learn/joblib converntions.
739
+ 0 means no additional information is printed.
740
+ Larger number generally mean more information is printed.
741
+ Note: verbosity up to 50 will print into standard error,
742
+ while larger than 50 will print in standrad output.
743
+
744
+ Returns
745
+ -------
746
+ scores : pd.DataFrame
747
+ The resulting scores (one column for each score specified).
748
+ Additionally, a 'fit_time' column will be added.
749
+ And, if ``return_estimator='all'`` or
750
+ ``return_estimator='cv'``, an 'estimator' columns with the
751
+ corresponding estimators fitted for each CV split.
752
+ final_estimator : object
753
+ The final estimator, fitted on all the data (only if
754
+ ``return_estimator='all'`` or ``return_estimator='final'``)
755
+ inspector : Inspector | None
756
+ The inspector object (only if ``return_inspector=True``)
757
+
758
+ """
759
+
760
+ # Validate parameters
761
+ (
762
+ df_X,
763
+ df_y,
764
+ df_groups,
765
+ pipeline,
766
+ _,
767
+ _,
768
+ _,
769
+ problem_type,
770
+ ) = _validata_api_params(
771
+ X=X,
772
+ y=y,
773
+ model=model,
774
+ data=data,
775
+ X_types=X_types,
776
+ problem_type=problem_type,
777
+ preprocess=preprocess,
778
+ return_estimator=None,
779
+ return_inspector=False,
780
+ groups=groups,
781
+ pos_labels=pos_labels,
782
+ model_params=model_params,
783
+ search_params=search_params,
784
+ seed=seed,
785
+ )
786
+
787
+ fit_params = {}
788
+ if df_groups is not None:
789
+ if isinstance(pipeline, BaseSearchCV):
790
+ fit_params["groups"] = df_groups.values
791
+
792
+ _sklearn_deprec_fit_params = {}
793
+ if sklearn.__version__ >= "1.4.0":
794
+ _sklearn_deprec_fit_params["params"] = fit_params
795
+ else:
796
+ _sklearn_deprec_fit_params["fit_params"] = fit_params
797
+
798
+ logger.info("Fitting final model")
799
+ pipeline.fit(df_X, df_y, **fit_params)
800
+
801
+ return pipeline