scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-12-12 06:49:15 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/ClassifierServer.py
5
+
6
+ THIS_FILE = (
7
+ "/data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/ai/ClassifierServer.py"
8
+ )
9
+
10
+ """
11
+ Functionality:
12
+ * Provides a unified interface for initializing various scikit-learn classifiers
13
+ * Supports optional preprocessing with StandardScaler
14
+
15
+ Input:
16
+ * Classifier name as string
17
+ * Optional class weights for imbalanced datasets
18
+ * Optional scaler for feature preprocessing
19
+
20
+ Output:
21
+ * Initialized classifier or pipeline with scaler
22
+
23
+ Prerequisites:
24
+ * scikit-learn
25
+ * Optional: CatBoost for CatBoostClassifier
26
+ """
27
+
28
+ from typing import Dict, List, Optional, Union
29
+
30
+ from sklearn.base import BaseEstimator as _BaseEstimator
31
+ from sklearn.discriminant_analysis import (
32
+ QuadraticDiscriminantAnalysis as _QuadraticDiscriminantAnalysis,
33
+ )
34
+ from sklearn.ensemble import AdaBoostClassifier as _AdaBoostClassifier
35
+ from sklearn.gaussian_process import (
36
+ GaussianProcessClassifier as _GaussianProcessClassifier,
37
+ )
38
+ from sklearn.linear_model import LogisticRegression as _LogisticRegression
39
+ from sklearn.linear_model import (
40
+ PassiveAggressiveClassifier as _PassiveAggressiveClassifier,
41
+ )
42
+ from sklearn.linear_model import Perceptron as _Perceptron
43
+ from sklearn.linear_model import RidgeClassifier as _RidgeClassifier
44
+ from sklearn.linear_model import SGDClassifier as _SGDClassifier
45
+ from sklearn.neighbors import KNeighborsClassifier as _KNeighborsClassifier
46
+ from sklearn.pipeline import Pipeline as _Pipeline
47
+ from sklearn.pipeline import make_pipeline as _make_pipeline
48
+ from sklearn.preprocessing import StandardScaler as _StandardScaler
49
+ from sklearn.svm import SVC as _SVC
50
+ from sklearn.svm import LinearSVC as _LinearSVC
51
+
52
+
53
+ class ClassifierServer:
54
+ """
55
+ Server for initializing various scikit-learn classifiers with consistent interface.
56
+
57
+ Example
58
+ -------
59
+ >>> clf_server = ClassifierServer(class_weight={0: 1.0, 1: 2.0}, random_state=42)
60
+ >>> clf = clf_server("SVC", scaler=_StandardScaler())
61
+ >>> print(clf_server.list)
62
+ ['CatBoostClassifier', 'Perceptron', ...]
63
+
64
+ Parameters
65
+ ----------
66
+ class_weight : Optional[Dict[int, float]]
67
+ Class weights for handling imbalanced datasets
68
+ random_state : int
69
+ Random seed for reproducibility
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ class_weight: Optional[Dict[int, float]] = None,
75
+ random_state: int = 42,
76
+ ):
77
+ self.class_weight = class_weight
78
+ self.random_state = random_state
79
+
80
+ self.clf_candi = {
81
+ "Perceptron": _Perceptron(
82
+ penalty="l2",
83
+ class_weight=self.class_weight,
84
+ random_state=random_state,
85
+ ),
86
+ "PassiveAggressiveClassifier": _PassiveAggressiveClassifier(
87
+ class_weight=self.class_weight, random_state=random_state
88
+ ),
89
+ "LogisticRegression": _LogisticRegression(
90
+ class_weight=self.class_weight, random_state=random_state
91
+ ),
92
+ "SGDClassifier": _SGDClassifier(
93
+ class_weight=self.class_weight, random_state=random_state
94
+ ),
95
+ "RidgeClassifier": _RidgeClassifier(
96
+ class_weight=self.class_weight, random_state=random_state
97
+ ),
98
+ "QuadraticDiscriminantAnalysis": _QuadraticDiscriminantAnalysis(),
99
+ "GaussianProcessClassifier": _GaussianProcessClassifier(
100
+ random_state=random_state
101
+ ),
102
+ "KNeighborsClassifier": _KNeighborsClassifier(),
103
+ "AdaBoostClassifier": _AdaBoostClassifier(random_state=random_state),
104
+ "LinearSVC": _LinearSVC(
105
+ class_weight=self.class_weight, random_state=random_state
106
+ ),
107
+ "SVC": _SVC(class_weight=self.class_weight, random_state=random_state),
108
+ }
109
+
110
+ def __call__(
111
+ self, clf_str: str, scaler: Optional[_BaseEstimator] = None
112
+ ) -> Union[_BaseEstimator, _Pipeline]:
113
+ if clf_str not in self.clf_candi:
114
+ raise ValueError(
115
+ f"Unknown classifier: {clf_str}. Available options: {self.list}"
116
+ )
117
+
118
+ if scaler is not None:
119
+ clf = _make_pipeline(scaler, self.clf_candi[clf_str])
120
+ else:
121
+ clf = self.clf_candi[clf_str]
122
+ return clf
123
+
124
+ @property
125
+ def list(self) -> List[str]:
126
+ return list(self.clf_candi.keys())
127
+
128
+
129
+ if __name__ == "__main__":
130
+ clf_server = ClassifierServer()
131
+ clf = clf_server("SVC", scaler=_StandardScaler())
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python3
2
+ """Scitex clustering module."""
3
+
4
+ from ._pca import pca
5
+ from ._umap import main, umap
6
+
7
+ __all__ = [
8
+ "main",
9
+ "pca",
10
+ "umap",
11
+ ]
@@ -0,0 +1,115 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-05-14 00:58:26 (ywatanabe)"
4
+
5
+ import matplotlib.pyplot as plt
6
+ import scitex
7
+ import numpy as np
8
+ import seaborn as sns
9
+ from natsort import natsorted
10
+ from sklearn.decomposition import PCA
11
+ from sklearn.preprocessing import LabelEncoder
12
+
13
+
14
+ def pca(
15
+ data_all,
16
+ labels_all,
17
+ axes_titles=None,
18
+ title="PCA Clustering",
19
+ alpha=0.1,
20
+ s=3,
21
+ use_independent_legend=False,
22
+ add_super_imposed=False,
23
+ palette="viridis",
24
+ ):
25
+
26
+ assert len(data_all) == len(labels_all)
27
+
28
+ if isinstance(data_all, list):
29
+ data_all = list(data_all)
30
+ labels_all = list(labels_all)
31
+
32
+ le = LabelEncoder()
33
+ # le.fit(np.hstack(labels_all))
34
+ le.fit(natsorted(np.hstack(labels_all)))
35
+ labels_all = [le.transform(labels) for labels in labels_all]
36
+
37
+ pca_model = PCA(n_components=2)
38
+
39
+ ncols = len(data_all) + 1 if add_super_imposed else len(data_all)
40
+ share = True if ncols > 1 else False
41
+ fig, axes = plt.subplots(ncols=ncols, sharex=share, sharey=share)
42
+
43
+ fig.suptitle(title)
44
+ fig.supxlabel("PCA 1")
45
+ fig.supylabel("PCA 2")
46
+
47
+ for ii, (data, labels) in enumerate(zip(data_all, labels_all)):
48
+ if ii == 0:
49
+ _pca = pca_model.fit(data)
50
+ embedding = _pca.transform(data)
51
+ else:
52
+ embedding = pca_model.transform(data)
53
+
54
+ if ncols == 1:
55
+ ax = axes
56
+ else:
57
+ ax = axes[ii + 1] if add_super_imposed else axes[ii]
58
+
59
+ sns.scatterplot(
60
+ x=embedding[:, 0],
61
+ y=embedding[:, 1],
62
+ hue=le.inverse_transform(labels),
63
+ ax=ax,
64
+ palette=palette,
65
+ s=s,
66
+ alpha=alpha,
67
+ )
68
+
69
+ ax.set_box_aspect(1)
70
+
71
+ if axes_titles is not None:
72
+ ax.set_title(axes_titles[ii])
73
+
74
+ if not use_independent_legend:
75
+ ax.legend(loc="upper left")
76
+
77
+ if add_super_imposed:
78
+ axes[0].set_title("Superimposed")
79
+ axes[0].set_aspect("equal")
80
+
81
+ sns.scatterplot(
82
+ x=embedding[:, 0],
83
+ y=embedding[:, 1],
84
+ hue=le.inverse_transform(labels),
85
+ ax=axes[0],
86
+ palette=palette,
87
+ legend="full" if ii == 0 else False,
88
+ s=s,
89
+ alpha=alpha,
90
+ )
91
+
92
+ if not use_independent_legend:
93
+ return fig, None, pca_model
94
+
95
+ elif use_independent_legend:
96
+ legend_figs = []
97
+ for i, ax in enumerate(axes):
98
+ legend = ax.get_legend()
99
+ if legend:
100
+ legend_fig = plt.figure(figsize=(3, 2))
101
+ new_legend = legend_fig.gca().legend(
102
+ handles=legend.legendHandles,
103
+ labels=legend.texts,
104
+ loc="center",
105
+ )
106
+ legend_fig.canvas.draw()
107
+ legend_filename = f"legend_{i}.png"
108
+ legend_fig.savefig(legend_filename, bbox_inches="tight")
109
+ legend_figs.append(legend_fig)
110
+ plt.close(legend_fig)
111
+
112
+ for ax in axes:
113
+ ax.legend_ = None
114
+ # ax.remove_legend()
115
+ return fig, legend_figs, pca_model
@@ -0,0 +1,376 @@
1
+ #!./env/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-09-12 05:37:55 (ywatanabe)"
4
+ # _umap_dev.py
5
+
6
+
7
+ """
8
+ This script does XYZ.
9
+ """
10
+
11
+
12
+ """
13
+ Imports
14
+ """
15
+ import sys
16
+
17
+ import matplotlib.pyplot as plt
18
+ import scitex
19
+ import numpy as np
20
+ import umap.umap_ as umap_orig
21
+ from natsort import natsorted
22
+ from sklearn.preprocessing import LabelEncoder
23
+
24
+ # sys.path = ["."] + sys.path
25
+ # from scripts import utils, load
26
+
27
+ """
28
+ Warnings
29
+ """
30
+ # warnings.simplefilter("ignore", UserWarning)
31
+
32
+
33
+ """
34
+ Config
35
+ """
36
+ # CONFIG = scitex.gen.load_configs()
37
+
38
+
39
+ """
40
+ Functions & Classes
41
+ """
42
+
43
+
44
+ def umap(
45
+ data,
46
+ labels,
47
+ hues=None,
48
+ hues_colors=None,
49
+ axes=None,
50
+ axes_titles=None,
51
+ supervised=False,
52
+ title="UMAP Clustering",
53
+ alpha=1.0,
54
+ s=3,
55
+ use_independent_legend=False,
56
+ add_super_imposed=False,
57
+ umap_model=None,
58
+ ):
59
+ """
60
+ Perform UMAP clustering and visualization.
61
+
62
+ Parameters
63
+ ----------
64
+ data_all : list
65
+ List of data arrays to cluster
66
+ labels_all : list
67
+ List of label arrays corresponding to data_all
68
+ hues_all : list, optional
69
+ List of hue arrays for coloring points
70
+ hues_colors_all : list, optional
71
+ List of color mappings for hues
72
+ axes : matplotlib.axes.Axes, optional
73
+ Existing axes to plot on
74
+ axes_titles : list, optional
75
+ Titles for each subplot
76
+ supervised : bool, optional
77
+ Whether to use supervised UMAP
78
+ title : str, optional
79
+ Main title for the plot
80
+ alpha : float, optional
81
+ Transparency of points
82
+ s : int, optional
83
+ Size of points
84
+ use_independent_legend : bool, optional
85
+ Whether to create separate legend figures
86
+ add_super_imposed : bool, optional
87
+ Whether to add a superimposed plot
88
+ umap_model : umap.UMAP, optional
89
+ Pre-fitted UMAP model
90
+
91
+ Returns
92
+ -------
93
+ tuple
94
+ Figure, legend figures (if applicable), and UMAP model
95
+ """
96
+
97
+ # Renaming
98
+ data_all = data
99
+ labels_all = labels
100
+ hues_all = hues
101
+ hues_colors_all = hues_colors
102
+
103
+ data_all, labels_all, hues_all, hues_colors_all = _check_input_vars(
104
+ data_all, labels_all, hues_all, hues_colors_all
105
+ )
106
+
107
+ # Label Encoding
108
+ le = LabelEncoder()
109
+ le.fit(natsorted(np.hstack(labels_all)))
110
+ labels_all = [le.transform(labels) for labels in labels_all]
111
+
112
+ # Running UMAP Clustering
113
+ _umap = _run_umap(umap_model, data_all, labels_all, supervised, title)
114
+
115
+ # Plotting
116
+ fig, legend_figs = _plot(
117
+ _umap,
118
+ le,
119
+ data_all,
120
+ labels_all,
121
+ hues_all,
122
+ hues_colors_all,
123
+ add_super_imposed,
124
+ axes,
125
+ title,
126
+ axes_titles,
127
+ use_independent_legend,
128
+ s,
129
+ alpha,
130
+ )
131
+
132
+ return fig, legend_figs, _umap
133
+
134
+
135
+ def _plot(
136
+ _umap,
137
+ le,
138
+ data_all,
139
+ labels_all,
140
+ hues_all,
141
+ hues_colors_all,
142
+ add_super_imposed,
143
+ axes,
144
+ title,
145
+ axes_titles,
146
+ use_independent_legend,
147
+ s,
148
+ alpha,
149
+ ):
150
+ # Plotting
151
+ ncols = len(data_all) + 1 if add_super_imposed else len(data_all)
152
+ share = True if ncols > 1 else False
153
+
154
+ if axes is None:
155
+ fig, axes = scitex.plt.subplots(ncols=ncols, sharex=share, sharey=share)
156
+ else:
157
+ assert len(axes) == ncols
158
+ fig = (
159
+ axes[0].get_figure()
160
+ # axes
161
+ if isinstance(
162
+ axes, (np.ndarray, scitex.plt._subplots._AxesWrapper.AxesWrapper)
163
+ )
164
+ # axis
165
+ else axes.get_figure()
166
+ )
167
+
168
+ fig.supxyt("UMAP 1", "UMAP 2", title)
169
+
170
+ for ii, (data, labels, hues, hues_colors) in enumerate(
171
+ zip(data_all, labels_all, hues_all, hues_colors_all)
172
+ ):
173
+ embedding = _umap.transform(data)
174
+
175
+ # ax
176
+ if ncols == 1:
177
+ ax = axes
178
+ else:
179
+ ax = axes[ii + 1] if add_super_imposed else axes[ii]
180
+
181
+ _hues = le.inverse_transform(labels) if hues is None else hues
182
+ for hue in np.unique(_hues):
183
+ indi = hue == np.array(_hues)
184
+
185
+ if hues_colors:
186
+ colors = np.vstack(hues_colors)[indi]
187
+ colors = [colors[ii] for ii in range(len(colors))]
188
+ else:
189
+ colors = None
190
+ ax.scatter(
191
+ x=embedding[:, 0][indi],
192
+ y=embedding[:, 1][indi],
193
+ label=hue,
194
+ c=colors,
195
+ s=s,
196
+ alpha=alpha,
197
+ )
198
+
199
+ ax.set_box_aspect(1)
200
+
201
+ if axes_titles is not None:
202
+ ax.set_title(axes_titles[ii])
203
+
204
+ # Merged axis
205
+ if add_super_imposed:
206
+ ax = axes[0]
207
+ _hues = le.inverse_transform(labels) if hues is None else hues
208
+ for hue in np.unique(_hues):
209
+ indi = hue == np.array(_hues)
210
+ ax.scatter(
211
+ x=embedding[:, 0][indi],
212
+ y=embedding[:, 1][indi],
213
+ label=hue,
214
+ c=np.vstack(hues_colors)[indi][0],
215
+ s=s,
216
+ alpha=alpha,
217
+ )
218
+
219
+ ax.set_title("Superimposed")
220
+ ax.set_box_aspect(1)
221
+ # ax.sns_scatterplot(
222
+ # x=embedding[:, 0],
223
+ # y=embedding[:, 1],
224
+ # hue=le.inverse_transform(labels) if hues is None else hues,
225
+ # palette=hues_colors,
226
+ # legend="full" if ii == 0 else False,
227
+ # s=s,
228
+ # alpha=alpha,
229
+ # )
230
+
231
+ if share:
232
+ scitex.plt.ax.sharex(axes)
233
+ scitex.plt.ax.sharey(axes)
234
+
235
+ if not use_independent_legend:
236
+ for ax in axes.flat:
237
+ ax.legend(loc="upper left")
238
+ return fig, None
239
+
240
+ elif use_independent_legend:
241
+ legend_figs = []
242
+ for i, ax in enumerate(axes):
243
+ legend = ax.get_legend()
244
+ if legend:
245
+ legend_fig = plt.figure(figsize=(3, 2))
246
+
247
+ new_legend = legend_fig.gca().legend(
248
+ handles=legend.get_lines(),
249
+ labels=[t.get_text() for t in legend.texts],
250
+ loc="center",
251
+ )
252
+
253
+ # new_legend = legend_fig.gca().legend(
254
+ # handles=legend.legendHandles,
255
+ # labels=legend.texts,
256
+ # loc="center",
257
+ # )
258
+
259
+ # legend_fig.canvas.draw()
260
+ legend_figs.append(legend_fig)
261
+ ax.get_legend().remove()
262
+
263
+ for ax in axes:
264
+ ax.legend_ = None
265
+
266
+ # elif use_independent_legend:
267
+ # legend_figs = []
268
+ # for i, ax in enumerate(axes):
269
+ # legend = ax.get_legend()
270
+ # if legend:
271
+ # legend_fig = plt.figure(figsize=(3, 2))
272
+ # new_legend = legend_fig.gca().legend(
273
+ # handles=legend.legendHandles,
274
+ # labels=legend.texts,
275
+ # loc="center",
276
+ # )
277
+ # legend_fig.canvas.draw()
278
+ # legend_filename = f"legend_{i}.png"
279
+ # legend_fig.savefig(legend_filename, bbox_inches="tight")
280
+ # legend_figs.append(legend_fig)
281
+ # plt.close(legend_fig)
282
+
283
+ # for ax in axes:
284
+ # ax.legend_ = None
285
+
286
+ return fig, legend_figs
287
+
288
+
289
+ def _run_umap(umap_model, data_all, labels_all, supervised, title):
290
+ # UMAP Clustering
291
+ if not umap_model:
292
+ umap_model = umap_orig.UMAP(random_state=42)
293
+ supervised_label_or_none = labels_all[0] if supervised else None
294
+ title = f"(Supervised) {title}" if supervised else f"(Unsupervised) {title}"
295
+ _umap = umap_model.fit(data_all[0], y=supervised_label_or_none)
296
+ else:
297
+ _umap = umap_model
298
+
299
+ return _umap
300
+
301
+
302
+ def _check_input_vars(data_all, labels_all, hues_all, hues_colors_all):
303
+ # Ensures input formats
304
+ if hues_all is None:
305
+ hues_all = [None for _ in range(len(data_all))]
306
+
307
+ if hues_colors_all is None:
308
+ hues_colors_all = [None for _ in range(len(data_all))]
309
+
310
+ assert len(data_all) == len(labels_all) == len(hues_all) == len(hues_colors_all)
311
+
312
+ assert (
313
+ isinstance(data_all, list)
314
+ and isinstance(labels_all, list)
315
+ and isinstance(hues_all, list)
316
+ and isinstance(hues_colors_all, list)
317
+ )
318
+ return data_all, labels_all, hues_all, hues_colors_all
319
+
320
+
321
+ def _test(dataset_str="iris"):
322
+ import matplotlib.pyplot as plt
323
+ import numpy as np
324
+ from sklearn.datasets import load_digits, load_iris
325
+ from sklearn.model_selection import train_test_split
326
+
327
+ # Load iris dataset
328
+ load_dataset = {"iris": load_iris, "mnist": load_digits}[dataset_str]
329
+
330
+ dataset = load_dataset()
331
+ X = dataset.data
332
+ y = dataset.target
333
+
334
+ # Split data into two parts
335
+ X1, X2, y1, y2 = train_test_split(X, y, test_size=0.5, random_state=42)
336
+
337
+ # Call umap function
338
+ fig, legend_figs, umap_model = umap(
339
+ data=[X1, X2],
340
+ labels=[y1, y2],
341
+ # axes=axes,
342
+ axes_titles=[f"{dataset_str} Set 1", f"{dataset_str} Set 2"],
343
+ supervised=True,
344
+ title=dataset_str,
345
+ use_independent_legend=True,
346
+ s=10,
347
+ )
348
+
349
+ # plt.tight_layout()
350
+ scitex.io.save(fig, f"/tmp/scitex/umap/{dataset_str}.jpg")
351
+
352
+ # Save legend figures if any
353
+ if legend_figs:
354
+ for i, leg_fig in enumerate(legend_figs):
355
+ scitex.io.save(leg_fig, f"/tmp/scitex/umap/{dataset_str}_legend_{i}.jpg")
356
+
357
+
358
+ main = umap
359
+
360
+ if __name__ == "__main__":
361
+ # # Argument Parser
362
+ # import argparse
363
+ # parser = argparse.ArgumentParser(description='')
364
+ # parser.add_argument('--var', '-v', type=int, default=1, help='')
365
+ # parser.add_argument('--flag', '-f', action='store_true', default=False, help='')
366
+ # args = parser.parse_args()
367
+
368
+ # Main
369
+ CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.gen.start(
370
+ sys, plt, verbose=False, agg=True
371
+ )
372
+ _test(dataset_str="mnist")
373
+ # main()
374
+ scitex.gen.close(CONFIG, verbose=False, notify=False)
375
+
376
+ # EOF