scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,149 @@
1
+ #!/usr/bin/env python3
2
+ # Time-stamp: "2024-09-07 01:09:38 (ywatanabe)"
3
+
4
+ import os
5
+
6
+ import scitex
7
+ import numpy as np
8
+
9
+
10
+ class EarlyStopping:
11
+ """
12
+ Early stops the training if the validation score doesn't improve after a given patience period.
13
+
14
+ """
15
+
16
+ def __init__(self, patience=7, verbose=False, delta=1e-5, direction="minimize"):
17
+ """
18
+ Args:
19
+ patience (int): How long to wait after last time validation score improved.
20
+ Default: 7
21
+ verbose (bool): If True, prints a message for each validation score improvement.
22
+ Default: False
23
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
24
+ Default: 0
25
+ """
26
+ self.patience = patience
27
+ self.verbose = verbose
28
+ self.direction = direction
29
+
30
+ self.delta = delta
31
+
32
+ # default
33
+ self.counter = 0
34
+ self.best_score = np.inf if direction == "minimize" else -np.inf
35
+ self.best_i_global = None
36
+ self.models_spaths_dict = {}
37
+
38
+ def is_best(self, val_score):
39
+ is_smaller = val_score < self.best_score - abs(self.delta)
40
+ is_larger = self.best_score + abs(self.delta) < val_score
41
+ return is_smaller if self.direction == "minimize" else is_larger
42
+
43
+ def __call__(self, current_score, models_spaths_dict, i_global):
44
+ # The 1st call
45
+ if self.best_score is None:
46
+ self.save(current_score, models_spaths_dict, i_global)
47
+ return False
48
+
49
+ # After the 2nd call
50
+ if self.is_best(current_score):
51
+ self.save(current_score, models_spaths_dict, i_global)
52
+ self.counter = 0
53
+ return False
54
+
55
+ else:
56
+ self.counter += 1
57
+ if self.verbose:
58
+ print(
59
+ f"\nEarlyStopping counter: {self.counter} out of {self.patience}\n"
60
+ )
61
+ if self.counter >= self.patience:
62
+ if self.verbose:
63
+ scitex.gen.print_block("Early-stopped.", c="yellow")
64
+ return True
65
+
66
+ def save(self, current_score, models_spaths_dict, i_global):
67
+ """Saves model when validation score decrease."""
68
+
69
+ if self.verbose:
70
+ print(
71
+ f"\nUpdate the best score: ({self.best_score:.6f} --> {current_score:.6f})"
72
+ )
73
+
74
+ self.best_score = current_score
75
+ self.best_i_global = i_global
76
+
77
+ for model, spath in models_spaths_dict.items():
78
+ scitex.io.save(model.state_dict(), spath)
79
+
80
+ self.models_spaths_dict = models_spaths_dict
81
+
82
+
83
+ if __name__ == "__main__":
84
+ pass
85
+ # # starts the current fold's loop
86
+ # i_global = 0
87
+ # lc_logger = scitex.ml.LearningCurveLogger()
88
+ # early_stopping = utils.EarlyStopping(patience=50, verbose=True)
89
+ # for i_epoch, epoch in enumerate(tqdm(range(merged_conf["MAX_EPOCHS"]))):
90
+
91
+ # dlf.fill(i_fold, reset_fill_counter=False)
92
+
93
+ # step_str = "Validation"
94
+ # for i_batch, batch in enumerate(dlf.dl_val):
95
+ # _, loss_diag_val = utils.base_step(
96
+ # step_str,
97
+ # model,
98
+ # mtl,
99
+ # batch,
100
+ # device,
101
+ # i_fold,
102
+ # i_epoch,
103
+ # i_batch,
104
+ # i_global,
105
+ # lc_logger,
106
+ # no_mtl=args.no_mtl,
107
+ # print_batch_interval=False,
108
+ # )
109
+ # lc_logger.print(step_str)
110
+
111
+ # step_str = "Training"
112
+ # for i_batch, batch in enumerate(dlf.dl_tra):
113
+ # optimizer.zero_grad()
114
+ # loss, _ = utils.base_step(
115
+ # step_str,
116
+ # model,
117
+ # mtl,
118
+ # batch,
119
+ # device,
120
+ # i_fold,
121
+ # i_epoch,
122
+ # i_batch,
123
+ # i_global,
124
+ # lc_logger,
125
+ # no_mtl=args.no_mtl,
126
+ # print_batch_interval=False,
127
+ # )
128
+ # loss.backward()
129
+ # optimizer.step()
130
+ # i_global += 1
131
+ # lc_logger.print(step_str)
132
+
133
+ # bACC_val = np.array(lc_logger.logged_dict["Validation"]["bACC_diag_plot"])[
134
+ # np.array(lc_logger.logged_dict["Validation"]["i_epoch"]) == i_epoch
135
+ # ].mean()
136
+
137
+ # model_spath = (
138
+ # merged_conf["sdir"]
139
+ # + f"checkpoints/model_fold#{i_fold}_epoch#{i_epoch:03d}.pth"
140
+ # )
141
+ # mtl_spath = model_spath.replace("model_fold", "mtl_fold")
142
+ # models_spaths_dict = {model_spath: model, mtl_spath: mtl}
143
+
144
+ # early_stopping(loss_diag_val, models_spaths_dict, i_epoch, i_global)
145
+ # # early_stopping(-bACC_val, models_spaths_dict, i_epoch, i_global)
146
+
147
+ # if early_stopping.early_stop:
148
+ # print("Early stopping")
149
+ # break
@@ -0,0 +1,56 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-20 10:53:22 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/feature_extraction/__init__.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/feature_extraction/__init__.py"
7
+
8
+ #!/usr/bin/env python3
9
+ # -*- coding: utf-8 -*-
10
+ # Time-stamp: "2024-10-22 19:51:47 (ywatanabe)"
11
+ # File: __init__.py
12
+
13
+ import os as __os
14
+ import importlib as __importlib
15
+ import inspect as __inspect
16
+ import warnings as __warnings
17
+
18
+ # Get the current directory
19
+ current_dir = __os.path.dirname(__file__)
20
+
21
+ # Iterate through all Python files in the current directory
22
+ for filename in __os.listdir(current_dir):
23
+ if filename.endswith(".py") and not filename.startswith("__"):
24
+ module_name = filename[:-3] # Remove .py extension
25
+ try:
26
+ module = __importlib.import_module(f".{module_name}", package=__name__)
27
+
28
+ # Import only functions and classes from the module
29
+ for name, obj in __inspect.getmembers(module):
30
+ if __inspect.isfunction(obj) or __inspect.isclass(obj):
31
+ if not name.startswith("_"):
32
+ globals()[name] = obj
33
+ except ImportError as e:
34
+ # Warn about modules that couldn't be imported due to missing dependencies
35
+ __warnings.warn(
36
+ f"Could not import {module_name} from scitex.ai.feature_extraction: {str(e)}. "
37
+ f"Some functionality may be unavailable. "
38
+ f"Consider installing missing dependencies if you need this module.",
39
+ ImportWarning,
40
+ stacklevel=2
41
+ )
42
+
43
+ # Clean up temporary variables
44
+ del __os, __importlib, __inspect, __warnings, current_dir
45
+ if 'filename' in locals():
46
+ del filename
47
+ if 'module_name' in locals():
48
+ del module_name
49
+ if 'module' in locals():
50
+ del module
51
+ if 'name' in locals():
52
+ del name
53
+ if 'obj' in locals():
54
+ del obj
55
+
56
+ # EOF
@@ -0,0 +1,148 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-27 21:36:51 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/feature_extraction/vit.py
5
+
6
+ THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/ai/feature_extraction/vit.py"
7
+
8
+ """
9
+ Functionality:
10
+ Extracts features from images using Vision Transformer (ViT) models
11
+ Input:
12
+ Image arrays of arbitrary dimensions
13
+ Output:
14
+ Feature vectors (1000-dimensional embeddings)
15
+ Prerequisites:
16
+ torch, PIL, torchvision
17
+ """
18
+
19
+ import os as _os
20
+ from typing import Tuple, Union
21
+
22
+ import torch
23
+ import torch as _torch
24
+ from pytorch_pretrained_vit import ViT
25
+ from torchvision import transforms as _transforms
26
+
27
+ # from ...decorators import batch_torch_fn
28
+
29
+
30
+ def _setup_device(device: Union[str, None]) -> str:
31
+ if device is None:
32
+ device = "cuda" if _torch.cuda.is_available() else "cpu"
33
+ return device
34
+
35
+
36
+ class VitFeatureExtractor:
37
+ def __init__(
38
+ self,
39
+ model_name="B_16",
40
+ torch_home="./models",
41
+ device=None,
42
+ ):
43
+ self.valid_models = [
44
+ "B_16",
45
+ "B_32",
46
+ "L_16",
47
+ "L_32",
48
+ "B_16_imagenet1k",
49
+ "B_32_imagenet1k",
50
+ "L_16_imagenet1k",
51
+ "L_32_imagenet1k",
52
+ ]
53
+ self.model_name = model_name
54
+ self.torch_home = torch_home
55
+ self.device = _setup_device(device)
56
+
57
+ _os.environ["TORCH_HOME"] = torch_home
58
+ self._validate_inputs()
59
+ self.model = ViT(model_name, pretrained=True).to(self.device).eval()
60
+ self.transform = _transforms.Compose(
61
+ [
62
+ _transforms.ToPILImage(),
63
+ _transforms.Resize(self.model.image_size),
64
+ _transforms.ToTensor(),
65
+ _transforms.Normalize(0.5, 0.5),
66
+ ]
67
+ )
68
+
69
+ def _validate_inputs(self):
70
+ if self.model_name not in self.valid_models:
71
+ raise ValueError(f"Invalid model name. Choose from: {self.valid_models}")
72
+ if not _os.path.exists(self.torch_home):
73
+ raise FileNotFoundError(f"Model directory not found: {self.torch_home}")
74
+
75
+ def _preprocess_array(
76
+ self,
77
+ arr: _torch.Tensor,
78
+ dim: Tuple[int, int],
79
+ channel_dim: Union[int, None],
80
+ ) -> _torch.Tensor:
81
+ # print(f"Input array shape: {arr.shape}")
82
+
83
+ orig_shape = arr.shape
84
+ dim = tuple(d if d >= 0 else len(orig_shape) + d for d in dim)
85
+
86
+ perm = list(range(len(orig_shape)))
87
+ for d in sorted(dim):
88
+ perm.remove(d)
89
+ perm.append(d)
90
+ arr = arr.permute(perm)
91
+
92
+ # Flatten all dimensions except the last two (spatial dimensions)
93
+ batch_shape = arr.shape[:-2]
94
+ spatial_shape = arr.shape[-2:]
95
+ arr = arr.reshape(-1, *spatial_shape)
96
+
97
+ # Process each image
98
+ transformed = []
99
+ for img in arr:
100
+ img = img.unsqueeze(0)
101
+ img = img.repeat(3, 1, 1)
102
+ transformed.append(self.transform(img))
103
+ result = _torch.stack(transformed)
104
+ return result, batch_shape
105
+
106
+ # @batch_method
107
+ # @torch_method
108
+ # @batch_torch_fn
109
+ def extract_features(
110
+ self,
111
+ arr,
112
+ axis=(-2, -1),
113
+ dim=None,
114
+ channel_dim=None,
115
+ batch_size=None,
116
+ device="cuda",
117
+ ):
118
+
119
+ processed_arr, batch_shape = self._preprocess_array(
120
+ arr,
121
+ axis,
122
+ channel_dim,
123
+ )
124
+ # print(f"Processed shape: {processed_arr.shape}")
125
+
126
+ processed_arr = processed_arr.to(self.device)
127
+ with _torch.no_grad():
128
+ features = self.model(processed_arr).cpu()
129
+
130
+ return features.reshape(*batch_shape, -1)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ import scitex
135
+
136
+ extractor = scitex.ai.feature_extraction.VitFeatureExtractor(
137
+ model_name="B_16_imagenet1k"
138
+ )
139
+ tensor = torch.randn(3, 2, 4, 5, 32, 32)
140
+ processed = extractor.extract_features(tensor, (-2, -1), None)
141
+ print(processed.shape)
142
+
143
+ arr = np.random.rand(3, 2, 4, 5, 32, 32)
144
+ processed = extractor.extract_features(arr, (-2, -1), None)
145
+ print(processed.shape)
146
+ # torch.Size([3, 2, 4, 5, 32, 32])
147
+
148
+ # EOF
@@ -0,0 +1,277 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-25 12:00:00"
4
+ # Author: Yusuke Watanabe (ywatanabe@alumni.u-tokyo.ac.jp)
5
+ # scitex/src/scitex/ai/genai/__init__.py
6
+
7
+ """
8
+ GenAI module for unified access to multiple AI providers.
9
+
10
+ This module provides a consistent interface for interacting with various
11
+ AI providers (OpenAI, Anthropic, Google, etc.) with built-in cost tracking,
12
+ chat history management, and error handling.
13
+ """
14
+
15
+ from typing import List, Dict, Any, Optional, Union
16
+ import logging
17
+
18
+ from .provider_factory import Provider, create_provider, GenAI as GenAIFactory
19
+ from .auth_manager import AuthManager
20
+ from .chat_history import ChatHistory
21
+ from .cost_tracker import CostTracker
22
+ from .response_handler import ResponseHandler
23
+ from .base_provider import BaseProvider, CompletionResponse
24
+
25
+ # Import legacy providers for backward compatibility
26
+ from .anthropic import Anthropic
27
+ from .openai import OpenAI
28
+ from .google import Google
29
+ from .groq import Groq
30
+ from .deepseek import DeepSeek
31
+ from .llama import Llama
32
+ from .perplexity import Perplexity
33
+ from .genai_factory import genai_factory
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class GenAI:
39
+ """
40
+ Unified interface for multiple AI providers.
41
+
42
+ This class provides a consistent API for interacting with various AI providers
43
+ while handling authentication, chat history, cost tracking, and response processing.
44
+
45
+ Args:
46
+ provider: Provider name (e.g., 'openai', 'anthropic', 'google')
47
+ api_key: Optional API key (if not provided, will use environment variable)
48
+ model: Model name (if not provided, will use provider's default)
49
+ system_prompt: Optional system prompt to prepend to conversations
50
+ **kwargs: Additional provider-specific configuration
51
+
52
+ Example:
53
+ >>> from scitex.ai.genai import GenAI
54
+ >>>
55
+ >>> # Basic usage
56
+ >>> ai = GenAI(provider="openai")
57
+ >>> response = ai.complete("What is the capital of France?")
58
+ >>> print(response)
59
+ "The capital of France is Paris."
60
+ >>>
61
+ >>> # With specific model and system prompt
62
+ >>> ai = GenAI(
63
+ ... provider="anthropic",
64
+ ... model="claude-3-opus-20240229",
65
+ ... system_prompt="You are a helpful geography expert."
66
+ ... )
67
+ >>> response = ai.complete("Tell me about Paris.")
68
+ >>>
69
+ >>> # Check costs
70
+ >>> print(ai.get_cost_summary())
71
+ "Total cost: $0.015 | Requests: 2 | Tokens: 1,234"
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ provider: Union[str, Provider],
77
+ api_key: Optional[str] = None,
78
+ model: Optional[str] = None,
79
+ system_prompt: Optional[str] = None,
80
+ **kwargs
81
+ ):
82
+ """Initialize GenAI with specified provider."""
83
+ # Store provider name
84
+ if isinstance(provider, str):
85
+ self.provider_name = provider.lower()
86
+ else:
87
+ self.provider_name = provider.value
88
+
89
+ # Initialize components
90
+ self.auth_manager = AuthManager(api_key, self.provider_name)
91
+ self.chat_history = ChatHistory(n_keep=-1) # Keep all messages by default
92
+ self.response_handler = ResponseHandler()
93
+
94
+ # Get API key from auth manager if not provided
95
+ if api_key is None:
96
+ api_key = self.auth_manager.api_key
97
+
98
+ # Create provider instance
99
+ self.provider = create_provider(
100
+ provider=self.provider_name, api_key=api_key, model=model, **kwargs
101
+ )
102
+
103
+ # Initialize cost tracker with provider and model
104
+ # Note: provider instance may have a model attribute set during initialization
105
+ actual_model = getattr(self.provider, "model", None) or model or "unknown"
106
+ self.cost_tracker = CostTracker(provider=self.provider_name, model=actual_model)
107
+
108
+ # Add system prompt if provided
109
+ if system_prompt:
110
+ self.chat_history.add_message("system", system_prompt)
111
+
112
+ logger.info(f"Initialized GenAI with provider: {self.provider_name}")
113
+
114
+ def complete(
115
+ self, prompt: str, images: Optional[List[str]] = None, **kwargs
116
+ ) -> str:
117
+ """
118
+ Generate a completion for the given prompt.
119
+
120
+ Args:
121
+ prompt: The input prompt
122
+ images: Optional list of image URLs or base64 strings
123
+ **kwargs: Additional provider-specific parameters
124
+
125
+ Returns:
126
+ The generated response text
127
+
128
+ Raises:
129
+ ValueError: If the provider doesn't support images but images are provided
130
+ Exception: Provider-specific exceptions
131
+ """
132
+ # Add user message to history
133
+ self.chat_history.add_message("user", prompt, images)
134
+
135
+ # Get messages for API call
136
+ messages = [msg.to_dict() for msg in self.chat_history.get_messages()]
137
+
138
+ # Call provider
139
+ try:
140
+ response: CompletionResponse = self.provider.complete(
141
+ messages=messages, **kwargs
142
+ )
143
+ except Exception as e:
144
+ logger.error(f"Provider {self.provider_name} failed: {str(e)}")
145
+ raise
146
+
147
+ # Process response - CompletionResponse has a content attribute
148
+ content = response.content
149
+
150
+ # Add assistant message to history
151
+ self.chat_history.add_message("assistant", content)
152
+
153
+ # Track costs - CompletionResponse has input_tokens and output_tokens
154
+ self.cost_tracker.update(
155
+ input_tokens=response.input_tokens, output_tokens=response.output_tokens
156
+ )
157
+
158
+ return content
159
+
160
+ def complete_async(self, prompt: str, images: Optional[List[str]] = None, **kwargs):
161
+ """
162
+ Async version of complete method.
163
+
164
+ Args:
165
+ prompt: The input prompt
166
+ images: Optional list of image URLs or base64 strings
167
+ **kwargs: Additional provider-specific parameters
168
+
169
+ Returns:
170
+ Awaitable that resolves to the generated response text
171
+ """
172
+ raise NotImplementedError("Async completion not yet implemented")
173
+
174
+ def stream(self, prompt: str, images: Optional[List[str]] = None, **kwargs):
175
+ """
176
+ Stream completions for the given prompt.
177
+
178
+ Args:
179
+ prompt: The input prompt
180
+ images: Optional list of image URLs or base64 strings
181
+ **kwargs: Additional provider-specific parameters
182
+
183
+ Yields:
184
+ Chunks of the generated response
185
+ """
186
+ raise NotImplementedError("Streaming not yet implemented")
187
+
188
+ def clear_history(self):
189
+ """Clear the chat history."""
190
+ self.chat_history.clear()
191
+ logger.info("Chat history cleared")
192
+
193
+ def get_history(self) -> List[Dict[str, str]]:
194
+ """Get the current chat history."""
195
+ return self.chat_history.messages
196
+
197
+ def get_cost_summary(self) -> str:
198
+ """Get a summary of costs incurred."""
199
+ return self.cost_tracker.get_summary()
200
+
201
+ def get_detailed_costs(self) -> Dict[str, Any]:
202
+ """Get detailed cost breakdown."""
203
+ return {
204
+ "total_cost": self.cost_tracker.total_cost,
205
+ "total_prompt_tokens": self.cost_tracker.total_prompt_tokens,
206
+ "total_completion_tokens": self.cost_tracker.total_completion_tokens,
207
+ "request_count": self.cost_tracker.request_count,
208
+ "cost_by_model": self.cost_tracker.cost_by_model,
209
+ }
210
+
211
+ def reset_costs(self):
212
+ """Reset cost tracking."""
213
+ self.cost_tracker.reset()
214
+ logger.info("Cost tracking reset")
215
+
216
+ def __repr__(self) -> str:
217
+ """String representation of GenAI instance."""
218
+ return (
219
+ f"GenAI(provider='{self.provider_name}', "
220
+ f"model='{self.provider.model}', "
221
+ f"requests={self.cost_tracker.request_count})"
222
+ )
223
+
224
+
225
+ # Convenience function for one-off completions
226
+ def complete(
227
+ prompt: str,
228
+ provider: Union[str, Provider] = "openai",
229
+ model: Optional[str] = None,
230
+ api_key: Optional[str] = None,
231
+ **kwargs
232
+ ) -> str:
233
+ """
234
+ Convenience function for one-off completions without managing state.
235
+
236
+ Args:
237
+ prompt: The input prompt
238
+ provider: Provider name or enum
239
+ model: Optional model name
240
+ api_key: Optional API key
241
+ **kwargs: Additional parameters
242
+
243
+ Returns:
244
+ The generated response text
245
+
246
+ Example:
247
+ >>> from scitex.ai.genai import complete
248
+ >>> response = complete("What is 2+2?", provider="anthropic")
249
+ >>> print(response)
250
+ "2 + 2 = 4"
251
+ """
252
+ genai = GenAI(provider=provider, model=model, api_key=api_key)
253
+ return genai.complete(prompt, **kwargs)
254
+
255
+
256
+ # Export public API
257
+ __all__ = [
258
+ # New API
259
+ "GenAI",
260
+ "GenAIFactory",
261
+ "complete",
262
+ "Provider",
263
+ "create_provider",
264
+ "AuthManager",
265
+ "ChatHistory",
266
+ "CostTracker",
267
+ "ResponseHandler",
268
+ # Legacy API for backward compatibility
269
+ "genai_factory",
270
+ "Anthropic",
271
+ "OpenAI",
272
+ "Google",
273
+ "Groq",
274
+ "DeepSeek",
275
+ "Llama",
276
+ "Perplexity",
277
+ ]