scitex 2.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +73 -0
- scitex/__main__.py +89 -0
- scitex/__version__.py +14 -0
- scitex/_sh.py +59 -0
- scitex/ai/_LearningCurveLogger.py +583 -0
- scitex/ai/__Classifiers.py +101 -0
- scitex/ai/__init__.py +55 -0
- scitex/ai/_gen_ai/_Anthropic.py +173 -0
- scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
- scitex/ai/_gen_ai/_DeepSeek.py +175 -0
- scitex/ai/_gen_ai/_Google.py +161 -0
- scitex/ai/_gen_ai/_Groq.py +97 -0
- scitex/ai/_gen_ai/_Llama.py +142 -0
- scitex/ai/_gen_ai/_OpenAI.py +230 -0
- scitex/ai/_gen_ai/_PARAMS.py +565 -0
- scitex/ai/_gen_ai/_Perplexity.py +191 -0
- scitex/ai/_gen_ai/__init__.py +32 -0
- scitex/ai/_gen_ai/_calc_cost.py +78 -0
- scitex/ai/_gen_ai/_format_output_func.py +183 -0
- scitex/ai/_gen_ai/_genai_factory.py +71 -0
- scitex/ai/act/__init__.py +8 -0
- scitex/ai/act/_define.py +11 -0
- scitex/ai/classification/__init__.py +7 -0
- scitex/ai/classification/classification_reporter.py +1137 -0
- scitex/ai/classification/classifier_server.py +131 -0
- scitex/ai/classification/classifiers.py +101 -0
- scitex/ai/classification_reporter.py +1161 -0
- scitex/ai/classifier_server.py +131 -0
- scitex/ai/clustering/__init__.py +11 -0
- scitex/ai/clustering/_pca.py +115 -0
- scitex/ai/clustering/_umap.py +376 -0
- scitex/ai/early_stopping.py +149 -0
- scitex/ai/feature_extraction/__init__.py +56 -0
- scitex/ai/feature_extraction/vit.py +148 -0
- scitex/ai/genai/__init__.py +277 -0
- scitex/ai/genai/anthropic.py +177 -0
- scitex/ai/genai/anthropic_provider.py +320 -0
- scitex/ai/genai/anthropic_refactored.py +109 -0
- scitex/ai/genai/auth_manager.py +200 -0
- scitex/ai/genai/base_genai.py +336 -0
- scitex/ai/genai/base_provider.py +291 -0
- scitex/ai/genai/calc_cost.py +78 -0
- scitex/ai/genai/chat_history.py +307 -0
- scitex/ai/genai/cost_tracker.py +276 -0
- scitex/ai/genai/deepseek.py +188 -0
- scitex/ai/genai/deepseek_provider.py +251 -0
- scitex/ai/genai/format_output_func.py +183 -0
- scitex/ai/genai/genai_factory.py +71 -0
- scitex/ai/genai/google.py +169 -0
- scitex/ai/genai/google_provider.py +228 -0
- scitex/ai/genai/groq.py +104 -0
- scitex/ai/genai/groq_provider.py +248 -0
- scitex/ai/genai/image_processor.py +250 -0
- scitex/ai/genai/llama.py +155 -0
- scitex/ai/genai/llama_provider.py +214 -0
- scitex/ai/genai/mock_provider.py +127 -0
- scitex/ai/genai/model_registry.py +304 -0
- scitex/ai/genai/openai.py +230 -0
- scitex/ai/genai/openai_provider.py +293 -0
- scitex/ai/genai/params.py +565 -0
- scitex/ai/genai/perplexity.py +202 -0
- scitex/ai/genai/perplexity_provider.py +205 -0
- scitex/ai/genai/provider_base.py +302 -0
- scitex/ai/genai/provider_factory.py +370 -0
- scitex/ai/genai/response_handler.py +235 -0
- scitex/ai/layer/_Pass.py +21 -0
- scitex/ai/layer/__init__.py +10 -0
- scitex/ai/layer/_switch.py +8 -0
- scitex/ai/loss/_L1L2Losses.py +34 -0
- scitex/ai/loss/__init__.py +12 -0
- scitex/ai/loss/multi_task_loss.py +47 -0
- scitex/ai/metrics/__init__.py +9 -0
- scitex/ai/metrics/_bACC.py +51 -0
- scitex/ai/metrics/silhoute_score_block.py +496 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
- scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
- scitex/ai/optim/__init__.py +13 -0
- scitex/ai/optim/_get_set.py +31 -0
- scitex/ai/optim/_optimizers.py +71 -0
- scitex/ai/plt/__init__.py +21 -0
- scitex/ai/plt/_conf_mat.py +592 -0
- scitex/ai/plt/_learning_curve.py +194 -0
- scitex/ai/plt/_optuna_study.py +111 -0
- scitex/ai/plt/aucs/__init__.py +2 -0
- scitex/ai/plt/aucs/example.py +60 -0
- scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
- scitex/ai/plt/aucs/roc_auc.py +246 -0
- scitex/ai/sampling/undersample.py +29 -0
- scitex/ai/sk/__init__.py +11 -0
- scitex/ai/sk/_clf.py +58 -0
- scitex/ai/sk/_to_sktime.py +100 -0
- scitex/ai/sklearn/__init__.py +26 -0
- scitex/ai/sklearn/clf.py +58 -0
- scitex/ai/sklearn/to_sktime.py +100 -0
- scitex/ai/training/__init__.py +7 -0
- scitex/ai/training/early_stopping.py +150 -0
- scitex/ai/training/learning_curve_logger.py +555 -0
- scitex/ai/utils/__init__.py +22 -0
- scitex/ai/utils/_check_params.py +50 -0
- scitex/ai/utils/_default_dataset.py +46 -0
- scitex/ai/utils/_format_samples_for_sktime.py +26 -0
- scitex/ai/utils/_label_encoder.py +134 -0
- scitex/ai/utils/_merge_labels.py +22 -0
- scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
- scitex/ai/utils/_under_sample.py +51 -0
- scitex/ai/utils/_verify_n_gpus.py +16 -0
- scitex/ai/utils/grid_search.py +148 -0
- scitex/context/__init__.py +9 -0
- scitex/context/_suppress_output.py +38 -0
- scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
- scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
- scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
- scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
- scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
- scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
- scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
- scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
- scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
- scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
- scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
- scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
- scitex/db/_BaseMixins/__init__.py +30 -0
- scitex/db/_PostgreSQL.py +126 -0
- scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
- scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
- scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
- scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
- scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
- scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
- scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
- scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
- scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
- scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
- scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
- scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
- scitex/db/_PostgreSQLMixins/__init__.py +34 -0
- scitex/db/_SQLite3.py +2136 -0
- scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
- scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
- scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
- scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
- scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
- scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
- scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
- scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
- scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
- scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
- scitex/db/_SQLite3Mixins/__init__.py +30 -0
- scitex/db/__init__.py +14 -0
- scitex/db/_delete_duplicates.py +397 -0
- scitex/db/_inspect.py +163 -0
- scitex/decorators/__init__.py +54 -0
- scitex/decorators/_auto_order.py +172 -0
- scitex/decorators/_batch_fn.py +127 -0
- scitex/decorators/_cache_disk.py +32 -0
- scitex/decorators/_cache_mem.py +12 -0
- scitex/decorators/_combined.py +98 -0
- scitex/decorators/_converters.py +282 -0
- scitex/decorators/_deprecated.py +26 -0
- scitex/decorators/_not_implemented.py +30 -0
- scitex/decorators/_numpy_fn.py +86 -0
- scitex/decorators/_pandas_fn.py +121 -0
- scitex/decorators/_preserve_doc.py +19 -0
- scitex/decorators/_signal_fn.py +95 -0
- scitex/decorators/_timeout.py +55 -0
- scitex/decorators/_torch_fn.py +136 -0
- scitex/decorators/_wrap.py +39 -0
- scitex/decorators/_xarray_fn.py +88 -0
- scitex/dev/__init__.py +15 -0
- scitex/dev/_analyze_code_flow.py +284 -0
- scitex/dev/_reload.py +59 -0
- scitex/dict/_DotDict.py +442 -0
- scitex/dict/__init__.py +18 -0
- scitex/dict/_listed_dict.py +42 -0
- scitex/dict/_pop_keys.py +36 -0
- scitex/dict/_replace.py +13 -0
- scitex/dict/_safe_merge.py +62 -0
- scitex/dict/_to_str.py +32 -0
- scitex/dsp/__init__.py +72 -0
- scitex/dsp/_crop.py +122 -0
- scitex/dsp/_demo_sig.py +331 -0
- scitex/dsp/_detect_ripples.py +212 -0
- scitex/dsp/_ensure_3d.py +18 -0
- scitex/dsp/_hilbert.py +78 -0
- scitex/dsp/_listen.py +702 -0
- scitex/dsp/_misc.py +30 -0
- scitex/dsp/_mne.py +32 -0
- scitex/dsp/_modulation_index.py +79 -0
- scitex/dsp/_pac.py +319 -0
- scitex/dsp/_psd.py +102 -0
- scitex/dsp/_resample.py +65 -0
- scitex/dsp/_time.py +36 -0
- scitex/dsp/_transform.py +68 -0
- scitex/dsp/_wavelet.py +212 -0
- scitex/dsp/add_noise.py +111 -0
- scitex/dsp/example.py +253 -0
- scitex/dsp/filt.py +155 -0
- scitex/dsp/norm.py +18 -0
- scitex/dsp/params.py +51 -0
- scitex/dsp/reference.py +43 -0
- scitex/dsp/template.py +25 -0
- scitex/dsp/utils/__init__.py +15 -0
- scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
- scitex/dsp/utils/_ensure_3d.py +18 -0
- scitex/dsp/utils/_ensure_even_len.py +10 -0
- scitex/dsp/utils/_zero_pad.py +48 -0
- scitex/dsp/utils/filter.py +408 -0
- scitex/dsp/utils/pac.py +177 -0
- scitex/dt/__init__.py +8 -0
- scitex/dt/_linspace.py +130 -0
- scitex/etc/__init__.py +15 -0
- scitex/etc/wait_key.py +34 -0
- scitex/gen/_DimHandler.py +196 -0
- scitex/gen/_TimeStamper.py +244 -0
- scitex/gen/__init__.py +95 -0
- scitex/gen/_alternate_kwarg.py +13 -0
- scitex/gen/_cache.py +11 -0
- scitex/gen/_check_host.py +34 -0
- scitex/gen/_ci.py +12 -0
- scitex/gen/_close.py +222 -0
- scitex/gen/_embed.py +78 -0
- scitex/gen/_inspect_module.py +257 -0
- scitex/gen/_is_ipython.py +12 -0
- scitex/gen/_less.py +48 -0
- scitex/gen/_list_packages.py +139 -0
- scitex/gen/_mat2py.py +88 -0
- scitex/gen/_norm.py +170 -0
- scitex/gen/_paste.py +18 -0
- scitex/gen/_print_config.py +84 -0
- scitex/gen/_shell.py +48 -0
- scitex/gen/_src.py +111 -0
- scitex/gen/_start.py +451 -0
- scitex/gen/_symlink.py +55 -0
- scitex/gen/_symlog.py +27 -0
- scitex/gen/_tee.py +238 -0
- scitex/gen/_title2path.py +60 -0
- scitex/gen/_title_case.py +88 -0
- scitex/gen/_to_even.py +84 -0
- scitex/gen/_to_odd.py +34 -0
- scitex/gen/_to_rank.py +39 -0
- scitex/gen/_transpose.py +37 -0
- scitex/gen/_type.py +78 -0
- scitex/gen/_var_info.py +73 -0
- scitex/gen/_wrap.py +17 -0
- scitex/gen/_xml2dict.py +76 -0
- scitex/gen/misc.py +730 -0
- scitex/gen/path.py +0 -0
- scitex/general/__init__.py +5 -0
- scitex/gists/_SigMacro_processFigure_S.py +128 -0
- scitex/gists/_SigMacro_toBlue.py +172 -0
- scitex/gists/__init__.py +12 -0
- scitex/io/_H5Explorer.py +292 -0
- scitex/io/__init__.py +82 -0
- scitex/io/_cache.py +101 -0
- scitex/io/_flush.py +24 -0
- scitex/io/_glob.py +103 -0
- scitex/io/_json2md.py +113 -0
- scitex/io/_load.py +168 -0
- scitex/io/_load_configs.py +146 -0
- scitex/io/_load_modules/__init__.py +38 -0
- scitex/io/_load_modules/_catboost.py +66 -0
- scitex/io/_load_modules/_con.py +20 -0
- scitex/io/_load_modules/_db.py +24 -0
- scitex/io/_load_modules/_docx.py +42 -0
- scitex/io/_load_modules/_eeg.py +110 -0
- scitex/io/_load_modules/_hdf5.py +196 -0
- scitex/io/_load_modules/_image.py +19 -0
- scitex/io/_load_modules/_joblib.py +19 -0
- scitex/io/_load_modules/_json.py +18 -0
- scitex/io/_load_modules/_markdown.py +103 -0
- scitex/io/_load_modules/_matlab.py +37 -0
- scitex/io/_load_modules/_numpy.py +39 -0
- scitex/io/_load_modules/_optuna.py +155 -0
- scitex/io/_load_modules/_pandas.py +69 -0
- scitex/io/_load_modules/_pdf.py +31 -0
- scitex/io/_load_modules/_pickle.py +24 -0
- scitex/io/_load_modules/_torch.py +16 -0
- scitex/io/_load_modules/_txt.py +126 -0
- scitex/io/_load_modules/_xml.py +49 -0
- scitex/io/_load_modules/_yaml.py +23 -0
- scitex/io/_mv_to_tmp.py +19 -0
- scitex/io/_path.py +286 -0
- scitex/io/_reload.py +78 -0
- scitex/io/_save.py +539 -0
- scitex/io/_save_modules/__init__.py +66 -0
- scitex/io/_save_modules/_catboost.py +22 -0
- scitex/io/_save_modules/_csv.py +89 -0
- scitex/io/_save_modules/_excel.py +49 -0
- scitex/io/_save_modules/_hdf5.py +249 -0
- scitex/io/_save_modules/_html.py +48 -0
- scitex/io/_save_modules/_image.py +140 -0
- scitex/io/_save_modules/_joblib.py +25 -0
- scitex/io/_save_modules/_json.py +25 -0
- scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
- scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
- scitex/io/_save_modules/_matlab.py +24 -0
- scitex/io/_save_modules/_mp4.py +29 -0
- scitex/io/_save_modules/_numpy.py +57 -0
- scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
- scitex/io/_save_modules/_pickle.py +45 -0
- scitex/io/_save_modules/_plotly.py +27 -0
- scitex/io/_save_modules/_text.py +23 -0
- scitex/io/_save_modules/_torch.py +26 -0
- scitex/io/_save_modules/_yaml.py +29 -0
- scitex/life/__init__.py +10 -0
- scitex/life/_monitor_rain.py +49 -0
- scitex/linalg/__init__.py +17 -0
- scitex/linalg/_distance.py +63 -0
- scitex/linalg/_geometric_median.py +64 -0
- scitex/linalg/_misc.py +73 -0
- scitex/nn/_AxiswiseDropout.py +27 -0
- scitex/nn/_BNet.py +126 -0
- scitex/nn/_BNet_Res.py +164 -0
- scitex/nn/_ChannelGainChanger.py +44 -0
- scitex/nn/_DropoutChannels.py +50 -0
- scitex/nn/_Filters.py +489 -0
- scitex/nn/_FreqGainChanger.py +110 -0
- scitex/nn/_GaussianFilter.py +48 -0
- scitex/nn/_Hilbert.py +111 -0
- scitex/nn/_MNet_1000.py +157 -0
- scitex/nn/_ModulationIndex.py +221 -0
- scitex/nn/_PAC.py +414 -0
- scitex/nn/_PSD.py +40 -0
- scitex/nn/_ResNet1D.py +120 -0
- scitex/nn/_SpatialAttention.py +25 -0
- scitex/nn/_Spectrogram.py +161 -0
- scitex/nn/_SwapChannels.py +50 -0
- scitex/nn/_TransposeLayer.py +19 -0
- scitex/nn/_Wavelet.py +183 -0
- scitex/nn/__init__.py +63 -0
- scitex/os/__init__.py +8 -0
- scitex/os/_mv.py +50 -0
- scitex/parallel/__init__.py +8 -0
- scitex/parallel/_run.py +151 -0
- scitex/path/__init__.py +33 -0
- scitex/path/_clean.py +52 -0
- scitex/path/_find.py +108 -0
- scitex/path/_get_module_path.py +51 -0
- scitex/path/_get_spath.py +35 -0
- scitex/path/_getsize.py +18 -0
- scitex/path/_increment_version.py +87 -0
- scitex/path/_mk_spath.py +51 -0
- scitex/path/_path.py +19 -0
- scitex/path/_split.py +23 -0
- scitex/path/_this_path.py +19 -0
- scitex/path/_version.py +101 -0
- scitex/pd/__init__.py +41 -0
- scitex/pd/_find_indi.py +126 -0
- scitex/pd/_find_pval.py +113 -0
- scitex/pd/_force_df.py +154 -0
- scitex/pd/_from_xyz.py +71 -0
- scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
- scitex/pd/_melt_cols.py +81 -0
- scitex/pd/_merge_columns.py +221 -0
- scitex/pd/_mv.py +63 -0
- scitex/pd/_replace.py +62 -0
- scitex/pd/_round.py +93 -0
- scitex/pd/_slice.py +63 -0
- scitex/pd/_sort.py +91 -0
- scitex/pd/_to_numeric.py +53 -0
- scitex/pd/_to_xy.py +59 -0
- scitex/pd/_to_xyz.py +110 -0
- scitex/plt/__init__.py +36 -0
- scitex/plt/_subplots/_AxesWrapper.py +182 -0
- scitex/plt/_subplots/_AxisWrapper.py +249 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
- scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
- scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
- scitex/plt/_subplots/_FigWrapper.py +226 -0
- scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
- scitex/plt/_subplots/__init__.py +111 -0
- scitex/plt/_subplots/_export_as_csv.py +232 -0
- scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
- scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
- scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
- scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
- scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
- scitex/plt/_tpl.py +28 -0
- scitex/plt/ax/__init__.py +114 -0
- scitex/plt/ax/_plot/__init__.py +53 -0
- scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
- scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
- scitex/plt/ax/_plot/_plot_cube.py +57 -0
- scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
- scitex/plt/ax/_plot/_plot_fillv.py +55 -0
- scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
- scitex/plt/ax/_plot/_plot_image.py +94 -0
- scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
- scitex/plt/ax/_plot/_plot_raster.py +172 -0
- scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
- scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
- scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
- scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
- scitex/plt/ax/_plot/_plot_violin.py +343 -0
- scitex/plt/ax/_style/__init__.py +38 -0
- scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
- scitex/plt/ax/_style/_add_panel.py +92 -0
- scitex/plt/ax/_style/_extend.py +64 -0
- scitex/plt/ax/_style/_force_aspect.py +37 -0
- scitex/plt/ax/_style/_format_label.py +23 -0
- scitex/plt/ax/_style/_hide_spines.py +84 -0
- scitex/plt/ax/_style/_map_ticks.py +182 -0
- scitex/plt/ax/_style/_rotate_labels.py +215 -0
- scitex/plt/ax/_style/_sci_note.py +279 -0
- scitex/plt/ax/_style/_set_log_scale.py +299 -0
- scitex/plt/ax/_style/_set_meta.py +261 -0
- scitex/plt/ax/_style/_set_n_ticks.py +37 -0
- scitex/plt/ax/_style/_set_size.py +16 -0
- scitex/plt/ax/_style/_set_supxyt.py +116 -0
- scitex/plt/ax/_style/_set_ticks.py +276 -0
- scitex/plt/ax/_style/_set_xyt.py +121 -0
- scitex/plt/ax/_style/_share_axes.py +264 -0
- scitex/plt/ax/_style/_shift.py +139 -0
- scitex/plt/ax/_style/_show_spines.py +333 -0
- scitex/plt/color/_PARAMS.py +70 -0
- scitex/plt/color/__init__.py +52 -0
- scitex/plt/color/_add_hue_col.py +41 -0
- scitex/plt/color/_colors.py +205 -0
- scitex/plt/color/_get_colors_from_cmap.py +134 -0
- scitex/plt/color/_interpolate.py +29 -0
- scitex/plt/color/_vizualize_colors.py +54 -0
- scitex/plt/utils/__init__.py +44 -0
- scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
- scitex/plt/utils/_calc_nice_ticks.py +101 -0
- scitex/plt/utils/_close.py +68 -0
- scitex/plt/utils/_colorbar.py +96 -0
- scitex/plt/utils/_configure_mpl.py +295 -0
- scitex/plt/utils/_histogram_utils.py +132 -0
- scitex/plt/utils/_im2grid.py +70 -0
- scitex/plt/utils/_is_valid_axis.py +78 -0
- scitex/plt/utils/_mk_colorbar.py +65 -0
- scitex/plt/utils/_mk_patches.py +26 -0
- scitex/plt/utils/_scientific_captions.py +638 -0
- scitex/plt/utils/_scitex_config.py +223 -0
- scitex/reproduce/__init__.py +14 -0
- scitex/reproduce/_fix_seeds.py +45 -0
- scitex/reproduce/_gen_ID.py +55 -0
- scitex/reproduce/_gen_timestamp.py +35 -0
- scitex/res/__init__.py +5 -0
- scitex/resource/__init__.py +13 -0
- scitex/resource/_get_processor_usages.py +281 -0
- scitex/resource/_get_specs.py +280 -0
- scitex/resource/_log_processor_usages.py +190 -0
- scitex/resource/_utils/__init__.py +31 -0
- scitex/resource/_utils/_get_env_info.py +481 -0
- scitex/resource/limit_ram.py +33 -0
- scitex/scholar/__init__.py +24 -0
- scitex/scholar/_local_search.py +454 -0
- scitex/scholar/_paper.py +244 -0
- scitex/scholar/_pdf_downloader.py +325 -0
- scitex/scholar/_search.py +393 -0
- scitex/scholar/_vector_search.py +370 -0
- scitex/scholar/_web_sources.py +457 -0
- scitex/stats/__init__.py +31 -0
- scitex/stats/_calc_partial_corr.py +17 -0
- scitex/stats/_corr_test_multi.py +94 -0
- scitex/stats/_corr_test_wrapper.py +115 -0
- scitex/stats/_describe_wrapper.py +90 -0
- scitex/stats/_multiple_corrections.py +63 -0
- scitex/stats/_nan_stats.py +93 -0
- scitex/stats/_p2stars.py +116 -0
- scitex/stats/_p2stars_wrapper.py +56 -0
- scitex/stats/_statistical_tests.py +73 -0
- scitex/stats/desc/__init__.py +40 -0
- scitex/stats/desc/_describe.py +189 -0
- scitex/stats/desc/_nan.py +289 -0
- scitex/stats/desc/_real.py +94 -0
- scitex/stats/multiple/__init__.py +14 -0
- scitex/stats/multiple/_bonferroni_correction.py +72 -0
- scitex/stats/multiple/_fdr_correction.py +400 -0
- scitex/stats/multiple/_multicompair.py +28 -0
- scitex/stats/tests/__corr_test.py +277 -0
- scitex/stats/tests/__corr_test_multi.py +343 -0
- scitex/stats/tests/__corr_test_single.py +277 -0
- scitex/stats/tests/__init__.py +22 -0
- scitex/stats/tests/_brunner_munzel_test.py +192 -0
- scitex/stats/tests/_nocorrelation_test.py +28 -0
- scitex/stats/tests/_smirnov_grubbs.py +98 -0
- scitex/str/__init__.py +113 -0
- scitex/str/_clean_path.py +75 -0
- scitex/str/_color_text.py +52 -0
- scitex/str/_decapitalize.py +58 -0
- scitex/str/_factor_out_digits.py +281 -0
- scitex/str/_format_plot_text.py +498 -0
- scitex/str/_grep.py +48 -0
- scitex/str/_latex.py +155 -0
- scitex/str/_latex_fallback.py +471 -0
- scitex/str/_mask_api.py +39 -0
- scitex/str/_mask_api_key.py +8 -0
- scitex/str/_parse.py +158 -0
- scitex/str/_print_block.py +47 -0
- scitex/str/_print_debug.py +68 -0
- scitex/str/_printc.py +62 -0
- scitex/str/_readable_bytes.py +38 -0
- scitex/str/_remove_ansi.py +23 -0
- scitex/str/_replace.py +134 -0
- scitex/str/_search.py +125 -0
- scitex/str/_squeeze_space.py +36 -0
- scitex/tex/__init__.py +10 -0
- scitex/tex/_preview.py +103 -0
- scitex/tex/_to_vec.py +116 -0
- scitex/torch/__init__.py +18 -0
- scitex/torch/_apply_to.py +34 -0
- scitex/torch/_nan_funcs.py +77 -0
- scitex/types/_ArrayLike.py +44 -0
- scitex/types/_ColorLike.py +21 -0
- scitex/types/__init__.py +14 -0
- scitex/types/_is_listed_X.py +70 -0
- scitex/utils/__init__.py +22 -0
- scitex/utils/_compress_hdf5.py +116 -0
- scitex/utils/_email.py +120 -0
- scitex/utils/_grid.py +148 -0
- scitex/utils/_notify.py +247 -0
- scitex/utils/_search.py +121 -0
- scitex/web/__init__.py +38 -0
- scitex/web/_search_pubmed.py +438 -0
- scitex/web/_summarize_url.py +158 -0
- scitex-2.0.0.dist-info/METADATA +307 -0
- scitex-2.0.0.dist-info/RECORD +572 -0
- scitex-2.0.0.dist-info/WHEEL +6 -0
- scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
- scitex-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-04-02 09:21:12 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from ..decorators import numpy_fn, torch_fn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Spectrogram(nn.Module):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
sampling_rate,
|
|
15
|
+
n_fft=256,
|
|
16
|
+
hop_length=None,
|
|
17
|
+
win_length=None,
|
|
18
|
+
window="hann",
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.sampling_rate = sampling_rate
|
|
22
|
+
self.n_fft = n_fft
|
|
23
|
+
self.hop_length = hop_length if hop_length is not None else n_fft // 4
|
|
24
|
+
self.win_length = win_length if win_length is not None else n_fft
|
|
25
|
+
if window == "hann":
|
|
26
|
+
self.window = torch.hann_window(window_length=self.win_length)
|
|
27
|
+
else:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Unsupported window type. Extend this to support more window types."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def forward(self, x):
|
|
33
|
+
"""
|
|
34
|
+
Computes the spectrogram for each channel in the input signal.
|
|
35
|
+
|
|
36
|
+
Parameters:
|
|
37
|
+
- signal (torch.Tensor): Input signal of shape (batch_size, n_chs, seq_len).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
- spectrograms (torch.Tensor): The computed spectrograms for each channel.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
x = scitex.dsp.ensure_3d(x)
|
|
44
|
+
|
|
45
|
+
batch_size, n_chs, seq_len = x.shape
|
|
46
|
+
spectrograms = []
|
|
47
|
+
|
|
48
|
+
for ch in range(n_chs):
|
|
49
|
+
x_ch = x[:, ch, :].unsqueeze(1) # Maintain expected input shape for stft
|
|
50
|
+
spec = torch.stft(
|
|
51
|
+
x_ch.squeeze(1),
|
|
52
|
+
n_fft=self.n_fft,
|
|
53
|
+
hop_length=self.hop_length,
|
|
54
|
+
win_length=self.win_length,
|
|
55
|
+
window=self.window.to(x.device),
|
|
56
|
+
center=True,
|
|
57
|
+
pad_mode="reflect",
|
|
58
|
+
normalized=False,
|
|
59
|
+
return_complex=True,
|
|
60
|
+
)
|
|
61
|
+
magnitude = torch.abs(spec).unsqueeze(1) # Keep channel dimension
|
|
62
|
+
spectrograms.append(magnitude)
|
|
63
|
+
|
|
64
|
+
# Concatenate spectrograms along channel dimension
|
|
65
|
+
spectrograms = torch.cat(spectrograms, dim=1)
|
|
66
|
+
|
|
67
|
+
# Calculate frequencies (y-axis)
|
|
68
|
+
freqs = torch.linspace(0, self.sampling_rate / 2, steps=self.n_fft // 2 + 1)
|
|
69
|
+
|
|
70
|
+
# Calculate times (x-axis)
|
|
71
|
+
# The number of frames can be computed from the size of the last dimension of the spectrogram
|
|
72
|
+
n_frames = spectrograms.shape[-1]
|
|
73
|
+
# Time of each frame in seconds, considering the hop length and sampling rate
|
|
74
|
+
times_sec = torch.arange(0, n_frames) * (self.hop_length / self.sampling_rate)
|
|
75
|
+
|
|
76
|
+
return spectrograms, freqs, times_sec
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@torch_fn
|
|
80
|
+
def spectrograms(x, fs, cuda=False):
|
|
81
|
+
return Spectrogram(fs)(x)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@torch_fn
|
|
85
|
+
def my_softmax(x, dim=-1):
|
|
86
|
+
return F.softmax(x, dim=dim)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@torch_fn
|
|
90
|
+
def unbias(x, func="min", dim=-1, cuda=False):
|
|
91
|
+
if func == "min":
|
|
92
|
+
return x - x.min(dim=dim, keepdims=True)[0]
|
|
93
|
+
if func == "mean":
|
|
94
|
+
return x - x.mean(dim=dim, keepdims=True)[0]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@torch_fn
|
|
98
|
+
def normalize(x, axis=-1, amp=1.0, cuda=False):
|
|
99
|
+
high = torch.abs(x.max(axis=axis, keepdims=True)[0])
|
|
100
|
+
low = torch.abs(x.min(axis=axis, keepdims=True)[0])
|
|
101
|
+
return amp * x / torch.maximum(high, low)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@torch_fn
|
|
105
|
+
def spectrograms(x, fs, dj=0.125, cuda=False):
|
|
106
|
+
from wavelets_pytorch.transform import (
|
|
107
|
+
WaveletTransformTorch,
|
|
108
|
+
) # PyTorch version
|
|
109
|
+
|
|
110
|
+
dt = 1 / fs
|
|
111
|
+
# dj = 0.125
|
|
112
|
+
batch_size, n_chs, seq_len = x.shape
|
|
113
|
+
|
|
114
|
+
x = x.cpu().numpy()
|
|
115
|
+
|
|
116
|
+
# # Batch of signals to process
|
|
117
|
+
# batch = np.array([batch_size * seq_len])
|
|
118
|
+
|
|
119
|
+
# Initialize wavelet filter banks (scipy and torch implementation)
|
|
120
|
+
# wa_scipy = WaveletTransform(dt, dj)
|
|
121
|
+
wa_torch = WaveletTransformTorch(dt, dj, cuda=True)
|
|
122
|
+
|
|
123
|
+
# Performing wavelet transform (and compute scalogram)
|
|
124
|
+
# cwt_scipy = wa_scipy.cwt(batch)
|
|
125
|
+
x = x[:, 0][:, np.newaxis]
|
|
126
|
+
cwt_torch = wa_torch.cwt(x)
|
|
127
|
+
|
|
128
|
+
return cwt_torch
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
if __name__ == "__main__":
|
|
132
|
+
import scitex
|
|
133
|
+
import seaborn as sns
|
|
134
|
+
import torchaudio
|
|
135
|
+
|
|
136
|
+
fs = 1024 # 128
|
|
137
|
+
t_sec = 10
|
|
138
|
+
x = scitex.dsp.np.demo_sig(t_sec=t_sec, fs=fs, type="ripple")
|
|
139
|
+
|
|
140
|
+
normalize(unbias(x, cuda=True), cuda=True)
|
|
141
|
+
|
|
142
|
+
# My implementtion
|
|
143
|
+
ss = spectrograms(x, fs, cuda=True)
|
|
144
|
+
fig, axes = plt.subplots(nrows=2)
|
|
145
|
+
axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0])
|
|
146
|
+
sns.heatmap(ss[0], ax=axes[1])
|
|
147
|
+
plt.show()
|
|
148
|
+
|
|
149
|
+
ss, ff, tt = spectrograms(x, fs, cuda=True)
|
|
150
|
+
fig, axes = plt.subplots(nrows=2)
|
|
151
|
+
axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0])
|
|
152
|
+
sns.heatmap(ss[0], ax=axes[1])
|
|
153
|
+
plt.show()
|
|
154
|
+
|
|
155
|
+
# Torch Audio
|
|
156
|
+
transform = torchaudio.transforms.Spectrogram(n_fft=16, normalized=True).cuda()
|
|
157
|
+
xx = torch.tensor(x).float().cuda()[0, 0]
|
|
158
|
+
ss = transform(xx)
|
|
159
|
+
sns.heatmap(ss.detach().cpu().numpy())
|
|
160
|
+
|
|
161
|
+
plt.show()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2023-05-04 21:21:19 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from torchsummary import summary
|
|
9
|
+
import scitex
|
|
10
|
+
import numpy as np
|
|
11
|
+
import random
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SwapChannels(nn.Module):
|
|
15
|
+
def __init__(self, dropout=0.5):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
18
|
+
|
|
19
|
+
def forward(self, x):
|
|
20
|
+
"""x: [batch_size, n_chs, seq_len]"""
|
|
21
|
+
if self.training:
|
|
22
|
+
orig_chs = torch.arange(x.shape[1])
|
|
23
|
+
|
|
24
|
+
indi_orig = self.dropout(torch.ones(x.shape[1])).bool()
|
|
25
|
+
chs_to_shuffle = orig_chs[~indi_orig]
|
|
26
|
+
|
|
27
|
+
rand_chs = random.sample(
|
|
28
|
+
list(np.array(chs_to_shuffle)), len(chs_to_shuffle)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
swapped_chs = orig_chs.clone()
|
|
32
|
+
swapped_chs[~indi_orig] = torch.LongTensor(rand_chs)
|
|
33
|
+
|
|
34
|
+
x = x[:, swapped_chs.long(), :]
|
|
35
|
+
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
## Demo data
|
|
41
|
+
bs, n_chs, seq_len = 16, 360, 1000
|
|
42
|
+
x = torch.rand(bs, n_chs, seq_len)
|
|
43
|
+
|
|
44
|
+
sc = SwapChannels()
|
|
45
|
+
print(sc(x).shape) # [16, 19, 1000]
|
|
46
|
+
|
|
47
|
+
# sb = SubjectBlock(n_chs=n_chs)
|
|
48
|
+
# print(sb(x, s).shape) # [16, 270, 1000]
|
|
49
|
+
|
|
50
|
+
# summary(sb, x, s)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-03-30 07:26:35 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransposeLayer(nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
axis1,
|
|
12
|
+
axis2,
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.axis1 = axis1
|
|
16
|
+
self.axis2 = axis2
|
|
17
|
+
|
|
18
|
+
def forward(self, x):
|
|
19
|
+
return x.transpose(self.axis1, self.axis2)
|
scitex/nn/_Wavelet.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-03 07:17:26 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/nn/_Wavelet.py
|
|
5
|
+
|
|
6
|
+
#!/usr/bin/env python3
|
|
7
|
+
# -*- coding: utf-8 -*-
|
|
8
|
+
# Time-stamp: "2024-05-30 11:04:45 (ywatanabe)"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import scitex
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
from ..gen._to_even import to_even
|
|
17
|
+
from ..gen._to_odd import to_odd
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Wavelet(nn.Module):
|
|
21
|
+
def __init__(
|
|
22
|
+
self, samp_rate, kernel_size=None, freq_scale="linear", out_scale="log"
|
|
23
|
+
):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.register_buffer("dummy", torch.tensor(0))
|
|
26
|
+
self.kernel = None
|
|
27
|
+
self.init_kernel(samp_rate, kernel_size=kernel_size, freq_scale=freq_scale)
|
|
28
|
+
self.out_scale = out_scale
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
"""Apply the 2D filter (n_filts, kernel_size) to input signal x with shape: (batch_size, n_chs, seq_len)"""
|
|
32
|
+
x = scitex.dsp.ensure_3d(x).to(self.dummy.device)
|
|
33
|
+
seq_len = x.shape[-1]
|
|
34
|
+
|
|
35
|
+
# Ensure the kernel is initialized
|
|
36
|
+
if self.kernel is None:
|
|
37
|
+
self.init_kernel()
|
|
38
|
+
if self.kernel is None:
|
|
39
|
+
raise ValueError("Filter kernel has not been initialized.")
|
|
40
|
+
assert self.kernel.ndim == 2
|
|
41
|
+
self.kernel = self.kernel.to(x.device) # cuda, torch.complex128
|
|
42
|
+
|
|
43
|
+
# Edge handling and convolution
|
|
44
|
+
extension_length = self.radius
|
|
45
|
+
first_segment = x[:, :, :extension_length].flip(dims=[-1])
|
|
46
|
+
last_segment = x[:, :, -extension_length:].flip(dims=[-1])
|
|
47
|
+
extended_x = torch.cat([first_segment, x, last_segment], dim=-1)
|
|
48
|
+
|
|
49
|
+
# working??
|
|
50
|
+
kernel_batched = self.kernel.unsqueeze(1)
|
|
51
|
+
extended_x_reshaped = extended_x.view(-1, 1, extended_x.shape[-1])
|
|
52
|
+
|
|
53
|
+
filtered_x_real = F.conv1d(
|
|
54
|
+
extended_x_reshaped, kernel_batched.real.float(), groups=1
|
|
55
|
+
)
|
|
56
|
+
filtered_x_imag = F.conv1d(
|
|
57
|
+
extended_x_reshaped, kernel_batched.imag.float(), groups=1
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
filtered_x = torch.view_as_complex(
|
|
61
|
+
torch.stack([filtered_x_real, filtered_x_imag], dim=-1)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
filtered_x = filtered_x.view(
|
|
65
|
+
x.shape[0], x.shape[1], kernel_batched.shape[0], -1
|
|
66
|
+
)
|
|
67
|
+
filtered_x = filtered_x.view(
|
|
68
|
+
x.shape[0], x.shape[1], kernel_batched.shape[0], -1
|
|
69
|
+
)
|
|
70
|
+
filtered_x = filtered_x[..., :seq_len]
|
|
71
|
+
assert filtered_x.shape[-1] == seq_len
|
|
72
|
+
|
|
73
|
+
pha = filtered_x.angle()
|
|
74
|
+
amp = filtered_x.abs()
|
|
75
|
+
|
|
76
|
+
# Repeats freqs
|
|
77
|
+
freqs = (
|
|
78
|
+
self.freqs.unsqueeze(0).unsqueeze(0).repeat(pha.shape[0], pha.shape[1], 1)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if self.out_scale == "log":
|
|
82
|
+
return pha, torch.log(amp + 1e-5), freqs
|
|
83
|
+
else:
|
|
84
|
+
return pha, amp, freqs
|
|
85
|
+
|
|
86
|
+
def init_kernel(self, samp_rate, kernel_size=None, freq_scale="log"):
|
|
87
|
+
device = self.dummy.device
|
|
88
|
+
morlets, freqs = self.gen_morlet_to_nyquist(
|
|
89
|
+
samp_rate, kernel_size=kernel_size, freq_scale=freq_scale
|
|
90
|
+
)
|
|
91
|
+
self.kernel = torch.tensor(morlets).to(device)
|
|
92
|
+
self.freqs = torch.tensor(freqs).float().to(device)
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def gen_morlet_to_nyquist(samp_rate, kernel_size=None, freq_scale="linear"):
|
|
96
|
+
"""
|
|
97
|
+
Generates Morlet wavelets for exponentially increasing frequency bands up to the Nyquist frequency.
|
|
98
|
+
|
|
99
|
+
Parameters:
|
|
100
|
+
- samp_rate (int): The sampling rate of the signal, in Hertz.
|
|
101
|
+
- kernel_size (int): The size of the kernel, in number of samples.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
- np.ndarray: A 2D array of complex values representing the Morlet wavelets for each frequency band.
|
|
105
|
+
"""
|
|
106
|
+
if kernel_size is None:
|
|
107
|
+
kernel_size = int(samp_rate) # * 2.5)
|
|
108
|
+
|
|
109
|
+
nyquist_freq = samp_rate / 2
|
|
110
|
+
|
|
111
|
+
# Log freq_scale
|
|
112
|
+
def calc_freq_boundaries_log(nyquist_freq):
|
|
113
|
+
n_kernels = int(np.floor(np.log2(nyquist_freq)))
|
|
114
|
+
mid_hz = np.array([2 ** (n + 1) for n in range(n_kernels)])
|
|
115
|
+
width_hz = np.hstack([np.array([1]), np.diff(mid_hz) / 2]) + 1
|
|
116
|
+
low_hz = mid_hz - width_hz
|
|
117
|
+
high_hz = mid_hz + width_hz
|
|
118
|
+
low_hz[0] = 0.1
|
|
119
|
+
return low_hz, high_hz
|
|
120
|
+
|
|
121
|
+
def calc_freq_boundaries_linear(nyquist_freq):
|
|
122
|
+
n_kernels = int(nyquist_freq)
|
|
123
|
+
high_hz = np.linspace(1, nyquist_freq, n_kernels)
|
|
124
|
+
low_hz = high_hz - np.hstack([np.array(1), np.diff(high_hz)])
|
|
125
|
+
low_hz[0] = 0.1
|
|
126
|
+
return low_hz, high_hz
|
|
127
|
+
|
|
128
|
+
if freq_scale == "linear":
|
|
129
|
+
fn = calc_freq_boundaries_linear
|
|
130
|
+
if freq_scale == "log":
|
|
131
|
+
fn = calc_freq_boundaries_log
|
|
132
|
+
low_hz, high_hz = fn(nyquist_freq)
|
|
133
|
+
|
|
134
|
+
morlets = []
|
|
135
|
+
freqs = []
|
|
136
|
+
|
|
137
|
+
for _, (ll, hh) in enumerate(zip(low_hz, high_hz)):
|
|
138
|
+
if ll > nyquist_freq:
|
|
139
|
+
break
|
|
140
|
+
|
|
141
|
+
center_frequency = (ll + hh) / 2
|
|
142
|
+
|
|
143
|
+
t = np.arange(-kernel_size // 2, kernel_size // 2) / samp_rate
|
|
144
|
+
# Calculate standard deviation of the gaussian window for a given center frequency
|
|
145
|
+
sigma = 7 / (2 * np.pi * center_frequency)
|
|
146
|
+
sine_wave = np.exp(2j * np.pi * center_frequency * t)
|
|
147
|
+
gaussian_window = np.exp(-(t**2) / (2 * sigma**2))
|
|
148
|
+
morlet_wavelet = sine_wave * gaussian_window
|
|
149
|
+
|
|
150
|
+
freqs.append(center_frequency)
|
|
151
|
+
morlets.append(morlet_wavelet)
|
|
152
|
+
|
|
153
|
+
return np.array(morlets), np.array(freqs)
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def kernel_size(
|
|
157
|
+
self,
|
|
158
|
+
):
|
|
159
|
+
return to_even(self.kernel.shape[-1])
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def radius(
|
|
163
|
+
self,
|
|
164
|
+
):
|
|
165
|
+
return to_even(self.kernel_size // 2)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == "__main__":
|
|
169
|
+
import matplotlib.pyplot as plt
|
|
170
|
+
import scitex
|
|
171
|
+
|
|
172
|
+
xx, tt, fs = scitex.dsp.demo_sig(sig_type="chirp")
|
|
173
|
+
|
|
174
|
+
pha, amp, ff = scitex.dsp.wavelet(xx, fs)
|
|
175
|
+
|
|
176
|
+
fig, ax = scitex.plt.subplots()
|
|
177
|
+
ax.imshow2d(amp[0, 0].T)
|
|
178
|
+
ax = scitex.plt.ax.set_ticks(ax, xticks=tt, yticks=ff)
|
|
179
|
+
ax = scitex.plt.ax.set_n_ticks(ax)
|
|
180
|
+
plt.show()
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# EOF
|
scitex/nn/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Scitex nn module."""
|
|
3
|
+
|
|
4
|
+
from ._AxiswiseDropout import AxiswiseDropout
|
|
5
|
+
from ._BNet import BHead, BNet, BNet_config
|
|
6
|
+
from ._BNet_Res import BHead, BNet, BNet_config
|
|
7
|
+
from ._ChannelGainChanger import ChannelGainChanger
|
|
8
|
+
from ._DropoutChannels import DropoutChannels
|
|
9
|
+
from ._Filters import BandPassFilter, BandStopFilter, BaseFilter1D, DifferentiableBandPassFilter, GaussianFilter, HighPassFilter, LowPassFilter
|
|
10
|
+
from ._FreqGainChanger import FreqGainChanger
|
|
11
|
+
from ._GaussianFilter import GaussianFilter
|
|
12
|
+
from ._Hilbert import Hilbert
|
|
13
|
+
from ._MNet_1000 import MNet1000, MNet_1000, MNet_config, ReshapeLayer, SwapLayer
|
|
14
|
+
from ._ModulationIndex import ModulationIndex
|
|
15
|
+
from ._PAC import PAC
|
|
16
|
+
from ._PSD import PSD
|
|
17
|
+
from ._ResNet1D import ResNet1D, ResNetBasicBlock
|
|
18
|
+
from ._SpatialAttention import SpatialAttention
|
|
19
|
+
from ._Spectrogram import Spectrogram, my_softmax, normalize, spectrograms, unbias
|
|
20
|
+
from ._SwapChannels import SwapChannels
|
|
21
|
+
from ._TransposeLayer import TransposeLayer
|
|
22
|
+
from ._Wavelet import Wavelet
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"AxiswiseDropout",
|
|
26
|
+
"BHead",
|
|
27
|
+
"BHead",
|
|
28
|
+
"BNet",
|
|
29
|
+
"BNet",
|
|
30
|
+
"BNet_config",
|
|
31
|
+
"BNet_config",
|
|
32
|
+
"BandPassFilter",
|
|
33
|
+
"BandStopFilter",
|
|
34
|
+
"BaseFilter1D",
|
|
35
|
+
"ChannelGainChanger",
|
|
36
|
+
"DifferentiableBandPassFilter",
|
|
37
|
+
"DropoutChannels",
|
|
38
|
+
"FreqGainChanger",
|
|
39
|
+
"GaussianFilter",
|
|
40
|
+
"GaussianFilter",
|
|
41
|
+
"HighPassFilter",
|
|
42
|
+
"Hilbert",
|
|
43
|
+
"LowPassFilter",
|
|
44
|
+
"MNet1000",
|
|
45
|
+
"MNet_1000",
|
|
46
|
+
"MNet_config",
|
|
47
|
+
"ModulationIndex",
|
|
48
|
+
"PAC",
|
|
49
|
+
"PSD",
|
|
50
|
+
"ResNet1D",
|
|
51
|
+
"ResNetBasicBlock",
|
|
52
|
+
"ReshapeLayer",
|
|
53
|
+
"SpatialAttention",
|
|
54
|
+
"Spectrogram",
|
|
55
|
+
"SwapChannels",
|
|
56
|
+
"SwapLayer",
|
|
57
|
+
"TransposeLayer",
|
|
58
|
+
"Wavelet",
|
|
59
|
+
"my_softmax",
|
|
60
|
+
"normalize",
|
|
61
|
+
"spectrograms",
|
|
62
|
+
"unbias",
|
|
63
|
+
]
|
scitex/os/__init__.py
ADDED
scitex/os/_mv.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-04-06 09:00:45 (ywatanabe)"
|
|
4
|
+
|
|
5
|
+
# import os
|
|
6
|
+
# import shutil
|
|
7
|
+
|
|
8
|
+
# def mv(src, tgt):
|
|
9
|
+
# successful = True
|
|
10
|
+
# os.makedirs(tgt, exist_ok=True)
|
|
11
|
+
|
|
12
|
+
# if os.path.isdir(src):
|
|
13
|
+
# # Iterate over the items in the directory
|
|
14
|
+
# for item in os.listdir(src):
|
|
15
|
+
# item_path = os.path.join(src, item)
|
|
16
|
+
# # Check if the item is a file
|
|
17
|
+
# if os.path.isfile(item_path):
|
|
18
|
+
# try:
|
|
19
|
+
# shutil.move(item_path, tgt)
|
|
20
|
+
# print(f"\nMoved file from {item_path} to {tgt}")
|
|
21
|
+
# except OSError as e:
|
|
22
|
+
# print(f"\nError: {e}")
|
|
23
|
+
# successful = False
|
|
24
|
+
# else:
|
|
25
|
+
# print(f"\nSkipped directory {item_path}")
|
|
26
|
+
# else:
|
|
27
|
+
# # If src is a file, just move it
|
|
28
|
+
# try:
|
|
29
|
+
# shutil.move(src, tgt)
|
|
30
|
+
# print(f"\nMoved from {src} to {tgt}")
|
|
31
|
+
# except OSError as e:
|
|
32
|
+
# print(f"\nError: {e}")
|
|
33
|
+
# successful = False
|
|
34
|
+
|
|
35
|
+
# return successful
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def mv(src, tgt):
|
|
39
|
+
import os
|
|
40
|
+
import shutil
|
|
41
|
+
|
|
42
|
+
successful = True
|
|
43
|
+
os.makedirs(tgt, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
shutil.move(src, tgt)
|
|
47
|
+
print(f"\nMoved from {src} to {tgt}")
|
|
48
|
+
except OSError as e:
|
|
49
|
+
print(f"\nError: {e}")
|
|
50
|
+
successful = False
|
scitex/parallel/_run.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Time-stamp: "2024-11-14 23:12:20 (ywatanabe)"
|
|
4
|
+
# File: ./scitex_repo/src/scitex/parallel/_run.py
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
1. Functionality:
|
|
8
|
+
- Runs functions in parallel using ProcessPoolExecutor
|
|
9
|
+
- Handles both single and multiple return values
|
|
10
|
+
- Supports automatic CPU core detection
|
|
11
|
+
2. Input:
|
|
12
|
+
- Function to run
|
|
13
|
+
- List of items to process
|
|
14
|
+
- Optional parameters for execution control
|
|
15
|
+
3. Output:
|
|
16
|
+
- List of results or concatenated DataFrame/tuple
|
|
17
|
+
4. Prerequisites:
|
|
18
|
+
- concurrent.futures
|
|
19
|
+
- pandas
|
|
20
|
+
- tqdm
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import multiprocessing
|
|
24
|
+
import warnings
|
|
25
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
26
|
+
from typing import Any, Callable, List
|
|
27
|
+
|
|
28
|
+
from tqdm import tqdm
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run(
|
|
32
|
+
func: Callable,
|
|
33
|
+
args_list: List[tuple],
|
|
34
|
+
n_jobs: int = -1,
|
|
35
|
+
desc: str = "Processing",
|
|
36
|
+
) -> List[Any]:
|
|
37
|
+
"""Runs function in parallel using ThreadPoolExecutor with tuple arguments.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
func : Callable
|
|
42
|
+
Function to run in parallel
|
|
43
|
+
args_list : List[tuple]
|
|
44
|
+
List of argument tuples, each tuple contains arguments for one function call
|
|
45
|
+
n_jobs : int, optional
|
|
46
|
+
Number of jobs to run in parallel. -1 means using all processors
|
|
47
|
+
desc : str, optional
|
|
48
|
+
Description for progress bar
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
List[Any]
|
|
53
|
+
Results of parallel execution
|
|
54
|
+
|
|
55
|
+
Examples
|
|
56
|
+
--------
|
|
57
|
+
>>> def add(x, y):
|
|
58
|
+
... return x + y
|
|
59
|
+
>>> args_list = [(1, 4), (2, 5), (3, 6)]
|
|
60
|
+
>>> run(add, args_list)
|
|
61
|
+
[5, 7, 9]
|
|
62
|
+
"""
|
|
63
|
+
if not args_list:
|
|
64
|
+
raise ValueError("Args list cannot be empty")
|
|
65
|
+
if not callable(func):
|
|
66
|
+
raise ValueError("Func must be callable")
|
|
67
|
+
|
|
68
|
+
cpu_count = multiprocessing.cpu_count()
|
|
69
|
+
n_jobs = cpu_count if n_jobs < 0 else n_jobs
|
|
70
|
+
|
|
71
|
+
if n_jobs > cpu_count:
|
|
72
|
+
warnings.warn(f"n_jobs ({n_jobs}) is greater than CPU count ({cpu_count})")
|
|
73
|
+
if n_jobs < 1:
|
|
74
|
+
raise ValueError("n_jobs must be >= 1 or -1")
|
|
75
|
+
|
|
76
|
+
results = [None] * len(args_list) # Pre-allocate list
|
|
77
|
+
|
|
78
|
+
with ThreadPoolExecutor(max_workers=n_jobs) as executor:
|
|
79
|
+
futures = {
|
|
80
|
+
executor.submit(func, *args): idx for idx, args in enumerate(args_list)
|
|
81
|
+
}
|
|
82
|
+
for future in tqdm(as_completed(futures), total=len(args_list), desc=desc):
|
|
83
|
+
idx = futures[future]
|
|
84
|
+
results[idx] = future.result()
|
|
85
|
+
|
|
86
|
+
# If results contain multiple values (tuples), transpose them
|
|
87
|
+
if results and isinstance(results[0], tuple):
|
|
88
|
+
n_vars = len(results[0])
|
|
89
|
+
return tuple([result[i] for result in results] for i in range(n_vars))
|
|
90
|
+
|
|
91
|
+
return results
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# def run(
|
|
95
|
+
# func: Callable,
|
|
96
|
+
# items: List[Any],
|
|
97
|
+
# n_jobs: int = -1,
|
|
98
|
+
# desc: str = "Processing",
|
|
99
|
+
# ) -> List[Any]:
|
|
100
|
+
# """Runs function in parallel using ThreadPoolExecutor.
|
|
101
|
+
|
|
102
|
+
# Parameters
|
|
103
|
+
# ----------
|
|
104
|
+
# func : Callable
|
|
105
|
+
# Function to run in parallel
|
|
106
|
+
# items : List[Any]
|
|
107
|
+
# List of items to process
|
|
108
|
+
# n_jobs : int, optional
|
|
109
|
+
# Number of jobs to run in parallel. -1 means using all processors
|
|
110
|
+
# desc : str, optional
|
|
111
|
+
# Description for progress bar
|
|
112
|
+
|
|
113
|
+
# Returns
|
|
114
|
+
# -------
|
|
115
|
+
# List[Any]
|
|
116
|
+
# Results of parallel execution
|
|
117
|
+
# """
|
|
118
|
+
# if not items:
|
|
119
|
+
# raise ValueError("Items list cannot be empty")
|
|
120
|
+
# if not callable(func):
|
|
121
|
+
# raise ValueError("Func must be callable")
|
|
122
|
+
# if not isinstance(items, (list, tuple)):
|
|
123
|
+
# raise TypeError("Items must be a list or tuple")
|
|
124
|
+
# if not isinstance(n_jobs, int):
|
|
125
|
+
# raise TypeError("n_jobs must be an integer")
|
|
126
|
+
|
|
127
|
+
# cpu_count = multiprocessing.cpu_count()
|
|
128
|
+
# n_jobs = cpu_count if n_jobs < 0 else n_jobs
|
|
129
|
+
|
|
130
|
+
# if n_jobs > cpu_count:
|
|
131
|
+
# warnings.warn(f"n_jobs ({n_jobs}) is greater than CPU count ({cpu_count})")
|
|
132
|
+
# if n_jobs < 1:
|
|
133
|
+
# raise ValueError("n_jobs must be >= 1 or -1")
|
|
134
|
+
|
|
135
|
+
# results = [None] * len(items) # Pre-allocate list
|
|
136
|
+
# with ThreadPoolExecutor(max_workers=n_jobs) as executor:
|
|
137
|
+
# futures = {executor.submit(func, item): idx
|
|
138
|
+
# for idx, item in enumerate(items)}
|
|
139
|
+
# for future in tqdm(as_completed(futures), total=len(items), desc=desc):
|
|
140
|
+
# idx = futures[future]
|
|
141
|
+
# results[idx] = future.result()
|
|
142
|
+
|
|
143
|
+
# # If results contain multiple values (tuples), transpose them
|
|
144
|
+
# if results and isinstance(results[0], tuple):
|
|
145
|
+
# n_vars = len(results[0])
|
|
146
|
+
# return tuple([result[i] for result in results] for i in range(n_vars))
|
|
147
|
+
|
|
148
|
+
# return results
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# EOF
|