scitex 2.14.0__py3-none-any.whl → 2.15.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +71 -17
- scitex/_env_loader.py +156 -0
- scitex/_mcp_resources/__init__.py +37 -0
- scitex/_mcp_resources/_cheatsheet.py +135 -0
- scitex/_mcp_resources/_figrecipe.py +138 -0
- scitex/_mcp_resources/_formats.py +102 -0
- scitex/_mcp_resources/_modules.py +337 -0
- scitex/_mcp_resources/_session.py +149 -0
- scitex/_mcp_tools/__init__.py +4 -0
- scitex/_mcp_tools/audio.py +66 -0
- scitex/_mcp_tools/diagram.py +11 -95
- scitex/_mcp_tools/introspect.py +210 -0
- scitex/_mcp_tools/plt.py +260 -305
- scitex/_mcp_tools/scholar.py +74 -0
- scitex/_mcp_tools/social.py +244 -0
- scitex/_mcp_tools/template.py +24 -0
- scitex/_mcp_tools/writer.py +21 -204
- scitex/ai/_gen_ai/_PARAMS.py +10 -7
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
- scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
- scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
- scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
- scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
- scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
- scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
- scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
- scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
- scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +30 -1550
- scitex/ai/classification/timeseries/_sliding_window_core.py +467 -0
- scitex/ai/classification/timeseries/_sliding_window_plotting.py +369 -0
- scitex/audio/README.md +40 -36
- scitex/audio/__init__.py +129 -61
- scitex/audio/_branding.py +185 -0
- scitex/audio/_mcp/__init__.py +32 -0
- scitex/audio/_mcp/handlers.py +59 -6
- scitex/audio/_mcp/speak_handlers.py +238 -0
- scitex/audio/_relay.py +225 -0
- scitex/audio/_tts.py +18 -10
- scitex/audio/engines/base.py +17 -10
- scitex/audio/engines/elevenlabs_engine.py +7 -2
- scitex/audio/mcp_server.py +228 -75
- scitex/canvas/README.md +1 -1
- scitex/canvas/editor/_dearpygui/__init__.py +25 -0
- scitex/canvas/editor/_dearpygui/_editor.py +147 -0
- scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
- scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
- scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
- scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
- scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
- scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
- scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
- scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
- scitex/canvas/editor/_dearpygui/_selection.py +295 -0
- scitex/canvas/editor/_dearpygui/_state.py +93 -0
- scitex/canvas/editor/_dearpygui/_utils.py +61 -0
- scitex/canvas/editor/flask_editor/_core/__init__.py +27 -0
- scitex/canvas/editor/flask_editor/_core/_bbox_extraction.py +200 -0
- scitex/canvas/editor/flask_editor/_core/_editor.py +173 -0
- scitex/canvas/editor/flask_editor/_core/_export_helpers.py +353 -0
- scitex/canvas/editor/flask_editor/_core/_routes_basic.py +190 -0
- scitex/canvas/editor/flask_editor/_core/_routes_export.py +332 -0
- scitex/canvas/editor/flask_editor/_core/_routes_panels.py +252 -0
- scitex/canvas/editor/flask_editor/_core/_routes_save.py +218 -0
- scitex/canvas/editor/flask_editor/_core.py +25 -1684
- scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
- scitex/cli/__init__.py +38 -43
- scitex/cli/audio.py +76 -27
- scitex/cli/capture.py +13 -20
- scitex/cli/introspect.py +481 -0
- scitex/cli/main.py +200 -109
- scitex/cli/mcp.py +60 -34
- scitex/cli/plt.py +357 -0
- scitex/cli/repro.py +15 -8
- scitex/cli/resource.py +15 -8
- scitex/cli/scholar/__init__.py +23 -8
- scitex/cli/scholar/_crossref_scitex.py +296 -0
- scitex/cli/scholar/_fetch.py +25 -3
- scitex/cli/social.py +314 -0
- scitex/cli/stats.py +15 -8
- scitex/cli/template.py +129 -12
- scitex/cli/tex.py +15 -8
- scitex/cli/writer.py +132 -8
- scitex/cloud/__init__.py +41 -2
- scitex/config/README.md +1 -1
- scitex/config/__init__.py +16 -2
- scitex/config/_env_registry.py +256 -0
- scitex/context/__init__.py +22 -0
- scitex/dev/__init__.py +20 -1
- scitex/diagram/__init__.py +42 -19
- scitex/diagram/mcp_server.py +13 -125
- scitex/gen/__init__.py +50 -14
- scitex/gen/_list_packages.py +4 -4
- scitex/introspect/__init__.py +82 -0
- scitex/introspect/_call_graph.py +303 -0
- scitex/introspect/_class_hierarchy.py +163 -0
- scitex/introspect/_core.py +41 -0
- scitex/introspect/_docstring.py +131 -0
- scitex/introspect/_examples.py +113 -0
- scitex/introspect/_imports.py +271 -0
- scitex/{gen/_inspect_module.py → introspect/_list_api.py} +43 -54
- scitex/introspect/_mcp/__init__.py +41 -0
- scitex/introspect/_mcp/handlers.py +233 -0
- scitex/introspect/_members.py +155 -0
- scitex/introspect/_resolve.py +89 -0
- scitex/introspect/_signature.py +131 -0
- scitex/introspect/_source.py +80 -0
- scitex/introspect/_type_hints.py +172 -0
- scitex/io/_save.py +1 -2
- scitex/io/bundle/README.md +1 -1
- scitex/logging/_formatters.py +19 -9
- scitex/mcp_server.py +98 -5
- scitex/os/__init__.py +4 -0
- scitex/{gen → os}/_check_host.py +4 -5
- scitex/plt/__init__.py +245 -550
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
- scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/plt/gallery/README.md +1 -1
- scitex/plt/utils/_hitmap/__init__.py +82 -0
- scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
- scitex/plt/utils/_hitmap/_color_application.py +346 -0
- scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
- scitex/plt/utils/_hitmap/_constants.py +40 -0
- scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
- scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
- scitex/plt/utils/_hitmap/_query.py +113 -0
- scitex/plt/utils/_hitmap.py +46 -1616
- scitex/plt/utils/_metadata/__init__.py +80 -0
- scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
- scitex/plt/utils/_metadata/_artists/_base.py +195 -0
- scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
- scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
- scitex/plt/utils/_metadata/_artists/_images.py +80 -0
- scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
- scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
- scitex/plt/utils/_metadata/_artists/_text.py +106 -0
- scitex/plt/utils/_metadata/_csv.py +416 -0
- scitex/plt/utils/_metadata/_detect.py +225 -0
- scitex/plt/utils/_metadata/_legend.py +127 -0
- scitex/plt/utils/_metadata/_rounding.py +117 -0
- scitex/plt/utils/_metadata/_verification.py +202 -0
- scitex/schema/README.md +1 -1
- scitex/scholar/__init__.py +8 -0
- scitex/scholar/_mcp/crossref_handlers.py +265 -0
- scitex/scholar/core/Scholar.py +63 -1700
- scitex/scholar/core/_mixins/__init__.py +36 -0
- scitex/scholar/core/_mixins/_enrichers.py +270 -0
- scitex/scholar/core/_mixins/_library_handlers.py +100 -0
- scitex/scholar/core/_mixins/_loaders.py +103 -0
- scitex/scholar/core/_mixins/_pdf_download.py +375 -0
- scitex/scholar/core/_mixins/_pipeline.py +312 -0
- scitex/scholar/core/_mixins/_project_handlers.py +125 -0
- scitex/scholar/core/_mixins/_savers.py +69 -0
- scitex/scholar/core/_mixins/_search.py +103 -0
- scitex/scholar/core/_mixins/_services.py +88 -0
- scitex/scholar/core/_mixins/_url_finding.py +105 -0
- scitex/scholar/crossref_scitex.py +367 -0
- scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/scholar/examples/00_run_all.sh +120 -0
- scitex/scholar/jobs/_executors.py +27 -3
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
- scitex/scholar/pdf_download/_cli.py +154 -0
- scitex/scholar/pdf_download/strategies/__init__.py +11 -8
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
- scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
- scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
- scitex/scholar/pipelines/_single_steps.py +71 -36
- scitex/scholar/storage/_LibraryManager.py +97 -1695
- scitex/scholar/storage/_mixins/__init__.py +30 -0
- scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
- scitex/scholar/storage/_mixins/_library_operations.py +218 -0
- scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
- scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
- scitex/scholar/storage/_mixins/_resolution.py +376 -0
- scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
- scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
- scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
- scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
- scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
- scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
- scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
- scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
- scitex/security/README.md +3 -3
- scitex/session/README.md +1 -1
- scitex/session/__init__.py +26 -7
- scitex/session/_decorator.py +1 -1
- scitex/sh/README.md +1 -1
- scitex/sh/__init__.py +7 -4
- scitex/social/__init__.py +155 -0
- scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/stats/_mcp/_handlers/__init__.py +31 -0
- scitex/stats/_mcp/_handlers/_corrections.py +113 -0
- scitex/stats/_mcp/_handlers/_descriptive.py +78 -0
- scitex/stats/_mcp/_handlers/_effect_size.py +106 -0
- scitex/stats/_mcp/_handlers/_format.py +94 -0
- scitex/stats/_mcp/_handlers/_normality.py +110 -0
- scitex/stats/_mcp/_handlers/_posthoc.py +224 -0
- scitex/stats/_mcp/_handlers/_power.py +247 -0
- scitex/stats/_mcp/_handlers/_recommend.py +102 -0
- scitex/stats/_mcp/_handlers/_run_test.py +279 -0
- scitex/stats/_mcp/_handlers/_stars.py +48 -0
- scitex/stats/_mcp/handlers.py +19 -1171
- scitex/stats/auto/_stat_style.py +175 -0
- scitex/stats/auto/_style_definitions.py +411 -0
- scitex/stats/auto/_styles.py +22 -620
- scitex/stats/descriptive/__init__.py +11 -8
- scitex/stats/descriptive/_ci.py +39 -0
- scitex/stats/power/_power.py +15 -4
- scitex/str/__init__.py +2 -1
- scitex/str/_title_case.py +63 -0
- scitex/template/README.md +1 -1
- scitex/template/__init__.py +25 -10
- scitex/template/_code_templates.py +147 -0
- scitex/template/_mcp/handlers.py +81 -0
- scitex/template/_mcp/tool_schemas.py +55 -0
- scitex/template/_templates/__init__.py +51 -0
- scitex/template/_templates/audio.py +233 -0
- scitex/template/_templates/canvas.py +312 -0
- scitex/template/_templates/capture.py +268 -0
- scitex/template/_templates/config.py +43 -0
- scitex/template/_templates/diagram.py +294 -0
- scitex/template/_templates/io.py +107 -0
- scitex/template/_templates/module.py +53 -0
- scitex/template/_templates/plt.py +202 -0
- scitex/template/_templates/scholar.py +267 -0
- scitex/template/_templates/session.py +130 -0
- scitex/template/_templates/session_minimal.py +43 -0
- scitex/template/_templates/session_plot.py +67 -0
- scitex/template/_templates/session_stats.py +77 -0
- scitex/template/_templates/stats.py +323 -0
- scitex/template/_templates/writer.py +296 -0
- scitex/template/clone_writer_directory.py +5 -5
- scitex/ui/_backends/_email.py +10 -2
- scitex/ui/_backends/_webhook.py +5 -1
- scitex/web/_search_pubmed.py +10 -6
- scitex/writer/README.md +1 -1
- scitex/writer/_mcp/handlers.py +11 -744
- scitex/writer/_mcp/tool_schemas.py +5 -335
- scitex-2.15.2.dist-info/METADATA +648 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/RECORD +246 -150
- scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
- scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
- scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
- scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
- scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
- scitex/diagram/_compile.py +0 -312
- scitex/diagram/_diagram.py +0 -355
- scitex/diagram/_mcp/__init__.py +0 -4
- scitex/diagram/_mcp/handlers.py +0 -400
- scitex/diagram/_mcp/tool_schemas.py +0 -157
- scitex/diagram/_presets.py +0 -173
- scitex/diagram/_schema.py +0 -182
- scitex/diagram/_split.py +0 -278
- scitex/gen/_ci.py +0 -12
- scitex/gen/_title_case.py +0 -89
- scitex/plt/_mcp/__init__.py +0 -4
- scitex/plt/_mcp/_handlers_annotation.py +0 -102
- scitex/plt/_mcp/_handlers_figure.py +0 -195
- scitex/plt/_mcp/_handlers_plot.py +0 -252
- scitex/plt/_mcp/_handlers_style.py +0 -219
- scitex/plt/_mcp/handlers.py +0 -74
- scitex/plt/_mcp/tool_schemas.py +0 -497
- scitex/plt/mcp_server.py +0 -231
- scitex/scholar/data/.gitkeep +0 -0
- scitex/scholar/data/README.md +0 -44
- scitex/scholar/data/bib_files/bibliography.bib +0 -1952
- scitex/scholar/data/bib_files/neurovista.bib +0 -277
- scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
- scitex/scholar/data/bib_files/openaccess.bib +0 -89
- scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
- scitex/scholar/data/bib_files/pac.bib +0 -698
- scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
- scitex/scholar/data/bib_files/pac_processed.bib +0 -0
- scitex/scholar/data/bib_files/pac_titles.txt +0 -75
- scitex/scholar/data/bib_files/paywalled.bib +0 -98
- scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
- scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
- scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
- scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
- scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_seizure.bib +0 -46
- scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
- scitex/scholar/data/impact_factor.db +0 -0
- scitex/scholar/examples/SUGGESTIONS.md +0 -865
- scitex/scholar/examples/dev.py +0 -38
- scitex-2.14.0.dist-info/METADATA +0 -1238
- /scitex/{gen → context}/_detect_environment.py +0 -0
- /scitex/{gen → context}/_get_notebook_path.py +0 -0
- /scitex/{gen/_shell.py → sh/_shell_legacy.py} +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/WHEEL +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/entry_points.txt +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,120 +1,58 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/classification/reporters/_SingleClassificationReporter.py
|
|
5
|
-
# ----------------------------------------
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
import os
|
|
8
|
-
|
|
9
|
-
__FILE__ = "./src/scitex/ml/classification/reporters/_SingleClassificationReporter.py"
|
|
10
|
-
__DIR__ = os.path.dirname(__FILE__)
|
|
11
|
-
# ----------------------------------------
|
|
12
|
-
|
|
13
|
-
__FILE__ = __file__
|
|
14
|
-
|
|
15
|
-
from pprint import pprint
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_SingleClassificationReporter.py
|
|
16
4
|
|
|
17
5
|
"""
|
|
18
6
|
Improved Single Classification Reporter with unified API.
|
|
19
7
|
|
|
20
|
-
|
|
21
|
-
-
|
|
22
|
-
-
|
|
23
|
-
-
|
|
24
|
-
-
|
|
25
|
-
-
|
|
8
|
+
This module provides a comprehensive classification reporter that:
|
|
9
|
+
- Uses unified API interface
|
|
10
|
+
- Supports lazy directory creation
|
|
11
|
+
- Provides numerical precision control
|
|
12
|
+
- Creates visualizations with proper error handling
|
|
13
|
+
- Maintains consistent parameter naming
|
|
14
|
+
|
|
15
|
+
The main class inherits from multiple mixins for modular functionality:
|
|
16
|
+
- MetricsMixin: Metrics calculation and aggregation
|
|
17
|
+
- StorageMixin: File storage and organization
|
|
18
|
+
- PlottingMixin: Visualization generation
|
|
19
|
+
- FeatureImportanceMixin: Feature importance analysis
|
|
20
|
+
- CVSummaryMixin: Cross-validation summary generation
|
|
21
|
+
- ReportsMixin: Multi-format report generation
|
|
26
22
|
"""
|
|
27
23
|
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
28
26
|
from pathlib import Path
|
|
27
|
+
from pprint import pprint
|
|
29
28
|
from typing import Any, Dict, List, Optional, Union
|
|
30
29
|
|
|
31
|
-
import numpy as np
|
|
32
|
-
import pandas as pd
|
|
33
30
|
from scitex.logging import getLogger
|
|
34
31
|
|
|
35
|
-
# Import base class and utilities
|
|
36
32
|
from ._BaseClassificationReporter import BaseClassificationReporter, ReporterConfig
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
calc_pre_rec_auc,
|
|
45
|
-
calc_roc_auc,
|
|
33
|
+
from ._mixins import (
|
|
34
|
+
CVSummaryMixin,
|
|
35
|
+
FeatureImportanceMixin,
|
|
36
|
+
MetricsMixin,
|
|
37
|
+
PlottingMixin,
|
|
38
|
+
ReportsMixin,
|
|
39
|
+
StorageMixin,
|
|
46
40
|
)
|
|
47
41
|
from .reporter_utils._Plotter import Plotter
|
|
48
|
-
from .reporter_utils.
|
|
49
|
-
create_summary_statistics,
|
|
50
|
-
generate_latex_report,
|
|
51
|
-
generate_markdown_report,
|
|
52
|
-
generate_org_report,
|
|
53
|
-
)
|
|
54
|
-
from .reporter_utils.storage import MetricStorage, save_metric
|
|
42
|
+
from .reporter_utils.storage import MetricStorage
|
|
55
43
|
|
|
56
44
|
logger = getLogger(__name__)
|
|
57
45
|
|
|
58
46
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
"fold_metric_with_value": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}-{{value:.3f}}.json",
|
|
69
|
-
"fold_metric": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}.json",
|
|
70
|
-
# Confusion matrix
|
|
71
|
-
"confusion_matrix_csv": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.csv",
|
|
72
|
-
"confusion_matrix_csv_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.csv",
|
|
73
|
-
"confusion_matrix_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.jpg",
|
|
74
|
-
"confusion_matrix_jpg_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.jpg",
|
|
75
|
-
# Classification report
|
|
76
|
-
"classification_report": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.csv",
|
|
77
|
-
# ROC curve
|
|
78
|
-
"roc_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.csv",
|
|
79
|
-
"roc_curve_csv_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.csv",
|
|
80
|
-
"roc_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.jpg",
|
|
81
|
-
"roc_curve_jpg_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.jpg",
|
|
82
|
-
# PR curve
|
|
83
|
-
"pr_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.csv",
|
|
84
|
-
"pr_curve_csv_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.csv",
|
|
85
|
-
"pr_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.jpg",
|
|
86
|
-
"pr_curve_jpg_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.jpg",
|
|
87
|
-
# Raw prediction data (optional, enabled via calculate_metrics parameters)
|
|
88
|
-
"y_true": f"{FOLD_FILE_PREFIX_PATTERN}_y-true.csv",
|
|
89
|
-
"y_pred": f"{FOLD_FILE_PREFIX_PATTERN}_y-pred.csv",
|
|
90
|
-
"y_proba": f"{FOLD_FILE_PREFIX_PATTERN}_y-proba.csv",
|
|
91
|
-
# Metrics dashboard
|
|
92
|
-
"metrics_summary": f"{FOLD_FILE_PREFIX_PATTERN}_metrics-summary.jpg",
|
|
93
|
-
# Feature importance
|
|
94
|
-
"feature_importance_json": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.json",
|
|
95
|
-
"feature_importance_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.jpg",
|
|
96
|
-
# Classification report edge cases (when CSV conversion fails)
|
|
97
|
-
"classification_report_json": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.json",
|
|
98
|
-
"classification_report_txt": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.txt",
|
|
99
|
-
# Folds all (CV summary)
|
|
100
|
-
"cv_summary_metric": "cv-summary_{metric_name}_mean-{mean:.3f}_std-{std:.3f}_n-{n_folds}.json",
|
|
101
|
-
"cv_summary_confusion_matrix_csv": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
102
|
-
"cv_summary_confusion_matrix_jpg": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
103
|
-
"cv_summary_classification_report": "cv-summary_classification-report_n-{n_folds}.csv",
|
|
104
|
-
"cv_summary_roc_curve_csv": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
105
|
-
"cv_summary_roc_curve_jpg": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
106
|
-
"cv_summary_pr_curve_csv": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
107
|
-
"cv_summary_pr_curve_jpg": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
108
|
-
"cv_summary_feature_importance_json": "cv-summary_feature-importance_n-{n_folds}.json",
|
|
109
|
-
"cv_summary_feature_importance_jpg": "cv-summary_feature-importance_n-{n_folds}.jpg",
|
|
110
|
-
"cv_summary_summary": "cv-summary_summary.json",
|
|
111
|
-
# Folds all edge cases (when balanced_acc is None)
|
|
112
|
-
"cv_summary_confusion_matrix_csv_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.csv",
|
|
113
|
-
"cv_summary_confusion_matrix_jpg_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.jpg",
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
47
|
+
class SingleTaskClassificationReporter(
|
|
48
|
+
MetricsMixin,
|
|
49
|
+
StorageMixin,
|
|
50
|
+
PlottingMixin,
|
|
51
|
+
FeatureImportanceMixin,
|
|
52
|
+
CVSummaryMixin,
|
|
53
|
+
ReportsMixin,
|
|
54
|
+
BaseClassificationReporter,
|
|
55
|
+
):
|
|
118
56
|
"""
|
|
119
57
|
Improved single-task classification reporter with unified API.
|
|
120
58
|
|
|
@@ -142,6 +80,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
142
80
|
Base directory for outputs. If None, creates timestamped directory.
|
|
143
81
|
config : ReporterConfig, optional
|
|
144
82
|
Configuration object for advanced settings
|
|
83
|
+
verbose : bool, default True
|
|
84
|
+
Print initialization message
|
|
145
85
|
**kwargs
|
|
146
86
|
Additional arguments passed to base class
|
|
147
87
|
|
|
@@ -174,11 +114,9 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
174
114
|
verbose: bool = True,
|
|
175
115
|
**kwargs,
|
|
176
116
|
):
|
|
177
|
-
# Use config or create default
|
|
178
117
|
if config is None:
|
|
179
118
|
config = ReporterConfig()
|
|
180
119
|
|
|
181
|
-
# Initialize base class with config settings
|
|
182
120
|
super().__init__(
|
|
183
121
|
output_dir=output_dir,
|
|
184
122
|
precision=config.precision,
|
|
@@ -186,14 +124,11 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
186
124
|
)
|
|
187
125
|
|
|
188
126
|
self.config = config
|
|
189
|
-
self.session_config = None
|
|
127
|
+
self.session_config = None
|
|
190
128
|
self.storage = MetricStorage(self.output_dir, precision=self.precision)
|
|
191
129
|
self.plotter = Plotter(enable_plotting=True)
|
|
192
130
|
|
|
193
|
-
# Track calculated metrics for summary
|
|
194
131
|
self.fold_metrics: Dict[int, Dict[str, Any]] = {}
|
|
195
|
-
|
|
196
|
-
# Store all predictions for overall curves
|
|
197
132
|
self.all_predictions: List[Dict[str, Any]] = []
|
|
198
133
|
|
|
199
134
|
if verbose:
|
|
@@ -212,1252 +147,6 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
212
147
|
"""
|
|
213
148
|
self.session_config = config
|
|
214
149
|
|
|
215
|
-
def calculate_metrics(
|
|
216
|
-
self,
|
|
217
|
-
y_true: np.ndarray,
|
|
218
|
-
y_pred: np.ndarray,
|
|
219
|
-
y_proba: Optional[np.ndarray] = None,
|
|
220
|
-
labels: Optional[List[str]] = None,
|
|
221
|
-
fold: Optional[int] = None,
|
|
222
|
-
verbose=True,
|
|
223
|
-
store_y_true: bool = True,
|
|
224
|
-
store_y_pred: bool = True,
|
|
225
|
-
store_y_proba: bool = True,
|
|
226
|
-
model=None,
|
|
227
|
-
feature_names: Optional[List[str]] = None,
|
|
228
|
-
) -> Dict[str, Any]:
|
|
229
|
-
"""
|
|
230
|
-
Calculate and save classification metrics using unified API.
|
|
231
|
-
|
|
232
|
-
Parameters
|
|
233
|
-
----------
|
|
234
|
-
y_true : np.ndarray
|
|
235
|
-
True class labels
|
|
236
|
-
y_pred : np.ndarray
|
|
237
|
-
Predicted class labels
|
|
238
|
-
y_proba : np.ndarray, optional
|
|
239
|
-
Prediction probabilities (required for AUC metrics)
|
|
240
|
-
labels : List[str], optional
|
|
241
|
-
Class labels for display
|
|
242
|
-
fold : int, optional
|
|
243
|
-
Fold index for cross-validation
|
|
244
|
-
verbose : bool, default True
|
|
245
|
-
Print progress messages
|
|
246
|
-
store_y_true : bool, default True
|
|
247
|
-
Save y_true as CSV with sample_index and fold columns
|
|
248
|
-
store_y_pred : bool, default True
|
|
249
|
-
Save y_pred as CSV with sample_index and fold columns
|
|
250
|
-
store_y_proba : bool, default True
|
|
251
|
-
Save y_proba as CSV with sample_index and fold columns
|
|
252
|
-
model : object, optional
|
|
253
|
-
Trained model for feature importance extraction
|
|
254
|
-
feature_names : List[str], optional
|
|
255
|
-
Feature names for feature importance (required if model is provided)
|
|
256
|
-
|
|
257
|
-
Returns
|
|
258
|
-
-------
|
|
259
|
-
Dict[str, Any]
|
|
260
|
-
Dictionary of calculated metrics
|
|
261
|
-
"""
|
|
262
|
-
|
|
263
|
-
if verbose:
|
|
264
|
-
if fold:
|
|
265
|
-
print()
|
|
266
|
-
logger.info(f"Calculating metrics for fold #{fold:02d}...")
|
|
267
|
-
else:
|
|
268
|
-
logger.info(f"Calculating metrics...")
|
|
269
|
-
|
|
270
|
-
# Validate inputs
|
|
271
|
-
if len(y_true) != len(y_pred):
|
|
272
|
-
raise ValueError("y_true and y_pred must have same length")
|
|
273
|
-
|
|
274
|
-
if y_proba is not None and len(y_true) != len(y_proba):
|
|
275
|
-
raise ValueError("y_true and y_proba must have same length")
|
|
276
|
-
|
|
277
|
-
# Set default fold index
|
|
278
|
-
if fold is None:
|
|
279
|
-
fold = 0
|
|
280
|
-
|
|
281
|
-
# Set default labels if not provided
|
|
282
|
-
if labels is None:
|
|
283
|
-
unique_labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
|
|
284
|
-
labels = [f"Class_{i}" for i in unique_labels]
|
|
285
|
-
|
|
286
|
-
# Calculate all metrics
|
|
287
|
-
metrics = {}
|
|
288
|
-
|
|
289
|
-
# Core metrics (always calculated) - pass fold to all metrics
|
|
290
|
-
metrics["balanced-accuracy"] = calc_bacc(y_true, y_pred, fold=fold)
|
|
291
|
-
metrics["mcc"] = calc_mcc(y_true, y_pred, fold=fold)
|
|
292
|
-
|
|
293
|
-
metrics["confusion_matrix"] = calc_conf_mat(
|
|
294
|
-
y_true=y_true, y_pred=y_pred, labels=labels, fold=fold
|
|
295
|
-
)
|
|
296
|
-
metrics["classification_report"] = calc_clf_report(
|
|
297
|
-
y_true, y_pred, labels, fold=fold
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
# AUC metrics (only if probabilities available)
|
|
301
|
-
if y_proba is not None:
|
|
302
|
-
try:
|
|
303
|
-
from scitex.ai.metrics import calc_pre_rec_auc, calc_roc_auc
|
|
304
|
-
|
|
305
|
-
metrics["roc-auc"] = calc_roc_auc(
|
|
306
|
-
y_true,
|
|
307
|
-
y_proba,
|
|
308
|
-
labels=labels,
|
|
309
|
-
fold=fold,
|
|
310
|
-
return_curve=False,
|
|
311
|
-
)
|
|
312
|
-
metrics["pr-auc"] = calc_pre_rec_auc(
|
|
313
|
-
y_true,
|
|
314
|
-
y_proba,
|
|
315
|
-
labels=labels,
|
|
316
|
-
fold=fold,
|
|
317
|
-
return_curve=False,
|
|
318
|
-
)
|
|
319
|
-
except Exception as e:
|
|
320
|
-
logger.warning(f"Could not calculate AUC metrics: {e}")
|
|
321
|
-
|
|
322
|
-
# Round all numerical values
|
|
323
|
-
metrics = self._round_numeric(metrics)
|
|
324
|
-
|
|
325
|
-
# Add labels to metrics for later use
|
|
326
|
-
metrics["labels"] = labels
|
|
327
|
-
|
|
328
|
-
if verbose:
|
|
329
|
-
logger.info(f"Metrics calculated:")
|
|
330
|
-
pprint(metrics)
|
|
331
|
-
|
|
332
|
-
# Store metrics for summary
|
|
333
|
-
self.fold_metrics[fold] = metrics.copy()
|
|
334
|
-
|
|
335
|
-
# Store predictions for overall curves
|
|
336
|
-
if y_proba is not None:
|
|
337
|
-
self.all_predictions.append(
|
|
338
|
-
{
|
|
339
|
-
"fold": fold,
|
|
340
|
-
"y_true": y_true.copy(),
|
|
341
|
-
"y_proba": y_proba.copy(),
|
|
342
|
-
}
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
# Save metrics if requested
|
|
346
|
-
self._save_fold_metrics(metrics, fold, labels)
|
|
347
|
-
|
|
348
|
-
# Generate plots if requested
|
|
349
|
-
self._create_plots(y_true, y_pred, y_proba, labels, fold, metrics)
|
|
350
|
-
|
|
351
|
-
# Handle feature importance automatically if model provided
|
|
352
|
-
if model is not None and feature_names is not None:
|
|
353
|
-
try:
|
|
354
|
-
from scitex.ai.feature_selection import extract_feature_importance
|
|
355
|
-
|
|
356
|
-
importance_dict = extract_feature_importance(
|
|
357
|
-
model, feature_names, method="auto"
|
|
358
|
-
)
|
|
359
|
-
if importance_dict:
|
|
360
|
-
# Store in fold metrics for cross-fold aggregation
|
|
361
|
-
metrics["feature-importance"] = importance_dict
|
|
362
|
-
self.fold_metrics[fold]["feature-importance"] = importance_dict
|
|
363
|
-
|
|
364
|
-
# Save feature importance
|
|
365
|
-
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
366
|
-
filename = FILENAME_PATTERNS["feature_importance_json"].format(
|
|
367
|
-
fold=fold
|
|
368
|
-
)
|
|
369
|
-
self.storage.save(importance_dict, f"{fold_dir}/{filename}")
|
|
370
|
-
|
|
371
|
-
if verbose:
|
|
372
|
-
logger.info(f" Feature importance extracted and saved")
|
|
373
|
-
except Exception as e:
|
|
374
|
-
logger.warning(f"Could not extract feature importance: {e}")
|
|
375
|
-
|
|
376
|
-
# Save raw predictions if requested (as CSV using DataFrames)
|
|
377
|
-
# Include sample_index for easy concatenation across folds
|
|
378
|
-
if store_y_true or store_y_pred or store_y_proba:
|
|
379
|
-
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
380
|
-
sample_indices = np.arange(len(y_true))
|
|
381
|
-
|
|
382
|
-
# Warn if file size will be large (>10MB estimated)
|
|
383
|
-
n_samples = len(y_true)
|
|
384
|
-
estimated_size_mb = 0
|
|
385
|
-
if store_y_true:
|
|
386
|
-
estimated_size_mb += n_samples * 0.0001 # ~100 bytes per row
|
|
387
|
-
if store_y_pred:
|
|
388
|
-
estimated_size_mb += n_samples * 0.0001
|
|
389
|
-
if store_y_proba and y_proba is not None:
|
|
390
|
-
n_classes = 1 if y_proba.ndim == 1 else y_proba.shape[1]
|
|
391
|
-
estimated_size_mb += n_samples * n_classes * 0.0001
|
|
392
|
-
|
|
393
|
-
if estimated_size_mb > 10:
|
|
394
|
-
logger.warning(
|
|
395
|
-
f"Storing raw predictions for fold {fold} will create ~{estimated_size_mb:.1f}MB of CSV files. "
|
|
396
|
-
f"Set store_y_true/store_y_pred/store_y_proba=False to disable."
|
|
397
|
-
)
|
|
398
|
-
|
|
399
|
-
if store_y_true:
|
|
400
|
-
filename = FILENAME_PATTERNS["y_true"].format(fold=fold)
|
|
401
|
-
# Convert to DataFrame for CSV format with index
|
|
402
|
-
df_y_true = pd.DataFrame(
|
|
403
|
-
{"sample_index": sample_indices, "fold": fold, "y_true": y_true}
|
|
404
|
-
)
|
|
405
|
-
self.storage.save(df_y_true, f"{fold_dir}/{filename}")
|
|
406
|
-
|
|
407
|
-
if store_y_pred:
|
|
408
|
-
filename = FILENAME_PATTERNS["y_pred"].format(fold=fold)
|
|
409
|
-
# Convert to DataFrame for CSV format with index
|
|
410
|
-
df_y_pred = pd.DataFrame(
|
|
411
|
-
{"sample_index": sample_indices, "fold": fold, "y_pred": y_pred}
|
|
412
|
-
)
|
|
413
|
-
self.storage.save(df_y_pred, f"{fold_dir}/{filename}")
|
|
414
|
-
|
|
415
|
-
if store_y_proba and y_proba is not None:
|
|
416
|
-
filename = FILENAME_PATTERNS["y_proba"].format(fold=fold)
|
|
417
|
-
# Convert to DataFrame for CSV format with index
|
|
418
|
-
# Handle both 1D (binary) and 2D (multiclass) probability arrays
|
|
419
|
-
if y_proba.ndim == 1:
|
|
420
|
-
df_y_proba = pd.DataFrame(
|
|
421
|
-
{
|
|
422
|
-
"sample_index": sample_indices,
|
|
423
|
-
"fold": fold,
|
|
424
|
-
"y_proba": y_proba,
|
|
425
|
-
}
|
|
426
|
-
)
|
|
427
|
-
else:
|
|
428
|
-
# Create column names for each class
|
|
429
|
-
data = {"sample_index": sample_indices, "fold": fold}
|
|
430
|
-
for i in range(y_proba.shape[1]):
|
|
431
|
-
data[f"proba_class_{i}"] = y_proba[:, i]
|
|
432
|
-
df_y_proba = pd.DataFrame(data)
|
|
433
|
-
self.storage.save(df_y_proba, f"{fold_dir}/{filename}")
|
|
434
|
-
|
|
435
|
-
return metrics
|
|
436
|
-
|
|
437
|
-
def _save_fold_metrics(
|
|
438
|
-
self, metrics: Dict[str, Any], fold: int, labels: List[str]
|
|
439
|
-
) -> None:
|
|
440
|
-
"""Save metrics for a specific fold in shallow directory structure."""
|
|
441
|
-
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
442
|
-
|
|
443
|
-
# Extract metric values for filenames
|
|
444
|
-
balanced_acc = self._extract_metric_value(metrics.get("balanced-accuracy"))
|
|
445
|
-
mcc_value = self._extract_metric_value(metrics.get("mcc"))
|
|
446
|
-
roc_auc_value = self._extract_metric_value(metrics.get("roc-auc"))
|
|
447
|
-
pr_auc_value = self._extract_metric_value(metrics.get("pr-auc"))
|
|
448
|
-
|
|
449
|
-
# Save individual metrics with values in filenames
|
|
450
|
-
for metric_name, metric_value in metrics.items():
|
|
451
|
-
# Extract actual value if it's a wrapped dict with 'value' key
|
|
452
|
-
if isinstance(metric_value, dict) and "value" in metric_value:
|
|
453
|
-
actual_value = metric_value["value"]
|
|
454
|
-
else:
|
|
455
|
-
actual_value = metric_value
|
|
456
|
-
|
|
457
|
-
if metric_name == "confusion_matrix":
|
|
458
|
-
# Save confusion matrix as CSV with proper formatting
|
|
459
|
-
try:
|
|
460
|
-
# actual_value is already a DataFrame from calc_conf_mat
|
|
461
|
-
# Just rename the index and columns
|
|
462
|
-
if isinstance(actual_value, pd.DataFrame):
|
|
463
|
-
cm_df = actual_value.copy()
|
|
464
|
-
cm_df.index = [f"True_{label}" for label in labels]
|
|
465
|
-
cm_df.columns = [f"Pred_{label}" for label in labels]
|
|
466
|
-
else:
|
|
467
|
-
# Fallback for numpy array
|
|
468
|
-
cm_df = pd.DataFrame(
|
|
469
|
-
actual_value,
|
|
470
|
-
index=[f"True_{label}" for label in labels],
|
|
471
|
-
columns=[f"Pred_{label}" for label in labels],
|
|
472
|
-
)
|
|
473
|
-
except Exception as e:
|
|
474
|
-
logger.error(f"Error formatting confusion matrix: {e}")
|
|
475
|
-
cm_df = None
|
|
476
|
-
|
|
477
|
-
# Save if cm_df was created successfully
|
|
478
|
-
if cm_df is not None:
|
|
479
|
-
# Create filename with balanced accuracy
|
|
480
|
-
if balanced_acc is not None:
|
|
481
|
-
cm_filename = FILENAME_PATTERNS["confusion_matrix_csv"].format(
|
|
482
|
-
fold=fold, bacc=balanced_acc
|
|
483
|
-
)
|
|
484
|
-
else:
|
|
485
|
-
cm_filename = FILENAME_PATTERNS[
|
|
486
|
-
"confusion_matrix_csv_no_bacc"
|
|
487
|
-
].format(fold=fold)
|
|
488
|
-
|
|
489
|
-
# Save with index=True to preserve row labels
|
|
490
|
-
self.storage.save(cm_df, f"{fold_dir}/{cm_filename}", index=True)
|
|
491
|
-
|
|
492
|
-
elif metric_name == "classification_report":
|
|
493
|
-
# Save classification report with consistent naming
|
|
494
|
-
report_filename = FILENAME_PATTERNS["classification_report"].format(
|
|
495
|
-
fold=fold
|
|
496
|
-
)
|
|
497
|
-
if isinstance(actual_value, pd.DataFrame):
|
|
498
|
-
# Reset index to make it an ordinary column with name
|
|
499
|
-
report_df = actual_value.reset_index()
|
|
500
|
-
report_df = report_df.rename(columns={"index": "class"})
|
|
501
|
-
self.storage.save(report_df, f"{fold_dir}/{report_filename}")
|
|
502
|
-
elif isinstance(actual_value, dict):
|
|
503
|
-
# Try to create DataFrame from dict
|
|
504
|
-
try:
|
|
505
|
-
report_df = pd.DataFrame(actual_value).transpose()
|
|
506
|
-
self.storage.save(report_df, f"{fold_dir}/{report_filename}")
|
|
507
|
-
except:
|
|
508
|
-
# Save as JSON if DataFrame conversion fails
|
|
509
|
-
report_filename = FILENAME_PATTERNS[
|
|
510
|
-
"classification_report_json"
|
|
511
|
-
].format(fold=fold)
|
|
512
|
-
self.storage.save(
|
|
513
|
-
actual_value,
|
|
514
|
-
f"{fold_dir}/{report_filename}",
|
|
515
|
-
)
|
|
516
|
-
else:
|
|
517
|
-
# String or other format
|
|
518
|
-
report_filename = FILENAME_PATTERNS[
|
|
519
|
-
"classification_report_txt"
|
|
520
|
-
].format(fold=fold)
|
|
521
|
-
self.storage.save(actual_value, f"{fold_dir}/{report_filename}")
|
|
522
|
-
|
|
523
|
-
elif metric_name == "balanced-accuracy" and balanced_acc is not None:
|
|
524
|
-
# Save with value in filename
|
|
525
|
-
filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
|
|
526
|
-
fold=fold,
|
|
527
|
-
metric_name="balanced-accuracy",
|
|
528
|
-
value=balanced_acc,
|
|
529
|
-
)
|
|
530
|
-
save_metric(
|
|
531
|
-
actual_value,
|
|
532
|
-
self.output_dir / f"{fold_dir}/{filename}",
|
|
533
|
-
fold=fold,
|
|
534
|
-
precision=self.precision,
|
|
535
|
-
)
|
|
536
|
-
elif metric_name == "mcc" and mcc_value is not None:
|
|
537
|
-
# Save with value in filename
|
|
538
|
-
filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
|
|
539
|
-
fold=fold, metric_name="mcc", value=mcc_value
|
|
540
|
-
)
|
|
541
|
-
save_metric(
|
|
542
|
-
actual_value,
|
|
543
|
-
self.output_dir / f"{fold_dir}/{filename}",
|
|
544
|
-
fold=fold,
|
|
545
|
-
precision=self.precision,
|
|
546
|
-
)
|
|
547
|
-
elif metric_name == "roc-auc" and roc_auc_value is not None:
|
|
548
|
-
# Save with value in filename
|
|
549
|
-
filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
|
|
550
|
-
fold=fold, metric_name="roc-auc", value=roc_auc_value
|
|
551
|
-
)
|
|
552
|
-
save_metric(
|
|
553
|
-
actual_value,
|
|
554
|
-
self.output_dir / f"{fold_dir}/{filename}",
|
|
555
|
-
fold=fold,
|
|
556
|
-
precision=self.precision,
|
|
557
|
-
)
|
|
558
|
-
elif metric_name == "pr-auc" and pr_auc_value is not None:
|
|
559
|
-
# Save with value in filename
|
|
560
|
-
filename = FILENAME_PATTERNS["fold_metric_with_value"].format(
|
|
561
|
-
fold=fold, metric_name="pr-auc", value=pr_auc_value
|
|
562
|
-
)
|
|
563
|
-
save_metric(
|
|
564
|
-
actual_value,
|
|
565
|
-
self.output_dir / f"{fold_dir}/{filename}",
|
|
566
|
-
fold=fold,
|
|
567
|
-
precision=self.precision,
|
|
568
|
-
)
|
|
569
|
-
|
|
570
|
-
def _extract_metric_value(self, metric_data: Any) -> Optional[float]:
|
|
571
|
-
"""Extract numeric value from metric data."""
|
|
572
|
-
if metric_data is None:
|
|
573
|
-
return None
|
|
574
|
-
if isinstance(metric_data, dict) and "value" in metric_data:
|
|
575
|
-
return float(metric_data["value"])
|
|
576
|
-
if isinstance(metric_data, (int, float, np.number)):
|
|
577
|
-
return float(metric_data)
|
|
578
|
-
return None
|
|
579
|
-
|
|
580
|
-
def _save_curve_data(
|
|
581
|
-
self,
|
|
582
|
-
y_true: np.ndarray,
|
|
583
|
-
y_proba: Optional[np.ndarray],
|
|
584
|
-
fold: int,
|
|
585
|
-
metrics: Dict[str, Any],
|
|
586
|
-
) -> None:
|
|
587
|
-
"""Save ROC and PR curve data as CSV files with metric values in filenames."""
|
|
588
|
-
if y_proba is None:
|
|
589
|
-
return
|
|
590
|
-
|
|
591
|
-
from sklearn.metrics import (
|
|
592
|
-
auc,
|
|
593
|
-
average_precision_score,
|
|
594
|
-
precision_recall_curve,
|
|
595
|
-
roc_curve,
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
599
|
-
|
|
600
|
-
# Handle binary vs multiclass
|
|
601
|
-
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
602
|
-
# Binary classification
|
|
603
|
-
if y_proba.ndim == 2:
|
|
604
|
-
y_proba_pos = y_proba[:, 1]
|
|
605
|
-
else:
|
|
606
|
-
y_proba_pos = y_proba
|
|
607
|
-
|
|
608
|
-
# Normalize labels to integers for sklearn curve functions
|
|
609
|
-
from scitex.ai.metrics import _normalize_labels
|
|
610
|
-
|
|
611
|
-
y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
|
|
612
|
-
|
|
613
|
-
# ROC curve data
|
|
614
|
-
fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
|
|
615
|
-
roc_auc = auc(fpr, tpr)
|
|
616
|
-
|
|
617
|
-
# Create ROC curve DataFrame with just FPR and TPR columns
|
|
618
|
-
roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
|
|
619
|
-
|
|
620
|
-
# Save with AUC value in filename
|
|
621
|
-
roc_filename = FILENAME_PATTERNS["roc_curve_csv"].format(
|
|
622
|
-
fold=fold, auc=roc_auc
|
|
623
|
-
)
|
|
624
|
-
self.storage.save(roc_df, f"{fold_dir}/{roc_filename}")
|
|
625
|
-
|
|
626
|
-
# PR curve data
|
|
627
|
-
precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
|
|
628
|
-
avg_precision = average_precision_score(y_true_norm, y_proba_pos)
|
|
629
|
-
|
|
630
|
-
# Create PR curve DataFrame with Recall and Precision columns
|
|
631
|
-
pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
|
|
632
|
-
|
|
633
|
-
# Save with AP value in filename
|
|
634
|
-
pr_filename = FILENAME_PATTERNS["pr_curve_csv"].format(
|
|
635
|
-
fold=fold, ap=avg_precision
|
|
636
|
-
)
|
|
637
|
-
self.storage.save(pr_df, f"{fold_dir}/{pr_filename}")
|
|
638
|
-
|
|
639
|
-
def _create_plots(
|
|
640
|
-
self,
|
|
641
|
-
y_true: np.ndarray,
|
|
642
|
-
y_pred: np.ndarray,
|
|
643
|
-
y_proba: Optional[np.ndarray],
|
|
644
|
-
labels: List[str],
|
|
645
|
-
fold: int,
|
|
646
|
-
metrics: Dict[str, Any],
|
|
647
|
-
) -> None:
|
|
648
|
-
"""Create and save plots with metric-based filenames in unified structure."""
|
|
649
|
-
# Use unified fold directory
|
|
650
|
-
fold_dir = self._create_subdir_if_needed(
|
|
651
|
-
FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
652
|
-
)
|
|
653
|
-
fold_dir.mkdir(parents=True, exist_ok=True)
|
|
654
|
-
|
|
655
|
-
# # Save curve data for external plotting
|
|
656
|
-
# self._save_curve_data(y_true, y_proba, fold, metrics)
|
|
657
|
-
|
|
658
|
-
# Confusion matrix plot with metric in filename
|
|
659
|
-
if "confusion_matrix" in metrics:
|
|
660
|
-
# Extract actual confusion matrix value if wrapped in dict
|
|
661
|
-
cm_data = metrics["confusion_matrix"]
|
|
662
|
-
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
663
|
-
cm_data = cm_data["value"]
|
|
664
|
-
|
|
665
|
-
# Get balanced accuracy for title and filename
|
|
666
|
-
balanced_acc = metrics.get("balanced-accuracy", {})
|
|
667
|
-
if isinstance(balanced_acc, dict) and "value" in balanced_acc:
|
|
668
|
-
balanced_acc = balanced_acc["value"]
|
|
669
|
-
elif isinstance(balanced_acc, (float, np.floating)):
|
|
670
|
-
balanced_acc = float(balanced_acc)
|
|
671
|
-
else:
|
|
672
|
-
balanced_acc = None
|
|
673
|
-
|
|
674
|
-
# Create title with balanced accuracy and filename with fold and metric
|
|
675
|
-
if balanced_acc is not None:
|
|
676
|
-
title = f"Confusion Matrix (Fold {fold:02d}) - Balanced Acc: {balanced_acc:.3f}"
|
|
677
|
-
filename = FILENAME_PATTERNS["confusion_matrix_jpg"].format(
|
|
678
|
-
fold=fold, bacc=balanced_acc
|
|
679
|
-
)
|
|
680
|
-
else:
|
|
681
|
-
title = f"Confusion Matrix (Fold {fold:02d})"
|
|
682
|
-
filename = FILENAME_PATTERNS["confusion_matrix_jpg_no_bacc"].format(
|
|
683
|
-
fold=fold
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
self.plotter.create_confusion_matrix_plot(
|
|
687
|
-
cm_data,
|
|
688
|
-
labels=labels,
|
|
689
|
-
save_path=fold_dir / filename,
|
|
690
|
-
title=title,
|
|
691
|
-
)
|
|
692
|
-
|
|
693
|
-
# ROC curve with AUC in filename (if probabilities available)
|
|
694
|
-
if y_proba is not None:
|
|
695
|
-
# Get AUC for filename
|
|
696
|
-
roc_auc = metrics.get("roc-auc", {})
|
|
697
|
-
if isinstance(roc_auc, dict) and "value" in roc_auc:
|
|
698
|
-
roc_auc_val = roc_auc["value"]
|
|
699
|
-
roc_filename = FILENAME_PATTERNS["roc_curve_jpg"].format(
|
|
700
|
-
fold=fold, auc=roc_auc_val
|
|
701
|
-
)
|
|
702
|
-
else:
|
|
703
|
-
roc_filename = FILENAME_PATTERNS["roc_curve_jpg_no_auc"].format(
|
|
704
|
-
fold=fold
|
|
705
|
-
)
|
|
706
|
-
|
|
707
|
-
self.plotter.create_roc_curve(
|
|
708
|
-
y_true,
|
|
709
|
-
y_proba,
|
|
710
|
-
labels=labels,
|
|
711
|
-
save_path=fold_dir / roc_filename,
|
|
712
|
-
title=f"ROC Curve (Fold {fold:02d})",
|
|
713
|
-
)
|
|
714
|
-
|
|
715
|
-
# PR curve with AP in filename
|
|
716
|
-
pr_auc = metrics.get("pr-auc", {})
|
|
717
|
-
if isinstance(pr_auc, dict) and "value" in pr_auc:
|
|
718
|
-
pr_auc_val = pr_auc["value"]
|
|
719
|
-
pr_filename = FILENAME_PATTERNS["pr_curve_jpg"].format(
|
|
720
|
-
fold=fold, ap=pr_auc_val
|
|
721
|
-
)
|
|
722
|
-
else:
|
|
723
|
-
pr_filename = FILENAME_PATTERNS["pr_curve_jpg_no_ap"].format(fold=fold)
|
|
724
|
-
|
|
725
|
-
self.plotter.create_precision_recall_curve(
|
|
726
|
-
y_true,
|
|
727
|
-
y_proba,
|
|
728
|
-
labels=labels,
|
|
729
|
-
save_path=fold_dir / pr_filename,
|
|
730
|
-
title=f"Precision-Recall Curve (Fold {fold:02d})",
|
|
731
|
-
)
|
|
732
|
-
|
|
733
|
-
# NEW: Create comprehensive metrics visualization dashboard
|
|
734
|
-
# This automatically creates a 4-panel figure with confusion matrix, ROC, PR curve, and metrics table
|
|
735
|
-
summary_filename = FILENAME_PATTERNS["metrics_summary"].format(fold=fold)
|
|
736
|
-
self.plotter.create_metrics_visualization(
|
|
737
|
-
metrics=metrics,
|
|
738
|
-
y_true=y_true,
|
|
739
|
-
y_pred=y_pred,
|
|
740
|
-
y_proba=y_proba,
|
|
741
|
-
labels=labels,
|
|
742
|
-
save_path=fold_dir / summary_filename,
|
|
743
|
-
title="Classification Metrics Dashboard",
|
|
744
|
-
fold=fold,
|
|
745
|
-
verbose=False, # Already have verbose output from individual plots
|
|
746
|
-
)
|
|
747
|
-
|
|
748
|
-
def get_summary(self) -> Dict[str, Any]:
|
|
749
|
-
"""
|
|
750
|
-
Get summary of all calculated metrics across folds.
|
|
751
|
-
|
|
752
|
-
Returns
|
|
753
|
-
-------
|
|
754
|
-
Dict[str, Any]
|
|
755
|
-
Summary statistics across all folds
|
|
756
|
-
"""
|
|
757
|
-
if not self.fold_metrics:
|
|
758
|
-
return {"error": "No metrics calculated yet"}
|
|
759
|
-
|
|
760
|
-
summary = {
|
|
761
|
-
"output_dir": str(self.output_dir),
|
|
762
|
-
"total_folds": len(self.fold_metrics),
|
|
763
|
-
"metrics_summary": {},
|
|
764
|
-
}
|
|
765
|
-
|
|
766
|
-
# Aggregate confusion matrices across all folds
|
|
767
|
-
confusion_matrices = []
|
|
768
|
-
for fold_metrics in self.fold_metrics.values():
|
|
769
|
-
if "confusion_matrix" in fold_metrics:
|
|
770
|
-
cm_data = fold_metrics["confusion_matrix"]
|
|
771
|
-
# Extract actual value if wrapped in dict
|
|
772
|
-
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
773
|
-
cm_data = cm_data["value"]
|
|
774
|
-
if cm_data is not None:
|
|
775
|
-
confusion_matrices.append(cm_data)
|
|
776
|
-
|
|
777
|
-
if confusion_matrices:
|
|
778
|
-
# Sum all confusion matrices to get overall counts
|
|
779
|
-
overall_cm = np.sum(confusion_matrices, axis=0)
|
|
780
|
-
summary["overall_confusion_matrix"] = overall_cm.tolist()
|
|
781
|
-
|
|
782
|
-
# Also calculate normalized version (as percentages)
|
|
783
|
-
overall_cm_normalized = overall_cm / overall_cm.sum()
|
|
784
|
-
summary["overall_confusion_matrix_normalized"] = self._round_numeric(
|
|
785
|
-
overall_cm_normalized.tolist()
|
|
786
|
-
)
|
|
787
|
-
|
|
788
|
-
# Calculate summary statistics for scalar metrics
|
|
789
|
-
scalar_metrics = ["balanced-accuracy", "mcc", "roc-auc", "pr-auc"]
|
|
790
|
-
|
|
791
|
-
for metric_name in scalar_metrics:
|
|
792
|
-
values = []
|
|
793
|
-
for fold_metrics in self.fold_metrics.values():
|
|
794
|
-
if metric_name in fold_metrics:
|
|
795
|
-
# Extract actual value if it's wrapped in dict
|
|
796
|
-
metric_val = fold_metrics[metric_name]
|
|
797
|
-
if isinstance(metric_val, dict) and "value" in metric_val:
|
|
798
|
-
values.append(metric_val["value"])
|
|
799
|
-
else:
|
|
800
|
-
values.append(metric_val)
|
|
801
|
-
|
|
802
|
-
if values:
|
|
803
|
-
values = np.array(values)
|
|
804
|
-
summary["metrics_summary"][metric_name] = {
|
|
805
|
-
"mean": self._round_numeric(np.mean(values)),
|
|
806
|
-
"std": self._round_numeric(np.std(values)),
|
|
807
|
-
"min": self._round_numeric(np.min(values)),
|
|
808
|
-
"max": self._round_numeric(np.max(values)),
|
|
809
|
-
"values": self._round_numeric(values.tolist()),
|
|
810
|
-
}
|
|
811
|
-
|
|
812
|
-
# Aggregate feature importance across folds
|
|
813
|
-
feature_importances_list = []
|
|
814
|
-
for fold_metrics in self.fold_metrics.values():
|
|
815
|
-
if "feature-importance" in fold_metrics:
|
|
816
|
-
feature_importances_list.append(fold_metrics["feature-importance"])
|
|
817
|
-
|
|
818
|
-
if feature_importances_list:
|
|
819
|
-
from scitex.ai.feature_selection import (
|
|
820
|
-
aggregate_feature_importances,
|
|
821
|
-
create_feature_importance_dataframe,
|
|
822
|
-
)
|
|
823
|
-
|
|
824
|
-
aggregated_importances = aggregate_feature_importances(
|
|
825
|
-
feature_importances_list
|
|
826
|
-
)
|
|
827
|
-
summary["feature-importance"] = aggregated_importances
|
|
828
|
-
|
|
829
|
-
return summary
|
|
830
|
-
|
|
831
|
-
def create_cv_aggregation_visualizations(
|
|
832
|
-
self,
|
|
833
|
-
output_dir: Optional[Path] = None,
|
|
834
|
-
show_individual_folds: bool = True,
|
|
835
|
-
fold_alpha: float = 0.15,
|
|
836
|
-
) -> None:
|
|
837
|
-
"""
|
|
838
|
-
Create CV aggregation visualizations with faded individual fold lines.
|
|
839
|
-
|
|
840
|
-
This creates publication-quality cross-validation plots showing:
|
|
841
|
-
- Individual fold curves (faded/transparent)
|
|
842
|
-
- Mean curve across folds (bold)
|
|
843
|
-
- Confidence intervals (± 1 std. dev.)
|
|
844
|
-
|
|
845
|
-
Parameters
|
|
846
|
-
----------
|
|
847
|
-
output_dir : Path, optional
|
|
848
|
-
Directory to save plots (defaults to cv_summary)
|
|
849
|
-
show_individual_folds : bool, default True
|
|
850
|
-
Whether to show individual fold curves
|
|
851
|
-
fold_alpha : float, default 0.15
|
|
852
|
-
Transparency for individual fold curves (0-1)
|
|
853
|
-
"""
|
|
854
|
-
if not self.all_predictions:
|
|
855
|
-
logger.warning("No predictions stored for CV aggregation visualizations")
|
|
856
|
-
return
|
|
857
|
-
|
|
858
|
-
if output_dir is None:
|
|
859
|
-
output_dir = self._create_subdir_if_needed("cv_summary")
|
|
860
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
861
|
-
|
|
862
|
-
n_folds = len(self.all_predictions)
|
|
863
|
-
|
|
864
|
-
# ROC curve with faded fold lines
|
|
865
|
-
roc_save_path = output_dir / f"roc_cv_aggregation_n{n_folds}.jpg"
|
|
866
|
-
self.plotter.create_cv_aggregation_plot(
|
|
867
|
-
fold_predictions=self.all_predictions,
|
|
868
|
-
curve_type="roc",
|
|
869
|
-
save_path=roc_save_path,
|
|
870
|
-
show_individual_folds=show_individual_folds,
|
|
871
|
-
fold_alpha=fold_alpha,
|
|
872
|
-
title=f"ROC Curves - Cross Validation (n={n_folds} folds)",
|
|
873
|
-
verbose=True,
|
|
874
|
-
)
|
|
875
|
-
logger.info(f"Created CV aggregation ROC plot with faded fold lines")
|
|
876
|
-
|
|
877
|
-
# PR curve with faded fold lines
|
|
878
|
-
pr_save_path = output_dir / f"pr_cv_aggregation_n{n_folds}.jpg"
|
|
879
|
-
self.plotter.create_cv_aggregation_plot(
|
|
880
|
-
fold_predictions=self.all_predictions,
|
|
881
|
-
curve_type="pr",
|
|
882
|
-
save_path=pr_save_path,
|
|
883
|
-
show_individual_folds=show_individual_folds,
|
|
884
|
-
fold_alpha=fold_alpha,
|
|
885
|
-
title=f"Precision-Recall Curves - Cross Validation (n={n_folds} folds)",
|
|
886
|
-
verbose=True,
|
|
887
|
-
)
|
|
888
|
-
logger.info(f"Created CV aggregation PR plot with faded fold lines")
|
|
889
|
-
|
|
890
|
-
def save_feature_importance(
|
|
891
|
-
self,
|
|
892
|
-
model,
|
|
893
|
-
feature_names: List[str],
|
|
894
|
-
fold: Optional[int] = None,
|
|
895
|
-
) -> Dict[str, float]:
|
|
896
|
-
"""
|
|
897
|
-
Calculate and save feature importance for tree-based models.
|
|
898
|
-
|
|
899
|
-
Parameters
|
|
900
|
-
----------
|
|
901
|
-
model : object
|
|
902
|
-
Fitted classifier (must have feature_importances_ or coef_)
|
|
903
|
-
feature_names : List[str]
|
|
904
|
-
Names of features
|
|
905
|
-
fold : int, optional
|
|
906
|
-
Fold number for tracking
|
|
907
|
-
|
|
908
|
-
Returns
|
|
909
|
-
-------
|
|
910
|
-
Dict[str, float]
|
|
911
|
-
Dictionary of feature importances {feature_name: importance}
|
|
912
|
-
"""
|
|
913
|
-
# Use centralized metric calculation
|
|
914
|
-
from scitex.ai.metrics import calc_feature_importance
|
|
915
|
-
|
|
916
|
-
try:
|
|
917
|
-
importance_dict, importances = calc_feature_importance(model, feature_names)
|
|
918
|
-
except ValueError as e:
|
|
919
|
-
logger.warning(f"Could not extract feature importance: {e}")
|
|
920
|
-
return {}
|
|
921
|
-
|
|
922
|
-
# Already sorted by calc_feature_importance
|
|
923
|
-
sorted_importances = list(importance_dict.items())
|
|
924
|
-
|
|
925
|
-
# Save as JSON using FILENAME_PATTERNS
|
|
926
|
-
fold_subdir = (
|
|
927
|
-
FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
928
|
-
if fold is not None
|
|
929
|
-
else "cv_summary"
|
|
930
|
-
)
|
|
931
|
-
json_filename = FILENAME_PATTERNS["feature_importance_json"].format(fold=fold)
|
|
932
|
-
self.storage.save(dict(sorted_importances), f"{fold_subdir}/{json_filename}")
|
|
933
|
-
|
|
934
|
-
# Create visualization using FILENAME_PATTERNS
|
|
935
|
-
jpg_filename = FILENAME_PATTERNS["feature_importance_jpg"].format(fold=fold)
|
|
936
|
-
save_path = self.output_dir / fold_subdir / jpg_filename
|
|
937
|
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
938
|
-
|
|
939
|
-
self.plotter.create_feature_importance_plot(
|
|
940
|
-
feature_importance=importances,
|
|
941
|
-
feature_names=feature_names,
|
|
942
|
-
save_path=save_path,
|
|
943
|
-
title=(
|
|
944
|
-
f"Feature Importance (Fold {fold:02d})"
|
|
945
|
-
if fold is not None
|
|
946
|
-
else "Feature Importance (CV Summary)"
|
|
947
|
-
),
|
|
948
|
-
)
|
|
949
|
-
|
|
950
|
-
logger.info(
|
|
951
|
-
f"Saved feature importance"
|
|
952
|
-
+ (f" for fold {fold}" if fold is not None else "")
|
|
953
|
-
)
|
|
954
|
-
|
|
955
|
-
return importance_dict
|
|
956
|
-
|
|
957
|
-
def save_feature_importance_summary(
|
|
958
|
-
self,
|
|
959
|
-
all_importances: List[Dict[str, float]],
|
|
960
|
-
) -> None:
|
|
961
|
-
"""
|
|
962
|
-
Create summary visualization of feature importances across all folds.
|
|
963
|
-
|
|
964
|
-
Parameters
|
|
965
|
-
----------
|
|
966
|
-
all_importances : List[Dict[str, float]]
|
|
967
|
-
List of feature importance dicts from each fold
|
|
968
|
-
"""
|
|
969
|
-
if not all_importances:
|
|
970
|
-
return
|
|
971
|
-
|
|
972
|
-
# Aggregate importances across folds
|
|
973
|
-
all_features = set()
|
|
974
|
-
for imp_dict in all_importances:
|
|
975
|
-
all_features.update(imp_dict.keys())
|
|
976
|
-
|
|
977
|
-
# Calculate mean and std for each feature
|
|
978
|
-
feature_stats = {}
|
|
979
|
-
for feature in all_features:
|
|
980
|
-
values = [imp_dict.get(feature, 0) for imp_dict in all_importances]
|
|
981
|
-
feature_stats[feature] = {
|
|
982
|
-
"mean": float(np.mean(values)),
|
|
983
|
-
"std": float(np.std(values)),
|
|
984
|
-
"values": [float(v) for v in values],
|
|
985
|
-
}
|
|
986
|
-
|
|
987
|
-
# Sort by mean importance
|
|
988
|
-
sorted_features = sorted(
|
|
989
|
-
feature_stats.items(), key=lambda x: x[1]["mean"], reverse=True
|
|
990
|
-
)
|
|
991
|
-
|
|
992
|
-
# Save as JSON using FILENAME_PATTERNS
|
|
993
|
-
n_folds = len(all_importances)
|
|
994
|
-
json_filename = FILENAME_PATTERNS["cv_summary_feature_importance_json"].format(
|
|
995
|
-
n_folds=n_folds
|
|
996
|
-
)
|
|
997
|
-
self.storage.save(
|
|
998
|
-
dict(sorted_features),
|
|
999
|
-
f"cv_summary/{json_filename}",
|
|
1000
|
-
)
|
|
1001
|
-
|
|
1002
|
-
# Create visualization using centralized plotting function
|
|
1003
|
-
from scitex.ai.plt import plot_feature_importance_cv_summary
|
|
1004
|
-
|
|
1005
|
-
jpg_filename = FILENAME_PATTERNS["cv_summary_feature_importance_jpg"].format(
|
|
1006
|
-
n_folds=n_folds
|
|
1007
|
-
)
|
|
1008
|
-
save_path = self.output_dir / "cv_summary" / jpg_filename
|
|
1009
|
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1010
|
-
|
|
1011
|
-
fig = plot_feature_importance_cv_summary(
|
|
1012
|
-
all_importances=all_importances,
|
|
1013
|
-
spath=save_path,
|
|
1014
|
-
)
|
|
1015
|
-
|
|
1016
|
-
logger.info("Saved feature importance summary")
|
|
1017
|
-
|
|
1018
|
-
def create_cv_summary_curves(self, summary: Dict[str, Any]) -> None:
|
|
1019
|
-
"""
|
|
1020
|
-
Create CV summary ROC and PR curves from aggregated predictions.
|
|
1021
|
-
"""
|
|
1022
|
-
if not self.all_predictions:
|
|
1023
|
-
logger.warning("No predictions stored for CV summary curves")
|
|
1024
|
-
return
|
|
1025
|
-
|
|
1026
|
-
# Aggregate all predictions
|
|
1027
|
-
all_y_true = np.concatenate([p["y_true"] for p in self.all_predictions])
|
|
1028
|
-
all_y_proba = np.concatenate([p["y_proba"] for p in self.all_predictions])
|
|
1029
|
-
|
|
1030
|
-
# Get per-fold metrics for mean and std
|
|
1031
|
-
roc_values = []
|
|
1032
|
-
pr_values = []
|
|
1033
|
-
for metrics in self.fold_metrics.values():
|
|
1034
|
-
if "roc-auc" in metrics:
|
|
1035
|
-
val = metrics["roc-auc"]
|
|
1036
|
-
if isinstance(val, dict) and "value" in val:
|
|
1037
|
-
roc_values.append(val["value"])
|
|
1038
|
-
else:
|
|
1039
|
-
roc_values.append(val)
|
|
1040
|
-
if "pr-auc" in metrics:
|
|
1041
|
-
val = metrics["pr-auc"]
|
|
1042
|
-
if isinstance(val, dict) and "value" in val:
|
|
1043
|
-
pr_values.append(val["value"])
|
|
1044
|
-
else:
|
|
1045
|
-
pr_values.append(val)
|
|
1046
|
-
|
|
1047
|
-
# Calculate mean and std
|
|
1048
|
-
n_folds = len(self.fold_metrics)
|
|
1049
|
-
if roc_values:
|
|
1050
|
-
roc_mean = np.mean(roc_values)
|
|
1051
|
-
roc_std = np.std(roc_values)
|
|
1052
|
-
else:
|
|
1053
|
-
from .reporter_utils.metrics import calc_roc_auc
|
|
1054
|
-
|
|
1055
|
-
overall_roc = calc_roc_auc(all_y_true, all_y_proba)
|
|
1056
|
-
roc_mean = overall_roc["value"]
|
|
1057
|
-
roc_std = 0.0
|
|
1058
|
-
|
|
1059
|
-
if pr_values:
|
|
1060
|
-
pr_mean = np.mean(pr_values)
|
|
1061
|
-
pr_std = np.std(pr_values)
|
|
1062
|
-
else:
|
|
1063
|
-
from .reporter_utils.metrics import calc_pre_rec_auc
|
|
1064
|
-
|
|
1065
|
-
overall_pr = calc_pre_rec_auc(all_y_true, all_y_proba)
|
|
1066
|
-
pr_mean = overall_pr["value"]
|
|
1067
|
-
pr_std = 0.0
|
|
1068
|
-
|
|
1069
|
-
# Create cv_summary directory
|
|
1070
|
-
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
1071
|
-
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
1072
|
-
|
|
1073
|
-
# Save CV summary curve data
|
|
1074
|
-
self._save_cv_summary_curve_data(
|
|
1075
|
-
all_y_true,
|
|
1076
|
-
all_y_proba,
|
|
1077
|
-
roc_mean,
|
|
1078
|
-
roc_std,
|
|
1079
|
-
pr_mean,
|
|
1080
|
-
pr_std,
|
|
1081
|
-
n_folds,
|
|
1082
|
-
)
|
|
1083
|
-
|
|
1084
|
-
# Normalize labels to integers for sklearn curve functions in plotter
|
|
1085
|
-
from scitex.ai.metrics import _normalize_labels
|
|
1086
|
-
|
|
1087
|
-
all_y_true_norm, _, label_names, _ = _normalize_labels(all_y_true, all_y_true)
|
|
1088
|
-
|
|
1089
|
-
# ROC Curve with mean±std and n_folds in filename
|
|
1090
|
-
roc_title = f"ROC Curve (CV Summary) - AUC: {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
1091
|
-
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_jpg"].format(
|
|
1092
|
-
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
1093
|
-
)
|
|
1094
|
-
self.plotter.create_overall_roc_curve(
|
|
1095
|
-
all_y_true_norm,
|
|
1096
|
-
all_y_proba,
|
|
1097
|
-
labels=label_names,
|
|
1098
|
-
save_path=cv_summary_dir / roc_filename,
|
|
1099
|
-
title=roc_title,
|
|
1100
|
-
auc_mean=roc_mean,
|
|
1101
|
-
auc_std=roc_std,
|
|
1102
|
-
verbose=True,
|
|
1103
|
-
)
|
|
1104
|
-
|
|
1105
|
-
# PR Curve with mean±std and n_folds in filename
|
|
1106
|
-
pr_title = f"Precision-Recall Curve (CV Summary) - AP: {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
1107
|
-
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_jpg"].format(
|
|
1108
|
-
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
1109
|
-
)
|
|
1110
|
-
self.plotter.create_overall_pr_curve(
|
|
1111
|
-
all_y_true_norm,
|
|
1112
|
-
all_y_proba,
|
|
1113
|
-
labels=label_names,
|
|
1114
|
-
save_path=cv_summary_dir / pr_filename,
|
|
1115
|
-
title=pr_title,
|
|
1116
|
-
ap_mean=pr_mean,
|
|
1117
|
-
ap_std=pr_std,
|
|
1118
|
-
verbose=True,
|
|
1119
|
-
)
|
|
1120
|
-
|
|
1121
|
-
logger.info(
|
|
1122
|
-
f"Created CV summary ROC curve: AUC = {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
1123
|
-
)
|
|
1124
|
-
logger.info(
|
|
1125
|
-
f"Created CV summary PR curve: AP = {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
1126
|
-
)
|
|
1127
|
-
|
|
1128
|
-
def _save_cv_summary_curve_data(
|
|
1129
|
-
self,
|
|
1130
|
-
y_true: np.ndarray,
|
|
1131
|
-
y_proba: np.ndarray,
|
|
1132
|
-
roc_mean: float,
|
|
1133
|
-
roc_std: float,
|
|
1134
|
-
pr_mean: float,
|
|
1135
|
-
pr_std: float,
|
|
1136
|
-
n_folds: int,
|
|
1137
|
-
) -> None:
|
|
1138
|
-
"""Save CV summary ROC and PR curve data as CSV files with metric values in filenames."""
|
|
1139
|
-
from sklearn.metrics import (
|
|
1140
|
-
auc,
|
|
1141
|
-
average_precision_score,
|
|
1142
|
-
precision_recall_curve,
|
|
1143
|
-
roc_curve,
|
|
1144
|
-
)
|
|
1145
|
-
|
|
1146
|
-
cv_summary_dir = "cv_summary"
|
|
1147
|
-
|
|
1148
|
-
# Handle binary vs multiclass
|
|
1149
|
-
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
1150
|
-
# Binary classification
|
|
1151
|
-
if y_proba.ndim == 2:
|
|
1152
|
-
y_proba_pos = y_proba[:, 1]
|
|
1153
|
-
else:
|
|
1154
|
-
y_proba_pos = y_proba
|
|
1155
|
-
|
|
1156
|
-
# Normalize labels to integers for sklearn curve functions
|
|
1157
|
-
from scitex.ai.metrics import _normalize_labels
|
|
1158
|
-
|
|
1159
|
-
y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
|
|
1160
|
-
|
|
1161
|
-
# ROC curve data
|
|
1162
|
-
fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
|
|
1163
|
-
roc_auc = auc(fpr, tpr)
|
|
1164
|
-
|
|
1165
|
-
# Create ROC curve DataFrame with just FPR and TPR columns
|
|
1166
|
-
roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
|
|
1167
|
-
|
|
1168
|
-
# Save with mean±std and n_folds in filename
|
|
1169
|
-
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_csv"].format(
|
|
1170
|
-
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
1171
|
-
)
|
|
1172
|
-
self.storage.save(roc_df, f"{cv_summary_dir}/{roc_filename}")
|
|
1173
|
-
|
|
1174
|
-
# PR curve data
|
|
1175
|
-
precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
|
|
1176
|
-
avg_precision = average_precision_score(y_true_norm, y_proba_pos)
|
|
1177
|
-
|
|
1178
|
-
# Create PR curve DataFrame with Recall and Precision columns
|
|
1179
|
-
pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
|
|
1180
|
-
|
|
1181
|
-
# Save with mean±std and n_folds in filename
|
|
1182
|
-
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_csv"].format(
|
|
1183
|
-
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
1184
|
-
)
|
|
1185
|
-
self.storage.save(pr_df, f"{cv_summary_dir}/{pr_filename}")
|
|
1186
|
-
|
|
1187
|
-
def save_cv_summary_confusion_matrix(self, summary: Dict[str, Any]) -> None:
|
|
1188
|
-
"""
|
|
1189
|
-
Save and plot the CV summary confusion matrix.
|
|
1190
|
-
|
|
1191
|
-
Parameters
|
|
1192
|
-
----------
|
|
1193
|
-
summary : Dict[str, Any]
|
|
1194
|
-
Summary dictionary containing overall_confusion_matrix
|
|
1195
|
-
"""
|
|
1196
|
-
# Aggregate confusion matrices across all folds
|
|
1197
|
-
confusion_matrices = []
|
|
1198
|
-
for fold_metrics in self.fold_metrics.values():
|
|
1199
|
-
if "confusion_matrix" in fold_metrics:
|
|
1200
|
-
cm_data = fold_metrics["confusion_matrix"]
|
|
1201
|
-
# Extract actual value if wrapped in dict
|
|
1202
|
-
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
1203
|
-
cm_data = cm_data["value"]
|
|
1204
|
-
if cm_data is not None:
|
|
1205
|
-
confusion_matrices.append(cm_data)
|
|
1206
|
-
|
|
1207
|
-
if not confusion_matrices:
|
|
1208
|
-
return
|
|
1209
|
-
|
|
1210
|
-
# Sum all confusion matrices to get overall counts
|
|
1211
|
-
overall_cm = np.sum(confusion_matrices, axis=0)
|
|
1212
|
-
|
|
1213
|
-
# Get labels from one of the folds
|
|
1214
|
-
labels = None
|
|
1215
|
-
for fold_metrics in self.fold_metrics.values():
|
|
1216
|
-
# Labels are stored directly in fold_metrics now
|
|
1217
|
-
if "labels" in fold_metrics:
|
|
1218
|
-
labels = fold_metrics["labels"]
|
|
1219
|
-
break
|
|
1220
|
-
# Fallback: check if labels are in confusion_matrix dict
|
|
1221
|
-
elif "confusion_matrix" in fold_metrics:
|
|
1222
|
-
cm_data = fold_metrics["confusion_matrix"]
|
|
1223
|
-
if isinstance(cm_data, dict) and "labels" in cm_data:
|
|
1224
|
-
labels = cm_data["labels"]
|
|
1225
|
-
break
|
|
1226
|
-
|
|
1227
|
-
# Save as CSV with labels in cv_summary directory
|
|
1228
|
-
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
1229
|
-
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
1230
|
-
|
|
1231
|
-
# Get balanced accuracy stats for filename
|
|
1232
|
-
balanced_acc_mean = None
|
|
1233
|
-
balanced_acc_std = None
|
|
1234
|
-
n_folds = len(self.fold_metrics)
|
|
1235
|
-
if "metrics_summary" in summary:
|
|
1236
|
-
if "balanced-accuracy" in summary["metrics_summary"]:
|
|
1237
|
-
balanced_acc_stats = summary["metrics_summary"]["balanced-accuracy"]
|
|
1238
|
-
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
1239
|
-
balanced_acc_std = balanced_acc_stats.get("std")
|
|
1240
|
-
|
|
1241
|
-
# Create filename with mean±std and n_folds
|
|
1242
|
-
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
1243
|
-
cm_filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_csv"].format(
|
|
1244
|
-
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
1245
|
-
)
|
|
1246
|
-
else:
|
|
1247
|
-
cm_filename = FILENAME_PATTERNS[
|
|
1248
|
-
"cv_summary_confusion_matrix_csv_no_bacc"
|
|
1249
|
-
].format(n_folds=n_folds)
|
|
1250
|
-
|
|
1251
|
-
if labels:
|
|
1252
|
-
cm_df = pd.DataFrame(
|
|
1253
|
-
overall_cm,
|
|
1254
|
-
index=[f"True_{label}" for label in labels],
|
|
1255
|
-
columns=[f"Pred_{label}" for label in labels],
|
|
1256
|
-
)
|
|
1257
|
-
else:
|
|
1258
|
-
cm_df = pd.DataFrame(overall_cm)
|
|
1259
|
-
|
|
1260
|
-
# Save with proper filename (with index=True to preserve row labels)
|
|
1261
|
-
self.storage.save(cm_df, f"cv_summary/{cm_filename}", index=True)
|
|
1262
|
-
|
|
1263
|
-
# Create plot for CV summary confusion matrix
|
|
1264
|
-
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
1265
|
-
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
1266
|
-
|
|
1267
|
-
# Calculate balanced accuracy mean and std for overall confusion matrix title
|
|
1268
|
-
balanced_acc_mean = None
|
|
1269
|
-
balanced_acc_std = None
|
|
1270
|
-
if "metrics_summary" in self.get_summary():
|
|
1271
|
-
metrics_summary = self.get_summary()["metrics_summary"]
|
|
1272
|
-
if "balanced-accuracy" in metrics_summary:
|
|
1273
|
-
balanced_acc_stats = metrics_summary["balanced-accuracy"]
|
|
1274
|
-
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
1275
|
-
balanced_acc_std = balanced_acc_stats.get("std")
|
|
1276
|
-
|
|
1277
|
-
# Create title with balanced accuracy stats and filename with mean±std and n_folds
|
|
1278
|
-
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
1279
|
-
title = f"Confusion Matrix (CV Summary) - Balanced Acc: {balanced_acc_mean:.3f} ± {balanced_acc_std:.3f} (n={n_folds})"
|
|
1280
|
-
filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_jpg"].format(
|
|
1281
|
-
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
1282
|
-
)
|
|
1283
|
-
else:
|
|
1284
|
-
title = f"Confusion Matrix (CV Summary) (n={n_folds})"
|
|
1285
|
-
filename = FILENAME_PATTERNS[
|
|
1286
|
-
"cv_summary_confusion_matrix_jpg_no_bacc"
|
|
1287
|
-
].format(n_folds=n_folds)
|
|
1288
|
-
|
|
1289
|
-
# Create the plot with enhanced title
|
|
1290
|
-
self.plotter.create_confusion_matrix_plot(
|
|
1291
|
-
overall_cm,
|
|
1292
|
-
labels=labels,
|
|
1293
|
-
save_path=cv_summary_dir / filename,
|
|
1294
|
-
title=title,
|
|
1295
|
-
)
|
|
1296
|
-
|
|
1297
|
-
def generate_reports(self) -> Dict[str, Path]:
|
|
1298
|
-
"""
|
|
1299
|
-
Generate comprehensive reports in multiple formats.
|
|
1300
|
-
|
|
1301
|
-
Returns
|
|
1302
|
-
-------
|
|
1303
|
-
Dict[str, Path]
|
|
1304
|
-
Paths to generated report files
|
|
1305
|
-
"""
|
|
1306
|
-
# Prepare results dictionary for report generation
|
|
1307
|
-
results = {
|
|
1308
|
-
"config": {
|
|
1309
|
-
"n_folds": len(self.fold_metrics),
|
|
1310
|
-
"output_dir": str(self.output_dir),
|
|
1311
|
-
},
|
|
1312
|
-
"session_config": self.session_config, # Pass the SciTeX CONFIG
|
|
1313
|
-
"summary": {},
|
|
1314
|
-
"folds": [],
|
|
1315
|
-
"plots": {},
|
|
1316
|
-
}
|
|
1317
|
-
|
|
1318
|
-
# Get summary statistics
|
|
1319
|
-
summary = self.get_summary()
|
|
1320
|
-
|
|
1321
|
-
# Extract summary statistics for reporting
|
|
1322
|
-
if "metrics_summary" in summary:
|
|
1323
|
-
results["summary"] = summary["metrics_summary"]
|
|
1324
|
-
|
|
1325
|
-
# Add feature importance if available
|
|
1326
|
-
if "feature-importance" in summary:
|
|
1327
|
-
results["summary"]["feature-importance"] = summary["feature-importance"]
|
|
1328
|
-
|
|
1329
|
-
# Add per-fold results
|
|
1330
|
-
for fold, fold_data in self.fold_metrics.items():
|
|
1331
|
-
fold_result = {"fold_id": fold}
|
|
1332
|
-
fold_result.update(fold_data)
|
|
1333
|
-
|
|
1334
|
-
# Try to load sample size info from features.json
|
|
1335
|
-
# scitex.io.save transforms relative paths: adds storage_out to calling file's dir
|
|
1336
|
-
try:
|
|
1337
|
-
import json
|
|
1338
|
-
|
|
1339
|
-
# Construct the storage_out path where scitex.io.save actually saves files
|
|
1340
|
-
# Pattern: {calling_file_dir}/storage_out/{relative_path}
|
|
1341
|
-
calling_file_dir = Path(__file__).parent / "reporter_utils"
|
|
1342
|
-
storage_out_path = (
|
|
1343
|
-
calling_file_dir
|
|
1344
|
-
/ "storage_out"
|
|
1345
|
-
/ self.output_dir
|
|
1346
|
-
/ FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
1347
|
-
/ "features.json"
|
|
1348
|
-
)
|
|
1349
|
-
|
|
1350
|
-
# Also try regular path in case storage behavior changes
|
|
1351
|
-
regular_path = (
|
|
1352
|
-
self.output_dir
|
|
1353
|
-
/ FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
1354
|
-
/ "features.json"
|
|
1355
|
-
)
|
|
1356
|
-
|
|
1357
|
-
features_json = None
|
|
1358
|
-
if storage_out_path.exists():
|
|
1359
|
-
features_json = storage_out_path
|
|
1360
|
-
elif regular_path.exists():
|
|
1361
|
-
features_json = regular_path
|
|
1362
|
-
|
|
1363
|
-
if features_json:
|
|
1364
|
-
with open(features_json, "r") as f:
|
|
1365
|
-
features_data = json.load(f)
|
|
1366
|
-
# Add sample size info if available
|
|
1367
|
-
for key in [
|
|
1368
|
-
"n_train",
|
|
1369
|
-
"n_test",
|
|
1370
|
-
"n_train_seizure",
|
|
1371
|
-
"n_train_interictal",
|
|
1372
|
-
"n_test_seizure",
|
|
1373
|
-
"n_test_interictal",
|
|
1374
|
-
]:
|
|
1375
|
-
if key in features_data:
|
|
1376
|
-
fold_result[key] = int(features_data[key])
|
|
1377
|
-
except Exception:
|
|
1378
|
-
pass
|
|
1379
|
-
|
|
1380
|
-
results["folds"].append(fold_result)
|
|
1381
|
-
|
|
1382
|
-
# Add plot references with unified structure
|
|
1383
|
-
# CV summary plots in cv_summary directory
|
|
1384
|
-
cv_summary_dir = self.output_dir / "cv_summary"
|
|
1385
|
-
if cv_summary_dir.exists():
|
|
1386
|
-
for plot_file in cv_summary_dir.glob("*.jpg"):
|
|
1387
|
-
plot_key = f"cv_summary_{plot_file.stem}"
|
|
1388
|
-
results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
|
|
1389
|
-
|
|
1390
|
-
# Per-fold plots in fold directories
|
|
1391
|
-
for fold_dir in sorted(self.output_dir.glob("fold_*")):
|
|
1392
|
-
# Extract fold number (directory is fold_XX, filename starts with fold-XX)
|
|
1393
|
-
fold_num = fold_dir.name.replace("fold_", "")
|
|
1394
|
-
for plot_file in fold_dir.glob("*.jpg"):
|
|
1395
|
-
# Use just the stem as the plot key since filename already contains fold info
|
|
1396
|
-
# e.g., "fold-00_confusion-matrix_bacc-0.500" becomes plot key "fold_00_confusion-matrix"
|
|
1397
|
-
plot_key = f"fold_{fold_num}_{plot_file.stem}"
|
|
1398
|
-
results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
|
|
1399
|
-
|
|
1400
|
-
# Generate reports
|
|
1401
|
-
reports_dir = self._create_subdir_if_needed("reports")
|
|
1402
|
-
generated_files = {}
|
|
1403
|
-
|
|
1404
|
-
# Org-mode report (primary format) - will generate other formats via pandoc
|
|
1405
|
-
org_path = reports_dir / "classification_report.org"
|
|
1406
|
-
generate_org_report(results, org_path, include_plots=True, convert_formats=True)
|
|
1407
|
-
generated_files["org"] = org_path
|
|
1408
|
-
logger.info(f"Generated org-mode report: {org_path}")
|
|
1409
|
-
|
|
1410
|
-
# Note: Markdown, HTML, LaTeX, and DOCX are now generated via pandoc from org
|
|
1411
|
-
# This ensures consistency across all formats
|
|
1412
|
-
|
|
1413
|
-
# Try to compile LaTeX to PDF
|
|
1414
|
-
try:
|
|
1415
|
-
import shutil
|
|
1416
|
-
import subprocess
|
|
1417
|
-
|
|
1418
|
-
if shutil.which("pdflatex"):
|
|
1419
|
-
# Change to reports directory for compilation
|
|
1420
|
-
original_dir = Path.cwd()
|
|
1421
|
-
try:
|
|
1422
|
-
import os
|
|
1423
|
-
|
|
1424
|
-
os.chdir(reports_dir)
|
|
1425
|
-
|
|
1426
|
-
# Run pdflatex twice for proper references
|
|
1427
|
-
for _ in range(2):
|
|
1428
|
-
result = subprocess.run(
|
|
1429
|
-
[
|
|
1430
|
-
"pdflatex",
|
|
1431
|
-
"-interaction=nonstopmode",
|
|
1432
|
-
"classification_report.tex",
|
|
1433
|
-
],
|
|
1434
|
-
capture_output=True,
|
|
1435
|
-
text=True,
|
|
1436
|
-
timeout=30,
|
|
1437
|
-
)
|
|
1438
|
-
|
|
1439
|
-
pdf_path = reports_dir / "classification_report.pdf"
|
|
1440
|
-
if pdf_path.exists():
|
|
1441
|
-
generated_files["pdf"] = pdf_path
|
|
1442
|
-
logger.info(f"Generated PDF report: {pdf_path}")
|
|
1443
|
-
|
|
1444
|
-
# Clean up LaTeX auxiliary files
|
|
1445
|
-
for ext in [".aux", ".log", ".out", ".toc"]:
|
|
1446
|
-
aux_file = reports_dir / f"classification_report{ext}"
|
|
1447
|
-
if aux_file.exists():
|
|
1448
|
-
aux_file.unlink()
|
|
1449
|
-
finally:
|
|
1450
|
-
os.chdir(original_dir)
|
|
1451
|
-
else:
|
|
1452
|
-
logger.warning("pdflatex not found. Skipping PDF generation.")
|
|
1453
|
-
except Exception as e:
|
|
1454
|
-
logger.warning(f"Could not generate PDF report: {e}")
|
|
1455
|
-
|
|
1456
|
-
# Skip paper_exports - all data is already available in fold_XX/ and cv_summary/
|
|
1457
|
-
# with descriptive filenames perfect for sharing
|
|
1458
|
-
|
|
1459
|
-
return generated_files
|
|
1460
|
-
|
|
1461
150
|
def save_summary(
|
|
1462
151
|
self, filename: str = "cv_summary/summary.json", verbose: bool = True
|
|
1463
152
|
) -> Path:
|
|
@@ -1468,6 +157,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
1468
157
|
----------
|
|
1469
158
|
filename : str, default "cv_summary/summary.json"
|
|
1470
159
|
Filename for summary (now in cv_summary directory)
|
|
160
|
+
verbose : bool, default True
|
|
161
|
+
Print summary to console
|
|
1471
162
|
|
|
1472
163
|
Returns
|
|
1473
164
|
-------
|
|
@@ -1476,17 +167,11 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
1476
167
|
"""
|
|
1477
168
|
summary = self.get_summary()
|
|
1478
169
|
|
|
1479
|
-
# Try to load and include configuration
|
|
1480
170
|
try:
|
|
1481
|
-
# Try different possible locations for CONFIG.yaml
|
|
1482
171
|
possible_paths = [
|
|
1483
|
-
self.output_dir.parent
|
|
1484
|
-
/ "CONFIGS"
|
|
1485
|
-
/ "
|
|
1486
|
-
self.output_dir.parent.parent
|
|
1487
|
-
/ "CONFIGS"
|
|
1488
|
-
/ "CONFIG.yaml", # ../../CONFIGS/CONFIG.yaml
|
|
1489
|
-
self.output_dir / "CONFIGS" / "CONFIG.yaml", # ./CONFIGS/CONFIG.yaml
|
|
172
|
+
self.output_dir.parent / "CONFIGS" / "CONFIG.yaml",
|
|
173
|
+
self.output_dir.parent.parent / "CONFIGS" / "CONFIG.yaml",
|
|
174
|
+
self.output_dir / "CONFIGS" / "CONFIG.yaml",
|
|
1490
175
|
]
|
|
1491
176
|
|
|
1492
177
|
config_path = None
|
|
@@ -1498,33 +183,21 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
1498
183
|
if config_path and config_path.exists():
|
|
1499
184
|
import yaml
|
|
1500
185
|
|
|
1501
|
-
with open(config_path
|
|
186
|
+
with open(config_path) as config_file:
|
|
1502
187
|
config_data = yaml.safe_load(config_file)
|
|
1503
188
|
summary["experiment_configuration"] = config_data
|
|
1504
189
|
except Exception as e:
|
|
1505
190
|
logger.warning(f"Could not load CONFIG.yaml: {e}")
|
|
1506
191
|
|
|
1507
|
-
# Save CV summary metrics with proper filenames
|
|
1508
192
|
self._save_cv_summary_metrics(summary)
|
|
1509
|
-
|
|
1510
|
-
# Save and plot CV summary confusion matrix
|
|
1511
193
|
self.save_cv_summary_confusion_matrix(summary)
|
|
1512
|
-
|
|
1513
|
-
# Create CV summary ROC and PR curves
|
|
1514
194
|
self.create_cv_summary_curves(summary)
|
|
1515
|
-
|
|
1516
|
-
# Create CV aggregation visualizations with faded fold lines
|
|
1517
195
|
self.create_cv_aggregation_visualizations(
|
|
1518
196
|
show_individual_folds=True, fold_alpha=0.15
|
|
1519
197
|
)
|
|
1520
|
-
|
|
1521
|
-
# Save CV summary classification report
|
|
1522
198
|
self._save_cv_summary_classification_report(summary)
|
|
1523
|
-
|
|
1524
|
-
# Generate comprehensive reports in multiple formats
|
|
1525
199
|
self.generate_reports()
|
|
1526
200
|
|
|
1527
|
-
# Ensure cv_summary directory exists
|
|
1528
201
|
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
1529
202
|
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
1530
203
|
|
|
@@ -1533,239 +206,8 @@ class SingleTaskClassificationReporter(BaseClassificationReporter):
|
|
|
1533
206
|
logger.info("Summary:")
|
|
1534
207
|
pprint(summary)
|
|
1535
208
|
|
|
1536
|
-
# Save summary in cv_summary directory
|
|
1537
209
|
return self.storage.save(summary, "cv_summary/summary.json")
|
|
1538
210
|
|
|
1539
|
-
def _save_cv_summary_metrics(self, summary: Dict[str, Any]) -> None:
|
|
1540
|
-
"""
|
|
1541
|
-
Save individual CV summary metrics with mean/std/n_folds in filenames.
|
|
1542
|
-
"""
|
|
1543
|
-
if "metrics_summary" not in summary:
|
|
1544
|
-
return
|
|
1545
|
-
|
|
1546
|
-
n_folds = len(self.fold_metrics)
|
|
1547
|
-
cv_summary_dir = "cv_summary"
|
|
1548
|
-
|
|
1549
|
-
for metric_name, stats in summary["metrics_summary"].items():
|
|
1550
|
-
if isinstance(stats, dict) and "mean" in stats:
|
|
1551
|
-
mean_val = stats.get("mean", 0)
|
|
1552
|
-
std_val = stats.get("std", 0)
|
|
1553
|
-
|
|
1554
|
-
# Create filename with mean_std_n format
|
|
1555
|
-
filename = FILENAME_PATTERNS["cv_summary_metric"].format(
|
|
1556
|
-
metric_name=metric_name,
|
|
1557
|
-
mean=mean_val,
|
|
1558
|
-
std=std_val,
|
|
1559
|
-
n_folds=n_folds,
|
|
1560
|
-
)
|
|
1561
|
-
|
|
1562
|
-
# Save metric statistics
|
|
1563
|
-
self.storage.save(stats, f"{cv_summary_dir}/{filename}")
|
|
1564
|
-
|
|
1565
|
-
def _save_cv_summary_classification_report(self, summary: Dict[str, Any]) -> None:
|
|
1566
|
-
"""
|
|
1567
|
-
Save CV summary classification report with mean ± std (n_folds=X) format.
|
|
1568
|
-
"""
|
|
1569
|
-
n_folds = len(self.fold_metrics)
|
|
1570
|
-
cv_summary_dir = "cv_summary"
|
|
1571
|
-
|
|
1572
|
-
# Collect classification reports from all folds
|
|
1573
|
-
all_reports = []
|
|
1574
|
-
for fold_num, fold_metrics in self.fold_metrics.items():
|
|
1575
|
-
if "classification_report" in fold_metrics:
|
|
1576
|
-
report = fold_metrics["classification_report"]
|
|
1577
|
-
if isinstance(report, dict) and "value" in report:
|
|
1578
|
-
report = report["value"]
|
|
1579
|
-
|
|
1580
|
-
# Convert DataFrame to dict if needed
|
|
1581
|
-
if isinstance(report, pd.DataFrame):
|
|
1582
|
-
# Convert DataFrame to dict format expected by aggregation
|
|
1583
|
-
# Assumes DataFrame has 'class' column and metric columns
|
|
1584
|
-
if "class" in report.columns:
|
|
1585
|
-
report_dict = {}
|
|
1586
|
-
for _, row in report.iterrows():
|
|
1587
|
-
class_name = row["class"]
|
|
1588
|
-
report_dict[class_name] = {
|
|
1589
|
-
col: row[col]
|
|
1590
|
-
for col in report.columns
|
|
1591
|
-
if col != "class"
|
|
1592
|
-
}
|
|
1593
|
-
report = report_dict
|
|
1594
|
-
else:
|
|
1595
|
-
# DataFrame with class names as index
|
|
1596
|
-
report = report.to_dict("index")
|
|
1597
|
-
|
|
1598
|
-
if isinstance(report, dict):
|
|
1599
|
-
all_reports.append(report)
|
|
1600
|
-
|
|
1601
|
-
if not all_reports:
|
|
1602
|
-
return
|
|
1603
|
-
|
|
1604
|
-
# Calculate mean and std for each metric in the classification report
|
|
1605
|
-
summary_report = {}
|
|
1606
|
-
|
|
1607
|
-
# Get all class labels (excluding summary rows)
|
|
1608
|
-
all_classes = set()
|
|
1609
|
-
for report in all_reports:
|
|
1610
|
-
all_classes.update(
|
|
1611
|
-
[
|
|
1612
|
-
k
|
|
1613
|
-
for k in report.keys()
|
|
1614
|
-
if k not in ["accuracy", "macro avg", "weighted avg"]
|
|
1615
|
-
]
|
|
1616
|
-
)
|
|
1617
|
-
|
|
1618
|
-
# Process each class
|
|
1619
|
-
for cls in sorted(all_classes):
|
|
1620
|
-
cls_metrics = {
|
|
1621
|
-
"precision": [],
|
|
1622
|
-
"recall": [],
|
|
1623
|
-
"f1-score": [],
|
|
1624
|
-
"support": [],
|
|
1625
|
-
}
|
|
1626
|
-
|
|
1627
|
-
for report in all_reports:
|
|
1628
|
-
if cls in report:
|
|
1629
|
-
for metric in [
|
|
1630
|
-
"precision",
|
|
1631
|
-
"recall",
|
|
1632
|
-
"f1-score",
|
|
1633
|
-
"support",
|
|
1634
|
-
]:
|
|
1635
|
-
if metric in report[cls]:
|
|
1636
|
-
cls_metrics[metric].append(report[cls][metric])
|
|
1637
|
-
|
|
1638
|
-
summary_report[cls] = {}
|
|
1639
|
-
for metric, values in cls_metrics.items():
|
|
1640
|
-
if values:
|
|
1641
|
-
if metric == "support":
|
|
1642
|
-
# For support, show total and mean±std to capture variability
|
|
1643
|
-
total_support = int(np.sum(values))
|
|
1644
|
-
mean_support = np.mean(values)
|
|
1645
|
-
std_support = np.std(values)
|
|
1646
|
-
if std_support > 0:
|
|
1647
|
-
# Show mean±std if there's variability across folds
|
|
1648
|
-
summary_report[cls][metric] = (
|
|
1649
|
-
f"{mean_support:.1f} ± {std_support:.1f} (total={total_support})"
|
|
1650
|
-
)
|
|
1651
|
-
else:
|
|
1652
|
-
# If constant across folds, just show the value
|
|
1653
|
-
summary_report[cls][metric] = (
|
|
1654
|
-
f"{int(mean_support)} per fold (total={total_support})"
|
|
1655
|
-
)
|
|
1656
|
-
else:
|
|
1657
|
-
mean_val = np.mean(values)
|
|
1658
|
-
std_val = np.std(values)
|
|
1659
|
-
summary_report[cls][metric] = (
|
|
1660
|
-
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
1661
|
-
)
|
|
1662
|
-
|
|
1663
|
-
# Process summary rows (macro avg, weighted avg)
|
|
1664
|
-
for avg_type in ["macro avg", "weighted avg"]:
|
|
1665
|
-
avg_metrics = {"precision": [], "recall": [], "f1-score": []}
|
|
1666
|
-
|
|
1667
|
-
for report in all_reports:
|
|
1668
|
-
if avg_type in report:
|
|
1669
|
-
for metric in ["precision", "recall", "f1-score"]:
|
|
1670
|
-
if metric in report[avg_type]:
|
|
1671
|
-
avg_metrics[metric].append(report[avg_type][metric])
|
|
1672
|
-
|
|
1673
|
-
if any(avg_metrics.values()):
|
|
1674
|
-
summary_report[avg_type] = {}
|
|
1675
|
-
for metric, values in avg_metrics.items():
|
|
1676
|
-
if values:
|
|
1677
|
-
mean_val = np.mean(values)
|
|
1678
|
-
std_val = np.std(values)
|
|
1679
|
-
summary_report[avg_type][metric] = (
|
|
1680
|
-
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
1681
|
-
)
|
|
1682
|
-
|
|
1683
|
-
# Convert to DataFrame for better visualization
|
|
1684
|
-
if summary_report:
|
|
1685
|
-
report_df = pd.DataFrame(summary_report).T
|
|
1686
|
-
# Reset index to make it an ordinary column with name
|
|
1687
|
-
report_df = report_df.reset_index()
|
|
1688
|
-
report_df = report_df.rename(columns={"index": "class"})
|
|
1689
|
-
|
|
1690
|
-
# Save as CSV
|
|
1691
|
-
filename = FILENAME_PATTERNS["cv_summary_classification_report"].format(
|
|
1692
|
-
n_folds=n_folds
|
|
1693
|
-
)
|
|
1694
|
-
self.storage.save(
|
|
1695
|
-
report_df,
|
|
1696
|
-
f"{cv_summary_dir}/{filename}",
|
|
1697
|
-
)
|
|
1698
|
-
|
|
1699
|
-
def save(
|
|
1700
|
-
self,
|
|
1701
|
-
data: Any,
|
|
1702
|
-
relative_path: Union[str, Path],
|
|
1703
|
-
fold: Optional[int] = None,
|
|
1704
|
-
) -> Path:
|
|
1705
|
-
"""
|
|
1706
|
-
Save custom data with automatic fold organization and filename prefixing.
|
|
1707
|
-
|
|
1708
|
-
Parameters
|
|
1709
|
-
----------
|
|
1710
|
-
data : Any
|
|
1711
|
-
Custom data to save (any format supported by stx.io.save)
|
|
1712
|
-
relative_path : Union[str, Path]
|
|
1713
|
-
Relative path from output_dir or fold directory. Examples:
|
|
1714
|
-
- When fold is provided: "custom_metrics.json" → "fold_00/fold-00_custom_metrics.json"
|
|
1715
|
-
- When fold is None: "cv_summary/results.csv" → "cv_summary/results.csv"
|
|
1716
|
-
fold : Optional[int], default None
|
|
1717
|
-
If provided, automatically prepends "fold_{fold:02d}/" to the path
|
|
1718
|
-
and adds "fold-{fold:02d}_" prefix to the filename
|
|
1719
|
-
|
|
1720
|
-
Returns
|
|
1721
|
-
-------
|
|
1722
|
-
Path
|
|
1723
|
-
Absolute path to the saved file
|
|
1724
|
-
|
|
1725
|
-
Examples
|
|
1726
|
-
--------
|
|
1727
|
-
>>> # Save custom metrics for fold 0 (automatic fold directory and prefix)
|
|
1728
|
-
>>> reporter.save(
|
|
1729
|
-
... {"metric1": 0.95, "metric2": 0.87},
|
|
1730
|
-
... "custom_metrics.json",
|
|
1731
|
-
... fold=0
|
|
1732
|
-
... ) # Saves to: fold_00/fold-00_custom_metrics.json
|
|
1733
|
-
|
|
1734
|
-
>>> # Save to cv_summary (no fold, no prefix)
|
|
1735
|
-
>>> reporter.save(
|
|
1736
|
-
... df_results,
|
|
1737
|
-
... "cv_summary/final_analysis.csv"
|
|
1738
|
-
... )
|
|
1739
|
-
|
|
1740
|
-
>>> # Save to reports directory
|
|
1741
|
-
>>> reporter.save(
|
|
1742
|
-
... report_content,
|
|
1743
|
-
... "reports/analysis.md"
|
|
1744
|
-
... )
|
|
1745
|
-
"""
|
|
1746
|
-
# Automatically prepend fold directory and prefix filename if fold is provided
|
|
1747
|
-
if fold is not None:
|
|
1748
|
-
# Parse the path to add prefix to filename only
|
|
1749
|
-
path_obj = Path(relative_path)
|
|
1750
|
-
filename = path_obj.name
|
|
1751
|
-
parent = path_obj.parent
|
|
1752
|
-
|
|
1753
|
-
# Add fold prefix to filename (e.g., "fold-00_custom_metrics.json")
|
|
1754
|
-
prefixed_filename = (
|
|
1755
|
-
f"{FOLD_FILE_PREFIX_PATTERN.format(fold=fold)}_{filename}"
|
|
1756
|
-
)
|
|
1757
|
-
|
|
1758
|
-
# Construct full path: fold_00/fold-00_filename.ext
|
|
1759
|
-
if parent and str(parent) != ".":
|
|
1760
|
-
relative_path = f"{FOLD_DIR_PREFIX_PATTERN.format(fold=fold)}/{parent}/{prefixed_filename}"
|
|
1761
|
-
else:
|
|
1762
|
-
relative_path = (
|
|
1763
|
-
f"{FOLD_DIR_PREFIX_PATTERN.format(fold=fold)}/{prefixed_filename}"
|
|
1764
|
-
)
|
|
1765
|
-
|
|
1766
|
-
# Use the existing storage.save method which already handles everything
|
|
1767
|
-
return self.storage.save(data, relative_path)
|
|
1768
|
-
|
|
1769
211
|
def __repr__(self) -> str:
|
|
1770
212
|
fold_count = len(self.fold_metrics)
|
|
1771
213
|
return (
|