scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1161 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-02-15 01:38:28 (ywatanabe)"
4
+ # File: ./src/scitex/ai/ClassificationReporter.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/ClassificationReporter.py"
7
+
8
+ import os as _os
9
+ import random as _random
10
+ import sys as _sys
11
+ from collections import defaultdict as _defaultdict
12
+ from glob import glob as _glob
13
+ from pprint import pprint as _pprint
14
+
15
+ import matplotlib as _matplotlib
16
+ import matplotlib.pyplot as _plt
17
+ import scitex as _scitex
18
+ import numpy as _np
19
+ import pandas as _pd
20
+ import torch as _torch
21
+ from sklearn.metrics import (
22
+ balanced_accuracy_score as _balanced_accuracy_score,
23
+ classification_report as _classification_report,
24
+ confusion_matrix as _confusion_matrix,
25
+ matthews_corrcoef as _matthews_corrcoef,
26
+ )
27
+
28
+ from ..reproduce import fix_seeds as _fix_seeds
29
+
30
+
31
+ class MultiClassificationReporter(object):
32
+ def __init__(self, sdir, tgts=None):
33
+ if tgts is None:
34
+ sdirs = [""]
35
+ else:
36
+ sdirs = [_os.path.join(sdir, tgt, "/") for tgt in tgts]
37
+ sdirs = [sdir + tgt + "/" for tgt in tgts]
38
+
39
+ self.tgt2id = {tgt: i_tgt for i_tgt, tgt in enumerate(tgts)}
40
+ self.reporters = [ClassificationReporter(sdir) for sdir in sdirs]
41
+
42
+ def add(self, obj_name, obj, tgt=None):
43
+ i_tgt = self.tgt2id[tgt]
44
+ self.reporters[i_tgt].add(obj_name, obj)
45
+
46
+ def calc_metrics(
47
+ self,
48
+ true_class,
49
+ pred_class,
50
+ pred_proba,
51
+ labels=None,
52
+ i_fold=None,
53
+ show=True,
54
+ auc_plt_config=dict(
55
+ figsize=(7, 7),
56
+ labelsize=8,
57
+ fontsize=7,
58
+ legendfontsize=6,
59
+ tick_size=0.8,
60
+ tick_width=0.2,
61
+ ),
62
+ tgt=None,
63
+ ):
64
+ i_tgt = self.tgt2id[tgt]
65
+ self.reporters[i_tgt].calc_metrics(
66
+ true_class,
67
+ pred_class,
68
+ pred_proba,
69
+ labels=labels,
70
+ i_fold=i_fold,
71
+ show=show,
72
+ auc_plt_config=auc_plt_config,
73
+ )
74
+
75
+ def summarize(
76
+ self,
77
+ n_round=3,
78
+ show=False,
79
+ tgt=None,
80
+ ):
81
+ i_tgt = self.tgt2id[tgt]
82
+ self.reporters[i_tgt].summarize(
83
+ n_round=n_round,
84
+ show=show,
85
+ )
86
+
87
+ def save(
88
+ self,
89
+ files_to_reproduce=None,
90
+ meta_dict=None,
91
+ tgt=None,
92
+ ):
93
+ i_tgt = self.tgt2id[tgt]
94
+ self.reporters[i_tgt].save(
95
+ files_to_reproduce=files_to_reproduce,
96
+ meta_dict=meta_dict,
97
+ )
98
+
99
+ def plot_and_save_conf_mats(
100
+ self,
101
+ plt,
102
+ extend_ratio=1.0,
103
+ colorbar=True,
104
+ confmat_plt_config=None,
105
+ sci_notation_kwargs=None,
106
+ tgt=None,
107
+ ):
108
+ i_tgt = self.tgt2id[tgt]
109
+ self.reporters[i_tgt].plot_and_save_conf_mats(
110
+ plt,
111
+ extend_ratio=extend_ratio,
112
+ colorbar=colorbar,
113
+ confmat_plt_config=confmat_plt_config,
114
+ sci_notation_kwargs=sci_notation_kwargs,
115
+ )
116
+
117
+
118
+ class ClassificationReporter(object):
119
+ """Saves the following metrics under sdir.
120
+ - Balanced Accuracy
121
+ - MCC
122
+ - Confusion Matrix
123
+ - Classification Report
124
+ - ROC AUC score / curve
125
+ - PRE-REC AUC score / curve
126
+
127
+ Example is described in this file.
128
+ """
129
+
130
+ def __init__(self, sdir):
131
+ self.sdir = sdir
132
+ self.folds_dict = _defaultdict(list)
133
+ _fix_seeds(os=_os, random=_random, np=_np, torch=_torch, verbose=False)
134
+
135
+ def add(
136
+ self,
137
+ obj_name,
138
+ obj,
139
+ ):
140
+ """
141
+ ## fig
142
+ fig, ax = plt.subplots()
143
+ ax.plot(np.random.rand(10))
144
+ reporter.add("manu_figs", fig)
145
+
146
+ ## DataFrame
147
+ df = pd.DataFrame(np.random.rand(5, 3))
148
+ reporter.add("manu_dfs", df)
149
+
150
+ ## scalar
151
+ scalar = random.random()
152
+ reporter.add("manu_scalers", scalar)
153
+ """
154
+ assert isinstance(obj_name, str)
155
+ self.folds_dict[obj_name].append(obj)
156
+
157
+ @staticmethod
158
+ def calc_bACC(true_class, pred_class, i_fold, show=False):
159
+ """Balanced ACC"""
160
+ balanced_acc = _balanced_accuracy_score(true_class, pred_class)
161
+ if show:
162
+ print(f"\nBalanced ACC in fold#{i_fold} was {balanced_acc:.3f}\n")
163
+ return balanced_acc
164
+
165
+ @staticmethod
166
+ def calc_balanced_accuracy(true_class, pred_class, i_fold, show=False):
167
+ """Balanced accuracy (snake_case alias for calc_bACC)"""
168
+ return ClassificationReporter.calc_bACC(true_class, pred_class, i_fold, show)
169
+
170
+ @staticmethod
171
+ def calc_mcc(true_class, pred_class, i_fold, show=False):
172
+ """MCC"""
173
+ mcc = float(_matthews_corrcoef(true_class, pred_class))
174
+ if show:
175
+ print(f"\nMCC in fold#{i_fold} was {mcc:.3f}\n")
176
+ return mcc
177
+
178
+ @staticmethod
179
+ def calc_conf_mat(true_class, pred_class, labels, i_fold, show=False):
180
+ """
181
+ Confusion Matrix
182
+ This method assumes unique classes of true_class and pred_class are the same.
183
+ """
184
+ conf_mat = _pd.DataFrame(
185
+ data=_confusion_matrix(
186
+ true_class, pred_class, labels=_np.arange(len(labels))
187
+ ),
188
+ columns=labels,
189
+ ).set_index(_pd.Series(list(labels)))
190
+
191
+ if show:
192
+ print(f"\nConfusion Matrix in fold#{i_fold}: \n")
193
+ _pprint(conf_mat)
194
+ print()
195
+
196
+ return conf_mat
197
+
198
+ @staticmethod
199
+ def calc_clf_report(
200
+ true_class, pred_class, labels, balanced_acc, i_fold, show=False
201
+ ):
202
+ """Classification Report"""
203
+ clf_report = _pd.DataFrame(
204
+ _classification_report(
205
+ true_class,
206
+ pred_class,
207
+ labels=_np.arange(len(labels)),
208
+ target_names=labels,
209
+ output_dict=True,
210
+ )
211
+ )
212
+
213
+ clf_report["accuracy"] = balanced_acc
214
+ clf_report = _pd.concat(
215
+ [
216
+ clf_report[labels],
217
+ clf_report[["accuracy", "macro avg", "weighted avg"]],
218
+ ],
219
+ axis=1,
220
+ )
221
+ clf_report = clf_report.rename(columns={"accuracy": "balanced accuracy"})
222
+ clf_report = clf_report.round(3)
223
+ clf_report["index"] = clf_report.index
224
+ clf_report.loc["support", "index"] = "sample size"
225
+ clf_report.set_index("index", drop=True, inplace=True)
226
+ clf_report.index.name = None
227
+ if show:
228
+ print(f"\nClassification Report for fold#{i_fold}:\n")
229
+ _pprint(clf_report)
230
+ print()
231
+ return clf_report
232
+
233
+ def calc_AUCs(
234
+ self,
235
+ true_class,
236
+ pred_proba,
237
+ labels,
238
+ i_fold,
239
+ show=True,
240
+ auc_plt_config=dict(
241
+ figsize=(7, 7),
242
+ labelsize=8,
243
+ fontsize=7,
244
+ legendfontsize=6,
245
+ tick_size=0.8,
246
+ tick_width=0.2,
247
+ ),
248
+ ):
249
+ """ROC AUC and PRE-REC AUC."""
250
+ n_classes = len(labels)
251
+ assert len(_np.unique(true_class)) == n_classes
252
+ if n_classes == 2:
253
+ roc_auc = self._calc_AUCs_binary(
254
+ true_class,
255
+ pred_proba,
256
+ i_fold,
257
+ show=show,
258
+ auc_plt_config=auc_plt_config,
259
+ )
260
+ else:
261
+ roc_auc = self._calc_AUCs_multiple(
262
+ true_class,
263
+ pred_proba,
264
+ labels,
265
+ i_fold,
266
+ show=show,
267
+ auc_plt_config=auc_plt_config,
268
+ )
269
+ return roc_auc
270
+
271
+ def calc_aucs(self, true_class, pred_proba, labels, i_fold, show=True, auc_plt_config=None):
272
+ """Calculate AUCs (snake_case alias for calc_AUCs)"""
273
+ if auc_plt_config is None:
274
+ auc_plt_config = dict(
275
+ figsize=(7, 7),
276
+ labelsize=8,
277
+ fontsize=7,
278
+ legendfontsize=6,
279
+ tick_size=0.8,
280
+ tick_width=0.2,
281
+ )
282
+ return self.calc_AUCs(true_class, pred_proba, labels, i_fold, show, auc_plt_config)
283
+
284
+ def _calc_AUCs_binary(
285
+ self,
286
+ true_class,
287
+ pred_proba,
288
+ i_fold,
289
+ show=False,
290
+ auc_plt_config=dict(
291
+ figsize=(7, 7),
292
+ labelsize=8,
293
+ fontsize=7,
294
+ legendfontsize=6,
295
+ tick_size=0.8,
296
+ tick_width=0.2,
297
+ ),
298
+ ):
299
+ """Calculates metrics for binary classification."""
300
+ from sklearn.metrics import (
301
+ PrecisionRecallDisplay,
302
+ RocCurveDisplay,
303
+ auc,
304
+ precision_recall_curve,
305
+ roc_curve,
306
+ )
307
+
308
+ unique_classes = sorted(list(_np.unique(true_class)))
309
+ n_classes = len(unique_classes)
310
+ assert n_classes == 2, "This method is only for binary classification"
311
+
312
+ # ROC curve
313
+ fpr, tpr, _ = roc_curve(true_class, pred_proba)
314
+ roc_auc = auc(fpr, tpr)
315
+
316
+ fig_size = auc_plt_config["figsize"]
317
+ fontsize = auc_plt_config["fontsize"]
318
+ labelsize = auc_plt_config["labelsize"]
319
+ legendfontsize = auc_plt_config["legendfontsize"]
320
+ tick_size = auc_plt_config["tick_size"]
321
+ tick_width = auc_plt_config["tick_width"]
322
+
323
+ fig_roc, ax_roc = _plt.subplots(figsize=fig_size)
324
+ RocCurveDisplay(
325
+ fpr=fpr,
326
+ tpr=tpr,
327
+ roc_auc=roc_auc,
328
+ ).plot(ax=ax_roc)
329
+ ax_roc.plot([0, 1], [0, 1], "k:")
330
+ ax_roc.set_xlabel("False Positive Rate", fontsize=labelsize)
331
+ ax_roc.set_ylabel("True Positive Rate", fontsize=labelsize)
332
+ ax_roc.set_title("ROC Curve", fontsize=fontsize)
333
+ ax_roc.legend(fontsize=legendfontsize)
334
+ ax_roc.tick_params(
335
+ axis="both",
336
+ which="major",
337
+ labelsize=tick_size,
338
+ width=tick_width,
339
+ )
340
+ self.folds_dict["ROC_fig"].append(fig_roc)
341
+ if show:
342
+ print(f"\nROC AUC in fold#{i_fold} is {roc_auc:.3f}\n")
343
+
344
+ # PRE-REC curve
345
+ fig_prerec, ax_prerec = _plt.subplots(figsize=fig_size)
346
+ PrecisionRecallDisplay.from_predictions(
347
+ true_class,
348
+ pred_proba,
349
+ ax=ax_prerec,
350
+ )
351
+ ax_prerec.set_xlabel("Recall", fontsize=labelsize)
352
+ ax_prerec.set_ylabel("Precision", fontsize=labelsize)
353
+ ax_prerec.set_title("Precision-Recall Curve", fontsize=fontsize)
354
+ ax_prerec.legend(fontsize=legendfontsize)
355
+ ax_prerec.tick_params(
356
+ axis="both",
357
+ which="major",
358
+ labelsize=tick_size,
359
+ width=tick_width,
360
+ )
361
+ self.folds_dict["PRE_REC_fig"].append(fig_prerec)
362
+
363
+ return roc_auc
364
+
365
+ def _calc_aucs_binary(self, true_class, pred_proba, i_fold, show=False, auc_plt_config=None):
366
+ """Calculates metrics for binary classification (snake_case alias)."""
367
+ if auc_plt_config is None:
368
+ auc_plt_config = dict(
369
+ figsize=(7, 7),
370
+ labelsize=8,
371
+ fontsize=7,
372
+ legendfontsize=6,
373
+ tick_size=0.8,
374
+ tick_width=0.2,
375
+ )
376
+ return self._calc_AUCs_binary(true_class, pred_proba, i_fold, show, auc_plt_config)
377
+
378
+
379
+ # #!/usr/bin/env python3
380
+ # # -*- coding: utf-8 -*-
381
+ # # Time-stamp: "2024-11-20 00:15:08 (ywatanabe)"
382
+ # # File: ./scitex_repo/src/scitex/ai/ClassificationReporter.py
383
+
384
+ # THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/ClassificationReporter.py"
385
+
386
+ # #!/usr/bin/env python3
387
+ # # -*- coding: utf-8 -*-
388
+ # # Time-stamp: "2024-11-13 12:54:17 (ywatanabe)"
389
+ # # File: ./scitex_repo/src/scitex/ai/ClassificationReporter.py
390
+
391
+ # import os
392
+ # import random
393
+ # import sys
394
+ # from collections import defaultdict as _defaultdict
395
+ # from glob import glob as _glob
396
+ # from pprint import pprint as _pprint
397
+
398
+ # import matplotlib
399
+ # import matplotlib.pyplot as plt
400
+ # import scitex
401
+ # import numpy as np
402
+ # import pandas as pd
403
+ # import torch
404
+ # from sklearn.metrics import (
405
+ # balanced_accuracy_score,
406
+ # classification_report,
407
+ # confusion_matrix,
408
+ # matthews_corrcoef,
409
+ # )
410
+
411
+ # from ..reproduce import fix_seeds
412
+
413
+
414
+ # class MultiClassificationReporter(object):
415
+ # def __init__(self, sdir, tgts=None):
416
+ # if tgts is None:
417
+ # sdirs = [""]
418
+ # else:
419
+ # sdirs = [os.path.join(sdir, tgt, "/") for tgt in tgts]
420
+ # sdirs = [sdir + tgt + "/" for tgt in tgts]
421
+
422
+ # self.tgt2id = {tgt: i_tgt for i_tgt, tgt in enumerate(tgts)}
423
+ # self.reporters = [ClassificationReporter(sdir) for sdir in sdirs]
424
+
425
+ # def add(self, obj_name, obj, tgt=None):
426
+ # i_tgt = self.tgt2id[tgt]
427
+ # self.reporters[i_tgt].add(obj_name, obj)
428
+
429
+ # def calc_metrics(
430
+ # self,
431
+ # true_class,
432
+ # pred_class,
433
+ # pred_proba,
434
+ # labels=None,
435
+ # i_fold=None,
436
+ # show=True,
437
+ # auc_plt_config=dict(
438
+ # figsize=(7, 7),
439
+ # labelsize=8,
440
+ # fontsize=7,
441
+ # legendfontsize=6,
442
+ # tick_size=0.8,
443
+ # tick_width=0.2,
444
+ # ),
445
+ # tgt=None,
446
+ # ):
447
+ # i_tgt = self.tgt2id[tgt]
448
+ # self.reporters[i_tgt].calc_metrics(
449
+ # true_class,
450
+ # pred_class,
451
+ # pred_proba,
452
+ # labels=labels,
453
+ # i_fold=i_fold,
454
+ # show=show,
455
+ # auc_plt_config=auc_plt_config,
456
+ # )
457
+
458
+ # def summarize(
459
+ # self,
460
+ # n_round=3,
461
+ # show=False,
462
+ # tgt=None,
463
+ # ):
464
+ # i_tgt = self.tgt2id[tgt]
465
+ # self.reporters[i_tgt].summarize(
466
+ # n_round=n_round,
467
+ # show=show,
468
+ # )
469
+
470
+ # def save(
471
+ # self,
472
+ # files_to_reproduce=None,
473
+ # meta_dict=None,
474
+ # tgt=None,
475
+ # ):
476
+ # i_tgt = self.tgt2id[tgt]
477
+ # self.reporters[i_tgt].save(
478
+ # files_to_reproduce=files_to_reproduce,
479
+ # meta_dict=meta_dict,
480
+ # )
481
+
482
+ # def plot_and_save_conf_mats(
483
+ # self,
484
+ # plt,
485
+ # extend_ratio=1.0,
486
+ # colorbar=True,
487
+ # confmat_plt_config=None,
488
+ # sci_notation_kwargs=None,
489
+ # tgt=None,
490
+ # ):
491
+ # i_tgt = self.tgt2id[tgt]
492
+ # self.reporters[i_tgt].plot_and_save_conf_mats(
493
+ # plt,
494
+ # extend_ratio=extend_ratio,
495
+ # colorbar=colorbar,
496
+ # confmat_plt_config=confmat_plt_config,
497
+ # sci_notation_kwargs=sci_notation_kwargs,
498
+ # )
499
+
500
+
501
+ # class ClassificationReporter(object):
502
+ # """Saves the following metrics under sdir.
503
+ # - Balanced Accuracy
504
+ # - MCC
505
+ # - Confusion Matrix
506
+ # - Classification Report
507
+ # - ROC AUC score / curve
508
+ # - PRE-REC AUC score / curve
509
+
510
+ # Example is described in this file.
511
+ # """
512
+
513
+ # def __init__(self, sdir):
514
+ # self.sdir = sdir
515
+ # self.folds_dict = _defaultdict(list)
516
+ # fix_seeds(os=os, random=random, np=np, torch=torch, show=False)
517
+
518
+ # def add(
519
+ # self,
520
+ # obj_name,
521
+ # obj,
522
+ # ):
523
+ # """
524
+ # ## fig
525
+ # fig, ax = plt.subplots()
526
+ # ax.plot(np.random.rand(10))
527
+ # reporter.add("manu_figs", fig)
528
+
529
+ # ## DataFrame
530
+ # df = pd.DataFrame(np.random.rand(5, 3))
531
+ # reporter.add("manu_dfs", df)
532
+
533
+ # ## scalar
534
+ # scalar = random.random()
535
+ # reporter.add("manu_scalers", scalar)
536
+ # """
537
+ # assert isinstance(obj_name, str)
538
+ # self.folds_dict[obj_name].append(obj)
539
+
540
+ # @staticmethod
541
+ # def calc_bACC(true_class, pred_class, i_fold, show=False):
542
+ # """Balanced ACC"""
543
+ # balanced_acc = balanced_accuracy_score(true_class, pred_class)
544
+ # if show:
545
+ # print(f"\nBalanced ACC in fold#{i_fold} was {balanced_acc:.3f}\n")
546
+ # return balanced_acc
547
+
548
+ # @staticmethod
549
+ # def calc_mcc(true_class, pred_class, i_fold, show=False):
550
+ # """MCC"""
551
+ # mcc = float(matthews_corrcoef(true_class, pred_class))
552
+ # if show:
553
+ # print(f"\nMCC in fold#{i_fold} was {mcc:.3f}\n")
554
+ # return mcc
555
+
556
+ # @staticmethod
557
+ # def calc_conf_mat(true_class, pred_class, labels, i_fold, show=False):
558
+ # """
559
+ # Confusion Matrix
560
+ # This method assumes unique classes of true_class and pred_class are the same.
561
+ # """
562
+ # # conf_mat = pd.DataFrame(
563
+ # # data=confusion_matrix(true_class, pred_class),
564
+ # # columns=pred_labels,
565
+ # # index=true_labels,
566
+ # # )
567
+
568
+ # conf_mat = pd.DataFrame(
569
+ # data=confusion_matrix(
570
+ # true_class, pred_class, labels=np.arange(len(labels))
571
+ # ),
572
+ # columns=labels,
573
+ # ).set_index(pd.Series(list(labels)))
574
+
575
+ # if show:
576
+ # print(f"\nConfusion Matrix in fold#{i_fold}: \n")
577
+ # _pprint(conf_mat)
578
+ # print()
579
+
580
+ # return conf_mat
581
+
582
+ # @staticmethod
583
+ # def calc_clf_report(
584
+ # true_class, pred_class, labels, balanced_acc, i_fold, show=False
585
+ # ):
586
+ # """Classification Report"""
587
+ # clf_report = pd.DataFrame(
588
+ # classification_report(
589
+ # true_class,
590
+ # pred_class,
591
+ # labels=np.arange(len(labels)),
592
+ # target_names=labels,
593
+ # output_dict=True,
594
+ # )
595
+ # )
596
+
597
+ # # ACC to bACC
598
+ # clf_report["accuracy"] = balanced_acc
599
+ # clf_report = pd.concat(
600
+ # [
601
+ # clf_report[labels],
602
+ # clf_report[["accuracy", "macro avg", "weighted avg"]],
603
+ # ],
604
+ # axis=1,
605
+ # )
606
+ # clf_report = clf_report.rename(
607
+ # columns={"accuracy": "balanced accuracy"}
608
+ # )
609
+ # clf_report = clf_report.round(3)
610
+ # # Renames 'support' to 'sample size'
611
+ # clf_report["index"] = clf_report.index
612
+ # clf_report.loc["support", "index"] = "sample size"
613
+ # clf_report.set_index("index", drop=True, inplace=True)
614
+ # clf_report.index.name = None
615
+ # if show:
616
+ # print(f"\nClassification Report for fold#{i_fold}:\n")
617
+ # _pprint(clf_report)
618
+ # print()
619
+ # return clf_report
620
+
621
+ # @staticmethod
622
+ # def calc_and_plot_roc_curve(
623
+ # true_class, pred_proba, labels, sdir_for_csv=None
624
+ # ):
625
+ # # ROC-AUC
626
+ # fig_roc, metrics_roc_auc_dict = scitex.ml.plt.roc_auc(
627
+ # plt,
628
+ # true_class,
629
+ # pred_proba,
630
+ # labels,
631
+ # sdir_for_csv=sdir_for_csv,
632
+ # )
633
+ # plt.close()
634
+ # return fig_roc, metrics_roc_auc_dict
635
+
636
+ # @staticmethod
637
+ # def calc_and_plot_pre_rec_curve(true_class, pred_proba, labels):
638
+ # # PRE-REC AUC
639
+ # fig_pre_rec, metrics_pre_rec_auc_dict = scitex.ml.plt.pre_rec_auc(
640
+ # plt, true_class, pred_proba, labels
641
+ # )
642
+ # plt.close()
643
+ # return fig_pre_rec, metrics_pre_rec_auc_dict
644
+
645
+ # def calc_metrics(
646
+ # self,
647
+ # true_class,
648
+ # pred_class,
649
+ # pred_proba,
650
+ # labels=None,
651
+ # i_fold=None,
652
+ # show=True,
653
+ # auc_plt_config=dict(
654
+ # figsize=(7, 7),
655
+ # labelsize=8,
656
+ # fontsize=7,
657
+ # legendfontsize=6,
658
+ # tick_size=0.8,
659
+ # tick_width=0.2,
660
+ # ),
661
+ # ):
662
+ # """
663
+ # Calculates ACC, Confusion Matrix, Classification Report, and ROC-AUC score on a fold.
664
+ # Metrics and curves will be kept in self.folds_dict.
665
+ # """
666
+
667
+ # ## Preparation
668
+ # # for convenience
669
+ # true_class = scitex.gen.torch_to_arr(true_class).astype(int).reshape(-1)
670
+ # pred_class = (
671
+ # scitex.gen.torch_to_arr(pred_class).astype(np.float64).reshape(-1)
672
+ # )
673
+ # pred_proba = scitex.gen.torch_to_arr(pred_proba).astype(np.float64)
674
+
675
+ # # for curves
676
+ # scitex.plt.configure_mpl(
677
+ # plt,
678
+ # **auc_plt_config,
679
+ # )
680
+
681
+ # ## Calc metrics
682
+ # # Balanced ACC
683
+ # bacc = self.calc_bACC(true_class, pred_class, i_fold, show=show)
684
+ # self.folds_dict["balanced_acc"].append(bacc)
685
+
686
+ # # MCC
687
+ # self.folds_dict["mcc"].append(
688
+ # self.calc_mcc(true_class, pred_class, i_fold, show=show)
689
+ # )
690
+
691
+ # # Confusion Matrix
692
+ # self.folds_dict["conf_mat/conf_mat"].append(
693
+ # self.calc_conf_mat(
694
+ # true_class,
695
+ # pred_class,
696
+ # labels,
697
+ # i_fold,
698
+ # show=show,
699
+ # )
700
+ # )
701
+
702
+ # # Classification Report
703
+ # self.folds_dict["clf_report"].append(
704
+ # self.calc_clf_report(
705
+ # true_class, pred_class, labels, bacc, i_fold, show=show
706
+ # )
707
+ # )
708
+
709
+ # ## Curves
710
+ # # ROC curve
711
+ # self.sdir_for_roc_csv = f"{self.sdir}roc/csv/"
712
+ # fig_roc, metrics_roc_auc_dict = self.calc_and_plot_roc_curve(
713
+ # true_class,
714
+ # pred_proba,
715
+ # labels,
716
+ # sdir_for_csv=self.sdir_for_roc_csv + f"fold#{i_fold}/",
717
+ # )
718
+ # self.folds_dict["roc/micro"].append(
719
+ # metrics_roc_auc_dict["roc_auc"]["micro"]
720
+ # )
721
+ # self.folds_dict["roc/macro"].append(
722
+ # metrics_roc_auc_dict["roc_auc"]["macro"]
723
+ # )
724
+ # self.folds_dict["roc/figs"].append(fig_roc)
725
+
726
+ # # PRE-REC curve
727
+ # fig_pre_rec, metrics_pre_rec_auc_dict = (
728
+ # self.calc_and_plot_pre_rec_curve(true_class, pred_proba, labels)
729
+ # )
730
+ # self.folds_dict["pre_rec/micro"].append(
731
+ # metrics_pre_rec_auc_dict["pre_rec_auc"]["micro"]
732
+ # )
733
+ # self.folds_dict["pre_rec/macro"].append(
734
+ # metrics_pre_rec_auc_dict["pre_rec_auc"]["macro"]
735
+ # )
736
+ # self.folds_dict["pre_rec/figs"].append(fig_pre_rec)
737
+
738
+ # @staticmethod
739
+ # def _mk_cv_index(n_folds):
740
+ # return [
741
+ # f"{n_folds}-folds_CV_mean",
742
+ # f"{n_folds}-fold_CV_std",
743
+ # ] + [f"fold#{i_fold}" for i_fold in range(n_folds)]
744
+
745
+ # def summarize_roc(
746
+ # self,
747
+ # ):
748
+
749
+ # folds_dirs = _glob(self.sdir_for_roc_csv + "fold#*")
750
+ # n_folds = len(folds_dirs)
751
+
752
+ # # get class names
753
+ # _csv_files = _glob(os.path.join(folds_dirs[0], "*"))
754
+ # classes_str = [
755
+ # csv_file.split("/")[-1].split(".csv")[0] for csv_file in _csv_files
756
+ # ]
757
+
758
+ # # dfs_classes = []
759
+ # # take mean and std by each class
760
+ # for cls_str in classes_str:
761
+
762
+ # fpaths_cls = [
763
+ # os.path.join(fold_dir, f"{cls_str}.csv")
764
+ # for fold_dir in folds_dirs
765
+ # ]
766
+
767
+ # ys = []
768
+ # roc_aucs = []
769
+ # for fpath_cls in fpaths_cls:
770
+ # loaded_df = scitex.io.load(fpath_cls)
771
+ # ys.append(loaded_df["y"])
772
+ # roc_aucs.append(loaded_df["roc_auc"])
773
+ # ys = pd.concat(ys, axis=1)
774
+ # roc_aucs = pd.concat(roc_aucs, axis=1)
775
+
776
+ # df_cls = loaded_df[["x"]].copy()
777
+ # df_cls["y_mean"] = ys.mean(axis=1)
778
+ # df_cls["y_std"] = ys.std(axis=1)
779
+ # df_cls["roc_auc_mean"] = roc_aucs.mean(axis=1)
780
+ # df_cls["roc_auc_std"] = roc_aucs.std(axis=1)
781
+
782
+ # spath_cls = os.path.join(
783
+ # self.sdir_for_roc_csv, f"k-fold_mean_std/{cls_str}.csv"
784
+ # )
785
+ # scitex.io.save(df_cls, spath_cls)
786
+ # # dfs_classes.append(df_cls)
787
+
788
+ # def summarize(
789
+ # self,
790
+ # n_round=3,
791
+ # show=False,
792
+ # ):
793
+ # """
794
+ # 1) Take mean and std of scalars/pd.Dataframes for folds.
795
+ # 2) Replace self.folds_dict with the summarized DataFrames.
796
+ # """
797
+ # self.summarize_roc()
798
+
799
+ # _n_folds_all = [
800
+ # len(self.folds_dict[k]) for k in self.folds_dict.keys()
801
+ # ] # sometimes includes 0 because AUC curves are not always defined.
802
+ # self.n_folds_intended = max(_n_folds_all)
803
+
804
+ # for i_k, k in enumerate(self.folds_dict.keys()):
805
+ # n_folds = _n_folds_all[i_k]
806
+
807
+ # if n_folds != 0:
808
+ # ## listed scalars
809
+ # if is_listed_X(self.folds_dict[k], [float, int]):
810
+ # mm = np.mean(self.folds_dict[k])
811
+ # ss = np.std(self.folds_dict[k], ddof=1)
812
+ # sr = pd.DataFrame(
813
+ # data=[mm, ss] + self.folds_dict[k],
814
+ # index=self._mk_cv_index(n_folds),
815
+ # columns=[k],
816
+ # )
817
+ # self.folds_dict[k] = sr.round(n_round)
818
+
819
+ # ## listed pd.DataFrames
820
+ # elif is_listed_X(self.folds_dict[k], pd.DataFrame):
821
+ # zero_df_for_mm = 0 * self.folds_dict[k][0].copy()
822
+ # zero_df_for_ss = 0 * self.folds_dict[k][0].copy()
823
+
824
+ # mm = (
825
+ # zero_df_for_mm
826
+ # + np.stack(self.folds_dict[k]).mean(axis=0)
827
+ # ).round(n_round)
828
+
829
+ # ss = (
830
+ # zero_df_for_ss
831
+ # + np.stack(self.folds_dict[k]).std(axis=0, ddof=1)
832
+ # ).round(n_round)
833
+
834
+ # self.folds_dict[k] = [mm, ss] + [
835
+ # df_fold.round(n_round)
836
+ # for df_fold in self.folds_dict[k]
837
+ # ]
838
+
839
+ # if show:
840
+ # print(
841
+ # "\n----------------------------------------\n"
842
+ # f"\n{k}\n"
843
+ # f"\n{n_folds}-fold-CV mean:\n"
844
+ # )
845
+ # _pprint(self.folds_dict[k][0])
846
+ # print(f"\n\n{n_folds}-fold-CV std.:\n")
847
+ # _pprint(self.folds_dict[k][1])
848
+ # print("\n\n----------------------------------------\n")
849
+
850
+ # ## listed figures
851
+ # elif is_listed_X(self.folds_dict[k], matplotlib.figure.Figure):
852
+ # pass
853
+
854
+ # else:
855
+ # print(f"{k} was not summarized")
856
+ # print(type(self.folds_dict[k][0]))
857
+
858
+ # def save(
859
+ # self,
860
+ # files_to_reproduce=None,
861
+ # meta_dict=None,
862
+ # ):
863
+ # """
864
+ # 1) Saves the content of self.folds_dict.
865
+ # 2) Plots the colormap of confusion matrices and saves them.
866
+ # 3) Saves passed meta_dict under self.sdir
867
+
868
+ # Example:
869
+ # meta_df_1 = pd.DataFrame(data=np.random.rand(3,3))
870
+ # meta_dict_1 = {"a": 0}
871
+ # meta_dict_2 = {"b": 0}
872
+ # meta_dict = {"meta_1.csv": meta_df_1,
873
+ # "meta_1.yaml": meta_dict_1,
874
+ # "meta_2.yaml": meta_dict_1,
875
+ # }
876
+
877
+ # """
878
+ # if meta_dict is not None:
879
+ # for k, v in meta_dict.items():
880
+ # scitex.io.save(v, self.sdir + k)
881
+
882
+ # for k in self.folds_dict.keys():
883
+
884
+ # ## pd.Series / pd.DataFrame
885
+ # if isinstance(self.folds_dict[k], pd.Series) or isinstance(
886
+ # self.folds_dict[k], pd.DataFrame
887
+ # ):
888
+ # scitex.io.save(self.folds_dict[k], self.sdir + f"{k}.csv")
889
+
890
+ # ## listed pd.DataFrame
891
+ # elif is_listed_X(self.folds_dict[k], pd.DataFrame):
892
+ # scitex.io.save(
893
+ # self.folds_dict[k],
894
+ # self.sdir + f"{k}.csv",
895
+ # # indi_suffix=self.cv_index,
896
+ # indi_suffix=self._mk_cv_index(len(self.folds_dict[k])),
897
+ # )
898
+
899
+ # ## listed figures
900
+ # elif is_listed_X(self.folds_dict[k], matplotlib.figure.Figure):
901
+ # for i_fold, fig in enumerate(self.folds_dict[k]):
902
+ # scitex.io.save(
903
+ # self.folds_dict[k][i_fold],
904
+ # self.sdir + f"{k}/fold#{i_fold}.png",
905
+ # )
906
+
907
+ # else:
908
+ # print(f"{k} was not saved")
909
+ # print(type(self.folds_dict[k]))
910
+
911
+ # if files_to_reproduce is not None:
912
+ # if isinstance(files_to_reproduce, list):
913
+ # files_to_reproduce = [files_to_reproduce]
914
+ # for f in files_to_reproduce:
915
+ # scitex.io.save(f, self.sdir)
916
+
917
+ # def plot_and_save_conf_mats(
918
+ # self,
919
+ # plt,
920
+ # extend_ratio=1.0,
921
+ # colorbar=True,
922
+ # confmat_plt_config=None,
923
+ # sci_notation_kwargs=None,
924
+ # ):
925
+ # def _inner_plot_conf_mat(
926
+ # plt,
927
+ # cm_df,
928
+ # title,
929
+ # extend_ratio=1.0,
930
+ # colorbar=True,
931
+ # sci_notation_kwargs=None,
932
+ # ):
933
+ # labels = list(cm_df.columns)
934
+ # fig_conf_mat = scitex.ml.plt.confusion_matrix(
935
+ # plt,
936
+ # cm_df.T,
937
+ # labels=labels,
938
+ # title=title,
939
+ # x_extend_ratio=extend_ratio,
940
+ # y_extend_ratio=extend_ratio,
941
+ # colorbar=colorbar,
942
+ # )
943
+
944
+ # if sci_notation_kwargs is not None:
945
+ # fig_conf_mat.axes[-1] = scitex.plt.ax_scientific_notation(
946
+ # fig_conf_mat.axes[-1], **sci_notation_kwargs
947
+ # )
948
+ # return fig_conf_mat
949
+
950
+ # ## Configures mpl
951
+ # scitex.plt.configure_mpl(
952
+ # plt,
953
+ # **confmat_plt_config,
954
+ # )
955
+
956
+ # ########################################
957
+ # ## Prepares confmats dfs
958
+ # ########################################
959
+ # ## Drops mean and std for the folds
960
+ # try:
961
+ # conf_mats = self.folds_dict["conf_mat/conf_mat"][
962
+ # -self.n_folds_intended :
963
+ # ]
964
+
965
+ # except Exception as e:
966
+ # print(e)
967
+ # conf_mats = self.folds_dict["conf_mat/conf_mat"]
968
+
969
+ # ## Prepaires conf_mat_overall_sum
970
+ # conf_mat_zero = 0 * conf_mats[0].copy() # get the table format
971
+ # conf_mat_overall_sum = conf_mat_zero + np.stack(conf_mats).sum(axis=0)
972
+
973
+ # ########################################
974
+ # ## Plots & Saves
975
+ # ########################################
976
+ # # each fold's conf
977
+ # for i_fold, cm in enumerate(conf_mats):
978
+ # title = f"Test fold#{i_fold}"
979
+ # fig_conf_mat_fold = _inner_plot_conf_mat(
980
+ # plt,
981
+ # cm,
982
+ # title,
983
+ # extend_ratio=extend_ratio,
984
+ # colorbar=colorbar,
985
+ # sci_notation_kwargs=sci_notation_kwargs,
986
+ # )
987
+ # scitex.io.save(
988
+ # fig_conf_mat_fold,
989
+ # self.sdir + f"conf_mat/figs/fold#{i_fold}.png",
990
+ # )
991
+ # plt.close()
992
+
993
+ # ## overall_sum conf_mat
994
+ # title = f"{self.n_folds_intended}-CV overall sum"
995
+ # fig_conf_mat_overall_sum = _inner_plot_conf_mat(
996
+ # plt,
997
+ # conf_mat_overall_sum,
998
+ # title,
999
+ # extend_ratio=extend_ratio,
1000
+ # colorbar=colorbar,
1001
+ # sci_notation_kwargs=sci_notation_kwargs,
1002
+ # )
1003
+ # scitex.io.save(
1004
+ # fig_conf_mat_overall_sum,
1005
+ # self.sdir
1006
+ # + f"conf_mat/figs/{self.n_folds_intended}-fold_cv_overall-sum.png",
1007
+ # )
1008
+ # plt.close()
1009
+
1010
+
1011
+ # if __name__ == "__main__":
1012
+ # import random
1013
+ # import sys
1014
+
1015
+ # import scitex
1016
+ # import numpy as np
1017
+ # from catboost import CatBoostClassifier, Pool
1018
+ # from sklearn.datasets import load_digits
1019
+ # from sklearn.model_selection import StratifiedKFold
1020
+
1021
+ # ################################################################################
1022
+ # ## Sets tee
1023
+ # ################################################################################
1024
+ # sdir = scitex.io.mk_spath(
1025
+ # "./tmp/sdir-ClassificationReporter/"
1026
+ # ) # "/tmp/sdir/"
1027
+ # sys.stdout, sys.stderr = scitex.gen.tee(sys, sdir)
1028
+
1029
+ # ################################################################################
1030
+ # ## Fixes seeds
1031
+ # ################################################################################
1032
+ # fix_seeds(np=np)
1033
+
1034
+ # ## Loads
1035
+ # mnist = load_digits()
1036
+ # X, T = mnist.data, mnist.target
1037
+ # labels = mnist.target_names.astype(str)
1038
+
1039
+ # ## Main
1040
+ # skf = StratifiedKFold(n_splits=5, shuffle=True)
1041
+ # # reporter = ClassificationReporter(sdir)
1042
+ # mreporter = MultiClassificationReporter(sdir, tgts=["Test1", "Test2"])
1043
+ # for i_fold, (indi_tra, indi_tes) in enumerate(skf.split(X, T)):
1044
+ # X_tra, T_tra = X[indi_tra], T[indi_tra]
1045
+ # X_tes, T_tes = X[indi_tes], T[indi_tes]
1046
+
1047
+ # clf = CatBoostClassifier(verbose=False)
1048
+
1049
+ # clf.fit(X_tra, T_tra, verbose=False)
1050
+
1051
+ # ## Prediction
1052
+ # pred_proba_tes = clf.predict_proba(X_tes)
1053
+ # pred_cls_tes = np.argmax(pred_proba_tes, axis=1)
1054
+
1055
+ # pred_cls_tes[pred_cls_tes == 9] = 8 # overide 9 as 8 # fixme
1056
+
1057
+ # ##############################
1058
+ # ## Manually adds objects to reporter to save
1059
+ # ##############################
1060
+ # ## Figure
1061
+ # fig, ax = plt.subplots()
1062
+ # ax.plot(np.arange(10))
1063
+ # # reporter.add("manu_figs", fig)
1064
+ # mreporter.add("manu_figs", fig, tgt="Test1")
1065
+ # mreporter.add("manu_figs", fig, tgt="Test2")
1066
+
1067
+ # ## DataFrame
1068
+ # df = pd.DataFrame(np.random.rand(5, 3))
1069
+ # # reporter.add("manu_dfs", df)
1070
+ # mreporter.add("manu_dfs", df, tgt="Test1")
1071
+ # mreporter.add("manu_dfs", df, tgt="Test2")
1072
+
1073
+ # ## Scalar
1074
+ # scalar = random.random()
1075
+ # # reporter.add(
1076
+ # # "manu_scalars",
1077
+ # # scalar,
1078
+ # # )
1079
+ # mreporter.add("manu_scalars", scalar, tgt="Test1")
1080
+ # mreporter.add("manu_scalars", scalar, tgt="Test2")
1081
+
1082
+ # ########################################
1083
+ # ## Metrics
1084
+ # ########################################
1085
+ # mreporter.calc_metrics(
1086
+ # T_tes,
1087
+ # pred_cls_tes,
1088
+ # pred_proba_tes,
1089
+ # labels=labels,
1090
+ # i_fold=i_fold,
1091
+ # tgt="Test1",
1092
+ # )
1093
+ # mreporter.calc_metrics(
1094
+ # T_tes,
1095
+ # pred_cls_tes,
1096
+ # pred_proba_tes,
1097
+ # labels=labels,
1098
+ # i_fold=i_fold,
1099
+ # tgt="Test2",
1100
+ # )
1101
+
1102
+ # # reporter.summarize(show=True)
1103
+ # mreporter.summarize(show=True, tgt="Test1")
1104
+ # mreporter.summarize(show=True, tgt="Test2")
1105
+
1106
+ # fake_fpaths = ["fake_file_1.txt", "fake_file_2.txt"]
1107
+ # for ff in fake_fpaths:
1108
+ # scitex.io.touch(ff)
1109
+
1110
+ # files_to_reproduce = [
1111
+ # scitex.gen.get_this_fpath(when_ipython="/dev/null"),
1112
+ # *fake_fpaths,
1113
+ # ]
1114
+ # # reporter.save(files_to_reproduce=files_to_reproduce)
1115
+ # mreporter.save(files_to_reproduce=files_to_reproduce, tgt="Test1")
1116
+ # mreporter.save(files_to_reproduce=files_to_reproduce, tgt="Test2")
1117
+
1118
+ # confmat_plt_config = dict(
1119
+ # figsize=(8, 8),
1120
+ # # labelsize=8,
1121
+ # # fontsize=6,
1122
+ # # legendfontsize=6,
1123
+ # figscale=2,
1124
+ # tick_size=0.8,
1125
+ # tick_width=0.2,
1126
+ # )
1127
+
1128
+ # sci_notation_kwargs = dict(
1129
+ # order=1,
1130
+ # fformat="%1.0d",
1131
+ # scilimits=(-3, 3),
1132
+ # x=False,
1133
+ # y=True,
1134
+ # ) # "%3.1f"
1135
+
1136
+ # # sci_notation_kwargs = None
1137
+ # # reporter.plot_and_save_conf_mats(
1138
+ # # plt,
1139
+ # # extend_ratio=1.0,
1140
+ # # confmat_plt_config=confmat_plt_config,
1141
+ # # sci_notation_kwargs=sci_notation_kwargs,
1142
+ # # )
1143
+
1144
+ # mreporter.plot_and_save_conf_mats(
1145
+ # plt,
1146
+ # extend_ratio=1.0,
1147
+ # confmat_plt_config=confmat_plt_config,
1148
+ # sci_notation_kwargs=sci_notation_kwargs,
1149
+ # tgt="Test1",
1150
+ # )
1151
+ # mreporter.plot_and_save_conf_mats(
1152
+ # plt,
1153
+ # extend_ratio=1.0,
1154
+ # confmat_plt_config=confmat_plt_config,
1155
+ # sci_notation_kwargs=sci_notation_kwargs,
1156
+ # tgt="Test2",
1157
+ # )
1158
+
1159
+ # python -m scitex.ai.ClassificationReporter
1160
+
1161
+ # EOF