scitex 2.14.0__py3-none-any.whl → 2.15.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +71 -17
- scitex/_env_loader.py +156 -0
- scitex/_mcp_resources/__init__.py +37 -0
- scitex/_mcp_resources/_cheatsheet.py +135 -0
- scitex/_mcp_resources/_figrecipe.py +138 -0
- scitex/_mcp_resources/_formats.py +102 -0
- scitex/_mcp_resources/_modules.py +337 -0
- scitex/_mcp_resources/_session.py +149 -0
- scitex/_mcp_tools/__init__.py +4 -0
- scitex/_mcp_tools/audio.py +66 -0
- scitex/_mcp_tools/diagram.py +11 -95
- scitex/_mcp_tools/introspect.py +210 -0
- scitex/_mcp_tools/plt.py +260 -305
- scitex/_mcp_tools/scholar.py +74 -0
- scitex/_mcp_tools/social.py +244 -0
- scitex/_mcp_tools/template.py +24 -0
- scitex/_mcp_tools/writer.py +21 -204
- scitex/ai/_gen_ai/_PARAMS.py +10 -7
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
- scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
- scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
- scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
- scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
- scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
- scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
- scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
- scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
- scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +30 -1550
- scitex/ai/classification/timeseries/_sliding_window_core.py +467 -0
- scitex/ai/classification/timeseries/_sliding_window_plotting.py +369 -0
- scitex/audio/README.md +40 -36
- scitex/audio/__init__.py +129 -61
- scitex/audio/_branding.py +185 -0
- scitex/audio/_mcp/__init__.py +32 -0
- scitex/audio/_mcp/handlers.py +59 -6
- scitex/audio/_mcp/speak_handlers.py +238 -0
- scitex/audio/_relay.py +225 -0
- scitex/audio/_tts.py +18 -10
- scitex/audio/engines/base.py +17 -10
- scitex/audio/engines/elevenlabs_engine.py +7 -2
- scitex/audio/mcp_server.py +228 -75
- scitex/canvas/README.md +1 -1
- scitex/canvas/editor/_dearpygui/__init__.py +25 -0
- scitex/canvas/editor/_dearpygui/_editor.py +147 -0
- scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
- scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
- scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
- scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
- scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
- scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
- scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
- scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
- scitex/canvas/editor/_dearpygui/_selection.py +295 -0
- scitex/canvas/editor/_dearpygui/_state.py +93 -0
- scitex/canvas/editor/_dearpygui/_utils.py +61 -0
- scitex/canvas/editor/flask_editor/_core/__init__.py +27 -0
- scitex/canvas/editor/flask_editor/_core/_bbox_extraction.py +200 -0
- scitex/canvas/editor/flask_editor/_core/_editor.py +173 -0
- scitex/canvas/editor/flask_editor/_core/_export_helpers.py +353 -0
- scitex/canvas/editor/flask_editor/_core/_routes_basic.py +190 -0
- scitex/canvas/editor/flask_editor/_core/_routes_export.py +332 -0
- scitex/canvas/editor/flask_editor/_core/_routes_panels.py +252 -0
- scitex/canvas/editor/flask_editor/_core/_routes_save.py +218 -0
- scitex/canvas/editor/flask_editor/_core.py +25 -1684
- scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
- scitex/cli/__init__.py +38 -43
- scitex/cli/audio.py +76 -27
- scitex/cli/capture.py +13 -20
- scitex/cli/introspect.py +481 -0
- scitex/cli/main.py +200 -109
- scitex/cli/mcp.py +60 -34
- scitex/cli/plt.py +357 -0
- scitex/cli/repro.py +15 -8
- scitex/cli/resource.py +15 -8
- scitex/cli/scholar/__init__.py +23 -8
- scitex/cli/scholar/_crossref_scitex.py +296 -0
- scitex/cli/scholar/_fetch.py +25 -3
- scitex/cli/social.py +314 -0
- scitex/cli/stats.py +15 -8
- scitex/cli/template.py +129 -12
- scitex/cli/tex.py +15 -8
- scitex/cli/writer.py +132 -8
- scitex/cloud/__init__.py +41 -2
- scitex/config/README.md +1 -1
- scitex/config/__init__.py +16 -2
- scitex/config/_env_registry.py +256 -0
- scitex/context/__init__.py +22 -0
- scitex/dev/__init__.py +20 -1
- scitex/diagram/__init__.py +42 -19
- scitex/diagram/mcp_server.py +13 -125
- scitex/gen/__init__.py +50 -14
- scitex/gen/_list_packages.py +4 -4
- scitex/introspect/__init__.py +82 -0
- scitex/introspect/_call_graph.py +303 -0
- scitex/introspect/_class_hierarchy.py +163 -0
- scitex/introspect/_core.py +41 -0
- scitex/introspect/_docstring.py +131 -0
- scitex/introspect/_examples.py +113 -0
- scitex/introspect/_imports.py +271 -0
- scitex/{gen/_inspect_module.py → introspect/_list_api.py} +43 -54
- scitex/introspect/_mcp/__init__.py +41 -0
- scitex/introspect/_mcp/handlers.py +233 -0
- scitex/introspect/_members.py +155 -0
- scitex/introspect/_resolve.py +89 -0
- scitex/introspect/_signature.py +131 -0
- scitex/introspect/_source.py +80 -0
- scitex/introspect/_type_hints.py +172 -0
- scitex/io/_save.py +1 -2
- scitex/io/bundle/README.md +1 -1
- scitex/logging/_formatters.py +19 -9
- scitex/mcp_server.py +98 -5
- scitex/os/__init__.py +4 -0
- scitex/{gen → os}/_check_host.py +4 -5
- scitex/plt/__init__.py +245 -550
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
- scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/plt/gallery/README.md +1 -1
- scitex/plt/utils/_hitmap/__init__.py +82 -0
- scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
- scitex/plt/utils/_hitmap/_color_application.py +346 -0
- scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
- scitex/plt/utils/_hitmap/_constants.py +40 -0
- scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
- scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
- scitex/plt/utils/_hitmap/_query.py +113 -0
- scitex/plt/utils/_hitmap.py +46 -1616
- scitex/plt/utils/_metadata/__init__.py +80 -0
- scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
- scitex/plt/utils/_metadata/_artists/_base.py +195 -0
- scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
- scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
- scitex/plt/utils/_metadata/_artists/_images.py +80 -0
- scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
- scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
- scitex/plt/utils/_metadata/_artists/_text.py +106 -0
- scitex/plt/utils/_metadata/_csv.py +416 -0
- scitex/plt/utils/_metadata/_detect.py +225 -0
- scitex/plt/utils/_metadata/_legend.py +127 -0
- scitex/plt/utils/_metadata/_rounding.py +117 -0
- scitex/plt/utils/_metadata/_verification.py +202 -0
- scitex/schema/README.md +1 -1
- scitex/scholar/__init__.py +8 -0
- scitex/scholar/_mcp/crossref_handlers.py +265 -0
- scitex/scholar/core/Scholar.py +63 -1700
- scitex/scholar/core/_mixins/__init__.py +36 -0
- scitex/scholar/core/_mixins/_enrichers.py +270 -0
- scitex/scholar/core/_mixins/_library_handlers.py +100 -0
- scitex/scholar/core/_mixins/_loaders.py +103 -0
- scitex/scholar/core/_mixins/_pdf_download.py +375 -0
- scitex/scholar/core/_mixins/_pipeline.py +312 -0
- scitex/scholar/core/_mixins/_project_handlers.py +125 -0
- scitex/scholar/core/_mixins/_savers.py +69 -0
- scitex/scholar/core/_mixins/_search.py +103 -0
- scitex/scholar/core/_mixins/_services.py +88 -0
- scitex/scholar/core/_mixins/_url_finding.py +105 -0
- scitex/scholar/crossref_scitex.py +367 -0
- scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/scholar/examples/00_run_all.sh +120 -0
- scitex/scholar/jobs/_executors.py +27 -3
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
- scitex/scholar/pdf_download/_cli.py +154 -0
- scitex/scholar/pdf_download/strategies/__init__.py +11 -8
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
- scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
- scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
- scitex/scholar/pipelines/_single_steps.py +71 -36
- scitex/scholar/storage/_LibraryManager.py +97 -1695
- scitex/scholar/storage/_mixins/__init__.py +30 -0
- scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
- scitex/scholar/storage/_mixins/_library_operations.py +218 -0
- scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
- scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
- scitex/scholar/storage/_mixins/_resolution.py +376 -0
- scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
- scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
- scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
- scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
- scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
- scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
- scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
- scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
- scitex/security/README.md +3 -3
- scitex/session/README.md +1 -1
- scitex/session/__init__.py +26 -7
- scitex/session/_decorator.py +1 -1
- scitex/sh/README.md +1 -1
- scitex/sh/__init__.py +7 -4
- scitex/social/__init__.py +155 -0
- scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/stats/_mcp/_handlers/__init__.py +31 -0
- scitex/stats/_mcp/_handlers/_corrections.py +113 -0
- scitex/stats/_mcp/_handlers/_descriptive.py +78 -0
- scitex/stats/_mcp/_handlers/_effect_size.py +106 -0
- scitex/stats/_mcp/_handlers/_format.py +94 -0
- scitex/stats/_mcp/_handlers/_normality.py +110 -0
- scitex/stats/_mcp/_handlers/_posthoc.py +224 -0
- scitex/stats/_mcp/_handlers/_power.py +247 -0
- scitex/stats/_mcp/_handlers/_recommend.py +102 -0
- scitex/stats/_mcp/_handlers/_run_test.py +279 -0
- scitex/stats/_mcp/_handlers/_stars.py +48 -0
- scitex/stats/_mcp/handlers.py +19 -1171
- scitex/stats/auto/_stat_style.py +175 -0
- scitex/stats/auto/_style_definitions.py +411 -0
- scitex/stats/auto/_styles.py +22 -620
- scitex/stats/descriptive/__init__.py +11 -8
- scitex/stats/descriptive/_ci.py +39 -0
- scitex/stats/power/_power.py +15 -4
- scitex/str/__init__.py +2 -1
- scitex/str/_title_case.py +63 -0
- scitex/template/README.md +1 -1
- scitex/template/__init__.py +25 -10
- scitex/template/_code_templates.py +147 -0
- scitex/template/_mcp/handlers.py +81 -0
- scitex/template/_mcp/tool_schemas.py +55 -0
- scitex/template/_templates/__init__.py +51 -0
- scitex/template/_templates/audio.py +233 -0
- scitex/template/_templates/canvas.py +312 -0
- scitex/template/_templates/capture.py +268 -0
- scitex/template/_templates/config.py +43 -0
- scitex/template/_templates/diagram.py +294 -0
- scitex/template/_templates/io.py +107 -0
- scitex/template/_templates/module.py +53 -0
- scitex/template/_templates/plt.py +202 -0
- scitex/template/_templates/scholar.py +267 -0
- scitex/template/_templates/session.py +130 -0
- scitex/template/_templates/session_minimal.py +43 -0
- scitex/template/_templates/session_plot.py +67 -0
- scitex/template/_templates/session_stats.py +77 -0
- scitex/template/_templates/stats.py +323 -0
- scitex/template/_templates/writer.py +296 -0
- scitex/template/clone_writer_directory.py +5 -5
- scitex/ui/_backends/_email.py +10 -2
- scitex/ui/_backends/_webhook.py +5 -1
- scitex/web/_search_pubmed.py +10 -6
- scitex/writer/README.md +1 -1
- scitex/writer/_mcp/handlers.py +11 -744
- scitex/writer/_mcp/tool_schemas.py +5 -335
- scitex-2.15.2.dist-info/METADATA +648 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/RECORD +246 -150
- scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
- scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
- scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
- scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
- scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
- scitex/diagram/_compile.py +0 -312
- scitex/diagram/_diagram.py +0 -355
- scitex/diagram/_mcp/__init__.py +0 -4
- scitex/diagram/_mcp/handlers.py +0 -400
- scitex/diagram/_mcp/tool_schemas.py +0 -157
- scitex/diagram/_presets.py +0 -173
- scitex/diagram/_schema.py +0 -182
- scitex/diagram/_split.py +0 -278
- scitex/gen/_ci.py +0 -12
- scitex/gen/_title_case.py +0 -89
- scitex/plt/_mcp/__init__.py +0 -4
- scitex/plt/_mcp/_handlers_annotation.py +0 -102
- scitex/plt/_mcp/_handlers_figure.py +0 -195
- scitex/plt/_mcp/_handlers_plot.py +0 -252
- scitex/plt/_mcp/_handlers_style.py +0 -219
- scitex/plt/_mcp/handlers.py +0 -74
- scitex/plt/_mcp/tool_schemas.py +0 -497
- scitex/plt/mcp_server.py +0 -231
- scitex/scholar/data/.gitkeep +0 -0
- scitex/scholar/data/README.md +0 -44
- scitex/scholar/data/bib_files/bibliography.bib +0 -1952
- scitex/scholar/data/bib_files/neurovista.bib +0 -277
- scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
- scitex/scholar/data/bib_files/openaccess.bib +0 -89
- scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
- scitex/scholar/data/bib_files/pac.bib +0 -698
- scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
- scitex/scholar/data/bib_files/pac_processed.bib +0 -0
- scitex/scholar/data/bib_files/pac_titles.txt +0 -75
- scitex/scholar/data/bib_files/paywalled.bib +0 -98
- scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
- scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
- scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
- scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
- scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_seizure.bib +0 -46
- scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
- scitex/scholar/data/impact_factor.db +0 -0
- scitex/scholar/examples/SUGGESTIONS.md +0 -865
- scitex/scholar/examples/dev.py +0 -38
- scitex-2.14.0.dist-info/METADATA +0 -1238
- /scitex/{gen → context}/_detect_environment.py +0 -0
- /scitex/{gen → context}/_get_notebook_path.py +0 -0
- /scitex/{gen/_shell.py → sh/_shell_legacy.py} +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/WHEEL +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/entry_points.txt +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/__init__.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Mixin classes for SingleTaskClassificationReporter.
|
|
7
|
+
|
|
8
|
+
Each mixin provides a specific set of methods for the reporter class.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from ._constants import (
|
|
12
|
+
FILENAME_PATTERNS,
|
|
13
|
+
FOLD_DIR_PREFIX_PATTERN,
|
|
14
|
+
FOLD_FILE_PREFIX_PATTERN,
|
|
15
|
+
)
|
|
16
|
+
from ._cv_summary import CVSummaryMixin
|
|
17
|
+
from ._feature_importance import FeatureImportanceMixin
|
|
18
|
+
from ._metrics import MetricsMixin
|
|
19
|
+
from ._plotting import PlottingMixin
|
|
20
|
+
from ._reports import ReportsMixin
|
|
21
|
+
from ._storage import StorageMixin
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"FILENAME_PATTERNS",
|
|
25
|
+
"FOLD_DIR_PREFIX_PATTERN",
|
|
26
|
+
"FOLD_FILE_PREFIX_PATTERN",
|
|
27
|
+
"MetricsMixin",
|
|
28
|
+
"StorageMixin",
|
|
29
|
+
"PlottingMixin",
|
|
30
|
+
"FeatureImportanceMixin",
|
|
31
|
+
"CVSummaryMixin",
|
|
32
|
+
"ReportsMixin",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# EOF
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_constants.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Constants for classification reporter file naming.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Fold directory and filename prefixes for consistent naming
|
|
10
|
+
FOLD_DIR_PREFIX_PATTERN = "fold_{fold:02d}" # Directory: fold_00, fold_01, ...
|
|
11
|
+
FOLD_FILE_PREFIX_PATTERN = "fold-{fold:02d}" # Filename prefix: fold-00_, fold-01_, ...
|
|
12
|
+
|
|
13
|
+
# Filename patterns for consistent naming across the reporter
|
|
14
|
+
# Note: fold-{fold:02d} comes first to group files by fold when sorted
|
|
15
|
+
# Convention: hyphens within chunks, underscores between chunks
|
|
16
|
+
FILENAME_PATTERNS = {
|
|
17
|
+
# Individual fold metrics (with metric value in filename)
|
|
18
|
+
"fold_metric_with_value": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}-{{value:.3f}}.json",
|
|
19
|
+
"fold_metric": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}.json",
|
|
20
|
+
# Confusion matrix
|
|
21
|
+
"confusion_matrix_csv": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.csv",
|
|
22
|
+
"confusion_matrix_csv_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.csv",
|
|
23
|
+
"confusion_matrix_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.jpg",
|
|
24
|
+
"confusion_matrix_jpg_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.jpg",
|
|
25
|
+
# Classification report
|
|
26
|
+
"classification_report": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.csv",
|
|
27
|
+
# ROC curve
|
|
28
|
+
"roc_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.csv",
|
|
29
|
+
"roc_curve_csv_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.csv",
|
|
30
|
+
"roc_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.jpg",
|
|
31
|
+
"roc_curve_jpg_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.jpg",
|
|
32
|
+
# PR curve
|
|
33
|
+
"pr_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.csv",
|
|
34
|
+
"pr_curve_csv_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.csv",
|
|
35
|
+
"pr_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.jpg",
|
|
36
|
+
"pr_curve_jpg_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.jpg",
|
|
37
|
+
# Raw prediction data
|
|
38
|
+
"y_true": f"{FOLD_FILE_PREFIX_PATTERN}_y-true.csv",
|
|
39
|
+
"y_pred": f"{FOLD_FILE_PREFIX_PATTERN}_y-pred.csv",
|
|
40
|
+
"y_proba": f"{FOLD_FILE_PREFIX_PATTERN}_y-proba.csv",
|
|
41
|
+
# Metrics dashboard
|
|
42
|
+
"metrics_summary": f"{FOLD_FILE_PREFIX_PATTERN}_metrics-summary.jpg",
|
|
43
|
+
# Feature importance
|
|
44
|
+
"feature_importance_json": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.json",
|
|
45
|
+
"feature_importance_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.jpg",
|
|
46
|
+
# Classification report edge cases
|
|
47
|
+
"classification_report_json": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.json",
|
|
48
|
+
"classification_report_txt": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.txt",
|
|
49
|
+
# CV summary
|
|
50
|
+
"cv_summary_metric": "cv-summary_{metric_name}_mean-{mean:.3f}_std-{std:.3f}_n-{n_folds}.json",
|
|
51
|
+
"cv_summary_confusion_matrix_csv": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
52
|
+
"cv_summary_confusion_matrix_jpg": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
53
|
+
"cv_summary_classification_report": "cv-summary_classification-report_n-{n_folds}.csv",
|
|
54
|
+
"cv_summary_roc_curve_csv": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
55
|
+
"cv_summary_roc_curve_jpg": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
56
|
+
"cv_summary_pr_curve_csv": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
57
|
+
"cv_summary_pr_curve_jpg": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
58
|
+
"cv_summary_feature_importance_json": "cv-summary_feature-importance_n-{n_folds}.json",
|
|
59
|
+
"cv_summary_feature_importance_jpg": "cv-summary_feature-importance_n-{n_folds}.jpg",
|
|
60
|
+
"cv_summary_summary": "cv-summary_summary.json",
|
|
61
|
+
# CV summary edge cases
|
|
62
|
+
"cv_summary_confusion_matrix_csv_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.csv",
|
|
63
|
+
"cv_summary_confusion_matrix_jpg_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.jpg",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# EOF
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_cv_summary.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
CV summary mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any, Dict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from scitex.logging import getLogger
|
|
17
|
+
|
|
18
|
+
from ._constants import FILENAME_PATTERNS
|
|
19
|
+
|
|
20
|
+
logger = getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CVSummaryMixin:
|
|
24
|
+
"""Mixin providing CV summary methods."""
|
|
25
|
+
|
|
26
|
+
def create_cv_summary_curves(self, summary: Dict[str, Any]) -> None:
|
|
27
|
+
"""Create CV summary ROC and PR curves from aggregated predictions."""
|
|
28
|
+
if not self.all_predictions:
|
|
29
|
+
logger.warning("No predictions stored for CV summary curves")
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
all_y_true = np.concatenate([p["y_true"] for p in self.all_predictions])
|
|
33
|
+
all_y_proba = np.concatenate([p["y_proba"] for p in self.all_predictions])
|
|
34
|
+
|
|
35
|
+
roc_values = []
|
|
36
|
+
pr_values = []
|
|
37
|
+
for metrics in self.fold_metrics.values():
|
|
38
|
+
if "roc-auc" in metrics:
|
|
39
|
+
val = metrics["roc-auc"]
|
|
40
|
+
if isinstance(val, dict) and "value" in val:
|
|
41
|
+
roc_values.append(val["value"])
|
|
42
|
+
else:
|
|
43
|
+
roc_values.append(val)
|
|
44
|
+
if "pr-auc" in metrics:
|
|
45
|
+
val = metrics["pr-auc"]
|
|
46
|
+
if isinstance(val, dict) and "value" in val:
|
|
47
|
+
pr_values.append(val["value"])
|
|
48
|
+
else:
|
|
49
|
+
pr_values.append(val)
|
|
50
|
+
|
|
51
|
+
n_folds = len(self.fold_metrics)
|
|
52
|
+
if roc_values:
|
|
53
|
+
roc_mean = np.mean(roc_values)
|
|
54
|
+
roc_std = np.std(roc_values)
|
|
55
|
+
else:
|
|
56
|
+
from ..reporter_utils.metrics import calc_roc_auc
|
|
57
|
+
|
|
58
|
+
overall_roc = calc_roc_auc(all_y_true, all_y_proba)
|
|
59
|
+
roc_mean = overall_roc["value"]
|
|
60
|
+
roc_std = 0.0
|
|
61
|
+
|
|
62
|
+
if pr_values:
|
|
63
|
+
pr_mean = np.mean(pr_values)
|
|
64
|
+
pr_std = np.std(pr_values)
|
|
65
|
+
else:
|
|
66
|
+
from ..reporter_utils.metrics import calc_pre_rec_auc
|
|
67
|
+
|
|
68
|
+
overall_pr = calc_pre_rec_auc(all_y_true, all_y_proba)
|
|
69
|
+
pr_mean = overall_pr["value"]
|
|
70
|
+
pr_std = 0.0
|
|
71
|
+
|
|
72
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
73
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
|
|
75
|
+
self._save_cv_summary_curve_data(
|
|
76
|
+
all_y_true, all_y_proba, roc_mean, roc_std, pr_mean, pr_std, n_folds
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
from scitex.ai.metrics import _normalize_labels
|
|
80
|
+
|
|
81
|
+
all_y_true_norm, _, label_names, _ = _normalize_labels(all_y_true, all_y_true)
|
|
82
|
+
|
|
83
|
+
roc_title = f"ROC Curve (CV Summary) - AUC: {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
84
|
+
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_jpg"].format(
|
|
85
|
+
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
86
|
+
)
|
|
87
|
+
self.plotter.create_overall_roc_curve(
|
|
88
|
+
all_y_true_norm,
|
|
89
|
+
all_y_proba,
|
|
90
|
+
labels=label_names,
|
|
91
|
+
save_path=cv_summary_dir / roc_filename,
|
|
92
|
+
title=roc_title,
|
|
93
|
+
auc_mean=roc_mean,
|
|
94
|
+
auc_std=roc_std,
|
|
95
|
+
verbose=True,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
pr_title = f"Precision-Recall Curve (CV Summary) - AP: {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
99
|
+
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_jpg"].format(
|
|
100
|
+
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
101
|
+
)
|
|
102
|
+
self.plotter.create_overall_pr_curve(
|
|
103
|
+
all_y_true_norm,
|
|
104
|
+
all_y_proba,
|
|
105
|
+
labels=label_names,
|
|
106
|
+
save_path=cv_summary_dir / pr_filename,
|
|
107
|
+
title=pr_title,
|
|
108
|
+
ap_mean=pr_mean,
|
|
109
|
+
ap_std=pr_std,
|
|
110
|
+
verbose=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
logger.info(
|
|
114
|
+
f"Created CV summary ROC curve: AUC = {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
115
|
+
)
|
|
116
|
+
logger.info(
|
|
117
|
+
f"Created CV summary PR curve: AP = {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _save_cv_summary_curve_data(
|
|
121
|
+
self,
|
|
122
|
+
y_true: np.ndarray,
|
|
123
|
+
y_proba: np.ndarray,
|
|
124
|
+
roc_mean: float,
|
|
125
|
+
roc_std: float,
|
|
126
|
+
pr_mean: float,
|
|
127
|
+
pr_std: float,
|
|
128
|
+
n_folds: int,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Save CV summary ROC and PR curve data as CSV files."""
|
|
131
|
+
from sklearn.metrics import (
|
|
132
|
+
auc,
|
|
133
|
+
average_precision_score,
|
|
134
|
+
precision_recall_curve,
|
|
135
|
+
roc_curve,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
cv_summary_dir = "cv_summary"
|
|
139
|
+
|
|
140
|
+
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
141
|
+
if y_proba.ndim == 2:
|
|
142
|
+
y_proba_pos = y_proba[:, 1]
|
|
143
|
+
else:
|
|
144
|
+
y_proba_pos = y_proba
|
|
145
|
+
|
|
146
|
+
from scitex.ai.metrics import _normalize_labels
|
|
147
|
+
|
|
148
|
+
y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
|
|
149
|
+
|
|
150
|
+
fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
|
|
151
|
+
roc_auc = auc(fpr, tpr)
|
|
152
|
+
roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
|
|
153
|
+
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_csv"].format(
|
|
154
|
+
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
155
|
+
)
|
|
156
|
+
self.storage.save(roc_df, f"{cv_summary_dir}/{roc_filename}")
|
|
157
|
+
|
|
158
|
+
precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
|
|
159
|
+
avg_precision = average_precision_score(y_true_norm, y_proba_pos)
|
|
160
|
+
pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
|
|
161
|
+
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_csv"].format(
|
|
162
|
+
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
163
|
+
)
|
|
164
|
+
self.storage.save(pr_df, f"{cv_summary_dir}/{pr_filename}")
|
|
165
|
+
|
|
166
|
+
def save_cv_summary_confusion_matrix(self, summary: Dict[str, Any]) -> None:
|
|
167
|
+
"""Save and plot the CV summary confusion matrix."""
|
|
168
|
+
confusion_matrices = []
|
|
169
|
+
for fold_metrics in self.fold_metrics.values():
|
|
170
|
+
if "confusion_matrix" in fold_metrics:
|
|
171
|
+
cm_data = fold_metrics["confusion_matrix"]
|
|
172
|
+
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
173
|
+
cm_data = cm_data["value"]
|
|
174
|
+
if cm_data is not None:
|
|
175
|
+
confusion_matrices.append(cm_data)
|
|
176
|
+
|
|
177
|
+
if not confusion_matrices:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
overall_cm = np.sum(confusion_matrices, axis=0)
|
|
181
|
+
|
|
182
|
+
labels = None
|
|
183
|
+
for fold_metrics in self.fold_metrics.values():
|
|
184
|
+
if "labels" in fold_metrics:
|
|
185
|
+
labels = fold_metrics["labels"]
|
|
186
|
+
break
|
|
187
|
+
elif "confusion_matrix" in fold_metrics:
|
|
188
|
+
cm_data = fold_metrics["confusion_matrix"]
|
|
189
|
+
if isinstance(cm_data, dict) and "labels" in cm_data:
|
|
190
|
+
labels = cm_data["labels"]
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
194
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
balanced_acc_mean = None
|
|
197
|
+
balanced_acc_std = None
|
|
198
|
+
n_folds = len(self.fold_metrics)
|
|
199
|
+
if "metrics_summary" in summary:
|
|
200
|
+
if "balanced-accuracy" in summary["metrics_summary"]:
|
|
201
|
+
balanced_acc_stats = summary["metrics_summary"]["balanced-accuracy"]
|
|
202
|
+
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
203
|
+
balanced_acc_std = balanced_acc_stats.get("std")
|
|
204
|
+
|
|
205
|
+
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
206
|
+
cm_filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_csv"].format(
|
|
207
|
+
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
cm_filename = FILENAME_PATTERNS[
|
|
211
|
+
"cv_summary_confusion_matrix_csv_no_bacc"
|
|
212
|
+
].format(n_folds=n_folds)
|
|
213
|
+
|
|
214
|
+
if labels:
|
|
215
|
+
cm_df = pd.DataFrame(
|
|
216
|
+
overall_cm,
|
|
217
|
+
index=[f"True_{label}" for label in labels],
|
|
218
|
+
columns=[f"Pred_{label}" for label in labels],
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
cm_df = pd.DataFrame(overall_cm)
|
|
222
|
+
|
|
223
|
+
self.storage.save(cm_df, f"cv_summary/{cm_filename}", index=True)
|
|
224
|
+
|
|
225
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
226
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
227
|
+
|
|
228
|
+
balanced_acc_mean = None
|
|
229
|
+
balanced_acc_std = None
|
|
230
|
+
if "metrics_summary" in self.get_summary():
|
|
231
|
+
metrics_summary = self.get_summary()["metrics_summary"]
|
|
232
|
+
if "balanced-accuracy" in metrics_summary:
|
|
233
|
+
balanced_acc_stats = metrics_summary["balanced-accuracy"]
|
|
234
|
+
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
235
|
+
balanced_acc_std = balanced_acc_stats.get("std")
|
|
236
|
+
|
|
237
|
+
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
238
|
+
title = f"Confusion Matrix (CV Summary) - Balanced Acc: {balanced_acc_mean:.3f} ± {balanced_acc_std:.3f} (n={n_folds})"
|
|
239
|
+
filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_jpg"].format(
|
|
240
|
+
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
title = f"Confusion Matrix (CV Summary) (n={n_folds})"
|
|
244
|
+
filename = FILENAME_PATTERNS[
|
|
245
|
+
"cv_summary_confusion_matrix_jpg_no_bacc"
|
|
246
|
+
].format(n_folds=n_folds)
|
|
247
|
+
|
|
248
|
+
self.plotter.create_confusion_matrix_plot(
|
|
249
|
+
overall_cm,
|
|
250
|
+
labels=labels,
|
|
251
|
+
save_path=cv_summary_dir / filename,
|
|
252
|
+
title=title,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def _save_cv_summary_metrics(self, summary: Dict[str, Any]) -> None:
|
|
256
|
+
"""Save individual CV summary metrics with mean/std/n_folds in filenames."""
|
|
257
|
+
if "metrics_summary" not in summary:
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
n_folds = len(self.fold_metrics)
|
|
261
|
+
cv_summary_dir = "cv_summary"
|
|
262
|
+
|
|
263
|
+
for metric_name, stats in summary["metrics_summary"].items():
|
|
264
|
+
if isinstance(stats, dict) and "mean" in stats:
|
|
265
|
+
mean_val = stats.get("mean", 0)
|
|
266
|
+
std_val = stats.get("std", 0)
|
|
267
|
+
|
|
268
|
+
filename = FILENAME_PATTERNS["cv_summary_metric"].format(
|
|
269
|
+
metric_name=metric_name,
|
|
270
|
+
mean=mean_val,
|
|
271
|
+
std=std_val,
|
|
272
|
+
n_folds=n_folds,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
self.storage.save(stats, f"{cv_summary_dir}/{filename}")
|
|
276
|
+
|
|
277
|
+
def _save_cv_summary_classification_report(self, summary: Dict[str, Any]) -> None:
|
|
278
|
+
"""Save CV summary classification report with mean ± std (n_folds=X) format."""
|
|
279
|
+
n_folds = len(self.fold_metrics)
|
|
280
|
+
cv_summary_dir = "cv_summary"
|
|
281
|
+
|
|
282
|
+
all_reports = []
|
|
283
|
+
for fold_num, fold_metrics in self.fold_metrics.items():
|
|
284
|
+
if "classification_report" in fold_metrics:
|
|
285
|
+
report = fold_metrics["classification_report"]
|
|
286
|
+
if isinstance(report, dict) and "value" in report:
|
|
287
|
+
report = report["value"]
|
|
288
|
+
|
|
289
|
+
if isinstance(report, pd.DataFrame):
|
|
290
|
+
if "class" in report.columns:
|
|
291
|
+
report_dict = {}
|
|
292
|
+
for _, row in report.iterrows():
|
|
293
|
+
class_name = row["class"]
|
|
294
|
+
report_dict[class_name] = {
|
|
295
|
+
col: row[col]
|
|
296
|
+
for col in report.columns
|
|
297
|
+
if col != "class"
|
|
298
|
+
}
|
|
299
|
+
report = report_dict
|
|
300
|
+
else:
|
|
301
|
+
report = report.to_dict("index")
|
|
302
|
+
|
|
303
|
+
if isinstance(report, dict):
|
|
304
|
+
all_reports.append(report)
|
|
305
|
+
|
|
306
|
+
if not all_reports:
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
summary_report = {}
|
|
310
|
+
|
|
311
|
+
all_classes = set()
|
|
312
|
+
for report in all_reports:
|
|
313
|
+
all_classes.update(
|
|
314
|
+
[
|
|
315
|
+
k
|
|
316
|
+
for k in report.keys()
|
|
317
|
+
if k not in ["accuracy", "macro avg", "weighted avg"]
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
for cls in sorted(all_classes):
|
|
322
|
+
cls_metrics = {
|
|
323
|
+
"precision": [],
|
|
324
|
+
"recall": [],
|
|
325
|
+
"f1-score": [],
|
|
326
|
+
"support": [],
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
for report in all_reports:
|
|
330
|
+
if cls in report:
|
|
331
|
+
for metric in ["precision", "recall", "f1-score", "support"]:
|
|
332
|
+
if metric in report[cls]:
|
|
333
|
+
cls_metrics[metric].append(report[cls][metric])
|
|
334
|
+
|
|
335
|
+
summary_report[cls] = {}
|
|
336
|
+
for metric, values in cls_metrics.items():
|
|
337
|
+
if values:
|
|
338
|
+
if metric == "support":
|
|
339
|
+
total_support = int(np.sum(values))
|
|
340
|
+
mean_support = np.mean(values)
|
|
341
|
+
std_support = np.std(values)
|
|
342
|
+
if std_support > 0:
|
|
343
|
+
summary_report[cls][metric] = (
|
|
344
|
+
f"{mean_support:.1f} ± {std_support:.1f} (total={total_support})"
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
summary_report[cls][metric] = (
|
|
348
|
+
f"{int(mean_support)} per fold (total={total_support})"
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
mean_val = np.mean(values)
|
|
352
|
+
std_val = np.std(values)
|
|
353
|
+
summary_report[cls][metric] = (
|
|
354
|
+
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
for avg_type in ["macro avg", "weighted avg"]:
|
|
358
|
+
avg_metrics = {"precision": [], "recall": [], "f1-score": []}
|
|
359
|
+
|
|
360
|
+
for report in all_reports:
|
|
361
|
+
if avg_type in report:
|
|
362
|
+
for metric in ["precision", "recall", "f1-score"]:
|
|
363
|
+
if metric in report[avg_type]:
|
|
364
|
+
avg_metrics[metric].append(report[avg_type][metric])
|
|
365
|
+
|
|
366
|
+
if any(avg_metrics.values()):
|
|
367
|
+
summary_report[avg_type] = {}
|
|
368
|
+
for metric, values in avg_metrics.items():
|
|
369
|
+
if values:
|
|
370
|
+
mean_val = np.mean(values)
|
|
371
|
+
std_val = np.std(values)
|
|
372
|
+
summary_report[avg_type][metric] = (
|
|
373
|
+
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if summary_report:
|
|
377
|
+
report_df = pd.DataFrame(summary_report).T
|
|
378
|
+
report_df = report_df.reset_index()
|
|
379
|
+
report_df = report_df.rename(columns={"index": "class"})
|
|
380
|
+
|
|
381
|
+
filename = FILENAME_PATTERNS["cv_summary_classification_report"].format(
|
|
382
|
+
n_folds=n_folds
|
|
383
|
+
)
|
|
384
|
+
self.storage.save(report_df, f"{cv_summary_dir}/{filename}")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# EOF
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_feature_importance.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Feature importance mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from scitex.logging import getLogger
|
|
16
|
+
|
|
17
|
+
from ._constants import FILENAME_PATTERNS, FOLD_DIR_PREFIX_PATTERN
|
|
18
|
+
|
|
19
|
+
logger = getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FeatureImportanceMixin:
|
|
23
|
+
"""Mixin providing feature importance methods."""
|
|
24
|
+
|
|
25
|
+
def save_feature_importance(
|
|
26
|
+
self,
|
|
27
|
+
model,
|
|
28
|
+
feature_names: List[str],
|
|
29
|
+
fold: Optional[int] = None,
|
|
30
|
+
) -> Dict[str, float]:
|
|
31
|
+
"""Calculate and save feature importance for tree-based models."""
|
|
32
|
+
from scitex.ai.metrics import calc_feature_importance
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
importance_dict, importances = calc_feature_importance(model, feature_names)
|
|
36
|
+
except ValueError as e:
|
|
37
|
+
logger.warning(f"Could not extract feature importance: {e}")
|
|
38
|
+
return {}
|
|
39
|
+
|
|
40
|
+
sorted_importances = list(importance_dict.items())
|
|
41
|
+
|
|
42
|
+
fold_subdir = (
|
|
43
|
+
FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
44
|
+
if fold is not None
|
|
45
|
+
else "cv_summary"
|
|
46
|
+
)
|
|
47
|
+
json_filename = FILENAME_PATTERNS["feature_importance_json"].format(fold=fold)
|
|
48
|
+
self.storage.save(dict(sorted_importances), f"{fold_subdir}/{json_filename}")
|
|
49
|
+
|
|
50
|
+
jpg_filename = FILENAME_PATTERNS["feature_importance_jpg"].format(fold=fold)
|
|
51
|
+
save_path = self.output_dir / fold_subdir / jpg_filename
|
|
52
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
self.plotter.create_feature_importance_plot(
|
|
55
|
+
feature_importance=importances,
|
|
56
|
+
feature_names=feature_names,
|
|
57
|
+
save_path=save_path,
|
|
58
|
+
title=(
|
|
59
|
+
f"Feature Importance (Fold {fold:02d})"
|
|
60
|
+
if fold is not None
|
|
61
|
+
else "Feature Importance (CV Summary)"
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
logger.info(
|
|
66
|
+
"Saved feature importance"
|
|
67
|
+
+ (f" for fold {fold}" if fold is not None else "")
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return importance_dict
|
|
71
|
+
|
|
72
|
+
def save_feature_importance_summary(
|
|
73
|
+
self,
|
|
74
|
+
all_importances: List[Dict[str, float]],
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Create summary visualization of feature importances across all folds."""
|
|
77
|
+
if not all_importances:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
all_features = set()
|
|
81
|
+
for imp_dict in all_importances:
|
|
82
|
+
all_features.update(imp_dict.keys())
|
|
83
|
+
|
|
84
|
+
feature_stats = {}
|
|
85
|
+
for feature in all_features:
|
|
86
|
+
values = [imp_dict.get(feature, 0) for imp_dict in all_importances]
|
|
87
|
+
feature_stats[feature] = {
|
|
88
|
+
"mean": float(np.mean(values)),
|
|
89
|
+
"std": float(np.std(values)),
|
|
90
|
+
"values": [float(v) for v in values],
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
sorted_features = sorted(
|
|
94
|
+
feature_stats.items(), key=lambda x: x[1]["mean"], reverse=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
n_folds = len(all_importances)
|
|
98
|
+
json_filename = FILENAME_PATTERNS["cv_summary_feature_importance_json"].format(
|
|
99
|
+
n_folds=n_folds
|
|
100
|
+
)
|
|
101
|
+
self.storage.save(dict(sorted_features), f"cv_summary/{json_filename}")
|
|
102
|
+
|
|
103
|
+
from scitex.ai.plt import plot_feature_importance_cv_summary
|
|
104
|
+
|
|
105
|
+
jpg_filename = FILENAME_PATTERNS["cv_summary_feature_importance_jpg"].format(
|
|
106
|
+
n_folds=n_folds
|
|
107
|
+
)
|
|
108
|
+
save_path = self.output_dir / "cv_summary" / jpg_filename
|
|
109
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
|
|
111
|
+
plot_feature_importance_cv_summary(
|
|
112
|
+
all_importances=all_importances,
|
|
113
|
+
spath=save_path,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
logger.info("Saved feature importance summary")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# EOF
|