scitex 2.14.0__py3-none-any.whl → 2.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. scitex/__init__.py +71 -17
  2. scitex/_env_loader.py +156 -0
  3. scitex/_mcp_resources/__init__.py +37 -0
  4. scitex/_mcp_resources/_cheatsheet.py +135 -0
  5. scitex/_mcp_resources/_figrecipe.py +138 -0
  6. scitex/_mcp_resources/_formats.py +102 -0
  7. scitex/_mcp_resources/_modules.py +337 -0
  8. scitex/_mcp_resources/_session.py +149 -0
  9. scitex/_mcp_tools/__init__.py +4 -0
  10. scitex/_mcp_tools/audio.py +66 -0
  11. scitex/_mcp_tools/diagram.py +11 -95
  12. scitex/_mcp_tools/introspect.py +210 -0
  13. scitex/_mcp_tools/plt.py +260 -305
  14. scitex/_mcp_tools/scholar.py +74 -0
  15. scitex/_mcp_tools/social.py +27 -0
  16. scitex/_mcp_tools/template.py +24 -0
  17. scitex/_mcp_tools/writer.py +17 -210
  18. scitex/ai/_gen_ai/_PARAMS.py +10 -7
  19. scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
  20. scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
  21. scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
  22. scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
  23. scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
  24. scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
  25. scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
  26. scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
  27. scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
  28. scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +30 -1550
  29. scitex/ai/classification/timeseries/_sliding_window_core.py +467 -0
  30. scitex/ai/classification/timeseries/_sliding_window_plotting.py +369 -0
  31. scitex/audio/README.md +40 -36
  32. scitex/audio/__init__.py +129 -61
  33. scitex/audio/_branding.py +185 -0
  34. scitex/audio/_mcp/__init__.py +32 -0
  35. scitex/audio/_mcp/handlers.py +59 -6
  36. scitex/audio/_mcp/speak_handlers.py +238 -0
  37. scitex/audio/_relay.py +225 -0
  38. scitex/audio/_tts.py +18 -10
  39. scitex/audio/engines/base.py +17 -10
  40. scitex/audio/engines/elevenlabs_engine.py +7 -2
  41. scitex/audio/mcp_server.py +228 -75
  42. scitex/canvas/README.md +1 -1
  43. scitex/canvas/editor/_dearpygui/__init__.py +25 -0
  44. scitex/canvas/editor/_dearpygui/_editor.py +147 -0
  45. scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
  46. scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
  47. scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
  48. scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
  49. scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
  50. scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
  51. scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
  52. scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
  53. scitex/canvas/editor/_dearpygui/_selection.py +295 -0
  54. scitex/canvas/editor/_dearpygui/_state.py +93 -0
  55. scitex/canvas/editor/_dearpygui/_utils.py +61 -0
  56. scitex/canvas/editor/flask_editor/_core/__init__.py +27 -0
  57. scitex/canvas/editor/flask_editor/_core/_bbox_extraction.py +200 -0
  58. scitex/canvas/editor/flask_editor/_core/_editor.py +173 -0
  59. scitex/canvas/editor/flask_editor/_core/_export_helpers.py +353 -0
  60. scitex/canvas/editor/flask_editor/_core/_routes_basic.py +190 -0
  61. scitex/canvas/editor/flask_editor/_core/_routes_export.py +332 -0
  62. scitex/canvas/editor/flask_editor/_core/_routes_panels.py +252 -0
  63. scitex/canvas/editor/flask_editor/_core/_routes_save.py +218 -0
  64. scitex/canvas/editor/flask_editor/_core.py +25 -1684
  65. scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
  66. scitex/cli/__init__.py +38 -43
  67. scitex/cli/audio.py +160 -41
  68. scitex/cli/capture.py +133 -20
  69. scitex/cli/introspect.py +488 -0
  70. scitex/cli/main.py +200 -109
  71. scitex/cli/mcp.py +60 -34
  72. scitex/cli/plt.py +414 -0
  73. scitex/cli/repro.py +15 -8
  74. scitex/cli/resource.py +15 -8
  75. scitex/cli/scholar/__init__.py +154 -8
  76. scitex/cli/scholar/_crossref_scitex.py +296 -0
  77. scitex/cli/scholar/_fetch.py +25 -3
  78. scitex/cli/social.py +355 -0
  79. scitex/cli/stats.py +136 -11
  80. scitex/cli/template.py +129 -12
  81. scitex/cli/tex.py +15 -8
  82. scitex/cli/writer.py +49 -299
  83. scitex/cloud/__init__.py +41 -2
  84. scitex/config/README.md +1 -1
  85. scitex/config/__init__.py +16 -2
  86. scitex/config/_env_registry.py +256 -0
  87. scitex/context/__init__.py +22 -0
  88. scitex/dev/__init__.py +20 -1
  89. scitex/diagram/__init__.py +42 -19
  90. scitex/diagram/mcp_server.py +13 -125
  91. scitex/gen/__init__.py +50 -14
  92. scitex/gen/_list_packages.py +4 -4
  93. scitex/introspect/__init__.py +82 -0
  94. scitex/introspect/_call_graph.py +303 -0
  95. scitex/introspect/_class_hierarchy.py +163 -0
  96. scitex/introspect/_core.py +41 -0
  97. scitex/introspect/_docstring.py +131 -0
  98. scitex/introspect/_examples.py +113 -0
  99. scitex/introspect/_imports.py +271 -0
  100. scitex/{gen/_inspect_module.py → introspect/_list_api.py} +48 -56
  101. scitex/introspect/_mcp/__init__.py +41 -0
  102. scitex/introspect/_mcp/handlers.py +233 -0
  103. scitex/introspect/_members.py +155 -0
  104. scitex/introspect/_resolve.py +89 -0
  105. scitex/introspect/_signature.py +131 -0
  106. scitex/introspect/_source.py +80 -0
  107. scitex/introspect/_type_hints.py +172 -0
  108. scitex/io/_save.py +1 -2
  109. scitex/io/bundle/README.md +1 -1
  110. scitex/logging/_formatters.py +19 -9
  111. scitex/mcp_server.py +98 -5
  112. scitex/os/__init__.py +4 -0
  113. scitex/{gen → os}/_check_host.py +4 -5
  114. scitex/plt/__init__.py +245 -550
  115. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
  116. scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  117. scitex/plt/gallery/README.md +1 -1
  118. scitex/plt/utils/_hitmap/__init__.py +82 -0
  119. scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
  120. scitex/plt/utils/_hitmap/_color_application.py +346 -0
  121. scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
  122. scitex/plt/utils/_hitmap/_constants.py +40 -0
  123. scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
  124. scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
  125. scitex/plt/utils/_hitmap/_query.py +113 -0
  126. scitex/plt/utils/_hitmap.py +46 -1616
  127. scitex/plt/utils/_metadata/__init__.py +80 -0
  128. scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
  129. scitex/plt/utils/_metadata/_artists/_base.py +195 -0
  130. scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
  131. scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
  132. scitex/plt/utils/_metadata/_artists/_images.py +80 -0
  133. scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
  134. scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
  135. scitex/plt/utils/_metadata/_artists/_text.py +106 -0
  136. scitex/plt/utils/_metadata/_csv.py +416 -0
  137. scitex/plt/utils/_metadata/_detect.py +225 -0
  138. scitex/plt/utils/_metadata/_legend.py +127 -0
  139. scitex/plt/utils/_metadata/_rounding.py +117 -0
  140. scitex/plt/utils/_metadata/_verification.py +202 -0
  141. scitex/schema/README.md +1 -1
  142. scitex/scholar/__init__.py +8 -0
  143. scitex/scholar/_mcp/crossref_handlers.py +265 -0
  144. scitex/scholar/core/Scholar.py +63 -1700
  145. scitex/scholar/core/_mixins/__init__.py +36 -0
  146. scitex/scholar/core/_mixins/_enrichers.py +270 -0
  147. scitex/scholar/core/_mixins/_library_handlers.py +100 -0
  148. scitex/scholar/core/_mixins/_loaders.py +103 -0
  149. scitex/scholar/core/_mixins/_pdf_download.py +375 -0
  150. scitex/scholar/core/_mixins/_pipeline.py +312 -0
  151. scitex/scholar/core/_mixins/_project_handlers.py +125 -0
  152. scitex/scholar/core/_mixins/_savers.py +69 -0
  153. scitex/scholar/core/_mixins/_search.py +103 -0
  154. scitex/scholar/core/_mixins/_services.py +88 -0
  155. scitex/scholar/core/_mixins/_url_finding.py +105 -0
  156. scitex/scholar/crossref_scitex.py +367 -0
  157. scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  158. scitex/scholar/examples/00_run_all.sh +120 -0
  159. scitex/scholar/jobs/_executors.py +27 -3
  160. scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
  161. scitex/scholar/pdf_download/_cli.py +154 -0
  162. scitex/scholar/pdf_download/strategies/__init__.py +11 -8
  163. scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
  164. scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
  165. scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
  166. scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
  167. scitex/scholar/pipelines/_single_steps.py +71 -36
  168. scitex/scholar/storage/_LibraryManager.py +97 -1695
  169. scitex/scholar/storage/_mixins/__init__.py +30 -0
  170. scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
  171. scitex/scholar/storage/_mixins/_library_operations.py +218 -0
  172. scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
  173. scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
  174. scitex/scholar/storage/_mixins/_resolution.py +376 -0
  175. scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
  176. scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
  177. scitex/security/README.md +3 -3
  178. scitex/session/README.md +1 -1
  179. scitex/session/__init__.py +26 -7
  180. scitex/session/_decorator.py +1 -1
  181. scitex/sh/README.md +1 -1
  182. scitex/sh/__init__.py +7 -4
  183. scitex/social/__init__.py +155 -0
  184. scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  185. scitex/stats/_mcp/_handlers/__init__.py +31 -0
  186. scitex/stats/_mcp/_handlers/_corrections.py +113 -0
  187. scitex/stats/_mcp/_handlers/_descriptive.py +78 -0
  188. scitex/stats/_mcp/_handlers/_effect_size.py +106 -0
  189. scitex/stats/_mcp/_handlers/_format.py +94 -0
  190. scitex/stats/_mcp/_handlers/_normality.py +110 -0
  191. scitex/stats/_mcp/_handlers/_posthoc.py +224 -0
  192. scitex/stats/_mcp/_handlers/_power.py +247 -0
  193. scitex/stats/_mcp/_handlers/_recommend.py +102 -0
  194. scitex/stats/_mcp/_handlers/_run_test.py +279 -0
  195. scitex/stats/_mcp/_handlers/_stars.py +48 -0
  196. scitex/stats/_mcp/handlers.py +19 -1171
  197. scitex/stats/auto/_stat_style.py +175 -0
  198. scitex/stats/auto/_style_definitions.py +411 -0
  199. scitex/stats/auto/_styles.py +22 -620
  200. scitex/stats/descriptive/__init__.py +11 -8
  201. scitex/stats/descriptive/_ci.py +39 -0
  202. scitex/stats/power/_power.py +15 -4
  203. scitex/str/__init__.py +2 -1
  204. scitex/str/_title_case.py +63 -0
  205. scitex/template/README.md +1 -1
  206. scitex/template/__init__.py +25 -10
  207. scitex/template/_code_templates.py +147 -0
  208. scitex/template/_mcp/handlers.py +81 -0
  209. scitex/template/_mcp/tool_schemas.py +55 -0
  210. scitex/template/_templates/__init__.py +51 -0
  211. scitex/template/_templates/audio.py +233 -0
  212. scitex/template/_templates/canvas.py +312 -0
  213. scitex/template/_templates/capture.py +268 -0
  214. scitex/template/_templates/config.py +43 -0
  215. scitex/template/_templates/diagram.py +294 -0
  216. scitex/template/_templates/io.py +107 -0
  217. scitex/template/_templates/module.py +53 -0
  218. scitex/template/_templates/plt.py +202 -0
  219. scitex/template/_templates/scholar.py +267 -0
  220. scitex/template/_templates/session.py +130 -0
  221. scitex/template/_templates/session_minimal.py +43 -0
  222. scitex/template/_templates/session_plot.py +67 -0
  223. scitex/template/_templates/session_stats.py +77 -0
  224. scitex/template/_templates/stats.py +323 -0
  225. scitex/template/_templates/writer.py +296 -0
  226. scitex/template/clone_writer_directory.py +5 -5
  227. scitex/ui/_backends/_email.py +10 -2
  228. scitex/ui/_backends/_webhook.py +5 -1
  229. scitex/web/_search_pubmed.py +10 -6
  230. scitex/writer/README.md +1 -1
  231. scitex/writer/__init__.py +43 -34
  232. scitex/writer/_mcp/handlers.py +11 -744
  233. scitex/writer/_mcp/tool_schemas.py +5 -335
  234. scitex-2.15.3.dist-info/METADATA +667 -0
  235. {scitex-2.14.0.dist-info → scitex-2.15.3.dist-info}/RECORD +241 -120
  236. scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
  237. scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
  238. scitex/diagram/_compile.py +0 -312
  239. scitex/diagram/_diagram.py +0 -355
  240. scitex/diagram/_mcp/__init__.py +0 -4
  241. scitex/diagram/_mcp/handlers.py +0 -400
  242. scitex/diagram/_mcp/tool_schemas.py +0 -157
  243. scitex/diagram/_presets.py +0 -173
  244. scitex/diagram/_schema.py +0 -182
  245. scitex/diagram/_split.py +0 -278
  246. scitex/gen/_ci.py +0 -12
  247. scitex/gen/_title_case.py +0 -89
  248. scitex/plt/_mcp/__init__.py +0 -4
  249. scitex/plt/_mcp/_handlers_annotation.py +0 -102
  250. scitex/plt/_mcp/_handlers_figure.py +0 -195
  251. scitex/plt/_mcp/_handlers_plot.py +0 -252
  252. scitex/plt/_mcp/_handlers_style.py +0 -219
  253. scitex/plt/_mcp/handlers.py +0 -74
  254. scitex/plt/_mcp/tool_schemas.py +0 -497
  255. scitex/plt/mcp_server.py +0 -231
  256. scitex/scholar/examples/SUGGESTIONS.md +0 -865
  257. scitex/scholar/examples/dev.py +0 -38
  258. scitex-2.14.0.dist-info/METADATA +0 -1238
  259. /scitex/{gen → context}/_detect_environment.py +0 -0
  260. /scitex/{gen → context}/_get_notebook_path.py +0 -0
  261. /scitex/{gen/_shell.py → sh/_shell_legacy.py} +0 -0
  262. {scitex-2.14.0.dist-info → scitex-2.15.3.dist-info}/WHEEL +0 -0
  263. {scitex-2.14.0.dist-info → scitex-2.15.3.dist-info}/entry_points.txt +0 -0
  264. {scitex-2.14.0.dist-info → scitex-2.15.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,120 +1,58 @@
1
1
  #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Timestamp: "2025-10-04 04:38:23 (ywatanabe)"
4
- # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/classification/reporters/_SingleClassificationReporter.py
5
- # ----------------------------------------
6
- from __future__ import annotations
7
- import os
8
-
9
- __FILE__ = "./src/scitex/ml/classification/reporters/_SingleClassificationReporter.py"
10
- __DIR__ = os.path.dirname(__FILE__)
11
- # ----------------------------------------
12
-
13
- __FILE__ = __file__
14
-
15
- from pprint import pprint
2
+ # Timestamp: "2026-01-24 (ywatanabe)"
3
+ # File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_SingleClassificationReporter.py
16
4
 
17
5
  """
18
6
  Improved Single Classification Reporter with unified API.
19
7
 
20
- Enhanced version that addresses all identified issues:
21
- - Unified API interface
22
- - Lazy directory creation
23
- - Numerical precision control
24
- - Graceful plotting with error handling
25
- - Consistent parameter names
8
+ This module provides a comprehensive classification reporter that:
9
+ - Uses unified API interface
10
+ - Supports lazy directory creation
11
+ - Provides numerical precision control
12
+ - Creates visualizations with proper error handling
13
+ - Maintains consistent parameter naming
14
+
15
+ The main class inherits from multiple mixins for modular functionality:
16
+ - MetricsMixin: Metrics calculation and aggregation
17
+ - StorageMixin: File storage and organization
18
+ - PlottingMixin: Visualization generation
19
+ - FeatureImportanceMixin: Feature importance analysis
20
+ - CVSummaryMixin: Cross-validation summary generation
21
+ - ReportsMixin: Multi-format report generation
26
22
  """
27
23
 
24
+ from __future__ import annotations
25
+
28
26
  from pathlib import Path
27
+ from pprint import pprint
29
28
  from typing import Any, Dict, List, Optional, Union
30
29
 
31
- import numpy as np
32
- import pandas as pd
33
30
  from scitex.logging import getLogger
34
31
 
35
- # Import base class and utilities
36
32
  from ._BaseClassificationReporter import BaseClassificationReporter, ReporterConfig
37
-
38
- # Import original metric calculation functions (these are good)
39
- from .reporter_utils import (
40
- calc_bacc,
41
- calc_clf_report,
42
- calc_conf_mat,
43
- calc_mcc,
44
- calc_pre_rec_auc,
45
- calc_roc_auc,
33
+ from ._mixins import (
34
+ CVSummaryMixin,
35
+ FeatureImportanceMixin,
36
+ MetricsMixin,
37
+ PlottingMixin,
38
+ ReportsMixin,
39
+ StorageMixin,
46
40
  )
47
41
  from .reporter_utils._Plotter import Plotter
48
- from .reporter_utils.reporting import (
49
- create_summary_statistics,
50
- generate_latex_report,
51
- generate_markdown_report,
52
- generate_org_report,
53
- )
54
- from .reporter_utils.storage import MetricStorage, save_metric
42
+ from .reporter_utils.storage import MetricStorage
55
43
 
56
44
  logger = getLogger(__name__)
57
45
 
58
46
 
59
- # Fold directory and filename prefixes for consistent naming
60
- FOLD_DIR_PREFIX_PATTERN = "fold_{fold:02d}" # Directory: fold_00, fold_01, ...
61
- FOLD_FILE_PREFIX_PATTERN = "fold-{fold:02d}" # Filename prefix: fold-00_, fold-01_, ...
62
-
63
- # Filename patterns for consistent naming across the reporter
64
- # Note: fold-{fold:02d} comes first to group files by fold when sorted
65
- # Convention: hyphens within chunks, underscores between chunks
66
- FILENAME_PATTERNS = {
67
- # Individual fold metrics (with metric value in filename)
68
- "fold_metric_with_value": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}-{{value:.3f}}.json",
69
- "fold_metric": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}.json",
70
- # Confusion matrix
71
- "confusion_matrix_csv": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.csv",
72
- "confusion_matrix_csv_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.csv",
73
- "confusion_matrix_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.jpg",
74
- "confusion_matrix_jpg_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.jpg",
75
- # Classification report
76
- "classification_report": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.csv",
77
- # ROC curve
78
- "roc_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.csv",
79
- "roc_curve_csv_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.csv",
80
- "roc_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.jpg",
81
- "roc_curve_jpg_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.jpg",
82
- # PR curve
83
- "pr_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.csv",
84
- "pr_curve_csv_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.csv",
85
- "pr_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.jpg",
86
- "pr_curve_jpg_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.jpg",
87
- # Raw prediction data (optional, enabled via calculate_metrics parameters)
88
- "y_true": f"{FOLD_FILE_PREFIX_PATTERN}_y-true.csv",
89
- "y_pred": f"{FOLD_FILE_PREFIX_PATTERN}_y-pred.csv",
90
- "y_proba": f"{FOLD_FILE_PREFIX_PATTERN}_y-proba.csv",
91
- # Metrics dashboard
92
- "metrics_summary": f"{FOLD_FILE_PREFIX_PATTERN}_metrics-summary.jpg",
93
- # Feature importance
94
- "feature_importance_json": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.json",
95
- "feature_importance_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.jpg",
96
- # Classification report edge cases (when CSV conversion fails)
97
- "classification_report_json": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.json",
98
- "classification_report_txt": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.txt",
99
- # Folds all (CV summary)
100
- "cv_summary_metric": "cv-summary_{metric_name}_mean-{mean:.3f}_std-{std:.3f}_n-{n_folds}.json",
101
- "cv_summary_confusion_matrix_csv": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
102
- "cv_summary_confusion_matrix_jpg": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
103
- "cv_summary_classification_report": "cv-summary_classification-report_n-{n_folds}.csv",
104
- "cv_summary_roc_curve_csv": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
105
- "cv_summary_roc_curve_jpg": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
106
- "cv_summary_pr_curve_csv": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
107
- "cv_summary_pr_curve_jpg": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
108
- "cv_summary_feature_importance_json": "cv-summary_feature-importance_n-{n_folds}.json",
109
- "cv_summary_feature_importance_jpg": "cv-summary_feature-importance_n-{n_folds}.jpg",
110
- "cv_summary_summary": "cv-summary_summary.json",
111
- # Folds all edge cases (when balanced_acc is None)
112
- "cv_summary_confusion_matrix_csv_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.csv",
113
- "cv_summary_confusion_matrix_jpg_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.jpg",
114
- }
115
-
116
-
117
- class SingleTaskClassificationReporter(BaseClassificationReporter):
47
+ class SingleTaskClassificationReporter(
48
+ MetricsMixin,
49
+ StorageMixin,
50
+ PlottingMixin,
51
+ FeatureImportanceMixin,
52
+ CVSummaryMixin,
53
+ ReportsMixin,
54
+ BaseClassificationReporter,
55
+ ):
118
56
  """
119
57
  Improved single-task classification reporter with unified API.
120
58
 
@@ -142,6 +80,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
142
80
  Base directory for outputs. If None, creates timestamped directory.
143
81
  config : ReporterConfig, optional
144
82
  Configuration object for advanced settings
83
+ verbose : bool, default True
84
+ Print initialization message
145
85
  **kwargs
146
86
  Additional arguments passed to base class
147
87
 
@@ -174,11 +114,9 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
174
114
  verbose: bool = True,
175
115
  **kwargs,
176
116
  ):
177
- # Use config or create default
178
117
  if config is None:
179
118
  config = ReporterConfig()
180
119
 
181
- # Initialize base class with config settings
182
120
  super().__init__(
183
121
  output_dir=output_dir,
184
122
  precision=config.precision,
@@ -186,14 +124,11 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
186
124
  )
187
125
 
188
126
  self.config = config
189
- self.session_config = None # Will store SciTeX session CONFIG if provided
127
+ self.session_config = None
190
128
  self.storage = MetricStorage(self.output_dir, precision=self.precision)
191
129
  self.plotter = Plotter(enable_plotting=True)
192
130
 
193
- # Track calculated metrics for summary
194
131
  self.fold_metrics: Dict[int, Dict[str, Any]] = {}
195
-
196
- # Store all predictions for overall curves
197
132
  self.all_predictions: List[Dict[str, Any]] = []
198
133
 
199
134
  if verbose:
@@ -212,1252 +147,6 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
212
147
  """
213
148
  self.session_config = config
214
149
 
215
- def calculate_metrics(
216
- self,
217
- y_true: np.ndarray,
218
- y_pred: np.ndarray,
219
- y_proba: Optional[np.ndarray] = None,
220
- labels: Optional[List[str]] = None,
221
- fold: Optional[int] = None,
222
- verbose=True,
223
- store_y_true: bool = True,
224
- store_y_pred: bool = True,
225
- store_y_proba: bool = True,
226
- model=None,
227
- feature_names: Optional[List[str]] = None,
228
- ) -> Dict[str, Any]:
229
- """
230
- Calculate and save classification metrics using unified API.
231
-
232
- Parameters
233
- ----------
234
- y_true : np.ndarray
235
- True class labels
236
- y_pred : np.ndarray
237
- Predicted class labels
238
- y_proba : np.ndarray, optional
239
- Prediction probabilities (required for AUC metrics)
240
- labels : List[str], optional
241
- Class labels for display
242
- fold : int, optional
243
- Fold index for cross-validation
244
- verbose : bool, default True
245
- Print progress messages
246
- store_y_true : bool, default True
247
- Save y_true as CSV with sample_index and fold columns
248
- store_y_pred : bool, default True
249
- Save y_pred as CSV with sample_index and fold columns
250
- store_y_proba : bool, default True
251
- Save y_proba as CSV with sample_index and fold columns
252
- model : object, optional
253
- Trained model for feature importance extraction
254
- feature_names : List[str], optional
255
- Feature names for feature importance (required if model is provided)
256
-
257
- Returns
258
- -------
259
- Dict[str, Any]
260
- Dictionary of calculated metrics
261
- """
262
-
263
- if verbose:
264
- if fold:
265
- print()
266
- logger.info(f"Calculating metrics for fold #{fold:02d}...")
267
- else:
268
- logger.info(f"Calculating metrics...")
269
-
270
- # Validate inputs
271
- if len(y_true) != len(y_pred):
272
- raise ValueError("y_true and y_pred must have same length")
273
-
274
- if y_proba is not None and len(y_true) != len(y_proba):
275
- raise ValueError("y_true and y_proba must have same length")
276
-
277
- # Set default fold index
278
- if fold is None:
279
- fold = 0
280
-
281
- # Set default labels if not provided
282
- if labels is None:
283
- unique_labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
284
- labels = [f"Class_{i}" for i in unique_labels]
285
-
286
- # Calculate all metrics
287
- metrics = {}
288
-
289
- # Core metrics (always calculated) - pass fold to all metrics
290
- metrics["balanced-accuracy"] = calc_bacc(y_true, y_pred, fold=fold)
291
- metrics["mcc"] = calc_mcc(y_true, y_pred, fold=fold)
292
-
293
- metrics["confusion_matrix"] = calc_conf_mat(
294
- y_true=y_true, y_pred=y_pred, labels=labels, fold=fold
295
- )
296
- metrics["classification_report"] = calc_clf_report(
297
- y_true, y_pred, labels, fold=fold
298
- )
299
-
300
- # AUC metrics (only if probabilities available)
301
- if y_proba is not None:
302
- try:
303
- from scitex.ai.metrics import calc_pre_rec_auc, calc_roc_auc
304
-
305
- metrics["roc-auc"] = calc_roc_auc(
306
- y_true,
307
- y_proba,
308
- labels=labels,
309
- fold=fold,
310
- return_curve=False,
311
- )
312
- metrics["pr-auc"] = calc_pre_rec_auc(
313
- y_true,
314
- y_proba,
315
- labels=labels,
316
- fold=fold,
317
- return_curve=False,
318
- )
319
- except Exception as e:
320
- logger.warning(f"Could not calculate AUC metrics: {e}")
321
-
322
- # Round all numerical values
323
- metrics = self._round_numeric(metrics)
324
-
325
- # Add labels to metrics for later use
326
- metrics["labels"] = labels
327
-
328
- if verbose:
329
- logger.info(f"Metrics calculated:")
330
- pprint(metrics)
331
-
332
- # Store metrics for summary
333
- self.fold_metrics[fold] = metrics.copy()
334
-
335
- # Store predictions for overall curves
336
- if y_proba is not None:
337
- self.all_predictions.append(
338
- {
339
- "fold": fold,
340
- "y_true": y_true.copy(),
341
- "y_proba": y_proba.copy(),
342
- }
343
- )
344
-
345
- # Save metrics if requested
346
- self._save_fold_metrics(metrics, fold, labels)
347
-
348
- # Generate plots if requested
349
- self._create_plots(y_true, y_pred, y_proba, labels, fold, metrics)
350
-
351
- # Handle feature importance automatically if model provided
352
- if model is not None and feature_names is not None:
353
- try:
354
- from scitex.ai.feature_selection import extract_feature_importance
355
-
356
- importance_dict = extract_feature_importance(
357
- model, feature_names, method="auto"
358
- )
359
- if importance_dict:
360
- # Store in fold metrics for cross-fold aggregation
361
- metrics["feature-importance"] = importance_dict
362
- self.fold_metrics[fold]["feature-importance"] = importance_dict
363
-
364
- # Save feature importance
365
- fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
366
- filename = FILENAME_PATTERNS["feature_importance_json"].format(
367
- fold=fold
368
- )
369
- self.storage.save(importance_dict, f"{fold_dir}/{filename}")
370
-
371
- if verbose:
372
- logger.info(f" Feature importance extracted and saved")
373
- except Exception as e:
374
- logger.warning(f"Could not extract feature importance: {e}")
375
-
376
- # Save raw predictions if requested (as CSV using DataFrames)
377
- # Include sample_index for easy concatenation across folds
378
- if store_y_true or store_y_pred or store_y_proba:
379
- fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
380
- sample_indices = np.arange(len(y_true))
381
-
382
- # Warn if file size will be large (>10MB estimated)
383
- n_samples = len(y_true)
384
- estimated_size_mb = 0
385
- if store_y_true:
386
- estimated_size_mb += n_samples * 0.0001 # ~100 bytes per row
387
- if store_y_pred:
388
- estimated_size_mb += n_samples * 0.0001
389
- if store_y_proba and y_proba is not None:
390
- n_classes = 1 if y_proba.ndim == 1 else y_proba.shape[1]
391
- estimated_size_mb += n_samples * n_classes * 0.0001
392
-
393
- if estimated_size_mb > 10:
394
- logger.warning(
395
- f"Storing raw predictions for fold {fold} will create ~{estimated_size_mb:.1f}MB of CSV files. "
396
- f"Set store_y_true/store_y_pred/store_y_proba=False to disable."
397
- )
398
-
399
- if store_y_true:
400
- filename = FILENAME_PATTERNS["y_true"].format(fold=fold)
401
- # Convert to DataFrame for CSV format with index
402
- df_y_true = pd.DataFrame(
403
- {"sample_index": sample_indices, "fold": fold, "y_true": y_true}
404
- )
405
- self.storage.save(df_y_true, f"{fold_dir}/{filename}")
406
-
407
- if store_y_pred:
408
- filename = FILENAME_PATTERNS["y_pred"].format(fold=fold)
409
- # Convert to DataFrame for CSV format with index
410
- df_y_pred = pd.DataFrame(
411
- {"sample_index": sample_indices, "fold": fold, "y_pred": y_pred}
412
- )
413
- self.storage.save(df_y_pred, f"{fold_dir}/{filename}")
414
-
415
- if store_y_proba and y_proba is not None:
416
- filename = FILENAME_PATTERNS["y_proba"].format(fold=fold)
417
- # Convert to DataFrame for CSV format with index
418
- # Handle both 1D (binary) and 2D (multiclass) probability arrays
419
- if y_proba.ndim == 1:
420
- df_y_proba = pd.DataFrame(
421
- {
422
- "sample_index": sample_indices,
423
- "fold": fold,
424
- "y_proba": y_proba,
425
- }
426
- )
427
- else:
428
- # Create column names for each class
429
- data = {"sample_index": sample_indices, "fold": fold}
430
- for i in range(y_proba.shape[1]):
431
- data[f"proba_class_{i}"] = y_proba[:, i]
432
- df_y_proba = pd.DataFrame(data)
433
- self.storage.save(df_y_proba, f"{fold_dir}/{filename}")
434
-
435
- return metrics
436
-
437
- def _save_fold_metrics(
438
- self, metrics: Dict[str, Any], fold: int, labels: List[str]
439
- ) -> None:
440
- """Save metrics for a specific fold in shallow directory structure."""
441
- fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
442
-
443
- # Extract metric values for filenames
444
- balanced_acc = self._extract_metric_value(metrics.get("balanced-accuracy"))
445
- mcc_value = self._extract_metric_value(metrics.get("mcc"))
446
- roc_auc_value = self._extract_metric_value(metrics.get("roc-auc"))
447
- pr_auc_value = self._extract_metric_value(metrics.get("pr-auc"))
448
-
449
- # Save individual metrics with values in filenames
450
- for metric_name, metric_value in metrics.items():
451
- # Extract actual value if it's a wrapped dict with 'value' key
452
- if isinstance(metric_value, dict) and "value" in metric_value:
453
- actual_value = metric_value["value"]
454
- else:
455
- actual_value = metric_value
456
-
457
- if metric_name == "confusion_matrix":
458
- # Save confusion matrix as CSV with proper formatting
459
- try:
460
- # actual_value is already a DataFrame from calc_conf_mat
461
- # Just rename the index and columns
462
- if isinstance(actual_value, pd.DataFrame):
463
- cm_df = actual_value.copy()
464
- cm_df.index = [f"True_{label}" for label in labels]
465
- cm_df.columns = [f"Pred_{label}" for label in labels]
466
- else:
467
- # Fallback for numpy array
468
- cm_df = pd.DataFrame(
469
- actual_value,
470
- index=[f"True_{label}" for label in labels],
471
- columns=[f"Pred_{label}" for label in labels],
472
- )
473
- except Exception as e:
474
- logger.error(f"Error formatting confusion matrix: {e}")
475
- cm_df = None
476
-
477
- # Save if cm_df was created successfully
478
- if cm_df is not None:
479
- # Create filename with balanced accuracy
480
- if balanced_acc is not None:
481
- cm_filename = FILENAME_PATTERNS["confusion_matrix_csv"].format(
482
- fold=fold, bacc=balanced_acc
483
- )
484
- else:
485
- cm_filename = FILENAME_PATTERNS[
486
- "confusion_matrix_csv_no_bacc"
487
- ].format(fold=fold)
488
-
489
- # Save with index=True to preserve row labels
490
- self.storage.save(cm_df, f"{fold_dir}/{cm_filename}", index=True)
491
-
492
- elif metric_name == "classification_report":
493
- # Save classification report with consistent naming
494
- report_filename = FILENAME_PATTERNS["classification_report"].format(
495
- fold=fold
496
- )
497
- if isinstance(actual_value, pd.DataFrame):
498
- # Reset index to make it an ordinary column with name
499
- report_df = actual_value.reset_index()
500
- report_df = report_df.rename(columns={"index": "class"})
501
- self.storage.save(report_df, f"{fold_dir}/{report_filename}")
502
- elif isinstance(actual_value, dict):
503
- # Try to create DataFrame from dict
504
- try:
505
- report_df = pd.DataFrame(actual_value).transpose()
506
- self.storage.save(report_df, f"{fold_dir}/{report_filename}")
507
- except:
508
- # Save as JSON if DataFrame conversion fails
509
- report_filename = FILENAME_PATTERNS[
510
- "classification_report_json"
511
- ].format(fold=fold)
512
- self.storage.save(
513
- actual_value,
514
- f"{fold_dir}/{report_filename}",
515
- )
516
- else:
517
- # String or other format
518
- report_filename = FILENAME_PATTERNS[
519
- "classification_report_txt"
520
- ].format(fold=fold)
521
- self.storage.save(actual_value, f"{fold_dir}/{report_filename}")
522
-
523
- elif metric_name == "balanced-accuracy" and balanced_acc is not None:
524
- # Save with value in filename
525
- filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
526
- fold=fold,
527
- metric_name="balanced-accuracy",
528
- value=balanced_acc,
529
- )
530
- save_metric(
531
- actual_value,
532
- self.output_dir / f"{fold_dir}/{filename}",
533
- fold=fold,
534
- precision=self.precision,
535
- )
536
- elif metric_name == "mcc" and mcc_value is not None:
537
- # Save with value in filename
538
- filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
539
- fold=fold, metric_name="mcc", value=mcc_value
540
- )
541
- save_metric(
542
- actual_value,
543
- self.output_dir / f"{fold_dir}/{filename}",
544
- fold=fold,
545
- precision=self.precision,
546
- )
547
- elif metric_name == "roc-auc" and roc_auc_value is not None:
548
- # Save with value in filename
549
- filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
550
- fold=fold, metric_name="roc-auc", value=roc_auc_value
551
- )
552
- save_metric(
553
- actual_value,
554
- self.output_dir / f"{fold_dir}/{filename}",
555
- fold=fold,
556
- precision=self.precision,
557
- )
558
- elif metric_name == "pr-auc" and pr_auc_value is not None:
559
- # Save with value in filename
560
- filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
561
- fold=fold, metric_name="pr-auc", value=pr_auc_value
562
- )
563
- save_metric(
564
- actual_value,
565
- self.output_dir / f"{fold_dir}/{filename}",
566
- fold=fold,
567
- precision=self.precision,
568
- )
569
-
570
- def _extract_metric_value(self, metric_data: Any) -> Optional[float]:
571
- """Extract numeric value from metric data."""
572
- if metric_data is None:
573
- return None
574
- if isinstance(metric_data, dict) and "value" in metric_data:
575
- return float(metric_data["value"])
576
- if isinstance(metric_data, (int, float, np.number)):
577
- return float(metric_data)
578
- return None
579
-
580
- def _save_curve_data(
581
- self,
582
- y_true: np.ndarray,
583
- y_proba: Optional[np.ndarray],
584
- fold: int,
585
- metrics: Dict[str, Any],
586
- ) -> None:
587
- """Save ROC and PR curve data as CSV files with metric values in filenames."""
588
- if y_proba is None:
589
- return
590
-
591
- from sklearn.metrics import (
592
- auc,
593
- average_precision_score,
594
- precision_recall_curve,
595
- roc_curve,
596
- )
597
-
598
- fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
599
-
600
- # Handle binary vs multiclass
601
- if y_proba.ndim == 1 or y_proba.shape[1] == 2:
602
- # Binary classification
603
- if y_proba.ndim == 2:
604
- y_proba_pos = y_proba[:, 1]
605
- else:
606
- y_proba_pos = y_proba
607
-
608
- # Normalize labels to integers for sklearn curve functions
609
- from scitex.ai.metrics import _normalize_labels
610
-
611
- y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
612
-
613
- # ROC curve data
614
- fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
615
- roc_auc = auc(fpr, tpr)
616
-
617
- # Create ROC curve DataFrame with just FPR and TPR columns
618
- roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
619
-
620
- # Save with AUC value in filename
621
- roc_filename = FILENAME_PATTERNS["roc_curve_csv"].format(
622
- fold=fold, auc=roc_auc
623
- )
624
- self.storage.save(roc_df, f"{fold_dir}/{roc_filename}")
625
-
626
- # PR curve data
627
- precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
628
- avg_precision = average_precision_score(y_true_norm, y_proba_pos)
629
-
630
- # Create PR curve DataFrame with Recall and Precision columns
631
- pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
632
-
633
- # Save with AP value in filename
634
- pr_filename = FILENAME_PATTERNS["pr_curve_csv"].format(
635
- fold=fold, ap=avg_precision
636
- )
637
- self.storage.save(pr_df, f"{fold_dir}/{pr_filename}")
638
-
639
- def _create_plots(
640
- self,
641
- y_true: np.ndarray,
642
- y_pred: np.ndarray,
643
- y_proba: Optional[np.ndarray],
644
- labels: List[str],
645
- fold: int,
646
- metrics: Dict[str, Any],
647
- ) -> None:
648
- """Create and save plots with metric-based filenames in unified structure."""
649
- # Use unified fold directory
650
- fold_dir = self._create_subdir_if_needed(
651
- FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
652
- )
653
- fold_dir.mkdir(parents=True, exist_ok=True)
654
-
655
- # # Save curve data for external plotting
656
- # self._save_curve_data(y_true, y_proba, fold, metrics)
657
-
658
- # Confusion matrix plot with metric in filename
659
- if "confusion_matrix" in metrics:
660
- # Extract actual confusion matrix value if wrapped in dict
661
- cm_data = metrics["confusion_matrix"]
662
- if isinstance(cm_data, dict) and "value" in cm_data:
663
- cm_data = cm_data["value"]
664
-
665
- # Get balanced accuracy for title and filename
666
- balanced_acc = metrics.get("balanced-accuracy", {})
667
- if isinstance(balanced_acc, dict) and "value" in balanced_acc:
668
- balanced_acc = balanced_acc["value"]
669
- elif isinstance(balanced_acc, (float, np.floating)):
670
- balanced_acc = float(balanced_acc)
671
- else:
672
- balanced_acc = None
673
-
674
- # Create title with balanced accuracy and filename with fold and metric
675
- if balanced_acc is not None:
676
- title = f"Confusion Matrix (Fold {fold:02d}) - Balanced Acc: {balanced_acc:.3f}"
677
- filename = FILENAME_PATTERNS["confusion_matrix_jpg"].format(
678
- fold=fold, bacc=balanced_acc
679
- )
680
- else:
681
- title = f"Confusion Matrix (Fold {fold:02d})"
682
- filename = FILENAME_PATTERNS["confusion_matrix_jpg_no_bacc"].format(
683
- fold=fold
684
- )
685
-
686
- self.plotter.create_confusion_matrix_plot(
687
- cm_data,
688
- labels=labels,
689
- save_path=fold_dir / filename,
690
- title=title,
691
- )
692
-
693
- # ROC curve with AUC in filename (if probabilities available)
694
- if y_proba is not None:
695
- # Get AUC for filename
696
- roc_auc = metrics.get("roc-auc", {})
697
- if isinstance(roc_auc, dict) and "value" in roc_auc:
698
- roc_auc_val = roc_auc["value"]
699
- roc_filename = FILENAME_PATTERNS["roc_curve_jpg"].format(
700
- fold=fold, auc=roc_auc_val
701
- )
702
- else:
703
- roc_filename = FILENAME_PATTERNS["roc_curve_jpg_no_auc"].format(
704
- fold=fold
705
- )
706
-
707
- self.plotter.create_roc_curve(
708
- y_true,
709
- y_proba,
710
- labels=labels,
711
- save_path=fold_dir / roc_filename,
712
- title=f"ROC Curve (Fold {fold:02d})",
713
- )
714
-
715
- # PR curve with AP in filename
716
- pr_auc = metrics.get("pr-auc", {})
717
- if isinstance(pr_auc, dict) and "value" in pr_auc:
718
- pr_auc_val = pr_auc["value"]
719
- pr_filename = FILENAME_PATTERNS["pr_curve_jpg"].format(
720
- fold=fold, ap=pr_auc_val
721
- )
722
- else:
723
- pr_filename = FILENAME_PATTERNS["pr_curve_jpg_no_ap"].format(fold=fold)
724
-
725
- self.plotter.create_precision_recall_curve(
726
- y_true,
727
- y_proba,
728
- labels=labels,
729
- save_path=fold_dir / pr_filename,
730
- title=f"Precision-Recall Curve (Fold {fold:02d})",
731
- )
732
-
733
- # NEW: Create comprehensive metrics visualization dashboard
734
- # This automatically creates a 4-panel figure with confusion matrix, ROC, PR curve, and metrics table
735
- summary_filename = FILENAME_PATTERNS["metrics_summary"].format(fold=fold)
736
- self.plotter.create_metrics_visualization(
737
- metrics=metrics,
738
- y_true=y_true,
739
- y_pred=y_pred,
740
- y_proba=y_proba,
741
- labels=labels,
742
- save_path=fold_dir / summary_filename,
743
- title="Classification Metrics Dashboard",
744
- fold=fold,
745
- verbose=False, # Already have verbose output from individual plots
746
- )
747
-
748
- def get_summary(self) -> Dict[str, Any]:
749
- """
750
- Get summary of all calculated metrics across folds.
751
-
752
- Returns
753
- -------
754
- Dict[str, Any]
755
- Summary statistics across all folds
756
- """
757
- if not self.fold_metrics:
758
- return {"error": "No metrics calculated yet"}
759
-
760
- summary = {
761
- "output_dir": str(self.output_dir),
762
- "total_folds": len(self.fold_metrics),
763
- "metrics_summary": {},
764
- }
765
-
766
- # Aggregate confusion matrices across all folds
767
- confusion_matrices = []
768
- for fold_metrics in self.fold_metrics.values():
769
- if "confusion_matrix" in fold_metrics:
770
- cm_data = fold_metrics["confusion_matrix"]
771
- # Extract actual value if wrapped in dict
772
- if isinstance(cm_data, dict) and "value" in cm_data:
773
- cm_data = cm_data["value"]
774
- if cm_data is not None:
775
- confusion_matrices.append(cm_data)
776
-
777
- if confusion_matrices:
778
- # Sum all confusion matrices to get overall counts
779
- overall_cm = np.sum(confusion_matrices, axis=0)
780
- summary["overall_confusion_matrix"] = overall_cm.tolist()
781
-
782
- # Also calculate normalized version (as percentages)
783
- overall_cm_normalized = overall_cm / overall_cm.sum()
784
- summary["overall_confusion_matrix_normalized"] = self._round_numeric(
785
- overall_cm_normalized.tolist()
786
- )
787
-
788
- # Calculate summary statistics for scalar metrics
789
- scalar_metrics = ["balanced-accuracy", "mcc", "roc-auc", "pr-auc"]
790
-
791
- for metric_name in scalar_metrics:
792
- values = []
793
- for fold_metrics in self.fold_metrics.values():
794
- if metric_name in fold_metrics:
795
- # Extract actual value if it's wrapped in dict
796
- metric_val = fold_metrics[metric_name]
797
- if isinstance(metric_val, dict) and "value" in metric_val:
798
- values.append(metric_val["value"])
799
- else:
800
- values.append(metric_val)
801
-
802
- if values:
803
- values = np.array(values)
804
- summary["metrics_summary"][metric_name] = {
805
- "mean": self._round_numeric(np.mean(values)),
806
- "std": self._round_numeric(np.std(values)),
807
- "min": self._round_numeric(np.min(values)),
808
- "max": self._round_numeric(np.max(values)),
809
- "values": self._round_numeric(values.tolist()),
810
- }
811
-
812
- # Aggregate feature importance across folds
813
- feature_importances_list = []
814
- for fold_metrics in self.fold_metrics.values():
815
- if "feature-importance" in fold_metrics:
816
- feature_importances_list.append(fold_metrics["feature-importance"])
817
-
818
- if feature_importances_list:
819
- from scitex.ai.feature_selection import (
820
- aggregate_feature_importances,
821
- create_feature_importance_dataframe,
822
- )
823
-
824
- aggregated_importances = aggregate_feature_importances(
825
- feature_importances_list
826
- )
827
- summary["feature-importance"] = aggregated_importances
828
-
829
- return summary
830
-
831
- def create_cv_aggregation_visualizations(
832
- self,
833
- output_dir: Optional[Path] = None,
834
- show_individual_folds: bool = True,
835
- fold_alpha: float = 0.15,
836
- ) -> None:
837
- """
838
- Create CV aggregation visualizations with faded individual fold lines.
839
-
840
- This creates publication-quality cross-validation plots showing:
841
- - Individual fold curves (faded/transparent)
842
- - Mean curve across folds (bold)
843
- - Confidence intervals (± 1 std. dev.)
844
-
845
- Parameters
846
- ----------
847
- output_dir : Path, optional
848
- Directory to save plots (defaults to cv_summary)
849
- show_individual_folds : bool, default True
850
- Whether to show individual fold curves
851
- fold_alpha : float, default 0.15
852
- Transparency for individual fold curves (0-1)
853
- """
854
- if not self.all_predictions:
855
- logger.warning("No predictions stored for CV aggregation visualizations")
856
- return
857
-
858
- if output_dir is None:
859
- output_dir = self._create_subdir_if_needed("cv_summary")
860
- output_dir.mkdir(parents=True, exist_ok=True)
861
-
862
- n_folds = len(self.all_predictions)
863
-
864
- # ROC curve with faded fold lines
865
- roc_save_path = output_dir / f"roc_cv_aggregation_n{n_folds}.jpg"
866
- self.plotter.create_cv_aggregation_plot(
867
- fold_predictions=self.all_predictions,
868
- curve_type="roc",
869
- save_path=roc_save_path,
870
- show_individual_folds=show_individual_folds,
871
- fold_alpha=fold_alpha,
872
- title=f"ROC Curves - Cross Validation (n={n_folds} folds)",
873
- verbose=True,
874
- )
875
- logger.info(f"Created CV aggregation ROC plot with faded fold lines")
876
-
877
- # PR curve with faded fold lines
878
- pr_save_path = output_dir / f"pr_cv_aggregation_n{n_folds}.jpg"
879
- self.plotter.create_cv_aggregation_plot(
880
- fold_predictions=self.all_predictions,
881
- curve_type="pr",
882
- save_path=pr_save_path,
883
- show_individual_folds=show_individual_folds,
884
- fold_alpha=fold_alpha,
885
- title=f"Precision-Recall Curves - Cross Validation (n={n_folds} folds)",
886
- verbose=True,
887
- )
888
- logger.info(f"Created CV aggregation PR plot with faded fold lines")
889
-
890
- def save_feature_importance(
891
- self,
892
- model,
893
- feature_names: List[str],
894
- fold: Optional[int] = None,
895
- ) -> Dict[str, float]:
896
- """
897
- Calculate and save feature importance for tree-based models.
898
-
899
- Parameters
900
- ----------
901
- model : object
902
- Fitted classifier (must have feature_importances_ or coef_)
903
- feature_names : List[str]
904
- Names of features
905
- fold : int, optional
906
- Fold number for tracking
907
-
908
- Returns
909
- -------
910
- Dict[str, float]
911
- Dictionary of feature importances {feature_name: importance}
912
- """
913
- # Use centralized metric calculation
914
- from scitex.ai.metrics import calc_feature_importance
915
-
916
- try:
917
- importance_dict, importances = calc_feature_importance(model, feature_names)
918
- except ValueError as e:
919
- logger.warning(f"Could not extract feature importance: {e}")
920
- return {}
921
-
922
- # Already sorted by calc_feature_importance
923
- sorted_importances = list(importance_dict.items())
924
-
925
- # Save as JSON using FILENAME_PATTERNS
926
- fold_subdir = (
927
- FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
928
- if fold is not None
929
- else "cv_summary"
930
- )
931
- json_filename = FILENAME_PATTERNS["feature_importance_json"].format(fold=fold)
932
- self.storage.save(dict(sorted_importances), f"{fold_subdir}/{json_filename}")
933
-
934
- # Create visualization using FILENAME_PATTERNS
935
- jpg_filename = FILENAME_PATTERNS["feature_importance_jpg"].format(fold=fold)
936
- save_path = self.output_dir / fold_subdir / jpg_filename
937
- save_path.parent.mkdir(parents=True, exist_ok=True)
938
-
939
- self.plotter.create_feature_importance_plot(
940
- feature_importance=importances,
941
- feature_names=feature_names,
942
- save_path=save_path,
943
- title=(
944
- f"Feature Importance (Fold {fold:02d})"
945
- if fold is not None
946
- else "Feature Importance (CV Summary)"
947
- ),
948
- )
949
-
950
- logger.info(
951
- f"Saved feature importance"
952
- + (f" for fold {fold}" if fold is not None else "")
953
- )
954
-
955
- return importance_dict
956
-
957
- def save_feature_importance_summary(
958
- self,
959
- all_importances: List[Dict[str, float]],
960
- ) -> None:
961
- """
962
- Create summary visualization of feature importances across all folds.
963
-
964
- Parameters
965
- ----------
966
- all_importances : List[Dict[str, float]]
967
- List of feature importance dicts from each fold
968
- """
969
- if not all_importances:
970
- return
971
-
972
- # Aggregate importances across folds
973
- all_features = set()
974
- for imp_dict in all_importances:
975
- all_features.update(imp_dict.keys())
976
-
977
- # Calculate mean and std for each feature
978
- feature_stats = {}
979
- for feature in all_features:
980
- values = [imp_dict.get(feature, 0) for imp_dict in all_importances]
981
- feature_stats[feature] = {
982
- "mean": float(np.mean(values)),
983
- "std": float(np.std(values)),
984
- "values": [float(v) for v in values],
985
- }
986
-
987
- # Sort by mean importance
988
- sorted_features = sorted(
989
- feature_stats.items(), key=lambda x: x[1]["mean"], reverse=True
990
- )
991
-
992
- # Save as JSON using FILENAME_PATTERNS
993
- n_folds = len(all_importances)
994
- json_filename = FILENAME_PATTERNS["cv_summary_feature_importance_json"].format(
995
- n_folds=n_folds
996
- )
997
- self.storage.save(
998
- dict(sorted_features),
999
- f"cv_summary/{json_filename}",
1000
- )
1001
-
1002
- # Create visualization using centralized plotting function
1003
- from scitex.ai.plt import plot_feature_importance_cv_summary
1004
-
1005
- jpg_filename = FILENAME_PATTERNS["cv_summary_feature_importance_jpg"].format(
1006
- n_folds=n_folds
1007
- )
1008
- save_path = self.output_dir / "cv_summary" / jpg_filename
1009
- save_path.parent.mkdir(parents=True, exist_ok=True)
1010
-
1011
- fig = plot_feature_importance_cv_summary(
1012
- all_importances=all_importances,
1013
- spath=save_path,
1014
- )
1015
-
1016
- logger.info("Saved feature importance summary")
1017
-
1018
- def create_cv_summary_curves(self, summary: Dict[str, Any]) -> None:
1019
- """
1020
- Create CV summary ROC and PR curves from aggregated predictions.
1021
- """
1022
- if not self.all_predictions:
1023
- logger.warning("No predictions stored for CV summary curves")
1024
- return
1025
-
1026
- # Aggregate all predictions
1027
- all_y_true = np.concatenate([p["y_true"] for p in self.all_predictions])
1028
- all_y_proba = np.concatenate([p["y_proba"] for p in self.all_predictions])
1029
-
1030
- # Get per-fold metrics for mean and std
1031
- roc_values = []
1032
- pr_values = []
1033
- for metrics in self.fold_metrics.values():
1034
- if "roc-auc" in metrics:
1035
- val = metrics["roc-auc"]
1036
- if isinstance(val, dict) and "value" in val:
1037
- roc_values.append(val["value"])
1038
- else:
1039
- roc_values.append(val)
1040
- if "pr-auc" in metrics:
1041
- val = metrics["pr-auc"]
1042
- if isinstance(val, dict) and "value" in val:
1043
- pr_values.append(val["value"])
1044
- else:
1045
- pr_values.append(val)
1046
-
1047
- # Calculate mean and std
1048
- n_folds = len(self.fold_metrics)
1049
- if roc_values:
1050
- roc_mean = np.mean(roc_values)
1051
- roc_std = np.std(roc_values)
1052
- else:
1053
- from .reporter_utils.metrics import calc_roc_auc
1054
-
1055
- overall_roc = calc_roc_auc(all_y_true, all_y_proba)
1056
- roc_mean = overall_roc["value"]
1057
- roc_std = 0.0
1058
-
1059
- if pr_values:
1060
- pr_mean = np.mean(pr_values)
1061
- pr_std = np.std(pr_values)
1062
- else:
1063
- from .reporter_utils.metrics import calc_pre_rec_auc
1064
-
1065
- overall_pr = calc_pre_rec_auc(all_y_true, all_y_proba)
1066
- pr_mean = overall_pr["value"]
1067
- pr_std = 0.0
1068
-
1069
- # Create cv_summary directory
1070
- cv_summary_dir = self._create_subdir_if_needed("cv_summary")
1071
- cv_summary_dir.mkdir(parents=True, exist_ok=True)
1072
-
1073
- # Save CV summary curve data
1074
- self._save_cv_summary_curve_data(
1075
- all_y_true,
1076
- all_y_proba,
1077
- roc_mean,
1078
- roc_std,
1079
- pr_mean,
1080
- pr_std,
1081
- n_folds,
1082
- )
1083
-
1084
- # Normalize labels to integers for sklearn curve functions in plotter
1085
- from scitex.ai.metrics import _normalize_labels
1086
-
1087
- all_y_true_norm, _, label_names, _ = _normalize_labels(all_y_true, all_y_true)
1088
-
1089
- # ROC Curve with mean±std and n_folds in filename
1090
- roc_title = f"ROC Curve (CV Summary) - AUC: {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
1091
- roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_jpg"].format(
1092
- mean=roc_mean, std=roc_std, n_folds=n_folds
1093
- )
1094
- self.plotter.create_overall_roc_curve(
1095
- all_y_true_norm,
1096
- all_y_proba,
1097
- labels=label_names,
1098
- save_path=cv_summary_dir / roc_filename,
1099
- title=roc_title,
1100
- auc_mean=roc_mean,
1101
- auc_std=roc_std,
1102
- verbose=True,
1103
- )
1104
-
1105
- # PR Curve with mean±std and n_folds in filename
1106
- pr_title = f"Precision-Recall Curve (CV Summary) - AP: {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
1107
- pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_jpg"].format(
1108
- mean=pr_mean, std=pr_std, n_folds=n_folds
1109
- )
1110
- self.plotter.create_overall_pr_curve(
1111
- all_y_true_norm,
1112
- all_y_proba,
1113
- labels=label_names,
1114
- save_path=cv_summary_dir / pr_filename,
1115
- title=pr_title,
1116
- ap_mean=pr_mean,
1117
- ap_std=pr_std,
1118
- verbose=True,
1119
- )
1120
-
1121
- logger.info(
1122
- f"Created CV summary ROC curve: AUC = {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
1123
- )
1124
- logger.info(
1125
- f"Created CV summary PR curve: AP = {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
1126
- )
1127
-
1128
- def _save_cv_summary_curve_data(
1129
- self,
1130
- y_true: np.ndarray,
1131
- y_proba: np.ndarray,
1132
- roc_mean: float,
1133
- roc_std: float,
1134
- pr_mean: float,
1135
- pr_std: float,
1136
- n_folds: int,
1137
- ) -> None:
1138
- """Save CV summary ROC and PR curve data as CSV files with metric values in filenames."""
1139
- from sklearn.metrics import (
1140
- auc,
1141
- average_precision_score,
1142
- precision_recall_curve,
1143
- roc_curve,
1144
- )
1145
-
1146
- cv_summary_dir = "cv_summary"
1147
-
1148
- # Handle binary vs multiclass
1149
- if y_proba.ndim == 1 or y_proba.shape[1] == 2:
1150
- # Binary classification
1151
- if y_proba.ndim == 2:
1152
- y_proba_pos = y_proba[:, 1]
1153
- else:
1154
- y_proba_pos = y_proba
1155
-
1156
- # Normalize labels to integers for sklearn curve functions
1157
- from scitex.ai.metrics import _normalize_labels
1158
-
1159
- y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
1160
-
1161
- # ROC curve data
1162
- fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
1163
- roc_auc = auc(fpr, tpr)
1164
-
1165
- # Create ROC curve DataFrame with just FPR and TPR columns
1166
- roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
1167
-
1168
- # Save with mean±std and n_folds in filename
1169
- roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_csv"].format(
1170
- mean=roc_mean, std=roc_std, n_folds=n_folds
1171
- )
1172
- self.storage.save(roc_df, f"{cv_summary_dir}/{roc_filename}")
1173
-
1174
- # PR curve data
1175
- precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
1176
- avg_precision = average_precision_score(y_true_norm, y_proba_pos)
1177
-
1178
- # Create PR curve DataFrame with Recall and Precision columns
1179
- pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
1180
-
1181
- # Save with mean±std and n_folds in filename
1182
- pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_csv"].format(
1183
- mean=pr_mean, std=pr_std, n_folds=n_folds
1184
- )
1185
- self.storage.save(pr_df, f"{cv_summary_dir}/{pr_filename}")
1186
-
1187
- def save_cv_summary_confusion_matrix(self, summary: Dict[str, Any]) -> None:
1188
- """
1189
- Save and plot the CV summary confusion matrix.
1190
-
1191
- Parameters
1192
- ----------
1193
- summary : Dict[str, Any]
1194
- Summary dictionary containing overall_confusion_matrix
1195
- """
1196
- # Aggregate confusion matrices across all folds
1197
- confusion_matrices = []
1198
- for fold_metrics in self.fold_metrics.values():
1199
- if "confusion_matrix" in fold_metrics:
1200
- cm_data = fold_metrics["confusion_matrix"]
1201
- # Extract actual value if wrapped in dict
1202
- if isinstance(cm_data, dict) and "value" in cm_data:
1203
- cm_data = cm_data["value"]
1204
- if cm_data is not None:
1205
- confusion_matrices.append(cm_data)
1206
-
1207
- if not confusion_matrices:
1208
- return
1209
-
1210
- # Sum all confusion matrices to get overall counts
1211
- overall_cm = np.sum(confusion_matrices, axis=0)
1212
-
1213
- # Get labels from one of the folds
1214
- labels = None
1215
- for fold_metrics in self.fold_metrics.values():
1216
- # Labels are stored directly in fold_metrics now
1217
- if "labels" in fold_metrics:
1218
- labels = fold_metrics["labels"]
1219
- break
1220
- # Fallback: check if labels are in confusion_matrix dict
1221
- elif "confusion_matrix" in fold_metrics:
1222
- cm_data = fold_metrics["confusion_matrix"]
1223
- if isinstance(cm_data, dict) and "labels" in cm_data:
1224
- labels = cm_data["labels"]
1225
- break
1226
-
1227
- # Save as CSV with labels in cv_summary directory
1228
- cv_summary_dir = self._create_subdir_if_needed("cv_summary")
1229
- cv_summary_dir.mkdir(parents=True, exist_ok=True)
1230
-
1231
- # Get balanced accuracy stats for filename
1232
- balanced_acc_mean = None
1233
- balanced_acc_std = None
1234
- n_folds = len(self.fold_metrics)
1235
- if "metrics_summary" in summary:
1236
- if "balanced-accuracy" in summary["metrics_summary"]:
1237
- balanced_acc_stats = summary["metrics_summary"]["balanced-accuracy"]
1238
- balanced_acc_mean = balanced_acc_stats.get("mean")
1239
- balanced_acc_std = balanced_acc_stats.get("std")
1240
-
1241
- # Create filename with mean±std and n_folds
1242
- if balanced_acc_mean is not None and balanced_acc_std is not None:
1243
- cm_filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_csv"].format(
1244
- mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
1245
- )
1246
- else:
1247
- cm_filename = FILENAME_PATTERNS[
1248
- "cv_summary_confusion_matrix_csv_no_bacc"
1249
- ].format(n_folds=n_folds)
1250
-
1251
- if labels:
1252
- cm_df = pd.DataFrame(
1253
- overall_cm,
1254
- index=[f"True_{label}" for label in labels],
1255
- columns=[f"Pred_{label}" for label in labels],
1256
- )
1257
- else:
1258
- cm_df = pd.DataFrame(overall_cm)
1259
-
1260
- # Save with proper filename (with index=True to preserve row labels)
1261
- self.storage.save(cm_df, f"cv_summary/{cm_filename}", index=True)
1262
-
1263
- # Create plot for CV summary confusion matrix
1264
- cv_summary_dir = self._create_subdir_if_needed("cv_summary")
1265
- cv_summary_dir.mkdir(parents=True, exist_ok=True)
1266
-
1267
- # Calculate balanced accuracy mean and std for overall confusion matrix title
1268
- balanced_acc_mean = None
1269
- balanced_acc_std = None
1270
- if "metrics_summary" in self.get_summary():
1271
- metrics_summary = self.get_summary()["metrics_summary"]
1272
- if "balanced-accuracy" in metrics_summary:
1273
- balanced_acc_stats = metrics_summary["balanced-accuracy"]
1274
- balanced_acc_mean = balanced_acc_stats.get("mean")
1275
- balanced_acc_std = balanced_acc_stats.get("std")
1276
-
1277
- # Create title with balanced accuracy stats and filename with mean±std and n_folds
1278
- if balanced_acc_mean is not None and balanced_acc_std is not None:
1279
- title = f"Confusion Matrix (CV Summary) - Balanced Acc: {balanced_acc_mean:.3f} ± {balanced_acc_std:.3f} (n={n_folds})"
1280
- filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_jpg"].format(
1281
- mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
1282
- )
1283
- else:
1284
- title = f"Confusion Matrix (CV Summary) (n={n_folds})"
1285
- filename = FILENAME_PATTERNS[
1286
- "cv_summary_confusion_matrix_jpg_no_bacc"
1287
- ].format(n_folds=n_folds)
1288
-
1289
- # Create the plot with enhanced title
1290
- self.plotter.create_confusion_matrix_plot(
1291
- overall_cm,
1292
- labels=labels,
1293
- save_path=cv_summary_dir / filename,
1294
- title=title,
1295
- )
1296
-
1297
- def generate_reports(self) -> Dict[str, Path]:
1298
- """
1299
- Generate comprehensive reports in multiple formats.
1300
-
1301
- Returns
1302
- -------
1303
- Dict[str, Path]
1304
- Paths to generated report files
1305
- """
1306
- # Prepare results dictionary for report generation
1307
- results = {
1308
- "config": {
1309
- "n_folds": len(self.fold_metrics),
1310
- "output_dir": str(self.output_dir),
1311
- },
1312
- "session_config": self.session_config, # Pass the SciTeX CONFIG
1313
- "summary": {},
1314
- "folds": [],
1315
- "plots": {},
1316
- }
1317
-
1318
- # Get summary statistics
1319
- summary = self.get_summary()
1320
-
1321
- # Extract summary statistics for reporting
1322
- if "metrics_summary" in summary:
1323
- results["summary"] = summary["metrics_summary"]
1324
-
1325
- # Add feature importance if available
1326
- if "feature-importance" in summary:
1327
- results["summary"]["feature-importance"] = summary["feature-importance"]
1328
-
1329
- # Add per-fold results
1330
- for fold, fold_data in self.fold_metrics.items():
1331
- fold_result = {"fold_id": fold}
1332
- fold_result.update(fold_data)
1333
-
1334
- # Try to load sample size info from features.json
1335
- # scitex.io.save transforms relative paths: adds storage_out to calling file's dir
1336
- try:
1337
- import json
1338
-
1339
- # Construct the storage_out path where scitex.io.save actually saves files
1340
- # Pattern: {calling_file_dir}/storage_out/{relative_path}
1341
- calling_file_dir = Path(__file__).parent / "reporter_utils"
1342
- storage_out_path = (
1343
- calling_file_dir
1344
- / "storage_out"
1345
- / self.output_dir
1346
- / FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
1347
- / "features.json"
1348
- )
1349
-
1350
- # Also try regular path in case storage behavior changes
1351
- regular_path = (
1352
- self.output_dir
1353
- / FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
1354
- / "features.json"
1355
- )
1356
-
1357
- features_json = None
1358
- if storage_out_path.exists():
1359
- features_json = storage_out_path
1360
- elif regular_path.exists():
1361
- features_json = regular_path
1362
-
1363
- if features_json:
1364
- with open(features_json, "r") as f:
1365
- features_data = json.load(f)
1366
- # Add sample size info if available
1367
- for key in [
1368
- "n_train",
1369
- "n_test",
1370
- "n_train_seizure",
1371
- "n_train_interictal",
1372
- "n_test_seizure",
1373
- "n_test_interictal",
1374
- ]:
1375
- if key in features_data:
1376
- fold_result[key] = int(features_data[key])
1377
- except Exception:
1378
- pass
1379
-
1380
- results["folds"].append(fold_result)
1381
-
1382
- # Add plot references with unified structure
1383
- # CV summary plots in cv_summary directory
1384
- cv_summary_dir = self.output_dir / "cv_summary"
1385
- if cv_summary_dir.exists():
1386
- for plot_file in cv_summary_dir.glob("*.jpg"):
1387
- plot_key = f"cv_summary_{plot_file.stem}"
1388
- results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
1389
-
1390
- # Per-fold plots in fold directories
1391
- for fold_dir in sorted(self.output_dir.glob("fold_*")):
1392
- # Extract fold number (directory is fold_XX, filename starts with fold-XX)
1393
- fold_num = fold_dir.name.replace("fold_", "")
1394
- for plot_file in fold_dir.glob("*.jpg"):
1395
- # Use just the stem as the plot key since filename already contains fold info
1396
- # e.g., "fold-00_confusion-matrix_bacc-0.500" becomes plot key "fold_00_confusion-matrix"
1397
- plot_key = f"fold_{fold_num}_{plot_file.stem}"
1398
- results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
1399
-
1400
- # Generate reports
1401
- reports_dir = self._create_subdir_if_needed("reports")
1402
- generated_files = {}
1403
-
1404
- # Org-mode report (primary format) - will generate other formats via pandoc
1405
- org_path = reports_dir / "classification_report.org"
1406
- generate_org_report(results, org_path, include_plots=True, convert_formats=True)
1407
- generated_files["org"] = org_path
1408
- logger.info(f"Generated org-mode report: {org_path}")
1409
-
1410
- # Note: Markdown, HTML, LaTeX, and DOCX are now generated via pandoc from org
1411
- # This ensures consistency across all formats
1412
-
1413
- # Try to compile LaTeX to PDF
1414
- try:
1415
- import shutil
1416
- import subprocess
1417
-
1418
- if shutil.which("pdflatex"):
1419
- # Change to reports directory for compilation
1420
- original_dir = Path.cwd()
1421
- try:
1422
- import os
1423
-
1424
- os.chdir(reports_dir)
1425
-
1426
- # Run pdflatex twice for proper references
1427
- for _ in range(2):
1428
- result = subprocess.run(
1429
- [
1430
- "pdflatex",
1431
- "-interaction=nonstopmode",
1432
- "classification_report.tex",
1433
- ],
1434
- capture_output=True,
1435
- text=True,
1436
- timeout=30,
1437
- )
1438
-
1439
- pdf_path = reports_dir / "classification_report.pdf"
1440
- if pdf_path.exists():
1441
- generated_files["pdf"] = pdf_path
1442
- logger.info(f"Generated PDF report: {pdf_path}")
1443
-
1444
- # Clean up LaTeX auxiliary files
1445
- for ext in [".aux", ".log", ".out", ".toc"]:
1446
- aux_file = reports_dir / f"classification_report{ext}"
1447
- if aux_file.exists():
1448
- aux_file.unlink()
1449
- finally:
1450
- os.chdir(original_dir)
1451
- else:
1452
- logger.warning("pdflatex not found. Skipping PDF generation.")
1453
- except Exception as e:
1454
- logger.warning(f"Could not generate PDF report: {e}")
1455
-
1456
- # Skip paper_exports - all data is already available in fold_XX/ and cv_summary/
1457
- # with descriptive filenames perfect for sharing
1458
-
1459
- return generated_files
1460
-
1461
150
  def save_summary(
1462
151
  self, filename: str = "cv_summary/summary.json", verbose: bool = True
1463
152
  ) -> Path:
@@ -1468,6 +157,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
1468
157
  ----------
1469
158
  filename : str, default "cv_summary/summary.json"
1470
159
  Filename for summary (now in cv_summary directory)
160
+ verbose : bool, default True
161
+ Print summary to console
1471
162
 
1472
163
  Returns
1473
164
  -------
@@ -1476,17 +167,11 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
1476
167
  """
1477
168
  summary = self.get_summary()
1478
169
 
1479
- # Try to load and include configuration
1480
170
  try:
1481
- # Try different possible locations for CONFIG.yaml
1482
171
  possible_paths = [
1483
- self.output_dir.parent
1484
- / "CONFIGS"
1485
- / "CONFIG.yaml", # ../CONFIGS/CONFIG.yaml
1486
- self.output_dir.parent.parent
1487
- / "CONFIGS"
1488
- / "CONFIG.yaml", # ../../CONFIGS/CONFIG.yaml
1489
- self.output_dir / "CONFIGS" / "CONFIG.yaml", # ./CONFIGS/CONFIG.yaml
172
+ self.output_dir.parent / "CONFIGS" / "CONFIG.yaml",
173
+ self.output_dir.parent.parent / "CONFIGS" / "CONFIG.yaml",
174
+ self.output_dir / "CONFIGS" / "CONFIG.yaml",
1490
175
  ]
1491
176
 
1492
177
  config_path = None
@@ -1498,33 +183,21 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
1498
183
  if config_path and config_path.exists():
1499
184
  import yaml
1500
185
 
1501
- with open(config_path, "r") as config_file:
186
+ with open(config_path) as config_file:
1502
187
  config_data = yaml.safe_load(config_file)
1503
188
  summary["experiment_configuration"] = config_data
1504
189
  except Exception as e:
1505
190
  logger.warning(f"Could not load CONFIG.yaml: {e}")
1506
191
 
1507
- # Save CV summary metrics with proper filenames
1508
192
  self._save_cv_summary_metrics(summary)
1509
-
1510
- # Save and plot CV summary confusion matrix
1511
193
  self.save_cv_summary_confusion_matrix(summary)
1512
-
1513
- # Create CV summary ROC and PR curves
1514
194
  self.create_cv_summary_curves(summary)
1515
-
1516
- # Create CV aggregation visualizations with faded fold lines
1517
195
  self.create_cv_aggregation_visualizations(
1518
196
  show_individual_folds=True, fold_alpha=0.15
1519
197
  )
1520
-
1521
- # Save CV summary classification report
1522
198
  self._save_cv_summary_classification_report(summary)
1523
-
1524
- # Generate comprehensive reports in multiple formats
1525
199
  self.generate_reports()
1526
200
 
1527
- # Ensure cv_summary directory exists
1528
201
  cv_summary_dir = self._create_subdir_if_needed("cv_summary")
1529
202
  cv_summary_dir.mkdir(parents=True, exist_ok=True)
1530
203
 
@@ -1533,239 +206,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
1533
206
  logger.info("Summary:")
1534
207
  pprint(summary)
1535
208
 
1536
- # Save summary in cv_summary directory
1537
209
  return self.storage.save(summary, "cv_summary/summary.json")
1538
210
 
1539
- def _save_cv_summary_metrics(self, summary: Dict[str, Any]) -> None:
1540
- """
1541
- Save individual CV summary metrics with mean/std/n_folds in filenames.
1542
- """
1543
- if "metrics_summary" not in summary:
1544
- return
1545
-
1546
- n_folds = len(self.fold_metrics)
1547
- cv_summary_dir = "cv_summary"
1548
-
1549
- for metric_name, stats in summary["metrics_summary"].items():
1550
- if isinstance(stats, dict) and "mean" in stats:
1551
- mean_val = stats.get("mean", 0)
1552
- std_val = stats.get("std", 0)
1553
-
1554
- # Create filename with mean_std_n format
1555
- filename = FILENAME_PATTERNS["cv_summary_metric"].format(
1556
- metric_name=metric_name,
1557
- mean=mean_val,
1558
- std=std_val,
1559
- n_folds=n_folds,
1560
- )
1561
-
1562
- # Save metric statistics
1563
- self.storage.save(stats, f"{cv_summary_dir}/{filename}")
1564
-
1565
- def _save_cv_summary_classification_report(self, summary: Dict[str, Any]) -> None:
1566
- """
1567
- Save CV summary classification report with mean ± std (n_folds=X) format.
1568
- """
1569
- n_folds = len(self.fold_metrics)
1570
- cv_summary_dir = "cv_summary"
1571
-
1572
- # Collect classification reports from all folds
1573
- all_reports = []
1574
- for fold_num, fold_metrics in self.fold_metrics.items():
1575
- if "classification_report" in fold_metrics:
1576
- report = fold_metrics["classification_report"]
1577
- if isinstance(report, dict) and "value" in report:
1578
- report = report["value"]
1579
-
1580
- # Convert DataFrame to dict if needed
1581
- if isinstance(report, pd.DataFrame):
1582
- # Convert DataFrame to dict format expected by aggregation
1583
- # Assumes DataFrame has 'class' column and metric columns
1584
- if "class" in report.columns:
1585
- report_dict = {}
1586
- for _, row in report.iterrows():
1587
- class_name = row["class"]
1588
- report_dict[class_name] = {
1589
- col: row[col]
1590
- for col in report.columns
1591
- if col != "class"
1592
- }
1593
- report = report_dict
1594
- else:
1595
- # DataFrame with class names as index
1596
- report = report.to_dict("index")
1597
-
1598
- if isinstance(report, dict):
1599
- all_reports.append(report)
1600
-
1601
- if not all_reports:
1602
- return
1603
-
1604
- # Calculate mean and std for each metric in the classification report
1605
- summary_report = {}
1606
-
1607
- # Get all class labels (excluding summary rows)
1608
- all_classes = set()
1609
- for report in all_reports:
1610
- all_classes.update(
1611
- [
1612
- k
1613
- for k in report.keys()
1614
- if k not in ["accuracy", "macro avg", "weighted avg"]
1615
- ]
1616
- )
1617
-
1618
- # Process each class
1619
- for cls in sorted(all_classes):
1620
- cls_metrics = {
1621
- "precision": [],
1622
- "recall": [],
1623
- "f1-score": [],
1624
- "support": [],
1625
- }
1626
-
1627
- for report in all_reports:
1628
- if cls in report:
1629
- for metric in [
1630
- "precision",
1631
- "recall",
1632
- "f1-score",
1633
- "support",
1634
- ]:
1635
- if metric in report[cls]:
1636
- cls_metrics[metric].append(report[cls][metric])
1637
-
1638
- summary_report[cls] = {}
1639
- for metric, values in cls_metrics.items():
1640
- if values:
1641
- if metric == "support":
1642
- # For support, show total and mean±std to capture variability
1643
- total_support = int(np.sum(values))
1644
- mean_support = np.mean(values)
1645
- std_support = np.std(values)
1646
- if std_support > 0:
1647
- # Show mean±std if there's variability across folds
1648
- summary_report[cls][metric] = (
1649
- f"{mean_support:.1f} ± {std_support:.1f} (total={total_support})"
1650
- )
1651
- else:
1652
- # If constant across folds, just show the value
1653
- summary_report[cls][metric] = (
1654
- f"{int(mean_support)} per fold (total={total_support})"
1655
- )
1656
- else:
1657
- mean_val = np.mean(values)
1658
- std_val = np.std(values)
1659
- summary_report[cls][metric] = (
1660
- f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
1661
- )
1662
-
1663
- # Process summary rows (macro avg, weighted avg)
1664
- for avg_type in ["macro avg", "weighted avg"]:
1665
- avg_metrics = {"precision": [], "recall": [], "f1-score": []}
1666
-
1667
- for report in all_reports:
1668
- if avg_type in report:
1669
- for metric in ["precision", "recall", "f1-score"]:
1670
- if metric in report[avg_type]:
1671
- avg_metrics[metric].append(report[avg_type][metric])
1672
-
1673
- if any(avg_metrics.values()):
1674
- summary_report[avg_type] = {}
1675
- for metric, values in avg_metrics.items():
1676
- if values:
1677
- mean_val = np.mean(values)
1678
- std_val = np.std(values)
1679
- summary_report[avg_type][metric] = (
1680
- f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
1681
- )
1682
-
1683
- # Convert to DataFrame for better visualization
1684
- if summary_report:
1685
- report_df = pd.DataFrame(summary_report).T
1686
- # Reset index to make it an ordinary column with name
1687
- report_df = report_df.reset_index()
1688
- report_df = report_df.rename(columns={"index": "class"})
1689
-
1690
- # Save as CSV
1691
- filename = FILENAME_PATTERNS["cv_summary_classification_report"].format(
1692
- n_folds=n_folds
1693
- )
1694
- self.storage.save(
1695
- report_df,
1696
- f"{cv_summary_dir}/{filename}",
1697
- )
1698
-
1699
- def save(
1700
- self,
1701
- data: Any,
1702
- relative_path: Union[str, Path],
1703
- fold: Optional[int] = None,
1704
- ) -> Path:
1705
- """
1706
- Save custom data with automatic fold organization and filename prefixing.
1707
-
1708
- Parameters
1709
- ----------
1710
- data : Any
1711
- Custom data to save (any format supported by stx.io.save)
1712
- relative_path : Union[str, Path]
1713
- Relative path from output_dir or fold directory. Examples:
1714
- - When fold is provided: "custom_metrics.json" → "fold_00/fold-00_custom_metrics.json"
1715
- - When fold is None: "cv_summary/results.csv" → "cv_summary/results.csv"
1716
- fold : Optional[int], default None
1717
- If provided, automatically prepends "fold_{fold:02d}/" to the path
1718
- and adds "fold-{fold:02d}_" prefix to the filename
1719
-
1720
- Returns
1721
- -------
1722
- Path
1723
- Absolute path to the saved file
1724
-
1725
- Examples
1726
- --------
1727
- >>> # Save custom metrics for fold 0 (automatic fold directory and prefix)
1728
- >>> reporter.save(
1729
- ... {"metric1": 0.95, "metric2": 0.87},
1730
- ... "custom_metrics.json",
1731
- ... fold=0
1732
- ... ) # Saves to: fold_00/fold-00_custom_metrics.json
1733
-
1734
- >>> # Save to cv_summary (no fold, no prefix)
1735
- >>> reporter.save(
1736
- ... df_results,
1737
- ... "cv_summary/final_analysis.csv"
1738
- ... )
1739
-
1740
- >>> # Save to reports directory
1741
- >>> reporter.save(
1742
- ... report_content,
1743
- ... "reports/analysis.md"
1744
- ... )
1745
- """
1746
- # Automatically prepend fold directory and prefix filename if fold is provided
1747
- if fold is not None:
1748
- # Parse the path to add prefix to filename only
1749
- path_obj = Path(relative_path)
1750
- filename = path_obj.name
1751
- parent = path_obj.parent
1752
-
1753
- # Add fold prefix to filename (e.g., "fold-00_custom_metrics.json")
1754
- prefixed_filename = (
1755
- f"{FOLD_FILE_PREFIX_PATTERN.format(fold=fold)}_{filename}"
1756
- )
1757
-
1758
- # Construct full path: fold_00/fold-00_filename.ext
1759
- if parent and str(parent) != ".":
1760
- relative_path = f"{FOLD_DIR_PREFIX_PATTERN.format(fold=fold)}/{parent}/{prefixed_filename}"
1761
- else:
1762
- relative_path = (
1763
- f"{FOLD_DIR_PREFIX_PATTERN.format(fold=fold)}/{prefixed_filename}"
1764
- )
1765
-
1766
- # Use the existing storage.save method which already handles everything
1767
- return self.storage.save(data, relative_path)
1768
-
1769
211
  def __repr__(self) -> str:
1770
212
  fold_count = len(self.fold_metrics)
1771
213
  return (