scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,291 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2025-05-31 10:30:00"
4
+ # Author: ywatanabe
5
+ # File: ./src/scitex/ai/genai/base_provider.py
6
+
7
+ """
8
+ Abstract base class for AI provider implementations.
9
+
10
+ This module defines the interface that all AI providers must implement,
11
+ ensuring consistency across different providers (OpenAI, Anthropic, etc.).
12
+ """
13
+
14
+ from abc import ABC, abstractmethod
15
+ from dataclasses import dataclass, field
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Generator, Optional
18
+
19
+
20
+ class Provider(str, Enum):
21
+ """Supported AI providers."""
22
+
23
+ OPENAI = "openai"
24
+ ANTHROPIC = "anthropic"
25
+ GOOGLE = "google"
26
+ GROQ = "groq"
27
+ DEEPSEEK = "deepseek"
28
+ LLAMA = "llama"
29
+ PERPLEXITY = "perplexity"
30
+ MOCK = "mock" # For testing
31
+
32
+ def __str__(self):
33
+ return self.value
34
+
35
+
36
+ class Role(str, Enum):
37
+ """Message roles for chat conversations."""
38
+
39
+ SYSTEM = "system"
40
+ USER = "user"
41
+ ASSISTANT = "assistant"
42
+
43
+
44
+ @dataclass
45
+ class ProviderConfig:
46
+ """Configuration for AI providers."""
47
+
48
+ provider: str
49
+ model: str
50
+ api_key: Optional[str] = None
51
+ system_prompt: str = ""
52
+ temperature: float = 1.0
53
+ max_tokens: int = 4096
54
+ stream: bool = False
55
+ seed: Optional[int] = None
56
+ n_keep: int = 1
57
+
58
+
59
+ @dataclass
60
+ class CompletionResponse:
61
+ """Standard response format for completions."""
62
+
63
+ content: str
64
+ input_tokens: int
65
+ output_tokens: int
66
+ finish_reason: str = "stop"
67
+ provider_response: Optional[Any] = None
68
+
69
+
70
+ class BaseProvider(ABC):
71
+ """Abstract base class for AI providers.
72
+
73
+ All AI provider implementations must inherit from this class
74
+ and implement the required abstract methods.
75
+
76
+ Example
77
+ -------
78
+ >>> class MyProvider(BaseProvider):
79
+ ... def init_client(self) -> Any:
80
+ ... return MyAPIClient(self.api_key)
81
+ ...
82
+ ... def format_history(self, history: List[Dict]) -> List[Dict]:
83
+ ... # Provider-specific formatting
84
+ ... return history
85
+ ...
86
+ ... def call_static(self, messages: List[Dict], **kwargs) -> Any:
87
+ ... # Make API call
88
+ ... return self.client.complete(messages)
89
+ ...
90
+ ... def call_stream(self, messages: List[Dict], **kwargs) -> Generator:
91
+ ... # Make streaming API call
92
+ ... for chunk in self.client.stream(messages):
93
+ ... yield chunk
94
+ """
95
+
96
+ @abstractmethod
97
+ def init_client(self) -> Any:
98
+ """Initialize the provider-specific client.
99
+
100
+ This method should create and configure the API client
101
+ for the specific provider (e.g., OpenAI client, Anthropic client).
102
+
103
+ Returns
104
+ -------
105
+ Any
106
+ The initialized client object
107
+ """
108
+ pass
109
+
110
+ @abstractmethod
111
+ def format_history(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
112
+ """Format conversation history for the provider's API.
113
+
114
+ Different providers may expect different formats for conversation
115
+ history. This method converts the standard format to the
116
+ provider-specific format.
117
+
118
+ Parameters
119
+ ----------
120
+ history : List[Dict[str, Any]]
121
+ Standard format conversation history
122
+
123
+ Returns
124
+ -------
125
+ List[Dict[str, Any]]
126
+ Provider-specific formatted history
127
+ """
128
+ pass
129
+
130
+ @abstractmethod
131
+ def call_static(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
132
+ """Make a static (non-streaming) API call.
133
+
134
+ Parameters
135
+ ----------
136
+ messages : List[Dict[str, Any]]
137
+ Formatted conversation messages
138
+ **kwargs
139
+ Additional provider-specific parameters
140
+
141
+ Returns
142
+ -------
143
+ Any
144
+ Provider-specific response object
145
+ """
146
+ pass
147
+
148
+ @abstractmethod
149
+ def call_stream(
150
+ self, messages: List[Dict[str, Any]], **kwargs
151
+ ) -> Generator[str, None, None]:
152
+ """Make a streaming API call.
153
+
154
+ Parameters
155
+ ----------
156
+ messages : List[Dict[str, Any]]
157
+ Formatted conversation messages
158
+ **kwargs
159
+ Additional provider-specific parameters
160
+
161
+ Yields
162
+ ------
163
+ str
164
+ Response text chunks
165
+ """
166
+ pass
167
+
168
+ @property
169
+ @abstractmethod
170
+ def supports_streaming(self) -> bool:
171
+ """Whether this provider supports streaming responses.
172
+
173
+ Returns
174
+ -------
175
+ bool
176
+ True if streaming is supported
177
+ """
178
+ pass
179
+
180
+ @property
181
+ @abstractmethod
182
+ def supports_images(self) -> bool:
183
+ """Whether this provider supports image inputs.
184
+
185
+ Returns
186
+ -------
187
+ bool
188
+ True if images are supported
189
+ """
190
+ pass
191
+
192
+ @property
193
+ @abstractmethod
194
+ def max_context_length(self) -> int:
195
+ """Maximum context length in tokens.
196
+
197
+ Returns
198
+ -------
199
+ int
200
+ Maximum number of tokens
201
+ """
202
+ pass
203
+
204
+ def get_capabilities(self) -> Dict[str, Any]:
205
+ """Get provider capabilities summary.
206
+
207
+ Returns
208
+ -------
209
+ Dict[str, Any]
210
+ Dictionary of provider capabilities
211
+ """
212
+ return {
213
+ "supports_streaming": self.supports_streaming,
214
+ "supports_images": self.supports_images,
215
+ "max_context_length": self.max_context_length,
216
+ }
217
+
218
+ def extract_tokens_from_response(self, response: Any) -> Dict[str, int]:
219
+ """Extract token usage from provider response.
220
+
221
+ Default implementation returns zeros. Providers should override
222
+ to extract actual token counts from their response format.
223
+
224
+ Parameters
225
+ ----------
226
+ response : Any
227
+ Provider-specific response object
228
+
229
+ Returns
230
+ -------
231
+ Dict[str, int]
232
+ Dictionary with 'input_tokens' and 'output_tokens'
233
+ """
234
+ return {"input_tokens": 0, "output_tokens": 0}
235
+
236
+ def handle_rate_limit(self, error: Exception) -> bool:
237
+ """Handle rate limit errors.
238
+
239
+ Default implementation returns False. Providers can override
240
+ to implement retry logic or other handling.
241
+
242
+ Parameters
243
+ ----------
244
+ error : Exception
245
+ The error that occurred
246
+
247
+ Returns
248
+ -------
249
+ bool
250
+ True if the error was handled and operation should retry
251
+ """
252
+ return False
253
+
254
+ def validate_model(self, model: str) -> bool:
255
+ """Validate if a model is supported.
256
+
257
+ Default implementation returns True. Providers should override
258
+ to validate against their supported models.
259
+
260
+ Parameters
261
+ ----------
262
+ model : str
263
+ Model name to validate
264
+
265
+ Returns
266
+ -------
267
+ bool
268
+ True if model is supported
269
+ """
270
+ return True
271
+
272
+ def get_error_message(self, error: Exception) -> str:
273
+ """Extract user-friendly error message.
274
+
275
+ Default implementation returns string representation.
276
+ Providers can override for better error messages.
277
+
278
+ Parameters
279
+ ----------
280
+ error : Exception
281
+ The error that occurred
282
+
283
+ Returns
284
+ -------
285
+ str
286
+ User-friendly error message
287
+ """
288
+ return str(error)
289
+
290
+
291
+ # EOF
@@ -0,0 +1,78 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2024-11-04 01:37:36 (ywatanabe)"
4
+ # File: ./scitex_repo/src/scitex/ai/_gen_ai/_calc_cost.py
5
+
6
+ """
7
+ Functionality:
8
+ - Calculates usage costs for AI model API calls
9
+ - Handles token-based pricing for different models
10
+ Input:
11
+ - Model name
12
+ - Number of input and output tokens used
13
+ Output:
14
+ - Total cost in USD based on token usage
15
+ Prerequisites:
16
+ - MODELS parameter dictionary with pricing information
17
+ - pandas package
18
+ """
19
+
20
+ from typing import Union, Any
21
+ import pandas as pd
22
+
23
+ from .params import MODELS
24
+
25
+
26
+ def calc_cost(model: str, input_tokens: int, output_tokens: int) -> float:
27
+ """Calculates API usage cost based on token count.
28
+
29
+ Example
30
+ -------
31
+ >>> cost = calc_cost("gpt-4", 100, 50)
32
+ >>> print(f"${cost:.4f}")
33
+ $0.0030
34
+
35
+ Parameters
36
+ ----------
37
+ model : str
38
+ Name of the AI model
39
+ input_tokens : int
40
+ Number of input tokens used
41
+ output_tokens : int
42
+ Number of output tokens used
43
+
44
+ Returns
45
+ -------
46
+ float
47
+ Total cost in USD
48
+
49
+ Raises
50
+ ------
51
+ ValueError
52
+ If model is not found in MODELS
53
+ """
54
+ models_df = pd.DataFrame(MODELS)
55
+ indi = models_df["name"] == model
56
+
57
+ if not indi.any():
58
+ raise ValueError(f"Model '{model}' not found in pricing table")
59
+
60
+ costs = models_df[["input_cost", "output_cost"]][indi]
61
+ cost = (
62
+ input_tokens * costs["input_cost"] + output_tokens * costs["output_cost"]
63
+ ) / 1_000_000
64
+
65
+ return cost.iloc[0]
66
+
67
+
68
+ # def calc_cost(model, input_tokens, output_tokens):
69
+ # indi = MODELS["name"] == model
70
+ # costs = MODELS[["input_cost", "output_cost"]][indi]
71
+ # cost = (
72
+ # input_tokens * costs["input_cost"]
73
+ # + output_tokens * costs["output_cost"]
74
+ # ) / 1_000_000
75
+ # return cost.iloc[0]
76
+
77
+
78
+ # EOF
@@ -0,0 +1,307 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Time-stamp: "2025-05-31 10:10:00"
4
+ # Author: ywatanabe
5
+ # File: ./src/scitex/ai/genai/chat_history.py
6
+
7
+ """
8
+ Manages conversation history for AI providers.
9
+
10
+ This module handles chat history management including:
11
+ - Message storage and retrieval
12
+ - Role alternation enforcement
13
+ - System message handling
14
+ - History truncation
15
+ """
16
+
17
+ from typing import List, Dict, Optional, Any
18
+ from dataclasses import dataclass, field
19
+ from copy import deepcopy
20
+
21
+
22
+ @dataclass
23
+ class Message:
24
+ """Represents a single message in chat history.
25
+
26
+ Attributes
27
+ ----------
28
+ role : str
29
+ Message role (system, user, assistant)
30
+ content : str
31
+ Message content
32
+ images : Optional[List[str]]
33
+ Optional base64-encoded images
34
+ """
35
+
36
+ role: str
37
+ content: str
38
+ images: Optional[List[str]] = None
39
+
40
+ def to_dict(self) -> Dict[str, Any]:
41
+ """Convert message to dictionary.
42
+
43
+ Returns
44
+ -------
45
+ Dict[str, Any]
46
+ Dictionary representation
47
+ """
48
+ d = {"role": self.role, "content": self.content}
49
+ if self.images:
50
+ d["images"] = self.images
51
+ return d
52
+
53
+
54
+ class ChatHistory:
55
+ """Manages conversation history with role enforcement.
56
+
57
+ Example
58
+ -------
59
+ >>> history = ChatHistory(n_keep=5)
60
+ >>> history.add_message("user", "Hello")
61
+ >>> history.add_message("assistant", "Hi there!")
62
+ >>> messages = history.get_messages()
63
+ >>> print(len(messages))
64
+ 2
65
+
66
+ Parameters
67
+ ----------
68
+ system_prompt : Optional[str]
69
+ Optional system prompt to prepend
70
+ n_keep : int
71
+ Number of recent exchanges to keep (default: 1)
72
+ """
73
+
74
+ VALID_ROLES = {"system", "user", "assistant"}
75
+
76
+ def __init__(self, system_prompt: Optional[str] = None, n_keep: int = 1):
77
+ """Initialize chat history manager.
78
+
79
+ Parameters
80
+ ----------
81
+ system_prompt : Optional[str]
82
+ Optional system prompt
83
+ n_keep : int
84
+ Number of recent exchanges to keep (-1 to keep all)
85
+ """
86
+ self.system_prompt = system_prompt or ""
87
+ self.n_keep = n_keep
88
+ self.messages: List[Message] = []
89
+
90
+ # Add system message if provided
91
+ if system_prompt:
92
+ self.messages.append(Message(role="system", content=system_prompt))
93
+
94
+ def add_message(
95
+ self, role: str, content: str, images: Optional[List[str]] = None
96
+ ) -> None:
97
+ """Add a message to the history.
98
+
99
+ Parameters
100
+ ----------
101
+ role : str
102
+ Message role ("user", "assistant", "system")
103
+ content : str
104
+ Message content
105
+ images : Optional[List[str]]
106
+ Optional images for multimodal messages
107
+
108
+ Raises
109
+ ------
110
+ ValueError
111
+ If role is invalid
112
+ """
113
+ if role not in self.VALID_ROLES:
114
+ raise ValueError(f"Invalid role: {role}. Must be one of {self.VALID_ROLES}")
115
+
116
+ # Don't add duplicate system messages
117
+ if role == "system" and self.messages and self.messages[0].role == "system":
118
+ self.messages[0] = Message(role=role, content=content)
119
+ return
120
+
121
+ self.messages.append(Message(role=role, content=content, images=images))
122
+ self._trim_history()
123
+
124
+ def _trim_history(self) -> None:
125
+ """Trim history to n_keep exchanges."""
126
+ if self.n_keep == -1:
127
+ return
128
+
129
+ # Count system message
130
+ has_system = self.messages and self.messages[0].role == "system"
131
+ start_idx = 1 if has_system else 0
132
+
133
+ # Keep only last n_keep exchanges (2 messages per exchange)
134
+ if len(self.messages) - start_idx > self.n_keep * 2:
135
+ kept_messages = self.messages[-self.n_keep * 2 :]
136
+ if has_system:
137
+ self.messages = [self.messages[0]] + kept_messages
138
+ else:
139
+ self.messages = kept_messages
140
+
141
+ def format_for_api(self, provider: str) -> List[Dict[str, Any]]:
142
+ """Format messages for specific provider API.
143
+
144
+ Parameters
145
+ ----------
146
+ provider : str
147
+ Provider name (openai, anthropic, google)
148
+
149
+ Returns
150
+ -------
151
+ List[Dict[str, Any]]
152
+ Formatted messages
153
+ """
154
+ provider = provider.lower()
155
+
156
+ if provider == "openai":
157
+ return self._format_for_openai()
158
+ elif provider == "anthropic":
159
+ return self._format_for_anthropic()
160
+ elif provider == "google":
161
+ return self._format_for_google()
162
+ else:
163
+ # Default format
164
+ return [msg.to_dict() for msg in self.messages]
165
+
166
+ def _format_for_openai(self) -> List[Dict[str, Any]]:
167
+ """Format messages for OpenAI API."""
168
+ formatted = []
169
+
170
+ for msg in self.messages:
171
+ if msg.images:
172
+ # Multimodal message
173
+ content = [{"type": "text", "text": msg.content}]
174
+ for img in msg.images:
175
+ content.append(
176
+ {
177
+ "type": "image_url",
178
+ "image_url": {"url": f"data:image/jpeg;base64,{img}"},
179
+ }
180
+ )
181
+ formatted.append({"role": msg.role, "content": content})
182
+ else:
183
+ formatted.append({"role": msg.role, "content": msg.content})
184
+
185
+ return formatted
186
+
187
+ def _format_for_anthropic(self) -> List[Dict[str, Any]]:
188
+ """Format messages for Anthropic API (excludes system)."""
189
+ formatted = []
190
+
191
+ for msg in self.messages:
192
+ if msg.role == "system":
193
+ continue # Anthropic handles system separately
194
+ formatted.append({"role": msg.role, "content": msg.content})
195
+
196
+ return formatted
197
+
198
+ def _format_for_google(self) -> List[Dict[str, Any]]:
199
+ """Format messages for Google API."""
200
+ formatted = []
201
+
202
+ for msg in self.messages:
203
+ if msg.images:
204
+ parts = [{"text": msg.content}]
205
+ for img in msg.images:
206
+ parts.append(
207
+ {"inline_data": {"mime_type": "image/jpeg", "data": img}}
208
+ )
209
+ formatted.append({"role": msg.role, "parts": parts})
210
+ else:
211
+ formatted.append({"role": msg.role, "parts": [{"text": msg.content}]})
212
+
213
+ return formatted
214
+
215
+ def ensure_valid_sequence(self) -> None:
216
+ """Ensure messages follow valid sequence rules.
217
+
218
+ - Must start with user message (after system)
219
+ - Must alternate between user and assistant
220
+ """
221
+ if not self.messages:
222
+ return
223
+
224
+ # Skip system message if present
225
+ start_idx = 1 if self.messages and self.messages[0].role == "system" else 0
226
+
227
+ # Ensure starts with user
228
+ if (
229
+ len(self.messages) > start_idx
230
+ and self.messages[start_idx].role == "assistant"
231
+ ):
232
+ self.messages.insert(start_idx, Message(role="user", content="Hello"))
233
+
234
+ # Ensure alternating
235
+ i = start_idx
236
+ while i < len(self.messages) - 1:
237
+ current = self.messages[i]
238
+ next_msg = self.messages[i + 1]
239
+
240
+ if current.role == next_msg.role:
241
+ # Insert appropriate message
242
+ if current.role == "user":
243
+ self.messages.insert(
244
+ i + 1, Message(role="assistant", content="...")
245
+ )
246
+ else:
247
+ self.messages.insert(i + 1, Message(role="user", content="..."))
248
+ i += 1
249
+
250
+ def clear(self) -> None:
251
+ """Clear history, keeping only system message if present."""
252
+ if self.messages and self.messages[0].role == "system":
253
+ self.messages = [self.messages[0]]
254
+ else:
255
+ self.messages = []
256
+
257
+ def get_messages(self) -> List[Message]:
258
+ """Get copy of messages.
259
+
260
+ Returns
261
+ -------
262
+ List[Message]
263
+ Copy of message list
264
+ """
265
+ return deepcopy(self.messages)
266
+
267
+ def __len__(self) -> int:
268
+ """Get number of messages in history."""
269
+ return len(self.messages)
270
+
271
+ def __repr__(self) -> str:
272
+ """String representation of ChatHistory."""
273
+ return f"ChatHistory(messages={len(self.messages)}, n_keep={self.n_keep})"
274
+
275
+
276
+ # Backward compatibility aliases
277
+ def get_history(self) -> List[Dict[str, Any]]:
278
+ """Get history as list of dicts (backward compatibility)."""
279
+ return [msg.to_dict() for msg in self.messages]
280
+
281
+
282
+ def ensure_alternating(self) -> None:
283
+ """Ensure alternating messages (backward compatibility)."""
284
+ self.ensure_valid_sequence()
285
+
286
+
287
+ def ensure_user_first(self) -> None:
288
+ """Ensure user first (backward compatibility)."""
289
+ self.ensure_valid_sequence()
290
+
291
+
292
+ def reset(self, system_message: Optional[str] = None) -> None:
293
+ """Reset history (backward compatibility)."""
294
+ self.clear()
295
+ if system_message:
296
+ self.system_prompt = system_message
297
+ self.messages.append(Message(role="system", content=system_message))
298
+
299
+
300
+ # Add backward compatibility methods to ChatHistory
301
+ ChatHistory.get_history = get_history
302
+ ChatHistory.ensure_alternating = ensure_alternating
303
+ ChatHistory.ensure_user_first = ensure_user_first
304
+ ChatHistory.reset = reset
305
+
306
+
307
+ # EOF