scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
scitex/nn/_Filters.py ADDED
@@ -0,0 +1,489 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-05-28 17:05:26 (ywatanabe)"
4
+ # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Filters.py
5
+ # ----------------------------------------
6
+ import os
7
+
8
+ __FILE__ = "./src/scitex/nn/_Filters.py"
9
+ __DIR__ = os.path.dirname(__FILE__)
10
+ # ----------------------------------------
11
+
12
+ # Time-stamp: "2024-11-26 22:23:40 (ywatanabe)"
13
+
14
+ import numpy as np
15
+
16
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Filters.py"
17
+
18
+ """
19
+ Implements various neural network filter layers:
20
+ - BaseFilter1D: Abstract base class for 1D filters
21
+ - BandPassFilter: Implements bandpass filtering
22
+ - BandStopFilter: Implements bandstop filtering
23
+ - LowPassFilter: Implements lowpass filtering
24
+ - HighPassFilter: Implements highpass filtering
25
+ - GaussianFilter: Implements Gaussian smoothing
26
+ - DifferentiableBandPassFilter: Implements learnable bandpass filtering
27
+ """
28
+
29
+ # Imports
30
+ import sys
31
+ from abc import abstractmethod
32
+
33
+ import matplotlib.pyplot as plt
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from ..dsp.utils import build_bandpass_filters, init_bandpass_filters
40
+ from ..dsp.utils._ensure_3d import ensure_3d
41
+ from ..dsp.utils._ensure_even_len import ensure_even_len
42
+ from ..dsp.utils._zero_pad import zero_pad
43
+ from ..dsp.utils.filter import design_filter
44
+ from ..gen._to_even import to_even
45
+
46
+
47
+ class BaseFilter1D(nn.Module):
48
+ def __init__(self, fp16=False, in_place=False):
49
+ super().__init__()
50
+ self.fp16 = fp16
51
+ self.in_place = in_place
52
+ # self.kernels = None
53
+
54
+ @abstractmethod
55
+ def init_kernels(
56
+ self,
57
+ ):
58
+ """
59
+ Abstract method to initialize filter kernels.
60
+ Must be implemented by subclasses.
61
+ """
62
+ pass
63
+
64
+ def forward(self, x, t=None, edge_len=0):
65
+ """Apply the filter to input signal x with shape: (batch_size, n_chs, seq_len)"""
66
+
67
+ # Shape check
68
+ if self.fp16:
69
+ x = x.half()
70
+
71
+ x = ensure_3d(x)
72
+ batch_size, n_chs, seq_len = x.shape
73
+
74
+ # Kernel Check
75
+ if self.kernels is None:
76
+ raise ValueError("Filter kernels has not been initialized.")
77
+
78
+ # Filtering
79
+ x = self.flip_extend(x, self.kernel_size // 2)
80
+ x = self.batch_conv(x, self.kernels, padding=0)
81
+ x = x[..., :seq_len]
82
+
83
+ assert x.shape == (
84
+ batch_size,
85
+ n_chs,
86
+ len(self.kernels),
87
+ seq_len,
88
+ ), f"The shape of the filtered signal ({x.shape}) does not match the expected shape: ({batch_size}, {n_chs}, {len(self.kernels)}, {seq_len})."
89
+
90
+ # Edge remove
91
+ x = self.remove_edges(x, edge_len)
92
+
93
+ if t is None:
94
+ return x
95
+ else:
96
+ t = self.remove_edges(t, edge_len)
97
+ return x, t
98
+
99
+ @property
100
+ def kernel_size(
101
+ self,
102
+ ):
103
+ ks = self.kernels.shape[-1]
104
+ # if not ks % 2 == 0:
105
+ # raise ValueError("Kernel size should be an even number.")
106
+ return ks
107
+
108
+ @staticmethod
109
+ def flip_extend(x, extension_length):
110
+ first_segment = x[:, :, :extension_length].flip(dims=[-1])
111
+ last_segment = x[:, :, -extension_length:].flip(dims=[-1])
112
+ return torch.cat([first_segment, x, last_segment], dim=-1)
113
+
114
+ @staticmethod
115
+ def batch_conv(x, kernels, padding="same"):
116
+ """
117
+ x: (batch_size, n_chs, seq_len)
118
+ kernels: (n_kernels, seq_len_filt)
119
+ """
120
+ assert x.ndim == 3
121
+ assert kernels.ndim == 2
122
+ batch_size, n_chs, n_time = x.shape
123
+ x = x.reshape(-1, x.shape[-1]).unsqueeze(1)
124
+ kernels = kernels.unsqueeze(1) # add the channel dimension
125
+ n_kernels = len(kernels)
126
+ filted = F.conv1d(x, kernels.type_as(x), padding=padding)
127
+ return filted.reshape(batch_size, n_chs, n_kernels, -1)
128
+
129
+ @staticmethod
130
+ def remove_edges(x, edge_len):
131
+ edge_len = x.shape[-1] // 8 if edge_len == "auto" else edge_len
132
+
133
+ if 0 < edge_len:
134
+ return x[..., edge_len:-edge_len]
135
+ else:
136
+ return x
137
+
138
+
139
+ class BandPassFilter(BaseFilter1D):
140
+ def __init__(self, bands, fs, seq_len, fp16=False):
141
+ super().__init__(fp16=fp16)
142
+
143
+ self.fp16 = fp16
144
+
145
+ # Ensures bands shape
146
+ assert bands.ndim == 2
147
+
148
+ # Check bands definitions
149
+ nyq = fs / 2.0
150
+ # Convert bands to tensor if it's a numpy array
151
+ if isinstance(bands, np.ndarray):
152
+ bands = torch.tensor(bands)
153
+ bands = torch.clip(bands, 0.1, nyq - 1)
154
+ for ll, hh in bands:
155
+ assert 0 < ll
156
+ assert ll < hh
157
+ assert hh < nyq
158
+
159
+ # Prepare kernels
160
+ kernels = self.init_kernels(seq_len, fs, bands)
161
+ if fp16:
162
+ kernels = kernels.half()
163
+ self.register_buffer(
164
+ "kernels",
165
+ kernels,
166
+ )
167
+
168
+ @staticmethod
169
+ def init_kernels(seq_len, fs, bands):
170
+ # Convert seq_len and fs to numpy arrays for design_filter (expects numpy_fn)
171
+ seq_len_array = np.array([seq_len])
172
+ fs_array = np.array([fs])
173
+ filters = [
174
+ design_filter(
175
+ seq_len_array,
176
+ fs_array,
177
+ low_hz=ll,
178
+ high_hz=hh,
179
+ is_bandstop=False,
180
+ )
181
+ for ll, hh in bands
182
+ ]
183
+
184
+ # Convert filters list to tensors for zero_pad
185
+ filters_tensors = [
186
+ torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters
187
+ ]
188
+
189
+ kernels = zero_pad(filters_tensors)
190
+ kernels = ensure_even_len(kernels)
191
+ if not isinstance(kernels, torch.Tensor):
192
+ kernels = torch.tensor(kernels)
193
+ kernels = kernels.clone().detach()
194
+ # kernels = kernels.clone().detach().requires_grad_(True)
195
+ return kernels
196
+
197
+
198
+ # /home/ywatanabe/proj/scitex/src/scitex/nn/_Filters.py:155: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
199
+ # kernels = torch.tensor(kernels).clone().detach()
200
+
201
+
202
+ class BandStopFilter(BaseFilter1D):
203
+ def __init__(self, bands, fs, seq_len):
204
+ super().__init__()
205
+
206
+ # Ensures bands shape
207
+ assert bands.ndim == 2
208
+
209
+ # Check bands definitions
210
+ nyq = fs / 2.0
211
+ bands = np.clip(bands, 0.1, nyq - 1)
212
+ for ll, hh in bands:
213
+ assert 0 < ll
214
+ assert ll < hh
215
+ assert hh < nyq
216
+
217
+ self.register_buffer("kernels", self.init_kernels(seq_len, fs, bands))
218
+
219
+ @staticmethod
220
+ def init_kernels(seq_len, fs, bands):
221
+ # Convert to numpy arrays for design_filter
222
+ seq_len_array = np.array([seq_len])
223
+ fs_array = np.array([fs])
224
+ filters = [
225
+ design_filter(
226
+ seq_len_array, fs_array, low_hz=ll, high_hz=hh, is_bandstop=True
227
+ )
228
+ for ll, hh in bands
229
+ ]
230
+ # Convert filters list to tensors for zero_pad
231
+ filters_tensors = [
232
+ torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters
233
+ ]
234
+ kernels = zero_pad(filters_tensors)
235
+ kernels = ensure_even_len(kernels)
236
+ if not isinstance(kernels, torch.Tensor):
237
+ kernels = torch.tensor(kernels)
238
+ return kernels
239
+
240
+
241
+ class LowPassFilter(BaseFilter1D):
242
+ def __init__(self, cutoffs_hz, fs, seq_len):
243
+ super().__init__()
244
+
245
+ # Ensures bands shape
246
+ assert cutoffs_hz.ndim == 1
247
+
248
+ # Check bands definitions
249
+ nyq = fs / 2.0
250
+ bands = np.clip(cutoffs_hz, 0.1, nyq - 1)
251
+ for cc in cutoffs_hz:
252
+ assert 0 < cc
253
+ assert cc < nyq
254
+
255
+ self.register_buffer("kernels", self.init_kernels(seq_len, fs, cutoffs_hz))
256
+
257
+ @staticmethod
258
+ def init_kernels(seq_len, fs, cutoffs_hz):
259
+ # Convert to numpy arrays for design_filter
260
+ seq_len_array = np.array([seq_len])
261
+ fs_array = np.array([fs])
262
+ filters = [
263
+ design_filter(
264
+ seq_len_array, fs_array, low_hz=None, high_hz=cc, is_bandstop=False
265
+ )
266
+ for cc in cutoffs_hz
267
+ ]
268
+ # Convert filters list to tensors for zero_pad
269
+ filters_tensors = [
270
+ torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters
271
+ ]
272
+ kernels = zero_pad(filters_tensors)
273
+ kernels = ensure_even_len(kernels)
274
+ if not isinstance(kernels, torch.Tensor):
275
+ kernels = torch.tensor(kernels)
276
+ return kernels
277
+
278
+
279
+ class HighPassFilter(BaseFilter1D):
280
+ def __init__(self, cutoffs_hz, fs, seq_len):
281
+ super().__init__()
282
+
283
+ # Ensures bands shape
284
+ assert cutoffs_hz.ndim == 1
285
+
286
+ # Check bands definitions
287
+ nyq = fs / 2.0
288
+ bands = np.clip(cutoffs_hz, 0.1, nyq - 1)
289
+ for cc in cutoffs_hz:
290
+ assert 0 < cc
291
+ assert cc < nyq
292
+
293
+ self.register_buffer("kernels", self.init_kernels(seq_len, fs, cutoffs_hz))
294
+
295
+ @staticmethod
296
+ def init_kernels(seq_len, fs, cutoffs_hz):
297
+ # Convert to numpy arrays for design_filter
298
+ seq_len_array = np.array([seq_len])
299
+ fs_array = np.array([fs])
300
+ filters = [
301
+ design_filter(
302
+ seq_len_array, fs_array, low_hz=cc, high_hz=None, is_bandstop=False
303
+ )
304
+ for cc in cutoffs_hz
305
+ ]
306
+ # Convert filters list to tensors for zero_pad
307
+ filters_tensors = [
308
+ torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters
309
+ ]
310
+ kernels = zero_pad(filters_tensors)
311
+ kernels = ensure_even_len(kernels)
312
+ if not isinstance(kernels, torch.Tensor):
313
+ kernels = torch.tensor(kernels)
314
+ return kernels
315
+
316
+
317
+ class GaussianFilter(BaseFilter1D):
318
+ def __init__(self, sigma):
319
+ super().__init__()
320
+ self.sigma = to_even(sigma)
321
+ self.register_buffer("kernels", self.init_kernels(sigma))
322
+
323
+ @staticmethod
324
+ def init_kernels(sigma):
325
+ kernel_size = sigma * 6 # +/- 3SD
326
+ kernel_range = torch.arange(0, kernel_size) - kernel_size // 2
327
+ kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2)
328
+ kernel /= kernel.sum()
329
+ kernels = kernel.unsqueeze(0) # n_filters = 1
330
+ kernels = ensure_even_len(kernels)
331
+ return torch.tensor(kernels)
332
+
333
+
334
+ class DifferentiableBandPassFilter(BaseFilter1D):
335
+ def __init__(
336
+ self,
337
+ sig_len,
338
+ fs,
339
+ pha_low_hz=2,
340
+ pha_high_hz=20,
341
+ pha_n_bands=30,
342
+ amp_low_hz=80,
343
+ amp_high_hz=160,
344
+ amp_n_bands=50,
345
+ cycle=3,
346
+ fp16=False,
347
+ ):
348
+ super().__init__(fp16=fp16)
349
+
350
+ # Attributes
351
+ self.pha_low_hz = pha_low_hz
352
+ self.pha_high_hz = pha_high_hz
353
+ self.amp_low_hz = amp_low_hz
354
+ self.amp_high_hz = amp_high_hz
355
+ self.sig_len = sig_len
356
+ self.fs = fs
357
+ self.cycle = cycle
358
+ self.fp16 = fp16
359
+
360
+ # Check bands definitions
361
+ nyq = fs / 2.0
362
+ pha_high_hz = torch.tensor(pha_high_hz).clip(0.1, nyq - 1)
363
+ pha_low_hz = torch.tensor(pha_low_hz).clip(0.1, pha_high_hz - 1)
364
+ amp_high_hz = torch.tensor(amp_high_hz).clip(0.1, nyq - 1)
365
+ amp_low_hz = torch.tensor(amp_low_hz).clip(0.1, amp_high_hz - 1)
366
+
367
+ assert pha_low_hz < pha_high_hz < nyq
368
+ assert amp_low_hz < amp_high_hz < nyq
369
+
370
+ # Prepare kernels
371
+ self.init_kernels = init_bandpass_filters
372
+ self.build_bandpass_filters = build_bandpass_filters
373
+ kernels, self.pha_mids, self.amp_mids = self.init_kernels(
374
+ sig_len=sig_len,
375
+ fs=fs,
376
+ pha_low_hz=pha_low_hz,
377
+ pha_high_hz=pha_high_hz,
378
+ pha_n_bands=pha_n_bands,
379
+ amp_low_hz=amp_low_hz,
380
+ amp_high_hz=amp_high_hz,
381
+ amp_n_bands=amp_n_bands,
382
+ cycle=cycle,
383
+ )
384
+
385
+ self.register_buffer(
386
+ "kernels",
387
+ kernels,
388
+ )
389
+ # self.register_buffer("pha_mids", pha_mids)
390
+ # self.register_buffer("amp_mids", amp_mids)
391
+ # self.pha_mids = nn.Parameter(pha_mids.detach())
392
+ # self.amp_mids = nn.Parameter(amp_mids.detach())
393
+
394
+ if fp16:
395
+ self.kernels = self.kernels.half()
396
+ # self.pha_mids = self.pha_mids.half()
397
+ # self.amp_mids = self.amp_mids.half()
398
+
399
+ def forward(self, x, t=None, edge_len=0):
400
+ # Constrains the parameter spaces
401
+ torch.clip(self.pha_mids, self.pha_low_hz, self.pha_high_hz)
402
+ torch.clip(self.amp_mids, self.amp_low_hz, self.amp_high_hz)
403
+
404
+ self.kernels = self.build_bandpass_filters(
405
+ self.sig_len, self.fs, self.pha_mids, self.amp_mids, self.cycle
406
+ )
407
+ return super().forward(x=x, t=t, edge_len=edge_len)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ import scitex
412
+
413
+ # Start
414
+ CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.gen.start(sys, plt, fig_scale=5)
415
+
416
+ xx, tt, fs = scitex.dsp.demo_sig(sig_type="chirp", fs=1024)
417
+ xx = torch.tensor(xx).cuda()
418
+ # bands = np.array([[2, 3], [3, 4]])
419
+ # BandPassFilter(bands, fs, xx.shape)
420
+ m = DifferentiableBandPassFilter(xx.shape[-1], fs).cuda()
421
+
422
+ scitex.ml.utils.check_params(m)
423
+ # {'pha_mids': (torch.Size([30]), 'Learnable'),
424
+ # 'amp_mids': (torch.Size([50]), 'Learnable')}
425
+
426
+ xf = m(xx) # (8, 19, 80, 2048)
427
+
428
+ xf.sum().backward() # OK, differentiable
429
+
430
+ m.pha_mids
431
+ # Parameter containing:
432
+ # tensor([ 2.0000, 2.6207, 3.2414, 3.8621, 4.4828, 5.1034, 5.7241, 6.3448,
433
+ # 6.9655, 7.5862, 8.2069, 8.8276, 9.4483, 10.0690, 10.6897, 11.3103,
434
+ # 11.9310, 12.5517, 13.1724, 13.7931, 14.4138, 15.0345, 15.6552, 16.2759,
435
+ # 16.8966, 17.5172, 18.1379, 18.7586, 19.3793, 20.0000],
436
+ # requires_grad=True)
437
+ m.amp_mids
438
+ # Parameter containing:
439
+ # tensor([ 80.0000, 81.6327, 83.2653, 84.8980, 86.5306, 88.1633, 89.7959,
440
+ # 91.4286, 93.0612, 94.6939, 96.3265, 97.9592, 99.5918, 101.2245,
441
+ # 102.8571, 104.4898, 106.1225, 107.7551, 109.3878, 111.0204, 112.6531,
442
+ # 114.2857, 115.9184, 117.5510, 119.1837, 120.8163, 122.4490, 124.0816,
443
+ # 125.7143, 127.3469, 128.9796, 130.6122, 132.2449, 133.8775, 135.5102,
444
+ # 137.1429, 138.7755, 140.4082, 142.0408, 143.6735, 145.3061, 146.9388,
445
+ # 148.5714, 150.2041, 151.8367, 153.4694, 155.1020, 156.7347, 158.3673,
446
+ # 160.0000], requires_grad=True)
447
+
448
+ # PSD
449
+ bands = torch.hstack([m.pha_mids, m.amp_mids])
450
+
451
+ # Plots PSD
452
+ # matplotlib.use("TkAgg")
453
+ fig, axes = scitex.plt.subplots(nrows=1 + len(bands), ncols=2)
454
+
455
+ psd, ff = scitex.dsp.psd(xx, fs) # Orig
456
+ axes[0, 0].plot(tt, xx[0, 0].detach().cpu().numpy(), label="orig")
457
+ axes[0, 1].plot(
458
+ ff.detach().cpu().numpy(),
459
+ psd[0, 0].detach().cpu().numpy(),
460
+ label="orig",
461
+ )
462
+
463
+ for i_filt in range(len(bands)):
464
+ mid_hz = int(bands[i_filt].item())
465
+ psd_f, ff_f = scitex.dsp.psd(xf[:, :, i_filt, :], fs)
466
+ axes[i_filt + 1, 0].plot(
467
+ tt,
468
+ xf[0, 0, i_filt].detach().cpu().numpy(),
469
+ label=f"filted at {mid_hz} Hz",
470
+ )
471
+ axes[i_filt + 1, 1].plot(
472
+ ff_f.detach().cpu().numpy(),
473
+ psd_f[0, 0].detach().cpu().numpy(),
474
+ label=f"filted at {mid_hz} Hz",
475
+ )
476
+ for ax in axes.ravel():
477
+ ax.legend(loc="upper left")
478
+
479
+ scitex.io.save(fig, "traces.png")
480
+ # plt.show()
481
+
482
+ # Close
483
+ scitex.gen.close(CONFIG)
484
+
485
+ """
486
+ /home/ywatanabe/proj/entrance/scitex/dsp/nn/_Filters.py
487
+ """
488
+
489
+ # EOF
@@ -0,0 +1,110 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2023-04-23 11:02:34 (ywatanabe)"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchsummary import summary
9
+ import scitex
10
+ import numpy as np
11
+ import julius
12
+
13
+ # BANDS_LIM_HZ_DICT = {
14
+ # "delta": [0.5, 4],
15
+ # "theta": [4, 8],
16
+ # "lalpha": [8, 10],
17
+ # "halpha": [10, 13],
18
+ # "beta": [13, 32],
19
+ # "gamma": [32, 75],
20
+ # }
21
+
22
+
23
+ # class FreqDropout(nn.Module):
24
+ # def __init__(self, n_bands, samp_rate, dropout_ratio=0.5):
25
+ # super().__init__()
26
+ # self.dropout = nn.Dropout(p=0.5)
27
+ # self.n_bands = n_bands
28
+ # self.samp_rate = samp_rate
29
+ # # self.
30
+ # self.register_buffer("ones", torch.ones(self.n_bands))
31
+
32
+ # def forward(self, x):
33
+ # """x: [batch_size, n_chs, seq_len]"""
34
+ # x = julius.bands.split_bands(x, self.samp_rate, n_bands=self.n_bands)
35
+
36
+ # gains_orig = x.reshape(len(x), -1).abs().sum(axis=-1)
37
+ # sum_gains_orig = gains_orig.sum()
38
+
39
+ # # use_freqs = self.dropout(torch.ones(self.n_bands)).bool().long()
40
+ # use_freqs = self.dropout(self.ones) / 2 # .bool().long()
41
+
42
+ # gains = gains_orig * use_freqs
43
+ # sum_gains = gains.sum()
44
+ # gain_ratio = sum_gains / sum_gains_orig
45
+
46
+
47
+ # x *= use_freqs.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
48
+ # x /= gain_ratio
49
+ # x = x.sum(axis=0)
50
+
51
+ # return x
52
+
53
+
54
+ class FreqGainChanger(nn.Module):
55
+ def __init__(self, n_bands, samp_rate, dropout_ratio=0.5):
56
+ super().__init__()
57
+ self.dropout = nn.Dropout(p=0.5)
58
+ self.n_bands = n_bands
59
+ self.samp_rate = samp_rate
60
+ # self.register_buffer("ones", torch.ones(self.n_bands))
61
+
62
+ def forward(self, x):
63
+ """x: [batch_size, n_chs, seq_len]"""
64
+ if self.training:
65
+ x = julius.bands.split_bands(x, self.samp_rate, n_bands=self.n_bands)
66
+ freq_gains = (
67
+ torch.rand(self.n_bands)
68
+ .unsqueeze(-1)
69
+ .unsqueeze(-1)
70
+ .unsqueeze(-1)
71
+ .to(x.device)
72
+ + 0.5
73
+ )
74
+ freq_gains = F.softmax(freq_gains, dim=0)
75
+ x = (x * freq_gains).sum(axis=0)
76
+
77
+ return x
78
+ # import ipdb; ipdb.set_trace()
79
+
80
+ # gains_orig = x.reshape(len(x), -1).abs().sum(axis=-1)
81
+ # sum_gains_orig = gains_orig.sum()
82
+
83
+ # # use_freqs = self.dropout(torch.ones(self.n_bands)).bool().long()
84
+ # use_freqs = self.dropout(self.ones) / 2 # .bool().long()
85
+
86
+ # gains = gains_orig * use_freqs
87
+ # sum_gains = gains.sum()
88
+ # gain_ratio = sum_gains / sum_gains_orig
89
+
90
+ # x *= use_freqs.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
91
+ # x /= gain_ratio
92
+ # x = x.sum(axis=0)
93
+
94
+ # return x
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Parameters
99
+ N_BANDS = 10
100
+ SAMP_RATE = 1000
101
+ BS, N_CHS, SEQ_LEN = 16, 360, 1000
102
+
103
+ # Demo data
104
+ x = torch.rand(BS, N_CHS, SEQ_LEN).cuda()
105
+
106
+ # Feedforward
107
+ fgc = FreqGainChanger(N_BANDS, SAMP_RATE).cuda()
108
+ # fd.eval()
109
+ y = fgc(x)
110
+ y.sum().backward()
@@ -0,0 +1,48 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-04-01 18:14:44 (ywatanabe)"
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchaudio.transforms as T
11
+
12
+
13
+ class GaussianFilter(nn.Module):
14
+ def __init__(self, radius, sigma=None):
15
+ super().__init__()
16
+ if sigma is None:
17
+ sigma = radius / 2
18
+ self.radius = radius
19
+ self.register_buffer("kernel", self.gen_kernel_1d(radius, sigma=sigma))
20
+
21
+ @staticmethod
22
+ def gen_kernel_1d(radius, sigma=None):
23
+ if sigma is None:
24
+ sigma = radius / 2
25
+
26
+ kernel_size = 2 * radius + 1
27
+ x = torch.arange(kernel_size).float() - radius
28
+
29
+ kernel = torch.exp(-0.5 * (x / sigma) ** 2)
30
+ kernel = kernel / (sigma * math.sqrt(2 * math.pi))
31
+ kernel = kernel / torch.sum(kernel)
32
+
33
+ return kernel.unsqueeze(0).unsqueeze(0)
34
+
35
+ def forward(self, x):
36
+ """x.shape: (batch_size, n_chs, seq_len)"""
37
+
38
+ if x.ndim == 1:
39
+ x = x.unsqueeze(0).unsqueeze(0)
40
+ elif x.ndim == 2:
41
+ x = x.unsqueeze(1)
42
+
43
+ channels = x.size(1)
44
+ kernel = self.kernel.expand(channels, 1, -1).to(x.device).to(x.dtype)
45
+
46
+ return torch.nn.functional.conv1d(
47
+ x, kernel, padding=self.radius, groups=channels
48
+ )