scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,161 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-04-02 09:21:12 (ywatanabe)"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from ..decorators import numpy_fn, torch_fn
9
+
10
+
11
+ class Spectrogram(nn.Module):
12
+ def __init__(
13
+ self,
14
+ sampling_rate,
15
+ n_fft=256,
16
+ hop_length=None,
17
+ win_length=None,
18
+ window="hann",
19
+ ):
20
+ super().__init__()
21
+ self.sampling_rate = sampling_rate
22
+ self.n_fft = n_fft
23
+ self.hop_length = hop_length if hop_length is not None else n_fft // 4
24
+ self.win_length = win_length if win_length is not None else n_fft
25
+ if window == "hann":
26
+ self.window = torch.hann_window(window_length=self.win_length)
27
+ else:
28
+ raise ValueError(
29
+ "Unsupported window type. Extend this to support more window types."
30
+ )
31
+
32
+ def forward(self, x):
33
+ """
34
+ Computes the spectrogram for each channel in the input signal.
35
+
36
+ Parameters:
37
+ - signal (torch.Tensor): Input signal of shape (batch_size, n_chs, seq_len).
38
+
39
+ Returns:
40
+ - spectrograms (torch.Tensor): The computed spectrograms for each channel.
41
+ """
42
+
43
+ x = scitex.dsp.ensure_3d(x)
44
+
45
+ batch_size, n_chs, seq_len = x.shape
46
+ spectrograms = []
47
+
48
+ for ch in range(n_chs):
49
+ x_ch = x[:, ch, :].unsqueeze(1) # Maintain expected input shape for stft
50
+ spec = torch.stft(
51
+ x_ch.squeeze(1),
52
+ n_fft=self.n_fft,
53
+ hop_length=self.hop_length,
54
+ win_length=self.win_length,
55
+ window=self.window.to(x.device),
56
+ center=True,
57
+ pad_mode="reflect",
58
+ normalized=False,
59
+ return_complex=True,
60
+ )
61
+ magnitude = torch.abs(spec).unsqueeze(1) # Keep channel dimension
62
+ spectrograms.append(magnitude)
63
+
64
+ # Concatenate spectrograms along channel dimension
65
+ spectrograms = torch.cat(spectrograms, dim=1)
66
+
67
+ # Calculate frequencies (y-axis)
68
+ freqs = torch.linspace(0, self.sampling_rate / 2, steps=self.n_fft // 2 + 1)
69
+
70
+ # Calculate times (x-axis)
71
+ # The number of frames can be computed from the size of the last dimension of the spectrogram
72
+ n_frames = spectrograms.shape[-1]
73
+ # Time of each frame in seconds, considering the hop length and sampling rate
74
+ times_sec = torch.arange(0, n_frames) * (self.hop_length / self.sampling_rate)
75
+
76
+ return spectrograms, freqs, times_sec
77
+
78
+
79
+ @torch_fn
80
+ def spectrograms(x, fs, cuda=False):
81
+ return Spectrogram(fs)(x)
82
+
83
+
84
+ @torch_fn
85
+ def my_softmax(x, dim=-1):
86
+ return F.softmax(x, dim=dim)
87
+
88
+
89
+ @torch_fn
90
+ def unbias(x, func="min", dim=-1, cuda=False):
91
+ if func == "min":
92
+ return x - x.min(dim=dim, keepdims=True)[0]
93
+ if func == "mean":
94
+ return x - x.mean(dim=dim, keepdims=True)[0]
95
+
96
+
97
+ @torch_fn
98
+ def normalize(x, axis=-1, amp=1.0, cuda=False):
99
+ high = torch.abs(x.max(axis=axis, keepdims=True)[0])
100
+ low = torch.abs(x.min(axis=axis, keepdims=True)[0])
101
+ return amp * x / torch.maximum(high, low)
102
+
103
+
104
+ @torch_fn
105
+ def spectrograms(x, fs, dj=0.125, cuda=False):
106
+ from wavelets_pytorch.transform import (
107
+ WaveletTransformTorch,
108
+ ) # PyTorch version
109
+
110
+ dt = 1 / fs
111
+ # dj = 0.125
112
+ batch_size, n_chs, seq_len = x.shape
113
+
114
+ x = x.cpu().numpy()
115
+
116
+ # # Batch of signals to process
117
+ # batch = np.array([batch_size * seq_len])
118
+
119
+ # Initialize wavelet filter banks (scipy and torch implementation)
120
+ # wa_scipy = WaveletTransform(dt, dj)
121
+ wa_torch = WaveletTransformTorch(dt, dj, cuda=True)
122
+
123
+ # Performing wavelet transform (and compute scalogram)
124
+ # cwt_scipy = wa_scipy.cwt(batch)
125
+ x = x[:, 0][:, np.newaxis]
126
+ cwt_torch = wa_torch.cwt(x)
127
+
128
+ return cwt_torch
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import scitex
133
+ import seaborn as sns
134
+ import torchaudio
135
+
136
+ fs = 1024 # 128
137
+ t_sec = 10
138
+ x = scitex.dsp.np.demo_sig(t_sec=t_sec, fs=fs, type="ripple")
139
+
140
+ normalize(unbias(x, cuda=True), cuda=True)
141
+
142
+ # My implementtion
143
+ ss = spectrograms(x, fs, cuda=True)
144
+ fig, axes = plt.subplots(nrows=2)
145
+ axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0])
146
+ sns.heatmap(ss[0], ax=axes[1])
147
+ plt.show()
148
+
149
+ ss, ff, tt = spectrograms(x, fs, cuda=True)
150
+ fig, axes = plt.subplots(nrows=2)
151
+ axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0])
152
+ sns.heatmap(ss[0], ax=axes[1])
153
+ plt.show()
154
+
155
+ # Torch Audio
156
+ transform = torchaudio.transforms.Spectrogram(n_fft=16, normalized=True).cuda()
157
+ xx = torch.tensor(x).float().cuda()[0, 0]
158
+ ss = transform(xx)
159
+ sns.heatmap(ss.detach().cpu().numpy())
160
+
161
+ plt.show()
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2023-05-04 21:21:19 (ywatanabe)"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchsummary import summary
9
+ import scitex
10
+ import numpy as np
11
+ import random
12
+
13
+
14
+ class SwapChannels(nn.Module):
15
+ def __init__(self, dropout=0.5):
16
+ super().__init__()
17
+ self.dropout = nn.Dropout(p=dropout)
18
+
19
+ def forward(self, x):
20
+ """x: [batch_size, n_chs, seq_len]"""
21
+ if self.training:
22
+ orig_chs = torch.arange(x.shape[1])
23
+
24
+ indi_orig = self.dropout(torch.ones(x.shape[1])).bool()
25
+ chs_to_shuffle = orig_chs[~indi_orig]
26
+
27
+ rand_chs = random.sample(
28
+ list(np.array(chs_to_shuffle)), len(chs_to_shuffle)
29
+ )
30
+
31
+ swapped_chs = orig_chs.clone()
32
+ swapped_chs[~indi_orig] = torch.LongTensor(rand_chs)
33
+
34
+ x = x[:, swapped_chs.long(), :]
35
+
36
+ return x
37
+
38
+
39
+ if __name__ == "__main__":
40
+ ## Demo data
41
+ bs, n_chs, seq_len = 16, 360, 1000
42
+ x = torch.rand(bs, n_chs, seq_len)
43
+
44
+ sc = SwapChannels()
45
+ print(sc(x).shape) # [16, 19, 1000]
46
+
47
+ # sb = SubjectBlock(n_chs=n_chs)
48
+ # print(sb(x, s).shape) # [16, 270, 1000]
49
+
50
+ # summary(sb, x, s)
@@ -0,0 +1,19 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-03-30 07:26:35 (ywatanabe)"
4
+
5
+ import torch.nn as nn
6
+
7
+
8
+ class TransposeLayer(nn.Module):
9
+ def __init__(
10
+ self,
11
+ axis1,
12
+ axis2,
13
+ ):
14
+ super().__init__()
15
+ self.axis1 = axis1
16
+ self.axis2 = axis2
17
+
18
+ def forward(self, x):
19
+ return x.transpose(self.axis1, self.axis2)
scitex/nn/_Wavelet.py ADDED
@@ -0,0 +1,183 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-03 07:17:26 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/nn/_Wavelet.py
5
+
6
+ #!/usr/bin/env python3
7
+ # -*- coding: utf-8 -*-
8
+ # Time-stamp: "2024-05-30 11:04:45 (ywatanabe)"
9
+
10
+
11
+ import scitex
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from ..gen._to_even import to_even
17
+ from ..gen._to_odd import to_odd
18
+
19
+
20
+ class Wavelet(nn.Module):
21
+ def __init__(
22
+ self, samp_rate, kernel_size=None, freq_scale="linear", out_scale="log"
23
+ ):
24
+ super().__init__()
25
+ self.register_buffer("dummy", torch.tensor(0))
26
+ self.kernel = None
27
+ self.init_kernel(samp_rate, kernel_size=kernel_size, freq_scale=freq_scale)
28
+ self.out_scale = out_scale
29
+
30
+ def forward(self, x):
31
+ """Apply the 2D filter (n_filts, kernel_size) to input signal x with shape: (batch_size, n_chs, seq_len)"""
32
+ x = scitex.dsp.ensure_3d(x).to(self.dummy.device)
33
+ seq_len = x.shape[-1]
34
+
35
+ # Ensure the kernel is initialized
36
+ if self.kernel is None:
37
+ self.init_kernel()
38
+ if self.kernel is None:
39
+ raise ValueError("Filter kernel has not been initialized.")
40
+ assert self.kernel.ndim == 2
41
+ self.kernel = self.kernel.to(x.device) # cuda, torch.complex128
42
+
43
+ # Edge handling and convolution
44
+ extension_length = self.radius
45
+ first_segment = x[:, :, :extension_length].flip(dims=[-1])
46
+ last_segment = x[:, :, -extension_length:].flip(dims=[-1])
47
+ extended_x = torch.cat([first_segment, x, last_segment], dim=-1)
48
+
49
+ # working??
50
+ kernel_batched = self.kernel.unsqueeze(1)
51
+ extended_x_reshaped = extended_x.view(-1, 1, extended_x.shape[-1])
52
+
53
+ filtered_x_real = F.conv1d(
54
+ extended_x_reshaped, kernel_batched.real.float(), groups=1
55
+ )
56
+ filtered_x_imag = F.conv1d(
57
+ extended_x_reshaped, kernel_batched.imag.float(), groups=1
58
+ )
59
+
60
+ filtered_x = torch.view_as_complex(
61
+ torch.stack([filtered_x_real, filtered_x_imag], dim=-1)
62
+ )
63
+
64
+ filtered_x = filtered_x.view(
65
+ x.shape[0], x.shape[1], kernel_batched.shape[0], -1
66
+ )
67
+ filtered_x = filtered_x.view(
68
+ x.shape[0], x.shape[1], kernel_batched.shape[0], -1
69
+ )
70
+ filtered_x = filtered_x[..., :seq_len]
71
+ assert filtered_x.shape[-1] == seq_len
72
+
73
+ pha = filtered_x.angle()
74
+ amp = filtered_x.abs()
75
+
76
+ # Repeats freqs
77
+ freqs = (
78
+ self.freqs.unsqueeze(0).unsqueeze(0).repeat(pha.shape[0], pha.shape[1], 1)
79
+ )
80
+
81
+ if self.out_scale == "log":
82
+ return pha, torch.log(amp + 1e-5), freqs
83
+ else:
84
+ return pha, amp, freqs
85
+
86
+ def init_kernel(self, samp_rate, kernel_size=None, freq_scale="log"):
87
+ device = self.dummy.device
88
+ morlets, freqs = self.gen_morlet_to_nyquist(
89
+ samp_rate, kernel_size=kernel_size, freq_scale=freq_scale
90
+ )
91
+ self.kernel = torch.tensor(morlets).to(device)
92
+ self.freqs = torch.tensor(freqs).float().to(device)
93
+
94
+ @staticmethod
95
+ def gen_morlet_to_nyquist(samp_rate, kernel_size=None, freq_scale="linear"):
96
+ """
97
+ Generates Morlet wavelets for exponentially increasing frequency bands up to the Nyquist frequency.
98
+
99
+ Parameters:
100
+ - samp_rate (int): The sampling rate of the signal, in Hertz.
101
+ - kernel_size (int): The size of the kernel, in number of samples.
102
+
103
+ Returns:
104
+ - np.ndarray: A 2D array of complex values representing the Morlet wavelets for each frequency band.
105
+ """
106
+ if kernel_size is None:
107
+ kernel_size = int(samp_rate) # * 2.5)
108
+
109
+ nyquist_freq = samp_rate / 2
110
+
111
+ # Log freq_scale
112
+ def calc_freq_boundaries_log(nyquist_freq):
113
+ n_kernels = int(np.floor(np.log2(nyquist_freq)))
114
+ mid_hz = np.array([2 ** (n + 1) for n in range(n_kernels)])
115
+ width_hz = np.hstack([np.array([1]), np.diff(mid_hz) / 2]) + 1
116
+ low_hz = mid_hz - width_hz
117
+ high_hz = mid_hz + width_hz
118
+ low_hz[0] = 0.1
119
+ return low_hz, high_hz
120
+
121
+ def calc_freq_boundaries_linear(nyquist_freq):
122
+ n_kernels = int(nyquist_freq)
123
+ high_hz = np.linspace(1, nyquist_freq, n_kernels)
124
+ low_hz = high_hz - np.hstack([np.array(1), np.diff(high_hz)])
125
+ low_hz[0] = 0.1
126
+ return low_hz, high_hz
127
+
128
+ if freq_scale == "linear":
129
+ fn = calc_freq_boundaries_linear
130
+ if freq_scale == "log":
131
+ fn = calc_freq_boundaries_log
132
+ low_hz, high_hz = fn(nyquist_freq)
133
+
134
+ morlets = []
135
+ freqs = []
136
+
137
+ for _, (ll, hh) in enumerate(zip(low_hz, high_hz)):
138
+ if ll > nyquist_freq:
139
+ break
140
+
141
+ center_frequency = (ll + hh) / 2
142
+
143
+ t = np.arange(-kernel_size // 2, kernel_size // 2) / samp_rate
144
+ # Calculate standard deviation of the gaussian window for a given center frequency
145
+ sigma = 7 / (2 * np.pi * center_frequency)
146
+ sine_wave = np.exp(2j * np.pi * center_frequency * t)
147
+ gaussian_window = np.exp(-(t**2) / (2 * sigma**2))
148
+ morlet_wavelet = sine_wave * gaussian_window
149
+
150
+ freqs.append(center_frequency)
151
+ morlets.append(morlet_wavelet)
152
+
153
+ return np.array(morlets), np.array(freqs)
154
+
155
+ @property
156
+ def kernel_size(
157
+ self,
158
+ ):
159
+ return to_even(self.kernel.shape[-1])
160
+
161
+ @property
162
+ def radius(
163
+ self,
164
+ ):
165
+ return to_even(self.kernel_size // 2)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ import matplotlib.pyplot as plt
170
+ import scitex
171
+
172
+ xx, tt, fs = scitex.dsp.demo_sig(sig_type="chirp")
173
+
174
+ pha, amp, ff = scitex.dsp.wavelet(xx, fs)
175
+
176
+ fig, ax = scitex.plt.subplots()
177
+ ax.imshow2d(amp[0, 0].T)
178
+ ax = scitex.plt.ax.set_ticks(ax, xticks=tt, yticks=ff)
179
+ ax = scitex.plt.ax.set_n_ticks(ax)
180
+ plt.show()
181
+
182
+
183
+ # EOF
scitex/nn/__init__.py ADDED
@@ -0,0 +1,63 @@
1
+ #!/usr/bin/env python3
2
+ """Scitex nn module."""
3
+
4
+ from ._AxiswiseDropout import AxiswiseDropout
5
+ from ._BNet import BHead, BNet, BNet_config
6
+ from ._BNet_Res import BHead, BNet, BNet_config
7
+ from ._ChannelGainChanger import ChannelGainChanger
8
+ from ._DropoutChannels import DropoutChannels
9
+ from ._Filters import BandPassFilter, BandStopFilter, BaseFilter1D, DifferentiableBandPassFilter, GaussianFilter, HighPassFilter, LowPassFilter
10
+ from ._FreqGainChanger import FreqGainChanger
11
+ from ._GaussianFilter import GaussianFilter
12
+ from ._Hilbert import Hilbert
13
+ from ._MNet_1000 import MNet1000, MNet_1000, MNet_config, ReshapeLayer, SwapLayer
14
+ from ._ModulationIndex import ModulationIndex
15
+ from ._PAC import PAC
16
+ from ._PSD import PSD
17
+ from ._ResNet1D import ResNet1D, ResNetBasicBlock
18
+ from ._SpatialAttention import SpatialAttention
19
+ from ._Spectrogram import Spectrogram, my_softmax, normalize, spectrograms, unbias
20
+ from ._SwapChannels import SwapChannels
21
+ from ._TransposeLayer import TransposeLayer
22
+ from ._Wavelet import Wavelet
23
+
24
+ __all__ = [
25
+ "AxiswiseDropout",
26
+ "BHead",
27
+ "BHead",
28
+ "BNet",
29
+ "BNet",
30
+ "BNet_config",
31
+ "BNet_config",
32
+ "BandPassFilter",
33
+ "BandStopFilter",
34
+ "BaseFilter1D",
35
+ "ChannelGainChanger",
36
+ "DifferentiableBandPassFilter",
37
+ "DropoutChannels",
38
+ "FreqGainChanger",
39
+ "GaussianFilter",
40
+ "GaussianFilter",
41
+ "HighPassFilter",
42
+ "Hilbert",
43
+ "LowPassFilter",
44
+ "MNet1000",
45
+ "MNet_1000",
46
+ "MNet_config",
47
+ "ModulationIndex",
48
+ "PAC",
49
+ "PSD",
50
+ "ResNet1D",
51
+ "ResNetBasicBlock",
52
+ "ReshapeLayer",
53
+ "SpatialAttention",
54
+ "Spectrogram",
55
+ "SwapChannels",
56
+ "SwapLayer",
57
+ "TransposeLayer",
58
+ "Wavelet",
59
+ "my_softmax",
60
+ "normalize",
61
+ "spectrograms",
62
+ "unbias",
63
+ ]
scitex/os/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env python3
2
+ """Scitex os module."""
3
+
4
+ from ._mv import mv
5
+
6
+ __all__ = [
7
+ "mv",
8
+ ]
scitex/os/_mv.py ADDED
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-04-06 09:00:45 (ywatanabe)"
4
+
5
+ # import os
6
+ # import shutil
7
+
8
+ # def mv(src, tgt):
9
+ # successful = True
10
+ # os.makedirs(tgt, exist_ok=True)
11
+
12
+ # if os.path.isdir(src):
13
+ # # Iterate over the items in the directory
14
+ # for item in os.listdir(src):
15
+ # item_path = os.path.join(src, item)
16
+ # # Check if the item is a file
17
+ # if os.path.isfile(item_path):
18
+ # try:
19
+ # shutil.move(item_path, tgt)
20
+ # print(f"\nMoved file from {item_path} to {tgt}")
21
+ # except OSError as e:
22
+ # print(f"\nError: {e}")
23
+ # successful = False
24
+ # else:
25
+ # print(f"\nSkipped directory {item_path}")
26
+ # else:
27
+ # # If src is a file, just move it
28
+ # try:
29
+ # shutil.move(src, tgt)
30
+ # print(f"\nMoved from {src} to {tgt}")
31
+ # except OSError as e:
32
+ # print(f"\nError: {e}")
33
+ # successful = False
34
+
35
+ # return successful
36
+
37
+
38
+ def mv(src, tgt):
39
+ import os
40
+ import shutil
41
+
42
+ successful = True
43
+ os.makedirs(tgt, exist_ok=True)
44
+
45
+ try:
46
+ shutil.move(src, tgt)
47
+ print(f"\nMoved from {src} to {tgt}")
48
+ except OSError as e:
49
+ print(f"\nError: {e}")
50
+ successful = False
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env python3
2
+ """Scitex parallel module."""
3
+
4
+ from ._run import run
5
+
6
+ __all__ = [
7
+ "run",
8
+ ]
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-14 23:12:20 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/parallel/_run.py
5
+
6
+ """
7
+ 1. Functionality:
8
+ - Runs functions in parallel using ProcessPoolExecutor
9
+ - Handles both single and multiple return values
10
+ - Supports automatic CPU core detection
11
+ 2. Input:
12
+ - Function to run
13
+ - List of items to process
14
+ - Optional parameters for execution control
15
+ 3. Output:
16
+ - List of results or concatenated DataFrame/tuple
17
+ 4. Prerequisites:
18
+ - concurrent.futures
19
+ - pandas
20
+ - tqdm
21
+ """
22
+
23
+ import multiprocessing
24
+ import warnings
25
+ from concurrent.futures import ThreadPoolExecutor, as_completed
26
+ from typing import Any, Callable, List
27
+
28
+ from tqdm import tqdm
29
+
30
+
31
+ def run(
32
+ func: Callable,
33
+ args_list: List[tuple],
34
+ n_jobs: int = -1,
35
+ desc: str = "Processing",
36
+ ) -> List[Any]:
37
+ """Runs function in parallel using ThreadPoolExecutor with tuple arguments.
38
+
39
+ Parameters
40
+ ----------
41
+ func : Callable
42
+ Function to run in parallel
43
+ args_list : List[tuple]
44
+ List of argument tuples, each tuple contains arguments for one function call
45
+ n_jobs : int, optional
46
+ Number of jobs to run in parallel. -1 means using all processors
47
+ desc : str, optional
48
+ Description for progress bar
49
+
50
+ Returns
51
+ -------
52
+ List[Any]
53
+ Results of parallel execution
54
+
55
+ Examples
56
+ --------
57
+ >>> def add(x, y):
58
+ ... return x + y
59
+ >>> args_list = [(1, 4), (2, 5), (3, 6)]
60
+ >>> run(add, args_list)
61
+ [5, 7, 9]
62
+ """
63
+ if not args_list:
64
+ raise ValueError("Args list cannot be empty")
65
+ if not callable(func):
66
+ raise ValueError("Func must be callable")
67
+
68
+ cpu_count = multiprocessing.cpu_count()
69
+ n_jobs = cpu_count if n_jobs < 0 else n_jobs
70
+
71
+ if n_jobs > cpu_count:
72
+ warnings.warn(f"n_jobs ({n_jobs}) is greater than CPU count ({cpu_count})")
73
+ if n_jobs < 1:
74
+ raise ValueError("n_jobs must be >= 1 or -1")
75
+
76
+ results = [None] * len(args_list) # Pre-allocate list
77
+
78
+ with ThreadPoolExecutor(max_workers=n_jobs) as executor:
79
+ futures = {
80
+ executor.submit(func, *args): idx for idx, args in enumerate(args_list)
81
+ }
82
+ for future in tqdm(as_completed(futures), total=len(args_list), desc=desc):
83
+ idx = futures[future]
84
+ results[idx] = future.result()
85
+
86
+ # If results contain multiple values (tuples), transpose them
87
+ if results and isinstance(results[0], tuple):
88
+ n_vars = len(results[0])
89
+ return tuple([result[i] for result in results] for i in range(n_vars))
90
+
91
+ return results
92
+
93
+
94
+ # def run(
95
+ # func: Callable,
96
+ # items: List[Any],
97
+ # n_jobs: int = -1,
98
+ # desc: str = "Processing",
99
+ # ) -> List[Any]:
100
+ # """Runs function in parallel using ThreadPoolExecutor.
101
+
102
+ # Parameters
103
+ # ----------
104
+ # func : Callable
105
+ # Function to run in parallel
106
+ # items : List[Any]
107
+ # List of items to process
108
+ # n_jobs : int, optional
109
+ # Number of jobs to run in parallel. -1 means using all processors
110
+ # desc : str, optional
111
+ # Description for progress bar
112
+
113
+ # Returns
114
+ # -------
115
+ # List[Any]
116
+ # Results of parallel execution
117
+ # """
118
+ # if not items:
119
+ # raise ValueError("Items list cannot be empty")
120
+ # if not callable(func):
121
+ # raise ValueError("Func must be callable")
122
+ # if not isinstance(items, (list, tuple)):
123
+ # raise TypeError("Items must be a list or tuple")
124
+ # if not isinstance(n_jobs, int):
125
+ # raise TypeError("n_jobs must be an integer")
126
+
127
+ # cpu_count = multiprocessing.cpu_count()
128
+ # n_jobs = cpu_count if n_jobs < 0 else n_jobs
129
+
130
+ # if n_jobs > cpu_count:
131
+ # warnings.warn(f"n_jobs ({n_jobs}) is greater than CPU count ({cpu_count})")
132
+ # if n_jobs < 1:
133
+ # raise ValueError("n_jobs must be >= 1 or -1")
134
+
135
+ # results = [None] * len(items) # Pre-allocate list
136
+ # with ThreadPoolExecutor(max_workers=n_jobs) as executor:
137
+ # futures = {executor.submit(func, item): idx
138
+ # for idx, item in enumerate(items)}
139
+ # for future in tqdm(as_completed(futures), total=len(items), desc=desc):
140
+ # idx = futures[future]
141
+ # results[idx] = future.result()
142
+
143
+ # # If results contain multiple values (tuples), transpose them
144
+ # if results and isinstance(results[0], tuple):
145
+ # n_vars = len(results[0])
146
+ # return tuple([result[i] for result in results] for i in range(n_vars))
147
+
148
+ # return results
149
+
150
+
151
+ # EOF