scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,583 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-20 08:49:50 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/_LearningCurveLogger.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/_LearningCurveLogger.py"
7
+
8
+ """
9
+ Functionality:
10
+ - Records and visualizes learning curves during model training
11
+ - Supports tracking of multiple metrics across training/validation/test phases
12
+ - Generates plots showing training progress over iterations and epochs
13
+
14
+ Input:
15
+ - Training metrics dictionary containing loss, accuracy, predictions etc.
16
+ - Step information (Training/Validation/Test)
17
+
18
+ Output:
19
+ - Learning curve plots
20
+ - Dataframes with recorded metrics
21
+ - Training progress prints
22
+
23
+ Prerequisites:
24
+ - PyTorch
25
+ - scikit-learn
26
+ - matplotlib
27
+ - pandas
28
+ - numpy
29
+ """
30
+
31
+ import re as _re
32
+ from collections import defaultdict as _defaultdict
33
+ from pprint import pprint as _pprint
34
+ from typing import Dict as _Dict
35
+ from typing import List as _List
36
+ from typing import Union as _Union
37
+ from typing import Optional as _Optional
38
+ from typing import Any as _Any
39
+
40
+ import matplotlib as _matplotlib
41
+ import pandas as _pd
42
+ import numpy as _np
43
+ import warnings as _warnings
44
+ import torch as _torch
45
+
46
+ # Import scitex modules for plotting
47
+ try:
48
+ import scitex.plt.utils as _plt_utils
49
+ import scitex.plt.color as _plt_color
50
+
51
+ class _plt_module:
52
+ configure_mpl = _plt_utils.configure_mpl
53
+ colors = _plt_color
54
+ except ImportError:
55
+ # Mock for testing when scitex is not available
56
+ class _plt_module:
57
+ @staticmethod
58
+ def configure_mpl(*args, **kwargs):
59
+ pass
60
+
61
+ class colors:
62
+ @staticmethod
63
+ def to_RGBA(color, alpha=1.0):
64
+ return color
65
+
66
+ @staticmethod
67
+ def to_rgba(color, alpha=1.0):
68
+ return color
69
+
70
+
71
+ class LearningCurveLogger:
72
+ """Records and visualizes learning metrics during model training.
73
+
74
+ Example
75
+ -------
76
+ >>> logger = LearningCurveLogger()
77
+ >>> metrics = {
78
+ ... "loss_plot": 0.5,
79
+ ... "balanced_ACC_plot": 0.8,
80
+ ... "pred_proba": pred_proba,
81
+ ... "true_class": labels,
82
+ ... "i_fold": 0,
83
+ ... "i_epoch": 1,
84
+ ... "i_global": 100
85
+ ... }
86
+ >>> logger(metrics, "Training")
87
+ >>> fig = logger.plot_learning_curves(plt)
88
+ """
89
+
90
+ def __init__(self) -> None:
91
+ self.logged_dict: _Dict[str, _Dict] = _defaultdict(dict)
92
+
93
+ _warnings.warn(
94
+ '\n"gt_label" will be removed in the future. Please use "true_class" instead.\n',
95
+ DeprecationWarning,
96
+ )
97
+
98
+ def __call__(self, dict_to_log: _Dict[str, _Any], step: str) -> None:
99
+ """Logs metrics for a training step.
100
+
101
+ Parameters
102
+ ----------
103
+ dict_to_log : _Dict[str, _Any]
104
+ _Dictionary containing metrics to log
105
+ step : str
106
+ Phase of training ('Training', 'Validation', or 'Test')
107
+ """
108
+ if "gt_label" in dict_to_log:
109
+ dict_to_log["true_class"] = dict_to_log.pop("gt_label")
110
+
111
+ for k_to_log in dict_to_log:
112
+ try:
113
+ self.logged_dict[step][k_to_log].append(dict_to_log[k_to_log])
114
+ except:
115
+ self.logged_dict[step][k_to_log] = [dict_to_log[k_to_log]]
116
+
117
+ @property
118
+ def dfs(self) -> _Dict[str, _pd.DataFrame]:
119
+ """Returns DataFrames of logged metrics.
120
+
121
+ Returns
122
+ -------
123
+ _Dict[str, _pd.DataFrame]
124
+ _Dictionary of DataFrames for each step
125
+ """
126
+ return self._to_dfs_pivot(
127
+ self.logged_dict,
128
+ pivot_column=None,
129
+ )
130
+
131
+ def get_x_of_i_epoch(self, x: str, step: str, i_epoch: int) -> _np.ndarray:
132
+ """Gets metric values for a specific epoch.
133
+
134
+ Parameters
135
+ ----------
136
+ x : str
137
+ Name of metric to retrieve
138
+ step : str
139
+ Training phase
140
+ i_epoch : int
141
+ Epoch number
142
+
143
+ Returns
144
+ -------
145
+ _np.ndarray
146
+ Array of metric values for specified epoch
147
+ """
148
+ indi = _np.array(self.logged_dict[step]["i_epoch"]) == i_epoch
149
+ x_all_arr = _np.array(self.logged_dict[step][x])
150
+ assert len(indi) == len(x_all_arr)
151
+ return x_all_arr[indi]
152
+
153
+ def plot_learning_curves(
154
+ self,
155
+ plt: _Any,
156
+ plt_config_dict: _Optional[_Dict] = None,
157
+ title: _Optional[str] = None,
158
+ max_n_ticks: int = 4,
159
+ linewidth: float = 1,
160
+ scattersize: float = 50,
161
+ ) -> _matplotlib.figure.Figure:
162
+ """Plots learning curves from logged metrics.
163
+
164
+ Parameters
165
+ ----------
166
+ plt : _matplotlib.pyplot
167
+ _Matplotlib pyplot object
168
+ plt_config_dict : _Dict, optional
169
+ Plot configuration parameters
170
+ title : str, optional
171
+ Plot title
172
+ max_n_ticks : int
173
+ Maximum number of ticks on axes
174
+ linewidth : float
175
+ Width of plot lines
176
+ scattersize : float
177
+ Size of scatter points
178
+
179
+ Returns
180
+ -------
181
+ _matplotlib.figure.Figure
182
+ Figure containing learning curves
183
+ """
184
+
185
+ if plt_config_dict is not None:
186
+ _plt_module.configure_mpl(plt, **plt_config_dict)
187
+
188
+ self.dfs_pivot_i_global = self._to_dfs_pivot(
189
+ self.logged_dict, pivot_column="i_global"
190
+ )
191
+
192
+ COLOR_DICT = {
193
+ "Training": "blue",
194
+ "Validation": "green",
195
+ "Test": "red",
196
+ }
197
+
198
+ keys_to_plot = self._find_keys_to_plot(self.logged_dict)
199
+
200
+ if len(keys_to_plot) == 0:
201
+ # Create empty plot when no plot keys found
202
+ fig, axes = plt.subplots(1, 1)
203
+ axes.set_xlabel("Iteration#")
204
+ axes.set_ylabel("No metrics to plot")
205
+ fig.text(0.5, 0.95, title, ha="center")
206
+ return fig
207
+
208
+ fig, axes = plt.subplots(len(keys_to_plot), 1, sharex=True, sharey=False)
209
+
210
+ # Handle both single and multiple axes cases
211
+ if len(keys_to_plot) == 1:
212
+ axes = [axes] # Make it a list for consistent indexing
213
+
214
+ axes[-1].set_xlabel("Iteration#")
215
+ fig.text(0.5, 0.95, title, ha="center")
216
+
217
+ for i_plt, plt_k in enumerate(keys_to_plot):
218
+ ax = axes[i_plt]
219
+ ax.set_ylabel(self._rename_if_key_to_plot(plt_k))
220
+ ax.xaxis.set_major_locator(_matplotlib.ticker.MaxNLocator(max_n_ticks))
221
+ ax.yaxis.set_major_locator(_matplotlib.ticker.MaxNLocator(max_n_ticks))
222
+
223
+ if _re.search("[aA][cC][cC]", plt_k):
224
+ ax.set_ylim(0, 1)
225
+ ax.set_yticks([0, 0.5, 1.0])
226
+
227
+ for step_k in self.dfs_pivot_i_global.keys():
228
+ if step_k == _re.search("^[Tt]rain", step_k):
229
+ ax.plot(
230
+ self.dfs_pivot_i_global[step_k].index,
231
+ self.dfs_pivot_i_global[step_k][plt_k],
232
+ label=step_k,
233
+ color=_plt_module.colors.to_rgba(COLOR_DICT[step_k], alpha=0.9),
234
+ linewidth=linewidth,
235
+ )
236
+ ax.legend()
237
+
238
+ epoch_starts = abs(
239
+ self.dfs_pivot_i_global[step_k]["i_epoch"]
240
+ - self.dfs_pivot_i_global[step_k]["i_epoch"].shift(-1)
241
+ )
242
+ indi_global_epoch_starts = [0] + list(
243
+ epoch_starts[epoch_starts == 1].index
244
+ )
245
+
246
+ for i_epoch, i_global_epoch_start in enumerate(
247
+ indi_global_epoch_starts
248
+ ):
249
+ ax.axvline(
250
+ x=i_global_epoch_start,
251
+ ymin=-1e4,
252
+ ymax=1e4,
253
+ linestyle="--",
254
+ color=_plt_module.colors.to_rgba("gray", alpha=0.5),
255
+ )
256
+
257
+ if (step_k == "Validation") or (step_k == "Test"):
258
+ ax.scatter(
259
+ self.dfs_pivot_i_global[step_k].index,
260
+ self.dfs_pivot_i_global[step_k][plt_k],
261
+ label=step_k,
262
+ color=_plt_module.colors.to_rgba(COLOR_DICT[step_k], alpha=0.9),
263
+ s=scattersize,
264
+ alpha=0.9,
265
+ )
266
+ ax.legend()
267
+
268
+ return fig
269
+
270
+ def print(self, step: str) -> None:
271
+ """Prints metrics for given step.
272
+
273
+ Parameters
274
+ ----------
275
+ step : str
276
+ Training phase to print metrics for
277
+ """
278
+ df_pivot_i_epoch = self._to_dfs_pivot(self.logged_dict, pivot_column="i_epoch")
279
+ df_pivot_i_epoch_step = df_pivot_i_epoch[step]
280
+ df_pivot_i_epoch_step.columns = self._rename_if_key_to_plot(
281
+ df_pivot_i_epoch_step.columns
282
+ )
283
+ print("\n----------------------------------------\n")
284
+ print(f"\n{step}: (mean of batches)\n")
285
+ _pprint(df_pivot_i_epoch_step)
286
+ print("\n----------------------------------------\n")
287
+
288
+ @staticmethod
289
+ def _find_keys_to_plot(logged_dict: _Dict) -> _List[str]:
290
+ """Find metrics to plot from logged dictionary.
291
+
292
+ Parameters
293
+ ----------
294
+ logged_dict : _Dict
295
+ _Dictionary of logged metrics
296
+
297
+ Returns
298
+ -------
299
+ _List[str]
300
+ _List of metric names to plot
301
+ """
302
+ for step_k in logged_dict.keys():
303
+ break
304
+
305
+ keys_to_plot = []
306
+ for k in logged_dict[step_k].keys():
307
+ if _re.search("_plot$", k):
308
+ keys_to_plot.append(k)
309
+ return keys_to_plot
310
+
311
+ @staticmethod
312
+ def _rename_if_key_to_plot(x: _Union[str, _pd.Index]) -> _Union[str, _pd.Index]:
313
+ """Rename metric keys for plotting.
314
+
315
+ Parameters
316
+ ----------
317
+ x : str or _pd.Index
318
+ Metric name(s) to rename
319
+
320
+ Returns
321
+ -------
322
+ str or _pd.Index
323
+ Renamed metric name(s)
324
+ """
325
+ if isinstance(x, str):
326
+ if _re.search("_plot$", x):
327
+ return x.replace("_plot", "")
328
+ else:
329
+ return x
330
+ else:
331
+ return x.str.replace("_plot", "")
332
+
333
+ @staticmethod
334
+ def _to_dfs_pivot(
335
+ logged_dict: _Dict[str, _Dict],
336
+ pivot_column: _Optional[str] = None,
337
+ ) -> _Dict[str, _pd.DataFrame]:
338
+ """Convert logged dictionary to pivot DataFrames.
339
+
340
+ Parameters
341
+ ----------
342
+ logged_dict : _Dict[str, _Dict]
343
+ _Dictionary of logged metrics
344
+ pivot_column : str, optional
345
+ Column to pivot on
346
+
347
+ Returns
348
+ -------
349
+ _Dict[str, _pd.DataFrame]
350
+ _Dictionary of pivot DataFrames
351
+ """
352
+
353
+ dfs_pivot = {}
354
+ for step_k in logged_dict.keys():
355
+ if pivot_column is None:
356
+ df = _pd.DataFrame(logged_dict[step_k])
357
+ else:
358
+ df = (
359
+ _pd.DataFrame(logged_dict[step_k])
360
+ .groupby(pivot_column)
361
+ .mean()
362
+ .reset_index()
363
+ .set_index(pivot_column)
364
+ )
365
+ dfs_pivot[step_k] = df
366
+ return dfs_pivot
367
+
368
+
369
+ if __name__ == "__main__":
370
+ import warnings
371
+
372
+ import matplotlib.pyplot as plt
373
+ import torch
374
+ import torch.nn as nn
375
+ from sklearn.metrics import balanced_accuracy_score
376
+ from torch.utils.data import DataLoader, TensorDataset
377
+ from torch.utils.data.dataset import Subset
378
+ from torchvision import datasets
379
+
380
+ import sys
381
+
382
+ ################################################################################
383
+ ## Sets tee
384
+ ################################################################################
385
+ sdir = scitex.io.path.mk_spath("") # "/tmp/sdir/"
386
+ sys.stdout, sys.stderr = scitex.gen.tee(sys, sdir)
387
+
388
+ ################################################################################
389
+ ## NN
390
+ ################################################################################
391
+ class Perceptron(nn.Module):
392
+ def __init__(self):
393
+ super().__init__()
394
+ self.l1 = nn.Linear(28 * 28, 50)
395
+ self.l2 = nn.Linear(50, 10)
396
+
397
+ def forward(self, x):
398
+ x = x.view(-1, 28 * 28)
399
+ x = self.l1(x)
400
+ x = self.l2(x)
401
+ return x
402
+
403
+ ################################################################################
404
+ ## Prepaires demo data
405
+ ################################################################################
406
+ ## Downloads
407
+ _ds_tra_val = datasets.MNIST("/tmp/mnist", train=True, download=True)
408
+ _ds_tes = datasets.MNIST("/tmp/mnist", train=False, download=True)
409
+
410
+ ## Training-Validation splitting
411
+ n_samples = len(_ds_tra_val) # n_samples is 60000
412
+ train_size = int(n_samples * 0.8) # train_size is 48000
413
+
414
+ subset1_indices = list(range(0, train_size)) # [0,1,.....47999]
415
+ subset2_indices = list(range(train_size, n_samples)) # [48000,48001,.....59999]
416
+
417
+ _ds_tra = Subset(_ds_tra_val, subset1_indices)
418
+ _ds_val = Subset(_ds_tra_val, subset2_indices)
419
+
420
+ ## to tensors
421
+ ds_tra = TensorDataset(
422
+ _ds_tra.dataset.data.to(_torch.float32),
423
+ _ds_tra.dataset.targets,
424
+ )
425
+ ds_val = TensorDataset(
426
+ _ds_val.dataset.data.to(_torch.float32),
427
+ _ds_val.dataset.targets,
428
+ )
429
+ ds_tes = TensorDataset(
430
+ _ds_tes.data.to(_torch.float32),
431
+ _ds_tes.targets,
432
+ )
433
+
434
+ ## to dataloaders
435
+ batch_size = 64
436
+ dl_tra = DataLoader(
437
+ dataset=ds_tra,
438
+ batch_size=batch_size,
439
+ shuffle=True,
440
+ drop_last=True,
441
+ )
442
+
443
+ dl_val = DataLoader(
444
+ dataset=ds_val,
445
+ batch_size=batch_size,
446
+ shuffle=False,
447
+ drop_last=True,
448
+ )
449
+
450
+ dl_tes = DataLoader(
451
+ dataset=ds_tes,
452
+ batch_size=batch_size,
453
+ shuffle=False,
454
+ drop_last=True,
455
+ )
456
+
457
+ ################################################################################
458
+ ## Preparation
459
+ ################################################################################
460
+ model = Perceptron()
461
+ loss_func = nn.CrossEntropyLoss()
462
+ optimizer = _torch.optim.SGD(model.parameters(), lr=1e-3)
463
+ softmax = nn.Softmax(dim=-1)
464
+
465
+ ################################################################################
466
+ ## Main
467
+ ################################################################################
468
+ lc_logger = LearningCurveLogger()
469
+ i_global = 0
470
+
471
+ n_classes = len(dl_tra.dataset.tensors[1].unique())
472
+ i_fold = 0
473
+ max_epochs = 3
474
+
475
+ for i_epoch in range(max_epochs):
476
+ step = "Validation"
477
+ for i_batch, batch in enumerate(dl_val):
478
+
479
+ X, T = batch
480
+ logits = model(X)
481
+ pred_proba = softmax(logits)
482
+ pred_class = pred_proba.argmax(dim=-1)
483
+ loss = loss_func(logits, T)
484
+
485
+ with warnings.catch_warnings():
486
+ warnings.simplefilter("ignore", UserWarning)
487
+ bACC = balanced_accuracy_score(T, pred_class)
488
+
489
+ dict_to_log = {
490
+ "loss_plot": float(loss),
491
+ "balanced_ACC_plot": float(bACC),
492
+ "pred_proba": pred_proba.detach().cpu().numpy(),
493
+ "gt_label": T.cpu().numpy(),
494
+ # "true_class": T.cpu().numpy(),
495
+ "i_fold": i_fold,
496
+ "i_epoch": i_epoch,
497
+ "i_global": i_global,
498
+ }
499
+ lc_logger(dict_to_log, step)
500
+
501
+ lc_logger.print(step)
502
+
503
+ step = "Training"
504
+ for i_batch, batch in enumerate(dl_tra):
505
+ optimizer.zero_grad()
506
+
507
+ X, T = batch
508
+ logits = model(X)
509
+ pred_proba = softmax(logits)
510
+ pred_class = pred_proba.argmax(dim=-1)
511
+ loss = loss_func(logits, T)
512
+
513
+ loss.backward()
514
+ optimizer.step()
515
+
516
+ with warnings.catch_warnings():
517
+ warnings.simplefilter("ignore", UserWarning)
518
+ bACC = balanced_accuracy_score(T, pred_class)
519
+
520
+ dict_to_log = {
521
+ "loss_plot": float(loss),
522
+ "balanced_ACC_plot": float(bACC),
523
+ "pred_proba": pred_proba.detach().cpu().numpy(),
524
+ "gt_label": T.cpu().numpy(),
525
+ # "true_class": T.cpu().numpy(),
526
+ "i_fold": i_fold,
527
+ "i_epoch": i_epoch,
528
+ "i_global": i_global,
529
+ }
530
+ lc_logger(dict_to_log, step)
531
+
532
+ i_global += 1
533
+
534
+ lc_logger.print(step)
535
+
536
+ step = "Test"
537
+ for i_batch, batch in enumerate(dl_tes):
538
+
539
+ X, T = batch
540
+ logits = model(X)
541
+ pred_proba = softmax(logits)
542
+ pred_class = pred_proba.argmax(dim=-1)
543
+ loss = loss_func(logits, T)
544
+
545
+ with warnings.catch_warnings():
546
+ warnings.simplefilter("ignore", UserWarning)
547
+ bACC = balanced_accuracy_score(T, pred_class)
548
+
549
+ dict_to_log = {
550
+ "loss_plot": float(loss),
551
+ "balanced_ACC_plot": float(bACC),
552
+ "pred_proba": pred_proba.detach().cpu().numpy(),
553
+ # "gt_label": T.cpu().numpy(),
554
+ "true_class": T.cpu().numpy(),
555
+ "i_fold": i_fold,
556
+ "i_epoch": i_epoch,
557
+ "i_global": i_global,
558
+ }
559
+ lc_logger(dict_to_log, step)
560
+
561
+ lc_logger.print(step)
562
+
563
+ plt_config_dict = dict(
564
+ # figsize=(8.7, 10),
565
+ figscale=2.5,
566
+ labelsize=16,
567
+ fontsize=12,
568
+ legendfontsize=12,
569
+ tick_size=0.8,
570
+ tick_width=0.2,
571
+ )
572
+
573
+ fig = lc_logger.plot_learning_curves(
574
+ plt,
575
+ plt_config_dict=plt_config_dict,
576
+ title=f"fold#{i_fold}",
577
+ linewidth=1,
578
+ scattersize=50,
579
+ )
580
+ fig.show()
581
+ # scitex.gen.save(fig, sdir + f"fold#{i_fold}.png")
582
+
583
+ # EOF
@@ -0,0 +1,101 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2021-12-12 14:50:37 (ywatanabe)"
4
+
5
+
6
+ from catboost import CatBoostClassifier
7
+ from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
8
+ from sklearn.ensemble import AdaBoostClassifier
9
+ from sklearn.gaussian_process import GaussianProcessClassifier
10
+ from sklearn.linear_model import (
11
+ LogisticRegression,
12
+ PassiveAggressiveClassifier,
13
+ Perceptron,
14
+ RidgeClassifier,
15
+ SGDClassifier,
16
+ )
17
+ from sklearn.neighbors import KNeighborsClassifier
18
+ from sklearn.pipeline import make_pipeline
19
+ from sklearn.preprocessing import StandardScaler
20
+ from sklearn.svm import SVC, LinearSVC
21
+
22
+
23
+ class Classifiers(object):
24
+ """Instanciates one of scikit-learn-like Clasifiers in the same manner.
25
+
26
+ Example:
27
+ clf_server = ClassifierServer(class_weight={0:1., 1:2.}, random_state=42)
28
+ clf_str = "SVC"
29
+ clf = clf_server(clf_str, scaler=StandardScaler())
30
+
31
+ Note:
32
+ clf_str is acceptable if it is in the list below.
33
+
34
+ ['CatBoostClassifier',
35
+ 'Perceptron',
36
+ 'PassiveAggressiveClassifier',
37
+ 'LogisticRegression',
38
+ 'SGDClassifier',
39
+ 'RidgeClassifier',
40
+ 'QuadraticDiscriminantAnalysis',
41
+ 'GaussianProcessClassifier',
42
+ 'KNeighborsClassifier',
43
+ 'AdaBoostClassifier',
44
+ 'LinearSVC',
45
+ 'SVC']
46
+ """
47
+
48
+ def __init__(self, class_weight=None, random_state=42):
49
+ self.class_weight = class_weight
50
+ self.random_state = random_state
51
+
52
+ self.clf_candi = {
53
+ "CatBoostClassifier": CatBoostClassifier(
54
+ class_weights=self.class_weight, verbose=False
55
+ ),
56
+ "Perceptron": Perceptron(
57
+ penalty="l2", class_weight=self.class_weight, random_state=random_state
58
+ ),
59
+ "PassiveAggressiveClassifier": PassiveAggressiveClassifier(
60
+ class_weight=self.class_weight, random_state=random_state
61
+ ),
62
+ "LogisticRegression": LogisticRegression(
63
+ class_weight=self.class_weight, random_state=random_state
64
+ ),
65
+ "SGDClassifier": SGDClassifier(
66
+ class_weight=self.class_weight, random_state=random_state
67
+ ),
68
+ "RidgeClassifier": RidgeClassifier(
69
+ class_weight=self.class_weight, random_state=random_state
70
+ ),
71
+ "QuadraticDiscriminantAnalysis": QuadraticDiscriminantAnalysis(),
72
+ "GaussianProcessClassifier": GaussianProcessClassifier(
73
+ random_state=random_state
74
+ ),
75
+ "KNeighborsClassifier": KNeighborsClassifier(),
76
+ "AdaBoostClassifier": AdaBoostClassifier(random_state=random_state),
77
+ "LinearSVC": LinearSVC(
78
+ class_weight=self.class_weight, random_state=random_state
79
+ ),
80
+ "SVC": SVC(class_weight=self.class_weight, random_state=random_state),
81
+ }
82
+
83
+ def __call__(self, clf_str, scaler=None):
84
+ if scaler is not None:
85
+ clf = make_pipeline(scaler, self.clf_candi[clf_str]) # fixme
86
+ else:
87
+ clf = self.clf_candi[clf_str]
88
+ return clf
89
+
90
+ @property
91
+ def list(
92
+ self,
93
+ ):
94
+ clf_list = list(self.clf_candi.keys())
95
+ return clf_list
96
+
97
+
98
+ if __name__ == "__main__":
99
+ clf_server = ClassifierServer()
100
+ # l = clf_server.list
101
+ clf = clf_server("SVC", scaler=StandardScaler())