scitex 2.0.0__py2.py3-none-any.whl → 2.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +53 -15
- scitex/__main__.py +72 -26
- scitex/__version__.py +1 -1
- scitex/_sh.py +145 -23
- scitex/ai/__init__.py +30 -16
- scitex/ai/_gen_ai/_Anthropic.py +5 -7
- scitex/ai/_gen_ai/_BaseGenAI.py +2 -2
- scitex/ai/_gen_ai/_DeepSeek.py +10 -2
- scitex/ai/_gen_ai/_Google.py +2 -2
- scitex/ai/_gen_ai/_Llama.py +2 -2
- scitex/ai/_gen_ai/_OpenAI.py +2 -2
- scitex/ai/_gen_ai/_PARAMS.py +51 -65
- scitex/ai/_gen_ai/_Perplexity.py +2 -2
- scitex/ai/_gen_ai/__init__.py +25 -14
- scitex/ai/_gen_ai/_format_output_func.py +4 -4
- scitex/ai/classification/{classifier_server.py → Classifier.py} +5 -5
- scitex/ai/classification/CrossValidationExperiment.py +374 -0
- scitex/ai/classification/__init__.py +43 -4
- scitex/ai/classification/reporters/_BaseClassificationReporter.py +281 -0
- scitex/ai/classification/reporters/_ClassificationReporter.py +773 -0
- scitex/ai/classification/reporters/_MultiClassificationReporter.py +406 -0
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +1834 -0
- scitex/ai/classification/reporters/__init__.py +11 -0
- scitex/ai/classification/reporters/reporter_utils/_Plotter.py +1028 -0
- scitex/ai/classification/reporters/reporter_utils/__init__.py +80 -0
- scitex/ai/classification/reporters/reporter_utils/aggregation.py +457 -0
- scitex/ai/classification/reporters/reporter_utils/data_models.py +313 -0
- scitex/ai/classification/reporters/reporter_utils/reporting.py +1056 -0
- scitex/ai/classification/reporters/reporter_utils/storage.py +221 -0
- scitex/ai/classification/reporters/reporter_utils/validation.py +395 -0
- scitex/ai/classification/timeseries/_TimeSeriesBlockingSplit.py +568 -0
- scitex/ai/classification/timeseries/_TimeSeriesCalendarSplit.py +688 -0
- scitex/ai/classification/timeseries/_TimeSeriesMetadata.py +139 -0
- scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +1716 -0
- scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit_v01-not-using-n_splits.py +1685 -0
- scitex/ai/classification/timeseries/_TimeSeriesStrategy.py +84 -0
- scitex/ai/classification/timeseries/_TimeSeriesStratifiedSplit.py +610 -0
- scitex/ai/classification/timeseries/__init__.py +39 -0
- scitex/ai/classification/timeseries/_normalize_timestamp.py +436 -0
- scitex/ai/clustering/_umap.py +2 -2
- scitex/ai/feature_extraction/vit.py +1 -0
- scitex/ai/feature_selection/__init__.py +30 -0
- scitex/ai/feature_selection/feature_selection.py +364 -0
- scitex/ai/loss/multi_task_loss.py +1 -1
- scitex/ai/metrics/__init__.py +51 -4
- scitex/ai/metrics/_calc_bacc.py +61 -0
- scitex/ai/metrics/_calc_bacc_from_conf_mat.py +38 -0
- scitex/ai/metrics/_calc_clf_report.py +78 -0
- scitex/ai/metrics/_calc_conf_mat.py +93 -0
- scitex/ai/metrics/_calc_feature_importance.py +183 -0
- scitex/ai/metrics/_calc_mcc.py +61 -0
- scitex/ai/metrics/_calc_pre_rec_auc.py +116 -0
- scitex/ai/metrics/_calc_roc_auc.py +110 -0
- scitex/ai/metrics/_calc_seizure_prediction_metrics.py +490 -0
- scitex/ai/metrics/{silhoute_score_block.py → _calc_silhouette_score.py} +15 -8
- scitex/ai/metrics/_normalize_labels.py +83 -0
- scitex/ai/plt/__init__.py +47 -8
- scitex/ai/plt/{_conf_mat.py → _plot_conf_mat.py} +158 -87
- scitex/ai/plt/_plot_feature_importance.py +323 -0
- scitex/ai/plt/_plot_learning_curve.py +345 -0
- scitex/ai/plt/_plot_optuna_study.py +225 -0
- scitex/ai/plt/_plot_pre_rec_curve.py +290 -0
- scitex/ai/plt/_plot_roc_curve.py +255 -0
- scitex/ai/training/{learning_curve_logger.py → _LearningCurveLogger.py} +197 -213
- scitex/ai/training/__init__.py +2 -2
- scitex/ai/utils/grid_search.py +3 -3
- scitex/benchmark/__init__.py +52 -0
- scitex/benchmark/benchmark.py +400 -0
- scitex/benchmark/monitor.py +370 -0
- scitex/benchmark/profiler.py +297 -0
- scitex/browser/__init__.py +48 -0
- scitex/browser/automation/CookieHandler.py +216 -0
- scitex/browser/automation/__init__.py +7 -0
- scitex/browser/collaboration/__init__.py +55 -0
- scitex/browser/collaboration/auth_helpers.py +94 -0
- scitex/browser/collaboration/collaborative_agent.py +136 -0
- scitex/browser/collaboration/credential_manager.py +188 -0
- scitex/browser/collaboration/interactive_panel.py +400 -0
- scitex/browser/collaboration/persistent_browser.py +170 -0
- scitex/browser/collaboration/shared_session.py +383 -0
- scitex/browser/collaboration/standard_interactions.py +246 -0
- scitex/browser/collaboration/visual_feedback.py +181 -0
- scitex/browser/core/BrowserMixin.py +326 -0
- scitex/browser/core/ChromeProfileManager.py +446 -0
- scitex/browser/core/__init__.py +9 -0
- scitex/browser/debugging/__init__.py +18 -0
- scitex/browser/debugging/_browser_logger.py +657 -0
- scitex/browser/debugging/_highlight_element.py +143 -0
- scitex/browser/debugging/_show_grid.py +154 -0
- scitex/browser/interaction/__init__.py +24 -0
- scitex/browser/interaction/click_center.py +149 -0
- scitex/browser/interaction/click_with_fallbacks.py +206 -0
- scitex/browser/interaction/close_popups.py +498 -0
- scitex/browser/interaction/fill_with_fallbacks.py +209 -0
- scitex/browser/pdf/__init__.py +14 -0
- scitex/browser/pdf/click_download_for_chrome_pdf_viewer.py +200 -0
- scitex/browser/pdf/detect_chrome_pdf_viewer.py +198 -0
- scitex/browser/remote/CaptchaHandler.py +434 -0
- scitex/browser/remote/ZenRowsAPIClient.py +347 -0
- scitex/browser/remote/ZenRowsBrowserManager.py +570 -0
- scitex/browser/remote/__init__.py +11 -0
- scitex/browser/stealth/HumanBehavior.py +344 -0
- scitex/browser/stealth/StealthManager.py +1008 -0
- scitex/browser/stealth/__init__.py +9 -0
- scitex/browser/template.py +122 -0
- scitex/capture/__init__.py +110 -0
- scitex/capture/__main__.py +25 -0
- scitex/capture/capture.py +848 -0
- scitex/capture/cli.py +233 -0
- scitex/capture/gif.py +344 -0
- scitex/capture/mcp_server.py +961 -0
- scitex/capture/session.py +70 -0
- scitex/capture/utils.py +705 -0
- scitex/cli/__init__.py +17 -0
- scitex/cli/cloud.py +447 -0
- scitex/cli/main.py +42 -0
- scitex/cli/scholar.py +280 -0
- scitex/context/_suppress_output.py +5 -3
- scitex/db/__init__.py +30 -3
- scitex/db/__main__.py +75 -0
- scitex/db/_check_health.py +381 -0
- scitex/db/_delete_duplicates.py +25 -386
- scitex/db/_inspect.py +335 -114
- scitex/db/_inspect_optimized.py +301 -0
- scitex/db/{_PostgreSQL.py → _postgresql/_PostgreSQL.py} +3 -3
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_BackupMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_BatchMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_BlobMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_ConnectionMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_MaintenanceMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_QueryMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_SchemaMixin.py +1 -1
- scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_TransactionMixin.py +1 -1
- scitex/db/_postgresql/__init__.py +6 -0
- scitex/db/_sqlite3/_SQLite3.py +210 -0
- scitex/db/_sqlite3/_SQLite3Mixins/_ArrayMixin.py +581 -0
- scitex/db/_sqlite3/_SQLite3Mixins/_ArrayMixin_v01-need-_hash-col.py +517 -0
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_BatchMixin.py +1 -1
- scitex/db/_sqlite3/_SQLite3Mixins/_BlobMixin.py +281 -0
- scitex/db/_sqlite3/_SQLite3Mixins/_ColumnMixin.py +548 -0
- scitex/db/_sqlite3/_SQLite3Mixins/_ColumnMixin_v01-indentation-issues.py +583 -0
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_ConnectionMixin.py +29 -13
- scitex/db/_sqlite3/_SQLite3Mixins/_GitMixin.py +583 -0
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_ImportExportMixin.py +1 -1
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_IndexMixin.py +1 -1
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_MaintenanceMixin.py +2 -1
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_QueryMixin.py +37 -10
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_RowMixin.py +46 -6
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_TableMixin.py +56 -10
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/_TransactionMixin.py +1 -1
- scitex/db/{_SQLite3Mixins → _sqlite3/_SQLite3Mixins}/__init__.py +14 -2
- scitex/db/_sqlite3/__init__.py +7 -0
- scitex/db/_sqlite3/_delete_duplicates.py +274 -0
- scitex/decorators/__init__.py +2 -0
- scitex/decorators/_cache_disk.py +13 -5
- scitex/decorators/_cache_disk_async.py +49 -0
- scitex/decorators/_deprecated.py +175 -10
- scitex/decorators/_timeout.py +1 -1
- scitex/dev/_analyze_code_flow.py +2 -2
- scitex/dict/_DotDict.py +73 -15
- scitex/dict/_DotDict_v01-not-handling-recursive-instantiations.py +442 -0
- scitex/dict/_DotDict_v02-not-serializing-Path-object.py +446 -0
- scitex/dict/__init__.py +2 -0
- scitex/dict/_flatten.py +27 -0
- scitex/dsp/_crop.py +2 -2
- scitex/dsp/_demo_sig.py +2 -2
- scitex/dsp/_detect_ripples.py +2 -2
- scitex/dsp/_hilbert.py +2 -2
- scitex/dsp/_listen.py +6 -6
- scitex/dsp/_modulation_index.py +2 -2
- scitex/dsp/_pac.py +1 -1
- scitex/dsp/_psd.py +2 -2
- scitex/dsp/_resample.py +2 -1
- scitex/dsp/_time.py +3 -2
- scitex/dsp/_wavelet.py +3 -2
- scitex/dsp/add_noise.py +2 -2
- scitex/dsp/example.py +1 -0
- scitex/dsp/filt.py +10 -9
- scitex/dsp/template.py +3 -2
- scitex/dsp/utils/_differential_bandpass_filters.py +1 -1
- scitex/dsp/utils/pac.py +2 -2
- scitex/dt/_normalize_timestamp.py +432 -0
- scitex/errors.py +572 -0
- scitex/gen/_DimHandler.py +2 -2
- scitex/gen/__init__.py +37 -7
- scitex/gen/_deprecated_close.py +80 -0
- scitex/gen/_deprecated_start.py +26 -0
- scitex/gen/_detect_environment.py +152 -0
- scitex/gen/_detect_notebook_path.py +169 -0
- scitex/gen/_embed.py +6 -2
- scitex/gen/_get_notebook_path.py +257 -0
- scitex/gen/_less.py +1 -1
- scitex/gen/_list_packages.py +2 -2
- scitex/gen/_norm.py +44 -9
- scitex/gen/_norm_cache.py +269 -0
- scitex/gen/_src.py +3 -5
- scitex/gen/_title_case.py +3 -3
- scitex/io/__init__.py +28 -6
- scitex/io/_glob.py +13 -7
- scitex/io/_load.py +108 -21
- scitex/io/_load_cache.py +303 -0
- scitex/io/_load_configs.py +40 -15
- scitex/io/{_H5Explorer.py → _load_modules/_H5Explorer.py} +80 -17
- scitex/io/_load_modules/_ZarrExplorer.py +114 -0
- scitex/io/_load_modules/_bibtex.py +207 -0
- scitex/io/_load_modules/_hdf5.py +53 -178
- scitex/io/_load_modules/_json.py +5 -3
- scitex/io/_load_modules/_pdf.py +871 -16
- scitex/io/_load_modules/_sqlite3.py +15 -0
- scitex/io/_load_modules/_txt.py +41 -12
- scitex/io/_load_modules/_yaml.py +4 -3
- scitex/io/_load_modules/_zarr.py +126 -0
- scitex/io/_save.py +429 -171
- scitex/io/_save_modules/__init__.py +6 -0
- scitex/io/_save_modules/_bibtex.py +194 -0
- scitex/io/_save_modules/_csv.py +8 -4
- scitex/io/_save_modules/_excel.py +174 -15
- scitex/io/_save_modules/_hdf5.py +251 -226
- scitex/io/_save_modules/_image.py +1 -3
- scitex/io/_save_modules/_json.py +49 -4
- scitex/io/_save_modules/_listed_dfs_as_csv.py +1 -3
- scitex/io/_save_modules/_listed_scalars_as_csv.py +1 -3
- scitex/io/_save_modules/_tex.py +277 -0
- scitex/io/_save_modules/_yaml.py +42 -3
- scitex/io/_save_modules/_zarr.py +160 -0
- scitex/io/utils/__init__.py +20 -0
- scitex/io/utils/h5_to_zarr.py +616 -0
- scitex/linalg/_geometric_median.py +6 -2
- scitex/{gen/_tee.py → logging/_Tee.py} +43 -84
- scitex/logging/__init__.py +122 -0
- scitex/logging/_config.py +158 -0
- scitex/logging/_context.py +103 -0
- scitex/logging/_formatters.py +128 -0
- scitex/logging/_handlers.py +64 -0
- scitex/logging/_levels.py +35 -0
- scitex/logging/_logger.py +163 -0
- scitex/logging/_print_capture.py +95 -0
- scitex/ml/__init__.py +69 -0
- scitex/{ai/genai/anthropic.py → ml/_gen_ai/_Anthropic.py} +13 -19
- scitex/{ai/genai/base_genai.py → ml/_gen_ai/_BaseGenAI.py} +5 -5
- scitex/{ai/genai/deepseek.py → ml/_gen_ai/_DeepSeek.py} +11 -16
- scitex/{ai/genai/google.py → ml/_gen_ai/_Google.py} +7 -15
- scitex/{ai/genai/groq.py → ml/_gen_ai/_Groq.py} +1 -8
- scitex/{ai/genai/llama.py → ml/_gen_ai/_Llama.py} +3 -16
- scitex/{ai/genai/openai.py → ml/_gen_ai/_OpenAI.py} +3 -3
- scitex/{ai/genai/params.py → ml/_gen_ai/_PARAMS.py} +51 -65
- scitex/{ai/genai/perplexity.py → ml/_gen_ai/_Perplexity.py} +3 -14
- scitex/ml/_gen_ai/__init__.py +43 -0
- scitex/{ai/genai/calc_cost.py → ml/_gen_ai/_calc_cost.py} +1 -1
- scitex/{ai/genai/format_output_func.py → ml/_gen_ai/_format_output_func.py} +4 -4
- scitex/{ai/genai/genai_factory.py → ml/_gen_ai/_genai_factory.py} +8 -8
- scitex/ml/activation/__init__.py +8 -0
- scitex/ml/activation/_define.py +11 -0
- scitex/{ai/classifier_server.py → ml/classification/Classifier.py} +5 -5
- scitex/ml/classification/CrossValidationExperiment.py +374 -0
- scitex/ml/classification/__init__.py +46 -0
- scitex/ml/classification/reporters/_BaseClassificationReporter.py +281 -0
- scitex/ml/classification/reporters/_ClassificationReporter.py +773 -0
- scitex/ml/classification/reporters/_MultiClassificationReporter.py +406 -0
- scitex/ml/classification/reporters/_SingleClassificationReporter.py +1834 -0
- scitex/ml/classification/reporters/__init__.py +11 -0
- scitex/ml/classification/reporters/reporter_utils/_Plotter.py +1028 -0
- scitex/ml/classification/reporters/reporter_utils/__init__.py +80 -0
- scitex/ml/classification/reporters/reporter_utils/aggregation.py +457 -0
- scitex/ml/classification/reporters/reporter_utils/data_models.py +313 -0
- scitex/ml/classification/reporters/reporter_utils/reporting.py +1056 -0
- scitex/ml/classification/reporters/reporter_utils/storage.py +221 -0
- scitex/ml/classification/reporters/reporter_utils/validation.py +395 -0
- scitex/ml/classification/timeseries/_TimeSeriesBlockingSplit.py +568 -0
- scitex/ml/classification/timeseries/_TimeSeriesCalendarSplit.py +688 -0
- scitex/ml/classification/timeseries/_TimeSeriesMetadata.py +139 -0
- scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +1716 -0
- scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit_v01-not-using-n_splits.py +1685 -0
- scitex/ml/classification/timeseries/_TimeSeriesStrategy.py +84 -0
- scitex/ml/classification/timeseries/_TimeSeriesStratifiedSplit.py +610 -0
- scitex/ml/classification/timeseries/__init__.py +39 -0
- scitex/ml/classification/timeseries/_normalize_timestamp.py +436 -0
- scitex/ml/clustering/__init__.py +11 -0
- scitex/ml/clustering/_pca.py +115 -0
- scitex/ml/clustering/_umap.py +376 -0
- scitex/ml/feature_extraction/__init__.py +56 -0
- scitex/ml/feature_extraction/vit.py +149 -0
- scitex/ml/feature_selection/__init__.py +30 -0
- scitex/ml/feature_selection/feature_selection.py +364 -0
- scitex/ml/loss/_L1L2Losses.py +34 -0
- scitex/ml/loss/__init__.py +12 -0
- scitex/ml/loss/multi_task_loss.py +47 -0
- scitex/ml/metrics/__init__.py +56 -0
- scitex/ml/metrics/_calc_bacc.py +61 -0
- scitex/ml/metrics/_calc_bacc_from_conf_mat.py +38 -0
- scitex/ml/metrics/_calc_clf_report.py +78 -0
- scitex/ml/metrics/_calc_conf_mat.py +93 -0
- scitex/ml/metrics/_calc_feature_importance.py +183 -0
- scitex/ml/metrics/_calc_mcc.py +61 -0
- scitex/ml/metrics/_calc_pre_rec_auc.py +116 -0
- scitex/ml/metrics/_calc_roc_auc.py +110 -0
- scitex/ml/metrics/_calc_seizure_prediction_metrics.py +490 -0
- scitex/ml/metrics/_calc_silhouette_score.py +503 -0
- scitex/ml/metrics/_normalize_labels.py +83 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ml/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ml/optim/__init__.py +13 -0
- scitex/ml/optim/_get_set.py +31 -0
- scitex/ml/optim/_optimizers.py +71 -0
- scitex/ml/plt/__init__.py +60 -0
- scitex/ml/plt/_plot_conf_mat.py +663 -0
- scitex/ml/plt/_plot_feature_importance.py +323 -0
- scitex/ml/plt/_plot_learning_curve.py +345 -0
- scitex/ml/plt/_plot_optuna_study.py +225 -0
- scitex/ml/plt/_plot_pre_rec_curve.py +290 -0
- scitex/ml/plt/_plot_roc_curve.py +255 -0
- scitex/ml/sk/__init__.py +11 -0
- scitex/ml/sk/_clf.py +58 -0
- scitex/ml/sk/_to_sktime.py +100 -0
- scitex/ml/sklearn/__init__.py +26 -0
- scitex/ml/sklearn/clf.py +58 -0
- scitex/ml/sklearn/to_sktime.py +100 -0
- scitex/{ai/training/early_stopping.py → ml/training/_EarlyStopping.py} +1 -2
- scitex/{ai → ml/training}/_LearningCurveLogger.py +198 -242
- scitex/ml/training/__init__.py +7 -0
- scitex/ml/utils/__init__.py +22 -0
- scitex/ml/utils/_check_params.py +50 -0
- scitex/ml/utils/_default_dataset.py +46 -0
- scitex/ml/utils/_format_samples_for_sktime.py +26 -0
- scitex/ml/utils/_label_encoder.py +134 -0
- scitex/ml/utils/_merge_labels.py +22 -0
- scitex/ml/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ml/utils/_under_sample.py +51 -0
- scitex/ml/utils/_verify_n_gpus.py +16 -0
- scitex/ml/utils/grid_search.py +148 -0
- scitex/nn/_BNet.py +15 -9
- scitex/nn/_Filters.py +2 -2
- scitex/nn/_ModulationIndex.py +2 -2
- scitex/nn/_PAC.py +1 -1
- scitex/nn/_Spectrogram.py +12 -3
- scitex/nn/__init__.py +9 -10
- scitex/path/__init__.py +18 -0
- scitex/path/_clean.py +4 -0
- scitex/path/_find.py +9 -4
- scitex/path/_symlink.py +348 -0
- scitex/path/_version.py +4 -3
- scitex/pd/__init__.py +2 -0
- scitex/pd/_get_unique.py +99 -0
- scitex/plt/__init__.py +114 -5
- scitex/plt/_subplots/_AxesWrapper.py +1 -3
- scitex/plt/_subplots/_AxisWrapper.py +7 -3
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +47 -13
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +160 -2
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +26 -4
- scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py +322 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +1 -0
- scitex/plt/_subplots/_FigWrapper.py +62 -6
- scitex/plt/_subplots/_export_as_csv.py +43 -27
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +5 -4
- scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py +81 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +20 -5
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +35 -18
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +15 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py +35 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +6 -4
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py +60 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +1 -3
- scitex/plt/_subplots/_export_as_csv_formatters.py +56 -59
- scitex/plt/ax/_style/_hide_spines.py +1 -3
- scitex/plt/ax/_style/_rotate_labels.py +180 -76
- scitex/plt/ax/_style/_rotate_labels_v01.py +248 -0
- scitex/plt/ax/_style/_set_meta.py +11 -4
- scitex/plt/ax/_style/_set_supxyt.py +3 -3
- scitex/plt/ax/_style/_set_xyt.py +3 -3
- scitex/plt/ax/_style/_share_axes.py +2 -2
- scitex/plt/color/__init__.py +4 -4
- scitex/plt/color/{_get_colors_from_cmap.py → _get_colors_from_conf_matap.py} +7 -7
- scitex/plt/utils/_configure_mpl.py +99 -86
- scitex/plt/utils/_histogram_utils.py +1 -3
- scitex/plt/utils/_is_valid_axis.py +1 -3
- scitex/plt/utils/_scitex_config.py +1 -0
- scitex/repro/__init__.py +75 -0
- scitex/{reproduce → repro}/_gen_ID.py +1 -1
- scitex/{reproduce → repro}/_gen_timestamp.py +1 -1
- scitex/repro_rng/_RandomStateManager.py +590 -0
- scitex/repro_rng/_RandomStateManager_v01-no-verbose-options.py +414 -0
- scitex/repro_rng/__init__.py +39 -0
- scitex/reproduce/__init__.py +25 -13
- scitex/reproduce/_hash_array.py +22 -0
- scitex/resource/_get_processor_usages.py +4 -4
- scitex/resource/_get_specs.py +2 -2
- scitex/resource/_log_processor_usages.py +2 -2
- scitex/rng/_RandomStateManager.py +590 -0
- scitex/rng/_RandomStateManager_v01-no-verbose-options.py +414 -0
- scitex/rng/__init__.py +39 -0
- scitex/scholar/__init__.py +309 -19
- scitex/scholar/__main__.py +319 -0
- scitex/scholar/auth/ScholarAuthManager.py +308 -0
- scitex/scholar/auth/__init__.py +12 -0
- scitex/scholar/auth/core/AuthenticationGateway.py +473 -0
- scitex/scholar/auth/core/BrowserAuthenticator.py +386 -0
- scitex/scholar/auth/core/StrategyResolver.py +309 -0
- scitex/scholar/auth/core/__init__.py +16 -0
- scitex/scholar/auth/gateway/_OpenURLLinkFinder.py +120 -0
- scitex/scholar/auth/gateway/_OpenURLResolver.py +209 -0
- scitex/scholar/auth/gateway/__init__.py +38 -0
- scitex/scholar/auth/gateway/_resolve_functions.py +101 -0
- scitex/scholar/auth/providers/BaseAuthenticator.py +166 -0
- scitex/scholar/auth/providers/EZProxyAuthenticator.py +484 -0
- scitex/scholar/auth/providers/OpenAthensAuthenticator.py +619 -0
- scitex/scholar/auth/providers/ShibbolethAuthenticator.py +686 -0
- scitex/scholar/auth/providers/__init__.py +18 -0
- scitex/scholar/auth/session/AuthCacheManager.py +189 -0
- scitex/scholar/auth/session/SessionManager.py +159 -0
- scitex/scholar/auth/session/__init__.py +11 -0
- scitex/scholar/auth/sso/BaseSSOAutomator.py +373 -0
- scitex/scholar/auth/sso/OpenAthensSSOAutomator.py +378 -0
- scitex/scholar/auth/sso/SSOAutomator.py +180 -0
- scitex/scholar/auth/sso/UniversityOfMelbourneSSOAutomator.py +380 -0
- scitex/scholar/auth/sso/__init__.py +15 -0
- scitex/scholar/browser/ScholarBrowserManager.py +705 -0
- scitex/scholar/browser/__init__.py +38 -0
- scitex/scholar/browser/utils/__init__.py +13 -0
- scitex/scholar/browser/utils/click_and_wait.py +205 -0
- scitex/scholar/browser/utils/close_unwanted_pages.py +140 -0
- scitex/scholar/browser/utils/wait_redirects.py +732 -0
- scitex/scholar/config/PublisherRules.py +132 -0
- scitex/scholar/config/ScholarConfig.py +126 -0
- scitex/scholar/config/__init__.py +17 -0
- scitex/scholar/core/Paper.py +627 -0
- scitex/scholar/core/Papers.py +722 -0
- scitex/scholar/core/Scholar.py +1975 -0
- scitex/scholar/core/__init__.py +9 -0
- scitex/scholar/impact_factor/ImpactFactorEngine.py +204 -0
- scitex/scholar/impact_factor/__init__.py +20 -0
- scitex/scholar/impact_factor/estimation/ImpactFactorEstimationEngine.py +0 -0
- scitex/scholar/impact_factor/estimation/__init__.py +40 -0
- scitex/scholar/impact_factor/estimation/build_database.py +0 -0
- scitex/scholar/impact_factor/estimation/core/__init__.py +28 -0
- scitex/scholar/impact_factor/estimation/core/cache_manager.py +523 -0
- scitex/scholar/impact_factor/estimation/core/calculator.py +355 -0
- scitex/scholar/impact_factor/estimation/core/journal_matcher.py +428 -0
- scitex/scholar/integration/__init__.py +59 -0
- scitex/scholar/integration/base.py +502 -0
- scitex/scholar/integration/mendeley/__init__.py +22 -0
- scitex/scholar/integration/mendeley/exporter.py +166 -0
- scitex/scholar/integration/mendeley/importer.py +236 -0
- scitex/scholar/integration/mendeley/linker.py +79 -0
- scitex/scholar/integration/mendeley/mapper.py +212 -0
- scitex/scholar/integration/zotero/__init__.py +27 -0
- scitex/scholar/integration/zotero/__main__.py +264 -0
- scitex/scholar/integration/zotero/exporter.py +351 -0
- scitex/scholar/integration/zotero/importer.py +372 -0
- scitex/scholar/integration/zotero/linker.py +415 -0
- scitex/scholar/integration/zotero/mapper.py +286 -0
- scitex/scholar/metadata_engines/ScholarEngine.py +588 -0
- scitex/scholar/metadata_engines/__init__.py +21 -0
- scitex/scholar/metadata_engines/individual/ArXivEngine.py +397 -0
- scitex/scholar/metadata_engines/individual/CrossRefEngine.py +274 -0
- scitex/scholar/metadata_engines/individual/CrossRefLocalEngine.py +263 -0
- scitex/scholar/metadata_engines/individual/OpenAlexEngine.py +350 -0
- scitex/scholar/metadata_engines/individual/PubMedEngine.py +329 -0
- scitex/scholar/metadata_engines/individual/SemanticScholarEngine.py +438 -0
- scitex/scholar/metadata_engines/individual/URLDOIEngine.py +410 -0
- scitex/scholar/metadata_engines/individual/_BaseDOIEngine.py +487 -0
- scitex/scholar/metadata_engines/individual/__init__.py +7 -0
- scitex/scholar/metadata_engines/utils/_PubMedConverter.py +469 -0
- scitex/scholar/metadata_engines/utils/_URLDOIExtractor.py +283 -0
- scitex/scholar/metadata_engines/utils/__init__.py +30 -0
- scitex/scholar/metadata_engines/utils/_metadata2bibtex.py +103 -0
- scitex/scholar/metadata_engines/utils/_standardize_metadata.py +376 -0
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +579 -0
- scitex/scholar/pdf_download/__init__.py +5 -0
- scitex/scholar/pdf_download/strategies/__init__.py +38 -0
- scitex/scholar/pdf_download/strategies/chrome_pdf_viewer.py +376 -0
- scitex/scholar/pdf_download/strategies/direct_download.py +131 -0
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +167 -0
- scitex/scholar/pdf_download/strategies/manual_download_utils.py +996 -0
- scitex/scholar/pdf_download/strategies/response_body.py +207 -0
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +364 -0
- scitex/scholar/pipelines/ScholarPipelineParallel.py +478 -0
- scitex/scholar/pipelines/ScholarPipelineSingle.py +767 -0
- scitex/scholar/pipelines/__init__.py +49 -0
- scitex/scholar/storage/BibTeXHandler.py +1018 -0
- scitex/scholar/storage/PaperIO.py +468 -0
- scitex/scholar/storage/ScholarLibrary.py +182 -0
- scitex/scholar/storage/_DeduplicationManager.py +548 -0
- scitex/scholar/storage/_LibraryCacheManager.py +724 -0
- scitex/scholar/storage/_LibraryManager.py +1835 -0
- scitex/scholar/storage/__init__.py +28 -0
- scitex/scholar/url_finder/ScholarURLFinder.py +379 -0
- scitex/scholar/url_finder/__init__.py +7 -0
- scitex/scholar/url_finder/strategies/__init__.py +33 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_direct_links.py +261 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_dropdown.py +67 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_href.py +204 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_navigation.py +256 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_publisher_patterns.py +165 -0
- scitex/scholar/url_finder/strategies/find_pdf_urls_by_zotero_translators.py +163 -0
- scitex/scholar/url_finder/strategies/find_supplementary_urls_by_href.py +70 -0
- scitex/scholar/utils/__init__.py +22 -0
- scitex/scholar/utils/bibtex/__init__.py +9 -0
- scitex/scholar/utils/bibtex/_parse_bibtex.py +71 -0
- scitex/scholar/utils/cleanup/__init__.py +8 -0
- scitex/scholar/utils/cleanup/_cleanup_scholar_processes.py +96 -0
- scitex/scholar/utils/cleanup/cleanup_old_extractions.py +117 -0
- scitex/scholar/utils/text/_TextNormalizer.py +407 -0
- scitex/scholar/utils/text/__init__.py +9 -0
- scitex/scholar/zotero/__init__.py +38 -0
- scitex/session/__init__.py +51 -0
- scitex/session/_lifecycle.py +736 -0
- scitex/session/_manager.py +102 -0
- scitex/session/template.py +122 -0
- scitex/stats/__init__.py +30 -26
- scitex/stats/correct/__init__.py +21 -0
- scitex/stats/correct/_correct_bonferroni.py +551 -0
- scitex/stats/correct/_correct_fdr.py +634 -0
- scitex/stats/correct/_correct_holm.py +548 -0
- scitex/stats/correct/_correct_sidak.py +499 -0
- scitex/stats/descriptive/__init__.py +85 -0
- scitex/stats/descriptive/_circular.py +540 -0
- scitex/stats/descriptive/_describe.py +219 -0
- scitex/stats/descriptive/_nan.py +518 -0
- scitex/stats/descriptive/_real.py +189 -0
- scitex/stats/effect_sizes/__init__.py +41 -0
- scitex/stats/effect_sizes/_cliffs_delta.py +325 -0
- scitex/stats/effect_sizes/_cohens_d.py +342 -0
- scitex/stats/effect_sizes/_epsilon_squared.py +315 -0
- scitex/stats/effect_sizes/_eta_squared.py +302 -0
- scitex/stats/effect_sizes/_prob_superiority.py +296 -0
- scitex/stats/posthoc/__init__.py +19 -0
- scitex/stats/posthoc/_dunnett.py +463 -0
- scitex/stats/posthoc/_games_howell.py +383 -0
- scitex/stats/posthoc/_tukey_hsd.py +367 -0
- scitex/stats/power/__init__.py +19 -0
- scitex/stats/power/_power.py +433 -0
- scitex/stats/template.py +119 -0
- scitex/stats/utils/__init__.py +62 -0
- scitex/stats/utils/_effect_size.py +985 -0
- scitex/stats/utils/_formatters.py +270 -0
- scitex/stats/utils/_normalizers.py +927 -0
- scitex/stats/utils/_power.py +433 -0
- scitex/stats_v01/_EffectSizeCalculator.py +488 -0
- scitex/stats_v01/_StatisticalValidator.py +411 -0
- scitex/stats_v01/__init__.py +60 -0
- scitex/stats_v01/_additional_tests.py +415 -0
- scitex/{stats → stats_v01}/_p2stars.py +19 -5
- scitex/stats_v01/_two_sample_tests.py +141 -0
- scitex/stats_v01/desc/__init__.py +83 -0
- scitex/stats_v01/desc/_circular.py +540 -0
- scitex/stats_v01/desc/_describe.py +219 -0
- scitex/stats_v01/desc/_nan.py +518 -0
- scitex/{stats/desc/_nan.py → stats_v01/desc/_nan_v01-20250920_145731.py} +23 -12
- scitex/stats_v01/desc/_real.py +189 -0
- scitex/stats_v01/tests/__corr_test_optimized.py +221 -0
- scitex/stats_v01/tests/_corr_test_optimized.py +179 -0
- scitex/str/__init__.py +1 -3
- scitex/str/_clean_path.py +6 -2
- scitex/str/_latex_fallback.py +267 -160
- scitex/str/_parse.py +44 -36
- scitex/str/_printc.py +1 -3
- scitex/template/__init__.py +87 -0
- scitex/template/_create_project.py +267 -0
- scitex/template/create_pip_project.py +80 -0
- scitex/template/create_research.py +80 -0
- scitex/template/create_singularity.py +80 -0
- scitex/units.py +291 -0
- scitex/utils/_compress_hdf5.py +14 -3
- scitex/utils/_email.py +21 -2
- scitex/utils/_grid.py +6 -4
- scitex/utils/_notify.py +13 -10
- scitex/utils/_verify_scitex_format.py +589 -0
- scitex/utils/_verify_scitex_format_v01.py +370 -0
- scitex/utils/template.py +122 -0
- scitex/web/_search_pubmed.py +62 -16
- scitex-2.1.0.dist-info/LICENSE +21 -0
- scitex-2.1.0.dist-info/METADATA +677 -0
- scitex-2.1.0.dist-info/RECORD +919 -0
- {scitex-2.0.0.dist-info → scitex-2.1.0.dist-info}/WHEEL +1 -1
- scitex-2.1.0.dist-info/entry_points.txt +3 -0
- scitex/ai/__Classifiers.py +0 -101
- scitex/ai/classification/classification_reporter.py +0 -1137
- scitex/ai/classification/classifiers.py +0 -101
- scitex/ai/classification_reporter.py +0 -1161
- scitex/ai/genai/__init__.py +0 -277
- scitex/ai/genai/anthropic_provider.py +0 -320
- scitex/ai/genai/anthropic_refactored.py +0 -109
- scitex/ai/genai/auth_manager.py +0 -200
- scitex/ai/genai/base_provider.py +0 -291
- scitex/ai/genai/chat_history.py +0 -307
- scitex/ai/genai/cost_tracker.py +0 -276
- scitex/ai/genai/deepseek_provider.py +0 -251
- scitex/ai/genai/google_provider.py +0 -228
- scitex/ai/genai/groq_provider.py +0 -248
- scitex/ai/genai/image_processor.py +0 -250
- scitex/ai/genai/llama_provider.py +0 -214
- scitex/ai/genai/mock_provider.py +0 -127
- scitex/ai/genai/model_registry.py +0 -304
- scitex/ai/genai/openai_provider.py +0 -293
- scitex/ai/genai/perplexity_provider.py +0 -205
- scitex/ai/genai/provider_base.py +0 -302
- scitex/ai/genai/provider_factory.py +0 -370
- scitex/ai/genai/response_handler.py +0 -235
- scitex/ai/layer/_Pass.py +0 -21
- scitex/ai/layer/__init__.py +0 -10
- scitex/ai/layer/_switch.py +0 -8
- scitex/ai/metrics/_bACC.py +0 -51
- scitex/ai/plt/_learning_curve.py +0 -194
- scitex/ai/plt/_optuna_study.py +0 -111
- scitex/ai/plt/aucs/__init__.py +0 -2
- scitex/ai/plt/aucs/example.py +0 -60
- scitex/ai/plt/aucs/pre_rec_auc.py +0 -223
- scitex/ai/plt/aucs/roc_auc.py +0 -246
- scitex/ai/sampling/undersample.py +0 -29
- scitex/db/_SQLite3.py +0 -2136
- scitex/db/_SQLite3Mixins/_BlobMixin.py +0 -229
- scitex/gen/_close.py +0 -222
- scitex/gen/_start.py +0 -451
- scitex/general/__init__.py +0 -5
- scitex/io/_load_modules/_db.py +0 -24
- scitex/life/__init__.py +0 -10
- scitex/life/_monitor_rain.py +0 -49
- scitex/reproduce/_fix_seeds.py +0 -45
- scitex/res/__init__.py +0 -5
- scitex/scholar/_local_search.py +0 -454
- scitex/scholar/_paper.py +0 -244
- scitex/scholar/_pdf_downloader.py +0 -325
- scitex/scholar/_search.py +0 -393
- scitex/scholar/_vector_search.py +0 -370
- scitex/scholar/_web_sources.py +0 -457
- scitex/stats/desc/__init__.py +0 -40
- scitex-2.0.0.dist-info/METADATA +0 -307
- scitex-2.0.0.dist-info/RECORD +0 -572
- scitex-2.0.0.dist-info/licenses/LICENSE +0 -7
- /scitex/ai/{act → activation}/__init__.py +0 -0
- /scitex/ai/{act → activation}/_define.py +0 -0
- /scitex/ai/{early_stopping.py → training/_EarlyStopping.py} +0 -0
- /scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_ImportExportMixin.py +0 -0
- /scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_IndexMixin.py +0 -0
- /scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_RowMixin.py +0 -0
- /scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/_TableMixin.py +0 -0
- /scitex/db/{_PostgreSQLMixins → _postgresql/_PostgreSQLMixins}/__init__.py +0 -0
- /scitex/{stats → stats_v01}/_calc_partial_corr.py +0 -0
- /scitex/{stats → stats_v01}/_corr_test_multi.py +0 -0
- /scitex/{stats → stats_v01}/_corr_test_wrapper.py +0 -0
- /scitex/{stats → stats_v01}/_describe_wrapper.py +0 -0
- /scitex/{stats → stats_v01}/_multiple_corrections.py +0 -0
- /scitex/{stats → stats_v01}/_nan_stats.py +0 -0
- /scitex/{stats → stats_v01}/_p2stars_wrapper.py +0 -0
- /scitex/{stats → stats_v01}/_statistical_tests.py +0 -0
- /scitex/{stats/desc/_describe.py → stats_v01/desc/_describe_v01-20250920_145731.py} +0 -0
- /scitex/{stats/desc/_real.py → stats_v01/desc/_real_v01-20250920_145731.py} +0 -0
- /scitex/{stats → stats_v01}/multiple/__init__.py +0 -0
- /scitex/{stats → stats_v01}/multiple/_bonferroni_correction.py +0 -0
- /scitex/{stats → stats_v01}/multiple/_fdr_correction.py +0 -0
- /scitex/{stats → stats_v01}/multiple/_multicompair.py +0 -0
- /scitex/{stats → stats_v01}/tests/__corr_test.py +0 -0
- /scitex/{stats → stats_v01}/tests/__corr_test_multi.py +0 -0
- /scitex/{stats → stats_v01}/tests/__corr_test_single.py +0 -0
- /scitex/{stats → stats_v01}/tests/__init__.py +0 -0
- /scitex/{stats → stats_v01}/tests/_brunner_munzel_test.py +0 -0
- /scitex/{stats → stats_v01}/tests/_nocorrelation_test.py +0 -0
- /scitex/{stats → stats_v01}/tests/_smirnov_grubbs.py +0 -0
- {scitex-2.0.0.dist-info → scitex-2.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1028 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Timestamp: "2025-10-02 18:55:00 (ywatanabe)"
|
|
4
|
+
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/ml/classification/reporters/reporter_utils/_Plotter.py
|
|
5
|
+
# ----------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import os
|
|
8
|
+
__FILE__ = "./src/scitex/ml/classification/reporters/reporter_utils/_Plotter.py"
|
|
9
|
+
__DIR__ = os.path.dirname(__FILE__)
|
|
10
|
+
# ----------------------------------------
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
Classification Plotter - Delegates to stx.ml.plt functions.
|
|
14
|
+
|
|
15
|
+
This module provides a Plotter class that delegates to centralized
|
|
16
|
+
plotting functions in scitex.ml.plt to maintain DRY principle.
|
|
17
|
+
|
|
18
|
+
Features:
|
|
19
|
+
- Graceful error handling
|
|
20
|
+
- Headless environment support (Agg backend)
|
|
21
|
+
- Optional plotting with proper disabling
|
|
22
|
+
- Delegates to:
|
|
23
|
+
* stx.ml.plt.conf_mat (confusion matrices)
|
|
24
|
+
* stx.ml.plt.roc_auc (ROC curves)
|
|
25
|
+
* stx.ml.plt.pre_rec_auc (Precision-Recall curves)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import warnings
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
from typing import Any, List, Optional, Union
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
|
|
34
|
+
# Import centralized plotting functions from stx.ml.plt
|
|
35
|
+
try:
|
|
36
|
+
import matplotlib
|
|
37
|
+
import matplotlib.pyplot as plt
|
|
38
|
+
|
|
39
|
+
# Try to import seaborn for enhanced visualizations
|
|
40
|
+
try:
|
|
41
|
+
import seaborn as sns
|
|
42
|
+
except ImportError:
|
|
43
|
+
sns = None
|
|
44
|
+
|
|
45
|
+
# Import scitex plotting functions
|
|
46
|
+
import scitex as stx
|
|
47
|
+
from scitex.ml.plt.plot_conf_mat import plot_conf_mat as conf_mat
|
|
48
|
+
from scitex.ml.plt.plot_roc_curve import plot_roc_curve as roc_auc
|
|
49
|
+
from scitex.ml.plt.plot_pre_rec_curve import plot_pre_rec_curve as pre_rec_auc
|
|
50
|
+
|
|
51
|
+
PLOTTING_AVAILABLE = True
|
|
52
|
+
except ImportError:
|
|
53
|
+
PLOTTING_AVAILABLE = False
|
|
54
|
+
plt = None
|
|
55
|
+
sns = None
|
|
56
|
+
conf_mat = None
|
|
57
|
+
roc_auc = None
|
|
58
|
+
pre_rec_auc = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Plotter:
|
|
62
|
+
"""
|
|
63
|
+
Enhanced plotter with graceful error handling.
|
|
64
|
+
|
|
65
|
+
Features:
|
|
66
|
+
- Automatically disables if plotting libraries unavailable
|
|
67
|
+
- Uses non-interactive backend when no display available
|
|
68
|
+
- Provides informative error messages
|
|
69
|
+
- Supports fallback options
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
enable_plotting: bool = True,
|
|
75
|
+
save_dir: Optional[Path] = None,
|
|
76
|
+
verbose: bool = True,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Initialize plotter.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
enable_plotting : bool, default True
|
|
84
|
+
Whether to attempt plotting
|
|
85
|
+
save_dir : Path, optional
|
|
86
|
+
Directory to save plots
|
|
87
|
+
"""
|
|
88
|
+
self.enabled = enable_plotting and PLOTTING_AVAILABLE
|
|
89
|
+
self.save_dir = Path(save_dir) if save_dir else None
|
|
90
|
+
self.verbose = verbose
|
|
91
|
+
|
|
92
|
+
if enable_plotting and not PLOTTING_AVAILABLE:
|
|
93
|
+
warnings.warn(
|
|
94
|
+
"Plotting libraries not available. Plotting disabled.",
|
|
95
|
+
UserWarning,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def create_confusion_matrix_plot(
|
|
99
|
+
self,
|
|
100
|
+
confusion_matrix: np.ndarray,
|
|
101
|
+
labels: Optional[List[str]] = None,
|
|
102
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
103
|
+
verbose: bool = True,
|
|
104
|
+
title: str = "Confusion Matrix",
|
|
105
|
+
) -> Optional[Any]:
|
|
106
|
+
"""
|
|
107
|
+
Create confusion matrix plot with error handling.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
confusion_matrix : np.ndarray
|
|
112
|
+
Confusion matrix
|
|
113
|
+
labels : List[str], optional
|
|
114
|
+
Class labels
|
|
115
|
+
save_path : Union[str, Path], optional
|
|
116
|
+
Path to save plot
|
|
117
|
+
verbose : bool
|
|
118
|
+
Whether to print messages
|
|
119
|
+
title : str, default "Confusion Matrix"
|
|
120
|
+
Title for the plot
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
Optional[Any]
|
|
125
|
+
Matplotlib figure or None if plotting failed
|
|
126
|
+
"""
|
|
127
|
+
if not self.enabled or confusion_matrix is None:
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
# Delegate to centralized conf_mat function from stx.ml.plt
|
|
132
|
+
fig = conf_mat(
|
|
133
|
+
cm=confusion_matrix,
|
|
134
|
+
labels=labels,
|
|
135
|
+
title=title,
|
|
136
|
+
spath=save_path,
|
|
137
|
+
)
|
|
138
|
+
return fig
|
|
139
|
+
except Exception as e:
|
|
140
|
+
if self.verbose:
|
|
141
|
+
warnings.warn(
|
|
142
|
+
f"Failed to create confusion matrix plot: {e}", UserWarning
|
|
143
|
+
)
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def create_roc_curve(
|
|
147
|
+
self,
|
|
148
|
+
y_true: np.ndarray,
|
|
149
|
+
y_proba: np.ndarray,
|
|
150
|
+
labels: Optional[List[str]] = None,
|
|
151
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
152
|
+
verbose: bool = True,
|
|
153
|
+
title: str = "ROC Curve",
|
|
154
|
+
) -> Optional[Any]:
|
|
155
|
+
"""
|
|
156
|
+
Create ROC curve plot - delegates to stx.ml.plt.roc_auc.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
y_true : np.ndarray
|
|
161
|
+
True labels
|
|
162
|
+
y_proba : np.ndarray
|
|
163
|
+
Prediction probabilities
|
|
164
|
+
labels : List[str], optional
|
|
165
|
+
Class labels
|
|
166
|
+
save_path : Union[str, Path], optional
|
|
167
|
+
Path to save plot
|
|
168
|
+
verbose : bool
|
|
169
|
+
Whether to print messages
|
|
170
|
+
title : str, default "ROC Curve"
|
|
171
|
+
Title for the plot
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
Optional[Any]
|
|
176
|
+
Matplotlib figure or None if plotting failed
|
|
177
|
+
"""
|
|
178
|
+
if not self.enabled or y_true is None or y_proba is None:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
# Delegate to centralized roc_auc function from stx.ml.plt
|
|
183
|
+
fig, _ = roc_auc(
|
|
184
|
+
true_class=y_true,
|
|
185
|
+
pred_proba=y_proba,
|
|
186
|
+
labels=labels or [],
|
|
187
|
+
spath=save_path,
|
|
188
|
+
)
|
|
189
|
+
return fig
|
|
190
|
+
except Exception as e:
|
|
191
|
+
import sys
|
|
192
|
+
print(f"ERROR in create_roc_curve: {e}", file=sys.stderr)
|
|
193
|
+
import traceback
|
|
194
|
+
traceback.print_exc()
|
|
195
|
+
if self.verbose:
|
|
196
|
+
warnings.warn(f"Failed to create ROC curve: {e}", UserWarning)
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
def create_precision_recall_curve(
|
|
200
|
+
self,
|
|
201
|
+
y_true: np.ndarray,
|
|
202
|
+
y_proba: np.ndarray,
|
|
203
|
+
labels: Optional[List[str]] = None,
|
|
204
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
205
|
+
verbose: bool = True,
|
|
206
|
+
title: str = "Precision-Recall Curve",
|
|
207
|
+
) -> Optional[Any]:
|
|
208
|
+
"""
|
|
209
|
+
Create Precision-Recall curve plot - delegates to stx.ml.plt.pre_rec_auc.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
y_true : np.ndarray
|
|
214
|
+
True labels
|
|
215
|
+
y_proba : np.ndarray
|
|
216
|
+
Prediction probabilities
|
|
217
|
+
labels : List[str], optional
|
|
218
|
+
Class labels
|
|
219
|
+
save_path : Union[str, Path], optional
|
|
220
|
+
Path to save plot
|
|
221
|
+
verbose : bool
|
|
222
|
+
Whether to print messages
|
|
223
|
+
title : str, default "Precision-Recall Curve"
|
|
224
|
+
Title for the plot
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
Optional[Any]
|
|
229
|
+
Matplotlib figure or None if plotting failed
|
|
230
|
+
"""
|
|
231
|
+
if not self.enabled or y_true is None or y_proba is None:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
# Delegate to centralized pre_rec_auc function from stx.ml.plt
|
|
236
|
+
fig, _ = pre_rec_auc(
|
|
237
|
+
true_class=y_true,
|
|
238
|
+
pred_proba=y_proba,
|
|
239
|
+
labels=labels or [],
|
|
240
|
+
spath=save_path,
|
|
241
|
+
)
|
|
242
|
+
return fig
|
|
243
|
+
except Exception as e:
|
|
244
|
+
import sys
|
|
245
|
+
print(f"ERROR in create_precision_recall_curve: {e}", file=sys.stderr)
|
|
246
|
+
import traceback
|
|
247
|
+
traceback.print_exc()
|
|
248
|
+
if self.verbose:
|
|
249
|
+
warnings.warn(f"Failed to create PR curve: {e}", UserWarning)
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
def create_overall_roc_curve(
|
|
253
|
+
self,
|
|
254
|
+
y_true: np.ndarray,
|
|
255
|
+
y_proba: np.ndarray,
|
|
256
|
+
labels: Optional[List[str]] = None,
|
|
257
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
258
|
+
verbose: bool = True,
|
|
259
|
+
title: str = "ROC Curve (Overall)",
|
|
260
|
+
auc_mean: Optional[float] = None,
|
|
261
|
+
auc_std: Optional[float] = None,
|
|
262
|
+
) -> Optional[Any]:
|
|
263
|
+
"""
|
|
264
|
+
Create overall ROC curve plot with AUC statistics.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
y_true : np.ndarray
|
|
269
|
+
True labels
|
|
270
|
+
y_proba : np.ndarray
|
|
271
|
+
Prediction probabilities
|
|
272
|
+
save_path : Union[str, Path], optional
|
|
273
|
+
Path to save plot
|
|
274
|
+
verbose : bool
|
|
275
|
+
Whether to print messages
|
|
276
|
+
title : str
|
|
277
|
+
Title for the plot
|
|
278
|
+
auc_mean : float, optional
|
|
279
|
+
Mean AUC across folds
|
|
280
|
+
auc_std : float, optional
|
|
281
|
+
Standard deviation of AUC across folds
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
Optional[Any]
|
|
286
|
+
Matplotlib figure or None if plotting failed
|
|
287
|
+
"""
|
|
288
|
+
if not self.enabled:
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
from sklearn.metrics import auc, roc_curve
|
|
293
|
+
|
|
294
|
+
# Handle binary vs multiclass
|
|
295
|
+
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
296
|
+
# Binary classification
|
|
297
|
+
if y_proba.ndim == 2:
|
|
298
|
+
y_proba_pos = y_proba[:, 1]
|
|
299
|
+
else:
|
|
300
|
+
y_proba_pos = y_proba
|
|
301
|
+
|
|
302
|
+
# Determine pos_label for string labels
|
|
303
|
+
pos_label = None
|
|
304
|
+
if labels and len(labels) >= 2:
|
|
305
|
+
pos_label = labels[1] # Second label is positive class
|
|
306
|
+
|
|
307
|
+
fpr, tpr, _ = roc_curve(y_true, y_proba_pos, pos_label=pos_label)
|
|
308
|
+
roc_auc = auc(fpr, tpr)
|
|
309
|
+
|
|
310
|
+
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure
|
|
311
|
+
|
|
312
|
+
# Use provided mean/std if available, otherwise use calculated AUC
|
|
313
|
+
if auc_mean is not None and auc_std is not None:
|
|
314
|
+
label = f"ROC Curve (AUC = {auc_mean:.3f} ± {auc_std:.3f})"
|
|
315
|
+
else:
|
|
316
|
+
label = f"ROC Curve (AUC = {roc_auc:.3f})"
|
|
317
|
+
|
|
318
|
+
ax.plot(fpr, tpr, label=label, linewidth=2)
|
|
319
|
+
ax.plot([0, 1], [0, 1], "k--", label="Random", alpha=0.5)
|
|
320
|
+
ax.set_xlabel("False Positive Rate")
|
|
321
|
+
ax.set_ylabel("True Positive Rate")
|
|
322
|
+
ax.set_title(title)
|
|
323
|
+
ax.set_xlim([0, 1])
|
|
324
|
+
ax.set_ylim([0, 1])
|
|
325
|
+
ax.set_aspect("equal")
|
|
326
|
+
ax.legend(loc="lower right")
|
|
327
|
+
ax.grid(True, alpha=0.3)
|
|
328
|
+
|
|
329
|
+
else:
|
|
330
|
+
# Multiclass - not fully supported for overall curves yet
|
|
331
|
+
return None
|
|
332
|
+
|
|
333
|
+
if save_path:
|
|
334
|
+
try:
|
|
335
|
+
from pathlib import Path
|
|
336
|
+
from scitex.io import save as stx_io_save
|
|
337
|
+
|
|
338
|
+
# Resolve to absolute path to prevent _out directory creation
|
|
339
|
+
save_path_abs = Path(save_path).resolve() if isinstance(save_path, (str, Path)) else save_path
|
|
340
|
+
stx_io_save(fig, str(save_path_abs), verbose=True, use_caller_path=False)
|
|
341
|
+
except Exception as save_error:
|
|
342
|
+
print(f"ERROR: Failed to save ROC curve: {save_error}")
|
|
343
|
+
import traceback
|
|
344
|
+
traceback.print_exc()
|
|
345
|
+
|
|
346
|
+
plt.close(fig) # Clean up
|
|
347
|
+
return fig
|
|
348
|
+
|
|
349
|
+
except Exception as e:
|
|
350
|
+
print(f"ERROR in create_overall_roc_curve: {e}")
|
|
351
|
+
import traceback
|
|
352
|
+
traceback.print_exc()
|
|
353
|
+
warnings.warn(
|
|
354
|
+
f"Failed to create overall ROC curve: {e}", UserWarning
|
|
355
|
+
)
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
def create_overall_pr_curve(
|
|
359
|
+
self,
|
|
360
|
+
y_true: np.ndarray,
|
|
361
|
+
y_proba: np.ndarray,
|
|
362
|
+
labels: Optional[List[str]] = None,
|
|
363
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
364
|
+
verbose: bool = True,
|
|
365
|
+
title: str = "Precision-Recall Curve (Overall)",
|
|
366
|
+
ap_mean: Optional[float] = None,
|
|
367
|
+
ap_std: Optional[float] = None,
|
|
368
|
+
) -> Optional[Any]:
|
|
369
|
+
"""
|
|
370
|
+
Create overall Precision-Recall curve plot with AP statistics.
|
|
371
|
+
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
y_true : np.ndarray
|
|
375
|
+
True labels
|
|
376
|
+
y_proba : np.ndarray
|
|
377
|
+
Prediction probabilities
|
|
378
|
+
save_path : Union[str, Path], optional
|
|
379
|
+
Path to save plot
|
|
380
|
+
verbose : bool
|
|
381
|
+
Whether to print messages
|
|
382
|
+
title : str
|
|
383
|
+
Title for the plot
|
|
384
|
+
ap_mean : float, optional
|
|
385
|
+
Mean Average Precision across folds
|
|
386
|
+
ap_std : float, optional
|
|
387
|
+
Standard deviation of AP across folds
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
Optional[Any]
|
|
392
|
+
Matplotlib figure or None if plotting failed
|
|
393
|
+
"""
|
|
394
|
+
if not self.enabled:
|
|
395
|
+
return None
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
from sklearn.metrics import (average_precision_score,
|
|
399
|
+
precision_recall_curve)
|
|
400
|
+
|
|
401
|
+
# Handle binary classification
|
|
402
|
+
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
403
|
+
if y_proba.ndim == 2:
|
|
404
|
+
y_proba_pos = y_proba[:, 1]
|
|
405
|
+
else:
|
|
406
|
+
y_proba_pos = y_proba
|
|
407
|
+
|
|
408
|
+
precision, recall, _ = precision_recall_curve(
|
|
409
|
+
y_true, y_proba_pos
|
|
410
|
+
)
|
|
411
|
+
avg_precision = average_precision_score(y_true, y_proba_pos)
|
|
412
|
+
|
|
413
|
+
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure
|
|
414
|
+
|
|
415
|
+
# Use provided mean/std if available, otherwise use calculated AP
|
|
416
|
+
if ap_mean is not None and ap_std is not None:
|
|
417
|
+
label = f"PR Curve (Average Precision (AP) = {ap_mean:.3f} ± {ap_std:.3f})"
|
|
418
|
+
else:
|
|
419
|
+
label = f"PR Curve (Average Precision (AP) = {avg_precision:.3f})"
|
|
420
|
+
|
|
421
|
+
ax.plot(recall, precision, label=label, linewidth=2)
|
|
422
|
+
ax.set_xlabel("Recall")
|
|
423
|
+
ax.set_ylabel("Precision")
|
|
424
|
+
ax.set_title(title)
|
|
425
|
+
ax.set_xlim([0, 1])
|
|
426
|
+
ax.set_ylim([0, 1])
|
|
427
|
+
ax.set_aspect("equal")
|
|
428
|
+
ax.legend(loc="best")
|
|
429
|
+
ax.grid(True, alpha=0.3)
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
# Multiclass not well-defined for PR curves
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
if save_path:
|
|
436
|
+
from pathlib import Path
|
|
437
|
+
from scitex.io import save as stx_io_save
|
|
438
|
+
|
|
439
|
+
# Resolve to absolute path to prevent _out directory creation
|
|
440
|
+
save_path_abs = Path(save_path).resolve() if isinstance(save_path, (str, Path)) else save_path
|
|
441
|
+
stx_io_save(fig, str(save_path_abs), verbose=verbose or self.verbose, use_caller_path=False)
|
|
442
|
+
|
|
443
|
+
plt.close(fig) # Clean up
|
|
444
|
+
return fig
|
|
445
|
+
|
|
446
|
+
except Exception as e:
|
|
447
|
+
warnings.warn(
|
|
448
|
+
f"Failed to create overall PR curve: {e}", UserWarning
|
|
449
|
+
)
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def create_metrics_visualization(
|
|
454
|
+
self,
|
|
455
|
+
metrics: dict,
|
|
456
|
+
y_true: Optional[np.ndarray] = None,
|
|
457
|
+
y_pred: Optional[np.ndarray] = None,
|
|
458
|
+
y_proba: Optional[np.ndarray] = None,
|
|
459
|
+
labels: Optional[List[str]] = None,
|
|
460
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
461
|
+
title: str = "Classification Metrics Summary",
|
|
462
|
+
fold: Optional[int] = None,
|
|
463
|
+
verbose: bool = True,
|
|
464
|
+
) -> Optional[Any]:
|
|
465
|
+
"""
|
|
466
|
+
Create comprehensive metrics visualization dashboard.
|
|
467
|
+
|
|
468
|
+
This generalized method creates a multi-panel figure showing:
|
|
469
|
+
- Confusion matrix (if y_true and y_pred available)
|
|
470
|
+
- ROC curve (if y_true and y_proba available)
|
|
471
|
+
- Precision-Recall curve (if y_true and y_proba available)
|
|
472
|
+
- Key metrics summary table
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
metrics : dict
|
|
477
|
+
Dictionary of calculated metrics (balanced_accuracy, mcc, etc.)
|
|
478
|
+
y_true : np.ndarray, optional
|
|
479
|
+
True labels
|
|
480
|
+
y_pred : np.ndarray, optional
|
|
481
|
+
Predicted labels
|
|
482
|
+
y_proba : np.ndarray, optional
|
|
483
|
+
Prediction probabilities
|
|
484
|
+
labels : List[str], optional
|
|
485
|
+
Class labels
|
|
486
|
+
save_path : Union[str, Path], optional
|
|
487
|
+
Path to save the visualization
|
|
488
|
+
title : str, default "Classification Metrics Summary"
|
|
489
|
+
Overall title for the figure
|
|
490
|
+
fold : int, optional
|
|
491
|
+
Fold number (for cross-validation)
|
|
492
|
+
verbose : bool, default True
|
|
493
|
+
Whether to print messages
|
|
494
|
+
|
|
495
|
+
Returns
|
|
496
|
+
-------
|
|
497
|
+
Optional[Any]
|
|
498
|
+
Matplotlib figure or None if plotting failed
|
|
499
|
+
|
|
500
|
+
Examples
|
|
501
|
+
--------
|
|
502
|
+
>>> plotter = Plotter(enable_plotting=True)
|
|
503
|
+
>>> metrics = {
|
|
504
|
+
... 'balanced_accuracy': 0.85,
|
|
505
|
+
... 'mcc': 0.75,
|
|
506
|
+
... 'roc_auc': 0.90
|
|
507
|
+
... }
|
|
508
|
+
>>> fig = plotter.create_metrics_visualization(
|
|
509
|
+
... metrics, y_true, y_pred, y_proba,
|
|
510
|
+
... save_path='metrics_summary.png'
|
|
511
|
+
... )
|
|
512
|
+
"""
|
|
513
|
+
if not self.enabled:
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
try:
|
|
517
|
+
# Determine layout based on available data
|
|
518
|
+
has_cm = y_true is not None and y_pred is not None
|
|
519
|
+
has_roc = y_true is not None and y_proba is not None
|
|
520
|
+
has_pr = has_roc # Same requirements
|
|
521
|
+
|
|
522
|
+
# Count available plots
|
|
523
|
+
n_plots = sum([has_cm, has_roc, has_pr, True]) # +1 for metrics table
|
|
524
|
+
|
|
525
|
+
# Create figure with appropriate layout
|
|
526
|
+
if n_plots == 4:
|
|
527
|
+
fig = plt.figure(figsize=(16, 12))
|
|
528
|
+
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
|
|
529
|
+
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
|
|
530
|
+
elif n_plots == 3:
|
|
531
|
+
fig = plt.figure(figsize=(16, 6))
|
|
532
|
+
gs = fig.add_gridspec(1, 3, hspace=0.3, wspace=0.3)
|
|
533
|
+
positions = [(0, 0), (0, 1), (0, 2)]
|
|
534
|
+
elif n_plots == 2:
|
|
535
|
+
fig = plt.figure(figsize=(12, 6))
|
|
536
|
+
gs = fig.add_gridspec(1, 2, hspace=0.3, wspace=0.3)
|
|
537
|
+
positions = [(0, 0), (0, 1)]
|
|
538
|
+
else:
|
|
539
|
+
fig = plt.figure(figsize=(8, 6))
|
|
540
|
+
gs = fig.add_gridspec(1, 1)
|
|
541
|
+
positions = [(0, 0)]
|
|
542
|
+
|
|
543
|
+
# Set overall title
|
|
544
|
+
fold_suffix = f" (Fold {fold})" if fold is not None else ""
|
|
545
|
+
fig.suptitle(f"{title}{fold_suffix}", fontsize=16, fontweight='bold')
|
|
546
|
+
|
|
547
|
+
plot_idx = 0
|
|
548
|
+
|
|
549
|
+
# Plot 1: Confusion Matrix
|
|
550
|
+
if has_cm:
|
|
551
|
+
ax = fig.add_subplot(gs[positions[plot_idx]])
|
|
552
|
+
plot_idx += 1
|
|
553
|
+
|
|
554
|
+
# Get confusion matrix from metrics or calculate
|
|
555
|
+
cm = metrics.get('confusion_matrix')
|
|
556
|
+
if cm is not None:
|
|
557
|
+
if isinstance(cm, dict) and 'value' in cm:
|
|
558
|
+
cm = cm['value']
|
|
559
|
+
|
|
560
|
+
if sns is not None:
|
|
561
|
+
sns.heatmap(
|
|
562
|
+
cm,
|
|
563
|
+
annot=True,
|
|
564
|
+
fmt='d',
|
|
565
|
+
cmap='Blues',
|
|
566
|
+
xticklabels=labels,
|
|
567
|
+
yticklabels=labels,
|
|
568
|
+
ax=ax,
|
|
569
|
+
cbar_kws={'label': 'Count'}
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
im = ax.imshow(cm, cmap='Blues')
|
|
573
|
+
# Add annotations
|
|
574
|
+
for i in range(cm.shape[0]):
|
|
575
|
+
for j in range(cm.shape[1]):
|
|
576
|
+
ax.text(j, i, str(cm[i, j]),
|
|
577
|
+
ha='center', va='center')
|
|
578
|
+
|
|
579
|
+
ax.set_xlabel('Predicted Label')
|
|
580
|
+
ax.set_ylabel('True Label')
|
|
581
|
+
ax.set_title('Confusion Matrix')
|
|
582
|
+
|
|
583
|
+
# Plot 2: ROC Curve
|
|
584
|
+
if has_roc:
|
|
585
|
+
ax = fig.add_subplot(gs[positions[plot_idx]])
|
|
586
|
+
plot_idx += 1
|
|
587
|
+
|
|
588
|
+
from sklearn.metrics import auc, roc_curve
|
|
589
|
+
|
|
590
|
+
# Handle binary vs multiclass
|
|
591
|
+
if y_proba.ndim == 1 or y_proba.shape[1] == 2:
|
|
592
|
+
# Binary
|
|
593
|
+
if y_proba.ndim == 2:
|
|
594
|
+
y_proba_pos = y_proba[:, 1]
|
|
595
|
+
else:
|
|
596
|
+
y_proba_pos = y_proba
|
|
597
|
+
|
|
598
|
+
# Determine pos_label for string labels
|
|
599
|
+
pos_label = None
|
|
600
|
+
if labels and len(labels) >= 2:
|
|
601
|
+
pos_label = labels[1] # Second label is positive class
|
|
602
|
+
|
|
603
|
+
fpr, tpr, _ = roc_curve(y_true, y_proba_pos, pos_label=pos_label)
|
|
604
|
+
roc_auc = auc(fpr, tpr)
|
|
605
|
+
|
|
606
|
+
ax.plot(fpr, tpr, label=f'AUC = {roc_auc:.3f}', linewidth=2)
|
|
607
|
+
ax.plot([0, 1], [0, 1], 'k--', label='Random', alpha=0.5)
|
|
608
|
+
else:
|
|
609
|
+
# Multiclass - plot each class
|
|
610
|
+
for i in range(y_proba.shape[1]):
|
|
611
|
+
y_true_binary = (y_true == i).astype(int)
|
|
612
|
+
fpr, tpr, _ = roc_curve(y_true_binary, y_proba[:, i])
|
|
613
|
+
roc_auc = auc(fpr, tpr)
|
|
614
|
+
class_label = labels[i] if labels else f'Class {i}'
|
|
615
|
+
ax.plot(fpr, tpr, label=f'{class_label} (AUC={roc_auc:.3f})')
|
|
616
|
+
|
|
617
|
+
ax.plot([0, 1], [0, 1], 'k--', label='Random', alpha=0.5)
|
|
618
|
+
|
|
619
|
+
ax.set_xlabel('False Positive Rate')
|
|
620
|
+
ax.set_ylabel('True Positive Rate')
|
|
621
|
+
ax.set_title('ROC Curve')
|
|
622
|
+
ax.set_xlim([0, 1])
|
|
623
|
+
ax.set_ylim([0, 1])
|
|
624
|
+
ax.legend(loc='lower right')
|
|
625
|
+
ax.grid(True, alpha=0.3)
|
|
626
|
+
|
|
627
|
+
# Plot 3: Precision-Recall Curve
|
|
628
|
+
if has_pr and (y_proba.ndim == 1 or y_proba.shape[1] == 2):
|
|
629
|
+
ax = fig.add_subplot(gs[positions[plot_idx]])
|
|
630
|
+
plot_idx += 1
|
|
631
|
+
|
|
632
|
+
from sklearn.metrics import (average_precision_score,
|
|
633
|
+
precision_recall_curve)
|
|
634
|
+
|
|
635
|
+
if y_proba.ndim == 2:
|
|
636
|
+
y_proba_pos = y_proba[:, 1]
|
|
637
|
+
else:
|
|
638
|
+
y_proba_pos = y_proba
|
|
639
|
+
|
|
640
|
+
# Convert string labels to integer indices if needed
|
|
641
|
+
y_true_for_pr = y_true
|
|
642
|
+
if y_true.dtype.kind in ('U', 'S', 'O'): # Unicode, bytes, or object (string)
|
|
643
|
+
if labels:
|
|
644
|
+
label_to_idx = {label: idx for idx, label in enumerate(labels)}
|
|
645
|
+
y_true_for_pr = np.array([label_to_idx[yt] for yt in y_true])
|
|
646
|
+
else:
|
|
647
|
+
unique_labels = np.unique(y_true)
|
|
648
|
+
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
|
649
|
+
y_true_for_pr = np.array([label_to_idx[yt] for yt in y_true])
|
|
650
|
+
|
|
651
|
+
precision, recall, _ = precision_recall_curve(y_true_for_pr, y_proba_pos)
|
|
652
|
+
avg_precision = average_precision_score(y_true_for_pr, y_proba_pos)
|
|
653
|
+
|
|
654
|
+
ax.plot(recall, precision,
|
|
655
|
+
label=f'AP = {avg_precision:.3f}', linewidth=2)
|
|
656
|
+
ax.set_xlabel('Recall')
|
|
657
|
+
ax.set_ylabel('Precision')
|
|
658
|
+
ax.set_title('Precision-Recall Curve')
|
|
659
|
+
ax.set_xlim([0, 1])
|
|
660
|
+
ax.set_ylim([0, 1])
|
|
661
|
+
ax.legend(loc='lower left')
|
|
662
|
+
ax.grid(True, alpha=0.3)
|
|
663
|
+
|
|
664
|
+
# Plot 4: Metrics Summary Table
|
|
665
|
+
ax = fig.add_subplot(gs[positions[plot_idx]])
|
|
666
|
+
ax.axis('off')
|
|
667
|
+
|
|
668
|
+
# Prepare metrics table
|
|
669
|
+
metric_names = []
|
|
670
|
+
metric_values = []
|
|
671
|
+
|
|
672
|
+
# Standard metrics to display
|
|
673
|
+
display_metrics = {
|
|
674
|
+
'balanced_accuracy': 'Balanced Accuracy',
|
|
675
|
+
'mcc': 'Matthews Corr Coef',
|
|
676
|
+
'roc_auc': 'ROC AUC',
|
|
677
|
+
'pr_auc': 'PR AUC',
|
|
678
|
+
'pre_rec_auc': 'PR AUC',
|
|
679
|
+
'accuracy': 'Accuracy',
|
|
680
|
+
'precision': 'Precision',
|
|
681
|
+
'recall': 'Recall',
|
|
682
|
+
'f1_score': 'F1 Score',
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
for key, display_name in display_metrics.items():
|
|
686
|
+
if key in metrics:
|
|
687
|
+
value = metrics[key]
|
|
688
|
+
# Extract value if wrapped in dict
|
|
689
|
+
if isinstance(value, dict) and 'value' in value:
|
|
690
|
+
value = value['value']
|
|
691
|
+
if value is not None:
|
|
692
|
+
metric_names.append(display_name)
|
|
693
|
+
if isinstance(value, (int, float)):
|
|
694
|
+
metric_values.append(f'{value:.4f}')
|
|
695
|
+
else:
|
|
696
|
+
metric_values.append(str(value))
|
|
697
|
+
|
|
698
|
+
# Create table
|
|
699
|
+
if metric_names:
|
|
700
|
+
table_data = list(zip(metric_names, metric_values))
|
|
701
|
+
table = ax.table(
|
|
702
|
+
cellText=table_data,
|
|
703
|
+
colLabels=['Metric', 'Value'],
|
|
704
|
+
cellLoc='left',
|
|
705
|
+
loc='center',
|
|
706
|
+
colWidths=[0.6, 0.4]
|
|
707
|
+
)
|
|
708
|
+
table.auto_set_font_size(False)
|
|
709
|
+
table.set_fontsize(10)
|
|
710
|
+
table.scale(1, 2)
|
|
711
|
+
|
|
712
|
+
# Style header
|
|
713
|
+
for i in range(2):
|
|
714
|
+
table[(0, i)].set_facecolor('#40466e')
|
|
715
|
+
table[(0, i)].set_text_props(weight='bold', color='white')
|
|
716
|
+
|
|
717
|
+
# Alternate row colors
|
|
718
|
+
for i in range(1, len(metric_names) + 1):
|
|
719
|
+
if i % 2 == 0:
|
|
720
|
+
for j in range(2):
|
|
721
|
+
table[(i, j)].set_facecolor('#f0f0f0')
|
|
722
|
+
|
|
723
|
+
ax.set_title('Performance Metrics', fontweight='bold', pad=20)
|
|
724
|
+
|
|
725
|
+
# Save figure
|
|
726
|
+
if save_path:
|
|
727
|
+
from pathlib import Path
|
|
728
|
+
from scitex.io import save as stx_io_save
|
|
729
|
+
# Resolve to absolute path to prevent _out directory creation
|
|
730
|
+
save_path_abs = Path(save_path).resolve() if isinstance(save_path, (str, Path)) else save_path
|
|
731
|
+
stx_io_save(fig, str(save_path_abs), verbose=verbose or self.verbose, use_caller_path=False)
|
|
732
|
+
|
|
733
|
+
return fig
|
|
734
|
+
|
|
735
|
+
except Exception as e:
|
|
736
|
+
warnings.warn(
|
|
737
|
+
f"Failed to create metrics visualization: {e}",
|
|
738
|
+
UserWarning
|
|
739
|
+
)
|
|
740
|
+
import traceback
|
|
741
|
+
traceback.print_exc()
|
|
742
|
+
return None
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def create_feature_importance_plot(
|
|
746
|
+
self,
|
|
747
|
+
feature_importance: Union[np.ndarray, dict],
|
|
748
|
+
feature_names: Optional[List[str]] = None,
|
|
749
|
+
top_n: int = 20,
|
|
750
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
751
|
+
verbose: bool = True,
|
|
752
|
+
title: str = "Feature Importance",
|
|
753
|
+
) -> Optional[Any]:
|
|
754
|
+
"""
|
|
755
|
+
Create feature importance plot.
|
|
756
|
+
|
|
757
|
+
Parameters
|
|
758
|
+
----------
|
|
759
|
+
feature_importance : np.ndarray or dict
|
|
760
|
+
Feature importance values or dict with 'importance' key
|
|
761
|
+
feature_names : List[str], optional
|
|
762
|
+
Feature names
|
|
763
|
+
top_n : int, default 20
|
|
764
|
+
Number of top features to display
|
|
765
|
+
save_path : Union[str, Path], optional
|
|
766
|
+
Path to save plot
|
|
767
|
+
verbose : bool
|
|
768
|
+
Whether to print messages
|
|
769
|
+
title : str
|
|
770
|
+
Title for the plot
|
|
771
|
+
|
|
772
|
+
Returns
|
|
773
|
+
-------
|
|
774
|
+
Optional[Any]
|
|
775
|
+
Matplotlib figure or None if plotting failed
|
|
776
|
+
"""
|
|
777
|
+
if not self.enabled:
|
|
778
|
+
return None
|
|
779
|
+
|
|
780
|
+
try:
|
|
781
|
+
# Extract importance values if wrapped in dict
|
|
782
|
+
if isinstance(feature_importance, dict):
|
|
783
|
+
importance = feature_importance.get('importance', feature_importance.get('value'))
|
|
784
|
+
if importance is None:
|
|
785
|
+
importance = feature_importance # Assume dict is {feature: importance}
|
|
786
|
+
else:
|
|
787
|
+
importance = feature_importance
|
|
788
|
+
|
|
789
|
+
# Delegate to centralized plotting function
|
|
790
|
+
from scitex.ml.plt import plot_feature_importance as plot_fi
|
|
791
|
+
|
|
792
|
+
fig = plot_fi(
|
|
793
|
+
importance=importance,
|
|
794
|
+
feature_names=feature_names,
|
|
795
|
+
top_n=top_n,
|
|
796
|
+
title=title,
|
|
797
|
+
spath=save_path,
|
|
798
|
+
)
|
|
799
|
+
return fig
|
|
800
|
+
|
|
801
|
+
except Exception as e:
|
|
802
|
+
warnings.warn(f"Failed to create feature importance plot: {e}", UserWarning)
|
|
803
|
+
import traceback
|
|
804
|
+
traceback.print_exc()
|
|
805
|
+
return None
|
|
806
|
+
|
|
807
|
+
def create_cv_aggregation_plot(
|
|
808
|
+
self,
|
|
809
|
+
fold_predictions: List[Dict[str, Any]],
|
|
810
|
+
curve_type: str = 'roc',
|
|
811
|
+
class_labels: Optional[List[str]] = None,
|
|
812
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
813
|
+
verbose: bool = True,
|
|
814
|
+
title: Optional[str] = None,
|
|
815
|
+
show_mean: bool = True,
|
|
816
|
+
show_individual_folds: bool = True,
|
|
817
|
+
fold_alpha: float = 0.15,
|
|
818
|
+
) -> Optional[Any]:
|
|
819
|
+
"""
|
|
820
|
+
Create CV aggregation plot with faded individual fold lines.
|
|
821
|
+
|
|
822
|
+
This creates publication-quality cross-validation plots showing:
|
|
823
|
+
- Individual fold curves (faded/transparent)
|
|
824
|
+
- Mean curve across folds (bold)
|
|
825
|
+
- Optional confidence intervals
|
|
826
|
+
|
|
827
|
+
Parameters
|
|
828
|
+
----------
|
|
829
|
+
fold_predictions : List[Dict[str, Any]]
|
|
830
|
+
List of dicts with 'y_true', 'y_proba', and 'fold' keys
|
|
831
|
+
curve_type : str, default 'roc'
|
|
832
|
+
Type of curve: 'roc' or 'pr' (precision-recall)
|
|
833
|
+
class_labels : List[str], optional
|
|
834
|
+
Class labels for multiclass
|
|
835
|
+
save_path : Union[str, Path], optional
|
|
836
|
+
Path to save plot
|
|
837
|
+
verbose : bool, default True
|
|
838
|
+
Whether to print messages
|
|
839
|
+
title : str, optional
|
|
840
|
+
Custom title (auto-generated if None)
|
|
841
|
+
show_mean : bool, default True
|
|
842
|
+
Whether to show mean curve
|
|
843
|
+
show_individual_folds : bool, default True
|
|
844
|
+
Whether to show individual fold curves
|
|
845
|
+
fold_alpha : float, default 0.15
|
|
846
|
+
Transparency for individual fold curves (0-1)
|
|
847
|
+
|
|
848
|
+
Returns
|
|
849
|
+
-------
|
|
850
|
+
Optional[Any]
|
|
851
|
+
Matplotlib figure or None if plotting failed
|
|
852
|
+
|
|
853
|
+
Examples
|
|
854
|
+
--------
|
|
855
|
+
>>> # ROC curve with faded fold lines
|
|
856
|
+
>>> plotter.create_cv_aggregation_plot(
|
|
857
|
+
... fold_predictions,
|
|
858
|
+
... curve_type='roc',
|
|
859
|
+
... title='Cross-Validation ROC Curves',
|
|
860
|
+
... save_path='cv_roc.png'
|
|
861
|
+
... )
|
|
862
|
+
>>> # PR curve without individual folds
|
|
863
|
+
>>> plotter.create_cv_aggregation_plot(
|
|
864
|
+
... fold_predictions,
|
|
865
|
+
... curve_type='pr',
|
|
866
|
+
... show_individual_folds=False,
|
|
867
|
+
... save_path='cv_pr_mean_only.png'
|
|
868
|
+
... )
|
|
869
|
+
"""
|
|
870
|
+
if not self.enabled:
|
|
871
|
+
return None
|
|
872
|
+
|
|
873
|
+
try:
|
|
874
|
+
if curve_type not in ['roc', 'pr']:
|
|
875
|
+
raise ValueError("curve_type must be 'roc' or 'pr'")
|
|
876
|
+
|
|
877
|
+
from sklearn.metrics import (auc, average_precision_score,
|
|
878
|
+
precision_recall_curve, roc_curve)
|
|
879
|
+
|
|
880
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
|
881
|
+
|
|
882
|
+
# Storage for interpolated curves
|
|
883
|
+
if curve_type == 'roc':
|
|
884
|
+
mean_fpr = np.linspace(0, 1, 100)
|
|
885
|
+
tprs = []
|
|
886
|
+
aucs = []
|
|
887
|
+
else: # pr
|
|
888
|
+
mean_recall = np.linspace(0, 1, 100)
|
|
889
|
+
precisions = []
|
|
890
|
+
aps = []
|
|
891
|
+
|
|
892
|
+
# Plot individual fold curves (faded)
|
|
893
|
+
for fold_data in fold_predictions:
|
|
894
|
+
y_true = fold_data['y_true']
|
|
895
|
+
y_proba = fold_data['y_proba']
|
|
896
|
+
fold_idx = fold_data.get('fold', 0)
|
|
897
|
+
|
|
898
|
+
# Convert string labels to integer indices if needed
|
|
899
|
+
y_true_numeric = y_true
|
|
900
|
+
if y_true.dtype.kind in ('U', 'S', 'O'): # Unicode, bytes, or object (string)
|
|
901
|
+
if class_labels and len(class_labels) >= 2:
|
|
902
|
+
label_to_idx = {label: idx for idx, label in enumerate(class_labels)}
|
|
903
|
+
y_true_numeric = np.array([label_to_idx.get(yt, 0) for yt in y_true])
|
|
904
|
+
else:
|
|
905
|
+
# Infer labels from unique values
|
|
906
|
+
unique_labels = np.unique(y_true)
|
|
907
|
+
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
|
908
|
+
y_true_numeric = np.array([label_to_idx[yt] for yt in y_true])
|
|
909
|
+
|
|
910
|
+
# Handle binary classification
|
|
911
|
+
if y_proba.ndim == 2 and y_proba.shape[1] == 2:
|
|
912
|
+
y_proba_pos = y_proba[:, 1]
|
|
913
|
+
elif y_proba.ndim == 1:
|
|
914
|
+
y_proba_pos = y_proba
|
|
915
|
+
else:
|
|
916
|
+
# Multiclass - use first class for now
|
|
917
|
+
# TODO: Extend for multiclass support
|
|
918
|
+
y_proba_pos = y_proba[:, 0]
|
|
919
|
+
y_true_numeric = (y_true_numeric == 0).astype(int)
|
|
920
|
+
|
|
921
|
+
if curve_type == 'roc':
|
|
922
|
+
fpr, tpr, _ = roc_curve(y_true_numeric, y_proba_pos)
|
|
923
|
+
roc_auc = auc(fpr, tpr)
|
|
924
|
+
aucs.append(roc_auc)
|
|
925
|
+
|
|
926
|
+
if show_individual_folds:
|
|
927
|
+
ax.plot(fpr, tpr, alpha=fold_alpha, color='gray',
|
|
928
|
+
label=f'Fold {fold_idx}' if fold_idx == 0 else None)
|
|
929
|
+
|
|
930
|
+
# Interpolate for mean calculation
|
|
931
|
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
|
932
|
+
interp_tpr[0] = 0.0
|
|
933
|
+
tprs.append(interp_tpr)
|
|
934
|
+
|
|
935
|
+
else: # pr
|
|
936
|
+
precision, recall, _ = precision_recall_curve(y_true_numeric, y_proba_pos)
|
|
937
|
+
ap = average_precision_score(y_true_numeric, y_proba_pos)
|
|
938
|
+
aps.append(ap)
|
|
939
|
+
|
|
940
|
+
if show_individual_folds:
|
|
941
|
+
ax.plot(recall, precision, alpha=fold_alpha, color='gray',
|
|
942
|
+
label=f'Fold {fold_idx}' if fold_idx == 0 else None)
|
|
943
|
+
|
|
944
|
+
# Interpolate for mean calculation (reverse recall for interpolation)
|
|
945
|
+
interp_precision = np.interp(mean_recall, recall[::-1], precision[::-1])
|
|
946
|
+
precisions.append(interp_precision)
|
|
947
|
+
|
|
948
|
+
# Plot mean curve
|
|
949
|
+
if show_mean:
|
|
950
|
+
if curve_type == 'roc':
|
|
951
|
+
mean_tpr = np.mean(tprs, axis=0)
|
|
952
|
+
mean_tpr[-1] = 1.0
|
|
953
|
+
mean_auc = np.mean(aucs)
|
|
954
|
+
std_auc = np.std(aucs)
|
|
955
|
+
|
|
956
|
+
ax.plot(mean_fpr, mean_tpr, color='b', linewidth=2,
|
|
957
|
+
label=f'Mean ROC (AUC = {mean_auc:.3f} ± {std_auc:.3f})')
|
|
958
|
+
|
|
959
|
+
# Optional: Add confidence interval
|
|
960
|
+
std_tpr = np.std(tprs, axis=0)
|
|
961
|
+
tpr_upper = np.minimum(mean_tpr + std_tpr, 1)
|
|
962
|
+
tpr_lower = np.maximum(mean_tpr - std_tpr, 0)
|
|
963
|
+
ax.fill_between(mean_fpr, tpr_lower, tpr_upper,
|
|
964
|
+
color='b', alpha=0.2, label='± 1 std. dev.')
|
|
965
|
+
|
|
966
|
+
# Chance line
|
|
967
|
+
ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Chance')
|
|
968
|
+
|
|
969
|
+
ax.set_xlabel('False Positive Rate', fontsize=12)
|
|
970
|
+
ax.set_ylabel('True Positive Rate', fontsize=12)
|
|
971
|
+
if title is None:
|
|
972
|
+
title = f'ROC Curves - Cross Validation (n={len(fold_predictions)} folds)'
|
|
973
|
+
|
|
974
|
+
else: # pr
|
|
975
|
+
mean_precision = np.mean(precisions, axis=0)
|
|
976
|
+
mean_ap = np.mean(aps)
|
|
977
|
+
std_ap = np.std(aps)
|
|
978
|
+
|
|
979
|
+
ax.plot(mean_recall, mean_precision, color='b', linewidth=2,
|
|
980
|
+
label=f'Mean PR (AP = {mean_ap:.3f} ± {std_ap:.3f})')
|
|
981
|
+
|
|
982
|
+
# Optional: Add confidence interval
|
|
983
|
+
std_precision = np.std(precisions, axis=0)
|
|
984
|
+
precision_upper = np.minimum(mean_precision + std_precision, 1)
|
|
985
|
+
precision_lower = np.maximum(mean_precision - std_precision, 0)
|
|
986
|
+
ax.fill_between(mean_recall, precision_lower, precision_upper,
|
|
987
|
+
color='b', alpha=0.2, label='± 1 std. dev.')
|
|
988
|
+
|
|
989
|
+
ax.set_xlabel('Recall', fontsize=12)
|
|
990
|
+
ax.set_ylabel('Precision', fontsize=12)
|
|
991
|
+
if title is None:
|
|
992
|
+
title = f'Precision-Recall Curves - Cross Validation (n={len(fold_predictions)} folds)'
|
|
993
|
+
|
|
994
|
+
ax.set_xlim([0.0, 1.0])
|
|
995
|
+
ax.set_ylim([0.0, 1.05])
|
|
996
|
+
ax.set_title(title, fontsize=14, fontweight='bold')
|
|
997
|
+
ax.legend(loc='best', fontsize=10)
|
|
998
|
+
ax.grid(True, alpha=0.3)
|
|
999
|
+
ax.set_aspect('equal')
|
|
1000
|
+
|
|
1001
|
+
plt.tight_layout()
|
|
1002
|
+
|
|
1003
|
+
if save_path:
|
|
1004
|
+
from scitex.io import save as stx_io_save
|
|
1005
|
+
stx_io_save(fig, save_path, verbose=verbose or self.verbose, use_caller_path=False)
|
|
1006
|
+
|
|
1007
|
+
return fig
|
|
1008
|
+
|
|
1009
|
+
except Exception as e:
|
|
1010
|
+
warnings.warn(f"Failed to create CV aggregation plot: {e}", UserWarning)
|
|
1011
|
+
import traceback
|
|
1012
|
+
traceback.print_exc()
|
|
1013
|
+
return None
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def safe_plot_wrapper(func):
|
|
1017
|
+
"""Decorator to wrap plotting functions with error handling."""
|
|
1018
|
+
|
|
1019
|
+
def wrapper(*args, **kwargs):
|
|
1020
|
+
try:
|
|
1021
|
+
return func(*args, **kwargs)
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
warnings.warn(f"Plotting failed: {e}", UserWarning)
|
|
1024
|
+
return None
|
|
1025
|
+
|
|
1026
|
+
return wrapper
|
|
1027
|
+
|
|
1028
|
+
# EOF
|