scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,150 @@
1
+ #!/usr/bin/env python3
2
+ # Time-stamp: "2024-09-07 01:09:38 (ywatanabe)"
3
+
4
+ import os
5
+
6
+ import scitex
7
+ import numpy as np
8
+
9
+
10
+ class EarlyStopping:
11
+ """
12
+ Early stops the training if the validation score doesn't improve after a given patience period.
13
+
14
+ """
15
+
16
+ def __init__(self, patience=7, verbose=False, delta=1e-5, direction="minimize"):
17
+ """
18
+ Args:
19
+ patience (int): How long to wait after last time validation score improved.
20
+ Default: 7
21
+ verbose (bool): If True, prints a message for each validation score improvement.
22
+ Default: False
23
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
24
+ Default: 0
25
+ """
26
+ self.patience = patience
27
+ self.verbose = verbose
28
+ self.direction = direction
29
+
30
+ self.delta = delta
31
+
32
+ # default
33
+ self.counter = 0
34
+ self.best_score = np.inf if direction == "minimize" else -np.inf
35
+ self.best_i_global = None
36
+ self.models_spaths_dict = {}
37
+
38
+ def is_best(self, val_score):
39
+ is_smaller = val_score < self.best_score - abs(self.delta)
40
+ is_larger = self.best_score + abs(self.delta) < val_score
41
+ return is_smaller if self.direction == "minimize" else is_larger
42
+
43
+ def __call__(self, current_score, models_spaths_dict, i_global):
44
+ # The 1st call
45
+ if self.best_score is None:
46
+ self.save(current_score, models_spaths_dict, i_global)
47
+ return False
48
+
49
+ # After the 2nd call
50
+ if self.is_best(current_score):
51
+ self.save(current_score, models_spaths_dict, i_global)
52
+ self.counter = 0
53
+ return False
54
+
55
+ else:
56
+ self.counter += 1
57
+ if self.verbose:
58
+ print(
59
+ f"\nEarlyStopping counter: {self.counter} out of {self.patience}\n"
60
+ )
61
+ if self.counter >= self.patience:
62
+ if self.verbose:
63
+ scitex.str.printc("Early-stopped.", c="yellow")
64
+ return True
65
+ return False
66
+
67
+ def save(self, current_score, models_spaths_dict, i_global):
68
+ """Saves model when validation score decrease."""
69
+
70
+ if self.verbose:
71
+ print(
72
+ f"\nUpdate the best score: ({self.best_score:.6f} --> {current_score:.6f})"
73
+ )
74
+
75
+ self.best_score = current_score
76
+ self.best_i_global = i_global
77
+
78
+ for model, spath in models_spaths_dict.items():
79
+ scitex.io.save(model.state_dict(), spath)
80
+
81
+ self.models_spaths_dict = models_spaths_dict
82
+
83
+
84
+ if __name__ == "__main__":
85
+ pass
86
+ # # starts the current fold's loop
87
+ # i_global = 0
88
+ # lc_logger = scitex.ml.LearningCurveLogger()
89
+ # early_stopping = utils.EarlyStopping(patience=50, verbose=True)
90
+ # for i_epoch, epoch in enumerate(tqdm(range(merged_conf["MAX_EPOCHS"]))):
91
+
92
+ # dlf.fill(i_fold, reset_fill_counter=False)
93
+
94
+ # step_str = "Validation"
95
+ # for i_batch, batch in enumerate(dlf.dl_val):
96
+ # _, loss_diag_val = utils.base_step(
97
+ # step_str,
98
+ # model,
99
+ # mtl,
100
+ # batch,
101
+ # device,
102
+ # i_fold,
103
+ # i_epoch,
104
+ # i_batch,
105
+ # i_global,
106
+ # lc_logger,
107
+ # no_mtl=args.no_mtl,
108
+ # print_batch_interval=False,
109
+ # )
110
+ # lc_logger.print(step_str)
111
+
112
+ # step_str = "Training"
113
+ # for i_batch, batch in enumerate(dlf.dl_tra):
114
+ # optimizer.zero_grad()
115
+ # loss, _ = utils.base_step(
116
+ # step_str,
117
+ # model,
118
+ # mtl,
119
+ # batch,
120
+ # device,
121
+ # i_fold,
122
+ # i_epoch,
123
+ # i_batch,
124
+ # i_global,
125
+ # lc_logger,
126
+ # no_mtl=args.no_mtl,
127
+ # print_batch_interval=False,
128
+ # )
129
+ # loss.backward()
130
+ # optimizer.step()
131
+ # i_global += 1
132
+ # lc_logger.print(step_str)
133
+
134
+ # bACC_val = np.array(lc_logger.logged_dict["Validation"]["bACC_diag_plot"])[
135
+ # np.array(lc_logger.logged_dict["Validation"]["i_epoch"]) == i_epoch
136
+ # ].mean()
137
+
138
+ # model_spath = (
139
+ # merged_conf["sdir"]
140
+ # + f"checkpoints/model_fold#{i_fold}_epoch#{i_epoch:03d}.pth"
141
+ # )
142
+ # mtl_spath = model_spath.replace("model_fold", "mtl_fold")
143
+ # models_spaths_dict = {model_spath: model, mtl_spath: mtl}
144
+
145
+ # early_stopping(loss_diag_val, models_spaths_dict, i_epoch, i_global)
146
+ # # early_stopping(-bACC_val, models_spaths_dict, i_epoch, i_global)
147
+
148
+ # if early_stopping.early_stop:
149
+ # print("Early stopping")
150
+ # break
@@ -0,0 +1,555 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-20 08:49:50 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/_LearningCurveLogger.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/_LearningCurveLogger.py"
7
+
8
+ """
9
+ Functionality:
10
+ - Records and visualizes learning curves during model training
11
+ - Supports tracking of multiple metrics across training/validation/test phases
12
+ - Generates plots showing training progress over iterations and epochs
13
+
14
+ Input:
15
+ - Training metrics dictionary containing loss, accuracy, predictions etc.
16
+ - Step information (Training/Validation/Test)
17
+
18
+ Output:
19
+ - Learning curve plots
20
+ - Dataframes with recorded metrics
21
+ - Training progress prints
22
+
23
+ Prerequisites:
24
+ - PyTorch
25
+ - scikit-learn
26
+ - matplotlib
27
+ - pandas
28
+ - numpy
29
+ """
30
+
31
+ import re as _re
32
+ from collections import defaultdict as _defaultdict
33
+ from pprint import pprint as _pprint
34
+ from typing import Dict as _Dict
35
+ from typing import List as _List
36
+ from typing import Union as _Union
37
+ from typing import Optional as _Optional
38
+ from typing import Any as _Any
39
+
40
+ import matplotlib as _matplotlib
41
+ import pandas as _pd
42
+ import numpy as _np
43
+ import warnings as _warnings
44
+ import torch as _torch
45
+
46
+
47
+ class LearningCurveLogger:
48
+ """Records and visualizes learning metrics during model training.
49
+
50
+ Example
51
+ -------
52
+ >>> logger = LearningCurveLogger()
53
+ >>> metrics = {
54
+ ... "loss_plot": 0.5,
55
+ ... "balanced_ACC_plot": 0.8,
56
+ ... "pred_proba": pred_proba,
57
+ ... "true_class": labels,
58
+ ... "i_fold": 0,
59
+ ... "i_epoch": 1,
60
+ ... "i_global": 100
61
+ ... }
62
+ >>> logger(metrics, "Training")
63
+ >>> fig = logger.plot_learning_curves(plt)
64
+ """
65
+
66
+ def __init__(self) -> None:
67
+ self.logged_dict: _Dict[str, _Dict] = _defaultdict(dict)
68
+
69
+ _warnings.warn(
70
+ '\n"gt_label" will be removed in the future. Please use "true_class" instead.\n',
71
+ DeprecationWarning,
72
+ )
73
+
74
+ def __call__(self, dict_to_log: _Dict[str, _Any], step: str) -> None:
75
+ """Logs metrics for a training step.
76
+
77
+ Parameters
78
+ ----------
79
+ dict_to_log : _Dict[str, _Any]
80
+ _Dictionary containing metrics to log
81
+ step : str
82
+ Phase of training ('Training', 'Validation', or 'Test')
83
+ """
84
+ if "gt_label" in dict_to_log:
85
+ dict_to_log["true_class"] = dict_to_log.pop("gt_label")
86
+
87
+ for k_to_log in dict_to_log:
88
+ try:
89
+ self.logged_dict[step][k_to_log].append(dict_to_log[k_to_log])
90
+ except:
91
+ self.logged_dict[step][k_to_log] = [dict_to_log[k_to_log]]
92
+
93
+ @property
94
+ def dfs(self) -> _Dict[str, _pd.DataFrame]:
95
+ """Returns DataFrames of logged metrics.
96
+
97
+ Returns
98
+ -------
99
+ _Dict[str, _pd.DataFrame]
100
+ _Dictionary of DataFrames for each step
101
+ """
102
+ return self._to_dfs_pivot(
103
+ self.logged_dict,
104
+ pivot_column=None,
105
+ )
106
+
107
+ def get_x_of_i_epoch(self, x: str, step: str, i_epoch: int) -> _np.ndarray:
108
+ """Gets metric values for a specific epoch.
109
+
110
+ Parameters
111
+ ----------
112
+ x : str
113
+ Name of metric to retrieve
114
+ step : str
115
+ Training phase
116
+ i_epoch : int
117
+ Epoch number
118
+
119
+ Returns
120
+ -------
121
+ _np.ndarray
122
+ Array of metric values for specified epoch
123
+ """
124
+ indi = _np.array(self.logged_dict[step]["i_epoch"]) == i_epoch
125
+ x_all_arr = _np.array(self.logged_dict[step][x])
126
+ assert len(indi) == len(x_all_arr)
127
+ return x_all_arr[indi]
128
+
129
+ def plot_learning_curves(
130
+ self,
131
+ plt: _Any,
132
+ plt_config_dict: _Optional[_Dict] = None,
133
+ title: _Optional[str] = None,
134
+ max_n_ticks: int = 4,
135
+ linewidth: float = 1,
136
+ scattersize: float = 50,
137
+ ) -> _matplotlib.figure.Figure:
138
+ """Plots learning curves from logged metrics.
139
+
140
+ Parameters
141
+ ----------
142
+ plt : _matplotlib.pyplot
143
+ _Matplotlib pyplot object
144
+ plt_config_dict : _Dict, optional
145
+ Plot configuration parameters
146
+ title : str, optional
147
+ Plot title
148
+ max_n_ticks : int
149
+ Maximum number of ticks on axes
150
+ linewidth : float
151
+ Width of plot lines
152
+ scattersize : float
153
+ Size of scatter points
154
+
155
+ Returns
156
+ -------
157
+ _matplotlib.figure.Figure
158
+ Figure containing learning curves
159
+ """
160
+
161
+ if plt_config_dict is not None:
162
+ # Skip configure_mpl for now - would need to import plt module
163
+ pass
164
+
165
+ self.dfs_pivot_i_global = self._to_dfs_pivot(
166
+ self.logged_dict, pivot_column="i_global"
167
+ )
168
+
169
+ COLOR_DICT = {
170
+ "Training": "blue",
171
+ "Validation": "green",
172
+ "Test": "red",
173
+ }
174
+
175
+ keys_to_plot = self._find_keys_to_plot(self.logged_dict)
176
+
177
+ if len(keys_to_plot) == 0:
178
+ # No keys to plot, return empty figure
179
+ fig, ax = plt.subplots(1, 1)
180
+ ax.text(0.5, 0.5, 'No data to plot', ha='center', va='center')
181
+ return fig
182
+
183
+ fig, axes = plt.subplots(len(keys_to_plot), 1, sharex=True, sharey=False)
184
+ if len(keys_to_plot) == 1:
185
+ axes = [axes] # Make it a list for consistency
186
+ axes[-1].set_xlabel("Iteration#")
187
+ fig.text(0.5, 0.95, title, ha="center")
188
+
189
+ for i_plt, plt_k in enumerate(keys_to_plot):
190
+ ax = axes[i_plt]
191
+ ax.set_ylabel(self._rename_if_key_to_plot(plt_k))
192
+ ax.xaxis.set_major_locator(_matplotlib.ticker.MaxNLocator(max_n_ticks))
193
+ ax.yaxis.set_major_locator(_matplotlib.ticker.MaxNLocator(max_n_ticks))
194
+
195
+ if _re.search("[aA][cC][cC]", plt_k):
196
+ ax.set_ylim(0, 1)
197
+ ax.set_yticks([0, 0.5, 1.0])
198
+
199
+ for step_k in self.dfs_pivot_i_global.keys():
200
+ if step_k == _re.search("^[Tt]rain", step_k):
201
+ ax.plot(
202
+ self.dfs_pivot_i_global[step_k].index,
203
+ self.dfs_pivot_i_global[step_k][plt_k],
204
+ label=step_k,
205
+ color=COLOR_DICT[step_k],
206
+ linewidth=linewidth,
207
+ )
208
+ ax.legend()
209
+
210
+ epoch_starts = abs(
211
+ self.dfs_pivot_i_global[step_k]["i_epoch"]
212
+ - self.dfs_pivot_i_global[step_k]["i_epoch"].shift(-1)
213
+ )
214
+ indi_global_epoch_starts = [0] + list(
215
+ epoch_starts[epoch_starts == 1].index
216
+ )
217
+
218
+ for i_epoch, i_global_epoch_start in enumerate(
219
+ indi_global_epoch_starts
220
+ ):
221
+ ax.axvline(
222
+ x=i_global_epoch_start,
223
+ ymin=-1e4,
224
+ ymax=1e4,
225
+ linestyle="--",
226
+ color=_plt_module.colors.to_RGBA("gray", alpha=0.5),
227
+ )
228
+
229
+ if (step_k == "Validation") or (step_k == "Test"):
230
+ ax.scatter(
231
+ self.dfs_pivot_i_global[step_k].index,
232
+ self.dfs_pivot_i_global[step_k][plt_k],
233
+ label=step_k,
234
+ color=COLOR_DICT[step_k],
235
+ s=scattersize,
236
+ alpha=0.9,
237
+ )
238
+ ax.legend()
239
+
240
+ return fig
241
+
242
+ def print(self, step: str) -> None:
243
+ """Prints metrics for given step.
244
+
245
+ Parameters
246
+ ----------
247
+ step : str
248
+ Training phase to print metrics for
249
+ """
250
+ df_pivot_i_epoch = self._to_dfs_pivot(self.logged_dict, pivot_column="i_epoch")
251
+ df_pivot_i_epoch_step = df_pivot_i_epoch[step]
252
+ df_pivot_i_epoch_step.columns = self._rename_if_key_to_plot(
253
+ df_pivot_i_epoch_step.columns
254
+ )
255
+ print("\n----------------------------------------\n")
256
+ print(f"\n{step}: (mean of batches)\n")
257
+ _pprint(df_pivot_i_epoch_step)
258
+ print("\n----------------------------------------\n")
259
+
260
+ @staticmethod
261
+ def _find_keys_to_plot(logged_dict: _Dict) -> _List[str]:
262
+ """Find metrics to plot from logged dictionary.
263
+
264
+ Parameters
265
+ ----------
266
+ logged_dict : _Dict
267
+ _Dictionary of logged metrics
268
+
269
+ Returns
270
+ -------
271
+ _List[str]
272
+ _List of metric names to plot
273
+ """
274
+ for step_k in logged_dict.keys():
275
+ break
276
+
277
+ keys_to_plot = []
278
+ for k in logged_dict[step_k].keys():
279
+ if _re.search("_plot$", k):
280
+ keys_to_plot.append(k)
281
+ return keys_to_plot
282
+
283
+ @staticmethod
284
+ def _rename_if_key_to_plot(x: _Union[str, _pd.Index]) -> _Union[str, _pd.Index]:
285
+ """Rename metric keys for plotting.
286
+
287
+ Parameters
288
+ ----------
289
+ x : str or _pd.Index
290
+ Metric name(s) to rename
291
+
292
+ Returns
293
+ -------
294
+ str or _pd.Index
295
+ Renamed metric name(s)
296
+ """
297
+ if isinstance(x, str):
298
+ if _re.search("_plot$", x):
299
+ return x.replace("_plot", "")
300
+ else:
301
+ return x
302
+ else:
303
+ return x.str.replace("_plot", "")
304
+
305
+ @staticmethod
306
+ def _to_dfs_pivot(
307
+ logged_dict: _Dict[str, _Dict],
308
+ pivot_column: _Optional[str] = None,
309
+ ) -> _Dict[str, _pd.DataFrame]:
310
+ """Convert logged dictionary to pivot DataFrames.
311
+
312
+ Parameters
313
+ ----------
314
+ logged_dict : _Dict[str, _Dict]
315
+ _Dictionary of logged metrics
316
+ pivot_column : str, optional
317
+ Column to pivot on
318
+
319
+ Returns
320
+ -------
321
+ _Dict[str, _pd.DataFrame]
322
+ _Dictionary of pivot DataFrames
323
+ """
324
+
325
+ dfs_pivot = {}
326
+ for step_k in logged_dict.keys():
327
+ if pivot_column is None:
328
+ df = _pd.DataFrame(logged_dict[step_k])
329
+ else:
330
+ df = (
331
+ _pd.DataFrame(logged_dict[step_k])
332
+ .groupby(pivot_column)
333
+ .mean()
334
+ .reset_index()
335
+ .set_index(pivot_column)
336
+ )
337
+ dfs_pivot[step_k] = df
338
+ return dfs_pivot
339
+
340
+
341
+ if __name__ == "__main__":
342
+ import warnings
343
+
344
+ import matplotlib.pyplot as plt
345
+ import torch
346
+ import torch.nn as nn
347
+ from sklearn.metrics import balanced_accuracy_score
348
+ from torch.utils.data import DataLoader, TensorDataset
349
+ from torch.utils.data.dataset import Subset
350
+ from torchvision import datasets
351
+
352
+ import sys
353
+
354
+ ################################################################################
355
+ ## Sets tee
356
+ ################################################################################
357
+ sdir = scitex.io.path.mk_spath("") # "/tmp/sdir/"
358
+ sys.stdout, sys.stderr = scitex.gen.tee(sys, sdir)
359
+
360
+ ################################################################################
361
+ ## NN
362
+ ################################################################################
363
+ class Perceptron(nn.Module):
364
+ def __init__(self):
365
+ super().__init__()
366
+ self.l1 = nn.Linear(28 * 28, 50)
367
+ self.l2 = nn.Linear(50, 10)
368
+
369
+ def forward(self, x):
370
+ x = x.view(-1, 28 * 28)
371
+ x = self.l1(x)
372
+ x = self.l2(x)
373
+ return x
374
+
375
+ ################################################################################
376
+ ## Prepaires demo data
377
+ ################################################################################
378
+ ## Downloads
379
+ _ds_tra_val = datasets.MNIST("/tmp/mnist", train=True, download=True)
380
+ _ds_tes = datasets.MNIST("/tmp/mnist", train=False, download=True)
381
+
382
+ ## Training-Validation splitting
383
+ n_samples = len(_ds_tra_val) # n_samples is 60000
384
+ train_size = int(n_samples * 0.8) # train_size is 48000
385
+
386
+ subset1_indices = list(range(0, train_size)) # [0,1,.....47999]
387
+ subset2_indices = list(range(train_size, n_samples)) # [48000,48001,.....59999]
388
+
389
+ _ds_tra = Subset(_ds_tra_val, subset1_indices)
390
+ _ds_val = Subset(_ds_tra_val, subset2_indices)
391
+
392
+ ## to tensors
393
+ ds_tra = TensorDataset(
394
+ _ds_tra.dataset.data.to(_torch.float32),
395
+ _ds_tra.dataset.targets,
396
+ )
397
+ ds_val = TensorDataset(
398
+ _ds_val.dataset.data.to(_torch.float32),
399
+ _ds_val.dataset.targets,
400
+ )
401
+ ds_tes = TensorDataset(
402
+ _ds_tes.data.to(_torch.float32),
403
+ _ds_tes.targets,
404
+ )
405
+
406
+ ## to dataloaders
407
+ batch_size = 64
408
+ dl_tra = DataLoader(
409
+ dataset=ds_tra,
410
+ batch_size=batch_size,
411
+ shuffle=True,
412
+ drop_last=True,
413
+ )
414
+
415
+ dl_val = DataLoader(
416
+ dataset=ds_val,
417
+ batch_size=batch_size,
418
+ shuffle=False,
419
+ drop_last=True,
420
+ )
421
+
422
+ dl_tes = DataLoader(
423
+ dataset=ds_tes,
424
+ batch_size=batch_size,
425
+ shuffle=False,
426
+ drop_last=True,
427
+ )
428
+
429
+ ################################################################################
430
+ ## Preparation
431
+ ################################################################################
432
+ model = Perceptron()
433
+ loss_func = nn.CrossEntropyLoss()
434
+ optimizer = _torch.optim.SGD(model.parameters(), lr=1e-3)
435
+ softmax = nn.Softmax(dim=-1)
436
+
437
+ ################################################################################
438
+ ## Main
439
+ ################################################################################
440
+ lc_logger = LearningCurveLogger()
441
+ i_global = 0
442
+
443
+ n_classes = len(dl_tra.dataset.tensors[1].unique())
444
+ i_fold = 0
445
+ max_epochs = 3
446
+
447
+ for i_epoch in range(max_epochs):
448
+ step = "Validation"
449
+ for i_batch, batch in enumerate(dl_val):
450
+
451
+ X, T = batch
452
+ logits = model(X)
453
+ pred_proba = softmax(logits)
454
+ pred_class = pred_proba.argmax(dim=-1)
455
+ loss = loss_func(logits, T)
456
+
457
+ with warnings.catch_warnings():
458
+ warnings.simplefilter("ignore", UserWarning)
459
+ bACC = balanced_accuracy_score(T, pred_class)
460
+
461
+ dict_to_log = {
462
+ "loss_plot": float(loss),
463
+ "balanced_ACC_plot": float(bACC),
464
+ "pred_proba": pred_proba.detach().cpu().numpy(),
465
+ "gt_label": T.cpu().numpy(),
466
+ # "true_class": T.cpu().numpy(),
467
+ "i_fold": i_fold,
468
+ "i_epoch": i_epoch,
469
+ "i_global": i_global,
470
+ }
471
+ lc_logger(dict_to_log, step)
472
+
473
+ lc_logger.print(step)
474
+
475
+ step = "Training"
476
+ for i_batch, batch in enumerate(dl_tra):
477
+ optimizer.zero_grad()
478
+
479
+ X, T = batch
480
+ logits = model(X)
481
+ pred_proba = softmax(logits)
482
+ pred_class = pred_proba.argmax(dim=-1)
483
+ loss = loss_func(logits, T)
484
+
485
+ loss.backward()
486
+ optimizer.step()
487
+
488
+ with warnings.catch_warnings():
489
+ warnings.simplefilter("ignore", UserWarning)
490
+ bACC = balanced_accuracy_score(T, pred_class)
491
+
492
+ dict_to_log = {
493
+ "loss_plot": float(loss),
494
+ "balanced_ACC_plot": float(bACC),
495
+ "pred_proba": pred_proba.detach().cpu().numpy(),
496
+ "gt_label": T.cpu().numpy(),
497
+ # "true_class": T.cpu().numpy(),
498
+ "i_fold": i_fold,
499
+ "i_epoch": i_epoch,
500
+ "i_global": i_global,
501
+ }
502
+ lc_logger(dict_to_log, step)
503
+
504
+ i_global += 1
505
+
506
+ lc_logger.print(step)
507
+
508
+ step = "Test"
509
+ for i_batch, batch in enumerate(dl_tes):
510
+
511
+ X, T = batch
512
+ logits = model(X)
513
+ pred_proba = softmax(logits)
514
+ pred_class = pred_proba.argmax(dim=-1)
515
+ loss = loss_func(logits, T)
516
+
517
+ with warnings.catch_warnings():
518
+ warnings.simplefilter("ignore", UserWarning)
519
+ bACC = balanced_accuracy_score(T, pred_class)
520
+
521
+ dict_to_log = {
522
+ "loss_plot": float(loss),
523
+ "balanced_ACC_plot": float(bACC),
524
+ "pred_proba": pred_proba.detach().cpu().numpy(),
525
+ # "gt_label": T.cpu().numpy(),
526
+ "true_class": T.cpu().numpy(),
527
+ "i_fold": i_fold,
528
+ "i_epoch": i_epoch,
529
+ "i_global": i_global,
530
+ }
531
+ lc_logger(dict_to_log, step)
532
+
533
+ lc_logger.print(step)
534
+
535
+ plt_config_dict = dict(
536
+ # figsize=(8.7, 10),
537
+ figscale=2.5,
538
+ labelsize=16,
539
+ fontsize=12,
540
+ legendfontsize=12,
541
+ tick_size=0.8,
542
+ tick_width=0.2,
543
+ )
544
+
545
+ fig = lc_logger.plot_learning_curves(
546
+ plt,
547
+ plt_config_dict=plt_config_dict,
548
+ title=f"fold#{i_fold}",
549
+ linewidth=1,
550
+ scattersize=50,
551
+ )
552
+ fig.show()
553
+ # scitex.gen.save(fig, sdir + f"fold#{i_fold}.png")
554
+
555
+ # EOF