scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1137 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Timestamp: "2025-02-15 01:38:28 (ywatanabe)"
4
+ # File: ./src/scitex/ai/ClassificationReporter.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/ClassificationReporter.py"
7
+
8
+ import os as _os
9
+ import random as _random
10
+ import sys as _sys
11
+ from collections import defaultdict as _defaultdict
12
+ from glob import glob as _glob
13
+ from pprint import pprint as _pprint
14
+
15
+ import matplotlib as _matplotlib
16
+ import matplotlib.pyplot as _plt
17
+ import scitex as _scitex
18
+ import numpy as _np
19
+ import pandas as _pd
20
+ import torch as _torch
21
+ from sklearn.metrics import (
22
+ balanced_accuracy_score as _balanced_accuracy_score,
23
+ classification_report as _classification_report,
24
+ confusion_matrix as _confusion_matrix,
25
+ matthews_corrcoef as _matthews_corrcoef,
26
+ )
27
+
28
+ # Lazy import to avoid circular dependency
29
+ def _get_fix_seeds():
30
+ import scitex
31
+ return scitex.reproduce.fix_seeds
32
+
33
+ _fix_seeds = None # Will be loaded when needed
34
+
35
+
36
+ class MultiClassificationReporter(object):
37
+ def __init__(self, sdir, tgts=None):
38
+ if tgts is None:
39
+ sdirs = [""]
40
+ else:
41
+ sdirs = [_os.path.join(sdir, tgt, "/") for tgt in tgts]
42
+ sdirs = [sdir + tgt + "/" for tgt in tgts]
43
+
44
+ self.tgt2id = {tgt: i_tgt for i_tgt, tgt in enumerate(tgts)}
45
+ self.reporters = [ClassificationReporter(sdir) for sdir in sdirs]
46
+
47
+ def add(self, obj_name, obj, tgt=None):
48
+ i_tgt = self.tgt2id[tgt]
49
+ self.reporters[i_tgt].add(obj_name, obj)
50
+
51
+ def calc_metrics(
52
+ self,
53
+ true_class,
54
+ pred_class,
55
+ pred_proba,
56
+ labels=None,
57
+ i_fold=None,
58
+ show=True,
59
+ auc_plt_config=dict(
60
+ figsize=(7, 7),
61
+ labelsize=8,
62
+ fontsize=7,
63
+ legendfontsize=6,
64
+ tick_size=0.8,
65
+ tick_width=0.2,
66
+ ),
67
+ tgt=None,
68
+ ):
69
+ i_tgt = self.tgt2id[tgt]
70
+ self.reporters[i_tgt].calc_metrics(
71
+ true_class,
72
+ pred_class,
73
+ pred_proba,
74
+ labels=labels,
75
+ i_fold=i_fold,
76
+ show=show,
77
+ auc_plt_config=auc_plt_config,
78
+ )
79
+
80
+ def summarize(
81
+ self,
82
+ n_round=3,
83
+ show=False,
84
+ tgt=None,
85
+ ):
86
+ i_tgt = self.tgt2id[tgt]
87
+ self.reporters[i_tgt].summarize(
88
+ n_round=n_round,
89
+ show=show,
90
+ )
91
+
92
+ def save(
93
+ self,
94
+ files_to_reproduce=None,
95
+ meta_dict=None,
96
+ tgt=None,
97
+ ):
98
+ i_tgt = self.tgt2id[tgt]
99
+ self.reporters[i_tgt].save(
100
+ files_to_reproduce=files_to_reproduce,
101
+ meta_dict=meta_dict,
102
+ )
103
+
104
+ def plot_and_save_conf_mats(
105
+ self,
106
+ plt,
107
+ extend_ratio=1.0,
108
+ colorbar=True,
109
+ confmat_plt_config=None,
110
+ sci_notation_kwargs=None,
111
+ tgt=None,
112
+ ):
113
+ i_tgt = self.tgt2id[tgt]
114
+ self.reporters[i_tgt].plot_and_save_conf_mats(
115
+ plt,
116
+ extend_ratio=extend_ratio,
117
+ colorbar=colorbar,
118
+ confmat_plt_config=confmat_plt_config,
119
+ sci_notation_kwargs=sci_notation_kwargs,
120
+ )
121
+
122
+
123
+ class ClassificationReporter(object):
124
+ """Saves the following metrics under sdir.
125
+ - Balanced Accuracy
126
+ - MCC
127
+ - Confusion Matrix
128
+ - Classification Report
129
+ - ROC AUC score / curve
130
+ - PRE-REC AUC score / curve
131
+
132
+ Example is described in this file.
133
+ """
134
+
135
+ def __init__(self, sdir):
136
+ self.sdir = sdir
137
+ self.folds_dict = _defaultdict(list)
138
+ # Use lazy-loaded fix_seeds
139
+ fix_seeds_func = _fix_seeds or _get_fix_seeds()
140
+ fix_seeds_func(os=_os, random=_random, np=_np, torch=_torch, verbose=False)
141
+
142
+ def add(
143
+ self,
144
+ obj_name,
145
+ obj,
146
+ ):
147
+ """
148
+ ## fig
149
+ fig, ax = plt.subplots()
150
+ ax.plot(np.random.rand(10))
151
+ reporter.add("manu_figs", fig)
152
+
153
+ ## DataFrame
154
+ df = pd.DataFrame(np.random.rand(5, 3))
155
+ reporter.add("manu_dfs", df)
156
+
157
+ ## scalar
158
+ scalar = random.random()
159
+ reporter.add("manu_scalers", scalar)
160
+ """
161
+ assert isinstance(obj_name, str)
162
+ self.folds_dict[obj_name].append(obj)
163
+
164
+ @staticmethod
165
+ def calc_bACC(true_class, pred_class, i_fold, show=False):
166
+ """Balanced ACC"""
167
+ balanced_acc = _balanced_accuracy_score(true_class, pred_class)
168
+ if show:
169
+ print(f"\nBalanced ACC in fold#{i_fold} was {balanced_acc:.3f}\n")
170
+ return balanced_acc
171
+
172
+ @staticmethod
173
+ def calc_mcc(true_class, pred_class, i_fold, show=False):
174
+ """MCC"""
175
+ mcc = float(_matthews_corrcoef(true_class, pred_class))
176
+ if show:
177
+ print(f"\nMCC in fold#{i_fold} was {mcc:.3f}\n")
178
+ return mcc
179
+
180
+ @staticmethod
181
+ def calc_conf_mat(true_class, pred_class, labels, i_fold, show=False):
182
+ """
183
+ Confusion Matrix
184
+ This method assumes unique classes of true_class and pred_class are the same.
185
+ """
186
+ conf_mat = _pd.DataFrame(
187
+ data=_confusion_matrix(
188
+ true_class, pred_class, labels=_np.arange(len(labels))
189
+ ),
190
+ columns=labels,
191
+ ).set_index(_pd.Series(list(labels)))
192
+
193
+ if show:
194
+ print(f"\nConfusion Matrix in fold#{i_fold}: \n")
195
+ _pprint(conf_mat)
196
+ print()
197
+
198
+ return conf_mat
199
+
200
+ @staticmethod
201
+ def calc_clf_report(
202
+ true_class, pred_class, labels, balanced_acc, i_fold, show=False
203
+ ):
204
+ """Classification Report"""
205
+ clf_report = _pd.DataFrame(
206
+ _classification_report(
207
+ true_class,
208
+ pred_class,
209
+ labels=_np.arange(len(labels)),
210
+ target_names=labels,
211
+ output_dict=True,
212
+ )
213
+ )
214
+
215
+ clf_report["accuracy"] = balanced_acc
216
+ clf_report = _pd.concat(
217
+ [
218
+ clf_report[labels],
219
+ clf_report[["accuracy", "macro avg", "weighted avg"]],
220
+ ],
221
+ axis=1,
222
+ )
223
+ clf_report = clf_report.rename(columns={"accuracy": "balanced accuracy"})
224
+ clf_report = clf_report.round(3)
225
+ clf_report["index"] = clf_report.index
226
+ clf_report.loc["support", "index"] = "sample size"
227
+ clf_report.set_index("index", drop=True, inplace=True)
228
+ clf_report.index.name = None
229
+ if show:
230
+ print(f"\nClassification Report for fold#{i_fold}:\n")
231
+ _pprint(clf_report)
232
+ print()
233
+ return clf_report
234
+
235
+ def calc_AUCs(
236
+ self,
237
+ true_class,
238
+ pred_proba,
239
+ labels,
240
+ i_fold,
241
+ show=True,
242
+ auc_plt_config=dict(
243
+ figsize=(7, 7),
244
+ labelsize=8,
245
+ fontsize=7,
246
+ legendfontsize=6,
247
+ tick_size=0.8,
248
+ tick_width=0.2,
249
+ ),
250
+ ):
251
+ """ROC AUC and PRE-REC AUC."""
252
+ n_classes = len(labels)
253
+ assert len(_np.unique(true_class)) == n_classes
254
+ if n_classes == 2:
255
+ roc_auc = self._calc_AUCs_binary(
256
+ true_class,
257
+ pred_proba,
258
+ i_fold,
259
+ show=show,
260
+ auc_plt_config=auc_plt_config,
261
+ )
262
+ else:
263
+ roc_auc = self._calc_AUCs_multiple(
264
+ true_class,
265
+ pred_proba,
266
+ labels,
267
+ i_fold,
268
+ show=show,
269
+ auc_plt_config=auc_plt_config,
270
+ )
271
+ return roc_auc
272
+
273
+ def _calc_AUCs_binary(
274
+ self,
275
+ true_class,
276
+ pred_proba,
277
+ i_fold,
278
+ show=False,
279
+ auc_plt_config=dict(
280
+ figsize=(7, 7),
281
+ labelsize=8,
282
+ fontsize=7,
283
+ legendfontsize=6,
284
+ tick_size=0.8,
285
+ tick_width=0.2,
286
+ ),
287
+ ):
288
+ """Calculates metrics for binary classification."""
289
+ from sklearn.metrics import (
290
+ PrecisionRecallDisplay,
291
+ RocCurveDisplay,
292
+ auc,
293
+ precision_recall_curve,
294
+ roc_curve,
295
+ )
296
+
297
+ unique_classes = sorted(list(_np.unique(true_class)))
298
+ n_classes = len(unique_classes)
299
+ assert n_classes == 2, "This method is only for binary classification"
300
+
301
+ # ROC curve
302
+ fpr, tpr, _ = roc_curve(true_class, pred_proba)
303
+ roc_auc = auc(fpr, tpr)
304
+
305
+ fig_size = auc_plt_config["figsize"]
306
+ fontsize = auc_plt_config["fontsize"]
307
+ labelsize = auc_plt_config["labelsize"]
308
+ legendfontsize = auc_plt_config["legendfontsize"]
309
+ tick_size = auc_plt_config["tick_size"]
310
+ tick_width = auc_plt_config["tick_width"]
311
+
312
+ fig_roc, ax_roc = _plt.subplots(figsize=fig_size)
313
+ RocCurveDisplay(
314
+ fpr=fpr,
315
+ tpr=tpr,
316
+ roc_auc=roc_auc,
317
+ ).plot(ax=ax_roc)
318
+ ax_roc.plot([0, 1], [0, 1], "k:")
319
+ ax_roc.set_xlabel("False Positive Rate", fontsize=labelsize)
320
+ ax_roc.set_ylabel("True Positive Rate", fontsize=labelsize)
321
+ ax_roc.set_title("ROC Curve", fontsize=fontsize)
322
+ ax_roc.legend(fontsize=legendfontsize)
323
+ ax_roc.tick_params(
324
+ axis="both",
325
+ which="major",
326
+ labelsize=tick_size,
327
+ width=tick_width,
328
+ )
329
+ self.folds_dict["ROC_fig"].append(fig_roc)
330
+ if show:
331
+ print(f"\nROC AUC in fold#{i_fold} is {roc_auc:.3f}\n")
332
+
333
+ # PRE-REC curve
334
+ fig_prerec, ax_prerec = _plt.subplots(figsize=fig_size)
335
+ PrecisionRecallDisplay.from_predictions(
336
+ true_class,
337
+ pred_proba,
338
+ ax=ax_prerec,
339
+ )
340
+ ax_prerec.set_xlabel("Recall", fontsize=labelsize)
341
+ ax_prerec.set_ylabel("Precision", fontsize=labelsize)
342
+ ax_prerec.set_title("Precision-Recall Curve", fontsize=fontsize)
343
+ ax_prerec.legend(fontsize=legendfontsize)
344
+ ax_prerec.tick_params(
345
+ axis="both",
346
+ which="major",
347
+ labelsize=tick_size,
348
+ width=tick_width,
349
+ )
350
+ self.folds_dict["PRE_REC_fig"].append(fig_prerec)
351
+
352
+ return roc_auc
353
+
354
+
355
+ # #!/usr/bin/env python3
356
+ # # -*- coding: utf-8 -*-
357
+ # # Time-stamp: "2024-11-20 00:15:08 (ywatanabe)"
358
+ # # File: ./scitex_repo/src/scitex/ai/ClassificationReporter.py
359
+
360
+ # THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/ClassificationReporter.py"
361
+
362
+ # #!/usr/bin/env python3
363
+ # # -*- coding: utf-8 -*-
364
+ # # Time-stamp: "2024-11-13 12:54:17 (ywatanabe)"
365
+ # # File: ./scitex_repo/src/scitex/ai/ClassificationReporter.py
366
+
367
+ # import os
368
+ # import random
369
+ # import sys
370
+ # from collections import defaultdict as _defaultdict
371
+ # from glob import glob as _glob
372
+ # from pprint import pprint as _pprint
373
+
374
+ # import matplotlib
375
+ # import matplotlib.pyplot as plt
376
+ # import scitex
377
+ # import numpy as np
378
+ # import pandas as pd
379
+ # import torch
380
+ # from sklearn.metrics import (
381
+ # balanced_accuracy_score,
382
+ # classification_report,
383
+ # confusion_matrix,
384
+ # matthews_corrcoef,
385
+ # )
386
+
387
+ # from ..reproduce import fix_seeds
388
+
389
+
390
+ # class MultiClassificationReporter(object):
391
+ # def __init__(self, sdir, tgts=None):
392
+ # if tgts is None:
393
+ # sdirs = [""]
394
+ # else:
395
+ # sdirs = [os.path.join(sdir, tgt, "/") for tgt in tgts]
396
+ # sdirs = [sdir + tgt + "/" for tgt in tgts]
397
+
398
+ # self.tgt2id = {tgt: i_tgt for i_tgt, tgt in enumerate(tgts)}
399
+ # self.reporters = [ClassificationReporter(sdir) for sdir in sdirs]
400
+
401
+ # def add(self, obj_name, obj, tgt=None):
402
+ # i_tgt = self.tgt2id[tgt]
403
+ # self.reporters[i_tgt].add(obj_name, obj)
404
+
405
+ # def calc_metrics(
406
+ # self,
407
+ # true_class,
408
+ # pred_class,
409
+ # pred_proba,
410
+ # labels=None,
411
+ # i_fold=None,
412
+ # show=True,
413
+ # auc_plt_config=dict(
414
+ # figsize=(7, 7),
415
+ # labelsize=8,
416
+ # fontsize=7,
417
+ # legendfontsize=6,
418
+ # tick_size=0.8,
419
+ # tick_width=0.2,
420
+ # ),
421
+ # tgt=None,
422
+ # ):
423
+ # i_tgt = self.tgt2id[tgt]
424
+ # self.reporters[i_tgt].calc_metrics(
425
+ # true_class,
426
+ # pred_class,
427
+ # pred_proba,
428
+ # labels=labels,
429
+ # i_fold=i_fold,
430
+ # show=show,
431
+ # auc_plt_config=auc_plt_config,
432
+ # )
433
+
434
+ # def summarize(
435
+ # self,
436
+ # n_round=3,
437
+ # show=False,
438
+ # tgt=None,
439
+ # ):
440
+ # i_tgt = self.tgt2id[tgt]
441
+ # self.reporters[i_tgt].summarize(
442
+ # n_round=n_round,
443
+ # show=show,
444
+ # )
445
+
446
+ # def save(
447
+ # self,
448
+ # files_to_reproduce=None,
449
+ # meta_dict=None,
450
+ # tgt=None,
451
+ # ):
452
+ # i_tgt = self.tgt2id[tgt]
453
+ # self.reporters[i_tgt].save(
454
+ # files_to_reproduce=files_to_reproduce,
455
+ # meta_dict=meta_dict,
456
+ # )
457
+
458
+ # def plot_and_save_conf_mats(
459
+ # self,
460
+ # plt,
461
+ # extend_ratio=1.0,
462
+ # colorbar=True,
463
+ # confmat_plt_config=None,
464
+ # sci_notation_kwargs=None,
465
+ # tgt=None,
466
+ # ):
467
+ # i_tgt = self.tgt2id[tgt]
468
+ # self.reporters[i_tgt].plot_and_save_conf_mats(
469
+ # plt,
470
+ # extend_ratio=extend_ratio,
471
+ # colorbar=colorbar,
472
+ # confmat_plt_config=confmat_plt_config,
473
+ # sci_notation_kwargs=sci_notation_kwargs,
474
+ # )
475
+
476
+
477
+ # class ClassificationReporter(object):
478
+ # """Saves the following metrics under sdir.
479
+ # - Balanced Accuracy
480
+ # - MCC
481
+ # - Confusion Matrix
482
+ # - Classification Report
483
+ # - ROC AUC score / curve
484
+ # - PRE-REC AUC score / curve
485
+
486
+ # Example is described in this file.
487
+ # """
488
+
489
+ # def __init__(self, sdir):
490
+ # self.sdir = sdir
491
+ # self.folds_dict = _defaultdict(list)
492
+ # fix_seeds(os=os, random=random, np=np, torch=torch, show=False)
493
+
494
+ # def add(
495
+ # self,
496
+ # obj_name,
497
+ # obj,
498
+ # ):
499
+ # """
500
+ # ## fig
501
+ # fig, ax = plt.subplots()
502
+ # ax.plot(np.random.rand(10))
503
+ # reporter.add("manu_figs", fig)
504
+
505
+ # ## DataFrame
506
+ # df = pd.DataFrame(np.random.rand(5, 3))
507
+ # reporter.add("manu_dfs", df)
508
+
509
+ # ## scalar
510
+ # scalar = random.random()
511
+ # reporter.add("manu_scalers", scalar)
512
+ # """
513
+ # assert isinstance(obj_name, str)
514
+ # self.folds_dict[obj_name].append(obj)
515
+
516
+ # @staticmethod
517
+ # def calc_bACC(true_class, pred_class, i_fold, show=False):
518
+ # """Balanced ACC"""
519
+ # balanced_acc = balanced_accuracy_score(true_class, pred_class)
520
+ # if show:
521
+ # print(f"\nBalanced ACC in fold#{i_fold} was {balanced_acc:.3f}\n")
522
+ # return balanced_acc
523
+
524
+ # @staticmethod
525
+ # def calc_mcc(true_class, pred_class, i_fold, show=False):
526
+ # """MCC"""
527
+ # mcc = float(matthews_corrcoef(true_class, pred_class))
528
+ # if show:
529
+ # print(f"\nMCC in fold#{i_fold} was {mcc:.3f}\n")
530
+ # return mcc
531
+
532
+ # @staticmethod
533
+ # def calc_conf_mat(true_class, pred_class, labels, i_fold, show=False):
534
+ # """
535
+ # Confusion Matrix
536
+ # This method assumes unique classes of true_class and pred_class are the same.
537
+ # """
538
+ # # conf_mat = pd.DataFrame(
539
+ # # data=confusion_matrix(true_class, pred_class),
540
+ # # columns=pred_labels,
541
+ # # index=true_labels,
542
+ # # )
543
+
544
+ # conf_mat = pd.DataFrame(
545
+ # data=confusion_matrix(
546
+ # true_class, pred_class, labels=np.arange(len(labels))
547
+ # ),
548
+ # columns=labels,
549
+ # ).set_index(pd.Series(list(labels)))
550
+
551
+ # if show:
552
+ # print(f"\nConfusion Matrix in fold#{i_fold}: \n")
553
+ # _pprint(conf_mat)
554
+ # print()
555
+
556
+ # return conf_mat
557
+
558
+ # @staticmethod
559
+ # def calc_clf_report(
560
+ # true_class, pred_class, labels, balanced_acc, i_fold, show=False
561
+ # ):
562
+ # """Classification Report"""
563
+ # clf_report = pd.DataFrame(
564
+ # classification_report(
565
+ # true_class,
566
+ # pred_class,
567
+ # labels=np.arange(len(labels)),
568
+ # target_names=labels,
569
+ # output_dict=True,
570
+ # )
571
+ # )
572
+
573
+ # # ACC to bACC
574
+ # clf_report["accuracy"] = balanced_acc
575
+ # clf_report = pd.concat(
576
+ # [
577
+ # clf_report[labels],
578
+ # clf_report[["accuracy", "macro avg", "weighted avg"]],
579
+ # ],
580
+ # axis=1,
581
+ # )
582
+ # clf_report = clf_report.rename(
583
+ # columns={"accuracy": "balanced accuracy"}
584
+ # )
585
+ # clf_report = clf_report.round(3)
586
+ # # Renames 'support' to 'sample size'
587
+ # clf_report["index"] = clf_report.index
588
+ # clf_report.loc["support", "index"] = "sample size"
589
+ # clf_report.set_index("index", drop=True, inplace=True)
590
+ # clf_report.index.name = None
591
+ # if show:
592
+ # print(f"\nClassification Report for fold#{i_fold}:\n")
593
+ # _pprint(clf_report)
594
+ # print()
595
+ # return clf_report
596
+
597
+ # @staticmethod
598
+ # def calc_and_plot_roc_curve(
599
+ # true_class, pred_proba, labels, sdir_for_csv=None
600
+ # ):
601
+ # # ROC-AUC
602
+ # fig_roc, metrics_roc_auc_dict = scitex.ml.plt.roc_auc(
603
+ # plt,
604
+ # true_class,
605
+ # pred_proba,
606
+ # labels,
607
+ # sdir_for_csv=sdir_for_csv,
608
+ # )
609
+ # plt.close()
610
+ # return fig_roc, metrics_roc_auc_dict
611
+
612
+ # @staticmethod
613
+ # def calc_and_plot_pre_rec_curve(true_class, pred_proba, labels):
614
+ # # PRE-REC AUC
615
+ # fig_pre_rec, metrics_pre_rec_auc_dict = scitex.ml.plt.pre_rec_auc(
616
+ # plt, true_class, pred_proba, labels
617
+ # )
618
+ # plt.close()
619
+ # return fig_pre_rec, metrics_pre_rec_auc_dict
620
+
621
+ # def calc_metrics(
622
+ # self,
623
+ # true_class,
624
+ # pred_class,
625
+ # pred_proba,
626
+ # labels=None,
627
+ # i_fold=None,
628
+ # show=True,
629
+ # auc_plt_config=dict(
630
+ # figsize=(7, 7),
631
+ # labelsize=8,
632
+ # fontsize=7,
633
+ # legendfontsize=6,
634
+ # tick_size=0.8,
635
+ # tick_width=0.2,
636
+ # ),
637
+ # ):
638
+ # """
639
+ # Calculates ACC, Confusion Matrix, Classification Report, and ROC-AUC score on a fold.
640
+ # Metrics and curves will be kept in self.folds_dict.
641
+ # """
642
+
643
+ # ## Preparation
644
+ # # for convenience
645
+ # true_class = scitex.gen.torch_to_arr(true_class).astype(int).reshape(-1)
646
+ # pred_class = (
647
+ # scitex.gen.torch_to_arr(pred_class).astype(np.float64).reshape(-1)
648
+ # )
649
+ # pred_proba = scitex.gen.torch_to_arr(pred_proba).astype(np.float64)
650
+
651
+ # # for curves
652
+ # scitex.plt.configure_mpl(
653
+ # plt,
654
+ # **auc_plt_config,
655
+ # )
656
+
657
+ # ## Calc metrics
658
+ # # Balanced ACC
659
+ # bacc = self.calc_bACC(true_class, pred_class, i_fold, show=show)
660
+ # self.folds_dict["balanced_acc"].append(bacc)
661
+
662
+ # # MCC
663
+ # self.folds_dict["mcc"].append(
664
+ # self.calc_mcc(true_class, pred_class, i_fold, show=show)
665
+ # )
666
+
667
+ # # Confusion Matrix
668
+ # self.folds_dict["conf_mat/conf_mat"].append(
669
+ # self.calc_conf_mat(
670
+ # true_class,
671
+ # pred_class,
672
+ # labels,
673
+ # i_fold,
674
+ # show=show,
675
+ # )
676
+ # )
677
+
678
+ # # Classification Report
679
+ # self.folds_dict["clf_report"].append(
680
+ # self.calc_clf_report(
681
+ # true_class, pred_class, labels, bacc, i_fold, show=show
682
+ # )
683
+ # )
684
+
685
+ # ## Curves
686
+ # # ROC curve
687
+ # self.sdir_for_roc_csv = f"{self.sdir}roc/csv/"
688
+ # fig_roc, metrics_roc_auc_dict = self.calc_and_plot_roc_curve(
689
+ # true_class,
690
+ # pred_proba,
691
+ # labels,
692
+ # sdir_for_csv=self.sdir_for_roc_csv + f"fold#{i_fold}/",
693
+ # )
694
+ # self.folds_dict["roc/micro"].append(
695
+ # metrics_roc_auc_dict["roc_auc"]["micro"]
696
+ # )
697
+ # self.folds_dict["roc/macro"].append(
698
+ # metrics_roc_auc_dict["roc_auc"]["macro"]
699
+ # )
700
+ # self.folds_dict["roc/figs"].append(fig_roc)
701
+
702
+ # # PRE-REC curve
703
+ # fig_pre_rec, metrics_pre_rec_auc_dict = (
704
+ # self.calc_and_plot_pre_rec_curve(true_class, pred_proba, labels)
705
+ # )
706
+ # self.folds_dict["pre_rec/micro"].append(
707
+ # metrics_pre_rec_auc_dict["pre_rec_auc"]["micro"]
708
+ # )
709
+ # self.folds_dict["pre_rec/macro"].append(
710
+ # metrics_pre_rec_auc_dict["pre_rec_auc"]["macro"]
711
+ # )
712
+ # self.folds_dict["pre_rec/figs"].append(fig_pre_rec)
713
+
714
+ # @staticmethod
715
+ # def _mk_cv_index(n_folds):
716
+ # return [
717
+ # f"{n_folds}-folds_CV_mean",
718
+ # f"{n_folds}-fold_CV_std",
719
+ # ] + [f"fold#{i_fold}" for i_fold in range(n_folds)]
720
+
721
+ # def summarize_roc(
722
+ # self,
723
+ # ):
724
+
725
+ # folds_dirs = _glob(self.sdir_for_roc_csv + "fold#*")
726
+ # n_folds = len(folds_dirs)
727
+
728
+ # # get class names
729
+ # _csv_files = _glob(os.path.join(folds_dirs[0], "*"))
730
+ # classes_str = [
731
+ # csv_file.split("/")[-1].split(".csv")[0] for csv_file in _csv_files
732
+ # ]
733
+
734
+ # # dfs_classes = []
735
+ # # take mean and std by each class
736
+ # for cls_str in classes_str:
737
+
738
+ # fpaths_cls = [
739
+ # os.path.join(fold_dir, f"{cls_str}.csv")
740
+ # for fold_dir in folds_dirs
741
+ # ]
742
+
743
+ # ys = []
744
+ # roc_aucs = []
745
+ # for fpath_cls in fpaths_cls:
746
+ # loaded_df = scitex.io.load(fpath_cls)
747
+ # ys.append(loaded_df["y"])
748
+ # roc_aucs.append(loaded_df["roc_auc"])
749
+ # ys = pd.concat(ys, axis=1)
750
+ # roc_aucs = pd.concat(roc_aucs, axis=1)
751
+
752
+ # df_cls = loaded_df[["x"]].copy()
753
+ # df_cls["y_mean"] = ys.mean(axis=1)
754
+ # df_cls["y_std"] = ys.std(axis=1)
755
+ # df_cls["roc_auc_mean"] = roc_aucs.mean(axis=1)
756
+ # df_cls["roc_auc_std"] = roc_aucs.std(axis=1)
757
+
758
+ # spath_cls = os.path.join(
759
+ # self.sdir_for_roc_csv, f"k-fold_mean_std/{cls_str}.csv"
760
+ # )
761
+ # scitex.io.save(df_cls, spath_cls)
762
+ # # dfs_classes.append(df_cls)
763
+
764
+ # def summarize(
765
+ # self,
766
+ # n_round=3,
767
+ # show=False,
768
+ # ):
769
+ # """
770
+ # 1) Take mean and std of scalars/pd.Dataframes for folds.
771
+ # 2) Replace self.folds_dict with the summarized DataFrames.
772
+ # """
773
+ # self.summarize_roc()
774
+
775
+ # _n_folds_all = [
776
+ # len(self.folds_dict[k]) for k in self.folds_dict.keys()
777
+ # ] # sometimes includes 0 because AUC curves are not always defined.
778
+ # self.n_folds_intended = max(_n_folds_all)
779
+
780
+ # for i_k, k in enumerate(self.folds_dict.keys()):
781
+ # n_folds = _n_folds_all[i_k]
782
+
783
+ # if n_folds != 0:
784
+ # ## listed scalars
785
+ # if is_listed_X(self.folds_dict[k], [float, int]):
786
+ # mm = np.mean(self.folds_dict[k])
787
+ # ss = np.std(self.folds_dict[k], ddof=1)
788
+ # sr = pd.DataFrame(
789
+ # data=[mm, ss] + self.folds_dict[k],
790
+ # index=self._mk_cv_index(n_folds),
791
+ # columns=[k],
792
+ # )
793
+ # self.folds_dict[k] = sr.round(n_round)
794
+
795
+ # ## listed pd.DataFrames
796
+ # elif is_listed_X(self.folds_dict[k], pd.DataFrame):
797
+ # zero_df_for_mm = 0 * self.folds_dict[k][0].copy()
798
+ # zero_df_for_ss = 0 * self.folds_dict[k][0].copy()
799
+
800
+ # mm = (
801
+ # zero_df_for_mm
802
+ # + np.stack(self.folds_dict[k]).mean(axis=0)
803
+ # ).round(n_round)
804
+
805
+ # ss = (
806
+ # zero_df_for_ss
807
+ # + np.stack(self.folds_dict[k]).std(axis=0, ddof=1)
808
+ # ).round(n_round)
809
+
810
+ # self.folds_dict[k] = [mm, ss] + [
811
+ # df_fold.round(n_round)
812
+ # for df_fold in self.folds_dict[k]
813
+ # ]
814
+
815
+ # if show:
816
+ # print(
817
+ # "\n----------------------------------------\n"
818
+ # f"\n{k}\n"
819
+ # f"\n{n_folds}-fold-CV mean:\n"
820
+ # )
821
+ # _pprint(self.folds_dict[k][0])
822
+ # print(f"\n\n{n_folds}-fold-CV std.:\n")
823
+ # _pprint(self.folds_dict[k][1])
824
+ # print("\n\n----------------------------------------\n")
825
+
826
+ # ## listed figures
827
+ # elif is_listed_X(self.folds_dict[k], matplotlib.figure.Figure):
828
+ # pass
829
+
830
+ # else:
831
+ # print(f"{k} was not summarized")
832
+ # print(type(self.folds_dict[k][0]))
833
+
834
+ # def save(
835
+ # self,
836
+ # files_to_reproduce=None,
837
+ # meta_dict=None,
838
+ # ):
839
+ # """
840
+ # 1) Saves the content of self.folds_dict.
841
+ # 2) Plots the colormap of confusion matrices and saves them.
842
+ # 3) Saves passed meta_dict under self.sdir
843
+
844
+ # Example:
845
+ # meta_df_1 = pd.DataFrame(data=np.random.rand(3,3))
846
+ # meta_dict_1 = {"a": 0}
847
+ # meta_dict_2 = {"b": 0}
848
+ # meta_dict = {"meta_1.csv": meta_df_1,
849
+ # "meta_1.yaml": meta_dict_1,
850
+ # "meta_2.yaml": meta_dict_1,
851
+ # }
852
+
853
+ # """
854
+ # if meta_dict is not None:
855
+ # for k, v in meta_dict.items():
856
+ # scitex.io.save(v, self.sdir + k)
857
+
858
+ # for k in self.folds_dict.keys():
859
+
860
+ # ## pd.Series / pd.DataFrame
861
+ # if isinstance(self.folds_dict[k], pd.Series) or isinstance(
862
+ # self.folds_dict[k], pd.DataFrame
863
+ # ):
864
+ # scitex.io.save(self.folds_dict[k], self.sdir + f"{k}.csv")
865
+
866
+ # ## listed pd.DataFrame
867
+ # elif is_listed_X(self.folds_dict[k], pd.DataFrame):
868
+ # scitex.io.save(
869
+ # self.folds_dict[k],
870
+ # self.sdir + f"{k}.csv",
871
+ # # indi_suffix=self.cv_index,
872
+ # indi_suffix=self._mk_cv_index(len(self.folds_dict[k])),
873
+ # )
874
+
875
+ # ## listed figures
876
+ # elif is_listed_X(self.folds_dict[k], matplotlib.figure.Figure):
877
+ # for i_fold, fig in enumerate(self.folds_dict[k]):
878
+ # scitex.io.save(
879
+ # self.folds_dict[k][i_fold],
880
+ # self.sdir + f"{k}/fold#{i_fold}.png",
881
+ # )
882
+
883
+ # else:
884
+ # print(f"{k} was not saved")
885
+ # print(type(self.folds_dict[k]))
886
+
887
+ # if files_to_reproduce is not None:
888
+ # if isinstance(files_to_reproduce, list):
889
+ # files_to_reproduce = [files_to_reproduce]
890
+ # for f in files_to_reproduce:
891
+ # scitex.io.save(f, self.sdir)
892
+
893
+ # def plot_and_save_conf_mats(
894
+ # self,
895
+ # plt,
896
+ # extend_ratio=1.0,
897
+ # colorbar=True,
898
+ # confmat_plt_config=None,
899
+ # sci_notation_kwargs=None,
900
+ # ):
901
+ # def _inner_plot_conf_mat(
902
+ # plt,
903
+ # cm_df,
904
+ # title,
905
+ # extend_ratio=1.0,
906
+ # colorbar=True,
907
+ # sci_notation_kwargs=None,
908
+ # ):
909
+ # labels = list(cm_df.columns)
910
+ # fig_conf_mat = scitex.ml.plt.confusion_matrix(
911
+ # plt,
912
+ # cm_df.T,
913
+ # labels=labels,
914
+ # title=title,
915
+ # x_extend_ratio=extend_ratio,
916
+ # y_extend_ratio=extend_ratio,
917
+ # colorbar=colorbar,
918
+ # )
919
+
920
+ # if sci_notation_kwargs is not None:
921
+ # fig_conf_mat.axes[-1] = scitex.plt.ax_scientific_notation(
922
+ # fig_conf_mat.axes[-1], **sci_notation_kwargs
923
+ # )
924
+ # return fig_conf_mat
925
+
926
+ # ## Configures mpl
927
+ # scitex.plt.configure_mpl(
928
+ # plt,
929
+ # **confmat_plt_config,
930
+ # )
931
+
932
+ # ########################################
933
+ # ## Prepares confmats dfs
934
+ # ########################################
935
+ # ## Drops mean and std for the folds
936
+ # try:
937
+ # conf_mats = self.folds_dict["conf_mat/conf_mat"][
938
+ # -self.n_folds_intended :
939
+ # ]
940
+
941
+ # except Exception as e:
942
+ # print(e)
943
+ # conf_mats = self.folds_dict["conf_mat/conf_mat"]
944
+
945
+ # ## Prepaires conf_mat_overall_sum
946
+ # conf_mat_zero = 0 * conf_mats[0].copy() # get the table format
947
+ # conf_mat_overall_sum = conf_mat_zero + np.stack(conf_mats).sum(axis=0)
948
+
949
+ # ########################################
950
+ # ## Plots & Saves
951
+ # ########################################
952
+ # # each fold's conf
953
+ # for i_fold, cm in enumerate(conf_mats):
954
+ # title = f"Test fold#{i_fold}"
955
+ # fig_conf_mat_fold = _inner_plot_conf_mat(
956
+ # plt,
957
+ # cm,
958
+ # title,
959
+ # extend_ratio=extend_ratio,
960
+ # colorbar=colorbar,
961
+ # sci_notation_kwargs=sci_notation_kwargs,
962
+ # )
963
+ # scitex.io.save(
964
+ # fig_conf_mat_fold,
965
+ # self.sdir + f"conf_mat/figs/fold#{i_fold}.png",
966
+ # )
967
+ # plt.close()
968
+
969
+ # ## overall_sum conf_mat
970
+ # title = f"{self.n_folds_intended}-CV overall sum"
971
+ # fig_conf_mat_overall_sum = _inner_plot_conf_mat(
972
+ # plt,
973
+ # conf_mat_overall_sum,
974
+ # title,
975
+ # extend_ratio=extend_ratio,
976
+ # colorbar=colorbar,
977
+ # sci_notation_kwargs=sci_notation_kwargs,
978
+ # )
979
+ # scitex.io.save(
980
+ # fig_conf_mat_overall_sum,
981
+ # self.sdir
982
+ # + f"conf_mat/figs/{self.n_folds_intended}-fold_cv_overall-sum.png",
983
+ # )
984
+ # plt.close()
985
+
986
+
987
+ # if __name__ == "__main__":
988
+ # import random
989
+ # import sys
990
+
991
+ # import scitex
992
+ # import numpy as np
993
+ # from catboost import CatBoostClassifier, Pool
994
+ # from sklearn.datasets import load_digits
995
+ # from sklearn.model_selection import StratifiedKFold
996
+
997
+ # ################################################################################
998
+ # ## Sets tee
999
+ # ################################################################################
1000
+ # sdir = scitex.io.mk_spath(
1001
+ # "./tmp/sdir-ClassificationReporter/"
1002
+ # ) # "/tmp/sdir/"
1003
+ # sys.stdout, sys.stderr = scitex.gen.tee(sys, sdir)
1004
+
1005
+ # ################################################################################
1006
+ # ## Fixes seeds
1007
+ # ################################################################################
1008
+ # fix_seeds(np=np)
1009
+
1010
+ # ## Loads
1011
+ # mnist = load_digits()
1012
+ # X, T = mnist.data, mnist.target
1013
+ # labels = mnist.target_names.astype(str)
1014
+
1015
+ # ## Main
1016
+ # skf = StratifiedKFold(n_splits=5, shuffle=True)
1017
+ # # reporter = ClassificationReporter(sdir)
1018
+ # mreporter = MultiClassificationReporter(sdir, tgts=["Test1", "Test2"])
1019
+ # for i_fold, (indi_tra, indi_tes) in enumerate(skf.split(X, T)):
1020
+ # X_tra, T_tra = X[indi_tra], T[indi_tra]
1021
+ # X_tes, T_tes = X[indi_tes], T[indi_tes]
1022
+
1023
+ # clf = CatBoostClassifier(verbose=False)
1024
+
1025
+ # clf.fit(X_tra, T_tra, verbose=False)
1026
+
1027
+ # ## Prediction
1028
+ # pred_proba_tes = clf.predict_proba(X_tes)
1029
+ # pred_cls_tes = np.argmax(pred_proba_tes, axis=1)
1030
+
1031
+ # pred_cls_tes[pred_cls_tes == 9] = 8 # overide 9 as 8 # fixme
1032
+
1033
+ # ##############################
1034
+ # ## Manually adds objects to reporter to save
1035
+ # ##############################
1036
+ # ## Figure
1037
+ # fig, ax = plt.subplots()
1038
+ # ax.plot(np.arange(10))
1039
+ # # reporter.add("manu_figs", fig)
1040
+ # mreporter.add("manu_figs", fig, tgt="Test1")
1041
+ # mreporter.add("manu_figs", fig, tgt="Test2")
1042
+
1043
+ # ## DataFrame
1044
+ # df = pd.DataFrame(np.random.rand(5, 3))
1045
+ # # reporter.add("manu_dfs", df)
1046
+ # mreporter.add("manu_dfs", df, tgt="Test1")
1047
+ # mreporter.add("manu_dfs", df, tgt="Test2")
1048
+
1049
+ # ## Scalar
1050
+ # scalar = random.random()
1051
+ # # reporter.add(
1052
+ # # "manu_scalars",
1053
+ # # scalar,
1054
+ # # )
1055
+ # mreporter.add("manu_scalars", scalar, tgt="Test1")
1056
+ # mreporter.add("manu_scalars", scalar, tgt="Test2")
1057
+
1058
+ # ########################################
1059
+ # ## Metrics
1060
+ # ########################################
1061
+ # mreporter.calc_metrics(
1062
+ # T_tes,
1063
+ # pred_cls_tes,
1064
+ # pred_proba_tes,
1065
+ # labels=labels,
1066
+ # i_fold=i_fold,
1067
+ # tgt="Test1",
1068
+ # )
1069
+ # mreporter.calc_metrics(
1070
+ # T_tes,
1071
+ # pred_cls_tes,
1072
+ # pred_proba_tes,
1073
+ # labels=labels,
1074
+ # i_fold=i_fold,
1075
+ # tgt="Test2",
1076
+ # )
1077
+
1078
+ # # reporter.summarize(show=True)
1079
+ # mreporter.summarize(show=True, tgt="Test1")
1080
+ # mreporter.summarize(show=True, tgt="Test2")
1081
+
1082
+ # fake_fpaths = ["fake_file_1.txt", "fake_file_2.txt"]
1083
+ # for ff in fake_fpaths:
1084
+ # scitex.io.touch(ff)
1085
+
1086
+ # files_to_reproduce = [
1087
+ # scitex.gen.get_this_fpath(when_ipython="/dev/null"),
1088
+ # *fake_fpaths,
1089
+ # ]
1090
+ # # reporter.save(files_to_reproduce=files_to_reproduce)
1091
+ # mreporter.save(files_to_reproduce=files_to_reproduce, tgt="Test1")
1092
+ # mreporter.save(files_to_reproduce=files_to_reproduce, tgt="Test2")
1093
+
1094
+ # confmat_plt_config = dict(
1095
+ # figsize=(8, 8),
1096
+ # # labelsize=8,
1097
+ # # fontsize=6,
1098
+ # # legendfontsize=6,
1099
+ # figscale=2,
1100
+ # tick_size=0.8,
1101
+ # tick_width=0.2,
1102
+ # )
1103
+
1104
+ # sci_notation_kwargs = dict(
1105
+ # order=1,
1106
+ # fformat="%1.0d",
1107
+ # scilimits=(-3, 3),
1108
+ # x=False,
1109
+ # y=True,
1110
+ # ) # "%3.1f"
1111
+
1112
+ # # sci_notation_kwargs = None
1113
+ # # reporter.plot_and_save_conf_mats(
1114
+ # # plt,
1115
+ # # extend_ratio=1.0,
1116
+ # # confmat_plt_config=confmat_plt_config,
1117
+ # # sci_notation_kwargs=sci_notation_kwargs,
1118
+ # # )
1119
+
1120
+ # mreporter.plot_and_save_conf_mats(
1121
+ # plt,
1122
+ # extend_ratio=1.0,
1123
+ # confmat_plt_config=confmat_plt_config,
1124
+ # sci_notation_kwargs=sci_notation_kwargs,
1125
+ # tgt="Test1",
1126
+ # )
1127
+ # mreporter.plot_and_save_conf_mats(
1128
+ # plt,
1129
+ # extend_ratio=1.0,
1130
+ # confmat_plt_config=confmat_plt_config,
1131
+ # sci_notation_kwargs=sci_notation_kwargs,
1132
+ # tgt="Test2",
1133
+ # )
1134
+
1135
+ # python -m scitex.ai.ClassificationReporter
1136
+
1137
+ # EOF