scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,246 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+ from itertools import cycle
5
+
6
+ import numpy as np
7
+ from sklearn.metrics import roc_auc_score, roc_curve
8
+ import pandas as pd
9
+
10
+ import scitex
11
+
12
+
13
+ def interpolate_roc_data_points(df):
14
+ df_new = pd.DataFrame(
15
+ {
16
+ "x": np.arange(1001) / 1000,
17
+ "y": np.nan,
18
+ "threshold": np.nan,
19
+ }
20
+ )
21
+
22
+ for i_row in range(len(df) - 1):
23
+ x_pre = df.iloc[i_row]["fpr"]
24
+ x_post = df.iloc[i_row + 1]["fpr"]
25
+
26
+ indi = (x_pre <= df_new["x"]) * (df_new["x"] <= x_post)
27
+
28
+ y_pre = df.iloc[i_row]["tpr"]
29
+ y_post = df.iloc[i_row + 1]["tpr"]
30
+
31
+ t_pre = df.iloc[i_row]["threshold"]
32
+ t_post = df.iloc[i_row + 1]["threshold"]
33
+
34
+ df_new["y"][indi] = y_pre
35
+ df_new["threshold"][indi] = t_pre
36
+
37
+ df_new["y"].iloc[0] = df["tpr"].iloc[0]
38
+ df_new["y"].iloc[-1] = df["tpr"].iloc[-1]
39
+
40
+ df_new["threshold"].iloc[0] = df["threshold"].iloc[0]
41
+ df_new["threshold"].iloc[-1] = df["threshold"].iloc[-1]
42
+
43
+ df_new["roc_auc"] = df["roc_auc"].iloc[0]
44
+
45
+ # import ipdb; ipdb.set_trace()
46
+ # assert df_new["y"].isna().sum() == 0
47
+ return df_new
48
+
49
+
50
+ def to_onehot(labels, n_classes):
51
+ eye = np.eye(n_classes, dtype=int)
52
+ return eye[labels]
53
+
54
+
55
+ def roc_auc(plt, true_class, pred_proba, labels, sdir_for_csv=None):
56
+ """
57
+ Calculates ROC-AUC curve.
58
+ Return: fig, metrics (dict)
59
+ """
60
+
61
+ # Use label_binarize to be multi-label like settings
62
+ n_classes = len(labels)
63
+ true_class_onehot = to_onehot(true_class, n_classes)
64
+
65
+ # For each class
66
+ fpr = dict()
67
+ tpr = dict()
68
+ threshold = dict()
69
+ roc_auc = dict()
70
+ for i in range(n_classes):
71
+ true_class_i_onehot = true_class_onehot[:, i]
72
+ pred_proba_i = pred_proba[:, i]
73
+
74
+ try:
75
+ fpr[i], tpr[i], threshold[i] = roc_curve(true_class_i_onehot, pred_proba_i)
76
+ roc_auc[i] = roc_auc_score(true_class_i_onehot, pred_proba_i)
77
+ except Exception as e:
78
+ print(e)
79
+ fpr[i], tpr[i], threshold[i], roc_auc[i] = (
80
+ [np.nan],
81
+ [np.nan],
82
+ [np.nan],
83
+ np.nan,
84
+ )
85
+
86
+ ## Average fpr: micro and macro
87
+
88
+ # A "micro-average": quantifying score on all classes jointly
89
+ fpr["micro"], tpr["micro"], threshold["micro"] = roc_curve(
90
+ true_class_onehot.ravel(), pred_proba.ravel()
91
+ )
92
+ roc_auc["micro"] = roc_auc_score(true_class_onehot, pred_proba, average="micro")
93
+
94
+ # macro
95
+ _roc_aucs = []
96
+ for i in range(n_classes):
97
+ try:
98
+ _roc_aucs.append(
99
+ roc_auc_score(
100
+ true_class_onehot[:, i], pred_proba[:, i], average="macro"
101
+ )
102
+ )
103
+ except Exception as e:
104
+ print(
105
+ f'\nROC-AUC for "{labels[i]}" was not defined and NaN-filled '
106
+ "for a calculation purpose (for the macro avg.)\n"
107
+ )
108
+ _roc_aucs.append(np.nan)
109
+ roc_auc["macro"] = np.nanmean(_roc_aucs)
110
+
111
+ if sdir_for_csv is not None:
112
+ # to dfs
113
+ for i in range(n_classes):
114
+ class_name = labels[i].replace(" ", "_")
115
+ df = pd.DataFrame(
116
+ data={
117
+ "fpr": fpr[i],
118
+ "tpr": tpr[i],
119
+ "threshold": threshold[i],
120
+ "roc_auc": [roc_auc[i] for _ in range(len(fpr[i]))],
121
+ },
122
+ index=pd.Index(data=np.arange(len(fpr[i])), name=class_name),
123
+ )
124
+ df = interpolate_roc_data_points(df)
125
+ spath = f"{sdir_for_csv}{class_name}.csv"
126
+ scitex.io.save(df, spath)
127
+
128
+ # Plot FPR-TPR curve for each class and iso-f1 curves
129
+ colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
130
+
131
+ fig, ax = plt.subplots()
132
+ ax.set_box_aspect(1)
133
+ lines = []
134
+ legends = []
135
+
136
+ ## Chance Level (the diagonal line)
137
+ (l,) = ax.plot(
138
+ np.linspace(0.01, 1),
139
+ np.linspace(0.01, 1),
140
+ color="gray",
141
+ lw=2,
142
+ linestyle="--",
143
+ alpha=0.8,
144
+ )
145
+ lines.append(l)
146
+ legends.append("Chance")
147
+
148
+ ## Each Class
149
+ for i, color in zip(range(n_classes), colors):
150
+ (l,) = plt.plot(fpr[i], tpr[i], color=color, lw=2)
151
+ lines.append(l)
152
+ legends.append("{0} (AUC = {1:0.2f})" "".format(labels[i], roc_auc[i]))
153
+
154
+ # fig = plt.gcf()
155
+ fig.subplots_adjust(bottom=0.25)
156
+ ax.set_xlim([-0.01, 1.01])
157
+ ax.set_ylim([-0.01, 1.01])
158
+ ax.set_xticks([0.0, 0.5, 1.0])
159
+ ax.set_yticks([0.0, 0.5, 1.0])
160
+ ax.set_xlabel("FPR")
161
+ ax.set_ylabel("TPR")
162
+ ax.set_title("ROC Curve")
163
+ ax.legend(lines, legends, loc="lower right")
164
+
165
+ metrics = dict(roc_auc=roc_auc, fpr=fpr, tpr=tpr, threshold=threshold)
166
+
167
+ # return fig, roc_auc, fpr, tpr, threshold
168
+ return fig, metrics
169
+
170
+
171
+ if __name__ == "__main__":
172
+ import matplotlib.pyplot as plt
173
+ import numpy as np
174
+ from scipy.special import softmax
175
+ from sklearn import datasets, svm
176
+ from sklearn.model_selection import train_test_split
177
+
178
+ def mk_demo_data(n_classes=2, batch_size=16):
179
+ labels = ["cls{}".format(i_cls) for i_cls in range(n_classes)]
180
+ true_class = np.random.randint(0, n_classes, size=(batch_size,))
181
+ pred_proba = softmax(np.random.rand(batch_size, n_classes), axis=-1)
182
+ pred_class = np.argmax(pred_proba, axis=-1)
183
+ return labels, true_class, pred_proba, pred_class
184
+
185
+ ## Fix seed
186
+ np.random.seed(42)
187
+
188
+ """
189
+ ################################################################################
190
+ ## A Minimal Example
191
+ ################################################################################
192
+ labels, true_class, pred_proba, pred_class = \
193
+ mk_demo_data(n_classes=10, batch_size=256)
194
+
195
+ roc_auc, fpr, tpr, threshold = \
196
+ calc_roc_auc(true_class, pred_proba, labels, plot=False)
197
+ """
198
+
199
+ ################################################################################
200
+ ## MNIST
201
+ ################################################################################
202
+ from sklearn import datasets, metrics, svm
203
+ from sklearn.model_selection import train_test_split
204
+
205
+ digits = datasets.load_digits()
206
+
207
+ # flatten the images
208
+ n_samples = len(digits.images)
209
+ data = digits.images.reshape((n_samples, -1))
210
+
211
+ # Create a classifier: a support vector classifier
212
+ clf = svm.SVC(gamma=0.001, probability=True)
213
+
214
+ # Split data into 50% train and 50% test subsets
215
+ X_train, X_test, y_train, y_test = train_test_split(
216
+ data, digits.target, test_size=0.5, shuffle=False
217
+ )
218
+
219
+ # Learn the digits on the train subset
220
+ clf.fit(X_train, y_train)
221
+
222
+ # Predict the value of the digit on the test subset
223
+ predicted_proba = clf.predict_proba(X_test)
224
+ predicted = clf.predict(X_test)
225
+
226
+ n_classes = len(np.unique(digits.target))
227
+ labels = ["Class {}".format(i) for i in range(n_classes)]
228
+
229
+ ## Configures matplotlib
230
+ plt.rcParams["font.size"] = 20
231
+ plt.rcParams["legend.fontsize"] = "xx-small"
232
+ plt.rcParams["figure.figsize"] = (16 * 1.2, 9 * 1.2)
233
+
234
+ np.unique(y_test)
235
+ np.unique(predicted_proba)
236
+
237
+ y_test[y_test == 9] = 8 # override 9 as 8
238
+ ## Main
239
+ fig, metrics_dict = roc_auc(
240
+ plt, y_test, predicted_proba, labels, sdir_for_csv="./tmp/roc_test/"
241
+ )
242
+
243
+ fig.show()
244
+
245
+ print(metrics_dict.keys())
246
+ # dict_keys(['roc_auc', 'fpr', 'tpr', 'threshold'])
@@ -0,0 +1,29 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-24 10:13:17 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/sampling/undersample.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/sampling/undersample.py"
7
+
8
+ from typing import Tuple
9
+ from ...types import ArrayLike
10
+ from imblearn.under_sampling import RandomUnderSampler
11
+
12
+
13
+ def undersample(
14
+ X: ArrayLike, y: ArrayLike, random_state: int = 42
15
+ ) -> Tuple[ArrayLike, ArrayLike]:
16
+ """Undersample data preserving input type.
17
+
18
+ Args:
19
+ X: Features array-like of shape (n_samples, n_features)
20
+ y: Labels array-like of shape (n_samples,)
21
+ Returns:
22
+ Resampled X, y of same type as input
23
+ """
24
+ rus = RandomUnderSampler(random_state=random_state)
25
+ X_resampled, y_resampled = rus.fit_resample(X, y)
26
+ return X_resampled, y_resampled
27
+
28
+
29
+ # EOF
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python3
2
+ """Scitex sk module."""
3
+
4
+ from ._clf import GB_pipeline, rocket_pipeline
5
+ from ._to_sktime import to_sktime_df
6
+
7
+ __all__ = [
8
+ "GB_pipeline",
9
+ "rocket_pipeline",
10
+ "to_sktime_df",
11
+ ]
scitex/ai/sk/_clf.py ADDED
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-03-23 17:36:05 (ywatanabe)"
4
+
5
+ import numpy as np
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
8
+ from sklearn.feature_selection import SelectKBest, f_classif
9
+ from sklearn.linear_model import LogisticRegression, RidgeClassifierCV
10
+ from sklearn.pipeline import make_pipeline
11
+ from sklearn.svm import SVC, LinearSVC
12
+ from sktime.classification.deep_learning.cnn import CNNClassifier
13
+ from sktime.classification.deep_learning.inceptiontime import (
14
+ InceptionTimeClassifier,
15
+ )
16
+ from sktime.classification.deep_learning.lstmfcn import LSTMFCNClassifier
17
+ from sktime.classification.dummy import DummyClassifier
18
+ from sktime.classification.feature_based import TSFreshClassifier
19
+ from sktime.classification.hybrid import HIVECOTEV2
20
+ from sktime.classification.interval_based import TimeSeriesForestClassifier
21
+ from sktime.classification.kernel_based import RocketClassifier, TimeSeriesSVC
22
+ from sktime.transformations.panel.reduce import Tabularizer
23
+ from sktime.transformations.panel.rocket import Rocket
24
+
25
+ # _rocket_pipeline = make_pipeline(
26
+ # Rocket(n_jobs=-1),
27
+ # RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
28
+ # )
29
+
30
+
31
+ # def rocket_pipeline(*args, **kwargs):
32
+ # return _rocket_pipeline
33
+
34
+
35
+ def rocket_pipeline(*args, **kwargs):
36
+ return make_pipeline(
37
+ Rocket(*args, **kwargs),
38
+ LogisticRegression(
39
+ max_iter=1000
40
+ ), # Increase max_iter if needed for convergence
41
+ # RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
42
+ # SVC(probability=True, kernel="linear"),
43
+ )
44
+
45
+
46
+ # def rocket_pipeline(*args, **kwargs):
47
+ # return make_pipeline(
48
+ # Rocket(*args, **kwargs),
49
+ # SelectKBest(f_classif, k=500),
50
+ # PCA(n_components=100),
51
+ # LinearSVC(dual=False, tol=1e-3, C=0.1, probability=True),
52
+ # )
53
+
54
+
55
+ GB_pipeline = make_pipeline(
56
+ Tabularizer(),
57
+ GradientBoostingClassifier(),
58
+ )
@@ -0,0 +1,100 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-03-05 13:17:04 (ywatanabe)"
4
+
5
+ # import warnings
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+
11
+
12
+ def to_sktime_df(X):
13
+ """
14
+ Converts a dataset to a format compatible with sktime, encapsulating each sample as a pandas DataFrame.
15
+
16
+ Arguments:
17
+ - X (numpy.ndarray or torch.Tensor or pandas.DataFrame): The input dataset with shape (n_samples, n_chs, seq_len).
18
+ It should be a 3D array-like structure containing the time series data.
19
+
20
+ Return:
21
+ - sktime_df (pandas.DataFrame): A DataFrame where each element is a pandas Series representing a univariate time series.
22
+
23
+ Data Types and Shapes:
24
+ - If X is a numpy.ndarray, it should have the shape (n_samples, n_chs, seq_len).
25
+ - If X is a torch.Tensor, it should have the shape (n_samples, n_chs, seq_len) and will be converted to a numpy array.
26
+ - If X is a pandas.DataFrame, it is assumed to already be in the correct format and will be returned as is.
27
+
28
+ References:
29
+ - sktime: https://github.com/alan-turing-institute/sktime
30
+
31
+ Examples:
32
+ --------
33
+ >>> X_np = np.random.rand(64, 160, 1024)
34
+ >>> sktime_df = to_sktime_df(X_np)
35
+ >>> type(sktime_df)
36
+ <class 'pandas.core.frame.DataFrame'>
37
+ """
38
+ if isinstance(X, pd.DataFrame):
39
+ return X
40
+ elif torch.is_tensor(X):
41
+ X = X.numpy()
42
+ elif not isinstance(X, np.ndarray):
43
+ raise ValueError(
44
+ "Input X must be a numpy.ndarray, torch.Tensor, or pandas.DataFrame"
45
+ )
46
+
47
+ X = X.astype(np.float64)
48
+
49
+ def _format_a_sample_for_sktime(x):
50
+ """
51
+ Formats a single sample for sktime compatibility.
52
+
53
+ Arguments:
54
+ - x (numpy.ndarray): A 2D array with shape (n_chs, seq_len) representing a single sample.
55
+
56
+ Return:
57
+ - dims (pandas.Series): A Series where each element is a pandas Series representing a univariate time series.
58
+ """
59
+ return pd.Series([pd.Series(x[d], name=f"dim_{d}") for d in range(x.shape[0])])
60
+
61
+ sktime_df = pd.DataFrame(
62
+ [_format_a_sample_for_sktime(X[i]) for i in range(X.shape[0])]
63
+ )
64
+ return sktime_df
65
+
66
+
67
+ # # Obsolete warning for future compatibility
68
+ # def to_sktime(*args, **kwargs):
69
+ # warnings.warn(
70
+ # "to_sktime is deprecated; use to_sktime_df instead.", FutureWarning
71
+ # )
72
+ # return to_sktime_df(*args, **kwargs)
73
+
74
+
75
+ # import pandas as pd
76
+ # import numpy as np
77
+ # import torch
78
+
79
+ # def to_sktime(X):
80
+ # """
81
+ # X.shape: (n_samples, n_chs, seq_len)
82
+ # """
83
+
84
+ # def _format_a_sample_for_sktime(x):
85
+ # """
86
+ # x.shape: (n_chs, seq_len)
87
+ # """
88
+ # dims = pd.Series(
89
+ # [pd.Series(x[d], name=f"dim_{d}") for d in range(len(x))],
90
+ # index=[f"dim_{i}" for i in np.arange(len(x))],
91
+ # )
92
+ # return dims
93
+
94
+ # if torch.is_tensor(X):
95
+ # X = X.numpy()
96
+ # X = X.astype(np.float64)
97
+
98
+ # return pd.DataFrame(
99
+ # [_format_a_sample_for_sktime(X[i]) for i in range(len(X))]
100
+ # )
@@ -0,0 +1,26 @@
1
+ #!/usr/bin/env python3
2
+ """Sklearn wrappers and utilities."""
3
+
4
+ import warnings
5
+
6
+ try:
7
+ from .clf import *
8
+ except ImportError as e:
9
+ warnings.warn(
10
+ f"Could not import clf from scitex.ai.sklearn: {str(e)}. "
11
+ f"Some functionality may be unavailable. "
12
+ f"Consider installing missing dependencies if you need this module.",
13
+ ImportWarning,
14
+ stacklevel=2
15
+ )
16
+
17
+ try:
18
+ from .to_sktime import *
19
+ except ImportError as e:
20
+ warnings.warn(
21
+ f"Could not import to_sktime from scitex.ai.sklearn: {str(e)}. "
22
+ f"Some functionality may be unavailable. "
23
+ f"Consider installing missing dependencies if you need this module.",
24
+ ImportWarning,
25
+ stacklevel=2
26
+ )
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-03-23 17:36:05 (ywatanabe)"
4
+
5
+ import numpy as np
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
8
+ from sklearn.feature_selection import SelectKBest, f_classif
9
+ from sklearn.linear_model import LogisticRegression, RidgeClassifierCV
10
+ from sklearn.pipeline import make_pipeline
11
+ from sklearn.svm import SVC, LinearSVC
12
+ from sktime.classification.deep_learning.cnn import CNNClassifier
13
+ from sktime.classification.deep_learning.inceptiontime import (
14
+ InceptionTimeClassifier,
15
+ )
16
+ from sktime.classification.deep_learning.lstmfcn import LSTMFCNClassifier
17
+ from sktime.classification.dummy import DummyClassifier
18
+ from sktime.classification.feature_based import TSFreshClassifier
19
+ from sktime.classification.hybrid import HIVECOTEV2
20
+ from sktime.classification.interval_based import TimeSeriesForestClassifier
21
+ from sktime.classification.kernel_based import RocketClassifier, TimeSeriesSVC
22
+ from sktime.transformations.panel.reduce import Tabularizer
23
+ from sktime.transformations.panel.rocket import Rocket
24
+
25
+ # _rocket_pipeline = make_pipeline(
26
+ # Rocket(n_jobs=-1),
27
+ # RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
28
+ # )
29
+
30
+
31
+ # def rocket_pipeline(*args, **kwargs):
32
+ # return _rocket_pipeline
33
+
34
+
35
+ def rocket_pipeline(*args, **kwargs):
36
+ return make_pipeline(
37
+ Rocket(*args, **kwargs),
38
+ LogisticRegression(
39
+ max_iter=1000
40
+ ), # Increase max_iter if needed for convergence
41
+ # RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)),
42
+ # SVC(probability=True, kernel="linear"),
43
+ )
44
+
45
+
46
+ # def rocket_pipeline(*args, **kwargs):
47
+ # return make_pipeline(
48
+ # Rocket(*args, **kwargs),
49
+ # SelectKBest(f_classif, k=500),
50
+ # PCA(n_components=100),
51
+ # LinearSVC(dual=False, tol=1e-3, C=0.1, probability=True),
52
+ # )
53
+
54
+
55
+ GB_pipeline = make_pipeline(
56
+ Tabularizer(),
57
+ GradientBoostingClassifier(),
58
+ )
@@ -0,0 +1,100 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-03-05 13:17:04 (ywatanabe)"
4
+
5
+ # import warnings
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+
11
+
12
+ def to_sktime_df(X):
13
+ """
14
+ Converts a dataset to a format compatible with sktime, encapsulating each sample as a pandas DataFrame.
15
+
16
+ Arguments:
17
+ - X (numpy.ndarray or torch.Tensor or pandas.DataFrame): The input dataset with shape (n_samples, n_chs, seq_len).
18
+ It should be a 3D array-like structure containing the time series data.
19
+
20
+ Return:
21
+ - sktime_df (pandas.DataFrame): A DataFrame where each element is a pandas Series representing a univariate time series.
22
+
23
+ Data Types and Shapes:
24
+ - If X is a numpy.ndarray, it should have the shape (n_samples, n_chs, seq_len).
25
+ - If X is a torch.Tensor, it should have the shape (n_samples, n_chs, seq_len) and will be converted to a numpy array.
26
+ - If X is a pandas.DataFrame, it is assumed to already be in the correct format and will be returned as is.
27
+
28
+ References:
29
+ - sktime: https://github.com/alan-turing-institute/sktime
30
+
31
+ Examples:
32
+ --------
33
+ >>> X_np = np.random.rand(64, 160, 1024)
34
+ >>> sktime_df = to_sktime_df(X_np)
35
+ >>> type(sktime_df)
36
+ <class 'pandas.core.frame.DataFrame'>
37
+ """
38
+ if isinstance(X, pd.DataFrame):
39
+ return X
40
+ elif torch.is_tensor(X):
41
+ X = X.detach().numpy()
42
+ elif not isinstance(X, np.ndarray):
43
+ raise ValueError(
44
+ "Input X must be a numpy.ndarray, torch.Tensor, or pandas.DataFrame"
45
+ )
46
+
47
+ X = X.astype(np.float64)
48
+
49
+ def _format_a_sample_for_sktime(x):
50
+ """
51
+ Formats a single sample for sktime compatibility.
52
+
53
+ Arguments:
54
+ - x (numpy.ndarray): A 2D array with shape (n_chs, seq_len) representing a single sample.
55
+
56
+ Return:
57
+ - dims (pandas.Series): A Series where each element is a pandas Series representing a univariate time series.
58
+ """
59
+ return pd.Series([pd.Series(x[d], name=f"dim_{d}") for d in range(x.shape[0])])
60
+
61
+ sktime_df = pd.DataFrame(
62
+ [_format_a_sample_for_sktime(X[i]) for i in range(X.shape[0])]
63
+ )
64
+ return sktime_df
65
+
66
+
67
+ # # Obsolete warning for future compatibility
68
+ # def to_sktime(*args, **kwargs):
69
+ # warnings.warn(
70
+ # "to_sktime is deprecated; use to_sktime_df instead.", FutureWarning
71
+ # )
72
+ # return to_sktime_df(*args, **kwargs)
73
+
74
+
75
+ # import pandas as pd
76
+ # import numpy as np
77
+ # import torch
78
+
79
+ # def to_sktime(X):
80
+ # """
81
+ # X.shape: (n_samples, n_chs, seq_len)
82
+ # """
83
+
84
+ # def _format_a_sample_for_sktime(x):
85
+ # """
86
+ # x.shape: (n_chs, seq_len)
87
+ # """
88
+ # dims = pd.Series(
89
+ # [pd.Series(x[d], name=f"dim_{d}") for d in range(len(x))],
90
+ # index=[f"dim_{i}" for i in np.arange(len(x))],
91
+ # )
92
+ # return dims
93
+
94
+ # if torch.is_tensor(X):
95
+ # X = X.numpy()
96
+ # X = X.astype(np.float64)
97
+
98
+ # return pd.DataFrame(
99
+ # [_format_a_sample_for_sktime(X[i]) for i in range(len(X))]
100
+ # )
@@ -0,0 +1,7 @@
1
+ #!/usr/bin/env python3
2
+ """Training utilities."""
3
+
4
+ from .early_stopping import EarlyStopping
5
+ from .learning_curve_logger import LearningCurveLogger
6
+
7
+ __all__ = ["EarlyStopping", "LearningCurveLogger"]