scitex 2.14.0__py3-none-any.whl → 2.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +47 -0
- scitex/_env_loader.py +156 -0
- scitex/_mcp_resources/__init__.py +37 -0
- scitex/_mcp_resources/_cheatsheet.py +135 -0
- scitex/_mcp_resources/_figrecipe.py +138 -0
- scitex/_mcp_resources/_formats.py +102 -0
- scitex/_mcp_resources/_modules.py +337 -0
- scitex/_mcp_resources/_session.py +149 -0
- scitex/_mcp_tools/__init__.py +4 -0
- scitex/_mcp_tools/audio.py +66 -0
- scitex/_mcp_tools/diagram.py +11 -95
- scitex/_mcp_tools/introspect.py +191 -0
- scitex/_mcp_tools/plt.py +260 -305
- scitex/_mcp_tools/scholar.py +74 -0
- scitex/_mcp_tools/social.py +244 -0
- scitex/_mcp_tools/writer.py +21 -204
- scitex/ai/_gen_ai/_PARAMS.py +10 -7
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
- scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
- scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
- scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
- scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
- scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
- scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
- scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
- scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
- scitex/audio/README.md +40 -36
- scitex/audio/__init__.py +127 -59
- scitex/audio/_branding.py +185 -0
- scitex/audio/_mcp/__init__.py +32 -0
- scitex/audio/_mcp/handlers.py +59 -6
- scitex/audio/_mcp/speak_handlers.py +238 -0
- scitex/audio/_relay.py +225 -0
- scitex/audio/engines/elevenlabs_engine.py +6 -1
- scitex/audio/mcp_server.py +228 -75
- scitex/canvas/README.md +1 -1
- scitex/canvas/editor/_dearpygui/__init__.py +25 -0
- scitex/canvas/editor/_dearpygui/_editor.py +147 -0
- scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
- scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
- scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
- scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
- scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
- scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
- scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
- scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
- scitex/canvas/editor/_dearpygui/_selection.py +295 -0
- scitex/canvas/editor/_dearpygui/_state.py +93 -0
- scitex/canvas/editor/_dearpygui/_utils.py +61 -0
- scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
- scitex/cli/__init__.py +38 -43
- scitex/cli/audio.py +76 -27
- scitex/cli/capture.py +13 -20
- scitex/cli/introspect.py +443 -0
- scitex/cli/main.py +198 -109
- scitex/cli/mcp.py +60 -34
- scitex/cli/scholar/__init__.py +8 -0
- scitex/cli/scholar/_crossref_scitex.py +296 -0
- scitex/cli/scholar/_fetch.py +25 -3
- scitex/cli/social.py +314 -0
- scitex/cli/writer.py +117 -0
- scitex/config/README.md +1 -1
- scitex/config/__init__.py +16 -2
- scitex/config/_env_registry.py +191 -0
- scitex/diagram/__init__.py +42 -19
- scitex/diagram/mcp_server.py +13 -125
- scitex/introspect/__init__.py +75 -0
- scitex/introspect/_call_graph.py +303 -0
- scitex/introspect/_class_hierarchy.py +163 -0
- scitex/introspect/_core.py +42 -0
- scitex/introspect/_docstring.py +131 -0
- scitex/introspect/_examples.py +113 -0
- scitex/introspect/_imports.py +271 -0
- scitex/introspect/_mcp/__init__.py +37 -0
- scitex/introspect/_mcp/handlers.py +208 -0
- scitex/introspect/_members.py +151 -0
- scitex/introspect/_resolve.py +89 -0
- scitex/introspect/_signature.py +131 -0
- scitex/introspect/_source.py +80 -0
- scitex/introspect/_type_hints.py +172 -0
- scitex/io/bundle/README.md +1 -1
- scitex/mcp_server.py +98 -5
- scitex/plt/__init__.py +248 -550
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
- scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/plt/gallery/README.md +1 -1
- scitex/plt/utils/_hitmap/__init__.py +82 -0
- scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
- scitex/plt/utils/_hitmap/_color_application.py +346 -0
- scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
- scitex/plt/utils/_hitmap/_constants.py +40 -0
- scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
- scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
- scitex/plt/utils/_hitmap/_query.py +113 -0
- scitex/plt/utils/_hitmap.py +46 -1616
- scitex/plt/utils/_metadata/__init__.py +80 -0
- scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
- scitex/plt/utils/_metadata/_artists/_base.py +195 -0
- scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
- scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
- scitex/plt/utils/_metadata/_artists/_images.py +80 -0
- scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
- scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
- scitex/plt/utils/_metadata/_artists/_text.py +106 -0
- scitex/plt/utils/_metadata/_csv.py +416 -0
- scitex/plt/utils/_metadata/_detect.py +225 -0
- scitex/plt/utils/_metadata/_legend.py +127 -0
- scitex/plt/utils/_metadata/_rounding.py +117 -0
- scitex/plt/utils/_metadata/_verification.py +202 -0
- scitex/schema/README.md +1 -1
- scitex/scholar/__init__.py +8 -0
- scitex/scholar/_mcp/crossref_handlers.py +265 -0
- scitex/scholar/core/Scholar.py +63 -1700
- scitex/scholar/core/_mixins/__init__.py +36 -0
- scitex/scholar/core/_mixins/_enrichers.py +270 -0
- scitex/scholar/core/_mixins/_library_handlers.py +100 -0
- scitex/scholar/core/_mixins/_loaders.py +103 -0
- scitex/scholar/core/_mixins/_pdf_download.py +375 -0
- scitex/scholar/core/_mixins/_pipeline.py +312 -0
- scitex/scholar/core/_mixins/_project_handlers.py +125 -0
- scitex/scholar/core/_mixins/_savers.py +69 -0
- scitex/scholar/core/_mixins/_search.py +103 -0
- scitex/scholar/core/_mixins/_services.py +88 -0
- scitex/scholar/core/_mixins/_url_finding.py +105 -0
- scitex/scholar/crossref_scitex.py +367 -0
- scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/scholar/examples/00_run_all.sh +120 -0
- scitex/scholar/jobs/_executors.py +27 -3
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
- scitex/scholar/pdf_download/_cli.py +154 -0
- scitex/scholar/pdf_download/strategies/__init__.py +11 -8
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
- scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
- scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
- scitex/scholar/pipelines/_single_steps.py +71 -36
- scitex/scholar/storage/_LibraryManager.py +97 -1695
- scitex/scholar/storage/_mixins/__init__.py +30 -0
- scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
- scitex/scholar/storage/_mixins/_library_operations.py +218 -0
- scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
- scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
- scitex/scholar/storage/_mixins/_resolution.py +376 -0
- scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
- scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
- scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
- scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
- scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
- scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
- scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
- scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
- scitex/security/README.md +3 -3
- scitex/session/README.md +1 -1
- scitex/sh/README.md +1 -1
- scitex/social/__init__.py +153 -0
- scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/template/README.md +1 -1
- scitex/template/clone_writer_directory.py +5 -5
- scitex/writer/README.md +1 -1
- scitex/writer/_mcp/handlers.py +11 -744
- scitex/writer/_mcp/tool_schemas.py +5 -335
- scitex-2.15.1.dist-info/METADATA +648 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/RECORD +166 -111
- scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
- scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
- scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
- scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
- scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
- scitex/diagram/_compile.py +0 -312
- scitex/diagram/_diagram.py +0 -355
- scitex/diagram/_mcp/__init__.py +0 -4
- scitex/diagram/_mcp/handlers.py +0 -400
- scitex/diagram/_mcp/tool_schemas.py +0 -157
- scitex/diagram/_presets.py +0 -173
- scitex/diagram/_schema.py +0 -182
- scitex/diagram/_split.py +0 -278
- scitex/plt/_mcp/__init__.py +0 -4
- scitex/plt/_mcp/_handlers_annotation.py +0 -102
- scitex/plt/_mcp/_handlers_figure.py +0 -195
- scitex/plt/_mcp/_handlers_plot.py +0 -252
- scitex/plt/_mcp/_handlers_style.py +0 -219
- scitex/plt/_mcp/handlers.py +0 -74
- scitex/plt/_mcp/tool_schemas.py +0 -497
- scitex/plt/mcp_server.py +0 -231
- scitex/scholar/data/.gitkeep +0 -0
- scitex/scholar/data/README.md +0 -44
- scitex/scholar/data/bib_files/bibliography.bib +0 -1952
- scitex/scholar/data/bib_files/neurovista.bib +0 -277
- scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
- scitex/scholar/data/bib_files/openaccess.bib +0 -89
- scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
- scitex/scholar/data/bib_files/pac.bib +0 -698
- scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
- scitex/scholar/data/bib_files/pac_processed.bib +0 -0
- scitex/scholar/data/bib_files/pac_titles.txt +0 -75
- scitex/scholar/data/bib_files/paywalled.bib +0 -98
- scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
- scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
- scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
- scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
- scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_seizure.bib +0 -46
- scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
- scitex/scholar/data/impact_factor.db +0 -0
- scitex/scholar/examples/SUGGESTIONS.md +0 -865
- scitex/scholar/examples/dev.py +0 -38
- scitex-2.14.0.dist-info/METADATA +0 -1238
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/WHEEL +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/entry_points.txt +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/__init__.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Mixin classes for SingleTaskClassificationReporter.
|
|
7
|
+
|
|
8
|
+
Each mixin provides a specific set of methods for the reporter class.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from ._constants import (
|
|
12
|
+
FILENAME_PATTERNS,
|
|
13
|
+
FOLD_DIR_PREFIX_PATTERN,
|
|
14
|
+
FOLD_FILE_PREFIX_PATTERN,
|
|
15
|
+
)
|
|
16
|
+
from ._cv_summary import CVSummaryMixin
|
|
17
|
+
from ._feature_importance import FeatureImportanceMixin
|
|
18
|
+
from ._metrics import MetricsMixin
|
|
19
|
+
from ._plotting import PlottingMixin
|
|
20
|
+
from ._reports import ReportsMixin
|
|
21
|
+
from ._storage import StorageMixin
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"FILENAME_PATTERNS",
|
|
25
|
+
"FOLD_DIR_PREFIX_PATTERN",
|
|
26
|
+
"FOLD_FILE_PREFIX_PATTERN",
|
|
27
|
+
"MetricsMixin",
|
|
28
|
+
"StorageMixin",
|
|
29
|
+
"PlottingMixin",
|
|
30
|
+
"FeatureImportanceMixin",
|
|
31
|
+
"CVSummaryMixin",
|
|
32
|
+
"ReportsMixin",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# EOF
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_constants.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Constants for classification reporter file naming.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Fold directory and filename prefixes for consistent naming
|
|
10
|
+
FOLD_DIR_PREFIX_PATTERN = "fold_{fold:02d}" # Directory: fold_00, fold_01, ...
|
|
11
|
+
FOLD_FILE_PREFIX_PATTERN = "fold-{fold:02d}" # Filename prefix: fold-00_, fold-01_, ...
|
|
12
|
+
|
|
13
|
+
# Filename patterns for consistent naming across the reporter
|
|
14
|
+
# Note: fold-{fold:02d} comes first to group files by fold when sorted
|
|
15
|
+
# Convention: hyphens within chunks, underscores between chunks
|
|
16
|
+
FILENAME_PATTERNS = {
|
|
17
|
+
# Individual fold metrics (with metric value in filename)
|
|
18
|
+
"fold_metric_with_value": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}-{{value:.3f}}.json",
|
|
19
|
+
"fold_metric": f"{FOLD_FILE_PREFIX_PATTERN}_{{metric_name}}.json",
|
|
20
|
+
# Confusion matrix
|
|
21
|
+
"confusion_matrix_csv": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.csv",
|
|
22
|
+
"confusion_matrix_csv_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.csv",
|
|
23
|
+
"confusion_matrix_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix_bacc-{{bacc:.3f}}.jpg",
|
|
24
|
+
"confusion_matrix_jpg_no_bacc": f"{FOLD_FILE_PREFIX_PATTERN}_confusion-matrix.jpg",
|
|
25
|
+
# Classification report
|
|
26
|
+
"classification_report": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.csv",
|
|
27
|
+
# ROC curve
|
|
28
|
+
"roc_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.csv",
|
|
29
|
+
"roc_curve_csv_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.csv",
|
|
30
|
+
"roc_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve_auc-{{auc:.3f}}.jpg",
|
|
31
|
+
"roc_curve_jpg_no_auc": f"{FOLD_FILE_PREFIX_PATTERN}_roc-curve.jpg",
|
|
32
|
+
# PR curve
|
|
33
|
+
"pr_curve_csv": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.csv",
|
|
34
|
+
"pr_curve_csv_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.csv",
|
|
35
|
+
"pr_curve_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve_ap-{{ap:.3f}}.jpg",
|
|
36
|
+
"pr_curve_jpg_no_ap": f"{FOLD_FILE_PREFIX_PATTERN}_pr-curve.jpg",
|
|
37
|
+
# Raw prediction data
|
|
38
|
+
"y_true": f"{FOLD_FILE_PREFIX_PATTERN}_y-true.csv",
|
|
39
|
+
"y_pred": f"{FOLD_FILE_PREFIX_PATTERN}_y-pred.csv",
|
|
40
|
+
"y_proba": f"{FOLD_FILE_PREFIX_PATTERN}_y-proba.csv",
|
|
41
|
+
# Metrics dashboard
|
|
42
|
+
"metrics_summary": f"{FOLD_FILE_PREFIX_PATTERN}_metrics-summary.jpg",
|
|
43
|
+
# Feature importance
|
|
44
|
+
"feature_importance_json": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.json",
|
|
45
|
+
"feature_importance_jpg": f"{FOLD_FILE_PREFIX_PATTERN}_feature-importance.jpg",
|
|
46
|
+
# Classification report edge cases
|
|
47
|
+
"classification_report_json": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.json",
|
|
48
|
+
"classification_report_txt": f"{FOLD_FILE_PREFIX_PATTERN}_classification-report.txt",
|
|
49
|
+
# CV summary
|
|
50
|
+
"cv_summary_metric": "cv-summary_{metric_name}_mean-{mean:.3f}_std-{std:.3f}_n-{n_folds}.json",
|
|
51
|
+
"cv_summary_confusion_matrix_csv": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
52
|
+
"cv_summary_confusion_matrix_jpg": "cv-summary_confusion-matrix_bacc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
53
|
+
"cv_summary_classification_report": "cv-summary_classification-report_n-{n_folds}.csv",
|
|
54
|
+
"cv_summary_roc_curve_csv": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
55
|
+
"cv_summary_roc_curve_jpg": "cv-summary_roc-curve_auc-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
56
|
+
"cv_summary_pr_curve_csv": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.csv",
|
|
57
|
+
"cv_summary_pr_curve_jpg": "cv-summary_pr-curve_ap-{mean:.3f}_{std:.3f}_n-{n_folds}.jpg",
|
|
58
|
+
"cv_summary_feature_importance_json": "cv-summary_feature-importance_n-{n_folds}.json",
|
|
59
|
+
"cv_summary_feature_importance_jpg": "cv-summary_feature-importance_n-{n_folds}.jpg",
|
|
60
|
+
"cv_summary_summary": "cv-summary_summary.json",
|
|
61
|
+
# CV summary edge cases
|
|
62
|
+
"cv_summary_confusion_matrix_csv_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.csv",
|
|
63
|
+
"cv_summary_confusion_matrix_jpg_no_bacc": "cv-summary_confusion-matrix_n-{n_folds}.jpg",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# EOF
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_cv_summary.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
CV summary mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any, Dict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from scitex.logging import getLogger
|
|
17
|
+
|
|
18
|
+
from ._constants import FILENAME_PATTERNS
|
|
19
|
+
|
|
20
|
+
logger = getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CVSummaryMixin:
|
|
24
|
+
"""Mixin providing CV summary methods."""
|
|
25
|
+
|
|
26
|
+
def create_cv_summary_curves(self, summary: Dict[str, Any]) -> None:
|
|
27
|
+
"""Create CV summary ROC and PR curves from aggregated predictions."""
|
|
28
|
+
if not self.all_predictions:
|
|
29
|
+
logger.warning("No predictions stored for CV summary curves")
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
all_y_true = np.concatenate([p["y_true"] for p in self.all_predictions])
|
|
33
|
+
all_y_proba = np.concatenate([p["y_proba"] for p in self.all_predictions])
|
|
34
|
+
|
|
35
|
+
roc_values = []
|
|
36
|
+
pr_values = []
|
|
37
|
+
for metrics in self.fold_metrics.values():
|
|
38
|
+
if "roc-auc" in metrics:
|
|
39
|
+
val = metrics["roc-auc"]
|
|
40
|
+
if isinstance(val, dict) and "value" in val:
|
|
41
|
+
roc_values.append(val["value"])
|
|
42
|
+
else:
|
|
43
|
+
roc_values.append(val)
|
|
44
|
+
if "pr-auc" in metrics:
|
|
45
|
+
val = metrics["pr-auc"]
|
|
46
|
+
if isinstance(val, dict) and "value" in val:
|
|
47
|
+
pr_values.append(val["value"])
|
|
48
|
+
else:
|
|
49
|
+
pr_values.append(val)
|
|
50
|
+
|
|
51
|
+
n_folds = len(self.fold_metrics)
|
|
52
|
+
if roc_values:
|
|
53
|
+
roc_mean = np.mean(roc_values)
|
|
54
|
+
roc_std = np.std(roc_values)
|
|
55
|
+
else:
|
|
56
|
+
from ..reporter_utils.metrics import calc_roc_auc
|
|
57
|
+
|
|
58
|
+
overall_roc = calc_roc_auc(all_y_true, all_y_proba)
|
|
59
|
+
roc_mean = overall_roc["value"]
|
|
60
|
+
roc_std = 0.0
|
|
61
|
+
|
|
62
|
+
if pr_values:
|
|
63
|
+
pr_mean = np.mean(pr_values)
|
|
64
|
+
pr_std = np.std(pr_values)
|
|
65
|
+
else:
|
|
66
|
+
from ..reporter_utils.metrics import calc_pre_rec_auc
|
|
67
|
+
|
|
68
|
+
overall_pr = calc_pre_rec_auc(all_y_true, all_y_proba)
|
|
69
|
+
pr_mean = overall_pr["value"]
|
|
70
|
+
pr_std = 0.0
|
|
71
|
+
|
|
72
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
73
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
|
|
75
|
+
self._save_cv_summary_curve_data(
|
|
76
|
+
all_y_true, all_y_proba, roc_mean, roc_std, pr_mean, pr_std, n_folds
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
from scitex.ai.metrics import _normalize_labels
|
|
80
|
+
|
|
81
|
+
all_y_true_norm, _, label_names, _ = _normalize_labels(all_y_true, all_y_true)
|
|
82
|
+
|
|
83
|
+
roc_title = f"ROC Curve (CV Summary) - AUC: {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
84
|
+
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_jpg"].format(
|
|
85
|
+
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
86
|
+
)
|
|
87
|
+
self.plotter.create_overall_roc_curve(
|
|
88
|
+
all_y_true_norm,
|
|
89
|
+
all_y_proba,
|
|
90
|
+
labels=label_names,
|
|
91
|
+
save_path=cv_summary_dir / roc_filename,
|
|
92
|
+
title=roc_title,
|
|
93
|
+
auc_mean=roc_mean,
|
|
94
|
+
auc_std=roc_std,
|
|
95
|
+
verbose=True,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
pr_title = f"Precision-Recall Curve (CV Summary) - AP: {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
99
|
+
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_jpg"].format(
|
|
100
|
+
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
101
|
+
)
|
|
102
|
+
self.plotter.create_overall_pr_curve(
|
|
103
|
+
all_y_true_norm,
|
|
104
|
+
all_y_proba,
|
|
105
|
+
labels=label_names,
|
|
106
|
+
save_path=cv_summary_dir / pr_filename,
|
|
107
|
+
title=pr_title,
|
|
108
|
+
ap_mean=pr_mean,
|
|
109
|
+
ap_std=pr_std,
|
|
110
|
+
verbose=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
logger.info(
|
|
114
|
+
f"Created CV summary ROC curve: AUC = {roc_mean:.3f} ± {roc_std:.3f} (n={n_folds})"
|
|
115
|
+
)
|
|
116
|
+
logger.info(
|
|
117
|
+
f"Created CV summary PR curve: AP = {pr_mean:.3f} ± {pr_std:.3f} (n={n_folds})"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _save_cv_summary_curve_data(
|
|
121
|
+
self,
|
|
122
|
+
y_true: np.ndarray,
|
|
123
|
+
y_proba: np.ndarray,
|
|
124
|
+
roc_mean: float,
|
|
125
|
+
roc_std: float,
|
|
126
|
+
pr_mean: float,
|
|
127
|
+
pr_std: float,
|
|
128
|
+
n_folds: int,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Save CV summary ROC and PR curve data as CSV files."""
|
|
131
|
+
from sklearn.metrics import (
|
|
132
|
+
auc,
|
|
133
|
+
average_precision_score,
|
|
134
|
+
precision_recall_curve,
|
|
135
|
+
roc_curve,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
cv_summary_dir = "cv_summary"
|
|
139
|
+
|
|
140
|
+
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
141
|
+
if y_proba.ndim == 2:
|
|
142
|
+
y_proba_pos = y_proba[:, 1]
|
|
143
|
+
else:
|
|
144
|
+
y_proba_pos = y_proba
|
|
145
|
+
|
|
146
|
+
from scitex.ai.metrics import _normalize_labels
|
|
147
|
+
|
|
148
|
+
y_true_norm, _, _, _ = _normalize_labels(y_true, y_true)
|
|
149
|
+
|
|
150
|
+
fpr, tpr, _ = roc_curve(y_true_norm, y_proba_pos)
|
|
151
|
+
roc_auc = auc(fpr, tpr)
|
|
152
|
+
roc_df = pd.DataFrame({"FPR": fpr, "TPR": tpr})
|
|
153
|
+
roc_filename = FILENAME_PATTERNS["cv_summary_roc_curve_csv"].format(
|
|
154
|
+
mean=roc_mean, std=roc_std, n_folds=n_folds
|
|
155
|
+
)
|
|
156
|
+
self.storage.save(roc_df, f"{cv_summary_dir}/{roc_filename}")
|
|
157
|
+
|
|
158
|
+
precision, recall, _ = precision_recall_curve(y_true_norm, y_proba_pos)
|
|
159
|
+
avg_precision = average_precision_score(y_true_norm, y_proba_pos)
|
|
160
|
+
pr_df = pd.DataFrame({"Recall": recall, "Precision": precision})
|
|
161
|
+
pr_filename = FILENAME_PATTERNS["cv_summary_pr_curve_csv"].format(
|
|
162
|
+
mean=pr_mean, std=pr_std, n_folds=n_folds
|
|
163
|
+
)
|
|
164
|
+
self.storage.save(pr_df, f"{cv_summary_dir}/{pr_filename}")
|
|
165
|
+
|
|
166
|
+
def save_cv_summary_confusion_matrix(self, summary: Dict[str, Any]) -> None:
|
|
167
|
+
"""Save and plot the CV summary confusion matrix."""
|
|
168
|
+
confusion_matrices = []
|
|
169
|
+
for fold_metrics in self.fold_metrics.values():
|
|
170
|
+
if "confusion_matrix" in fold_metrics:
|
|
171
|
+
cm_data = fold_metrics["confusion_matrix"]
|
|
172
|
+
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
173
|
+
cm_data = cm_data["value"]
|
|
174
|
+
if cm_data is not None:
|
|
175
|
+
confusion_matrices.append(cm_data)
|
|
176
|
+
|
|
177
|
+
if not confusion_matrices:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
overall_cm = np.sum(confusion_matrices, axis=0)
|
|
181
|
+
|
|
182
|
+
labels = None
|
|
183
|
+
for fold_metrics in self.fold_metrics.values():
|
|
184
|
+
if "labels" in fold_metrics:
|
|
185
|
+
labels = fold_metrics["labels"]
|
|
186
|
+
break
|
|
187
|
+
elif "confusion_matrix" in fold_metrics:
|
|
188
|
+
cm_data = fold_metrics["confusion_matrix"]
|
|
189
|
+
if isinstance(cm_data, dict) and "labels" in cm_data:
|
|
190
|
+
labels = cm_data["labels"]
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
194
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
balanced_acc_mean = None
|
|
197
|
+
balanced_acc_std = None
|
|
198
|
+
n_folds = len(self.fold_metrics)
|
|
199
|
+
if "metrics_summary" in summary:
|
|
200
|
+
if "balanced-accuracy" in summary["metrics_summary"]:
|
|
201
|
+
balanced_acc_stats = summary["metrics_summary"]["balanced-accuracy"]
|
|
202
|
+
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
203
|
+
balanced_acc_std = balanced_acc_stats.get("std")
|
|
204
|
+
|
|
205
|
+
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
206
|
+
cm_filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_csv"].format(
|
|
207
|
+
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
cm_filename = FILENAME_PATTERNS[
|
|
211
|
+
"cv_summary_confusion_matrix_csv_no_bacc"
|
|
212
|
+
].format(n_folds=n_folds)
|
|
213
|
+
|
|
214
|
+
if labels:
|
|
215
|
+
cm_df = pd.DataFrame(
|
|
216
|
+
overall_cm,
|
|
217
|
+
index=[f"True_{label}" for label in labels],
|
|
218
|
+
columns=[f"Pred_{label}" for label in labels],
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
cm_df = pd.DataFrame(overall_cm)
|
|
222
|
+
|
|
223
|
+
self.storage.save(cm_df, f"cv_summary/{cm_filename}", index=True)
|
|
224
|
+
|
|
225
|
+
cv_summary_dir = self._create_subdir_if_needed("cv_summary")
|
|
226
|
+
cv_summary_dir.mkdir(parents=True, exist_ok=True)
|
|
227
|
+
|
|
228
|
+
balanced_acc_mean = None
|
|
229
|
+
balanced_acc_std = None
|
|
230
|
+
if "metrics_summary" in self.get_summary():
|
|
231
|
+
metrics_summary = self.get_summary()["metrics_summary"]
|
|
232
|
+
if "balanced-accuracy" in metrics_summary:
|
|
233
|
+
balanced_acc_stats = metrics_summary["balanced-accuracy"]
|
|
234
|
+
balanced_acc_mean = balanced_acc_stats.get("mean")
|
|
235
|
+
balanced_acc_std = balanced_acc_stats.get("std")
|
|
236
|
+
|
|
237
|
+
if balanced_acc_mean is not None and balanced_acc_std is not None:
|
|
238
|
+
title = f"Confusion Matrix (CV Summary) - Balanced Acc: {balanced_acc_mean:.3f} ± {balanced_acc_std:.3f} (n={n_folds})"
|
|
239
|
+
filename = FILENAME_PATTERNS["cv_summary_confusion_matrix_jpg"].format(
|
|
240
|
+
mean=balanced_acc_mean, std=balanced_acc_std, n_folds=n_folds
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
title = f"Confusion Matrix (CV Summary) (n={n_folds})"
|
|
244
|
+
filename = FILENAME_PATTERNS[
|
|
245
|
+
"cv_summary_confusion_matrix_jpg_no_bacc"
|
|
246
|
+
].format(n_folds=n_folds)
|
|
247
|
+
|
|
248
|
+
self.plotter.create_confusion_matrix_plot(
|
|
249
|
+
overall_cm,
|
|
250
|
+
labels=labels,
|
|
251
|
+
save_path=cv_summary_dir / filename,
|
|
252
|
+
title=title,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def _save_cv_summary_metrics(self, summary: Dict[str, Any]) -> None:
|
|
256
|
+
"""Save individual CV summary metrics with mean/std/n_folds in filenames."""
|
|
257
|
+
if "metrics_summary" not in summary:
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
n_folds = len(self.fold_metrics)
|
|
261
|
+
cv_summary_dir = "cv_summary"
|
|
262
|
+
|
|
263
|
+
for metric_name, stats in summary["metrics_summary"].items():
|
|
264
|
+
if isinstance(stats, dict) and "mean" in stats:
|
|
265
|
+
mean_val = stats.get("mean", 0)
|
|
266
|
+
std_val = stats.get("std", 0)
|
|
267
|
+
|
|
268
|
+
filename = FILENAME_PATTERNS["cv_summary_metric"].format(
|
|
269
|
+
metric_name=metric_name,
|
|
270
|
+
mean=mean_val,
|
|
271
|
+
std=std_val,
|
|
272
|
+
n_folds=n_folds,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
self.storage.save(stats, f"{cv_summary_dir}/{filename}")
|
|
276
|
+
|
|
277
|
+
def _save_cv_summary_classification_report(self, summary: Dict[str, Any]) -> None:
|
|
278
|
+
"""Save CV summary classification report with mean ± std (n_folds=X) format."""
|
|
279
|
+
n_folds = len(self.fold_metrics)
|
|
280
|
+
cv_summary_dir = "cv_summary"
|
|
281
|
+
|
|
282
|
+
all_reports = []
|
|
283
|
+
for fold_num, fold_metrics in self.fold_metrics.items():
|
|
284
|
+
if "classification_report" in fold_metrics:
|
|
285
|
+
report = fold_metrics["classification_report"]
|
|
286
|
+
if isinstance(report, dict) and "value" in report:
|
|
287
|
+
report = report["value"]
|
|
288
|
+
|
|
289
|
+
if isinstance(report, pd.DataFrame):
|
|
290
|
+
if "class" in report.columns:
|
|
291
|
+
report_dict = {}
|
|
292
|
+
for _, row in report.iterrows():
|
|
293
|
+
class_name = row["class"]
|
|
294
|
+
report_dict[class_name] = {
|
|
295
|
+
col: row[col]
|
|
296
|
+
for col in report.columns
|
|
297
|
+
if col != "class"
|
|
298
|
+
}
|
|
299
|
+
report = report_dict
|
|
300
|
+
else:
|
|
301
|
+
report = report.to_dict("index")
|
|
302
|
+
|
|
303
|
+
if isinstance(report, dict):
|
|
304
|
+
all_reports.append(report)
|
|
305
|
+
|
|
306
|
+
if not all_reports:
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
summary_report = {}
|
|
310
|
+
|
|
311
|
+
all_classes = set()
|
|
312
|
+
for report in all_reports:
|
|
313
|
+
all_classes.update(
|
|
314
|
+
[
|
|
315
|
+
k
|
|
316
|
+
for k in report.keys()
|
|
317
|
+
if k not in ["accuracy", "macro avg", "weighted avg"]
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
for cls in sorted(all_classes):
|
|
322
|
+
cls_metrics = {
|
|
323
|
+
"precision": [],
|
|
324
|
+
"recall": [],
|
|
325
|
+
"f1-score": [],
|
|
326
|
+
"support": [],
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
for report in all_reports:
|
|
330
|
+
if cls in report:
|
|
331
|
+
for metric in ["precision", "recall", "f1-score", "support"]:
|
|
332
|
+
if metric in report[cls]:
|
|
333
|
+
cls_metrics[metric].append(report[cls][metric])
|
|
334
|
+
|
|
335
|
+
summary_report[cls] = {}
|
|
336
|
+
for metric, values in cls_metrics.items():
|
|
337
|
+
if values:
|
|
338
|
+
if metric == "support":
|
|
339
|
+
total_support = int(np.sum(values))
|
|
340
|
+
mean_support = np.mean(values)
|
|
341
|
+
std_support = np.std(values)
|
|
342
|
+
if std_support > 0:
|
|
343
|
+
summary_report[cls][metric] = (
|
|
344
|
+
f"{mean_support:.1f} ± {std_support:.1f} (total={total_support})"
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
summary_report[cls][metric] = (
|
|
348
|
+
f"{int(mean_support)} per fold (total={total_support})"
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
mean_val = np.mean(values)
|
|
352
|
+
std_val = np.std(values)
|
|
353
|
+
summary_report[cls][metric] = (
|
|
354
|
+
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
for avg_type in ["macro avg", "weighted avg"]:
|
|
358
|
+
avg_metrics = {"precision": [], "recall": [], "f1-score": []}
|
|
359
|
+
|
|
360
|
+
for report in all_reports:
|
|
361
|
+
if avg_type in report:
|
|
362
|
+
for metric in ["precision", "recall", "f1-score"]:
|
|
363
|
+
if metric in report[avg_type]:
|
|
364
|
+
avg_metrics[metric].append(report[avg_type][metric])
|
|
365
|
+
|
|
366
|
+
if any(avg_metrics.values()):
|
|
367
|
+
summary_report[avg_type] = {}
|
|
368
|
+
for metric, values in avg_metrics.items():
|
|
369
|
+
if values:
|
|
370
|
+
mean_val = np.mean(values)
|
|
371
|
+
std_val = np.std(values)
|
|
372
|
+
summary_report[avg_type][metric] = (
|
|
373
|
+
f"{mean_val:.3f} ± {std_val:.3f} (n={n_folds})"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if summary_report:
|
|
377
|
+
report_df = pd.DataFrame(summary_report).T
|
|
378
|
+
report_df = report_df.reset_index()
|
|
379
|
+
report_df = report_df.rename(columns={"index": "class"})
|
|
380
|
+
|
|
381
|
+
filename = FILENAME_PATTERNS["cv_summary_classification_report"].format(
|
|
382
|
+
n_folds=n_folds
|
|
383
|
+
)
|
|
384
|
+
self.storage.save(report_df, f"{cv_summary_dir}/{filename}")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# EOF
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_feature_importance.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Feature importance mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from scitex.logging import getLogger
|
|
16
|
+
|
|
17
|
+
from ._constants import FILENAME_PATTERNS, FOLD_DIR_PREFIX_PATTERN
|
|
18
|
+
|
|
19
|
+
logger = getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FeatureImportanceMixin:
|
|
23
|
+
"""Mixin providing feature importance methods."""
|
|
24
|
+
|
|
25
|
+
def save_feature_importance(
|
|
26
|
+
self,
|
|
27
|
+
model,
|
|
28
|
+
feature_names: List[str],
|
|
29
|
+
fold: Optional[int] = None,
|
|
30
|
+
) -> Dict[str, float]:
|
|
31
|
+
"""Calculate and save feature importance for tree-based models."""
|
|
32
|
+
from scitex.ai.metrics import calc_feature_importance
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
importance_dict, importances = calc_feature_importance(model, feature_names)
|
|
36
|
+
except ValueError as e:
|
|
37
|
+
logger.warning(f"Could not extract feature importance: {e}")
|
|
38
|
+
return {}
|
|
39
|
+
|
|
40
|
+
sorted_importances = list(importance_dict.items())
|
|
41
|
+
|
|
42
|
+
fold_subdir = (
|
|
43
|
+
FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
44
|
+
if fold is not None
|
|
45
|
+
else "cv_summary"
|
|
46
|
+
)
|
|
47
|
+
json_filename = FILENAME_PATTERNS["feature_importance_json"].format(fold=fold)
|
|
48
|
+
self.storage.save(dict(sorted_importances), f"{fold_subdir}/{json_filename}")
|
|
49
|
+
|
|
50
|
+
jpg_filename = FILENAME_PATTERNS["feature_importance_jpg"].format(fold=fold)
|
|
51
|
+
save_path = self.output_dir / fold_subdir / jpg_filename
|
|
52
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
self.plotter.create_feature_importance_plot(
|
|
55
|
+
feature_importance=importances,
|
|
56
|
+
feature_names=feature_names,
|
|
57
|
+
save_path=save_path,
|
|
58
|
+
title=(
|
|
59
|
+
f"Feature Importance (Fold {fold:02d})"
|
|
60
|
+
if fold is not None
|
|
61
|
+
else "Feature Importance (CV Summary)"
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
logger.info(
|
|
66
|
+
"Saved feature importance"
|
|
67
|
+
+ (f" for fold {fold}" if fold is not None else "")
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return importance_dict
|
|
71
|
+
|
|
72
|
+
def save_feature_importance_summary(
|
|
73
|
+
self,
|
|
74
|
+
all_importances: List[Dict[str, float]],
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Create summary visualization of feature importances across all folds."""
|
|
77
|
+
if not all_importances:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
all_features = set()
|
|
81
|
+
for imp_dict in all_importances:
|
|
82
|
+
all_features.update(imp_dict.keys())
|
|
83
|
+
|
|
84
|
+
feature_stats = {}
|
|
85
|
+
for feature in all_features:
|
|
86
|
+
values = [imp_dict.get(feature, 0) for imp_dict in all_importances]
|
|
87
|
+
feature_stats[feature] = {
|
|
88
|
+
"mean": float(np.mean(values)),
|
|
89
|
+
"std": float(np.std(values)),
|
|
90
|
+
"values": [float(v) for v in values],
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
sorted_features = sorted(
|
|
94
|
+
feature_stats.items(), key=lambda x: x[1]["mean"], reverse=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
n_folds = len(all_importances)
|
|
98
|
+
json_filename = FILENAME_PATTERNS["cv_summary_feature_importance_json"].format(
|
|
99
|
+
n_folds=n_folds
|
|
100
|
+
)
|
|
101
|
+
self.storage.save(dict(sorted_features), f"cv_summary/{json_filename}")
|
|
102
|
+
|
|
103
|
+
from scitex.ai.plt import plot_feature_importance_cv_summary
|
|
104
|
+
|
|
105
|
+
jpg_filename = FILENAME_PATTERNS["cv_summary_feature_importance_jpg"].format(
|
|
106
|
+
n_folds=n_folds
|
|
107
|
+
)
|
|
108
|
+
save_path = self.output_dir / "cv_summary" / jpg_filename
|
|
109
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
|
|
111
|
+
plot_feature_importance_cv_summary(
|
|
112
|
+
all_importances=all_importances,
|
|
113
|
+
spath=save_path,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
logger.info("Saved feature importance summary")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# EOF
|