scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from itertools import cycle
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sklearn.metrics import roc_auc_score, roc_curve
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
import scitex
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def interpolate_roc_data_points(df):
|
|
14
|
+
df_new = pd.DataFrame(
|
|
15
|
+
{
|
|
16
|
+
"x": np.arange(1001) / 1000,
|
|
17
|
+
"y": np.nan,
|
|
18
|
+
"threshold": np.nan,
|
|
19
|
+
}
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
for i_row in range(len(df) - 1):
|
|
23
|
+
x_pre = df.iloc[i_row]["fpr"]
|
|
24
|
+
x_post = df.iloc[i_row + 1]["fpr"]
|
|
25
|
+
|
|
26
|
+
indi = (x_pre <= df_new["x"]) * (df_new["x"] <= x_post)
|
|
27
|
+
|
|
28
|
+
y_pre = df.iloc[i_row]["tpr"]
|
|
29
|
+
y_post = df.iloc[i_row + 1]["tpr"]
|
|
30
|
+
|
|
31
|
+
t_pre = df.iloc[i_row]["threshold"]
|
|
32
|
+
t_post = df.iloc[i_row + 1]["threshold"]
|
|
33
|
+
|
|
34
|
+
df_new["y"][indi] = y_pre
|
|
35
|
+
df_new["threshold"][indi] = t_pre
|
|
36
|
+
|
|
37
|
+
df_new["y"].iloc[0] = df["tpr"].iloc[0]
|
|
38
|
+
df_new["y"].iloc[-1] = df["tpr"].iloc[-1]
|
|
39
|
+
|
|
40
|
+
df_new["threshold"].iloc[0] = df["threshold"].iloc[0]
|
|
41
|
+
df_new["threshold"].iloc[-1] = df["threshold"].iloc[-1]
|
|
42
|
+
|
|
43
|
+
df_new["roc_auc"] = df["roc_auc"].iloc[0]
|
|
44
|
+
|
|
45
|
+
# import ipdb; ipdb.set_trace()
|
|
46
|
+
# assert df_new["y"].isna().sum() == 0
|
|
47
|
+
return df_new
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_onehot(labels, n_classes):
|
|
51
|
+
eye = np.eye(n_classes, dtype=int)
|
|
52
|
+
return eye[labels]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def roc_auc(plt, true_class, pred_proba, labels, sdir_for_csv=None):
|
|
56
|
+
"""
|
|
57
|
+
Calculates ROC-AUC curve.
|
|
58
|
+
Return: fig, metrics (dict)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
# Use label_binarize to be multi-label like settings
|
|
62
|
+
n_classes = len(labels)
|
|
63
|
+
true_class_onehot = to_onehot(true_class, n_classes)
|
|
64
|
+
|
|
65
|
+
# For each class
|
|
66
|
+
fpr = dict()
|
|
67
|
+
tpr = dict()
|
|
68
|
+
threshold = dict()
|
|
69
|
+
roc_auc = dict()
|
|
70
|
+
for i in range(n_classes):
|
|
71
|
+
true_class_i_onehot = true_class_onehot[:, i]
|
|
72
|
+
pred_proba_i = pred_proba[:, i]
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
fpr[i], tpr[i], threshold[i] = roc_curve(true_class_i_onehot, pred_proba_i)
|
|
76
|
+
roc_auc[i] = roc_auc_score(true_class_i_onehot, pred_proba_i)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
print(e)
|
|
79
|
+
fpr[i], tpr[i], threshold[i], roc_auc[i] = (
|
|
80
|
+
[np.nan],
|
|
81
|
+
[np.nan],
|
|
82
|
+
[np.nan],
|
|
83
|
+
np.nan,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
## Average fpr: micro and macro
|
|
87
|
+
|
|
88
|
+
# A "micro-average": quantifying score on all classes jointly
|
|
89
|
+
fpr["micro"], tpr["micro"], threshold["micro"] = roc_curve(
|
|
90
|
+
true_class_onehot.ravel(), pred_proba.ravel()
|
|
91
|
+
)
|
|
92
|
+
roc_auc["micro"] = roc_auc_score(true_class_onehot, pred_proba, average="micro")
|
|
93
|
+
|
|
94
|
+
# macro
|
|
95
|
+
_roc_aucs = []
|
|
96
|
+
for i in range(n_classes):
|
|
97
|
+
try:
|
|
98
|
+
_roc_aucs.append(
|
|
99
|
+
roc_auc_score(
|
|
100
|
+
true_class_onehot[:, i], pred_proba[:, i], average="macro"
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
print(
|
|
105
|
+
f'\nROC-AUC for "{labels[i]}" was not defined and NaN-filled '
|
|
106
|
+
"for a calculation purpose (for the macro avg.)\n"
|
|
107
|
+
)
|
|
108
|
+
_roc_aucs.append(np.nan)
|
|
109
|
+
roc_auc["macro"] = np.nanmean(_roc_aucs)
|
|
110
|
+
|
|
111
|
+
if sdir_for_csv is not None:
|
|
112
|
+
# to dfs
|
|
113
|
+
for i in range(n_classes):
|
|
114
|
+
class_name = labels[i].replace(" ", "_")
|
|
115
|
+
df = pd.DataFrame(
|
|
116
|
+
data={
|
|
117
|
+
"fpr": fpr[i],
|
|
118
|
+
"tpr": tpr[i],
|
|
119
|
+
"threshold": threshold[i],
|
|
120
|
+
"roc_auc": [roc_auc[i] for _ in range(len(fpr[i]))],
|
|
121
|
+
},
|
|
122
|
+
index=pd.Index(data=np.arange(len(fpr[i])), name=class_name),
|
|
123
|
+
)
|
|
124
|
+
df = interpolate_roc_data_points(df)
|
|
125
|
+
spath = f"{sdir_for_csv}{class_name}.csv"
|
|
126
|
+
scitex.io.save(df, spath)
|
|
127
|
+
|
|
128
|
+
# Plot FPR-TPR curve for each class and iso-f1 curves
|
|
129
|
+
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
|
|
130
|
+
|
|
131
|
+
fig, ax = plt.subplots()
|
|
132
|
+
ax.set_box_aspect(1)
|
|
133
|
+
lines = []
|
|
134
|
+
legends = []
|
|
135
|
+
|
|
136
|
+
## Chance Level (the diagonal line)
|
|
137
|
+
(l,) = ax.plot(
|
|
138
|
+
np.linspace(0.01, 1),
|
|
139
|
+
np.linspace(0.01, 1),
|
|
140
|
+
color="gray",
|
|
141
|
+
lw=2,
|
|
142
|
+
linestyle="--",
|
|
143
|
+
alpha=0.8,
|
|
144
|
+
)
|
|
145
|
+
lines.append(l)
|
|
146
|
+
legends.append("Chance")
|
|
147
|
+
|
|
148
|
+
## Each Class
|
|
149
|
+
for i, color in zip(range(n_classes), colors):
|
|
150
|
+
(l,) = plt.plot(fpr[i], tpr[i], color=color, lw=2)
|
|
151
|
+
lines.append(l)
|
|
152
|
+
legends.append("{0} (AUC = {1:0.2f})" "".format(labels[i], roc_auc[i]))
|
|
153
|
+
|
|
154
|
+
# fig = plt.gcf()
|
|
155
|
+
fig.subplots_adjust(bottom=0.25)
|
|
156
|
+
ax.set_xlim([-0.01, 1.01])
|
|
157
|
+
ax.set_ylim([-0.01, 1.01])
|
|
158
|
+
ax.set_xticks([0.0, 0.5, 1.0])
|
|
159
|
+
ax.set_yticks([0.0, 0.5, 1.0])
|
|
160
|
+
ax.set_xlabel("FPR")
|
|
161
|
+
ax.set_ylabel("TPR")
|
|
162
|
+
ax.set_title("ROC Curve")
|
|
163
|
+
ax.legend(lines, legends, loc="lower right")
|
|
164
|
+
|
|
165
|
+
metrics = dict(roc_auc=roc_auc, fpr=fpr, tpr=tpr, threshold=threshold)
|
|
166
|
+
|
|
167
|
+
# return fig, roc_auc, fpr, tpr, threshold
|
|
168
|
+
return fig, metrics
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__":
|
|
172
|
+
import matplotlib.pyplot as plt
|
|
173
|
+
import numpy as np
|
|
174
|
+
from scipy.special import softmax
|
|
175
|
+
from sklearn import datasets, svm
|
|
176
|
+
from sklearn.model_selection import train_test_split
|
|
177
|
+
|
|
178
|
+
def mk_demo_data(n_classes=2, batch_size=16):
|
|
179
|
+
labels = ["cls{}".format(i_cls) for i_cls in range(n_classes)]
|
|
180
|
+
true_class = np.random.randint(0, n_classes, size=(batch_size,))
|
|
181
|
+
pred_proba = softmax(np.random.rand(batch_size, n_classes), axis=-1)
|
|
182
|
+
pred_class = np.argmax(pred_proba, axis=-1)
|
|
183
|
+
return labels, true_class, pred_proba, pred_class
|
|
184
|
+
|
|
185
|
+
## Fix seed
|
|
186
|
+
np.random.seed(42)
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
################################################################################
|
|
190
|
+
## A Minimal Example
|
|
191
|
+
################################################################################
|
|
192
|
+
labels, true_class, pred_proba, pred_class = \
|
|
193
|
+
mk_demo_data(n_classes=10, batch_size=256)
|
|
194
|
+
|
|
195
|
+
roc_auc, fpr, tpr, threshold = \
|
|
196
|
+
calc_roc_auc(true_class, pred_proba, labels, plot=False)
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
################################################################################
|
|
200
|
+
## MNIST
|
|
201
|
+
################################################################################
|
|
202
|
+
from sklearn import datasets, metrics, svm
|
|
203
|
+
from sklearn.model_selection import train_test_split
|
|
204
|
+
|
|
205
|
+
digits = datasets.load_digits()
|
|
206
|
+
|
|
207
|
+
# flatten the images
|
|
208
|
+
n_samples = len(digits.images)
|
|
209
|
+
data = digits.images.reshape((n_samples, -1))
|
|
210
|
+
|
|
211
|
+
# Create a classifier: a support vector classifier
|
|
212
|
+
clf = svm.SVC(gamma=0.001, probability=True)
|
|
213
|
+
|
|
214
|
+
# Split data into 50% train and 50% test subsets
|
|
215
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
216
|
+
data, digits.target, test_size=0.5, shuffle=False
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Learn the digits on the train subset
|
|
220
|
+
clf.fit(X_train, y_train)
|
|
221
|
+
|
|
222
|
+
# Predict the value of the digit on the test subset
|
|
223
|
+
predicted_proba = clf.predict_proba(X_test)
|
|
224
|
+
predicted = clf.predict(X_test)
|
|
225
|
+
|
|
226
|
+
n_classes = len(np.unique(digits.target))
|
|
227
|
+
labels = ["Class {}".format(i) for i in range(n_classes)]
|
|
228
|
+
|
|
229
|
+
## Configures matplotlib
|
|
230
|
+
plt.rcParams["font.size"] = 20
|
|
231
|
+
plt.rcParams["legend.fontsize"] = "xx-small"
|
|
232
|
+
plt.rcParams["figure.figsize"] = (16 * 1.2, 9 * 1.2)
|
|
233
|
+
|
|
234
|
+
np.unique(y_test)
|
|
235
|
+
np.unique(predicted_proba)
|
|
236
|
+
|
|
237
|
+
y_test[y_test == 9] = 8 # override 9 as 8
|
|
238
|
+
## Main
|
|
239
|
+
fig, metrics_dict = roc_auc(
|
|
240
|
+
plt, y_test, predicted_proba, labels, sdir_for_csv="./tmp/roc_test/"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
fig.show()
|
|
244
|
+
|
|
245
|
+
print(metrics_dict.keys())
|
|
246
|
+
# dict_keys(['roc_auc', 'fpr', 'tpr', 'threshold'])
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-24 10:13:17 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/ai/sampling/undersample.py
|
|
5
|
+
|
|
6
|
+
THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/sampling/undersample.py"
|
|
7
|
+
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from ...types import ArrayLike
|
|
10
|
+
from imblearn.under_sampling import RandomUnderSampler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def undersample(
|
|
14
|
+
X: ArrayLike, y: ArrayLike, random_state: int = 42
|
|
15
|
+
) -> Tuple[ArrayLike, ArrayLike]:
|
|
16
|
+
"""Undersample data preserving input type.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
X: Features array-like of shape (n_samples, n_features)
|
|
20
|
+
y: Labels array-like of shape (n_samples,)
|
|
21
|
+
Returns:
|
|
22
|
+
Resampled X, y of same type as input
|
|
23
|
+
"""
|
|
24
|
+
rus = RandomUnderSampler(random_state=random_state)
|
|
25
|
+
X_resampled, y_resampled = rus.fit_resample(X, y)
|
|
26
|
+
return X_resampled, y_resampled
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# EOF
|
scitex/ai/sk/__init__.py
ADDED
scitex/ai/sk/_clf.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-03-23 17:36:05 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from sklearn.decomposition import PCA
|
|
7
|
+
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
|
|
8
|
+
from sklearn.feature_selection import SelectKBest, f_classif
|
|
9
|
+
from sklearn.linear_model import LogisticRegression, RidgeClassifierCV
|
|
10
|
+
from sklearn.pipeline import make_pipeline
|
|
11
|
+
from sklearn.svm import SVC, LinearSVC
|
|
12
|
+
from sktime.classification.deep_learning.cnn import CNNClassifier
|
|
13
|
+
from sktime.classification.deep_learning.inceptiontime import (
|
|
14
|
+
InceptionTimeClassifier,
|
|
15
|
+
)
|
|
16
|
+
from sktime.classification.deep_learning.lstmfcn import LSTMFCNClassifier
|
|
17
|
+
from sktime.classification.dummy import DummyClassifier
|
|
18
|
+
from sktime.classification.feature_based import TSFreshClassifier
|
|
19
|
+
from sktime.classification.hybrid import HIVECOTEV2
|
|
20
|
+
from sktime.classification.interval_based import TimeSeriesForestClassifier
|
|
21
|
+
from sktime.classification.kernel_based import RocketClassifier, TimeSeriesSVC
|
|
22
|
+
from sktime.transformations.panel.reduce import Tabularizer
|
|
23
|
+
from sktime.transformations.panel.rocket import Rocket
|
|
24
|
+
|
|
25
|
+
# _rocket_pipeline = make_pipeline(
|
|
26
|
+
# Rocket(n_jobs=-1),
|
|
27
|
+
# RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
|
|
28
|
+
# )
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# def rocket_pipeline(*args, **kwargs):
|
|
32
|
+
# return _rocket_pipeline
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def rocket_pipeline(*args, **kwargs):
|
|
36
|
+
return make_pipeline(
|
|
37
|
+
Rocket(*args, **kwargs),
|
|
38
|
+
LogisticRegression(
|
|
39
|
+
max_iter=1000
|
|
40
|
+
), # Increase max_iter if needed for convergence
|
|
41
|
+
# RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
|
|
42
|
+
# SVC(probability=True, kernel="linear"),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# def rocket_pipeline(*args, **kwargs):
|
|
47
|
+
# return make_pipeline(
|
|
48
|
+
# Rocket(*args, **kwargs),
|
|
49
|
+
# SelectKBest(f_classif, k=500),
|
|
50
|
+
# PCA(n_components=100),
|
|
51
|
+
# LinearSVC(dual=False, tol=1e-3, C=0.1, probability=True),
|
|
52
|
+
# )
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
GB_pipeline = make_pipeline(
|
|
56
|
+
Tabularizer(),
|
|
57
|
+
GradientBoostingClassifier(),
|
|
58
|
+
)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-03-05 13:17:04 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
# import warnings
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def to_sktime_df(X):
|
|
13
|
+
"""
|
|
14
|
+
Converts a dataset to a format compatible with sktime, encapsulating each sample as a pandas DataFrame.
|
|
15
|
+
|
|
16
|
+
Arguments:
|
|
17
|
+
- X (numpy.ndarray or torch.Tensor or pandas.DataFrame): The input dataset with shape (n_samples, n_chs, seq_len).
|
|
18
|
+
It should be a 3D array-like structure containing the time series data.
|
|
19
|
+
|
|
20
|
+
Return:
|
|
21
|
+
- sktime_df (pandas.DataFrame): A DataFrame where each element is a pandas Series representing a univariate time series.
|
|
22
|
+
|
|
23
|
+
Data Types and Shapes:
|
|
24
|
+
- If X is a numpy.ndarray, it should have the shape (n_samples, n_chs, seq_len).
|
|
25
|
+
- If X is a torch.Tensor, it should have the shape (n_samples, n_chs, seq_len) and will be converted to a numpy array.
|
|
26
|
+
- If X is a pandas.DataFrame, it is assumed to already be in the correct format and will be returned as is.
|
|
27
|
+
|
|
28
|
+
References:
|
|
29
|
+
- sktime: https://github.com/alan-turing-institute/sktime
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
--------
|
|
33
|
+
>>> X_np = np.random.rand(64, 160, 1024)
|
|
34
|
+
>>> sktime_df = to_sktime_df(X_np)
|
|
35
|
+
>>> type(sktime_df)
|
|
36
|
+
<class 'pandas.core.frame.DataFrame'>
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(X, pd.DataFrame):
|
|
39
|
+
return X
|
|
40
|
+
elif torch.is_tensor(X):
|
|
41
|
+
X = X.numpy()
|
|
42
|
+
elif not isinstance(X, np.ndarray):
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"Input X must be a numpy.ndarray, torch.Tensor, or pandas.DataFrame"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
X = X.astype(np.float64)
|
|
48
|
+
|
|
49
|
+
def _format_a_sample_for_sktime(x):
|
|
50
|
+
"""
|
|
51
|
+
Formats a single sample for sktime compatibility.
|
|
52
|
+
|
|
53
|
+
Arguments:
|
|
54
|
+
- x (numpy.ndarray): A 2D array with shape (n_chs, seq_len) representing a single sample.
|
|
55
|
+
|
|
56
|
+
Return:
|
|
57
|
+
- dims (pandas.Series): A Series where each element is a pandas Series representing a univariate time series.
|
|
58
|
+
"""
|
|
59
|
+
return pd.Series([pd.Series(x[d], name=f"dim_{d}") for d in range(x.shape[0])])
|
|
60
|
+
|
|
61
|
+
sktime_df = pd.DataFrame(
|
|
62
|
+
[_format_a_sample_for_sktime(X[i]) for i in range(X.shape[0])]
|
|
63
|
+
)
|
|
64
|
+
return sktime_df
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# # Obsolete warning for future compatibility
|
|
68
|
+
# def to_sktime(*args, **kwargs):
|
|
69
|
+
# warnings.warn(
|
|
70
|
+
# "to_sktime is deprecated; use to_sktime_df instead.", FutureWarning
|
|
71
|
+
# )
|
|
72
|
+
# return to_sktime_df(*args, **kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# import pandas as pd
|
|
76
|
+
# import numpy as np
|
|
77
|
+
# import torch
|
|
78
|
+
|
|
79
|
+
# def to_sktime(X):
|
|
80
|
+
# """
|
|
81
|
+
# X.shape: (n_samples, n_chs, seq_len)
|
|
82
|
+
# """
|
|
83
|
+
|
|
84
|
+
# def _format_a_sample_for_sktime(x):
|
|
85
|
+
# """
|
|
86
|
+
# x.shape: (n_chs, seq_len)
|
|
87
|
+
# """
|
|
88
|
+
# dims = pd.Series(
|
|
89
|
+
# [pd.Series(x[d], name=f"dim_{d}") for d in range(len(x))],
|
|
90
|
+
# index=[f"dim_{i}" for i in np.arange(len(x))],
|
|
91
|
+
# )
|
|
92
|
+
# return dims
|
|
93
|
+
|
|
94
|
+
# if torch.is_tensor(X):
|
|
95
|
+
# X = X.numpy()
|
|
96
|
+
# X = X.astype(np.float64)
|
|
97
|
+
|
|
98
|
+
# return pd.DataFrame(
|
|
99
|
+
# [_format_a_sample_for_sktime(X[i]) for i in range(len(X))]
|
|
100
|
+
# )
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Sklearn wrappers and utilities."""
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from .clf import *
|
|
8
|
+
except ImportError as e:
|
|
9
|
+
warnings.warn(
|
|
10
|
+
f"Could not import clf from scitex.ai.sklearn: {str(e)}. "
|
|
11
|
+
f"Some functionality may be unavailable. "
|
|
12
|
+
f"Consider installing missing dependencies if you need this module.",
|
|
13
|
+
ImportWarning,
|
|
14
|
+
stacklevel=2
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from .to_sktime import *
|
|
19
|
+
except ImportError as e:
|
|
20
|
+
warnings.warn(
|
|
21
|
+
f"Could not import to_sktime from scitex.ai.sklearn: {str(e)}. "
|
|
22
|
+
f"Some functionality may be unavailable. "
|
|
23
|
+
f"Consider installing missing dependencies if you need this module.",
|
|
24
|
+
ImportWarning,
|
|
25
|
+
stacklevel=2
|
|
26
|
+
)
|
scitex/ai/sklearn/clf.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-03-23 17:36:05 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from sklearn.decomposition import PCA
|
|
7
|
+
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
|
|
8
|
+
from sklearn.feature_selection import SelectKBest, f_classif
|
|
9
|
+
from sklearn.linear_model import LogisticRegression, RidgeClassifierCV
|
|
10
|
+
from sklearn.pipeline import make_pipeline
|
|
11
|
+
from sklearn.svm import SVC, LinearSVC
|
|
12
|
+
from sktime.classification.deep_learning.cnn import CNNClassifier
|
|
13
|
+
from sktime.classification.deep_learning.inceptiontime import (
|
|
14
|
+
InceptionTimeClassifier,
|
|
15
|
+
)
|
|
16
|
+
from sktime.classification.deep_learning.lstmfcn import LSTMFCNClassifier
|
|
17
|
+
from sktime.classification.dummy import DummyClassifier
|
|
18
|
+
from sktime.classification.feature_based import TSFreshClassifier
|
|
19
|
+
from sktime.classification.hybrid import HIVECOTEV2
|
|
20
|
+
from sktime.classification.interval_based import TimeSeriesForestClassifier
|
|
21
|
+
from sktime.classification.kernel_based import RocketClassifier, TimeSeriesSVC
|
|
22
|
+
from sktime.transformations.panel.reduce import Tabularizer
|
|
23
|
+
from sktime.transformations.panel.rocket import Rocket
|
|
24
|
+
|
|
25
|
+
# _rocket_pipeline = make_pipeline(
|
|
26
|
+
# Rocket(n_jobs=-1),
|
|
27
|
+
# RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
|
|
28
|
+
# )
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# def rocket_pipeline(*args, **kwargs):
|
|
32
|
+
# return _rocket_pipeline
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def rocket_pipeline(*args, **kwargs):
|
|
36
|
+
return make_pipeline(
|
|
37
|
+
Rocket(*args, **kwargs),
|
|
38
|
+
LogisticRegression(
|
|
39
|
+
max_iter=1000
|
|
40
|
+
), # Increase max_iter if needed for convergence
|
|
41
|
+
# RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
|
|
42
|
+
# SVC(probability=True, kernel="linear"),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# def rocket_pipeline(*args, **kwargs):
|
|
47
|
+
# return make_pipeline(
|
|
48
|
+
# Rocket(*args, **kwargs),
|
|
49
|
+
# SelectKBest(f_classif, k=500),
|
|
50
|
+
# PCA(n_components=100),
|
|
51
|
+
# LinearSVC(dual=False, tol=1e-3, C=0.1, probability=True),
|
|
52
|
+
# )
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
GB_pipeline = make_pipeline(
|
|
56
|
+
Tabularizer(),
|
|
57
|
+
GradientBoostingClassifier(),
|
|
58
|
+
)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-03-05 13:17:04 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
# import warnings
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def to_sktime_df(X):
|
|
13
|
+
"""
|
|
14
|
+
Converts a dataset to a format compatible with sktime, encapsulating each sample as a pandas DataFrame.
|
|
15
|
+
|
|
16
|
+
Arguments:
|
|
17
|
+
- X (numpy.ndarray or torch.Tensor or pandas.DataFrame): The input dataset with shape (n_samples, n_chs, seq_len).
|
|
18
|
+
It should be a 3D array-like structure containing the time series data.
|
|
19
|
+
|
|
20
|
+
Return:
|
|
21
|
+
- sktime_df (pandas.DataFrame): A DataFrame where each element is a pandas Series representing a univariate time series.
|
|
22
|
+
|
|
23
|
+
Data Types and Shapes:
|
|
24
|
+
- If X is a numpy.ndarray, it should have the shape (n_samples, n_chs, seq_len).
|
|
25
|
+
- If X is a torch.Tensor, it should have the shape (n_samples, n_chs, seq_len) and will be converted to a numpy array.
|
|
26
|
+
- If X is a pandas.DataFrame, it is assumed to already be in the correct format and will be returned as is.
|
|
27
|
+
|
|
28
|
+
References:
|
|
29
|
+
- sktime: https://github.com/alan-turing-institute/sktime
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
--------
|
|
33
|
+
>>> X_np = np.random.rand(64, 160, 1024)
|
|
34
|
+
>>> sktime_df = to_sktime_df(X_np)
|
|
35
|
+
>>> type(sktime_df)
|
|
36
|
+
<class 'pandas.core.frame.DataFrame'>
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(X, pd.DataFrame):
|
|
39
|
+
return X
|
|
40
|
+
elif torch.is_tensor(X):
|
|
41
|
+
X = X.detach().numpy()
|
|
42
|
+
elif not isinstance(X, np.ndarray):
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"Input X must be a numpy.ndarray, torch.Tensor, or pandas.DataFrame"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
X = X.astype(np.float64)
|
|
48
|
+
|
|
49
|
+
def _format_a_sample_for_sktime(x):
|
|
50
|
+
"""
|
|
51
|
+
Formats a single sample for sktime compatibility.
|
|
52
|
+
|
|
53
|
+
Arguments:
|
|
54
|
+
- x (numpy.ndarray): A 2D array with shape (n_chs, seq_len) representing a single sample.
|
|
55
|
+
|
|
56
|
+
Return:
|
|
57
|
+
- dims (pandas.Series): A Series where each element is a pandas Series representing a univariate time series.
|
|
58
|
+
"""
|
|
59
|
+
return pd.Series([pd.Series(x[d], name=f"dim_{d}") for d in range(x.shape[0])])
|
|
60
|
+
|
|
61
|
+
sktime_df = pd.DataFrame(
|
|
62
|
+
[_format_a_sample_for_sktime(X[i]) for i in range(X.shape[0])]
|
|
63
|
+
)
|
|
64
|
+
return sktime_df
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# # Obsolete warning for future compatibility
|
|
68
|
+
# def to_sktime(*args, **kwargs):
|
|
69
|
+
# warnings.warn(
|
|
70
|
+
# "to_sktime is deprecated; use to_sktime_df instead.", FutureWarning
|
|
71
|
+
# )
|
|
72
|
+
# return to_sktime_df(*args, **kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# import pandas as pd
|
|
76
|
+
# import numpy as np
|
|
77
|
+
# import torch
|
|
78
|
+
|
|
79
|
+
# def to_sktime(X):
|
|
80
|
+
# """
|
|
81
|
+
# X.shape: (n_samples, n_chs, seq_len)
|
|
82
|
+
# """
|
|
83
|
+
|
|
84
|
+
# def _format_a_sample_for_sktime(x):
|
|
85
|
+
# """
|
|
86
|
+
# x.shape: (n_chs, seq_len)
|
|
87
|
+
# """
|
|
88
|
+
# dims = pd.Series(
|
|
89
|
+
# [pd.Series(x[d], name=f"dim_{d}") for d in range(len(x))],
|
|
90
|
+
# index=[f"dim_{i}" for i in np.arange(len(x))],
|
|
91
|
+
# )
|
|
92
|
+
# return dims
|
|
93
|
+
|
|
94
|
+
# if torch.is_tensor(X):
|
|
95
|
+
# X = X.numpy()
|
|
96
|
+
# X = X.astype(np.float64)
|
|
97
|
+
|
|
98
|
+
# return pd.DataFrame(
|
|
99
|
+
# [_format_a_sample_for_sktime(X[i]) for i in range(len(X))]
|
|
100
|
+
# )
|