scitex 2.14.0__py3-none-any.whl → 2.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +47 -0
- scitex/_env_loader.py +156 -0
- scitex/_mcp_resources/__init__.py +37 -0
- scitex/_mcp_resources/_cheatsheet.py +135 -0
- scitex/_mcp_resources/_figrecipe.py +138 -0
- scitex/_mcp_resources/_formats.py +102 -0
- scitex/_mcp_resources/_modules.py +337 -0
- scitex/_mcp_resources/_session.py +149 -0
- scitex/_mcp_tools/__init__.py +4 -0
- scitex/_mcp_tools/audio.py +66 -0
- scitex/_mcp_tools/diagram.py +11 -95
- scitex/_mcp_tools/introspect.py +191 -0
- scitex/_mcp_tools/plt.py +260 -305
- scitex/_mcp_tools/scholar.py +74 -0
- scitex/_mcp_tools/social.py +244 -0
- scitex/_mcp_tools/writer.py +21 -204
- scitex/ai/_gen_ai/_PARAMS.py +10 -7
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
- scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
- scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
- scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
- scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
- scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
- scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
- scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
- scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
- scitex/audio/README.md +40 -36
- scitex/audio/__init__.py +127 -59
- scitex/audio/_branding.py +185 -0
- scitex/audio/_mcp/__init__.py +32 -0
- scitex/audio/_mcp/handlers.py +59 -6
- scitex/audio/_mcp/speak_handlers.py +238 -0
- scitex/audio/_relay.py +225 -0
- scitex/audio/engines/elevenlabs_engine.py +6 -1
- scitex/audio/mcp_server.py +228 -75
- scitex/canvas/README.md +1 -1
- scitex/canvas/editor/_dearpygui/__init__.py +25 -0
- scitex/canvas/editor/_dearpygui/_editor.py +147 -0
- scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
- scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
- scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
- scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
- scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
- scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
- scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
- scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
- scitex/canvas/editor/_dearpygui/_selection.py +295 -0
- scitex/canvas/editor/_dearpygui/_state.py +93 -0
- scitex/canvas/editor/_dearpygui/_utils.py +61 -0
- scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
- scitex/cli/__init__.py +38 -43
- scitex/cli/audio.py +76 -27
- scitex/cli/capture.py +13 -20
- scitex/cli/introspect.py +443 -0
- scitex/cli/main.py +198 -109
- scitex/cli/mcp.py +60 -34
- scitex/cli/scholar/__init__.py +8 -0
- scitex/cli/scholar/_crossref_scitex.py +296 -0
- scitex/cli/scholar/_fetch.py +25 -3
- scitex/cli/social.py +314 -0
- scitex/cli/writer.py +117 -0
- scitex/config/README.md +1 -1
- scitex/config/__init__.py +16 -2
- scitex/config/_env_registry.py +191 -0
- scitex/diagram/__init__.py +42 -19
- scitex/diagram/mcp_server.py +13 -125
- scitex/introspect/__init__.py +75 -0
- scitex/introspect/_call_graph.py +303 -0
- scitex/introspect/_class_hierarchy.py +163 -0
- scitex/introspect/_core.py +42 -0
- scitex/introspect/_docstring.py +131 -0
- scitex/introspect/_examples.py +113 -0
- scitex/introspect/_imports.py +271 -0
- scitex/introspect/_mcp/__init__.py +37 -0
- scitex/introspect/_mcp/handlers.py +208 -0
- scitex/introspect/_members.py +151 -0
- scitex/introspect/_resolve.py +89 -0
- scitex/introspect/_signature.py +131 -0
- scitex/introspect/_source.py +80 -0
- scitex/introspect/_type_hints.py +172 -0
- scitex/io/bundle/README.md +1 -1
- scitex/mcp_server.py +98 -5
- scitex/plt/__init__.py +248 -550
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
- scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/plt/gallery/README.md +1 -1
- scitex/plt/utils/_hitmap/__init__.py +82 -0
- scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
- scitex/plt/utils/_hitmap/_color_application.py +346 -0
- scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
- scitex/plt/utils/_hitmap/_constants.py +40 -0
- scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
- scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
- scitex/plt/utils/_hitmap/_query.py +113 -0
- scitex/plt/utils/_hitmap.py +46 -1616
- scitex/plt/utils/_metadata/__init__.py +80 -0
- scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
- scitex/plt/utils/_metadata/_artists/_base.py +195 -0
- scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
- scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
- scitex/plt/utils/_metadata/_artists/_images.py +80 -0
- scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
- scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
- scitex/plt/utils/_metadata/_artists/_text.py +106 -0
- scitex/plt/utils/_metadata/_csv.py +416 -0
- scitex/plt/utils/_metadata/_detect.py +225 -0
- scitex/plt/utils/_metadata/_legend.py +127 -0
- scitex/plt/utils/_metadata/_rounding.py +117 -0
- scitex/plt/utils/_metadata/_verification.py +202 -0
- scitex/schema/README.md +1 -1
- scitex/scholar/__init__.py +8 -0
- scitex/scholar/_mcp/crossref_handlers.py +265 -0
- scitex/scholar/core/Scholar.py +63 -1700
- scitex/scholar/core/_mixins/__init__.py +36 -0
- scitex/scholar/core/_mixins/_enrichers.py +270 -0
- scitex/scholar/core/_mixins/_library_handlers.py +100 -0
- scitex/scholar/core/_mixins/_loaders.py +103 -0
- scitex/scholar/core/_mixins/_pdf_download.py +375 -0
- scitex/scholar/core/_mixins/_pipeline.py +312 -0
- scitex/scholar/core/_mixins/_project_handlers.py +125 -0
- scitex/scholar/core/_mixins/_savers.py +69 -0
- scitex/scholar/core/_mixins/_search.py +103 -0
- scitex/scholar/core/_mixins/_services.py +88 -0
- scitex/scholar/core/_mixins/_url_finding.py +105 -0
- scitex/scholar/crossref_scitex.py +367 -0
- scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/scholar/examples/00_run_all.sh +120 -0
- scitex/scholar/jobs/_executors.py +27 -3
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
- scitex/scholar/pdf_download/_cli.py +154 -0
- scitex/scholar/pdf_download/strategies/__init__.py +11 -8
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
- scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
- scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
- scitex/scholar/pipelines/_single_steps.py +71 -36
- scitex/scholar/storage/_LibraryManager.py +97 -1695
- scitex/scholar/storage/_mixins/__init__.py +30 -0
- scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
- scitex/scholar/storage/_mixins/_library_operations.py +218 -0
- scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
- scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
- scitex/scholar/storage/_mixins/_resolution.py +376 -0
- scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
- scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
- scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
- scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
- scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
- scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
- scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
- scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
- scitex/security/README.md +3 -3
- scitex/session/README.md +1 -1
- scitex/sh/README.md +1 -1
- scitex/social/__init__.py +153 -0
- scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/template/README.md +1 -1
- scitex/template/clone_writer_directory.py +5 -5
- scitex/writer/README.md +1 -1
- scitex/writer/_mcp/handlers.py +11 -744
- scitex/writer/_mcp/tool_schemas.py +5 -335
- scitex-2.15.1.dist-info/METADATA +648 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/RECORD +166 -111
- scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
- scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
- scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
- scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
- scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
- scitex/diagram/_compile.py +0 -312
- scitex/diagram/_diagram.py +0 -355
- scitex/diagram/_mcp/__init__.py +0 -4
- scitex/diagram/_mcp/handlers.py +0 -400
- scitex/diagram/_mcp/tool_schemas.py +0 -157
- scitex/diagram/_presets.py +0 -173
- scitex/diagram/_schema.py +0 -182
- scitex/diagram/_split.py +0 -278
- scitex/plt/_mcp/__init__.py +0 -4
- scitex/plt/_mcp/_handlers_annotation.py +0 -102
- scitex/plt/_mcp/_handlers_figure.py +0 -195
- scitex/plt/_mcp/_handlers_plot.py +0 -252
- scitex/plt/_mcp/_handlers_style.py +0 -219
- scitex/plt/_mcp/handlers.py +0 -74
- scitex/plt/_mcp/tool_schemas.py +0 -497
- scitex/plt/mcp_server.py +0 -231
- scitex/scholar/data/.gitkeep +0 -0
- scitex/scholar/data/README.md +0 -44
- scitex/scholar/data/bib_files/bibliography.bib +0 -1952
- scitex/scholar/data/bib_files/neurovista.bib +0 -277
- scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
- scitex/scholar/data/bib_files/openaccess.bib +0 -89
- scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
- scitex/scholar/data/bib_files/pac.bib +0 -698
- scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
- scitex/scholar/data/bib_files/pac_processed.bib +0 -0
- scitex/scholar/data/bib_files/pac_titles.txt +0 -75
- scitex/scholar/data/bib_files/paywalled.bib +0 -98
- scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
- scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
- scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
- scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
- scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_seizure.bib +0 -46
- scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
- scitex/scholar/data/impact_factor.db +0 -0
- scitex/scholar/examples/SUGGESTIONS.md +0 -865
- scitex/scholar/examples/dev.py +0 -38
- scitex-2.14.0.dist-info/METADATA +0 -1238
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/WHEEL +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/entry_points.txt +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_metrics.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Metrics calculation mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from pprint import pprint
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from scitex.logging import getLogger
|
|
18
|
+
|
|
19
|
+
from ._constants import FILENAME_PATTERNS, FOLD_DIR_PREFIX_PATTERN
|
|
20
|
+
|
|
21
|
+
logger = getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MetricsMixin:
|
|
25
|
+
"""Mixin providing metrics calculation methods."""
|
|
26
|
+
|
|
27
|
+
def calculate_metrics(
|
|
28
|
+
self,
|
|
29
|
+
y_true: np.ndarray,
|
|
30
|
+
y_pred: np.ndarray,
|
|
31
|
+
y_proba: Optional[np.ndarray] = None,
|
|
32
|
+
labels: Optional[List[str]] = None,
|
|
33
|
+
fold: Optional[int] = None,
|
|
34
|
+
verbose: bool = True,
|
|
35
|
+
store_y_true: bool = True,
|
|
36
|
+
store_y_pred: bool = True,
|
|
37
|
+
store_y_proba: bool = True,
|
|
38
|
+
model: Optional[Any] = None,
|
|
39
|
+
feature_names: Optional[List[str]] = None,
|
|
40
|
+
) -> Dict[str, Any]:
|
|
41
|
+
"""Calculate and save classification metrics using unified API."""
|
|
42
|
+
from ..reporter_utils import (
|
|
43
|
+
calc_bacc,
|
|
44
|
+
calc_clf_report,
|
|
45
|
+
calc_conf_mat,
|
|
46
|
+
calc_mcc,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if verbose:
|
|
50
|
+
if fold:
|
|
51
|
+
print()
|
|
52
|
+
logger.info(f"Calculating metrics for fold #{fold:02d}...")
|
|
53
|
+
else:
|
|
54
|
+
logger.info("Calculating metrics...")
|
|
55
|
+
|
|
56
|
+
if len(y_true) != len(y_pred):
|
|
57
|
+
raise ValueError("y_true and y_pred must have same length")
|
|
58
|
+
|
|
59
|
+
if y_proba is not None and len(y_true) != len(y_proba):
|
|
60
|
+
raise ValueError("y_true and y_proba must have same length")
|
|
61
|
+
|
|
62
|
+
if fold is None:
|
|
63
|
+
fold = 0
|
|
64
|
+
|
|
65
|
+
if labels is None:
|
|
66
|
+
unique_labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
|
|
67
|
+
labels = [f"Class_{i}" for i in unique_labels]
|
|
68
|
+
|
|
69
|
+
metrics = {}
|
|
70
|
+
metrics["balanced-accuracy"] = calc_bacc(y_true, y_pred, fold=fold)
|
|
71
|
+
metrics["mcc"] = calc_mcc(y_true, y_pred, fold=fold)
|
|
72
|
+
metrics["confusion_matrix"] = calc_conf_mat(
|
|
73
|
+
y_true=y_true, y_pred=y_pred, labels=labels, fold=fold
|
|
74
|
+
)
|
|
75
|
+
metrics["classification_report"] = calc_clf_report(
|
|
76
|
+
y_true, y_pred, labels, fold=fold
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if y_proba is not None:
|
|
80
|
+
try:
|
|
81
|
+
from scitex.ai.metrics import calc_pre_rec_auc, calc_roc_auc
|
|
82
|
+
|
|
83
|
+
metrics["roc-auc"] = calc_roc_auc(
|
|
84
|
+
y_true, y_proba, labels=labels, fold=fold, return_curve=False
|
|
85
|
+
)
|
|
86
|
+
metrics["pr-auc"] = calc_pre_rec_auc(
|
|
87
|
+
y_true, y_proba, labels=labels, fold=fold, return_curve=False
|
|
88
|
+
)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.warning(f"Could not calculate AUC metrics: {e}")
|
|
91
|
+
|
|
92
|
+
metrics = self._round_numeric(metrics)
|
|
93
|
+
metrics["labels"] = labels
|
|
94
|
+
|
|
95
|
+
if verbose:
|
|
96
|
+
logger.info("Metrics calculated:")
|
|
97
|
+
pprint(metrics)
|
|
98
|
+
|
|
99
|
+
self.fold_metrics[fold] = metrics.copy()
|
|
100
|
+
|
|
101
|
+
if y_proba is not None:
|
|
102
|
+
self.all_predictions.append(
|
|
103
|
+
{
|
|
104
|
+
"fold": fold,
|
|
105
|
+
"y_true": y_true.copy(),
|
|
106
|
+
"y_proba": y_proba.copy(),
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self._save_fold_metrics(metrics, fold, labels)
|
|
111
|
+
self._create_plots(y_true, y_pred, y_proba, labels, fold, metrics)
|
|
112
|
+
|
|
113
|
+
if model is not None and feature_names is not None:
|
|
114
|
+
self._extract_and_save_feature_importance(
|
|
115
|
+
model, feature_names, fold, metrics, verbose
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if store_y_true or store_y_pred or store_y_proba:
|
|
119
|
+
self._store_raw_predictions(
|
|
120
|
+
y_true, y_pred, y_proba, fold, store_y_true, store_y_pred, store_y_proba
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return metrics
|
|
124
|
+
|
|
125
|
+
def _extract_and_save_feature_importance(
|
|
126
|
+
self, model, feature_names, fold, metrics, verbose
|
|
127
|
+
):
|
|
128
|
+
"""Extract and save feature importance from model."""
|
|
129
|
+
try:
|
|
130
|
+
from scitex.ai.feature_selection import extract_feature_importance
|
|
131
|
+
|
|
132
|
+
importance_dict = extract_feature_importance(
|
|
133
|
+
model, feature_names, method="auto"
|
|
134
|
+
)
|
|
135
|
+
if importance_dict:
|
|
136
|
+
metrics["feature-importance"] = importance_dict
|
|
137
|
+
self.fold_metrics[fold]["feature-importance"] = importance_dict
|
|
138
|
+
|
|
139
|
+
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
140
|
+
filename = FILENAME_PATTERNS["feature_importance_json"].format(
|
|
141
|
+
fold=fold
|
|
142
|
+
)
|
|
143
|
+
self.storage.save(importance_dict, f"{fold_dir}/{filename}")
|
|
144
|
+
|
|
145
|
+
if verbose:
|
|
146
|
+
logger.info(" Feature importance extracted and saved")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"Could not extract feature importance: {e}")
|
|
149
|
+
|
|
150
|
+
def _store_raw_predictions(
|
|
151
|
+
self, y_true, y_pred, y_proba, fold, store_y_true, store_y_pred, store_y_proba
|
|
152
|
+
):
|
|
153
|
+
"""Store raw prediction data as CSV files."""
|
|
154
|
+
fold_dir = FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
155
|
+
sample_indices = np.arange(len(y_true))
|
|
156
|
+
|
|
157
|
+
n_samples = len(y_true)
|
|
158
|
+
estimated_size_mb = 0
|
|
159
|
+
if store_y_true:
|
|
160
|
+
estimated_size_mb += n_samples * 0.0001
|
|
161
|
+
if store_y_pred:
|
|
162
|
+
estimated_size_mb += n_samples * 0.0001
|
|
163
|
+
if store_y_proba and y_proba is not None:
|
|
164
|
+
n_classes = 1 if y_proba.ndim == 1 else y_proba.shape[1]
|
|
165
|
+
estimated_size_mb += n_samples * n_classes * 0.0001
|
|
166
|
+
|
|
167
|
+
if estimated_size_mb > 10:
|
|
168
|
+
logger.warning(
|
|
169
|
+
f"Storing raw predictions for fold {fold} will create "
|
|
170
|
+
f"~{estimated_size_mb:.1f}MB of CSV files."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if store_y_true:
|
|
174
|
+
filename = FILENAME_PATTERNS["y_true"].format(fold=fold)
|
|
175
|
+
df_y_true = pd.DataFrame(
|
|
176
|
+
{"sample_index": sample_indices, "fold": fold, "y_true": y_true}
|
|
177
|
+
)
|
|
178
|
+
self.storage.save(df_y_true, f"{fold_dir}/{filename}")
|
|
179
|
+
|
|
180
|
+
if store_y_pred:
|
|
181
|
+
filename = FILENAME_PATTERNS["y_pred"].format(fold=fold)
|
|
182
|
+
df_y_pred = pd.DataFrame(
|
|
183
|
+
{"sample_index": sample_indices, "fold": fold, "y_pred": y_pred}
|
|
184
|
+
)
|
|
185
|
+
self.storage.save(df_y_pred, f"{fold_dir}/{filename}")
|
|
186
|
+
|
|
187
|
+
if store_y_proba and y_proba is not None:
|
|
188
|
+
filename = FILENAME_PATTERNS["y_proba"].format(fold=fold)
|
|
189
|
+
if y_proba.ndim == 1:
|
|
190
|
+
df_y_proba = pd.DataFrame(
|
|
191
|
+
{"sample_index": sample_indices, "fold": fold, "y_proba": y_proba}
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
data = {"sample_index": sample_indices, "fold": fold}
|
|
195
|
+
for i in range(y_proba.shape[1]):
|
|
196
|
+
data[f"proba_class_{i}"] = y_proba[:, i]
|
|
197
|
+
df_y_proba = pd.DataFrame(data)
|
|
198
|
+
self.storage.save(df_y_proba, f"{fold_dir}/{filename}")
|
|
199
|
+
|
|
200
|
+
def _extract_metric_value(self, metric_data: Any) -> Optional[float]:
|
|
201
|
+
"""Extract numeric value from metric data."""
|
|
202
|
+
if metric_data is None:
|
|
203
|
+
return None
|
|
204
|
+
if isinstance(metric_data, dict) and "value" in metric_data:
|
|
205
|
+
return float(metric_data["value"])
|
|
206
|
+
if isinstance(metric_data, (int, float, np.number)):
|
|
207
|
+
return float(metric_data)
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
211
|
+
"""Get summary of all calculated metrics across folds."""
|
|
212
|
+
if not self.fold_metrics:
|
|
213
|
+
return {"error": "No metrics calculated yet"}
|
|
214
|
+
|
|
215
|
+
summary = {
|
|
216
|
+
"output_dir": str(self.output_dir),
|
|
217
|
+
"total_folds": len(self.fold_metrics),
|
|
218
|
+
"metrics_summary": {},
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
confusion_matrices = []
|
|
222
|
+
for fold_metrics in self.fold_metrics.values():
|
|
223
|
+
if "confusion_matrix" in fold_metrics:
|
|
224
|
+
cm_data = fold_metrics["confusion_matrix"]
|
|
225
|
+
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
226
|
+
cm_data = cm_data["value"]
|
|
227
|
+
if cm_data is not None:
|
|
228
|
+
confusion_matrices.append(cm_data)
|
|
229
|
+
|
|
230
|
+
if confusion_matrices:
|
|
231
|
+
overall_cm = np.sum(confusion_matrices, axis=0)
|
|
232
|
+
summary["overall_confusion_matrix"] = overall_cm.tolist()
|
|
233
|
+
overall_cm_normalized = overall_cm / overall_cm.sum()
|
|
234
|
+
summary["overall_confusion_matrix_normalized"] = self._round_numeric(
|
|
235
|
+
overall_cm_normalized.tolist()
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
scalar_metrics = ["balanced-accuracy", "mcc", "roc-auc", "pr-auc"]
|
|
239
|
+
for metric_name in scalar_metrics:
|
|
240
|
+
values = []
|
|
241
|
+
for fold_metrics in self.fold_metrics.values():
|
|
242
|
+
if metric_name in fold_metrics:
|
|
243
|
+
metric_val = fold_metrics[metric_name]
|
|
244
|
+
if isinstance(metric_val, dict) and "value" in metric_val:
|
|
245
|
+
values.append(metric_val["value"])
|
|
246
|
+
else:
|
|
247
|
+
values.append(metric_val)
|
|
248
|
+
|
|
249
|
+
if values:
|
|
250
|
+
values = np.array(values)
|
|
251
|
+
summary["metrics_summary"][metric_name] = {
|
|
252
|
+
"mean": self._round_numeric(np.mean(values)),
|
|
253
|
+
"std": self._round_numeric(np.std(values)),
|
|
254
|
+
"min": self._round_numeric(np.min(values)),
|
|
255
|
+
"max": self._round_numeric(np.max(values)),
|
|
256
|
+
"values": self._round_numeric(values.tolist()),
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
feature_importances_list = []
|
|
260
|
+
for fold_metrics in self.fold_metrics.values():
|
|
261
|
+
if "feature-importance" in fold_metrics:
|
|
262
|
+
feature_importances_list.append(fold_metrics["feature-importance"])
|
|
263
|
+
|
|
264
|
+
if feature_importances_list:
|
|
265
|
+
from scitex.ai.feature_selection import aggregate_feature_importances
|
|
266
|
+
|
|
267
|
+
aggregated_importances = aggregate_feature_importances(
|
|
268
|
+
feature_importances_list
|
|
269
|
+
)
|
|
270
|
+
summary["feature-importance"] = aggregated_importances
|
|
271
|
+
|
|
272
|
+
return summary
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# EOF
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_plotting.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Plotting mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from scitex.logging import getLogger
|
|
17
|
+
|
|
18
|
+
from ._constants import FILENAME_PATTERNS, FOLD_DIR_PREFIX_PATTERN
|
|
19
|
+
|
|
20
|
+
logger = getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PlottingMixin:
|
|
24
|
+
"""Mixin providing plotting methods."""
|
|
25
|
+
|
|
26
|
+
def _create_plots(
|
|
27
|
+
self,
|
|
28
|
+
y_true: np.ndarray,
|
|
29
|
+
y_pred: np.ndarray,
|
|
30
|
+
y_proba: Optional[np.ndarray],
|
|
31
|
+
labels: List[str],
|
|
32
|
+
fold: int,
|
|
33
|
+
metrics: Dict[str, Any],
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Create and save plots with metric-based filenames."""
|
|
36
|
+
fold_dir = self._create_subdir_if_needed(
|
|
37
|
+
FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
38
|
+
)
|
|
39
|
+
fold_dir.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# Confusion matrix plot
|
|
42
|
+
if "confusion_matrix" in metrics:
|
|
43
|
+
self._create_confusion_matrix_plot(metrics, labels, fold, fold_dir)
|
|
44
|
+
|
|
45
|
+
# ROC and PR curves
|
|
46
|
+
if y_proba is not None:
|
|
47
|
+
self._create_roc_curve_plot(
|
|
48
|
+
y_true, y_proba, labels, fold, fold_dir, metrics
|
|
49
|
+
)
|
|
50
|
+
self._create_pr_curve_plot(y_true, y_proba, labels, fold, fold_dir, metrics)
|
|
51
|
+
|
|
52
|
+
# Metrics dashboard
|
|
53
|
+
summary_filename = FILENAME_PATTERNS["metrics_summary"].format(fold=fold)
|
|
54
|
+
self.plotter.create_metrics_visualization(
|
|
55
|
+
metrics=metrics,
|
|
56
|
+
y_true=y_true,
|
|
57
|
+
y_pred=y_pred,
|
|
58
|
+
y_proba=y_proba,
|
|
59
|
+
labels=labels,
|
|
60
|
+
save_path=fold_dir / summary_filename,
|
|
61
|
+
title="Classification Metrics Dashboard",
|
|
62
|
+
fold=fold,
|
|
63
|
+
verbose=False,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _create_confusion_matrix_plot(self, metrics, labels, fold, fold_dir):
|
|
67
|
+
"""Create confusion matrix plot."""
|
|
68
|
+
cm_data = metrics["confusion_matrix"]
|
|
69
|
+
if isinstance(cm_data, dict) and "value" in cm_data:
|
|
70
|
+
cm_data = cm_data["value"]
|
|
71
|
+
|
|
72
|
+
balanced_acc = metrics.get("balanced-accuracy", {})
|
|
73
|
+
if isinstance(balanced_acc, dict) and "value" in balanced_acc:
|
|
74
|
+
balanced_acc = balanced_acc["value"]
|
|
75
|
+
elif isinstance(balanced_acc, (float, np.floating)):
|
|
76
|
+
balanced_acc = float(balanced_acc)
|
|
77
|
+
else:
|
|
78
|
+
balanced_acc = None
|
|
79
|
+
|
|
80
|
+
if balanced_acc is not None:
|
|
81
|
+
title = (
|
|
82
|
+
f"Confusion Matrix (Fold {fold:02d}) - Balanced Acc: {balanced_acc:.3f}"
|
|
83
|
+
)
|
|
84
|
+
filename = FILENAME_PATTERNS["confusion_matrix_jpg"].format(
|
|
85
|
+
fold=fold, bacc=balanced_acc
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
title = f"Confusion Matrix (Fold {fold:02d})"
|
|
89
|
+
filename = FILENAME_PATTERNS["confusion_matrix_jpg_no_bacc"].format(
|
|
90
|
+
fold=fold
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.plotter.create_confusion_matrix_plot(
|
|
94
|
+
cm_data, labels=labels, save_path=fold_dir / filename, title=title
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def _create_roc_curve_plot(self, y_true, y_proba, labels, fold, fold_dir, metrics):
|
|
98
|
+
"""Create ROC curve plot."""
|
|
99
|
+
roc_auc = metrics.get("roc-auc", {})
|
|
100
|
+
if isinstance(roc_auc, dict) and "value" in roc_auc:
|
|
101
|
+
roc_auc_val = roc_auc["value"]
|
|
102
|
+
roc_filename = FILENAME_PATTERNS["roc_curve_jpg"].format(
|
|
103
|
+
fold=fold, auc=roc_auc_val
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
roc_filename = FILENAME_PATTERNS["roc_curve_jpg_no_auc"].format(fold=fold)
|
|
107
|
+
|
|
108
|
+
self.plotter.create_roc_curve(
|
|
109
|
+
y_true,
|
|
110
|
+
y_proba,
|
|
111
|
+
labels=labels,
|
|
112
|
+
save_path=fold_dir / roc_filename,
|
|
113
|
+
title=f"ROC Curve (Fold {fold:02d})",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _create_pr_curve_plot(self, y_true, y_proba, labels, fold, fold_dir, metrics):
|
|
117
|
+
"""Create precision-recall curve plot."""
|
|
118
|
+
pr_auc = metrics.get("pr-auc", {})
|
|
119
|
+
if isinstance(pr_auc, dict) and "value" in pr_auc:
|
|
120
|
+
pr_auc_val = pr_auc["value"]
|
|
121
|
+
pr_filename = FILENAME_PATTERNS["pr_curve_jpg"].format(
|
|
122
|
+
fold=fold, ap=pr_auc_val
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
pr_filename = FILENAME_PATTERNS["pr_curve_jpg_no_ap"].format(fold=fold)
|
|
126
|
+
|
|
127
|
+
self.plotter.create_precision_recall_curve(
|
|
128
|
+
y_true,
|
|
129
|
+
y_proba,
|
|
130
|
+
labels=labels,
|
|
131
|
+
save_path=fold_dir / pr_filename,
|
|
132
|
+
title=f"Precision-Recall Curve (Fold {fold:02d})",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def create_cv_aggregation_visualizations(
|
|
136
|
+
self,
|
|
137
|
+
output_dir: Optional[Path] = None,
|
|
138
|
+
show_individual_folds: bool = True,
|
|
139
|
+
fold_alpha: float = 0.15,
|
|
140
|
+
) -> None:
|
|
141
|
+
"""Create CV aggregation visualizations with faded individual fold lines."""
|
|
142
|
+
if not self.all_predictions:
|
|
143
|
+
logger.warning("No predictions stored for CV aggregation visualizations")
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
if output_dir is None:
|
|
147
|
+
output_dir = self._create_subdir_if_needed("cv_summary")
|
|
148
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
149
|
+
|
|
150
|
+
n_folds = len(self.all_predictions)
|
|
151
|
+
|
|
152
|
+
# ROC curve
|
|
153
|
+
roc_save_path = output_dir / f"roc_cv_aggregation_n{n_folds}.jpg"
|
|
154
|
+
self.plotter.create_cv_aggregation_plot(
|
|
155
|
+
fold_predictions=self.all_predictions,
|
|
156
|
+
curve_type="roc",
|
|
157
|
+
save_path=roc_save_path,
|
|
158
|
+
show_individual_folds=show_individual_folds,
|
|
159
|
+
fold_alpha=fold_alpha,
|
|
160
|
+
title=f"ROC Curves - Cross Validation (n={n_folds} folds)",
|
|
161
|
+
verbose=True,
|
|
162
|
+
)
|
|
163
|
+
logger.info("Created CV aggregation ROC plot with faded fold lines")
|
|
164
|
+
|
|
165
|
+
# PR curve
|
|
166
|
+
pr_save_path = output_dir / f"pr_cv_aggregation_n{n_folds}.jpg"
|
|
167
|
+
self.plotter.create_cv_aggregation_plot(
|
|
168
|
+
fold_predictions=self.all_predictions,
|
|
169
|
+
curve_type="pr",
|
|
170
|
+
save_path=pr_save_path,
|
|
171
|
+
show_individual_folds=show_individual_folds,
|
|
172
|
+
fold_alpha=fold_alpha,
|
|
173
|
+
title=f"Precision-Recall Curves - Cross Validation (n={n_folds} folds)",
|
|
174
|
+
verbose=True,
|
|
175
|
+
)
|
|
176
|
+
logger.info("Created CV aggregation PR plot with faded fold lines")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# EOF
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/reporters/_mixins/_reports.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Reports generation mixin for classification reporter.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Dict
|
|
14
|
+
|
|
15
|
+
from scitex.logging import getLogger
|
|
16
|
+
|
|
17
|
+
from ._constants import FOLD_DIR_PREFIX_PATTERN
|
|
18
|
+
|
|
19
|
+
logger = getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ReportsMixin:
|
|
23
|
+
"""Mixin providing report generation methods."""
|
|
24
|
+
|
|
25
|
+
def generate_reports(self) -> Dict[str, Path]:
|
|
26
|
+
"""Generate comprehensive reports in multiple formats."""
|
|
27
|
+
from ..reporter_utils.reporting import generate_org_report
|
|
28
|
+
|
|
29
|
+
results = {
|
|
30
|
+
"config": {
|
|
31
|
+
"n_folds": len(self.fold_metrics),
|
|
32
|
+
"output_dir": str(self.output_dir),
|
|
33
|
+
},
|
|
34
|
+
"session_config": self.session_config,
|
|
35
|
+
"summary": {},
|
|
36
|
+
"folds": [],
|
|
37
|
+
"plots": {},
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
summary = self.get_summary()
|
|
41
|
+
|
|
42
|
+
if "metrics_summary" in summary:
|
|
43
|
+
results["summary"] = summary["metrics_summary"]
|
|
44
|
+
|
|
45
|
+
if "feature-importance" in summary:
|
|
46
|
+
results["summary"]["feature-importance"] = summary["feature-importance"]
|
|
47
|
+
|
|
48
|
+
for fold, fold_data in self.fold_metrics.items():
|
|
49
|
+
fold_result = {"fold_id": fold}
|
|
50
|
+
fold_result.update(fold_data)
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
calling_file_dir = Path(__file__).parent.parent / "reporter_utils"
|
|
54
|
+
storage_out_path = (
|
|
55
|
+
calling_file_dir
|
|
56
|
+
/ "storage_out"
|
|
57
|
+
/ self.output_dir
|
|
58
|
+
/ FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
59
|
+
/ "features.json"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
regular_path = (
|
|
63
|
+
self.output_dir
|
|
64
|
+
/ FOLD_DIR_PREFIX_PATTERN.format(fold=fold)
|
|
65
|
+
/ "features.json"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
features_json = None
|
|
69
|
+
if storage_out_path.exists():
|
|
70
|
+
features_json = storage_out_path
|
|
71
|
+
elif regular_path.exists():
|
|
72
|
+
features_json = regular_path
|
|
73
|
+
|
|
74
|
+
if features_json:
|
|
75
|
+
with open(features_json) as f:
|
|
76
|
+
features_data = json.load(f)
|
|
77
|
+
for key in [
|
|
78
|
+
"n_train",
|
|
79
|
+
"n_test",
|
|
80
|
+
"n_train_seizure",
|
|
81
|
+
"n_train_interictal",
|
|
82
|
+
"n_test_seizure",
|
|
83
|
+
"n_test_interictal",
|
|
84
|
+
]:
|
|
85
|
+
if key in features_data:
|
|
86
|
+
fold_result[key] = int(features_data[key])
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
results["folds"].append(fold_result)
|
|
91
|
+
|
|
92
|
+
cv_summary_dir = self.output_dir / "cv_summary"
|
|
93
|
+
if cv_summary_dir.exists():
|
|
94
|
+
for plot_file in cv_summary_dir.glob("*.jpg"):
|
|
95
|
+
plot_key = f"cv_summary_{plot_file.stem}"
|
|
96
|
+
results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
|
|
97
|
+
|
|
98
|
+
for fold_dir in sorted(self.output_dir.glob("fold_*")):
|
|
99
|
+
fold_num = fold_dir.name.replace("fold_", "")
|
|
100
|
+
for plot_file in fold_dir.glob("*.jpg"):
|
|
101
|
+
plot_key = f"fold_{fold_num}_{plot_file.stem}"
|
|
102
|
+
results["plots"][plot_key] = str(plot_file.relative_to(self.output_dir))
|
|
103
|
+
|
|
104
|
+
reports_dir = self._create_subdir_if_needed("reports")
|
|
105
|
+
generated_files = {}
|
|
106
|
+
|
|
107
|
+
org_path = reports_dir / "classification_report.org"
|
|
108
|
+
generate_org_report(results, org_path, include_plots=True, convert_formats=True)
|
|
109
|
+
generated_files["org"] = org_path
|
|
110
|
+
logger.info(f"Generated org-mode report: {org_path}")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
import os
|
|
114
|
+
import shutil
|
|
115
|
+
import subprocess
|
|
116
|
+
|
|
117
|
+
if shutil.which("pdflatex"):
|
|
118
|
+
original_dir = Path.cwd()
|
|
119
|
+
try:
|
|
120
|
+
os.chdir(reports_dir)
|
|
121
|
+
|
|
122
|
+
for _ in range(2):
|
|
123
|
+
result = subprocess.run(
|
|
124
|
+
[
|
|
125
|
+
"pdflatex",
|
|
126
|
+
"-interaction=nonstopmode",
|
|
127
|
+
"classification_report.tex",
|
|
128
|
+
],
|
|
129
|
+
capture_output=True,
|
|
130
|
+
text=True,
|
|
131
|
+
timeout=30,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
pdf_path = reports_dir / "classification_report.pdf"
|
|
135
|
+
if pdf_path.exists():
|
|
136
|
+
generated_files["pdf"] = pdf_path
|
|
137
|
+
logger.info(f"Generated PDF report: {pdf_path}")
|
|
138
|
+
|
|
139
|
+
for ext in [".aux", ".log", ".out", ".toc"]:
|
|
140
|
+
aux_file = reports_dir / f"classification_report{ext}"
|
|
141
|
+
if aux_file.exists():
|
|
142
|
+
aux_file.unlink()
|
|
143
|
+
finally:
|
|
144
|
+
os.chdir(original_dir)
|
|
145
|
+
else:
|
|
146
|
+
logger.warning("pdflatex not found. Skipping PDF generation.")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"Could not generate PDF report: {e}")
|
|
149
|
+
|
|
150
|
+
return generated_files
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# EOF
|